Skip to main content

atomr_core/pattern/
retry.rs

1//! `retry` — wrap an async fallible operation in a bounded retry loop
2//! with optional fixed or exponential backoff.
3//!
4//! ```ignore
5//! use std::time::Duration;
6//! use atomr_core::pattern::{retry, RetrySchedule};
7//!
8//! let result = retry(
9//!     || async { fetch().await },
10//!     5,
11//!     RetrySchedule::exponential(Duration::from_millis(50), Duration::from_secs(2)),
12//! ).await;
13//! ```
14
15use std::future::Future;
16use std::time::Duration;
17
18/// Schedule for the delay between attempts.
19#[derive(Debug, Clone, Copy)]
20#[non_exhaustive]
21pub enum RetrySchedule {
22    /// Fixed delay between every attempt.
23    Fixed(Duration),
24    /// Exponential backoff: `min`, `min*2`, `min*4`, … capped at `max`.
25    Exponential { min: Duration, max: Duration },
26}
27
28impl RetrySchedule {
29    pub fn fixed(d: Duration) -> Self {
30        Self::Fixed(d)
31    }
32
33    pub fn exponential(min: Duration, max: Duration) -> Self {
34        Self::Exponential { min, max }
35    }
36
37    /// Delay before the `attempt`th retry (0-indexed: attempt 0 is the
38    /// first retry, i.e. after the initial call has already failed).
39    pub fn delay_for(self, attempt: u32) -> Duration {
40        match self {
41            Self::Fixed(d) => d,
42            Self::Exponential { min, max } => {
43                let factor = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
44                let nanos = (min.as_nanos() as u64).saturating_mul(factor);
45                let capped = nanos.min(max.as_nanos() as u64);
46                Duration::from_nanos(capped)
47            }
48        }
49    }
50}
51
52/// Run `op`, retrying up to `max_attempts` total times (including the
53/// initial call). Returns the last error if every attempt fails.
54///
55/// `max_attempts == 1` means "no retries" — `op` runs exactly once.
56pub async fn retry<T, E, F, Fut>(mut op: F, max_attempts: u32, schedule: RetrySchedule) -> Result<T, E>
57where
58    F: FnMut() -> Fut,
59    Fut: Future<Output = Result<T, E>>,
60{
61    assert!(max_attempts >= 1, "max_attempts must be ≥ 1");
62    let mut last_err: Option<E> = None;
63    for attempt in 0..max_attempts {
64        match op().await {
65            Ok(v) => return Ok(v),
66            Err(e) => {
67                last_err = Some(e);
68                if attempt + 1 < max_attempts {
69                    tokio::time::sleep(schedule.delay_for(attempt)).await;
70                }
71            }
72        }
73    }
74    Err(last_err.expect("loop ran ≥1 time"))
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use std::sync::atomic::{AtomicU32, Ordering};
81    use std::sync::Arc;
82
83    #[tokio::test]
84    async fn returns_immediately_on_first_success() {
85        let calls = Arc::new(AtomicU32::new(0));
86        let c2 = calls.clone();
87        let r: Result<i32, &'static str> = retry(
88            move || {
89                let c2 = c2.clone();
90                async move {
91                    c2.fetch_add(1, Ordering::SeqCst);
92                    Ok(42)
93                }
94            },
95            5,
96            RetrySchedule::fixed(Duration::from_millis(0)),
97        )
98        .await;
99        assert_eq!(r, Ok(42));
100        assert_eq!(calls.load(Ordering::SeqCst), 1);
101    }
102
103    #[tokio::test]
104    async fn retries_until_success() {
105        let calls = Arc::new(AtomicU32::new(0));
106        let c2 = calls.clone();
107        let r: Result<i32, &'static str> = retry(
108            move || {
109                let c2 = c2.clone();
110                async move {
111                    let n = c2.fetch_add(1, Ordering::SeqCst) + 1;
112                    if n < 3 {
113                        Err("not yet")
114                    } else {
115                        Ok(n as i32)
116                    }
117                }
118            },
119            5,
120            RetrySchedule::fixed(Duration::from_millis(0)),
121        )
122        .await;
123        assert_eq!(r, Ok(3));
124        assert_eq!(calls.load(Ordering::SeqCst), 3);
125    }
126
127    #[tokio::test]
128    async fn returns_last_error_after_max_attempts() {
129        let r: Result<i32, &'static str> =
130            retry(|| async { Err("nope") }, 3, RetrySchedule::fixed(Duration::from_millis(0))).await;
131        assert_eq!(r, Err("nope"));
132    }
133
134    #[test]
135    fn exponential_backoff_doubles_until_cap() {
136        let s = RetrySchedule::exponential(Duration::from_millis(10), Duration::from_millis(80));
137        assert_eq!(s.delay_for(0), Duration::from_millis(10));
138        assert_eq!(s.delay_for(1), Duration::from_millis(20));
139        assert_eq!(s.delay_for(2), Duration::from_millis(40));
140        assert_eq!(s.delay_for(3), Duration::from_millis(80));
141        assert_eq!(s.delay_for(10), Duration::from_millis(80)); // capped
142    }
143
144    #[test]
145    #[should_panic]
146    fn zero_max_attempts_panics() {
147        let _ = futures::executor::block_on(retry::<(), &'static str, _, _>(
148            || async { Ok(()) },
149            0,
150            RetrySchedule::fixed(Duration::ZERO),
151        ));
152    }
153}