atomr_core/pattern/
retry.rs1use std::future::Future;
16use std::time::Duration;
17
18#[derive(Debug, Clone, Copy)]
20#[non_exhaustive]
21pub enum RetrySchedule {
22 Fixed(Duration),
24 Exponential { min: Duration, max: Duration },
26}
27
28impl RetrySchedule {
29 pub fn fixed(d: Duration) -> Self {
30 Self::Fixed(d)
31 }
32
33 pub fn exponential(min: Duration, max: Duration) -> Self {
34 Self::Exponential { min, max }
35 }
36
37 pub fn delay_for(self, attempt: u32) -> Duration {
40 match self {
41 Self::Fixed(d) => d,
42 Self::Exponential { min, max } => {
43 let factor = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
44 let nanos = (min.as_nanos() as u64).saturating_mul(factor);
45 let capped = nanos.min(max.as_nanos() as u64);
46 Duration::from_nanos(capped)
47 }
48 }
49 }
50}
51
52pub async fn retry<T, E, F, Fut>(mut op: F, max_attempts: u32, schedule: RetrySchedule) -> Result<T, E>
57where
58 F: FnMut() -> Fut,
59 Fut: Future<Output = Result<T, E>>,
60{
61 assert!(max_attempts >= 1, "max_attempts must be ≥ 1");
62 let mut last_err: Option<E> = None;
63 for attempt in 0..max_attempts {
64 match op().await {
65 Ok(v) => return Ok(v),
66 Err(e) => {
67 last_err = Some(e);
68 if attempt + 1 < max_attempts {
69 tokio::time::sleep(schedule.delay_for(attempt)).await;
70 }
71 }
72 }
73 }
74 Err(last_err.expect("loop ran ≥1 time"))
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80 use std::sync::atomic::{AtomicU32, Ordering};
81 use std::sync::Arc;
82
83 #[tokio::test]
84 async fn returns_immediately_on_first_success() {
85 let calls = Arc::new(AtomicU32::new(0));
86 let c2 = calls.clone();
87 let r: Result<i32, &'static str> = retry(
88 move || {
89 let c2 = c2.clone();
90 async move {
91 c2.fetch_add(1, Ordering::SeqCst);
92 Ok(42)
93 }
94 },
95 5,
96 RetrySchedule::fixed(Duration::from_millis(0)),
97 )
98 .await;
99 assert_eq!(r, Ok(42));
100 assert_eq!(calls.load(Ordering::SeqCst), 1);
101 }
102
103 #[tokio::test]
104 async fn retries_until_success() {
105 let calls = Arc::new(AtomicU32::new(0));
106 let c2 = calls.clone();
107 let r: Result<i32, &'static str> = retry(
108 move || {
109 let c2 = c2.clone();
110 async move {
111 let n = c2.fetch_add(1, Ordering::SeqCst) + 1;
112 if n < 3 {
113 Err("not yet")
114 } else {
115 Ok(n as i32)
116 }
117 }
118 },
119 5,
120 RetrySchedule::fixed(Duration::from_millis(0)),
121 )
122 .await;
123 assert_eq!(r, Ok(3));
124 assert_eq!(calls.load(Ordering::SeqCst), 3);
125 }
126
127 #[tokio::test]
128 async fn returns_last_error_after_max_attempts() {
129 let r: Result<i32, &'static str> =
130 retry(|| async { Err("nope") }, 3, RetrySchedule::fixed(Duration::from_millis(0))).await;
131 assert_eq!(r, Err("nope"));
132 }
133
134 #[test]
135 fn exponential_backoff_doubles_until_cap() {
136 let s = RetrySchedule::exponential(Duration::from_millis(10), Duration::from_millis(80));
137 assert_eq!(s.delay_for(0), Duration::from_millis(10));
138 assert_eq!(s.delay_for(1), Duration::from_millis(20));
139 assert_eq!(s.delay_for(2), Duration::from_millis(40));
140 assert_eq!(s.delay_for(3), Duration::from_millis(80));
141 assert_eq!(s.delay_for(10), Duration::from_millis(80)); }
143
144 #[test]
145 #[should_panic]
146 fn zero_max_attempts_panics() {
147 let _ = futures::executor::block_on(retry::<(), &'static str, _, _>(
148 || async { Ok(()) },
149 0,
150 RetrySchedule::fixed(Duration::ZERO),
151 ));
152 }
153}