use crate::async_rt::task::{spawn, JoinHandle};
use crate::backend::DisconnectNotifier;
use crate::endpoint::Endpoint;
use crate::transport;
use crate::util::{greet_exchange, ready_exchange, PeerIdentity};
use crate::MultiPeerBackend;
use futures::channel::{mpsc, oneshot};
use futures::{FutureExt, StreamExt};
use rand::Rng;
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ReconnectConfig {
pub initial_interval: Duration,
pub max_interval: Duration,
pub backoff_multiplier: f64,
}
impl Default for ReconnectConfig {
fn default() -> Self {
Self {
initial_interval: Duration::from_millis(100),
max_interval: Duration::from_secs(30),
backoff_multiplier: 2.0,
}
}
}
pub struct ReconnectHandle {
shutdown_tx: Option<oneshot::Sender<()>>,
#[allow(dead_code)] task_handle: JoinHandle<()>,
}
impl ReconnectHandle {
pub fn shutdown(mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
pub type RegisterDisconnectFn = Box<dyn Fn(PeerIdentity, DisconnectNotifier) + Send + Sync>;
pub fn spawn_reconnect_task(
endpoint: Endpoint,
backend: Arc<dyn MultiPeerBackend>,
initial_peer_id: PeerIdentity,
register_disconnect_fn: RegisterDisconnectFn,
config: ReconnectConfig,
) -> ReconnectHandle {
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (disconnect_tx, mut disconnect_rx) = mpsc::channel::<PeerIdentity>(1);
register_disconnect_fn(initial_peer_id.clone(), disconnect_tx.clone());
let task_handle = spawn(async move {
log::debug!("Reconnect task started for endpoint: {}", endpoint);
let mut shutdown_rx = shutdown_rx.fuse();
loop {
let peer_id = futures::select! {
peer_id = disconnect_rx.next() => {
if let Some(id) = peer_id {
id
} else {
log::debug!("Disconnect channel closed, stopping reconnect task");
return;
}
}
_ = shutdown_rx => {
log::debug!("Shutdown received, stopping reconnect task");
return;
}
};
log::info!(
"Peer {:?} disconnected from {}, starting reconnection",
peer_id,
endpoint
);
let mut current_interval = config.initial_interval;
let mut attempt = 0u32;
'retry: loop {
attempt += 1;
log::debug!(
"Reconnection attempt {} to {} (waiting {:?})",
attempt,
endpoint,
current_interval
);
let sleep_fut = crate::async_rt::task::sleep(current_interval).fuse();
futures::pin_mut!(sleep_fut);
futures::select! {
_ = sleep_fut => {
}
_ = shutdown_rx => {
log::debug!("Shutdown received during backoff, stopping reconnect task");
return;
}
}
match try_reconnect(&endpoint, backend.clone()).await {
Ok((new_peer_id, resolved_endpoint)) => {
log::info!(
"Successfully reconnected to {} (peer {:?})",
endpoint,
new_peer_id
);
if let Some(monitor) = backend.monitor().lock().as_mut() {
let _ = monitor.try_send(crate::SocketEvent::Connected(
resolved_endpoint,
new_peer_id.clone(),
));
}
register_disconnect_fn(new_peer_id, disconnect_tx.clone());
break 'retry;
}
Err(e) => {
log::warn!(
"Reconnection attempt {} to {} failed: {:?}",
attempt,
endpoint,
e
);
let jitter = {
let mut rng = rand::rng();
rng.random_range(0.0..0.1)
};
let next_interval_secs =
current_interval.as_secs_f64() * config.backoff_multiplier + jitter;
current_interval = Duration::from_secs_f64(
next_interval_secs.min(config.max_interval.as_secs_f64()),
);
}
}
}
}
});
ReconnectHandle {
shutdown_tx: Some(shutdown_tx),
task_handle,
}
}
async fn try_reconnect(
endpoint: &Endpoint,
backend: Arc<dyn MultiPeerBackend>,
) -> crate::ZmqResult<(PeerIdentity, Endpoint)> {
let (mut raw_socket, resolved_endpoint) = transport::connect(endpoint).await?;
greet_exchange(&mut raw_socket).await?;
let mut props = None;
if let Some(identity) = &backend.socket_options().peer_id {
let mut connect_ops = std::collections::HashMap::new();
connect_ops.insert("Identity".to_string(), identity.clone().into());
props = Some(connect_ops);
}
let peer_id = ready_exchange(&mut raw_socket, backend.socket_type(), props).await?;
backend.peer_connected(&peer_id, raw_socket).await;
Ok((peer_id, resolved_endpoint))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reconnect_config_default() {
let config = ReconnectConfig::default();
assert_eq!(config.initial_interval, Duration::from_millis(100));
assert_eq!(config.max_interval, Duration::from_secs(30));
assert!((config.backoff_multiplier - 2.0).abs() < f64::EPSILON);
}
}