Skip to main content

chio_guards/external/
retry.rs

1//! Retry with deterministic jitter for transient external failures.
2//!
3//! [`retry_with_jitter`] runs an async operation up to `max_retries + 1`
4//! times, sleeping between attempts with a backoff controlled by
5//! [`BackoffStrategy`] and a bounded multiplicative jitter. The sleep uses
6//! [`tokio::time::sleep`], which honors [`tokio::time::pause`] + `advance`
7//! so tests don't depend on wall-clock time.
8//!
9//! Jitter is seeded deterministically from the attempt number by default,
10//! which keeps tests reproducible; callers can override the RNG via
11//! [`retry_with_jitter_rng`].
12
13use std::future::Future;
14use std::time::Duration;
15
16use rand::rngs::StdRng;
17use rand::Rng;
18use rand::SeedableRng;
19
20/// Backoff strategy between retry attempts.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum BackoffStrategy {
23    /// Each attempt sleeps `base_delay * 2^(attempt - 1)` before jitter.
24    Exponential,
25    /// Each attempt sleeps `base_delay` before jitter.
26    Constant,
27    /// Each attempt sleeps `base_delay * attempt` before jitter.
28    Linear,
29}
30
31/// Retry configuration.
32#[derive(Debug, Clone)]
33pub struct RetryConfig {
34    /// Maximum number of retries after the initial attempt. A value of `0`
35    /// means the operation is attempted exactly once.
36    pub max_retries: u32,
37    /// Base delay for the first retry.
38    pub base_delay: Duration,
39    /// Upper bound on the sleep between attempts (before jitter is added).
40    pub max_delay: Duration,
41    /// Fraction of the computed delay to use as bounded multiplicative
42    /// jitter. Must be in `[0.0, 1.0]`; values outside that range are
43    /// clamped.
44    pub jitter_fraction: f64,
45    /// Backoff curve.
46    pub strategy: BackoffStrategy,
47}
48
49impl Default for RetryConfig {
50    fn default() -> Self {
51        Self {
52            max_retries: 3,
53            base_delay: Duration::from_millis(100),
54            max_delay: Duration::from_secs(5),
55            jitter_fraction: 0.25,
56            strategy: BackoffStrategy::Exponential,
57        }
58    }
59}
60
61/// Outcome reported by the caller's operation.
62pub type AttemptResult<T, E> = Result<T, E>;
63
64/// Run `op` with retry + jitter using a deterministic RNG seeded from
65/// `config.max_retries`. For customizable randomness see
66/// [`retry_with_jitter_rng`].
67pub async fn retry_with_jitter<F, Fut, T, E>(config: &RetryConfig, op: F) -> Result<T, E>
68where
69    F: FnMut(u32) -> Fut,
70    Fut: Future<Output = AttemptResult<T, E>>,
71{
72    let seed = u64::from(config.max_retries).wrapping_add(0x9E37_79B9_7F4A_7C15);
73    let rng = StdRng::seed_from_u64(seed);
74    retry_with_jitter_rng(config, rng, op).await
75}
76
77/// Run `op` with retry + jitter using a caller-supplied RNG.
78///
79/// `op` receives the current attempt number (1-indexed).
80pub async fn retry_with_jitter_rng<F, Fut, T, E, R>(
81    config: &RetryConfig,
82    mut rng: R,
83    mut op: F,
84) -> Result<T, E>
85where
86    F: FnMut(u32) -> Fut,
87    Fut: Future<Output = AttemptResult<T, E>>,
88    R: Rng,
89{
90    let total_attempts = config.max_retries.saturating_add(1);
91    let mut last_err: Option<E> = None;
92    for attempt in 1..=total_attempts {
93        match op(attempt).await {
94            Ok(value) => return Ok(value),
95            Err(err) => {
96                last_err = Some(err);
97                if attempt >= total_attempts {
98                    break;
99                }
100                let delay = compute_delay(config, attempt, &mut rng);
101                if !delay.is_zero() {
102                    tokio::time::sleep(delay).await;
103                }
104            }
105        }
106    }
107    match last_err {
108        Some(err) => Err(err),
109        // Unreachable in practice: total_attempts >= 1 so the loop body runs
110        // at least once and either returns Ok or records an error. We still
111        // return a sensible path without panicking.
112        None => unreachable!("retry loop must have produced at least one result"),
113    }
114}
115
116fn compute_delay<R: Rng>(config: &RetryConfig, attempt: u32, rng: &mut R) -> Duration {
117    let base = config.base_delay.as_secs_f64().max(0.0);
118    let raw = match config.strategy {
119        BackoffStrategy::Constant => base,
120        BackoffStrategy::Linear => base * f64::from(attempt.max(1)),
121        BackoffStrategy::Exponential => {
122            // 2^(attempt - 1). Clamp the exponent to avoid overflow.
123            let exp = attempt.saturating_sub(1).min(30);
124            base * (1u64 << exp) as f64
125        }
126    };
127    let max_secs = config.max_delay.as_secs_f64().max(0.0);
128    let capped = raw.min(max_secs);
129    let jitter = config.jitter_fraction.clamp(0.0, 1.0);
130    let factor = if jitter == 0.0 {
131        1.0
132    } else {
133        1.0 + rng.gen_range(-jitter..=jitter)
134    };
135    let jittered = (capped * factor).max(0.0);
136    Duration::from_secs_f64(jittered)
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use std::sync::atomic::{AtomicU32, Ordering};
143    use std::sync::Arc;
144
145    #[tokio::test(flavor = "current_thread", start_paused = true)]
146    async fn succeeds_on_first_attempt() {
147        let counter = Arc::new(AtomicU32::new(0));
148        let counter_clone = Arc::clone(&counter);
149        let config = RetryConfig::default();
150        let result: Result<u32, &'static str> = retry_with_jitter(&config, |_| {
151            let counter = Arc::clone(&counter_clone);
152            async move {
153                counter.fetch_add(1, Ordering::SeqCst);
154                Ok(42)
155            }
156        })
157        .await;
158        assert_eq!(result, Ok(42));
159        assert_eq!(counter.load(Ordering::SeqCst), 1);
160    }
161
162    #[tokio::test(flavor = "current_thread", start_paused = true)]
163    async fn succeeds_after_retries() {
164        let counter = Arc::new(AtomicU32::new(0));
165        let counter_clone = Arc::clone(&counter);
166        let config = RetryConfig {
167            max_retries: 4,
168            base_delay: Duration::from_millis(10),
169            max_delay: Duration::from_millis(40),
170            jitter_fraction: 0.0,
171            strategy: BackoffStrategy::Exponential,
172        };
173        let result: Result<u32, &'static str> = retry_with_jitter(&config, move |_| {
174            let counter = Arc::clone(&counter_clone);
175            async move {
176                let n = counter.fetch_add(1, Ordering::SeqCst) + 1;
177                if n < 3 {
178                    Err("transient")
179                } else {
180                    Ok(n)
181                }
182            }
183        })
184        .await;
185        assert_eq!(result, Ok(3));
186        assert_eq!(counter.load(Ordering::SeqCst), 3);
187    }
188
189    #[tokio::test(flavor = "current_thread", start_paused = true)]
190    async fn returns_last_error_after_exhausting_retries() {
191        let counter = Arc::new(AtomicU32::new(0));
192        let counter_clone = Arc::clone(&counter);
193        let config = RetryConfig {
194            max_retries: 2,
195            base_delay: Duration::from_millis(1),
196            max_delay: Duration::from_millis(4),
197            jitter_fraction: 0.0,
198            strategy: BackoffStrategy::Constant,
199        };
200        let result: Result<u32, &'static str> = retry_with_jitter(&config, move |_| {
201            let counter = Arc::clone(&counter_clone);
202            async move {
203                counter.fetch_add(1, Ordering::SeqCst);
204                Err("always fails")
205            }
206        })
207        .await;
208        assert_eq!(result, Err("always fails"));
209        assert_eq!(counter.load(Ordering::SeqCst), 3);
210    }
211
212    #[tokio::test(flavor = "current_thread", start_paused = true)]
213    async fn zero_max_retries_runs_once() {
214        let counter = Arc::new(AtomicU32::new(0));
215        let counter_clone = Arc::clone(&counter);
216        let config = RetryConfig {
217            max_retries: 0,
218            base_delay: Duration::from_millis(1),
219            max_delay: Duration::from_millis(1),
220            jitter_fraction: 0.0,
221            strategy: BackoffStrategy::Exponential,
222        };
223        let result: Result<u32, &'static str> = retry_with_jitter(&config, move |_| {
224            let counter = Arc::clone(&counter_clone);
225            async move {
226                counter.fetch_add(1, Ordering::SeqCst);
227                Err("boom")
228            }
229        })
230        .await;
231        assert_eq!(result, Err("boom"));
232        assert_eq!(counter.load(Ordering::SeqCst), 1);
233    }
234
235    #[test]
236    fn compute_delay_caps_at_max_delay() {
237        let config = RetryConfig {
238            max_retries: 10,
239            base_delay: Duration::from_millis(100),
240            max_delay: Duration::from_millis(500),
241            jitter_fraction: 0.0,
242            strategy: BackoffStrategy::Exponential,
243        };
244        let mut rng = StdRng::seed_from_u64(1);
245        // 2^9 * 100ms = 51.2s, should be capped to 500ms.
246        let d = compute_delay(&config, 10, &mut rng);
247        assert_eq!(d, Duration::from_millis(500));
248    }
249}