use crate::async_rt::task::{spawn, JoinHandle};
use crate::endpoint::Endpoint;
use crate::engine::backend::DisconnectNotifier;
use crate::MultiPeerBackend;
use crate::PeerIdentity;
use futures::channel::{mpsc, oneshot};
use futures::{FutureExt, StreamExt};
use rand::RngExt;
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ReconnectConfig {
pub initial_interval: Duration,
pub max_interval: Duration,
pub backoff_multiplier: f64,
}
impl Default for ReconnectConfig {
fn default() -> Self {
Self {
initial_interval: Duration::from_millis(100),
max_interval: Duration::from_secs(30),
backoff_multiplier: 2.0,
}
}
}
pub struct ReconnectHandle {
shutdown_tx: Option<oneshot::Sender<()>>,
#[allow(dead_code)] task_handle: JoinHandle<()>,
}
impl ReconnectHandle {
pub fn shutdown(mut self) {
if let Some(tx) = self.shutdown_tx.take() {
let _ = tx.send(());
}
}
}
pub type RegisterDisconnectFn = Box<dyn Fn(PeerIdentity, DisconnectNotifier) + Send + Sync>;
pub fn spawn_reconnect_task<B: MultiPeerBackend + 'static>(
endpoint: Endpoint,
backend: Arc<B>,
initial_peer_id: PeerIdentity,
register_disconnect_fn: RegisterDisconnectFn,
config: ReconnectConfig,
) -> ReconnectHandle {
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let (disconnect_tx, mut disconnect_rx) = mpsc::channel::<PeerIdentity>(1);
register_disconnect_fn(initial_peer_id.clone(), disconnect_tx.clone());
let task_handle = spawn(async move {
log::debug!("Reconnect task started for endpoint: {}", endpoint);
let mut shutdown_rx = shutdown_rx.fuse();
loop {
let peer_id = futures::select! {
peer_id = disconnect_rx.next() => {
if let Some(id) = peer_id {
id
} else {
log::debug!("Disconnect channel closed, stopping reconnect task");
return;
}
}
_ = shutdown_rx => {
log::debug!("Shutdown received, stopping reconnect task");
return;
}
};
log::info!(
"Peer {:?} disconnected from {}, starting reconnection",
peer_id,
endpoint
);
let stop_mask = backend.socket_options().reconnect_stop;
if stop_mask.contains(crate::socket::ReconnectStop::AFTER_DISCONNECT) {
log::info!(
"reconnect_stop: AFTER_DISCONNECT set; not reconnecting to {}",
endpoint
);
return;
}
let mut current_interval = config.initial_interval;
let mut attempt = 0u32;
'retry: loop {
attempt += 1;
log::debug!(
"Reconnection attempt {} to {} (waiting {:?})",
attempt,
endpoint,
current_interval
);
let sleep_fut = crate::async_rt::task::sleep(current_interval).fuse();
futures::pin_mut!(sleep_fut);
futures::select! {
_ = sleep_fut => {
}
_ = shutdown_rx => {
log::debug!("Shutdown received during backoff, stopping reconnect task");
return;
}
}
match try_reconnect(&endpoint, backend.clone()).await {
Ok((new_peer_id, resolved_endpoint)) => {
log::info!(
"Successfully reconnected to {} (peer {:?})",
endpoint,
new_peer_id
);
if let Some(monitor) = backend.monitor().lock().as_mut() {
let _ = monitor.try_send(crate::SocketEvent::Connected(
resolved_endpoint,
new_peer_id.clone(),
));
}
register_disconnect_fn(new_peer_id.clone(), disconnect_tx.clone());
backend.on_reconnect(&new_peer_id);
break 'retry;
}
Err(e) => {
log::warn!(
"Reconnection attempt {} to {} failed: {:?}",
attempt,
endpoint,
e
);
let is_refused = matches!(
&e,
crate::ZmqError::Network(io)
if io.kind() == std::io::ErrorKind::ConnectionRefused
);
let is_handshake = matches!(
&e,
crate::ZmqError::HandshakeTimeout
| crate::ZmqError::MechanismMismatch { .. }
| crate::ZmqError::PlainAuthFailed { .. }
| crate::ZmqError::ZapDenied { .. }
| crate::ZmqError::ServerRoleConflict
);
if is_refused
&& stop_mask.contains(crate::socket::ReconnectStop::CONN_REFUSED)
{
log::info!("reconnect_stop: CONN_REFUSED; giving up on {}", endpoint);
return;
}
if is_handshake
&& stop_mask.contains(crate::socket::ReconnectStop::HANDSHAKE_FAILED)
{
log::info!(
"reconnect_stop: HANDSHAKE_FAILED; giving up on {}",
endpoint
);
return;
}
let jitter = {
let mut rng = rand::rng();
rng.random_range(0.0..0.1)
};
let next_interval_secs =
current_interval.as_secs_f64() * config.backoff_multiplier + jitter;
current_interval = Duration::from_secs_f64(
next_interval_secs.min(config.max_interval.as_secs_f64()),
);
}
}
}
}
});
ReconnectHandle {
shutdown_tx: Some(shutdown_tx),
task_handle,
}
}
async fn try_reconnect<B: MultiPeerBackend + 'static>(
endpoint: &Endpoint,
backend: Arc<B>,
) -> crate::ZmqResult<(PeerIdentity, Endpoint)> {
let (resolved_endpoint, peer_id) =
crate::socket::handshake::connect_peer_forever(endpoint.clone(), backend, None).await?;
Ok((peer_id, resolved_endpoint))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reconnect_config_default() {
let config = ReconnectConfig::default();
assert_eq!(config.initial_interval, Duration::from_millis(100));
assert_eq!(config.max_interval, Duration::from_secs(30));
assert!((config.backoff_multiplier - 2.0).abs() < f64::EPSILON);
}
}