1use std::collections::HashMap;
42use std::time::{Duration, Instant};
43
44#[derive(Debug, Clone)]
46pub struct ReconnectionConfig {
47 pub base_delay: Duration,
49 pub max_delay: Duration,
51 pub max_attempts: u32,
53 pub check_interval: Duration,
55}
56
57impl Default for ReconnectionConfig {
58 fn default() -> Self {
59 Self {
60 base_delay: Duration::from_secs(2),
61 max_delay: Duration::from_secs(60),
62 max_attempts: 10,
63 check_interval: Duration::from_secs(5),
64 }
65 }
66}
67
68impl ReconnectionConfig {
69 pub fn new(
71 base_delay: Duration,
72 max_delay: Duration,
73 max_attempts: u32,
74 check_interval: Duration,
75 ) -> Self {
76 Self {
77 base_delay,
78 max_delay,
79 max_attempts,
80 check_interval,
81 }
82 }
83
84 pub fn fast() -> Self {
86 Self {
87 base_delay: Duration::from_millis(500),
88 max_delay: Duration::from_secs(5),
89 max_attempts: 5,
90 check_interval: Duration::from_secs(1),
91 }
92 }
93
94 pub fn conservative() -> Self {
96 Self {
97 base_delay: Duration::from_secs(5),
98 max_delay: Duration::from_secs(120),
99 max_attempts: 5,
100 check_interval: Duration::from_secs(10),
101 }
102 }
103}
104
105#[derive(Debug, Clone)]
107struct PeerReconnectionState {
108 attempts: u32,
110 last_attempt: Instant,
112 disconnected_at: Instant,
114}
115
116impl PeerReconnectionState {
117 fn new() -> Self {
118 let now = Instant::now();
119 Self {
120 attempts: 0,
121 last_attempt: now,
122 disconnected_at: now,
123 }
124 }
125}
126
127#[derive(Debug, Clone, PartialEq, Eq)]
129pub enum ReconnectionStatus {
130 Ready,
132 Waiting {
134 remaining: Duration,
136 },
137 Exhausted {
139 attempts: u32,
141 },
142 NotTracked,
144}
145
146#[derive(Debug)]
151pub struct ReconnectionManager {
152 config: ReconnectionConfig,
154 peers: HashMap<String, PeerReconnectionState>,
156}
157
158impl ReconnectionManager {
159 pub fn new(config: ReconnectionConfig) -> Self {
161 Self {
162 config,
163 peers: HashMap::new(),
164 }
165 }
166
167 pub fn with_defaults() -> Self {
169 Self::new(ReconnectionConfig::default())
170 }
171
172 pub fn track_disconnection(&mut self, address: String) {
176 use std::collections::hash_map::Entry;
177
178 if let Entry::Vacant(entry) = self.peers.entry(address.clone()) {
179 log::debug!("Tracking {} for reconnection", address);
180 entry.insert(PeerReconnectionState::new());
181 }
182 }
183
184 pub fn is_tracked(&self, address: &str) -> bool {
186 self.peers.contains_key(address)
187 }
188
189 pub fn get_status(&self, address: &str) -> ReconnectionStatus {
191 match self.peers.get(address) {
192 None => ReconnectionStatus::NotTracked,
193 Some(state) => {
194 if state.attempts >= self.config.max_attempts {
195 return ReconnectionStatus::Exhausted {
196 attempts: state.attempts,
197 };
198 }
199
200 if state.attempts == 0 {
202 return ReconnectionStatus::Ready;
203 }
204
205 let delay = self.calculate_delay(state.attempts);
207 let elapsed = state.last_attempt.elapsed();
208
209 if elapsed >= delay {
210 ReconnectionStatus::Ready
211 } else {
212 ReconnectionStatus::Waiting {
213 remaining: delay - elapsed,
214 }
215 }
216 }
217 }
218 }
219
220 fn calculate_delay(&self, attempts: u32) -> Duration {
224 let multiplier = 1u64 << attempts.min(30); let delay_ms = self.config.base_delay.as_millis() as u64 * multiplier;
226 let max_ms = self.config.max_delay.as_millis() as u64;
227 Duration::from_millis(delay_ms.min(max_ms))
228 }
229
230 pub fn get_peers_to_reconnect(&self) -> Vec<String> {
236 self.peers
237 .iter()
238 .filter_map(|(address, state)| {
239 if state.attempts >= self.config.max_attempts {
240 return None;
241 }
242
243 if state.attempts == 0 {
245 return Some(address.clone());
246 }
247
248 let delay = self.calculate_delay(state.attempts);
250 if state.last_attempt.elapsed() >= delay {
251 Some(address.clone())
252 } else {
253 None
254 }
255 })
256 .collect()
257 }
258
259 pub fn record_attempt(&mut self, address: &str) {
263 let attempts = if let Some(state) = self.peers.get_mut(address) {
264 state.attempts += 1;
265 state.last_attempt = Instant::now();
266 Some(state.attempts)
267 } else {
268 None
269 };
270
271 if let Some(attempts) = attempts {
272 let next_delay = self.calculate_delay(attempts);
273 log::debug!(
274 "Reconnection attempt {} for {} (next delay: {:?})",
275 attempts,
276 address,
277 next_delay
278 );
279 }
280 }
281
282 pub fn on_connection_success(&mut self, address: &str) {
286 if self.peers.remove(address).is_some() {
287 log::debug!(
288 "Connection succeeded for {}, removed from reconnection tracking",
289 address
290 );
291 }
292 }
293
294 pub fn stop_tracking(&mut self, address: &str) {
296 if self.peers.remove(address).is_some() {
297 log::debug!("Stopped tracking {} for reconnection", address);
298 }
299 }
300
301 pub fn clear(&mut self) {
303 let count = self.peers.len();
304 self.peers.clear();
305 if count > 0 {
306 log::debug!("Cleared reconnection tracking for {} peers", count);
307 }
308 }
309
310 pub fn tracked_count(&self) -> usize {
312 self.peers.len()
313 }
314
315 pub fn get_peer_stats(&self, address: &str) -> Option<PeerReconnectionStats> {
317 self.peers.get(address).map(|state| PeerReconnectionStats {
318 attempts: state.attempts,
319 max_attempts: self.config.max_attempts,
320 disconnected_duration: state.disconnected_at.elapsed(),
321 next_attempt_delay: if state.attempts >= self.config.max_attempts {
322 Duration::MAX } else if state.attempts == 0 {
324 Duration::ZERO } else {
326 let delay = self.calculate_delay(state.attempts);
327 let elapsed = state.last_attempt.elapsed();
328 if elapsed >= delay {
329 Duration::ZERO
330 } else {
331 delay - elapsed
332 }
333 },
334 })
335 }
336
337 pub fn check_interval(&self) -> Duration {
339 self.config.check_interval
340 }
341}
342
343#[derive(Debug, Clone)]
345pub struct PeerReconnectionStats {
346 pub attempts: u32,
348 pub max_attempts: u32,
350 pub disconnected_duration: Duration,
352 pub next_attempt_delay: Duration,
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_exponential_backoff() {
362 let config = ReconnectionConfig {
363 base_delay: Duration::from_secs(2),
364 max_delay: Duration::from_secs(60),
365 max_attempts: 10,
366 check_interval: Duration::from_secs(5),
367 };
368 let manager = ReconnectionManager::new(config);
369
370 assert_eq!(manager.calculate_delay(0), Duration::from_secs(2));
372 assert_eq!(manager.calculate_delay(1), Duration::from_secs(4));
373 assert_eq!(manager.calculate_delay(2), Duration::from_secs(8));
374 assert_eq!(manager.calculate_delay(3), Duration::from_secs(16));
375 assert_eq!(manager.calculate_delay(4), Duration::from_secs(32));
376 assert_eq!(manager.calculate_delay(5), Duration::from_secs(60)); assert_eq!(manager.calculate_delay(6), Duration::from_secs(60));
378 }
379
380 #[test]
381 fn test_track_and_status() {
382 let mut manager = ReconnectionManager::new(ReconnectionConfig::fast());
383
384 assert_eq!(
386 manager.get_status("00:11:22:33:44:55"),
387 ReconnectionStatus::NotTracked
388 );
389
390 manager.track_disconnection("00:11:22:33:44:55".to_string());
392 assert!(manager.is_tracked("00:11:22:33:44:55"));
393
394 assert_eq!(
396 manager.get_status("00:11:22:33:44:55"),
397 ReconnectionStatus::Ready
398 );
399 }
400
401 #[test]
402 fn test_connection_success_clears_tracking() {
403 let mut manager = ReconnectionManager::with_defaults();
404
405 manager.track_disconnection("00:11:22:33:44:55".to_string());
406 assert!(manager.is_tracked("00:11:22:33:44:55"));
407
408 manager.on_connection_success("00:11:22:33:44:55");
409 assert!(!manager.is_tracked("00:11:22:33:44:55"));
410 }
411
412 #[test]
413 fn test_max_attempts_exhaustion() {
414 let config = ReconnectionConfig {
415 base_delay: Duration::from_millis(1),
416 max_delay: Duration::from_millis(10),
417 max_attempts: 3,
418 check_interval: Duration::from_millis(1),
419 };
420 let mut manager = ReconnectionManager::new(config);
421
422 manager.track_disconnection("test".to_string());
423
424 for _ in 0..3 {
426 manager.record_attempt("test");
427 }
428
429 assert_eq!(
431 manager.get_status("test"),
432 ReconnectionStatus::Exhausted { attempts: 3 }
433 );
434 }
435}