Skip to main content

reliability_toolkit/
retry.rs

1//! Exponential backoff with full jitter.
2//!
3//! `Retry::run` re-invokes the supplied closure until either it succeeds or
4//! the attempt budget is exhausted. Backoff is `min(max_delay, base * 2^attempt)`
5//! with [full jitter] applied to the result.
6//!
7//! Pass a custom retry predicate via [`RetryConfig::retry_if`] to skip
8//! re-execution for errors that aren't transient (auth failures, 4xx, …).
9//!
10//! [full jitter]: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/
11
12use std::future::Future;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::Arc;
15use std::time::Duration;
16
17use tokio::time::sleep;
18
19/// Decides whether a given error should trigger another retry attempt.
20pub type RetryPredicate<E> = Arc<dyn Fn(&E) -> bool + Send + Sync>;
21
22/// Configuration for [`Retry`].
23#[derive(Clone)]
24pub struct RetryConfig<E = std::io::Error> {
25    /// Maximum total attempts (including the first). Default: 3.
26    pub max_attempts: u32,
27    /// Initial backoff. Default: 100ms.
28    pub base_delay: Duration,
29    /// Cap on any single backoff. Default: 5s.
30    pub max_delay: Duration,
31    /// Predicate that decides if a given error should be retried.
32    /// Default: retry every error.
33    pub retry_if: Option<RetryPredicate<E>>,
34}
35
36impl<E> Default for RetryConfig<E> {
37    fn default() -> Self {
38        Self {
39            max_attempts: 3,
40            base_delay: Duration::from_millis(100),
41            max_delay: Duration::from_secs(5),
42            retry_if: None,
43        }
44    }
45}
46
47impl<E> std::fmt::Debug for RetryConfig<E> {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        f.debug_struct("RetryConfig")
50            .field("max_attempts", &self.max_attempts)
51            .field("base_delay", &self.base_delay)
52            .field("max_delay", &self.max_delay)
53            .field("retry_if", &self.retry_if.as_ref().map(|_| "<predicate>"))
54            .finish()
55    }
56}
57
58/// Exponential-backoff retry helper.
59#[derive(Clone, Debug)]
60pub struct Retry<E = std::io::Error> {
61    cfg: RetryConfig<E>,
62    seed: Arc<AtomicU64>,
63}
64
65impl<E> Retry<E> {
66    /// Build a `Retry` with the given config.
67    pub fn new(cfg: RetryConfig<E>) -> Self {
68        // Seed our tiny PRNG from the wall clock at construction time.
69        let seed = std::time::SystemTime::now()
70            .duration_since(std::time::UNIX_EPOCH)
71            .map_or(0xdead_beef, |d| d.as_nanos() as u64);
72        Self {
73            cfg,
74            seed: Arc::new(AtomicU64::new(seed.wrapping_add(1))),
75        }
76    }
77
78    /// Execute `make_fut` up to `max_attempts` times, backing off between tries.
79    ///
80    /// `make_fut` is a closure (not a single future) so we can re-poll a fresh
81    /// future on each attempt — most futures aren't `Clone`.
82    pub async fn run<F, Fut, T>(&self, mut make_fut: F) -> Result<T, E>
83    where
84        F: FnMut() -> Fut,
85        Fut: Future<Output = Result<T, E>>,
86    {
87        let mut attempt: u32 = 0;
88        loop {
89            attempt += 1;
90            match make_fut().await {
91                Ok(v) => return Ok(v),
92                Err(e) => {
93                    if attempt >= self.cfg.max_attempts {
94                        return Err(e);
95                    }
96                    if let Some(pred) = &self.cfg.retry_if {
97                        if !pred(&e) {
98                            return Err(e);
99                        }
100                    }
101                    let delay = self.backoff(attempt);
102                    sleep(delay).await;
103                }
104            }
105        }
106    }
107
108    fn backoff(&self, attempt: u32) -> Duration {
109        // attempt is 1-based here; exponent should grow as 0,1,2,...
110        let exp = attempt.saturating_sub(1).min(30);
111        let raw = self.cfg.base_delay.saturating_mul(1u32 << exp);
112        let capped = raw.min(self.cfg.max_delay);
113        // Full jitter: pick a random duration in [0, capped].
114        let max_ms = capped.as_millis().min(u128::from(u64::MAX)) as u64;
115        let jitter_ms = self.next_rand() % (max_ms + 1);
116        Duration::from_millis(jitter_ms)
117    }
118
119    /// XorShift64* PRNG — good enough for jitter, no `rand` dep needed.
120    fn next_rand(&self) -> u64 {
121        let mut x = self.seed.load(Ordering::Relaxed);
122        if x == 0 {
123            x = 0xdead_beef;
124        }
125        x ^= x >> 12;
126        x ^= x << 25;
127        x ^= x >> 27;
128        self.seed.store(x, Ordering::Relaxed);
129        x.wrapping_mul(0x2545_F491_4F6C_DD1D)
130    }
131}