atomr_core/pattern/
retry.rs1use std::future::Future;
20use std::time::Duration;
21
22#[derive(Debug, Clone, Copy)]
24#[non_exhaustive]
25pub enum RetrySchedule {
26 Fixed(Duration),
28 Exponential { min: Duration, max: Duration },
30}
31
32impl RetrySchedule {
33 pub fn fixed(d: Duration) -> Self {
34 Self::Fixed(d)
35 }
36
37 pub fn exponential(min: Duration, max: Duration) -> Self {
38 Self::Exponential { min, max }
39 }
40
41 pub fn delay_for(self, attempt: u32) -> Duration {
44 match self {
45 Self::Fixed(d) => d,
46 Self::Exponential { min, max } => {
47 let factor = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
48 let nanos = (min.as_nanos() as u64).saturating_mul(factor);
49 let capped = nanos.min(max.as_nanos() as u64);
50 Duration::from_nanos(capped)
51 }
52 }
53 }
54}
55
56pub async fn retry<T, E, F, Fut>(mut op: F, max_attempts: u32, schedule: RetrySchedule) -> Result<T, E>
61where
62 F: FnMut() -> Fut,
63 Fut: Future<Output = Result<T, E>>,
64{
65 assert!(max_attempts >= 1, "max_attempts must be ≥ 1");
66 let mut last_err: Option<E> = None;
67 for attempt in 0..max_attempts {
68 match op().await {
69 Ok(v) => return Ok(v),
70 Err(e) => {
71 last_err = Some(e);
72 if attempt + 1 < max_attempts {
73 tokio::time::sleep(schedule.delay_for(attempt)).await;
74 }
75 }
76 }
77 }
78 Err(last_err.expect("loop ran ≥1 time"))
79}
80
81#[cfg(test)]
82mod tests {
83 use super::*;
84 use std::sync::atomic::{AtomicU32, Ordering};
85 use std::sync::Arc;
86
87 #[tokio::test]
88 async fn returns_immediately_on_first_success() {
89 let calls = Arc::new(AtomicU32::new(0));
90 let c2 = calls.clone();
91 let r: Result<i32, &'static str> = retry(
92 move || {
93 let c2 = c2.clone();
94 async move {
95 c2.fetch_add(1, Ordering::SeqCst);
96 Ok(42)
97 }
98 },
99 5,
100 RetrySchedule::fixed(Duration::from_millis(0)),
101 )
102 .await;
103 assert_eq!(r, Ok(42));
104 assert_eq!(calls.load(Ordering::SeqCst), 1);
105 }
106
107 #[tokio::test]
108 async fn retries_until_success() {
109 let calls = Arc::new(AtomicU32::new(0));
110 let c2 = calls.clone();
111 let r: Result<i32, &'static str> = retry(
112 move || {
113 let c2 = c2.clone();
114 async move {
115 let n = c2.fetch_add(1, Ordering::SeqCst) + 1;
116 if n < 3 {
117 Err("not yet")
118 } else {
119 Ok(n as i32)
120 }
121 }
122 },
123 5,
124 RetrySchedule::fixed(Duration::from_millis(0)),
125 )
126 .await;
127 assert_eq!(r, Ok(3));
128 assert_eq!(calls.load(Ordering::SeqCst), 3);
129 }
130
131 #[tokio::test]
132 async fn returns_last_error_after_max_attempts() {
133 let r: Result<i32, &'static str> =
134 retry(|| async { Err("nope") }, 3, RetrySchedule::fixed(Duration::from_millis(0))).await;
135 assert_eq!(r, Err("nope"));
136 }
137
138 #[test]
139 fn exponential_backoff_doubles_until_cap() {
140 let s = RetrySchedule::exponential(Duration::from_millis(10), Duration::from_millis(80));
141 assert_eq!(s.delay_for(0), Duration::from_millis(10));
142 assert_eq!(s.delay_for(1), Duration::from_millis(20));
143 assert_eq!(s.delay_for(2), Duration::from_millis(40));
144 assert_eq!(s.delay_for(3), Duration::from_millis(80));
145 assert_eq!(s.delay_for(10), Duration::from_millis(80)); }
147
148 #[test]
149 #[should_panic]
150 fn zero_max_attempts_panics() {
151 let _ = futures::executor::block_on(retry::<(), &'static str, _, _>(
152 || async { Ok(()) },
153 0,
154 RetrySchedule::fixed(Duration::ZERO),
155 ));
156 }
157}