Skip to main content

atomr_core/pattern/
retry.rs

1//! `retry` — wrap an async fallible operation in a bounded retry loop
2//! with optional fixed or exponential backoff.
3//!
4//! Phase 3.4 of `docs/full-port-plan.md`. Akka.NET parity:
5//! `Pattern.Retry` (with the same semantics as the JVM
6//! `Patterns.retry`).
7//!
8//! ```ignore
9//! use std::time::Duration;
10//! use atomr_core::pattern::{retry, RetrySchedule};
11//!
12//! let result = retry(
13//!     || async { fetch().await },
14//!     5,
15//!     RetrySchedule::exponential(Duration::from_millis(50), Duration::from_secs(2)),
16//! ).await;
17//! ```
18
19use std::future::Future;
20use std::time::Duration;
21
22/// Schedule for the delay between attempts.
23#[derive(Debug, Clone, Copy)]
24#[non_exhaustive]
25pub enum RetrySchedule {
26    /// Fixed delay between every attempt.
27    Fixed(Duration),
28    /// Exponential backoff: `min`, `min*2`, `min*4`, … capped at `max`.
29    Exponential { min: Duration, max: Duration },
30}
31
32impl RetrySchedule {
33    pub fn fixed(d: Duration) -> Self {
34        Self::Fixed(d)
35    }
36
37    pub fn exponential(min: Duration, max: Duration) -> Self {
38        Self::Exponential { min, max }
39    }
40
41    /// Delay before the `attempt`th retry (0-indexed: attempt 0 is the
42    /// first retry, i.e. after the initial call has already failed).
43    pub fn delay_for(self, attempt: u32) -> Duration {
44        match self {
45            Self::Fixed(d) => d,
46            Self::Exponential { min, max } => {
47                let factor = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
48                let nanos = (min.as_nanos() as u64).saturating_mul(factor);
49                let capped = nanos.min(max.as_nanos() as u64);
50                Duration::from_nanos(capped)
51            }
52        }
53    }
54}
55
56/// Run `op`, retrying up to `max_attempts` total times (including the
57/// initial call). Returns the last error if every attempt fails.
58///
59/// `max_attempts == 1` means "no retries" — `op` runs exactly once.
60pub async fn retry<T, E, F, Fut>(mut op: F, max_attempts: u32, schedule: RetrySchedule) -> Result<T, E>
61where
62    F: FnMut() -> Fut,
63    Fut: Future<Output = Result<T, E>>,
64{
65    assert!(max_attempts >= 1, "max_attempts must be ≥ 1");
66    let mut last_err: Option<E> = None;
67    for attempt in 0..max_attempts {
68        match op().await {
69            Ok(v) => return Ok(v),
70            Err(e) => {
71                last_err = Some(e);
72                if attempt + 1 < max_attempts {
73                    tokio::time::sleep(schedule.delay_for(attempt)).await;
74                }
75            }
76        }
77    }
78    Err(last_err.expect("loop ran ≥1 time"))
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84    use std::sync::atomic::{AtomicU32, Ordering};
85    use std::sync::Arc;
86
87    #[tokio::test]
88    async fn returns_immediately_on_first_success() {
89        let calls = Arc::new(AtomicU32::new(0));
90        let c2 = calls.clone();
91        let r: Result<i32, &'static str> = retry(
92            move || {
93                let c2 = c2.clone();
94                async move {
95                    c2.fetch_add(1, Ordering::SeqCst);
96                    Ok(42)
97                }
98            },
99            5,
100            RetrySchedule::fixed(Duration::from_millis(0)),
101        )
102        .await;
103        assert_eq!(r, Ok(42));
104        assert_eq!(calls.load(Ordering::SeqCst), 1);
105    }
106
107    #[tokio::test]
108    async fn retries_until_success() {
109        let calls = Arc::new(AtomicU32::new(0));
110        let c2 = calls.clone();
111        let r: Result<i32, &'static str> = retry(
112            move || {
113                let c2 = c2.clone();
114                async move {
115                    let n = c2.fetch_add(1, Ordering::SeqCst) + 1;
116                    if n < 3 {
117                        Err("not yet")
118                    } else {
119                        Ok(n as i32)
120                    }
121                }
122            },
123            5,
124            RetrySchedule::fixed(Duration::from_millis(0)),
125        )
126        .await;
127        assert_eq!(r, Ok(3));
128        assert_eq!(calls.load(Ordering::SeqCst), 3);
129    }
130
131    #[tokio::test]
132    async fn returns_last_error_after_max_attempts() {
133        let r: Result<i32, &'static str> =
134            retry(|| async { Err("nope") }, 3, RetrySchedule::fixed(Duration::from_millis(0))).await;
135        assert_eq!(r, Err("nope"));
136    }
137
138    #[test]
139    fn exponential_backoff_doubles_until_cap() {
140        let s = RetrySchedule::exponential(Duration::from_millis(10), Duration::from_millis(80));
141        assert_eq!(s.delay_for(0), Duration::from_millis(10));
142        assert_eq!(s.delay_for(1), Duration::from_millis(20));
143        assert_eq!(s.delay_for(2), Duration::from_millis(40));
144        assert_eq!(s.delay_for(3), Duration::from_millis(80));
145        assert_eq!(s.delay_for(10), Duration::from_millis(80)); // capped
146    }
147
148    #[test]
149    #[should_panic]
150    fn zero_max_attempts_panics() {
151        let _ = futures::executor::block_on(retry::<(), &'static str, _, _>(
152            || async { Ok(()) },
153            0,
154            RetrySchedule::fixed(Duration::ZERO),
155        ));
156    }
157}