Skip to main content

mx_core/resilience/
retry.rs

1//! Retry policy with exponential backoff and jitter support.
2
3use std::future::Future;
4use std::time::Duration;
5
6use rand::Rng;
7use tokio::time::sleep;
8
9const DEFAULT_MAX_DELAY: Duration = Duration::from_secs(30);
10const DEFAULT_JITTER_FACTOR: f64 = 0.25;
11
12/// Configuration for retry behavior with exponential backoff support.
13#[derive(Debug, Clone)]
14pub struct RetryPolicy {
15    pub max_attempts: u32,
16    pub base_delay: Duration,
17    pub max_delay: Duration,
18    pub jitter_factor: f64,
19    use_exponential_backoff: bool,
20    retryable_predicate: Option<fn(&str) -> bool>,
21}
22
23impl RetryPolicy {
24    /// Fixed delay retry (legacy).
25    pub fn new(max_attempts: u32, delay_ms: u64) -> Self {
26        Self {
27            max_attempts,
28            base_delay: Duration::from_millis(delay_ms),
29            max_delay: Duration::from_millis(delay_ms),
30            jitter_factor: 0.0,
31            use_exponential_backoff: false,
32            retryable_predicate: None,
33        }
34    }
35
36    /// Exponential backoff retry with jitter.
37    pub fn with_exponential_backoff(
38        max_attempts: u32,
39        base_delay: Duration,
40        max_delay: Duration,
41        jitter_factor: f64,
42    ) -> Self {
43        Self {
44            max_attempts,
45            base_delay,
46            max_delay,
47            jitter_factor: jitter_factor.clamp(0.0, 1.0),
48            use_exponential_backoff: true,
49            retryable_predicate: None,
50        }
51    }
52
53    /// Sets a predicate function to determine if an error should trigger a retry.
54    pub fn with_retryable_predicate(mut self, predicate: fn(&str) -> bool) -> Self {
55        self.retryable_predicate = Some(predicate);
56        self
57    }
58
59    /// Computes the delay for a given attempt number.
60    ///
61    /// This method contains self-contained exponential backoff logic
62    /// (`base_delay * 2^attempt`, capped at `max_delay`, with optional jitter).
63    /// There is a separate [`super::ExponentialBackoff`] type elsewhere in the
64    /// resilience module, but this implementation is intentionally kept inline:
65    /// `RetryPolicy` needs tight control over attempt numbering and jitter that
66    /// would not benefit from an extra layer of indirection.
67    pub fn next_delay(&self, attempt: u32) -> Duration {
68        if !self.use_exponential_backoff {
69            return self.base_delay;
70        }
71
72        let base_ms = self.base_delay.as_millis() as u64;
73        let max_ms = self.max_delay.as_millis() as u64;
74
75        let delay_ms = if attempt >= 64 {
76            max_ms
77        } else {
78            let multiplier = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
79            base_ms.saturating_mul(multiplier).min(max_ms)
80        };
81
82        if self.jitter_factor > 0.0 {
83            self.apply_jitter(Duration::from_millis(delay_ms))
84        } else {
85            Duration::from_millis(delay_ms)
86        }
87    }
88
89    fn apply_jitter(&self, delay: Duration) -> Duration {
90        let mut rng = rand::rng();
91        let jitter_range = self.jitter_factor * 2.0;
92        let jitter_offset = rng.random::<f64>() * jitter_range - self.jitter_factor;
93        let factor = 1.0 + jitter_offset;
94        let delay_ms = delay.as_millis() as f64;
95        let jittered_ms = (delay_ms * factor).max(1.0) as u64;
96        Duration::from_millis(jittered_ms)
97    }
98
99    fn is_retryable(&self, error_msg: &str) -> bool {
100        match self.retryable_predicate {
101            Some(pred) => pred(error_msg),
102            None => true,
103        }
104    }
105
106    /// Executes an async operation with retry logic.
107    pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> Result<T, E>
108    where
109        F: FnMut() -> Fut,
110        Fut: Future<Output = Result<T, E>>,
111        E: std::fmt::Display,
112    {
113        let mut last_error: Option<E> = None;
114
115        for attempt in 0..self.max_attempts {
116            match operation().await {
117                Ok(result) => return Ok(result),
118                Err(e) => {
119                    let error_msg = e.to_string();
120                    if !self.is_retryable(&error_msg) {
121                        return Err(e);
122                    }
123                    last_error = Some(e);
124                    if attempt < self.max_attempts - 1 {
125                        let delay = self.next_delay(attempt);
126                        sleep(delay).await;
127                    }
128                }
129            }
130        }
131
132        Err(last_error.expect("at least one attempt must have been made"))
133    }
134}
135
136impl Default for RetryPolicy {
137    fn default() -> Self {
138        Self::with_exponential_backoff(
139            3,
140            Duration::from_millis(200),
141            DEFAULT_MAX_DELAY,
142            DEFAULT_JITTER_FACTOR,
143        )
144    }
145}
146
147#[cfg(test)]
148mod tests {
149    use super::*;
150
151    #[test]
152    fn test_fixed_delay() {
153        let policy = RetryPolicy::new(3, 200);
154        assert_eq!(policy.next_delay(0), Duration::from_millis(200));
155        assert_eq!(policy.next_delay(1), Duration::from_millis(200));
156        assert_eq!(policy.next_delay(2), Duration::from_millis(200));
157    }
158
159    #[test]
160    fn test_exponential_no_jitter() {
161        let policy = RetryPolicy::with_exponential_backoff(
162            5,
163            Duration::from_millis(100),
164            Duration::from_secs(30),
165            0.0,
166        );
167        assert_eq!(policy.next_delay(0), Duration::from_millis(100));
168        assert_eq!(policy.next_delay(1), Duration::from_millis(200));
169        assert_eq!(policy.next_delay(2), Duration::from_millis(400));
170    }
171
172    #[test]
173    fn test_capped_at_max_delay() {
174        let policy = RetryPolicy::with_exponential_backoff(
175            5,
176            Duration::from_millis(100),
177            Duration::from_secs(1),
178            0.0,
179        );
180        assert_eq!(policy.next_delay(20), Duration::from_secs(1));
181    }
182
183    #[test]
184    fn test_overflow_protection() {
185        let policy = RetryPolicy::with_exponential_backoff(
186            5,
187            Duration::from_secs(1),
188            Duration::from_secs(3600),
189            0.0,
190        );
191        let delay = policy.next_delay(100);
192        assert!(delay <= Duration::from_secs(3600));
193    }
194
195    #[tokio::test]
196    async fn test_execute_success() {
197        let policy = RetryPolicy::new(3, 10);
198        let result: Result<i32, String> = policy.execute(|| async { Ok(42) }).await;
199        assert_eq!(result.unwrap(), 42);
200    }
201
202    #[tokio::test]
203    async fn test_execute_retries_then_succeeds() {
204        use std::sync::Arc;
205        use std::sync::atomic::{AtomicU32, Ordering};
206
207        let attempts = Arc::new(AtomicU32::new(0));
208        let attempts_clone = Arc::clone(&attempts);
209
210        let policy = RetryPolicy::new(3, 1);
211        let result: Result<i32, String> = policy
212            .execute(|| {
213                let a = Arc::clone(&attempts_clone);
214                async move {
215                    let count = a.fetch_add(1, Ordering::SeqCst);
216                    if count < 2 {
217                        Err("transient".to_string())
218                    } else {
219                        Ok(42)
220                    }
221                }
222            })
223            .await;
224
225        assert_eq!(result.unwrap(), 42);
226        assert_eq!(attempts.load(Ordering::SeqCst), 3);
227    }
228}