use super::MsgListener;
use sn_interface::types::{log_markers::LogMarker, Peer};
use bytes::Bytes;
use priority_queue::DoublePriorityQueue;
use qp2p::{Endpoint, RetryConfig};
use std::{
collections::BTreeMap,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
time::{Duration, Instant},
};
use tokio::sync::{Mutex, RwLock};
type Priority = u64;
type ConnId = usize;
const CAPACITY: u8 = u8::MAX;
const UNUSED_TTL: Duration = Duration::from_secs(120);
#[derive(Clone)]
pub(crate) struct Link {
peer: Peer,
endpoint: Endpoint,
create_mutex: Arc<Mutex<usize>>,
connections: Arc<RwLock<BTreeMap<ConnId, ExpiringConn>>>,
queue: Arc<RwLock<DoublePriorityQueue<ConnId, Priority>>>,
access_counter: Arc<AtomicU64>,
listener: MsgListener,
expiration_check: Arc<RwLock<Instant>>,
}
impl Link {
pub(crate) fn new(peer: Peer, endpoint: Endpoint, listener: MsgListener) -> Self {
Self {
peer,
endpoint,
create_mutex: Arc::new(Mutex::new(0)),
connections: Arc::new(RwLock::new(BTreeMap::new())),
queue: Arc::new(RwLock::new(DoublePriorityQueue::new())),
access_counter: Arc::new(AtomicU64::new(0)),
listener,
expiration_check: Arc::new(RwLock::new(expiration())),
}
}
pub(crate) async fn new_with(
peer: Peer,
endpoint: Endpoint,
listener: MsgListener,
conn: qp2p::Connection,
) -> Self {
let instance = Self::new(peer, endpoint, listener);
instance.insert(conn).await;
instance
}
#[cfg(feature = "test-utils")]
pub(crate) fn peer(&self) -> &Peer {
&self.peer
}
pub(crate) async fn add(&self, conn: qp2p::Connection) {
self.insert(conn).await;
}
pub(crate) async fn disconnect(self) {
let _ = self.queue.write().await.clear();
let mut guard = self.connections.write().await;
for (_, item) in guard.iter() {
item.conn
.close(Some("We disconnected from peer.".to_string()));
}
guard.clear();
}
#[allow(unused)]
pub(crate) async fn send(&self, msg: Bytes) -> Result<(), SendToOneError> {
self.send_with(msg, 0, None).await
}
pub(crate) async fn send_with(
&self,
msg: Bytes,
priority: i32,
retry_config: Option<&RetryConfig>,
) -> Result<(), SendToOneError> {
let conn = self.get_or_connect().await?;
let queue_len = { self.queue.read().await.len() };
trace!(
"We have {} open connections to node {:?}.",
queue_len,
self.peer
);
match conn.send_with(msg, priority, retry_config).await {
Ok(()) => {
self.listener.count_msg().await;
Ok(())
}
Err(error) => {
let id = &conn.id();
{
let _ = self.connections.write().await.remove(id);
}
{
let _ = self.queue.write().await.remove(id);
}
conn.close(Some(format!("{:?}", error)));
Err(SendToOneError::Send(error))
}
}
}
async fn get_or_connect(&self) -> Result<qp2p::Connection, SendToOneError> {
let res = { self.queue.read().await.peek_max().map(|(id, _prio)| *id) };
match res {
None => {
let _lock = self.create_mutex.lock().await;
let res = { self.queue.read().await.peek_max().map(|(id, _prio)| *id) };
if let Some(id) = res {
self.read_conn(id).await
} else {
self.create_connection().await
}
}
Some(id) => self.read_conn(id).await,
}
}
pub(crate) async fn is_connected(&self) -> bool {
let res = { self.queue.read().await.peek_max().map(|(id, _prio)| *id) };
match res {
None => false,
Some(id) => match self.connections.read().await.get(&id) {
Some(conn) => !conn.expired().await,
None => false,
},
}
}
async fn read_conn(&self, id: usize) -> Result<qp2p::Connection, SendToOneError> {
let res = { self.connections.read().await.get(&id).cloned() };
match res {
Some(item) => {
self.touch(item.conn.id()).await;
Ok(item.conn)
}
None => self.create_connection().await,
}
}
async fn create_connection(&self) -> Result<qp2p::Connection, SendToOneError> {
let (conn, incoming_msgs) = self
.endpoint
.connect_to(&self.peer.addr())
.await
.map_err(SendToOneError::Connection)?;
trace!(
"{} to {} (id: {})",
LogMarker::ConnectionOpened,
conn.remote_address(),
conn.id()
);
self.insert(conn.clone()).await;
self.listener.listen(conn.clone(), incoming_msgs);
Ok(conn)
}
async fn insert(&self, conn: qp2p::Connection) {
let id = conn.id();
{
let _ = self
.connections
.write()
.await
.insert(id, ExpiringConn::new(conn));
}
{
let _ = self.queue.write().await.push(id, self.priority().await);
}
}
async fn touch(&self, id: ConnId) {
{
let _ = self
.queue
.write()
.await
.change_priority(&id, self.priority().await);
}
{
if let Some(conn) = self.connections.read().await.get(&id) {
conn.touch().await
}
}
}
async fn priority(&self) -> Priority {
let prio = self.access_counter.fetch_add(1, Ordering::SeqCst);
if prio == u64::MAX {
let mut queue = self.queue.write().await;
let clone = queue.clone();
for (id, _old_prio) in clone.into_sorted_iter() {
let _ =
queue.change_priority(&id, self.access_counter.fetch_add(1, Ordering::SeqCst));
}
self.access_counter.fetch_add(1, Ordering::SeqCst)
} else {
prio
}
}
pub(crate) async fn remove_expired(&self) {
if Instant::now() > { *self.expiration_check.read().await } {
*self.expiration_check.write().await = expiration();
} else {
return;
}
let queue = {
let queue = self.queue.read().await;
queue.clone()
};
let mut remaining = queue.len();
let mut expired_ids = vec![];
for (id, _old_prio) in queue.into_sorted_iter() {
if remaining <= 1 {
break;
}
let read_items = self.connections.read().await;
if let Some(conn) = read_items.get(&id) {
if conn.expired().await {
expired_ids.push(id);
remaining -= 1;
}
}
}
for id in expired_ids {
{
let _ = self.queue.write().await.remove(&id);
}
let removed = { self.connections.write().await.remove(&id) };
if let Some(item) = removed {
trace!("Connection expired: {}", item.conn.id());
item.conn.close(Some("Connection expired.".to_string()));
}
}
self.drop_excess().await;
}
async fn drop_excess(&self) {
let len = { self.queue.read().await.len() };
if len >= CAPACITY as usize {
let popped = { self.queue.write().await.pop_min() };
if let Some((evicted_id, _)) = popped {
let removed = { self.connections.write().await.remove(&evicted_id) };
if let Some(item) = removed {
trace!("Connection evicted: {}", evicted_id);
item.conn.close(Some("Connection evicted.".to_string()));
}
}
}
}
}
#[derive(Debug)]
pub(crate) enum SendToOneError {
Connection(qp2p::ConnectionError),
Send(qp2p::SendError),
}
impl SendToOneError {
#[allow(unused)]
pub(crate) fn is_local_close(&self) -> bool {
matches!(
self,
SendToOneError::Connection(qp2p::ConnectionError::Closed(qp2p::Close::Local))
| SendToOneError::Send(qp2p::SendError::ConnectionLost(
qp2p::ConnectionError::Closed(qp2p::Close::Local)
))
)
}
}
#[derive(Clone, Debug)]
struct ExpiringConn {
conn: qp2p::Connection,
expiry: Arc<RwLock<Instant>>,
}
impl ExpiringConn {
fn new(conn: qp2p::Connection) -> Self {
ExpiringConn {
conn,
expiry: Arc::new(RwLock::new(expiration())),
}
}
async fn expired(&self) -> bool {
*self.expiry.read().await < Instant::now()
}
async fn touch(&self) {
*self.expiry.write().await = expiration();
}
}
fn expiration() -> Instant {
Instant::now() + UNUSED_TTL
}