use std::collections::HashMap;
use std::time::{Duration, Instant};
#[derive(Debug, Clone)]
pub struct ReconnectionConfig {
pub base_delay: Duration,
pub max_delay: Duration,
pub max_attempts: u32,
pub check_interval: Duration,
pub use_flat_delay: bool,
pub reset_on_exhaustion: bool,
}
impl Default for ReconnectionConfig {
fn default() -> Self {
Self {
base_delay: Duration::from_secs(2),
max_delay: Duration::from_secs(60),
max_attempts: 10,
check_interval: Duration::from_secs(5),
use_flat_delay: false,
reset_on_exhaustion: false,
}
}
}
impl ReconnectionConfig {
pub fn new(
base_delay: Duration,
max_delay: Duration,
max_attempts: u32,
check_interval: Duration,
) -> Self {
Self {
base_delay,
max_delay,
max_attempts,
check_interval,
use_flat_delay: false,
reset_on_exhaustion: false,
}
}
pub fn fast() -> Self {
Self {
base_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(5),
max_attempts: 5,
check_interval: Duration::from_secs(1),
use_flat_delay: false,
reset_on_exhaustion: false,
}
}
pub fn conservative() -> Self {
Self {
base_delay: Duration::from_secs(5),
max_delay: Duration::from_secs(120),
max_attempts: 5,
check_interval: Duration::from_secs(10),
use_flat_delay: false,
reset_on_exhaustion: false,
}
}
pub fn kotlin_normal() -> Self {
Self {
base_delay: Duration::from_millis(1000),
max_delay: Duration::from_millis(15000),
max_attempts: 20,
check_interval: Duration::from_secs(5),
use_flat_delay: false,
reset_on_exhaustion: false,
}
}
pub fn kotlin_high_priority() -> Self {
Self {
base_delay: Duration::from_millis(1000),
max_delay: Duration::from_millis(15000),
max_attempts: 20,
check_interval: Duration::from_secs(5),
use_flat_delay: true,
reset_on_exhaustion: true,
}
}
}
#[derive(Debug, Clone)]
struct PeerReconnectionState {
attempts: u32,
last_attempt: Instant,
disconnected_at: Instant,
}
impl PeerReconnectionState {
fn new() -> Self {
let now = Instant::now();
Self {
attempts: 0,
last_attempt: now,
disconnected_at: now,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ReconnectionStatus {
Ready,
Waiting {
remaining: Duration,
},
Exhausted {
attempts: u32,
},
NotTracked,
}
#[derive(Debug)]
pub struct ReconnectionManager {
config: ReconnectionConfig,
peers: HashMap<String, PeerReconnectionState>,
}
impl ReconnectionManager {
pub fn new(config: ReconnectionConfig) -> Self {
Self {
config,
peers: HashMap::new(),
}
}
pub fn with_defaults() -> Self {
Self::new(ReconnectionConfig::default())
}
pub fn track_disconnection(&mut self, address: String) {
use std::collections::hash_map::Entry;
if let Entry::Vacant(entry) = self.peers.entry(address.clone()) {
log::debug!("Tracking {} for reconnection", address);
entry.insert(PeerReconnectionState::new());
}
}
pub fn is_tracked(&self, address: &str) -> bool {
self.peers.contains_key(address)
}
pub fn get_status(&self, address: &str) -> ReconnectionStatus {
match self.peers.get(address) {
None => ReconnectionStatus::NotTracked,
Some(state) => {
if state.attempts >= self.config.max_attempts {
if self.config.reset_on_exhaustion {
return ReconnectionStatus::Ready;
}
return ReconnectionStatus::Exhausted {
attempts: state.attempts,
};
}
if state.attempts == 0 {
return ReconnectionStatus::Ready;
}
let delay = self.calculate_delay(state.attempts);
let elapsed = state.last_attempt.elapsed();
if elapsed >= delay {
ReconnectionStatus::Ready
} else {
ReconnectionStatus::Waiting {
remaining: delay - elapsed,
}
}
}
}
}
fn calculate_delay(&self, attempts: u32) -> Duration {
if self.config.use_flat_delay {
return self.config.base_delay;
}
let multiplier = 1u64 << attempts.min(30); let delay_ms = self.config.base_delay.as_millis() as u64 * multiplier;
let max_ms = self.config.max_delay.as_millis() as u64;
Duration::from_millis(delay_ms.min(max_ms))
}
pub fn get_peers_to_reconnect(&mut self) -> Vec<String> {
if self.config.reset_on_exhaustion {
let max = self.config.max_attempts;
for state in self.peers.values_mut() {
if state.attempts >= max {
log::debug!("Auto-resetting exhausted peer (reset_on_exhaustion)");
state.attempts = 0;
state.last_attempt = Instant::now();
}
}
}
self.peers
.iter()
.filter_map(|(address, state)| {
if state.attempts >= self.config.max_attempts {
return None;
}
if state.attempts == 0 {
return Some(address.clone());
}
let delay = self.calculate_delay(state.attempts);
if state.last_attempt.elapsed() >= delay {
Some(address.clone())
} else {
None
}
})
.collect()
}
pub fn record_attempt(&mut self, address: &str) {
let attempts = if let Some(state) = self.peers.get_mut(address) {
state.attempts += 1;
state.last_attempt = Instant::now();
Some(state.attempts)
} else {
None
};
if let Some(attempts) = attempts {
let next_delay = self.calculate_delay(attempts);
log::debug!(
"Reconnection attempt {} for {} (next delay: {:?})",
attempts,
address,
next_delay
);
}
}
pub fn on_connection_success(&mut self, address: &str) {
if self.peers.remove(address).is_some() {
log::debug!(
"Connection succeeded for {}, removed from reconnection tracking",
address
);
}
}
pub fn stop_tracking(&mut self, address: &str) {
if self.peers.remove(address).is_some() {
log::debug!("Stopped tracking {} for reconnection", address);
}
}
pub fn clear(&mut self) {
let count = self.peers.len();
self.peers.clear();
if count > 0 {
log::debug!("Cleared reconnection tracking for {} peers", count);
}
}
pub fn tracked_count(&self) -> usize {
self.peers.len()
}
pub fn get_peer_stats(&self, address: &str) -> Option<PeerReconnectionStats> {
self.peers.get(address).map(|state| PeerReconnectionStats {
attempts: state.attempts,
max_attempts: self.config.max_attempts,
disconnected_duration: state.disconnected_at.elapsed(),
next_attempt_delay: if state.attempts >= self.config.max_attempts {
Duration::MAX } else if state.attempts == 0 {
Duration::ZERO } else {
self.calculate_delay(state.attempts)
},
})
}
pub fn check_interval(&self) -> Duration {
self.config.check_interval
}
}
#[derive(Debug, Clone)]
pub struct PeerReconnectionStats {
pub attempts: u32,
pub max_attempts: u32,
pub disconnected_duration: Duration,
pub next_attempt_delay: Duration,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exponential_backoff() {
let config = ReconnectionConfig {
base_delay: Duration::from_secs(2),
max_delay: Duration::from_secs(60),
max_attempts: 10,
check_interval: Duration::from_secs(5),
use_flat_delay: false,
reset_on_exhaustion: false,
};
let manager = ReconnectionManager::new(config);
assert_eq!(manager.calculate_delay(0), Duration::from_secs(2));
assert_eq!(manager.calculate_delay(1), Duration::from_secs(4));
assert_eq!(manager.calculate_delay(2), Duration::from_secs(8));
assert_eq!(manager.calculate_delay(3), Duration::from_secs(16));
assert_eq!(manager.calculate_delay(4), Duration::from_secs(32));
assert_eq!(manager.calculate_delay(5), Duration::from_secs(60)); assert_eq!(manager.calculate_delay(6), Duration::from_secs(60));
}
#[test]
fn test_track_and_status() {
let mut manager = ReconnectionManager::new(ReconnectionConfig::fast());
assert_eq!(
manager.get_status("00:11:22:33:44:55"),
ReconnectionStatus::NotTracked
);
manager.track_disconnection("00:11:22:33:44:55".to_string());
assert!(manager.is_tracked("00:11:22:33:44:55"));
assert_eq!(
manager.get_status("00:11:22:33:44:55"),
ReconnectionStatus::Ready
);
}
#[test]
fn test_connection_success_clears_tracking() {
let mut manager = ReconnectionManager::with_defaults();
manager.track_disconnection("00:11:22:33:44:55".to_string());
assert!(manager.is_tracked("00:11:22:33:44:55"));
manager.on_connection_success("00:11:22:33:44:55");
assert!(!manager.is_tracked("00:11:22:33:44:55"));
assert_eq!(
manager.get_status("00:11:22:33:44:55"),
ReconnectionStatus::NotTracked
);
assert_eq!(manager.tracked_count(), 0);
}
#[test]
fn test_max_attempts_exhaustion() {
let config = ReconnectionConfig {
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(10),
max_attempts: 3,
check_interval: Duration::from_millis(1),
use_flat_delay: false,
reset_on_exhaustion: false,
};
let mut manager = ReconnectionManager::new(config);
manager.track_disconnection("test".to_string());
for _ in 0..3 {
manager.record_attempt("test");
}
assert_eq!(
manager.get_status("test"),
ReconnectionStatus::Exhausted { attempts: 3 }
);
}
#[test]
fn test_kotlin_normal_config_backoff() {
let config = ReconnectionConfig::kotlin_normal();
assert_eq!(config.base_delay, Duration::from_millis(1000));
assert_eq!(config.max_delay, Duration::from_millis(15000));
assert_eq!(config.max_attempts, 20);
assert!(!config.use_flat_delay);
assert!(!config.reset_on_exhaustion);
let manager = ReconnectionManager::new(config);
assert_eq!(manager.calculate_delay(0), Duration::from_millis(1000));
assert_eq!(manager.calculate_delay(1), Duration::from_millis(2000));
assert_eq!(manager.calculate_delay(2), Duration::from_millis(4000));
assert_eq!(manager.calculate_delay(3), Duration::from_millis(8000));
assert_eq!(manager.calculate_delay(4), Duration::from_millis(15000));
assert_eq!(manager.calculate_delay(5), Duration::from_millis(15000));
}
#[test]
fn test_flat_delay_mode() {
let config = ReconnectionConfig::kotlin_high_priority();
assert!(config.use_flat_delay);
let manager = ReconnectionManager::new(config);
assert_eq!(manager.calculate_delay(0), Duration::from_millis(1000));
assert_eq!(manager.calculate_delay(1), Duration::from_millis(1000));
assert_eq!(manager.calculate_delay(5), Duration::from_millis(1000));
assert_eq!(manager.calculate_delay(19), Duration::from_millis(1000));
}
#[test]
fn test_reset_on_exhaustion() {
let config = ReconnectionConfig {
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(10),
max_attempts: 3,
check_interval: Duration::from_millis(1),
use_flat_delay: true,
reset_on_exhaustion: true,
};
let mut manager = ReconnectionManager::new(config);
manager.track_disconnection("test".to_string());
for _ in 0..3 {
manager.record_attempt("test");
}
assert_eq!(manager.get_status("test"), ReconnectionStatus::Ready);
std::thread::sleep(Duration::from_millis(5));
let peers = manager.get_peers_to_reconnect();
assert!(peers.contains(&"test".to_string()));
let stats = manager.get_peer_stats("test").unwrap();
assert_eq!(stats.attempts, 0);
}
#[test]
fn test_stop_tracking_matches_reset() {
let mut manager = ReconnectionManager::with_defaults();
manager.track_disconnection("peer1".to_string());
manager.track_disconnection("peer2".to_string());
assert_eq!(manager.tracked_count(), 2);
manager.stop_tracking("peer1");
assert!(!manager.is_tracked("peer1"));
assert_eq!(manager.get_status("peer1"), ReconnectionStatus::NotTracked);
assert_eq!(manager.tracked_count(), 1);
manager.on_connection_success("peer2");
assert!(!manager.is_tracked("peer2"));
assert_eq!(manager.get_status("peer2"), ReconnectionStatus::NotTracked);
assert_eq!(manager.tracked_count(), 0);
}
}