use std::{future::Future, net::SocketAddr, time::Duration};
use tokio::{
sync::{mpsc, oneshot},
task::JoinHandle,
time::timeout,
};
use tracing::*;
#[cfg(doc)]
use crate::{
Connection, Node,
protocols::{Reading, Writing},
};
use crate::{
Pea2Pea, connections::create_connection_span, node::NodeTask, protocols::ProtocolHandler,
};
pub trait OnDisconnect: Pea2Pea
where
Self: Clone + Send + Sync + 'static,
{
const TIMEOUT_MS: u64 = 3_000;
fn enable_on_disconnect(&self) -> impl Future<Output = ()> {
async {
let (from_node_sender, mut from_node_receiver) =
mpsc::channel::<(
SocketAddr,
oneshot::Sender<(JoinHandle<()>, oneshot::Receiver<()>)>,
)>(self.node().config().max_connections as usize);
let (tx, rx) = oneshot::channel::<()>();
let self_clone = self.clone();
let disconnect_task = tokio::spawn(async move {
trace!(parent: self_clone.node().span(), "spawned the OnDisconnect handler task");
if tx.send(()).is_err() {
error!(parent: self_clone.node().span(), "OnDisconnect handler creation interrupted! shutting down the node");
self_clone.node().shut_down().await;
return;
}
while let Some((addr, notifier)) = from_node_receiver.recv().await {
let self_clone2 = self_clone.clone();
let (done_tx, done_rx) = oneshot::channel();
let handle = tokio::spawn(async move {
if timeout(
Duration::from_millis(Self::TIMEOUT_MS),
self_clone2.on_disconnect(addr),
)
.await
.is_err()
{
let conn_span = create_connection_span(addr, self_clone2.node().span());
warn!(parent: conn_span, "OnDisconnect logic timed out");
}
let _ = done_tx.send(());
});
let _ = notifier.send((handle, done_rx)); }
});
let _ = rx.await;
self.node()
.tasks
.lock()
.insert(NodeTask::OnDisconnect, disconnect_task);
let hdl = ProtocolHandler(from_node_sender);
assert!(
self.node().protocols.on_disconnect.set(hdl).is_ok(),
"the OnDisconnect protocol was enabled more than once!"
);
}
}
fn on_disconnect(&self, addr: SocketAddr) -> impl Future<Output = ()> + Send;
}