kode_bridge/
retry.rs

1use crate::errors::KodeBridgeError;
2use rand::{random_range, rngs::StdRng, SeedableRng};
3use std::time::{Duration, Instant};
4use tracing::{debug, warn};
5
6/// Type alias for complex retry function
7pub type RetryFn = Box<dyn Fn(&KodeBridgeError, usize) -> bool + Send + Sync>;
8
9/// Advanced retry configuration with adaptive strategies
10pub struct RetryConfig {
11    /// Maximum number of retry attempts
12    pub max_attempts: usize,
13    /// Base delay between retries
14    pub base_delay: Duration,
15    /// Maximum delay between retries (for exponential backoff)
16    pub max_delay: Duration,
17    /// Backoff strategy to use
18    pub backoff_strategy: BackoffStrategy,
19    /// Jitter strategy to avoid thundering herd
20    pub jitter_strategy: JitterStrategy,
21    /// Custom retry decision function (not cloneable, so we'll skip it in Clone)
22    pub should_retry_fn: Option<RetryFn>,
23}
24
25impl Clone for RetryConfig {
26    fn clone(&self) -> Self {
27        Self {
28            max_attempts: self.max_attempts,
29            base_delay: self.base_delay,
30            max_delay: self.max_delay,
31            backoff_strategy: self.backoff_strategy,
32            jitter_strategy: self.jitter_strategy,
33            should_retry_fn: None, // Skip cloning function pointer
34        }
35    }
36}
37
38#[derive(Debug, Clone, Copy)]
39pub enum BackoffStrategy {
40    /// Fixed delay between retries
41    Fixed,
42    /// Exponential backoff: delay *= multiplier
43    Exponential { multiplier: f64 },
44    /// Linear backoff: delay += increment
45    Linear { increment: Duration },
46}
47
48#[derive(Debug, Clone, Copy)]
49pub enum JitterStrategy {
50    /// No jitter
51    None,
52    /// Add random jitter up to 50% of delay
53    Full,
54    /// Add random jitter up to 25% of delay  
55    Partial,
56    /// Use decorrelated jitter for better distribution
57    Decorrelated,
58}
59
60impl Default for RetryConfig {
61    fn default() -> Self {
62        Self {
63            max_attempts: 3,
64            base_delay: Duration::from_millis(100),
65            max_delay: Duration::from_secs(30),
66            backoff_strategy: BackoffStrategy::Exponential { multiplier: 2.0 },
67            jitter_strategy: JitterStrategy::Partial,
68            should_retry_fn: None,
69        }
70    }
71}
72
73impl RetryConfig {
74    /// Create a new retry configuration
75    pub fn new() -> Self {
76        Self::default()
77    }
78
79    /// Set maximum retry attempts
80    pub fn max_attempts(mut self, max_attempts: usize) -> Self {
81        self.max_attempts = max_attempts;
82        self
83    }
84
85    /// Set base delay
86    pub fn base_delay(mut self, delay: Duration) -> Self {
87        self.base_delay = delay;
88        self
89    }
90
91    /// Set maximum delay
92    pub fn max_delay(mut self, delay: Duration) -> Self {
93        self.max_delay = delay;
94        self
95    }
96
97    /// Use exponential backoff strategy
98    pub fn exponential_backoff(mut self, multiplier: f64) -> Self {
99        self.backoff_strategy = BackoffStrategy::Exponential { multiplier };
100        self
101    }
102
103    /// Use fixed backoff strategy
104    pub fn fixed_backoff(mut self) -> Self {
105        self.backoff_strategy = BackoffStrategy::Fixed;
106        self
107    }
108
109    /// Use linear backoff strategy
110    pub fn linear_backoff(mut self, increment: Duration) -> Self {
111        self.backoff_strategy = BackoffStrategy::Linear { increment };
112        self
113    }
114
115    /// Set jitter strategy
116    pub fn jitter(mut self, strategy: JitterStrategy) -> Self {
117        self.jitter_strategy = strategy;
118        self
119    }
120
121    /// Set custom retry condition
122    pub fn should_retry<F>(mut self, f: F) -> Self
123    where
124        F: Fn(&KodeBridgeError, usize) -> bool + Send + Sync + 'static,
125    {
126        self.should_retry_fn = Some(Box::new(f));
127        self
128    }
129
130    /// Smart defaults for different scenarios
131    pub fn for_network_operations() -> Self {
132        Self::new()
133            .max_attempts(5)
134            .base_delay(Duration::from_millis(50))
135            .max_delay(Duration::from_secs(10))
136            .exponential_backoff(2.0)
137            .jitter(JitterStrategy::Full)
138    }
139
140    pub fn for_rate_limited_apis() -> Self {
141        Self::new()
142            .max_attempts(10)
143            .base_delay(Duration::from_secs(1))
144            .max_delay(Duration::from_secs(60))
145            .exponential_backoff(1.5)
146            .jitter(JitterStrategy::Decorrelated)
147    }
148
149    pub fn for_quick_operations() -> Self {
150        Self::new()
151            .max_attempts(3)
152            .base_delay(Duration::from_millis(10))
153            .max_delay(Duration::from_millis(500))
154            .linear_backoff(Duration::from_millis(50))
155            .jitter(JitterStrategy::Partial)
156    }
157}
158
159/// Retry state tracking
160#[derive(Debug)]
161pub struct RetryState {
162    attempt: usize,
163    total_elapsed: Duration,
164    last_delay: Duration,
165}
166
167impl Default for RetryState {
168    fn default() -> Self {
169        Self {
170            attempt: 0,
171            total_elapsed: Duration::ZERO,
172            last_delay: Duration::ZERO,
173        }
174    }
175}
176
177impl RetryState {
178    pub fn new() -> Self {
179        Self::default()
180    }
181
182    pub fn attempt(&self) -> usize {
183        self.attempt
184    }
185
186    pub fn total_elapsed(&self) -> Duration {
187        self.total_elapsed
188    }
189
190    pub fn last_delay(&self) -> Duration {
191        self.last_delay
192    }
193}
194
195/// Smart retry executor
196pub struct RetryExecutor {
197    config: RetryConfig,
198}
199
200impl RetryExecutor {
201    pub fn new(config: RetryConfig) -> Self {
202        Self { config }
203    }
204
205    /// Execute operation with retry logic
206    pub async fn execute<F, Fut, T>(&self, mut operation: F) -> Result<T, KodeBridgeError>
207    where
208        F: FnMut() -> Fut,
209        Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
210    {
211        let mut state = RetryState::new();
212        let mut rng = StdRng::from_seed([0u8; 32]); // Use deterministic seed for Send compatibility
213
214        loop {
215            state.attempt += 1;
216            let attempt_start = Instant::now();
217
218            debug!("Retry attempt {} starting", state.attempt);
219
220            match operation().await {
221                Ok(result) => {
222                    if state.attempt > 1 {
223                        debug!(
224                            "Operation succeeded on attempt {} after {}ms",
225                            state.attempt,
226                            state.total_elapsed.as_millis()
227                        );
228                    }
229                    return Ok(result);
230                }
231                Err(error) => {
232                    let attempt_duration = attempt_start.elapsed();
233                    state.total_elapsed += attempt_duration;
234
235                    // Check if we should retry this error
236                    let should_retry = if let Some(ref custom_fn) = self.config.should_retry_fn {
237                        custom_fn(&error, state.attempt)
238                    } else {
239                        self.default_should_retry(&error, state.attempt)
240                    };
241
242                    if !should_retry || state.attempt >= self.config.max_attempts {
243                        warn!(
244                            "Operation failed after {} attempts in {}ms: {}",
245                            state.attempt,
246                            state.total_elapsed.as_millis(),
247                            error
248                        );
249                        return Err(error);
250                    }
251
252                    // Calculate next delay
253                    let next_delay = self.calculate_delay(&mut state, &mut rng);
254
255                    debug!(
256                        "Retrying after {}ms (attempt {}/{}, error: {})",
257                        next_delay.as_millis(),
258                        state.attempt,
259                        self.config.max_attempts,
260                        error
261                    );
262
263                    tokio::time::sleep(next_delay).await;
264                }
265            }
266        }
267    }
268
269    /// Execute operation with context for better error reporting
270    pub async fn execute_with_context<F, Fut, T>(
271        &self,
272        operation_name: &str,
273        operation: F,
274    ) -> Result<T, KodeBridgeError>
275    where
276        F: FnMut() -> Fut,
277        Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
278    {
279        debug!("Starting retry execution for operation: {}", operation_name);
280
281        match self.execute(operation).await {
282            Ok(result) => {
283                debug!("Operation '{}' completed successfully", operation_name);
284                Ok(result)
285            }
286            Err(error) => {
287                warn!(
288                    "Operation '{}' failed with error: {}",
289                    operation_name, error
290                );
291                Err(KodeBridgeError::custom(format!(
292                    "Operation '{}' failed after retries: {}",
293                    operation_name, error
294                )))
295            }
296        }
297    }
298
299    /// Default retry logic based on error type
300    fn default_should_retry(&self, error: &KodeBridgeError, attempt: usize) -> bool {
301        use KodeBridgeError::*;
302
303        match error {
304            // Always retry network-related errors
305            Io(_) | Connection { .. } | Timeout { .. } | StreamClosed => true,
306
307            // Retry server errors (5xx) but not client errors (4xx)
308            ServerError { status } => *status >= 500,
309            ClientError { .. } | InvalidRequest { .. } => false,
310
311            // Don't retry parsing or protocol errors
312            HttpParse(_) | Http(_) | Protocol { .. } => false,
313
314            // Don't retry configuration or validation errors
315            Configuration { .. } => false,
316
317            // Don't retry JSON errors (likely application issue)
318            Json(_) | JsonSerialize { .. } => false,
319
320            // Don't retry UTF-8 errors
321            Utf8(_) | FromUtf8(_) => false,
322
323            // Retry resource exhaustion with exponential backoff
324            PoolExhausted => attempt <= 5, // But limit attempts for pool exhaustion
325
326            // Custom errors - be conservative
327            Custom { .. } => false,
328
329            // HTTP status code errors need special handling
330            InvalidStatusCode(_) => false,
331        }
332    }
333
334    /// Calculate next retry delay with backoff and jitter
335    fn calculate_delay(&self, state: &mut RetryState, _rng: &mut impl rand::Rng) -> Duration {
336        let base_delay = match self.config.backoff_strategy {
337            BackoffStrategy::Fixed => self.config.base_delay,
338            BackoffStrategy::Exponential { multiplier } => {
339                if state.attempt == 1 {
340                    self.config.base_delay
341                } else {
342                    let exponential = (self.config.base_delay.as_millis() as f64
343                        * multiplier.powi((state.attempt - 1) as i32))
344                        as u64;
345                    Duration::from_millis(exponential)
346                }
347            }
348            BackoffStrategy::Linear { increment } => {
349                self.config.base_delay + increment * (state.attempt as u32 - 1)
350            }
351        };
352
353        // Cap at maximum delay
354        let capped_delay = std::cmp::min(base_delay, self.config.max_delay);
355
356        // Apply jitter
357        let final_delay = match self.config.jitter_strategy {
358            JitterStrategy::None => capped_delay,
359            JitterStrategy::Full => {
360                let jitter = random_range(0..=capped_delay.as_millis() / 2) as u64;
361                capped_delay + Duration::from_millis(jitter)
362            }
363            JitterStrategy::Partial => {
364                let jitter = random_range(0..=capped_delay.as_millis() / 4) as u64;
365                capped_delay + Duration::from_millis(jitter)
366            }
367            JitterStrategy::Decorrelated => {
368                // Decorrelated jitter: next_delay = random_between(base_delay, last_delay * 3)
369                let min_delay = self.config.base_delay.as_millis() as u64;
370                let max_delay = std::cmp::min(
371                    (state.last_delay.as_millis() as u64 * 3).max(min_delay),
372                    self.config.max_delay.as_millis() as u64,
373                );
374                Duration::from_millis(random_range(min_delay..=max_delay))
375            }
376        };
377
378        state.last_delay = final_delay;
379        final_delay
380    }
381}
382
383/// Convenience function for simple retry operations
384pub async fn retry<F, Fut, T>(config: RetryConfig, operation: F) -> Result<T, KodeBridgeError>
385where
386    F: FnMut() -> Fut,
387    Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
388{
389    RetryExecutor::new(config).execute(operation).await
390}
391
392/// Convenience function with default configuration
393pub async fn retry_default<F, Fut, T>(operation: F) -> Result<T, KodeBridgeError>
394where
395    F: FnMut() -> Fut,
396    Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
397{
398    retry(RetryConfig::default(), operation).await
399}
400
401/// Circuit breaker pattern for failing services
402#[derive(Debug)]
403pub struct CircuitBreaker {
404    failure_threshold: usize,
405    recovery_timeout: Duration,
406    consecutive_failures: usize,
407    last_failure_time: Option<Instant>,
408    state: CircuitState,
409}
410
411#[derive(Debug, Clone, PartialEq)]
412enum CircuitState {
413    Closed,   // Normal operation
414    Open,     // Failing, reject requests
415    HalfOpen, // Testing if service recovered
416}
417
418impl CircuitBreaker {
419    pub fn new(failure_threshold: usize, recovery_timeout: Duration) -> Self {
420        Self {
421            failure_threshold,
422            recovery_timeout,
423            consecutive_failures: 0,
424            last_failure_time: None,
425            state: CircuitState::Closed,
426        }
427    }
428
429    pub async fn execute<F, Fut, T>(&mut self, operation: F) -> Result<T, KodeBridgeError>
430    where
431        F: FnOnce() -> Fut,
432        Fut: std::future::Future<Output = Result<T, KodeBridgeError>>,
433    {
434        if self.state == CircuitState::Open {
435            if let Some(last_failure) = self.last_failure_time {
436                if last_failure.elapsed() >= self.recovery_timeout {
437                    debug!("Circuit breaker entering half-open state");
438                    self.state = CircuitState::HalfOpen;
439                } else {
440                    return Err(KodeBridgeError::custom("Circuit breaker is open"));
441                }
442            } else {
443                return Err(KodeBridgeError::custom("Circuit breaker is open"));
444            }
445        }
446
447        match operation().await {
448            Ok(result) => {
449                // Success - reset circuit breaker
450                if self.state == CircuitState::HalfOpen {
451                    debug!("Circuit breaker closing after successful operation");
452                }
453                self.consecutive_failures = 0;
454                self.last_failure_time = None;
455                self.state = CircuitState::Closed;
456                Ok(result)
457            }
458            Err(error) => {
459                // Failure - update circuit breaker state
460                self.consecutive_failures += 1;
461                self.last_failure_time = Some(Instant::now());
462
463                if self.consecutive_failures >= self.failure_threshold {
464                    debug!(
465                        "Circuit breaker opening after {} consecutive failures",
466                        self.consecutive_failures
467                    );
468                    self.state = CircuitState::Open;
469                }
470
471                Err(error)
472            }
473        }
474    }
475
476    pub fn is_open(&self) -> bool {
477        matches!(self.state, CircuitState::Open)
478    }
479
480    pub fn reset(&mut self) {
481        self.consecutive_failures = 0;
482        self.last_failure_time = None;
483        self.state = CircuitState::Closed;
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490    use std::sync::atomic::{AtomicUsize, Ordering};
491    use std::sync::Arc;
492
493    #[tokio::test]
494    async fn test_retry_success_on_first_attempt() {
495        let config = RetryConfig::new().max_attempts(3);
496        let executor = RetryExecutor::new(config);
497
498        let result = executor
499            .execute(|| async { Ok::<i32, KodeBridgeError>(42) })
500            .await;
501
502        assert_eq!(result.unwrap(), 42);
503    }
504
505    #[tokio::test]
506    async fn test_retry_success_after_failures() {
507        let config = RetryConfig::new()
508            .max_attempts(3)
509            .base_delay(Duration::from_millis(1));
510        let executor = RetryExecutor::new(config);
511        let attempt_count = Arc::new(AtomicUsize::new(0));
512
513        let result = executor
514            .execute(|| {
515                let count = attempt_count.clone();
516                async move {
517                    let current = count.fetch_add(1, Ordering::SeqCst);
518                    if current < 2 {
519                        Err(KodeBridgeError::connection("Temporary failure"))
520                    } else {
521                        Ok(42)
522                    }
523                }
524            })
525            .await;
526
527        assert_eq!(result.unwrap(), 42);
528        assert_eq!(attempt_count.load(Ordering::SeqCst), 3);
529    }
530
531    #[tokio::test]
532    async fn test_retry_max_attempts_exceeded() {
533        let config = RetryConfig::new()
534            .max_attempts(2)
535            .base_delay(Duration::from_millis(1));
536        let executor = RetryExecutor::new(config);
537        let attempt_count = Arc::new(AtomicUsize::new(0));
538
539        let result = executor
540            .execute(|| {
541                let count = attempt_count.clone();
542                async move {
543                    count.fetch_add(1, Ordering::SeqCst);
544                    Err::<i32, _>(KodeBridgeError::connection("Always fails"))
545                }
546            })
547            .await;
548
549        assert!(result.is_err());
550        assert_eq!(attempt_count.load(Ordering::SeqCst), 2);
551    }
552
553    #[tokio::test]
554    async fn test_retry_non_retriable_error() {
555        let config = RetryConfig::new()
556            .max_attempts(3)
557            .base_delay(Duration::from_millis(1));
558        let executor = RetryExecutor::new(config);
559        let attempt_count = Arc::new(AtomicUsize::new(0));
560
561        let result = executor
562            .execute(|| {
563                let count = attempt_count.clone();
564                async move {
565                    count.fetch_add(1, Ordering::SeqCst);
566                    Err::<i32, _>(KodeBridgeError::ClientError { status: 400 })
567                }
568            })
569            .await;
570
571        assert!(result.is_err());
572        assert_eq!(attempt_count.load(Ordering::SeqCst), 1); // No retry for client error
573    }
574
575    #[tokio::test]
576    async fn test_circuit_breaker() {
577        let mut breaker = CircuitBreaker::new(2, Duration::from_millis(100));
578
579        // First failure
580        let result = breaker
581            .execute(|| async { Err::<i32, _>(KodeBridgeError::connection("Failure 1")) })
582            .await;
583        assert!(result.is_err());
584        assert!(!breaker.is_open());
585
586        // Second failure - should open circuit
587        let result = breaker
588            .execute(|| async { Err::<i32, _>(KodeBridgeError::connection("Failure 2")) })
589            .await;
590        assert!(result.is_err());
591        assert!(breaker.is_open());
592
593        // Third attempt should be rejected immediately
594        let result = breaker
595            .execute(|| async { Ok::<i32, KodeBridgeError>(42) })
596            .await;
597        assert!(result.is_err());
598        assert!(result
599            .unwrap_err()
600            .to_string()
601            .contains("Circuit breaker is open"));
602    }
603
604    #[test]
605    fn test_backoff_strategies() {
606        let mut state = RetryState::new();
607        let mut rng = StdRng::from_seed([0u8; 32]); // Use deterministic seed for Send compatibility
608
609        // Test exponential backoff
610        let config = RetryConfig::new()
611            .exponential_backoff(2.0)
612            .base_delay(Duration::from_millis(100))
613            .jitter(JitterStrategy::None);
614        let executor = RetryExecutor::new(config);
615
616        state.attempt = 1;
617        let delay1 = executor.calculate_delay(&mut state, &mut rng);
618        assert_eq!(delay1, Duration::from_millis(100));
619
620        state.attempt = 2;
621        let delay2 = executor.calculate_delay(&mut state, &mut rng);
622        assert_eq!(delay2, Duration::from_millis(200));
623
624        state.attempt = 3;
625        let delay3 = executor.calculate_delay(&mut state, &mut rng);
626        assert_eq!(delay3, Duration::from_millis(400));
627    }
628
629    #[test]
630    fn test_retry_config_builder() {
631        let config = RetryConfig::for_network_operations();
632        assert_eq!(config.max_attempts, 5);
633        assert_eq!(config.base_delay, Duration::from_millis(50));
634
635        let config = RetryConfig::for_rate_limited_apis();
636        assert_eq!(config.max_attempts, 10);
637        assert_eq!(config.base_delay, Duration::from_secs(1));
638    }
639}