1use std::future::Future;
14use std::time::Duration;
15
16use rand::rngs::StdRng;
17use rand::Rng;
18use rand::SeedableRng;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum BackoffStrategy {
23 Exponential,
25 Constant,
27 Linear,
29}
30
31#[derive(Debug, Clone)]
33pub struct RetryConfig {
34 pub max_retries: u32,
37 pub base_delay: Duration,
39 pub max_delay: Duration,
41 pub jitter_fraction: f64,
45 pub strategy: BackoffStrategy,
47}
48
49impl Default for RetryConfig {
50 fn default() -> Self {
51 Self {
52 max_retries: 3,
53 base_delay: Duration::from_millis(100),
54 max_delay: Duration::from_secs(5),
55 jitter_fraction: 0.25,
56 strategy: BackoffStrategy::Exponential,
57 }
58 }
59}
60
61pub type AttemptResult<T, E> = Result<T, E>;
63
64pub async fn retry_with_jitter<F, Fut, T, E>(config: &RetryConfig, op: F) -> Result<T, E>
68where
69 F: FnMut(u32) -> Fut,
70 Fut: Future<Output = AttemptResult<T, E>>,
71{
72 let seed = u64::from(config.max_retries).wrapping_add(0x9E37_79B9_7F4A_7C15);
73 let rng = StdRng::seed_from_u64(seed);
74 retry_with_jitter_rng(config, rng, op).await
75}
76
77pub async fn retry_with_jitter_rng<F, Fut, T, E, R>(
81 config: &RetryConfig,
82 mut rng: R,
83 mut op: F,
84) -> Result<T, E>
85where
86 F: FnMut(u32) -> Fut,
87 Fut: Future<Output = AttemptResult<T, E>>,
88 R: Rng,
89{
90 let total_attempts = config.max_retries.saturating_add(1);
91 let mut last_err: Option<E> = None;
92 for attempt in 1..=total_attempts {
93 match op(attempt).await {
94 Ok(value) => return Ok(value),
95 Err(err) => {
96 last_err = Some(err);
97 if attempt >= total_attempts {
98 break;
99 }
100 let delay = compute_delay(config, attempt, &mut rng);
101 if !delay.is_zero() {
102 tokio::time::sleep(delay).await;
103 }
104 }
105 }
106 }
107 match last_err {
108 Some(err) => Err(err),
109 None => unreachable!("retry loop must have produced at least one result"),
113 }
114}
115
116fn compute_delay<R: Rng>(config: &RetryConfig, attempt: u32, rng: &mut R) -> Duration {
117 let base = config.base_delay.as_secs_f64().max(0.0);
118 let raw = match config.strategy {
119 BackoffStrategy::Constant => base,
120 BackoffStrategy::Linear => base * f64::from(attempt.max(1)),
121 BackoffStrategy::Exponential => {
122 let exp = attempt.saturating_sub(1).min(30);
124 base * (1u64 << exp) as f64
125 }
126 };
127 let max_secs = config.max_delay.as_secs_f64().max(0.0);
128 let capped = raw.min(max_secs);
129 let jitter = config.jitter_fraction.clamp(0.0, 1.0);
130 let factor = if jitter == 0.0 {
131 1.0
132 } else {
133 1.0 + rng.gen_range(-jitter..=jitter)
134 };
135 let jittered = (capped * factor).max(0.0);
136 Duration::from_secs_f64(jittered)
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use std::sync::atomic::{AtomicU32, Ordering};
143 use std::sync::Arc;
144
145 #[tokio::test(flavor = "current_thread", start_paused = true)]
146 async fn succeeds_on_first_attempt() {
147 let counter = Arc::new(AtomicU32::new(0));
148 let counter_clone = Arc::clone(&counter);
149 let config = RetryConfig::default();
150 let result: Result<u32, &'static str> = retry_with_jitter(&config, |_| {
151 let counter = Arc::clone(&counter_clone);
152 async move {
153 counter.fetch_add(1, Ordering::SeqCst);
154 Ok(42)
155 }
156 })
157 .await;
158 assert_eq!(result, Ok(42));
159 assert_eq!(counter.load(Ordering::SeqCst), 1);
160 }
161
162 #[tokio::test(flavor = "current_thread", start_paused = true)]
163 async fn succeeds_after_retries() {
164 let counter = Arc::new(AtomicU32::new(0));
165 let counter_clone = Arc::clone(&counter);
166 let config = RetryConfig {
167 max_retries: 4,
168 base_delay: Duration::from_millis(10),
169 max_delay: Duration::from_millis(40),
170 jitter_fraction: 0.0,
171 strategy: BackoffStrategy::Exponential,
172 };
173 let result: Result<u32, &'static str> = retry_with_jitter(&config, move |_| {
174 let counter = Arc::clone(&counter_clone);
175 async move {
176 let n = counter.fetch_add(1, Ordering::SeqCst) + 1;
177 if n < 3 {
178 Err("transient")
179 } else {
180 Ok(n)
181 }
182 }
183 })
184 .await;
185 assert_eq!(result, Ok(3));
186 assert_eq!(counter.load(Ordering::SeqCst), 3);
187 }
188
189 #[tokio::test(flavor = "current_thread", start_paused = true)]
190 async fn returns_last_error_after_exhausting_retries() {
191 let counter = Arc::new(AtomicU32::new(0));
192 let counter_clone = Arc::clone(&counter);
193 let config = RetryConfig {
194 max_retries: 2,
195 base_delay: Duration::from_millis(1),
196 max_delay: Duration::from_millis(4),
197 jitter_fraction: 0.0,
198 strategy: BackoffStrategy::Constant,
199 };
200 let result: Result<u32, &'static str> = retry_with_jitter(&config, move |_| {
201 let counter = Arc::clone(&counter_clone);
202 async move {
203 counter.fetch_add(1, Ordering::SeqCst);
204 Err("always fails")
205 }
206 })
207 .await;
208 assert_eq!(result, Err("always fails"));
209 assert_eq!(counter.load(Ordering::SeqCst), 3);
210 }
211
212 #[tokio::test(flavor = "current_thread", start_paused = true)]
213 async fn zero_max_retries_runs_once() {
214 let counter = Arc::new(AtomicU32::new(0));
215 let counter_clone = Arc::clone(&counter);
216 let config = RetryConfig {
217 max_retries: 0,
218 base_delay: Duration::from_millis(1),
219 max_delay: Duration::from_millis(1),
220 jitter_fraction: 0.0,
221 strategy: BackoffStrategy::Exponential,
222 };
223 let result: Result<u32, &'static str> = retry_with_jitter(&config, move |_| {
224 let counter = Arc::clone(&counter_clone);
225 async move {
226 counter.fetch_add(1, Ordering::SeqCst);
227 Err("boom")
228 }
229 })
230 .await;
231 assert_eq!(result, Err("boom"));
232 assert_eq!(counter.load(Ordering::SeqCst), 1);
233 }
234
235 #[test]
236 fn compute_delay_caps_at_max_delay() {
237 let config = RetryConfig {
238 max_retries: 10,
239 base_delay: Duration::from_millis(100),
240 max_delay: Duration::from_millis(500),
241 jitter_fraction: 0.0,
242 strategy: BackoffStrategy::Exponential,
243 };
244 let mut rng = StdRng::seed_from_u64(1);
245 let d = compute_delay(&config, 10, &mut rng);
247 assert_eq!(d, Duration::from_millis(500));
248 }
249}