1use crate::connection_pool::{ConnectionFactory, PooledConnection};
7use anyhow::{anyhow, Result};
8use serde::{Deserialize, Serialize};
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::{broadcast, RwLock};
14use tokio::time::sleep;
15use tracing::{error, info, warn};
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ReconnectConfig {
20 pub initial_delay: Duration,
22 pub max_delay: Duration,
24 pub multiplier: f64,
26 pub max_attempts: u32,
28 pub jitter_factor: f64,
30 pub connection_timeout: Duration,
32 pub enable_callbacks: bool,
34}
35
36impl Default for ReconnectConfig {
37 fn default() -> Self {
38 Self {
39 initial_delay: Duration::from_millis(100),
40 max_delay: Duration::from_secs(60),
41 multiplier: 2.0,
42 max_attempts: 10,
43 jitter_factor: 0.1,
44 connection_timeout: Duration::from_secs(30),
45 enable_callbacks: true,
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub enum ReconnectStrategy {
53 ExponentialBackoff,
55 FixedDelay(Duration),
57 LinearBackoff(Duration),
59 Custom,
61}
62
63#[derive(Debug, Clone)]
65pub enum ReconnectEvent {
66 AttemptStarted {
68 connection_id: String,
69 attempt: u32,
70 delay: Duration,
71 },
72 AttemptSucceeded {
74 connection_id: String,
75 attempt: u32,
76 total_time: Duration,
77 },
78 AttemptFailed {
80 connection_id: String,
81 attempt: u32,
82 error: String,
83 next_delay: Option<Duration>,
84 },
85 ReconnectionExhausted {
87 connection_id: String,
88 total_attempts: u32,
89 total_time: Duration,
90 },
91}
92
93#[derive(Debug, Clone, Default)]
95pub struct ReconnectStatistics {
96 pub total_attempts: u64,
97 pub successful_reconnects: u64,
98 pub failed_reconnects: u64,
99 pub current_streak: u32,
100 pub longest_streak: u32,
101 pub total_reconnect_time: Duration,
102 pub avg_reconnect_time: Duration,
103 pub last_reconnect_attempt: Option<Instant>,
104 pub last_successful_reconnect: Option<Instant>,
105}
106
107pub type ConnectionFailureCallback =
109 Arc<dyn Fn(String, String, u32) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
110
111pub struct ReconnectManager<T: PooledConnection> {
113 config: ReconnectConfig,
114 strategy: ReconnectStrategy,
115 statistics: Arc<RwLock<ReconnectStatistics>>,
116 event_sender: broadcast::Sender<ReconnectEvent>,
117 failure_callbacks: Arc<RwLock<Vec<ConnectionFailureCallback>>>,
118 _phantom: std::marker::PhantomData<T>,
119}
120
121impl<T: PooledConnection + Clone> ReconnectManager<T> {
122 pub fn new(config: ReconnectConfig, strategy: ReconnectStrategy) -> Self {
124 let (event_sender, _) = broadcast::channel(1000);
125
126 Self {
127 config,
128 strategy,
129 statistics: Arc::new(RwLock::new(ReconnectStatistics::default())),
130 event_sender,
131 failure_callbacks: Arc::new(RwLock::new(Vec::new())),
132 _phantom: std::marker::PhantomData,
133 }
134 }
135
136 pub async fn reconnect(
138 &self,
139 connection_id: String,
140 factory: Arc<dyn ConnectionFactory<T>>,
141 ) -> Result<T> {
142 let start_time = Instant::now();
143 let mut attempt = 0;
144 let mut current_delay = self.config.initial_delay;
145
146 loop {
147 attempt += 1;
148
149 if self.config.max_attempts > 0 && attempt > self.config.max_attempts {
151 let total_time = start_time.elapsed();
152
153 let _ = self
154 .event_sender
155 .send(ReconnectEvent::ReconnectionExhausted {
156 connection_id: connection_id.clone(),
157 total_attempts: attempt - 1,
158 total_time,
159 });
160
161 let mut stats = self.statistics.write().await;
163 stats.failed_reconnects += 1;
164 stats.current_streak = 0;
165
166 if self.config.enable_callbacks {
168 self.invoke_failure_callbacks(
169 connection_id.clone(),
170 "Maximum retry attempts exhausted".to_string(),
171 attempt - 1,
172 )
173 .await;
174 }
175
176 return Err(anyhow!(
177 "Failed to reconnect after {} attempts",
178 self.config.max_attempts
179 ));
180 }
181
182 let jittered_delay = self.apply_jitter(current_delay);
184
185 if attempt > 1 {
186 info!(
187 "Reconnection attempt {} for {} in {:?}",
188 attempt, connection_id, jittered_delay
189 );
190
191 let _ = self.event_sender.send(ReconnectEvent::AttemptStarted {
192 connection_id: connection_id.clone(),
193 attempt,
194 delay: jittered_delay,
195 });
196
197 sleep(jittered_delay).await;
198 }
199
200 {
202 let mut stats = self.statistics.write().await;
203 stats.total_attempts += 1;
204 stats.last_reconnect_attempt = Some(Instant::now());
205 }
206
207 match tokio::time::timeout(self.config.connection_timeout, factory.create_connection())
209 .await
210 {
211 Ok(Ok(connection)) => {
212 let total_time = start_time.elapsed();
213
214 info!(
215 "Successfully reconnected {} after {} attempts in {:?}",
216 connection_id, attempt, total_time
217 );
218
219 let _ = self.event_sender.send(ReconnectEvent::AttemptSucceeded {
220 connection_id: connection_id.clone(),
221 attempt,
222 total_time,
223 });
224
225 let mut stats = self.statistics.write().await;
227 stats.successful_reconnects += 1;
228 stats.current_streak += 1;
229 stats.longest_streak = stats.longest_streak.max(stats.current_streak);
230 stats.total_reconnect_time += total_time;
231 stats.last_successful_reconnect = Some(Instant::now());
232
233 if stats.successful_reconnects > 0 {
234 stats.avg_reconnect_time =
235 stats.total_reconnect_time / stats.successful_reconnects as u32;
236 }
237
238 return Ok(connection);
239 }
240 Ok(Err(e)) => {
241 warn!(
242 "Reconnection attempt {} for {} failed: {}",
243 attempt, connection_id, e
244 );
245
246 current_delay = self.calculate_next_delay(current_delay, attempt);
248 let next_delay = if attempt < self.config.max_attempts {
249 Some(current_delay)
250 } else {
251 None
252 };
253
254 let _ = self.event_sender.send(ReconnectEvent::AttemptFailed {
255 connection_id: connection_id.clone(),
256 attempt,
257 error: e.to_string(),
258 next_delay,
259 });
260
261 if self.config.enable_callbacks && attempt % 3 == 0 {
263 self.invoke_failure_callbacks(
264 connection_id.clone(),
265 e.to_string(),
266 attempt,
267 )
268 .await;
269 }
270 }
271 Err(_) => {
272 error!(
273 "Reconnection attempt {} for {} timed out",
274 attempt, connection_id
275 );
276
277 current_delay = self.calculate_next_delay(current_delay, attempt);
278
279 let _ = self.event_sender.send(ReconnectEvent::AttemptFailed {
280 connection_id: connection_id.clone(),
281 attempt,
282 error: "Connection timeout".to_string(),
283 next_delay: Some(current_delay),
284 });
285 }
286 }
287 }
288 }
289
290 fn calculate_next_delay(&self, current_delay: Duration, attempt: u32) -> Duration {
292 match &self.strategy {
293 ReconnectStrategy::ExponentialBackoff => {
294 let next_delay = current_delay.mul_f64(self.config.multiplier);
295 next_delay.min(self.config.max_delay)
296 }
297 ReconnectStrategy::FixedDelay(delay) => *delay,
298 ReconnectStrategy::LinearBackoff(increment) => {
299 let next_delay = self.config.initial_delay + (*increment * attempt);
300 next_delay.min(self.config.max_delay)
301 }
302 ReconnectStrategy::Custom => {
303 let next_delay = current_delay.mul_f64(self.config.multiplier);
305 next_delay.min(self.config.max_delay)
306 }
307 }
308 }
309
310 fn apply_jitter(&self, delay: Duration) -> Duration {
312 if self.config.jitter_factor <= 0.0 {
313 return delay;
314 }
315
316 let jitter_range = delay.as_millis() as f64 * self.config.jitter_factor;
317 let jitter = (fastrand::f64() - 0.5) * 2.0 * jitter_range;
318 let jittered_millis = (delay.as_millis() as f64 + jitter).max(0.0) as u64;
319
320 Duration::from_millis(jittered_millis)
321 }
322
323 pub async fn register_failure_callback<F>(&self, callback: F)
325 where
326 F: Fn(String, String, u32) -> Pin<Box<dyn Future<Output = ()> + Send>>
327 + Send
328 + Sync
329 + 'static,
330 {
331 let mut callbacks = self.failure_callbacks.write().await;
332 callbacks.push(Arc::new(callback));
333 }
334
335 async fn invoke_failure_callbacks(&self, connection_id: String, error: String, attempt: u32) {
337 let callbacks = self.failure_callbacks.read().await;
338
339 for callback in callbacks.iter() {
340 let fut = callback(connection_id.clone(), error.clone(), attempt);
341 tokio::spawn(async move {
342 fut.await;
343 });
344 }
345 }
346
347 pub async fn get_statistics(&self) -> ReconnectStatistics {
349 self.statistics.read().await.clone()
350 }
351
352 pub async fn reset_statistics(&self) {
354 *self.statistics.write().await = ReconnectStatistics::default();
355 }
356
357 pub fn subscribe(&self) -> broadcast::Receiver<ReconnectEvent> {
359 self.event_sender.subscribe()
360 }
361}
362
363pub struct ResilientConnection<T: PooledConnection> {
365 connection: Option<T>,
366 connection_id: String,
367 factory: Arc<dyn ConnectionFactory<T>>,
368 reconnect_manager: Arc<ReconnectManager<T>>,
369 last_error: Option<String>,
370}
371
372impl<T: PooledConnection + Clone> ResilientConnection<T> {
373 pub async fn new(
375 connection_id: String,
376 factory: Arc<dyn ConnectionFactory<T>>,
377 reconnect_manager: Arc<ReconnectManager<T>>,
378 ) -> Result<Self> {
379 let connection = factory.create_connection().await?;
380
381 Ok(Self {
382 connection: Some(connection),
383 connection_id,
384 factory,
385 reconnect_manager,
386 last_error: None,
387 })
388 }
389
390 pub async fn get_connection(&mut self) -> Result<&mut T> {
392 let needs_reconnection = match self.connection {
394 Some(ref mut conn) => !conn.is_healthy().await,
395 None => true,
396 };
397
398 if !needs_reconnection {
399 return self
401 .connection
402 .as_mut()
403 .ok_or_else(|| anyhow!("Connection unexpectedly None"));
404 }
405
406 info!(
408 "Connection {} is unhealthy, attempting reconnection",
409 self.connection_id
410 );
411
412 match self
413 .reconnect_manager
414 .reconnect(self.connection_id.clone(), self.factory.clone())
415 .await
416 {
417 Ok(new_conn) => {
418 self.connection = Some(new_conn);
419 self.last_error = None;
420 self.connection
421 .as_mut()
422 .ok_or_else(|| anyhow!("Connection unexpectedly None"))
423 }
424 Err(e) => {
425 self.last_error = Some(e.to_string());
426 Err(e)
427 }
428 }
429 }
430
431 pub async fn is_healthy(&self) -> bool {
433 if let Some(ref conn) = self.connection {
434 conn.is_healthy().await
435 } else {
436 false
437 }
438 }
439
440 pub fn last_error(&self) -> Option<&str> {
442 self.last_error.as_deref()
443 }
444
445 pub async fn reconnect(&mut self) -> Result<()> {
447 let new_conn = self
448 .reconnect_manager
449 .reconnect(self.connection_id.clone(), self.factory.clone())
450 .await?;
451
452 if let Some(mut old_conn) = self.connection.take() {
454 let _ = old_conn.close().await;
455 }
456
457 self.connection = Some(new_conn);
458 self.last_error = None;
459 Ok(())
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
467
468 #[derive(Clone)]
469 struct TestConnection {
470 id: u32,
471 healthy: Arc<AtomicBool>,
472 created_at: Instant,
473 }
474
475 #[async_trait::async_trait]
476 impl PooledConnection for TestConnection {
477 async fn is_healthy(&self) -> bool {
478 self.healthy.load(Ordering::Relaxed)
479 }
480
481 async fn close(&mut self) -> Result<()> {
482 Ok(())
483 }
484
485 fn clone_connection(&self) -> Box<dyn PooledConnection> {
486 Box::new(TestConnection {
487 id: self.id,
488 healthy: Arc::new(AtomicBool::new(self.healthy.load(Ordering::Relaxed))),
489 created_at: self.created_at,
490 })
491 }
492
493 fn created_at(&self) -> Instant {
494 self.created_at
495 }
496
497 fn last_activity(&self) -> Instant {
498 Instant::now()
499 }
500
501 fn update_activity(&mut self) {}
502 }
503
504 struct TestConnectionFactory {
505 counter: Arc<AtomicU32>,
506 should_fail: Arc<AtomicBool>,
507 fail_count: Arc<AtomicU32>,
508 }
509
510 #[async_trait::async_trait]
511 impl ConnectionFactory<TestConnection> for TestConnectionFactory {
512 async fn create_connection(&self) -> Result<TestConnection> {
513 let current_fails = self.fail_count.load(Ordering::Relaxed);
514
515 if self.should_fail.load(Ordering::Relaxed) && current_fails > 0 {
516 self.fail_count.fetch_sub(1, Ordering::Relaxed);
517 return Err(anyhow!("Simulated connection failure"));
518 }
519
520 let id = self.counter.fetch_add(1, Ordering::Relaxed);
521 Ok(TestConnection {
522 id,
523 healthy: Arc::new(AtomicBool::new(true)),
524 created_at: Instant::now(),
525 })
526 }
527 }
528
529 #[tokio::test]
530 async fn test_exponential_backoff() {
531 let config = ReconnectConfig {
532 initial_delay: Duration::from_millis(10),
533 max_delay: Duration::from_millis(100),
534 multiplier: 2.0,
535 max_attempts: 5,
536 jitter_factor: 0.0,
537 ..Default::default()
538 };
539
540 let manager =
541 ReconnectManager::<TestConnection>::new(config, ReconnectStrategy::ExponentialBackoff);
542
543 let factory = Arc::new(TestConnectionFactory {
544 counter: Arc::new(AtomicU32::new(0)),
545 should_fail: Arc::new(AtomicBool::new(true)),
546 fail_count: Arc::new(AtomicU32::new(3)), });
548
549 let start = Instant::now();
550 let result = manager.reconnect("test-conn".to_string(), factory).await;
551 let elapsed = start.elapsed();
552
553 assert!(result.is_ok());
554
555 assert!(elapsed >= Duration::from_millis(50));
558 assert!(elapsed < Duration::from_millis(300));
559
560 let stats = manager.get_statistics().await;
561 assert_eq!(stats.total_attempts, 4);
562 assert_eq!(stats.successful_reconnects, 1);
563 }
564
565 #[tokio::test]
566 async fn test_max_attempts() {
567 let config = ReconnectConfig {
568 initial_delay: Duration::from_millis(1),
569 max_attempts: 3,
570 ..Default::default()
571 };
572
573 let manager =
574 ReconnectManager::<TestConnection>::new(config, ReconnectStrategy::ExponentialBackoff);
575
576 let factory = Arc::new(TestConnectionFactory {
577 counter: Arc::new(AtomicU32::new(0)),
578 should_fail: Arc::new(AtomicBool::new(true)),
579 fail_count: Arc::new(AtomicU32::new(100)), });
581
582 let result = manager.reconnect("test-conn".to_string(), factory).await;
583 assert!(result.is_err());
584
585 let stats = manager.get_statistics().await;
586 assert_eq!(stats.total_attempts, 3);
587 assert_eq!(stats.failed_reconnects, 1);
588 }
589
590 #[tokio::test]
591 async fn test_failure_callbacks() {
592 let config = ReconnectConfig {
593 initial_delay: Duration::from_millis(1),
594 max_attempts: 3,
595 enable_callbacks: true,
596 ..Default::default()
597 };
598
599 let manager = ReconnectManager::<TestConnection>::new(
600 config,
601 ReconnectStrategy::FixedDelay(Duration::from_millis(1)),
602 );
603
604 let callback_called = Arc::new(AtomicBool::new(false));
605 let callback_called_clone = callback_called.clone();
606
607 manager
608 .register_failure_callback(move |_id, _error, _attempt| {
609 let called = callback_called_clone.clone();
610 Box::pin(async move {
611 called.store(true, Ordering::Relaxed);
612 })
613 })
614 .await;
615
616 let factory = Arc::new(TestConnectionFactory {
617 counter: Arc::new(AtomicU32::new(0)),
618 should_fail: Arc::new(AtomicBool::new(true)),
619 fail_count: Arc::new(AtomicU32::new(100)),
620 });
621
622 let _ = manager.reconnect("test-conn".to_string(), factory).await;
623
624 tokio::time::sleep(Duration::from_millis(10)).await;
626
627 assert!(callback_called.load(Ordering::Relaxed));
628 }
629
630 #[tokio::test]
631 async fn test_resilient_connection() {
632 let config = ReconnectConfig::default();
633 let manager = Arc::new(ReconnectManager::<TestConnection>::new(
634 config,
635 ReconnectStrategy::ExponentialBackoff,
636 ));
637
638 let _healthy_flag = Arc::new(AtomicBool::new(true));
639 let factory = Arc::new(TestConnectionFactory {
640 counter: Arc::new(AtomicU32::new(0)),
641 should_fail: Arc::new(AtomicBool::new(false)),
642 fail_count: Arc::new(AtomicU32::new(0)),
643 });
644
645 let mut resilient = ResilientConnection::new("test-conn".to_string(), factory, manager)
646 .await
647 .unwrap();
648
649 assert!(resilient.is_healthy().await);
651 let conn = resilient.get_connection().await.unwrap();
652 assert!(conn.is_healthy().await);
653 }
654}