omega_runtime/
retry.rs

1//! Retry Policy Implementation
2//!
3//! Provides configurable retry logic with exponential backoff and jitter
4//! for handling transient failures in distributed systems.
5
6use rand::Rng;
7use std::time::Duration;
8use thiserror::Error;
9use tracing::{debug, warn};
10
11/// Retry policy configuration
12#[derive(Debug, Clone)]
13pub struct RetryConfig {
14    /// Maximum number of retry attempts
15    pub max_retries: u32,
16    /// Initial delay before first retry
17    pub initial_delay: Duration,
18    /// Maximum delay between retries
19    pub max_delay: Duration,
20    /// Backoff multiplier for exponential backoff
21    pub multiplier: f64,
22    /// Whether to add jitter to delays (recommended to prevent thundering herd)
23    pub use_jitter: bool,
24    /// Jitter factor (0.0 to 1.0, where 0.5 means +/- 50% variation)
25    pub jitter_factor: f64,
26}
27
28impl Default for RetryConfig {
29    fn default() -> Self {
30        Self {
31            max_retries: 3,
32            initial_delay: Duration::from_millis(100),
33            max_delay: Duration::from_secs(30),
34            multiplier: 2.0,
35            use_jitter: true,
36            jitter_factor: 0.3,
37        }
38    }
39}
40
41/// Retry policy error types
42#[derive(Debug, Error)]
43pub enum RetryError<E> {
44    #[error("Operation failed after {attempts} attempts: {last_error}")]
45    MaxRetriesExceeded { attempts: u32, last_error: E },
46    #[error("Retry aborted: {0}")]
47    Aborted(String),
48}
49
50/// Retry policy for executing operations with automatic retries
51pub struct RetryPolicy {
52    config: RetryConfig,
53}
54
55impl RetryPolicy {
56    /// Create a new retry policy with the given configuration
57    pub fn new(config: RetryConfig) -> Self {
58        Self { config }
59    }
60
61    /// Create a retry policy with default configuration
62    pub fn with_default_config() -> Self {
63        Self {
64            config: RetryConfig::default(),
65        }
66    }
67
68    /// Execute an async operation with retry logic
69    pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> Result<T, RetryError<E>>
70    where
71        F: FnMut() -> Fut,
72        Fut: std::future::Future<Output = Result<T, E>>,
73        E: std::fmt::Display + Clone,
74    {
75        let mut attempts = 0;
76        #[allow(unused_assignments)]
77        let mut _last_error: Option<E> = None;
78
79        loop {
80            attempts += 1;
81
82            debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
83
84            match operation().await {
85                Ok(result) => {
86                    if attempts > 1 {
87                        debug!("Operation succeeded after {} attempts", attempts);
88                    }
89                    return Ok(result);
90                }
91                Err(e) => {
92                    warn!("Attempt {} failed: {}", attempts, e);
93
94                    if attempts > self.config.max_retries {
95                        return Err(RetryError::MaxRetriesExceeded {
96                            attempts,
97                            last_error: e,
98                        });
99                    }
100
101                    _last_error = Some(e);
102
103                    // Calculate delay with exponential backoff
104                    let delay = self.calculate_delay(attempts);
105                    debug!("Waiting {:?} before retry", delay);
106
107                    tokio::time::sleep(delay).await;
108                }
109            }
110        }
111    }
112
113    /// Execute a synchronous operation with retry logic
114    pub fn execute_sync<F, T, E>(&self, mut operation: F) -> Result<T, RetryError<E>>
115    where
116        F: FnMut() -> Result<T, E>,
117        E: std::fmt::Display + Clone,
118    {
119        let mut attempts = 0;
120        #[allow(unused_assignments)]
121        let mut _last_error: Option<E> = None;
122
123        loop {
124            attempts += 1;
125
126            debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
127
128            match operation() {
129                Ok(result) => {
130                    if attempts > 1 {
131                        debug!("Operation succeeded after {} attempts", attempts);
132                    }
133                    return Ok(result);
134                }
135                Err(e) => {
136                    warn!("Attempt {} failed: {}", attempts, e);
137
138                    if attempts > self.config.max_retries {
139                        return Err(RetryError::MaxRetriesExceeded {
140                            attempts,
141                            last_error: e,
142                        });
143                    }
144
145                    _last_error = Some(e);
146
147                    // Calculate delay with exponential backoff
148                    let delay = self.calculate_delay(attempts);
149                    debug!("Waiting {:?} before retry", delay);
150
151                    std::thread::sleep(delay);
152                }
153            }
154        }
155    }
156
157    /// Execute with custom retry condition
158    pub async fn execute_with_condition<F, Fut, T, E, C>(
159        &self,
160        mut operation: F,
161        mut should_retry: C,
162    ) -> Result<T, RetryError<E>>
163    where
164        F: FnMut() -> Fut,
165        Fut: std::future::Future<Output = Result<T, E>>,
166        E: std::fmt::Display + Clone,
167        C: FnMut(&E) -> bool,
168    {
169        let mut attempts = 0;
170
171        loop {
172            attempts += 1;
173
174            debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
175
176            match operation().await {
177                Ok(result) => {
178                    if attempts > 1 {
179                        debug!("Operation succeeded after {} attempts", attempts);
180                    }
181                    return Ok(result);
182                }
183                Err(e) => {
184                    if !should_retry(&e) {
185                        debug!("Error is not retryable: {}", e);
186                        return Err(RetryError::Aborted(format!(
187                            "Non-retryable error after {} attempts: {}",
188                            attempts, e
189                        )));
190                    }
191
192                    warn!("Attempt {} failed: {}", attempts, e);
193
194                    if attempts > self.config.max_retries {
195                        return Err(RetryError::MaxRetriesExceeded {
196                            attempts,
197                            last_error: e,
198                        });
199                    }
200
201                    // Calculate delay with exponential backoff
202                    let delay = self.calculate_delay(attempts);
203                    debug!("Waiting {:?} before retry", delay);
204
205                    tokio::time::sleep(delay).await;
206                }
207            }
208        }
209    }
210
211    /// Calculate delay for the given attempt number
212    fn calculate_delay(&self, attempt: u32) -> Duration {
213        // Calculate base delay using exponential backoff
214        let base_delay_ms = self.config.initial_delay.as_millis() as f64
215            * self.config.multiplier.powi((attempt - 1) as i32);
216
217        let base_delay = Duration::from_millis(base_delay_ms as u64);
218
219        // Cap at max delay
220        let capped_delay = if base_delay > self.config.max_delay {
221            self.config.max_delay
222        } else {
223            base_delay
224        };
225
226        // Add jitter if enabled
227        if self.config.use_jitter {
228            self.add_jitter(capped_delay)
229        } else {
230            capped_delay
231        }
232    }
233
234    /// Add jitter to a delay to prevent thundering herd
235    fn add_jitter(&self, delay: Duration) -> Duration {
236        let mut rng = rand::thread_rng();
237        let delay_ms = delay.as_millis() as f64;
238
239        // Calculate jitter range
240        let jitter_range = delay_ms * self.config.jitter_factor;
241
242        // Add random jitter in range [-jitter_range, +jitter_range]
243        let jitter = rng.gen_range(-jitter_range..=jitter_range);
244        let jittered_ms = (delay_ms + jitter).max(0.0);
245
246        Duration::from_millis(jittered_ms as u64)
247    }
248
249    /// Get the configuration
250    pub fn config(&self) -> &RetryConfig {
251        &self.config
252    }
253}
254
255/// Builder for retry configuration
256pub struct RetryConfigBuilder {
257    config: RetryConfig,
258}
259
260impl RetryConfigBuilder {
261    /// Create a new builder with default configuration
262    pub fn new() -> Self {
263        Self {
264            config: RetryConfig::default(),
265        }
266    }
267
268    /// Set maximum number of retries
269    pub fn max_retries(mut self, max_retries: u32) -> Self {
270        self.config.max_retries = max_retries;
271        self
272    }
273
274    /// Set initial delay
275    pub fn initial_delay(mut self, delay: Duration) -> Self {
276        self.config.initial_delay = delay;
277        self
278    }
279
280    /// Set maximum delay
281    pub fn max_delay(mut self, delay: Duration) -> Self {
282        self.config.max_delay = delay;
283        self
284    }
285
286    /// Set backoff multiplier
287    pub fn multiplier(mut self, multiplier: f64) -> Self {
288        self.config.multiplier = multiplier;
289        self
290    }
291
292    /// Enable or disable jitter
293    pub fn use_jitter(mut self, use_jitter: bool) -> Self {
294        self.config.use_jitter = use_jitter;
295        self
296    }
297
298    /// Set jitter factor
299    pub fn jitter_factor(mut self, factor: f64) -> Self {
300        self.config.jitter_factor = factor.clamp(0.0, 1.0);
301        self
302    }
303
304    /// Build the configuration
305    pub fn build(self) -> RetryConfig {
306        self.config
307    }
308}
309
310impl Default for RetryConfigBuilder {
311    fn default() -> Self {
312        Self::new()
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use std::sync::atomic::{AtomicU32, Ordering};
320    use std::sync::Arc;
321
322    #[tokio::test]
323    async fn test_immediate_success() {
324        let policy = RetryPolicy::default();
325        let counter = Arc::new(AtomicU32::new(0));
326        let counter_clone = Arc::clone(&counter);
327
328        let result = policy
329            .execute(|| async {
330                counter_clone.fetch_add(1, Ordering::SeqCst);
331                Ok::<_, String>("success")
332            })
333            .await;
334
335        assert!(result.is_ok());
336        assert_eq!(result.unwrap(), "success");
337        assert_eq!(counter.load(Ordering::SeqCst), 1);
338    }
339
340    #[tokio::test]
341    async fn test_retry_on_failure() {
342        let config = RetryConfig {
343            max_retries: 3,
344            initial_delay: Duration::from_millis(10),
345            ..Default::default()
346        };
347        let policy = RetryPolicy::new(config);
348        let counter = Arc::new(AtomicU32::new(0));
349        let counter_clone = Arc::clone(&counter);
350
351        let result = policy
352            .execute(|| async {
353                let count = counter_clone.fetch_add(1, Ordering::SeqCst);
354                if count < 2 {
355                    Err("temporary failure")
356                } else {
357                    Ok("success")
358                }
359            })
360            .await;
361
362        assert!(result.is_ok());
363        assert_eq!(result.unwrap(), "success");
364        assert_eq!(counter.load(Ordering::SeqCst), 3);
365    }
366
367    #[tokio::test]
368    async fn test_max_retries_exceeded() {
369        let config = RetryConfig {
370            max_retries: 2,
371            initial_delay: Duration::from_millis(10),
372            ..Default::default()
373        };
374        let policy = RetryPolicy::new(config);
375        let counter = Arc::new(AtomicU32::new(0));
376        let counter_clone = Arc::clone(&counter);
377
378        let result = policy
379            .execute(|| async {
380                counter_clone.fetch_add(1, Ordering::SeqCst);
381                Err::<String, _>("persistent failure")
382            })
383            .await;
384
385        assert!(result.is_err());
386        assert_eq!(counter.load(Ordering::SeqCst), 3); // Initial + 2 retries
387
388        match result {
389            Err(RetryError::MaxRetriesExceeded { attempts, .. }) => {
390                assert_eq!(attempts, 3);
391            }
392            _ => panic!("Expected MaxRetriesExceeded error"),
393        }
394    }
395
396    #[tokio::test]
397    async fn test_exponential_backoff() {
398        let config = RetryConfig {
399            max_retries: 3,
400            initial_delay: Duration::from_millis(50),
401            multiplier: 2.0,
402            use_jitter: false,
403            ..Default::default()
404        };
405        let policy = RetryPolicy::new(config);
406
407        let delay1 = policy.calculate_delay(1);
408        let delay2 = policy.calculate_delay(2);
409        let delay3 = policy.calculate_delay(3);
410
411        assert_eq!(delay1, Duration::from_millis(50));
412        assert_eq!(delay2, Duration::from_millis(100));
413        assert_eq!(delay3, Duration::from_millis(200));
414    }
415
416    #[tokio::test]
417    async fn test_max_delay_cap() {
418        let config = RetryConfig {
419            max_retries: 5,
420            initial_delay: Duration::from_millis(100),
421            max_delay: Duration::from_millis(500),
422            multiplier: 2.0,
423            use_jitter: false,
424            ..Default::default()
425        };
426        let policy = RetryPolicy::new(config);
427
428        let delay5 = policy.calculate_delay(5);
429        assert_eq!(delay5, Duration::from_millis(500)); // Capped at max_delay
430    }
431
432    #[tokio::test]
433    async fn test_jitter_adds_variation() {
434        let config = RetryConfig {
435            max_retries: 1,
436            initial_delay: Duration::from_millis(100),
437            use_jitter: true,
438            jitter_factor: 0.5,
439            ..Default::default()
440        };
441        let policy = RetryPolicy::new(config);
442
443        // Generate multiple delays and check they vary
444        let mut delays = vec![];
445        for _ in 0..10 {
446            let delay = policy.calculate_delay(1);
447            delays.push(delay);
448        }
449
450        // Check that delays are in reasonable range (50-150ms with 50% jitter)
451        for delay in &delays {
452            let ms = delay.as_millis();
453            assert!(ms >= 50 && ms <= 150, "Delay {} outside expected range", ms);
454        }
455
456        // Check that we got some variation
457        let all_same = delays.iter().all(|d| d == &delays[0]);
458        assert!(!all_same, "All delays are the same, jitter not working");
459    }
460
461    #[tokio::test]
462    async fn test_synchronous_retry() {
463        let config = RetryConfig {
464            max_retries: 3,
465            initial_delay: Duration::from_millis(10),
466            ..Default::default()
467        };
468        let policy = RetryPolicy::new(config);
469        let counter = Arc::new(AtomicU32::new(0));
470        let counter_clone = Arc::clone(&counter);
471
472        let result = policy.execute_sync(|| {
473            let count = counter_clone.fetch_add(1, Ordering::SeqCst);
474            if count < 2 {
475                Err("temporary failure")
476            } else {
477                Ok("success")
478            }
479        });
480
481        assert!(result.is_ok());
482        assert_eq!(result.unwrap(), "success");
483        assert_eq!(counter.load(Ordering::SeqCst), 3);
484    }
485
486    #[tokio::test]
487    async fn test_retry_with_custom_condition() {
488        let config = RetryConfig {
489            max_retries: 3,
490            initial_delay: Duration::from_millis(10),
491            ..Default::default()
492        };
493        let policy = RetryPolicy::new(config);
494        let counter = Arc::new(AtomicU32::new(0));
495        let counter_clone = Arc::clone(&counter);
496
497        // Only retry on "temporary" errors, not "permanent" ones
498        let result = policy
499            .execute_with_condition(
500                || {
501                    let counter = Arc::clone(&counter_clone);
502                    async move {
503                        counter.fetch_add(1, Ordering::SeqCst);
504                        Err::<String, _>("permanent error")
505                    }
506                },
507                |e| e.contains("temporary"),
508            )
509            .await;
510
511        // Should abort on first attempt since error is not retryable
512        assert!(matches!(result, Err(RetryError::Aborted(_))));
513        assert_eq!(counter.load(Ordering::SeqCst), 1);
514    }
515
516    #[tokio::test]
517    async fn test_retry_builder() {
518        let config = RetryConfigBuilder::new()
519            .max_retries(5)
520            .initial_delay(Duration::from_millis(50))
521            .max_delay(Duration::from_secs(10))
522            .multiplier(3.0)
523            .use_jitter(false)
524            .build();
525
526        assert_eq!(config.max_retries, 5);
527        assert_eq!(config.initial_delay, Duration::from_millis(50));
528        assert_eq!(config.max_delay, Duration::from_secs(10));
529        assert_eq!(config.multiplier, 3.0);
530        assert!(!config.use_jitter);
531    }
532
533    #[tokio::test]
534    async fn test_zero_retries() {
535        let config = RetryConfig {
536            max_retries: 0,
537            initial_delay: Duration::from_millis(10),
538            ..Default::default()
539        };
540        let policy = RetryPolicy::new(config);
541        let counter = Arc::new(AtomicU32::new(0));
542        let counter_clone = Arc::clone(&counter);
543
544        let result = policy
545            .execute(|| async {
546                counter_clone.fetch_add(1, Ordering::SeqCst);
547                Err::<String, _>("error")
548            })
549            .await;
550
551        assert!(result.is_err());
552        assert_eq!(counter.load(Ordering::SeqCst), 1); // Only initial attempt
553    }
554
555    #[tokio::test]
556    async fn test_concurrent_retries() {
557        let policy = Arc::new(RetryPolicy::default());
558        let mut handles = vec![];
559
560        for i in 0..5 {
561            let policy_clone = Arc::clone(&policy);
562            let handle = tokio::spawn(async move {
563                policy_clone
564                    .execute(|| async move {
565                        if i % 2 == 0 {
566                            Ok(format!("success {}", i))
567                        } else {
568                            Err(format!("error {}", i))
569                        }
570                    })
571                    .await
572            });
573            handles.push(handle);
574        }
575
576        for (i, handle) in handles.into_iter().enumerate() {
577            let result = handle.await.unwrap();
578            if i % 2 == 0 {
579                assert!(result.is_ok());
580            } else {
581                assert!(result.is_err());
582            }
583        }
584    }
585
586    #[tokio::test]
587    async fn test_jitter_factor_clamping() {
588        let config = RetryConfigBuilder::new()
589            .jitter_factor(1.5) // Should be clamped to 1.0
590            .build();
591
592        assert_eq!(config.jitter_factor, 1.0);
593
594        let config = RetryConfigBuilder::new()
595            .jitter_factor(-0.5) // Should be clamped to 0.0
596            .build();
597
598        assert_eq!(config.jitter_factor, 0.0);
599    }
600
601    #[tokio::test]
602    async fn test_timing_accuracy() {
603        let config = RetryConfig {
604            max_retries: 2,
605            initial_delay: Duration::from_millis(100),
606            multiplier: 2.0,
607            use_jitter: false,
608            ..Default::default()
609        };
610        let policy = RetryPolicy::new(config);
611        let counter = Arc::new(AtomicU32::new(0));
612        let counter_clone = Arc::clone(&counter);
613
614        let start = std::time::Instant::now();
615
616        let _ = policy
617            .execute(|| async {
618                counter_clone.fetch_add(1, Ordering::SeqCst);
619                Err::<String, _>("error")
620            })
621            .await;
622
623        let elapsed = start.elapsed();
624
625        // Should take at least 100ms (first retry) + 200ms (second retry) = 300ms
626        assert!(
627            elapsed >= Duration::from_millis(300),
628            "Elapsed time {:?} less than expected",
629            elapsed
630        );
631    }
632}