use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{broadcast, Mutex, RwLock};
use tokio::time::sleep;
use tracing::{debug, info, warn};
use crate::errors::{Error, Result, TransportError};
use crate::session::{Session, SessionConfig, SessionManager, SessionState};
use crate::shutdown::ShutdownSignal;
#[derive(Debug, Clone)]
pub struct ReconnectConfig {
pub max_retries: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
pub jitter: bool,
pub session_config: SessionConfig,
}
impl Default for ReconnectConfig {
fn default() -> Self {
Self {
max_retries: 10,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
backoff_multiplier: 2.0,
jitter: true,
session_config: SessionConfig::default(),
}
}
}
#[derive(Debug, Clone)]
pub enum ReconnectEvent {
Disconnected {
error: String,
},
Reconnecting {
attempt: u32,
delay: Duration,
},
Reconnected {
session_id: String,
attempts: u32,
},
Failed {
total_attempts: u32,
last_error: String,
},
}
#[derive(Debug, Clone, Default)]
pub struct ReconnectStats {
pub total_attempts: u64,
pub successful_reconnects: u64,
pub failed_reconnects: u64,
pub consecutive_failures: u32,
}
pub struct ReconnectingSession {
target: String,
config: ReconnectConfig,
session: Arc<RwLock<Option<Session>>>,
manager: Arc<SessionManager>,
events: broadcast::Sender<ReconnectEvent>,
stats: Arc<ReconnectStatsInner>,
shutdown: ShutdownSignal,
reconnecting: Arc<Mutex<bool>>,
}
struct ReconnectStatsInner {
total_attempts: AtomicU64,
successful_reconnects: AtomicU64,
failed_reconnects: AtomicU64,
consecutive_failures: AtomicU64,
}
impl Default for ReconnectStatsInner {
fn default() -> Self {
Self {
total_attempts: AtomicU64::new(0),
successful_reconnects: AtomicU64::new(0),
failed_reconnects: AtomicU64::new(0),
consecutive_failures: AtomicU64::new(0),
}
}
}
impl ReconnectingSession {
pub async fn new(target: &str, config: ReconnectConfig) -> Result<Self> {
let manager = Arc::new(SessionManager::new().await?);
Self::with_manager(target, config, manager).await
}
pub async fn with_manager(
target: &str,
config: ReconnectConfig,
manager: Arc<SessionManager>,
) -> Result<Self> {
let (events_tx, _) = broadcast::channel(100);
let instance = Self {
target: target.to_string(),
config,
session: Arc::new(RwLock::new(None)),
manager,
events: events_tx,
stats: Arc::new(ReconnectStatsInner::default()),
shutdown: ShutdownSignal::new(),
reconnecting: Arc::new(Mutex::new(false)),
};
instance.connect().await?;
Ok(instance)
}
async fn connect(&self) -> Result<()> {
let mut session_config = self.config.session_config.clone();
session_config.target = self.target.clone();
let session = self.manager.start_session(session_config).await?;
{
let mut guard = self.session.write().await;
*guard = Some(session);
}
info!(target = %self.target, "Connected to target");
Ok(())
}
async fn reconnect(&self, initial_error: &str) -> Result<()> {
{
let mut reconnecting = self.reconnecting.lock().await;
if *reconnecting {
debug!("Already reconnecting, skipping");
return Err(Error::InvalidState("Already reconnecting".to_string()));
}
*reconnecting = true;
}
let _ = self.events.send(ReconnectEvent::Disconnected {
error: initial_error.to_string(),
});
let mut delay = self.config.initial_delay;
let mut attempts = 0u32;
let mut last_error = initial_error.to_string();
loop {
if self.config.max_retries > 0 && attempts >= self.config.max_retries {
self.stats.failed_reconnects.fetch_add(1, Ordering::Relaxed);
let _ = self.events.send(ReconnectEvent::Failed {
total_attempts: attempts,
last_error: last_error.clone(),
});
*self.reconnecting.lock().await = false;
return Err(Error::Transport(TransportError::ConnectionFailed(format!(
"Max reconnection attempts ({}) exceeded: {}",
self.config.max_retries, last_error
))));
}
attempts += 1;
self.stats.total_attempts.fetch_add(1, Ordering::Relaxed);
let actual_delay = if self.config.jitter {
let jitter_factor = 0.5 + rand_jitter() * 0.5; Duration::from_secs_f64(delay.as_secs_f64() * jitter_factor)
} else {
delay
};
let _ = self.events.send(ReconnectEvent::Reconnecting {
attempt: attempts,
delay: actual_delay,
});
info!(
target = %self.target,
attempt = attempts,
delay_ms = actual_delay.as_millis(),
"Attempting reconnection"
);
tokio::select! {
_ = sleep(actual_delay) => {}
_ = self.shutdown.cancelled() => {
*self.reconnecting.lock().await = false;
return Err(Error::InvalidState("Shutdown requested".to_string()));
}
}
match self.connect().await {
Ok(()) => {
self.stats
.successful_reconnects
.fetch_add(1, Ordering::Relaxed);
self.stats.consecutive_failures.store(0, Ordering::Relaxed);
let session_id = {
let guard = self.session.read().await;
guard
.as_ref()
.map(|s| s.id().to_string())
.unwrap_or_default()
};
let _ = self.events.send(ReconnectEvent::Reconnected {
session_id,
attempts,
});
*self.reconnecting.lock().await = false;
return Ok(());
}
Err(e) => {
last_error = e.to_string();
self.stats
.consecutive_failures
.fetch_add(1, Ordering::Relaxed);
warn!(
target = %self.target,
attempt = attempts,
error = %e,
"Reconnection attempt failed"
);
}
}
delay = Duration::from_secs_f64(
(delay.as_secs_f64() * self.config.backoff_multiplier)
.min(self.config.max_delay.as_secs_f64()),
);
}
}
pub async fn send(&self, data: bytes::Bytes) -> Result<()> {
loop {
let result = {
let guard = self.session.read().await;
match guard.as_ref() {
Some(session) => session.send(data.clone()).await,
None => Err(Error::InvalidState("No active session".to_string())),
}
};
match result {
Ok(()) => return Ok(()),
Err(e) => {
if is_connection_error(&e) {
warn!(error = %e, "Connection error, attempting reconnection");
self.reconnect(&e.to_string()).await?;
continue;
}
return Err(e);
}
}
}
}
pub async fn output(&self) -> Option<crate::channels::OutputStream> {
let guard = self.session.read().await;
guard.as_ref().map(|s| s.output())
}
pub fn events(&self) -> broadcast::Receiver<ReconnectEvent> {
self.events.subscribe()
}
pub fn stats(&self) -> ReconnectStats {
ReconnectStats {
total_attempts: self.stats.total_attempts.load(Ordering::Relaxed),
successful_reconnects: self.stats.successful_reconnects.load(Ordering::Relaxed),
failed_reconnects: self.stats.failed_reconnects.load(Ordering::Relaxed),
consecutive_failures: self.stats.consecutive_failures.load(Ordering::Relaxed) as u32,
}
}
pub async fn state(&self) -> SessionState {
let guard = self.session.read().await;
match guard.as_ref() {
Some(session) => session.state().await,
None => SessionState::Terminated,
}
}
pub async fn is_ready(&self) -> bool {
let guard = self.session.read().await;
guard.as_ref().map(|s| s.is_ready()).unwrap_or(false)
}
pub async fn is_reconnecting(&self) -> bool {
*self.reconnecting.lock().await
}
pub fn target(&self) -> &str {
&self.target
}
pub fn shutdown_signal(&self) -> ShutdownSignal {
self.shutdown.clone()
}
pub async fn terminate(&self) -> Result<()> {
self.shutdown.shutdown();
let mut guard = self.session.write().await;
if let Some(mut session) = guard.take() {
session.terminate().await?;
}
info!(target = %self.target, "Reconnecting session terminated");
Ok(())
}
pub async fn force_reconnect(&self) -> Result<()> {
info!(target = %self.target, "Forcing reconnection");
{
let mut guard = self.session.write().await;
if let Some(mut session) = guard.take() {
let _ = session.terminate().await;
}
}
self.connect().await
}
}
fn is_connection_error(error: &Error) -> bool {
match error {
Error::Transport(TransportError::WebSocket(_)) => true,
Error::Transport(TransportError::ConnectionFailed(_)) => true,
Error::Transport(TransportError::ConnectionClosed { .. }) => true,
Error::Transport(TransportError::HeartbeatTimeout) => true,
Error::Timeout => true,
Error::Io(e) => {
use std::io::ErrorKind;
matches!(
e.kind(),
ErrorKind::ConnectionReset
| ErrorKind::ConnectionAborted
| ErrorKind::BrokenPipe
| ErrorKind::TimedOut
| ErrorKind::NotConnected
)
}
_ => false,
}
}
fn rand_jitter() -> f64 {
use std::time::SystemTime;
let nanos = SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.subsec_nanos();
(nanos as f64) / (u32::MAX as f64)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reconnect_config_default() {
let config = ReconnectConfig::default();
assert_eq!(config.max_retries, 10);
assert_eq!(config.initial_delay, Duration::from_secs(1));
assert_eq!(config.max_delay, Duration::from_secs(60));
assert!((config.backoff_multiplier - 2.0).abs() < f64::EPSILON);
assert!(config.jitter);
}
#[test]
fn test_reconnect_stats_default() {
let stats = ReconnectStats::default();
assert_eq!(stats.total_attempts, 0);
assert_eq!(stats.successful_reconnects, 0);
assert_eq!(stats.failed_reconnects, 0);
assert_eq!(stats.consecutive_failures, 0);
}
#[test]
fn test_is_connection_error() {
assert!(is_connection_error(&Error::Transport(
TransportError::WebSocket("connection closed".to_string())
)));
assert!(is_connection_error(&Error::Transport(
TransportError::ConnectionFailed("network unreachable".to_string())
)));
assert!(is_connection_error(&Error::Timeout));
assert!(!is_connection_error(&Error::Config(
"bad config".to_string()
)));
use crate::errors::ProtocolError;
assert!(!is_connection_error(&Error::Protocol(
ProtocolError::InvalidMessage("invalid message".to_string())
)));
}
#[test]
fn test_rand_jitter() {
let j1 = rand_jitter();
assert!((0.0..=1.0).contains(&j1));
}
#[test]
fn test_reconnect_event_debug() {
let event = ReconnectEvent::Disconnected {
error: "test".to_string(),
};
let debug = format!("{:?}", event);
assert!(debug.contains("Disconnected"));
let event = ReconnectEvent::Reconnecting {
attempt: 1,
delay: Duration::from_secs(1),
};
let debug = format!("{:?}", event);
assert!(debug.contains("Reconnecting"));
let event = ReconnectEvent::Reconnected {
session_id: "abc".to_string(),
attempts: 2,
};
let debug = format!("{:?}", event);
assert!(debug.contains("Reconnected"));
let event = ReconnectEvent::Failed {
total_attempts: 5,
last_error: "failed".to_string(),
};
let debug = format!("{:?}", event);
assert!(debug.contains("Failed"));
}
}