Skip to main content

oxirs_stream/
reconnect.rs

1//! Automatic reconnection logic with exponential backoff
2//!
3//! Provides resilient connection management with configurable retry strategies,
4//! connection failure callbacks, and comprehensive error handling.
5
6use crate::connection_pool::{ConnectionFactory, PooledConnection};
7use anyhow::{anyhow, Result};
8use serde::{Deserialize, Serialize};
9use std::future::Future;
10use std::pin::Pin;
11use std::sync::Arc;
12use std::time::{Duration, Instant};
13use tokio::sync::{broadcast, RwLock};
14use tokio::time::sleep;
15use tracing::{error, info, warn};
16
17/// Reconnection configuration
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct ReconnectConfig {
20    /// Initial retry delay
21    pub initial_delay: Duration,
22    /// Maximum retry delay
23    pub max_delay: Duration,
24    /// Exponential backoff multiplier
25    pub multiplier: f64,
26    /// Maximum retry attempts (0 for unlimited)
27    pub max_attempts: u32,
28    /// Jitter factor (0.0 to 1.0)
29    pub jitter_factor: f64,
30    /// Connection timeout
31    pub connection_timeout: Duration,
32    /// Enable connection failure callbacks
33    pub enable_callbacks: bool,
34}
35
36impl Default for ReconnectConfig {
37    fn default() -> Self {
38        Self {
39            initial_delay: Duration::from_millis(100),
40            max_delay: Duration::from_secs(60),
41            multiplier: 2.0,
42            max_attempts: 10,
43            jitter_factor: 0.1,
44            connection_timeout: Duration::from_secs(30),
45            enable_callbacks: true,
46        }
47    }
48}
49
50/// Reconnection strategy
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub enum ReconnectStrategy {
53    /// Exponential backoff with jitter
54    ExponentialBackoff,
55    /// Fixed delay between attempts
56    FixedDelay(Duration),
57    /// Linear backoff
58    LinearBackoff(Duration),
59    /// Custom strategy with callback
60    Custom,
61}
62
63/// Reconnection event
64#[derive(Debug, Clone)]
65pub enum ReconnectEvent {
66    /// Reconnection attempt started
67    AttemptStarted {
68        connection_id: String,
69        attempt: u32,
70        delay: Duration,
71    },
72    /// Reconnection attempt succeeded
73    AttemptSucceeded {
74        connection_id: String,
75        attempt: u32,
76        total_time: Duration,
77    },
78    /// Reconnection attempt failed
79    AttemptFailed {
80        connection_id: String,
81        attempt: u32,
82        error: String,
83        next_delay: Option<Duration>,
84    },
85    /// All reconnection attempts exhausted
86    ReconnectionExhausted {
87        connection_id: String,
88        total_attempts: u32,
89        total_time: Duration,
90    },
91}
92
93/// Reconnection statistics
94#[derive(Debug, Clone, Default)]
95pub struct ReconnectStatistics {
96    pub total_attempts: u64,
97    pub successful_reconnects: u64,
98    pub failed_reconnects: u64,
99    pub current_streak: u32,
100    pub longest_streak: u32,
101    pub total_reconnect_time: Duration,
102    pub avg_reconnect_time: Duration,
103    pub last_reconnect_attempt: Option<Instant>,
104    pub last_successful_reconnect: Option<Instant>,
105}
106
107/// Callback for connection failures
108pub type ConnectionFailureCallback =
109    Arc<dyn Fn(String, String, u32) -> Pin<Box<dyn Future<Output = ()> + Send>> + Send + Sync>;
110
111/// Automatic reconnection manager
112pub struct ReconnectManager<T: PooledConnection> {
113    config: ReconnectConfig,
114    strategy: ReconnectStrategy,
115    statistics: Arc<RwLock<ReconnectStatistics>>,
116    event_sender: broadcast::Sender<ReconnectEvent>,
117    failure_callbacks: Arc<RwLock<Vec<ConnectionFailureCallback>>>,
118    _phantom: std::marker::PhantomData<T>,
119}
120
121impl<T: PooledConnection + Clone> ReconnectManager<T> {
122    /// Create a new reconnection manager
123    pub fn new(config: ReconnectConfig, strategy: ReconnectStrategy) -> Self {
124        let (event_sender, _) = broadcast::channel(1000);
125
126        Self {
127            config,
128            strategy,
129            statistics: Arc::new(RwLock::new(ReconnectStatistics::default())),
130            event_sender,
131            failure_callbacks: Arc::new(RwLock::new(Vec::new())),
132            _phantom: std::marker::PhantomData,
133        }
134    }
135
136    /// Attempt to reconnect with configured strategy
137    pub async fn reconnect(
138        &self,
139        connection_id: String,
140        factory: Arc<dyn ConnectionFactory<T>>,
141    ) -> Result<T> {
142        let start_time = Instant::now();
143        let mut attempt = 0;
144        let mut current_delay = self.config.initial_delay;
145
146        loop {
147            attempt += 1;
148
149            // Check max attempts
150            if self.config.max_attempts > 0 && attempt > self.config.max_attempts {
151                let total_time = start_time.elapsed();
152
153                let _ = self
154                    .event_sender
155                    .send(ReconnectEvent::ReconnectionExhausted {
156                        connection_id: connection_id.clone(),
157                        total_attempts: attempt - 1,
158                        total_time,
159                    });
160
161                // Update statistics
162                let mut stats = self.statistics.write().await;
163                stats.failed_reconnects += 1;
164                stats.current_streak = 0;
165
166                // Call failure callbacks
167                if self.config.enable_callbacks {
168                    self.invoke_failure_callbacks(
169                        connection_id.clone(),
170                        "Maximum retry attempts exhausted".to_string(),
171                        attempt - 1,
172                    )
173                    .await;
174                }
175
176                return Err(anyhow!(
177                    "Failed to reconnect after {} attempts",
178                    self.config.max_attempts
179                ));
180            }
181
182            // Calculate delay with jitter
183            let jittered_delay = self.apply_jitter(current_delay);
184
185            if attempt > 1 {
186                info!(
187                    "Reconnection attempt {} for {} in {:?}",
188                    attempt, connection_id, jittered_delay
189                );
190
191                let _ = self.event_sender.send(ReconnectEvent::AttemptStarted {
192                    connection_id: connection_id.clone(),
193                    attempt,
194                    delay: jittered_delay,
195                });
196
197                sleep(jittered_delay).await;
198            }
199
200            // Update statistics
201            {
202                let mut stats = self.statistics.write().await;
203                stats.total_attempts += 1;
204                stats.last_reconnect_attempt = Some(Instant::now());
205            }
206
207            // Attempt connection with timeout
208            match tokio::time::timeout(self.config.connection_timeout, factory.create_connection())
209                .await
210            {
211                Ok(Ok(connection)) => {
212                    let total_time = start_time.elapsed();
213
214                    info!(
215                        "Successfully reconnected {} after {} attempts in {:?}",
216                        connection_id, attempt, total_time
217                    );
218
219                    let _ = self.event_sender.send(ReconnectEvent::AttemptSucceeded {
220                        connection_id: connection_id.clone(),
221                        attempt,
222                        total_time,
223                    });
224
225                    // Update statistics
226                    let mut stats = self.statistics.write().await;
227                    stats.successful_reconnects += 1;
228                    stats.current_streak += 1;
229                    stats.longest_streak = stats.longest_streak.max(stats.current_streak);
230                    stats.total_reconnect_time += total_time;
231                    stats.last_successful_reconnect = Some(Instant::now());
232
233                    if stats.successful_reconnects > 0 {
234                        stats.avg_reconnect_time =
235                            stats.total_reconnect_time / stats.successful_reconnects as u32;
236                    }
237
238                    return Ok(connection);
239                }
240                Ok(Err(e)) => {
241                    warn!(
242                        "Reconnection attempt {} for {} failed: {}",
243                        attempt, connection_id, e
244                    );
245
246                    // Calculate next delay
247                    current_delay = self.calculate_next_delay(current_delay, attempt);
248                    let next_delay = if attempt < self.config.max_attempts {
249                        Some(current_delay)
250                    } else {
251                        None
252                    };
253
254                    let _ = self.event_sender.send(ReconnectEvent::AttemptFailed {
255                        connection_id: connection_id.clone(),
256                        attempt,
257                        error: e.to_string(),
258                        next_delay,
259                    });
260
261                    // Call failure callbacks for each attempt if enabled
262                    if self.config.enable_callbacks && attempt % 3 == 0 {
263                        self.invoke_failure_callbacks(
264                            connection_id.clone(),
265                            e.to_string(),
266                            attempt,
267                        )
268                        .await;
269                    }
270                }
271                Err(_) => {
272                    error!(
273                        "Reconnection attempt {} for {} timed out",
274                        attempt, connection_id
275                    );
276
277                    current_delay = self.calculate_next_delay(current_delay, attempt);
278
279                    let _ = self.event_sender.send(ReconnectEvent::AttemptFailed {
280                        connection_id: connection_id.clone(),
281                        attempt,
282                        error: "Connection timeout".to_string(),
283                        next_delay: Some(current_delay),
284                    });
285                }
286            }
287        }
288    }
289
290    /// Calculate next delay based on strategy
291    fn calculate_next_delay(&self, current_delay: Duration, attempt: u32) -> Duration {
292        match &self.strategy {
293            ReconnectStrategy::ExponentialBackoff => {
294                let next_delay = current_delay.mul_f64(self.config.multiplier);
295                next_delay.min(self.config.max_delay)
296            }
297            ReconnectStrategy::FixedDelay(delay) => *delay,
298            ReconnectStrategy::LinearBackoff(increment) => {
299                let next_delay = self.config.initial_delay + (*increment * attempt);
300                next_delay.min(self.config.max_delay)
301            }
302            ReconnectStrategy::Custom => {
303                // For custom strategy, use exponential backoff as fallback
304                let next_delay = current_delay.mul_f64(self.config.multiplier);
305                next_delay.min(self.config.max_delay)
306            }
307        }
308    }
309
310    /// Apply jitter to delay
311    fn apply_jitter(&self, delay: Duration) -> Duration {
312        if self.config.jitter_factor <= 0.0 {
313            return delay;
314        }
315
316        let jitter_range = delay.as_millis() as f64 * self.config.jitter_factor;
317        let jitter = (fastrand::f64() - 0.5) * 2.0 * jitter_range;
318        let jittered_millis = (delay.as_millis() as f64 + jitter).max(0.0) as u64;
319
320        Duration::from_millis(jittered_millis)
321    }
322
323    /// Register a connection failure callback
324    pub async fn register_failure_callback<F>(&self, callback: F)
325    where
326        F: Fn(String, String, u32) -> Pin<Box<dyn Future<Output = ()> + Send>>
327            + Send
328            + Sync
329            + 'static,
330    {
331        let mut callbacks = self.failure_callbacks.write().await;
332        callbacks.push(Arc::new(callback));
333    }
334
335    /// Invoke all registered failure callbacks
336    async fn invoke_failure_callbacks(&self, connection_id: String, error: String, attempt: u32) {
337        let callbacks = self.failure_callbacks.read().await;
338
339        for callback in callbacks.iter() {
340            let fut = callback(connection_id.clone(), error.clone(), attempt);
341            tokio::spawn(async move {
342                fut.await;
343            });
344        }
345    }
346
347    /// Get reconnection statistics
348    pub async fn get_statistics(&self) -> ReconnectStatistics {
349        self.statistics.read().await.clone()
350    }
351
352    /// Reset reconnection statistics
353    pub async fn reset_statistics(&self) {
354        *self.statistics.write().await = ReconnectStatistics::default();
355    }
356
357    /// Subscribe to reconnection events
358    pub fn subscribe(&self) -> broadcast::Receiver<ReconnectEvent> {
359        self.event_sender.subscribe()
360    }
361}
362
363/// Helper for creating reconnection-aware connections
364pub struct ResilientConnection<T: PooledConnection> {
365    connection: Option<T>,
366    connection_id: String,
367    factory: Arc<dyn ConnectionFactory<T>>,
368    reconnect_manager: Arc<ReconnectManager<T>>,
369    last_error: Option<String>,
370}
371
372impl<T: PooledConnection + Clone> ResilientConnection<T> {
373    /// Create a new resilient connection
374    pub async fn new(
375        connection_id: String,
376        factory: Arc<dyn ConnectionFactory<T>>,
377        reconnect_manager: Arc<ReconnectManager<T>>,
378    ) -> Result<Self> {
379        let connection = factory.create_connection().await?;
380
381        Ok(Self {
382            connection: Some(connection),
383            connection_id,
384            factory,
385            reconnect_manager,
386            last_error: None,
387        })
388    }
389
390    /// Get the underlying connection, reconnecting if necessary
391    pub async fn get_connection(&mut self) -> Result<&mut T> {
392        // Check if we have a healthy connection
393        let needs_reconnection = match self.connection {
394            Some(ref mut conn) => !conn.is_healthy().await,
395            None => true,
396        };
397
398        if !needs_reconnection {
399            // Return the healthy connection
400            return self
401                .connection
402                .as_mut()
403                .ok_or_else(|| anyhow!("Connection unexpectedly None"));
404        }
405
406        // Connection is unhealthy or missing, attempt reconnection
407        info!(
408            "Connection {} is unhealthy, attempting reconnection",
409            self.connection_id
410        );
411
412        match self
413            .reconnect_manager
414            .reconnect(self.connection_id.clone(), self.factory.clone())
415            .await
416        {
417            Ok(new_conn) => {
418                self.connection = Some(new_conn);
419                self.last_error = None;
420                self.connection
421                    .as_mut()
422                    .ok_or_else(|| anyhow!("Connection unexpectedly None"))
423            }
424            Err(e) => {
425                self.last_error = Some(e.to_string());
426                Err(e)
427            }
428        }
429    }
430
431    /// Check if connection is currently healthy
432    pub async fn is_healthy(&self) -> bool {
433        if let Some(ref conn) = self.connection {
434            conn.is_healthy().await
435        } else {
436            false
437        }
438    }
439
440    /// Get the last error if any
441    pub fn last_error(&self) -> Option<&str> {
442        self.last_error.as_deref()
443    }
444
445    /// Manually trigger reconnection
446    pub async fn reconnect(&mut self) -> Result<()> {
447        let new_conn = self
448            .reconnect_manager
449            .reconnect(self.connection_id.clone(), self.factory.clone())
450            .await?;
451
452        // Close old connection if exists
453        if let Some(mut old_conn) = self.connection.take() {
454            let _ = old_conn.close().await;
455        }
456
457        self.connection = Some(new_conn);
458        self.last_error = None;
459        Ok(())
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466    use std::sync::atomic::{AtomicBool, AtomicU32, Ordering};
467
468    #[derive(Clone)]
469    struct TestConnection {
470        id: u32,
471        healthy: Arc<AtomicBool>,
472        created_at: Instant,
473    }
474
475    #[async_trait::async_trait]
476    impl PooledConnection for TestConnection {
477        async fn is_healthy(&self) -> bool {
478            self.healthy.load(Ordering::Relaxed)
479        }
480
481        async fn close(&mut self) -> Result<()> {
482            Ok(())
483        }
484
485        fn clone_connection(&self) -> Box<dyn PooledConnection> {
486            Box::new(TestConnection {
487                id: self.id,
488                healthy: Arc::new(AtomicBool::new(self.healthy.load(Ordering::Relaxed))),
489                created_at: self.created_at,
490            })
491        }
492
493        fn created_at(&self) -> Instant {
494            self.created_at
495        }
496
497        fn last_activity(&self) -> Instant {
498            Instant::now()
499        }
500
501        fn update_activity(&mut self) {}
502    }
503
504    struct TestConnectionFactory {
505        counter: Arc<AtomicU32>,
506        should_fail: Arc<AtomicBool>,
507        fail_count: Arc<AtomicU32>,
508    }
509
510    #[async_trait::async_trait]
511    impl ConnectionFactory<TestConnection> for TestConnectionFactory {
512        async fn create_connection(&self) -> Result<TestConnection> {
513            let current_fails = self.fail_count.load(Ordering::Relaxed);
514
515            if self.should_fail.load(Ordering::Relaxed) && current_fails > 0 {
516                self.fail_count.fetch_sub(1, Ordering::Relaxed);
517                return Err(anyhow!("Simulated connection failure"));
518            }
519
520            let id = self.counter.fetch_add(1, Ordering::Relaxed);
521            Ok(TestConnection {
522                id,
523                healthy: Arc::new(AtomicBool::new(true)),
524                created_at: Instant::now(),
525            })
526        }
527    }
528
529    #[tokio::test]
530    async fn test_exponential_backoff() {
531        let config = ReconnectConfig {
532            initial_delay: Duration::from_millis(10),
533            max_delay: Duration::from_millis(100),
534            multiplier: 2.0,
535            max_attempts: 5,
536            jitter_factor: 0.0,
537            ..Default::default()
538        };
539
540        let manager =
541            ReconnectManager::<TestConnection>::new(config, ReconnectStrategy::ExponentialBackoff);
542
543        let factory = Arc::new(TestConnectionFactory {
544            counter: Arc::new(AtomicU32::new(0)),
545            should_fail: Arc::new(AtomicBool::new(true)),
546            fail_count: Arc::new(AtomicU32::new(3)), // Fail first 3 attempts
547        });
548
549        let start = Instant::now();
550        let result = manager.reconnect("test-conn".to_string(), factory).await;
551        let elapsed = start.elapsed();
552
553        assert!(result.is_ok());
554
555        // Should have delays: 0ms, 10ms, 20ms, 40ms (total ~70ms)
556        // Allow for more timing variance during parallel test execution
557        assert!(elapsed >= Duration::from_millis(50));
558        assert!(elapsed < Duration::from_millis(300));
559
560        let stats = manager.get_statistics().await;
561        assert_eq!(stats.total_attempts, 4);
562        assert_eq!(stats.successful_reconnects, 1);
563    }
564
565    #[tokio::test]
566    async fn test_max_attempts() {
567        let config = ReconnectConfig {
568            initial_delay: Duration::from_millis(1),
569            max_attempts: 3,
570            ..Default::default()
571        };
572
573        let manager =
574            ReconnectManager::<TestConnection>::new(config, ReconnectStrategy::ExponentialBackoff);
575
576        let factory = Arc::new(TestConnectionFactory {
577            counter: Arc::new(AtomicU32::new(0)),
578            should_fail: Arc::new(AtomicBool::new(true)),
579            fail_count: Arc::new(AtomicU32::new(100)), // Always fail
580        });
581
582        let result = manager.reconnect("test-conn".to_string(), factory).await;
583        assert!(result.is_err());
584
585        let stats = manager.get_statistics().await;
586        assert_eq!(stats.total_attempts, 3);
587        assert_eq!(stats.failed_reconnects, 1);
588    }
589
590    #[tokio::test]
591    async fn test_failure_callbacks() {
592        let config = ReconnectConfig {
593            initial_delay: Duration::from_millis(1),
594            max_attempts: 3,
595            enable_callbacks: true,
596            ..Default::default()
597        };
598
599        let manager = ReconnectManager::<TestConnection>::new(
600            config,
601            ReconnectStrategy::FixedDelay(Duration::from_millis(1)),
602        );
603
604        let callback_called = Arc::new(AtomicBool::new(false));
605        let callback_called_clone = callback_called.clone();
606
607        manager
608            .register_failure_callback(move |_id, _error, _attempt| {
609                let called = callback_called_clone.clone();
610                Box::pin(async move {
611                    called.store(true, Ordering::Relaxed);
612                })
613            })
614            .await;
615
616        let factory = Arc::new(TestConnectionFactory {
617            counter: Arc::new(AtomicU32::new(0)),
618            should_fail: Arc::new(AtomicBool::new(true)),
619            fail_count: Arc::new(AtomicU32::new(100)),
620        });
621
622        let _ = manager.reconnect("test-conn".to_string(), factory).await;
623
624        // Give callback time to execute
625        tokio::time::sleep(Duration::from_millis(10)).await;
626
627        assert!(callback_called.load(Ordering::Relaxed));
628    }
629
630    #[tokio::test]
631    async fn test_resilient_connection() {
632        let config = ReconnectConfig::default();
633        let manager = Arc::new(ReconnectManager::<TestConnection>::new(
634            config,
635            ReconnectStrategy::ExponentialBackoff,
636        ));
637
638        let _healthy_flag = Arc::new(AtomicBool::new(true));
639        let factory = Arc::new(TestConnectionFactory {
640            counter: Arc::new(AtomicU32::new(0)),
641            should_fail: Arc::new(AtomicBool::new(false)),
642            fail_count: Arc::new(AtomicU32::new(0)),
643        });
644
645        let mut resilient = ResilientConnection::new("test-conn".to_string(), factory, manager)
646            .await
647            .unwrap();
648
649        // Should work normally
650        assert!(resilient.is_healthy().await);
651        let conn = resilient.get_connection().await.unwrap();
652        assert!(conn.is_healthy().await);
653    }
654}