omega_runtime/
retry.rs

1//! Retry Policy Implementation
2//!
3//! Provides configurable retry logic with exponential backoff and jitter
4//! for handling transient failures in distributed systems.
5
6use rand::Rng;
7use std::time::Duration;
8use thiserror::Error;
9use tracing::{debug, warn};
10
11/// Retry policy configuration
12#[derive(Debug, Clone)]
13pub struct RetryConfig {
14    /// Maximum number of retry attempts
15    pub max_retries: u32,
16    /// Initial delay before first retry
17    pub initial_delay: Duration,
18    /// Maximum delay between retries
19    pub max_delay: Duration,
20    /// Backoff multiplier for exponential backoff
21    pub multiplier: f64,
22    /// Whether to add jitter to delays (recommended to prevent thundering herd)
23    pub use_jitter: bool,
24    /// Jitter factor (0.0 to 1.0, where 0.5 means +/- 50% variation)
25    pub jitter_factor: f64,
26}
27
28impl Default for RetryConfig {
29    fn default() -> Self {
30        Self {
31            max_retries: 3,
32            initial_delay: Duration::from_millis(100),
33            max_delay: Duration::from_secs(30),
34            multiplier: 2.0,
35            use_jitter: true,
36            jitter_factor: 0.3,
37        }
38    }
39}
40
41/// Retry policy error types
42#[derive(Debug, Error)]
43pub enum RetryError<E> {
44    #[error("Operation failed after {attempts} attempts: {last_error}")]
45    MaxRetriesExceeded { attempts: u32, last_error: E },
46    #[error("Retry aborted: {0}")]
47    Aborted(String),
48}
49
50/// Retry policy for executing operations with automatic retries
51pub struct RetryPolicy {
52    config: RetryConfig,
53}
54
55impl RetryPolicy {
56    /// Create a new retry policy with the given configuration
57    pub fn new(config: RetryConfig) -> Self {
58        Self { config }
59    }
60
61    /// Create a retry policy with default configuration
62    pub fn default() -> Self {
63        Self {
64            config: RetryConfig::default(),
65        }
66    }
67
68    /// Execute an async operation with retry logic
69    pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> Result<T, RetryError<E>>
70    where
71        F: FnMut() -> Fut,
72        Fut: std::future::Future<Output = Result<T, E>>,
73        E: std::fmt::Display + Clone,
74    {
75        let mut attempts = 0;
76        let mut last_error: Option<E> = None;
77
78        loop {
79            attempts += 1;
80
81            debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
82
83            match operation().await {
84                Ok(result) => {
85                    if attempts > 1 {
86                        debug!("Operation succeeded after {} attempts", attempts);
87                    }
88                    return Ok(result);
89                }
90                Err(e) => {
91                    warn!("Attempt {} failed: {}", attempts, e);
92
93                    if attempts > self.config.max_retries {
94                        return Err(RetryError::MaxRetriesExceeded {
95                            attempts,
96                            last_error: e,
97                        });
98                    }
99
100                    last_error = Some(e);
101
102                    // Calculate delay with exponential backoff
103                    let delay = self.calculate_delay(attempts);
104                    debug!("Waiting {:?} before retry", delay);
105
106                    tokio::time::sleep(delay).await;
107                }
108            }
109        }
110    }
111
112    /// Execute a synchronous operation with retry logic
113    pub fn execute_sync<F, T, E>(&self, mut operation: F) -> Result<T, RetryError<E>>
114    where
115        F: FnMut() -> Result<T, E>,
116        E: std::fmt::Display + Clone,
117    {
118        let mut attempts = 0;
119        let mut last_error: Option<E> = None;
120
121        loop {
122            attempts += 1;
123
124            debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
125
126            match operation() {
127                Ok(result) => {
128                    if attempts > 1 {
129                        debug!("Operation succeeded after {} attempts", attempts);
130                    }
131                    return Ok(result);
132                }
133                Err(e) => {
134                    warn!("Attempt {} failed: {}", attempts, e);
135
136                    if attempts > self.config.max_retries {
137                        return Err(RetryError::MaxRetriesExceeded {
138                            attempts,
139                            last_error: e,
140                        });
141                    }
142
143                    last_error = Some(e);
144
145                    // Calculate delay with exponential backoff
146                    let delay = self.calculate_delay(attempts);
147                    debug!("Waiting {:?} before retry", delay);
148
149                    std::thread::sleep(delay);
150                }
151            }
152        }
153    }
154
155    /// Execute with custom retry condition
156    pub async fn execute_with_condition<F, Fut, T, E, C>(
157        &self,
158        mut operation: F,
159        mut should_retry: C,
160    ) -> Result<T, RetryError<E>>
161    where
162        F: FnMut() -> Fut,
163        Fut: std::future::Future<Output = Result<T, E>>,
164        E: std::fmt::Display + Clone,
165        C: FnMut(&E) -> bool,
166    {
167        let mut attempts = 0;
168
169        loop {
170            attempts += 1;
171
172            debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
173
174            match operation().await {
175                Ok(result) => {
176                    if attempts > 1 {
177                        debug!("Operation succeeded after {} attempts", attempts);
178                    }
179                    return Ok(result);
180                }
181                Err(e) => {
182                    if !should_retry(&e) {
183                        debug!("Error is not retryable: {}", e);
184                        return Err(RetryError::Aborted(format!(
185                            "Non-retryable error after {} attempts: {}",
186                            attempts, e
187                        )));
188                    }
189
190                    warn!("Attempt {} failed: {}", attempts, e);
191
192                    if attempts > self.config.max_retries {
193                        return Err(RetryError::MaxRetriesExceeded {
194                            attempts,
195                            last_error: e,
196                        });
197                    }
198
199                    // Calculate delay with exponential backoff
200                    let delay = self.calculate_delay(attempts);
201                    debug!("Waiting {:?} before retry", delay);
202
203                    tokio::time::sleep(delay).await;
204                }
205            }
206        }
207    }
208
209    /// Calculate delay for the given attempt number
210    fn calculate_delay(&self, attempt: u32) -> Duration {
211        // Calculate base delay using exponential backoff
212        let base_delay_ms = self.config.initial_delay.as_millis() as f64
213            * self.config.multiplier.powi((attempt - 1) as i32);
214
215        let base_delay = Duration::from_millis(base_delay_ms as u64);
216
217        // Cap at max delay
218        let capped_delay = if base_delay > self.config.max_delay {
219            self.config.max_delay
220        } else {
221            base_delay
222        };
223
224        // Add jitter if enabled
225        if self.config.use_jitter {
226            self.add_jitter(capped_delay)
227        } else {
228            capped_delay
229        }
230    }
231
232    /// Add jitter to a delay to prevent thundering herd
233    fn add_jitter(&self, delay: Duration) -> Duration {
234        let mut rng = rand::thread_rng();
235        let delay_ms = delay.as_millis() as f64;
236
237        // Calculate jitter range
238        let jitter_range = delay_ms * self.config.jitter_factor;
239
240        // Add random jitter in range [-jitter_range, +jitter_range]
241        let jitter = rng.gen_range(-jitter_range..=jitter_range);
242        let jittered_ms = (delay_ms + jitter).max(0.0);
243
244        Duration::from_millis(jittered_ms as u64)
245    }
246
247    /// Get the configuration
248    pub fn config(&self) -> &RetryConfig {
249        &self.config
250    }
251}
252
253/// Builder for retry configuration
254pub struct RetryConfigBuilder {
255    config: RetryConfig,
256}
257
258impl RetryConfigBuilder {
259    /// Create a new builder with default configuration
260    pub fn new() -> Self {
261        Self {
262            config: RetryConfig::default(),
263        }
264    }
265
266    /// Set maximum number of retries
267    pub fn max_retries(mut self, max_retries: u32) -> Self {
268        self.config.max_retries = max_retries;
269        self
270    }
271
272    /// Set initial delay
273    pub fn initial_delay(mut self, delay: Duration) -> Self {
274        self.config.initial_delay = delay;
275        self
276    }
277
278    /// Set maximum delay
279    pub fn max_delay(mut self, delay: Duration) -> Self {
280        self.config.max_delay = delay;
281        self
282    }
283
284    /// Set backoff multiplier
285    pub fn multiplier(mut self, multiplier: f64) -> Self {
286        self.config.multiplier = multiplier;
287        self
288    }
289
290    /// Enable or disable jitter
291    pub fn use_jitter(mut self, use_jitter: bool) -> Self {
292        self.config.use_jitter = use_jitter;
293        self
294    }
295
296    /// Set jitter factor
297    pub fn jitter_factor(mut self, factor: f64) -> Self {
298        self.config.jitter_factor = factor.clamp(0.0, 1.0);
299        self
300    }
301
302    /// Build the configuration
303    pub fn build(self) -> RetryConfig {
304        self.config
305    }
306}
307
308impl Default for RetryConfigBuilder {
309    fn default() -> Self {
310        Self::new()
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317    use std::sync::atomic::{AtomicU32, Ordering};
318    use std::sync::Arc;
319
320    #[tokio::test]
321    async fn test_immediate_success() {
322        let policy = RetryPolicy::default();
323        let counter = Arc::new(AtomicU32::new(0));
324        let counter_clone = Arc::clone(&counter);
325
326        let result = policy
327            .execute(|| async {
328                counter_clone.fetch_add(1, Ordering::SeqCst);
329                Ok::<_, String>("success")
330            })
331            .await;
332
333        assert!(result.is_ok());
334        assert_eq!(result.unwrap(), "success");
335        assert_eq!(counter.load(Ordering::SeqCst), 1);
336    }
337
338    #[tokio::test]
339    async fn test_retry_on_failure() {
340        let config = RetryConfig {
341            max_retries: 3,
342            initial_delay: Duration::from_millis(10),
343            ..Default::default()
344        };
345        let policy = RetryPolicy::new(config);
346        let counter = Arc::new(AtomicU32::new(0));
347        let counter_clone = Arc::clone(&counter);
348
349        let result = policy
350            .execute(|| async {
351                let count = counter_clone.fetch_add(1, Ordering::SeqCst);
352                if count < 2 {
353                    Err("temporary failure")
354                } else {
355                    Ok("success")
356                }
357            })
358            .await;
359
360        assert!(result.is_ok());
361        assert_eq!(result.unwrap(), "success");
362        assert_eq!(counter.load(Ordering::SeqCst), 3);
363    }
364
365    #[tokio::test]
366    async fn test_max_retries_exceeded() {
367        let config = RetryConfig {
368            max_retries: 2,
369            initial_delay: Duration::from_millis(10),
370            ..Default::default()
371        };
372        let policy = RetryPolicy::new(config);
373        let counter = Arc::new(AtomicU32::new(0));
374        let counter_clone = Arc::clone(&counter);
375
376        let result = policy
377            .execute(|| async {
378                counter_clone.fetch_add(1, Ordering::SeqCst);
379                Err::<String, _>("persistent failure")
380            })
381            .await;
382
383        assert!(result.is_err());
384        assert_eq!(counter.load(Ordering::SeqCst), 3); // Initial + 2 retries
385
386        match result {
387            Err(RetryError::MaxRetriesExceeded { attempts, .. }) => {
388                assert_eq!(attempts, 3);
389            }
390            _ => panic!("Expected MaxRetriesExceeded error"),
391        }
392    }
393
394    #[tokio::test]
395    async fn test_exponential_backoff() {
396        let config = RetryConfig {
397            max_retries: 3,
398            initial_delay: Duration::from_millis(50),
399            multiplier: 2.0,
400            use_jitter: false,
401            ..Default::default()
402        };
403        let policy = RetryPolicy::new(config);
404
405        let delay1 = policy.calculate_delay(1);
406        let delay2 = policy.calculate_delay(2);
407        let delay3 = policy.calculate_delay(3);
408
409        assert_eq!(delay1, Duration::from_millis(50));
410        assert_eq!(delay2, Duration::from_millis(100));
411        assert_eq!(delay3, Duration::from_millis(200));
412    }
413
414    #[tokio::test]
415    async fn test_max_delay_cap() {
416        let config = RetryConfig {
417            max_retries: 5,
418            initial_delay: Duration::from_millis(100),
419            max_delay: Duration::from_millis(500),
420            multiplier: 2.0,
421            use_jitter: false,
422            ..Default::default()
423        };
424        let policy = RetryPolicy::new(config);
425
426        let delay5 = policy.calculate_delay(5);
427        assert_eq!(delay5, Duration::from_millis(500)); // Capped at max_delay
428    }
429
430    #[tokio::test]
431    async fn test_jitter_adds_variation() {
432        let config = RetryConfig {
433            max_retries: 1,
434            initial_delay: Duration::from_millis(100),
435            use_jitter: true,
436            jitter_factor: 0.5,
437            ..Default::default()
438        };
439        let policy = RetryPolicy::new(config);
440
441        // Generate multiple delays and check they vary
442        let mut delays = vec![];
443        for _ in 0..10 {
444            let delay = policy.calculate_delay(1);
445            delays.push(delay);
446        }
447
448        // Check that delays are in reasonable range (50-150ms with 50% jitter)
449        for delay in &delays {
450            let ms = delay.as_millis();
451            assert!(ms >= 50 && ms <= 150, "Delay {} outside expected range", ms);
452        }
453
454        // Check that we got some variation
455        let all_same = delays.iter().all(|d| d == &delays[0]);
456        assert!(!all_same, "All delays are the same, jitter not working");
457    }
458
459    #[tokio::test]
460    async fn test_synchronous_retry() {
461        let config = RetryConfig {
462            max_retries: 3,
463            initial_delay: Duration::from_millis(10),
464            ..Default::default()
465        };
466        let policy = RetryPolicy::new(config);
467        let counter = Arc::new(AtomicU32::new(0));
468        let counter_clone = Arc::clone(&counter);
469
470        let result = policy.execute_sync(|| {
471            let count = counter_clone.fetch_add(1, Ordering::SeqCst);
472            if count < 2 {
473                Err("temporary failure")
474            } else {
475                Ok("success")
476            }
477        });
478
479        assert!(result.is_ok());
480        assert_eq!(result.unwrap(), "success");
481        assert_eq!(counter.load(Ordering::SeqCst), 3);
482    }
483
484    #[tokio::test]
485    async fn test_retry_with_custom_condition() {
486        let config = RetryConfig {
487            max_retries: 3,
488            initial_delay: Duration::from_millis(10),
489            ..Default::default()
490        };
491        let policy = RetryPolicy::new(config);
492        let counter = Arc::new(AtomicU32::new(0));
493        let counter_clone = Arc::clone(&counter);
494
495        // Only retry on "temporary" errors, not "permanent" ones
496        let result = policy
497            .execute_with_condition(
498                || {
499                    let counter = Arc::clone(&counter_clone);
500                    async move {
501                        counter.fetch_add(1, Ordering::SeqCst);
502                        Err::<String, _>("permanent error")
503                    }
504                },
505                |e| e.contains("temporary"),
506            )
507            .await;
508
509        // Should abort on first attempt since error is not retryable
510        assert!(matches!(result, Err(RetryError::Aborted(_))));
511        assert_eq!(counter.load(Ordering::SeqCst), 1);
512    }
513
514    #[tokio::test]
515    async fn test_retry_builder() {
516        let config = RetryConfigBuilder::new()
517            .max_retries(5)
518            .initial_delay(Duration::from_millis(50))
519            .max_delay(Duration::from_secs(10))
520            .multiplier(3.0)
521            .use_jitter(false)
522            .build();
523
524        assert_eq!(config.max_retries, 5);
525        assert_eq!(config.initial_delay, Duration::from_millis(50));
526        assert_eq!(config.max_delay, Duration::from_secs(10));
527        assert_eq!(config.multiplier, 3.0);
528        assert!(!config.use_jitter);
529    }
530
531    #[tokio::test]
532    async fn test_zero_retries() {
533        let config = RetryConfig {
534            max_retries: 0,
535            initial_delay: Duration::from_millis(10),
536            ..Default::default()
537        };
538        let policy = RetryPolicy::new(config);
539        let counter = Arc::new(AtomicU32::new(0));
540        let counter_clone = Arc::clone(&counter);
541
542        let result = policy
543            .execute(|| async {
544                counter_clone.fetch_add(1, Ordering::SeqCst);
545                Err::<String, _>("error")
546            })
547            .await;
548
549        assert!(result.is_err());
550        assert_eq!(counter.load(Ordering::SeqCst), 1); // Only initial attempt
551    }
552
553    #[tokio::test]
554    async fn test_concurrent_retries() {
555        let policy = Arc::new(RetryPolicy::default());
556        let mut handles = vec![];
557
558        for i in 0..5 {
559            let policy_clone = Arc::clone(&policy);
560            let handle = tokio::spawn(async move {
561                policy_clone
562                    .execute(|| async move {
563                        if i % 2 == 0 {
564                            Ok(format!("success {}", i))
565                        } else {
566                            Err(format!("error {}", i))
567                        }
568                    })
569                    .await
570            });
571            handles.push(handle);
572        }
573
574        for (i, handle) in handles.into_iter().enumerate() {
575            let result = handle.await.unwrap();
576            if i % 2 == 0 {
577                assert!(result.is_ok());
578            } else {
579                assert!(result.is_err());
580            }
581        }
582    }
583
584    #[tokio::test]
585    async fn test_jitter_factor_clamping() {
586        let config = RetryConfigBuilder::new()
587            .jitter_factor(1.5) // Should be clamped to 1.0
588            .build();
589
590        assert_eq!(config.jitter_factor, 1.0);
591
592        let config = RetryConfigBuilder::new()
593            .jitter_factor(-0.5) // Should be clamped to 0.0
594            .build();
595
596        assert_eq!(config.jitter_factor, 0.0);
597    }
598
599    #[tokio::test]
600    async fn test_timing_accuracy() {
601        let config = RetryConfig {
602            max_retries: 2,
603            initial_delay: Duration::from_millis(100),
604            multiplier: 2.0,
605            use_jitter: false,
606            ..Default::default()
607        };
608        let policy = RetryPolicy::new(config);
609        let counter = Arc::new(AtomicU32::new(0));
610        let counter_clone = Arc::clone(&counter);
611
612        let start = std::time::Instant::now();
613
614        let _ = policy
615            .execute(|| async {
616                counter_clone.fetch_add(1, Ordering::SeqCst);
617                Err::<String, _>("error")
618            })
619            .await;
620
621        let elapsed = start.elapsed();
622
623        // Should take at least 100ms (first retry) + 200ms (second retry) = 300ms
624        assert!(
625            elapsed >= Duration::from_millis(300),
626            "Elapsed time {:?} less than expected",
627            elapsed
628        );
629    }
630}