1use std::thread;
24use std::time::Duration;
25
26#[derive(Debug, Clone, Copy)]
28pub enum BackoffStrategy {
29 None,
31
32 Constant(Duration),
34
35 Linear {
37 initial: Duration,
39 increment: Duration,
41 max: Duration,
43 },
44
45 Exponential {
47 initial: Duration,
49 max: Duration,
51 multiplier: f64,
53 },
54}
55
56impl BackoffStrategy {
57 #[must_use]
59 pub fn delay_for_attempt(&self, attempt: usize) -> Duration {
60 match self {
61 Self::None => Duration::ZERO,
62
63 Self::Constant(d) => *d,
64
65 Self::Linear {
66 initial,
67 increment,
68 max,
69 } => {
70 let delay = *initial + (*increment * attempt as u32);
71 delay.min(*max)
72 }
73
74 Self::Exponential {
75 initial,
76 max,
77 multiplier,
78 } => {
79 let mult = multiplier.powi(attempt as i32);
80 let delay_nanos = initial.as_nanos() as f64 * mult;
81 let delay = Duration::from_nanos(delay_nanos as u64);
82 delay.min(*max)
83 }
84 }
85 }
86}
87
88impl Default for BackoffStrategy {
89 fn default() -> Self {
90 Self::Exponential {
91 initial: Duration::from_millis(100),
92 max: Duration::from_secs(30),
93 multiplier: 2.0,
94 }
95 }
96}
97
98#[derive(Debug, Clone, Copy)]
100pub struct RetryConfig {
101 pub max_attempts: usize,
103 pub backoff: BackoffStrategy,
105 pub jitter: bool,
107}
108
109impl RetryConfig {
110 #[must_use]
112 pub fn new() -> Self {
113 Self::default()
114 }
115
116 #[must_use]
118 pub fn max_attempts(mut self, n: usize) -> Self {
119 self.max_attempts = n.max(1);
120 self
121 }
122
123 #[must_use]
125 pub fn backoff(mut self, strategy: BackoffStrategy) -> Self {
126 self.backoff = strategy;
127 self
128 }
129
130 #[must_use]
132 pub fn jitter(mut self, enabled: bool) -> Self {
133 self.jitter = enabled;
134 self
135 }
136
137 #[must_use]
139 pub fn no_retry() -> Self {
140 Self {
141 max_attempts: 1,
142 backoff: BackoffStrategy::None,
143 jitter: false,
144 }
145 }
146
147 #[must_use]
149 pub fn with_constant_delay(attempts: usize, delay: Duration) -> Self {
150 Self {
151 max_attempts: attempts,
152 backoff: BackoffStrategy::Constant(delay),
153 jitter: false,
154 }
155 }
156
157 #[must_use]
159 pub fn with_exponential_backoff(attempts: usize, initial: Duration, max: Duration) -> Self {
160 Self {
161 max_attempts: attempts,
162 backoff: BackoffStrategy::Exponential {
163 initial,
164 max,
165 multiplier: 2.0,
166 },
167 jitter: true,
168 }
169 }
170}
171
172impl Default for RetryConfig {
173 fn default() -> Self {
174 Self {
175 max_attempts: 3,
176 backoff: BackoffStrategy::default(),
177 jitter: true,
178 }
179 }
180}
181
182#[derive(Debug)]
184pub struct RetryResult<T, E> {
185 pub result: Result<T, E>,
187 pub attempts: usize,
189 pub total_time: Duration,
191}
192
193impl<T, E> RetryResult<T, E> {
194 #[must_use]
196 pub fn is_ok(&self) -> bool {
197 self.result.is_ok()
198 }
199
200 #[must_use]
202 pub fn is_err(&self) -> bool {
203 self.result.is_err()
204 }
205
206 pub fn unwrap(self) -> T
208 where
209 E: std::fmt::Debug,
210 {
211 self.result.unwrap()
212 }
213
214 pub fn into_result(self) -> Result<T, E> {
216 self.result
217 }
218}
219
220pub fn retry<T, E, F>(config: RetryConfig, mut operation: F) -> RetryResult<T, E>
231where
232 F: FnMut() -> Result<T, E>,
233{
234 let start = std::time::Instant::now();
235 let mut last_error: Option<E> = None;
236
237 for attempt in 0..config.max_attempts {
238 match operation() {
239 Ok(value) => {
240 return RetryResult {
241 result: Ok(value),
242 attempts: attempt + 1,
243 total_time: start.elapsed(),
244 };
245 }
246 Err(e) => {
247 last_error = Some(e);
248
249 if attempt + 1 < config.max_attempts {
251 let mut delay = config.backoff.delay_for_attempt(attempt);
252
253 if config.jitter && delay > Duration::ZERO {
255 let jitter_factor = simple_random() * 0.25;
256 let jitter =
257 Duration::from_nanos((delay.as_nanos() as f64 * jitter_factor) as u64);
258 delay += jitter;
259 }
260
261 if delay > Duration::ZERO {
262 thread::sleep(delay);
263 }
264 }
265 }
266 }
267 }
268
269 RetryResult {
270 result: Err(last_error.expect("At least one attempt should have been made")),
271 attempts: config.max_attempts,
272 total_time: start.elapsed(),
273 }
274}
275
276pub fn retry_with_context<T, E, F>(config: RetryConfig, mut operation: F) -> RetryResult<T, E>
278where
279 F: FnMut(usize) -> Result<T, E>,
280{
281 let start = std::time::Instant::now();
282 let mut last_error: Option<E> = None;
283
284 for attempt in 0..config.max_attempts {
285 match operation(attempt) {
286 Ok(value) => {
287 return RetryResult {
288 result: Ok(value),
289 attempts: attempt + 1,
290 total_time: start.elapsed(),
291 };
292 }
293 Err(e) => {
294 last_error = Some(e);
295
296 if attempt + 1 < config.max_attempts {
297 let delay = config.backoff.delay_for_attempt(attempt);
298 if delay > Duration::ZERO {
299 thread::sleep(delay);
300 }
301 }
302 }
303 }
304 }
305
306 RetryResult {
307 result: Err(last_error.expect("At least one attempt should have been made")),
308 attempts: config.max_attempts,
309 total_time: start.elapsed(),
310 }
311}
312
313fn simple_random() -> f64 {
315 use std::time::SystemTime;
316 let nanos = SystemTime::now()
317 .duration_since(SystemTime::UNIX_EPOCH)
318 .unwrap_or_default()
319 .subsec_nanos();
320 (nanos % 1000) as f64 / 1000.0
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use std::cell::Cell;
327
328 #[test]
329 fn test_retry_succeeds_first_try() {
330 let config = RetryConfig::new().max_attempts(3);
331 let result = retry(config, || Ok::<_, &str>("success"));
332
333 assert!(result.is_ok());
334 assert_eq!(result.attempts, 1);
335 assert_eq!(result.unwrap(), "success");
336 }
337
338 #[test]
339 fn test_retry_succeeds_after_failures() {
340 let attempts = Cell::new(0);
341 let config = RetryConfig::new()
342 .max_attempts(3)
343 .backoff(BackoffStrategy::None);
344
345 let result = retry(config, || {
346 let n = attempts.get();
347 attempts.set(n + 1);
348 if n < 2 { Err("not yet") } else { Ok("success") }
349 });
350
351 assert!(result.is_ok());
352 assert_eq!(result.attempts, 3);
353 }
354
355 #[test]
356 fn test_retry_exhausted() {
357 let config = RetryConfig::new()
358 .max_attempts(3)
359 .backoff(BackoffStrategy::None);
360
361 let result = retry(config, || Err::<(), _>("always fails"));
362
363 assert!(result.is_err());
364 assert_eq!(result.attempts, 3);
365 }
366
367 #[test]
368 fn test_backoff_constant() {
369 let strategy = BackoffStrategy::Constant(Duration::from_millis(100));
370 assert_eq!(strategy.delay_for_attempt(0), Duration::from_millis(100));
371 assert_eq!(strategy.delay_for_attempt(5), Duration::from_millis(100));
372 }
373
374 #[test]
375 fn test_backoff_exponential() {
376 let strategy = BackoffStrategy::Exponential {
377 initial: Duration::from_millis(100),
378 max: Duration::from_secs(10),
379 multiplier: 2.0,
380 };
381
382 assert_eq!(strategy.delay_for_attempt(0), Duration::from_millis(100));
383 assert_eq!(strategy.delay_for_attempt(1), Duration::from_millis(200));
384 assert_eq!(strategy.delay_for_attempt(2), Duration::from_millis(400));
385 assert_eq!(strategy.delay_for_attempt(3), Duration::from_millis(800));
386 }
387
388 #[test]
389 fn test_backoff_max_cap() {
390 let strategy = BackoffStrategy::Exponential {
391 initial: Duration::from_secs(1),
392 max: Duration::from_secs(5),
393 multiplier: 2.0,
394 };
395
396 assert_eq!(strategy.delay_for_attempt(10), Duration::from_secs(5));
398 }
399
400 #[test]
401 fn test_no_retry_config() {
402 let config = RetryConfig::no_retry();
403 assert_eq!(config.max_attempts, 1);
404 }
405}