1use rand::Rng;
7use std::time::Duration;
8use thiserror::Error;
9use tracing::{debug, warn};
10
11#[derive(Debug, Clone)]
13pub struct RetryConfig {
14 pub max_retries: u32,
16 pub initial_delay: Duration,
18 pub max_delay: Duration,
20 pub multiplier: f64,
22 pub use_jitter: bool,
24 pub jitter_factor: f64,
26}
27
28impl Default for RetryConfig {
29 fn default() -> Self {
30 Self {
31 max_retries: 3,
32 initial_delay: Duration::from_millis(100),
33 max_delay: Duration::from_secs(30),
34 multiplier: 2.0,
35 use_jitter: true,
36 jitter_factor: 0.3,
37 }
38 }
39}
40
41#[derive(Debug, Error)]
43pub enum RetryError<E> {
44 #[error("Operation failed after {attempts} attempts: {last_error}")]
45 MaxRetriesExceeded { attempts: u32, last_error: E },
46 #[error("Retry aborted: {0}")]
47 Aborted(String),
48}
49
50pub struct RetryPolicy {
52 config: RetryConfig,
53}
54
55impl RetryPolicy {
56 pub fn new(config: RetryConfig) -> Self {
58 Self { config }
59 }
60
61 pub fn default() -> Self {
63 Self {
64 config: RetryConfig::default(),
65 }
66 }
67
68 pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> Result<T, RetryError<E>>
70 where
71 F: FnMut() -> Fut,
72 Fut: std::future::Future<Output = Result<T, E>>,
73 E: std::fmt::Display + Clone,
74 {
75 let mut attempts = 0;
76 let mut last_error: Option<E> = None;
77
78 loop {
79 attempts += 1;
80
81 debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
82
83 match operation().await {
84 Ok(result) => {
85 if attempts > 1 {
86 debug!("Operation succeeded after {} attempts", attempts);
87 }
88 return Ok(result);
89 }
90 Err(e) => {
91 warn!("Attempt {} failed: {}", attempts, e);
92
93 if attempts > self.config.max_retries {
94 return Err(RetryError::MaxRetriesExceeded {
95 attempts,
96 last_error: e,
97 });
98 }
99
100 last_error = Some(e);
101
102 let delay = self.calculate_delay(attempts);
104 debug!("Waiting {:?} before retry", delay);
105
106 tokio::time::sleep(delay).await;
107 }
108 }
109 }
110 }
111
112 pub fn execute_sync<F, T, E>(&self, mut operation: F) -> Result<T, RetryError<E>>
114 where
115 F: FnMut() -> Result<T, E>,
116 E: std::fmt::Display + Clone,
117 {
118 let mut attempts = 0;
119 let mut last_error: Option<E> = None;
120
121 loop {
122 attempts += 1;
123
124 debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
125
126 match operation() {
127 Ok(result) => {
128 if attempts > 1 {
129 debug!("Operation succeeded after {} attempts", attempts);
130 }
131 return Ok(result);
132 }
133 Err(e) => {
134 warn!("Attempt {} failed: {}", attempts, e);
135
136 if attempts > self.config.max_retries {
137 return Err(RetryError::MaxRetriesExceeded {
138 attempts,
139 last_error: e,
140 });
141 }
142
143 last_error = Some(e);
144
145 let delay = self.calculate_delay(attempts);
147 debug!("Waiting {:?} before retry", delay);
148
149 std::thread::sleep(delay);
150 }
151 }
152 }
153 }
154
155 pub async fn execute_with_condition<F, Fut, T, E, C>(
157 &self,
158 mut operation: F,
159 mut should_retry: C,
160 ) -> Result<T, RetryError<E>>
161 where
162 F: FnMut() -> Fut,
163 Fut: std::future::Future<Output = Result<T, E>>,
164 E: std::fmt::Display + Clone,
165 C: FnMut(&E) -> bool,
166 {
167 let mut attempts = 0;
168
169 loop {
170 attempts += 1;
171
172 debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
173
174 match operation().await {
175 Ok(result) => {
176 if attempts > 1 {
177 debug!("Operation succeeded after {} attempts", attempts);
178 }
179 return Ok(result);
180 }
181 Err(e) => {
182 if !should_retry(&e) {
183 debug!("Error is not retryable: {}", e);
184 return Err(RetryError::Aborted(format!(
185 "Non-retryable error after {} attempts: {}",
186 attempts, e
187 )));
188 }
189
190 warn!("Attempt {} failed: {}", attempts, e);
191
192 if attempts > self.config.max_retries {
193 return Err(RetryError::MaxRetriesExceeded {
194 attempts,
195 last_error: e,
196 });
197 }
198
199 let delay = self.calculate_delay(attempts);
201 debug!("Waiting {:?} before retry", delay);
202
203 tokio::time::sleep(delay).await;
204 }
205 }
206 }
207 }
208
209 fn calculate_delay(&self, attempt: u32) -> Duration {
211 let base_delay_ms = self.config.initial_delay.as_millis() as f64
213 * self.config.multiplier.powi((attempt - 1) as i32);
214
215 let base_delay = Duration::from_millis(base_delay_ms as u64);
216
217 let capped_delay = if base_delay > self.config.max_delay {
219 self.config.max_delay
220 } else {
221 base_delay
222 };
223
224 if self.config.use_jitter {
226 self.add_jitter(capped_delay)
227 } else {
228 capped_delay
229 }
230 }
231
232 fn add_jitter(&self, delay: Duration) -> Duration {
234 let mut rng = rand::thread_rng();
235 let delay_ms = delay.as_millis() as f64;
236
237 let jitter_range = delay_ms * self.config.jitter_factor;
239
240 let jitter = rng.gen_range(-jitter_range..=jitter_range);
242 let jittered_ms = (delay_ms + jitter).max(0.0);
243
244 Duration::from_millis(jittered_ms as u64)
245 }
246
247 pub fn config(&self) -> &RetryConfig {
249 &self.config
250 }
251}
252
253pub struct RetryConfigBuilder {
255 config: RetryConfig,
256}
257
258impl RetryConfigBuilder {
259 pub fn new() -> Self {
261 Self {
262 config: RetryConfig::default(),
263 }
264 }
265
266 pub fn max_retries(mut self, max_retries: u32) -> Self {
268 self.config.max_retries = max_retries;
269 self
270 }
271
272 pub fn initial_delay(mut self, delay: Duration) -> Self {
274 self.config.initial_delay = delay;
275 self
276 }
277
278 pub fn max_delay(mut self, delay: Duration) -> Self {
280 self.config.max_delay = delay;
281 self
282 }
283
284 pub fn multiplier(mut self, multiplier: f64) -> Self {
286 self.config.multiplier = multiplier;
287 self
288 }
289
290 pub fn use_jitter(mut self, use_jitter: bool) -> Self {
292 self.config.use_jitter = use_jitter;
293 self
294 }
295
296 pub fn jitter_factor(mut self, factor: f64) -> Self {
298 self.config.jitter_factor = factor.clamp(0.0, 1.0);
299 self
300 }
301
302 pub fn build(self) -> RetryConfig {
304 self.config
305 }
306}
307
308impl Default for RetryConfigBuilder {
309 fn default() -> Self {
310 Self::new()
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317 use std::sync::atomic::{AtomicU32, Ordering};
318 use std::sync::Arc;
319
320 #[tokio::test]
321 async fn test_immediate_success() {
322 let policy = RetryPolicy::default();
323 let counter = Arc::new(AtomicU32::new(0));
324 let counter_clone = Arc::clone(&counter);
325
326 let result = policy
327 .execute(|| async {
328 counter_clone.fetch_add(1, Ordering::SeqCst);
329 Ok::<_, String>("success")
330 })
331 .await;
332
333 assert!(result.is_ok());
334 assert_eq!(result.unwrap(), "success");
335 assert_eq!(counter.load(Ordering::SeqCst), 1);
336 }
337
338 #[tokio::test]
339 async fn test_retry_on_failure() {
340 let config = RetryConfig {
341 max_retries: 3,
342 initial_delay: Duration::from_millis(10),
343 ..Default::default()
344 };
345 let policy = RetryPolicy::new(config);
346 let counter = Arc::new(AtomicU32::new(0));
347 let counter_clone = Arc::clone(&counter);
348
349 let result = policy
350 .execute(|| async {
351 let count = counter_clone.fetch_add(1, Ordering::SeqCst);
352 if count < 2 {
353 Err("temporary failure")
354 } else {
355 Ok("success")
356 }
357 })
358 .await;
359
360 assert!(result.is_ok());
361 assert_eq!(result.unwrap(), "success");
362 assert_eq!(counter.load(Ordering::SeqCst), 3);
363 }
364
365 #[tokio::test]
366 async fn test_max_retries_exceeded() {
367 let config = RetryConfig {
368 max_retries: 2,
369 initial_delay: Duration::from_millis(10),
370 ..Default::default()
371 };
372 let policy = RetryPolicy::new(config);
373 let counter = Arc::new(AtomicU32::new(0));
374 let counter_clone = Arc::clone(&counter);
375
376 let result = policy
377 .execute(|| async {
378 counter_clone.fetch_add(1, Ordering::SeqCst);
379 Err::<String, _>("persistent failure")
380 })
381 .await;
382
383 assert!(result.is_err());
384 assert_eq!(counter.load(Ordering::SeqCst), 3); match result {
387 Err(RetryError::MaxRetriesExceeded { attempts, .. }) => {
388 assert_eq!(attempts, 3);
389 }
390 _ => panic!("Expected MaxRetriesExceeded error"),
391 }
392 }
393
394 #[tokio::test]
395 async fn test_exponential_backoff() {
396 let config = RetryConfig {
397 max_retries: 3,
398 initial_delay: Duration::from_millis(50),
399 multiplier: 2.0,
400 use_jitter: false,
401 ..Default::default()
402 };
403 let policy = RetryPolicy::new(config);
404
405 let delay1 = policy.calculate_delay(1);
406 let delay2 = policy.calculate_delay(2);
407 let delay3 = policy.calculate_delay(3);
408
409 assert_eq!(delay1, Duration::from_millis(50));
410 assert_eq!(delay2, Duration::from_millis(100));
411 assert_eq!(delay3, Duration::from_millis(200));
412 }
413
414 #[tokio::test]
415 async fn test_max_delay_cap() {
416 let config = RetryConfig {
417 max_retries: 5,
418 initial_delay: Duration::from_millis(100),
419 max_delay: Duration::from_millis(500),
420 multiplier: 2.0,
421 use_jitter: false,
422 ..Default::default()
423 };
424 let policy = RetryPolicy::new(config);
425
426 let delay5 = policy.calculate_delay(5);
427 assert_eq!(delay5, Duration::from_millis(500)); }
429
430 #[tokio::test]
431 async fn test_jitter_adds_variation() {
432 let config = RetryConfig {
433 max_retries: 1,
434 initial_delay: Duration::from_millis(100),
435 use_jitter: true,
436 jitter_factor: 0.5,
437 ..Default::default()
438 };
439 let policy = RetryPolicy::new(config);
440
441 let mut delays = vec![];
443 for _ in 0..10 {
444 let delay = policy.calculate_delay(1);
445 delays.push(delay);
446 }
447
448 for delay in &delays {
450 let ms = delay.as_millis();
451 assert!(ms >= 50 && ms <= 150, "Delay {} outside expected range", ms);
452 }
453
454 let all_same = delays.iter().all(|d| d == &delays[0]);
456 assert!(!all_same, "All delays are the same, jitter not working");
457 }
458
459 #[tokio::test]
460 async fn test_synchronous_retry() {
461 let config = RetryConfig {
462 max_retries: 3,
463 initial_delay: Duration::from_millis(10),
464 ..Default::default()
465 };
466 let policy = RetryPolicy::new(config);
467 let counter = Arc::new(AtomicU32::new(0));
468 let counter_clone = Arc::clone(&counter);
469
470 let result = policy.execute_sync(|| {
471 let count = counter_clone.fetch_add(1, Ordering::SeqCst);
472 if count < 2 {
473 Err("temporary failure")
474 } else {
475 Ok("success")
476 }
477 });
478
479 assert!(result.is_ok());
480 assert_eq!(result.unwrap(), "success");
481 assert_eq!(counter.load(Ordering::SeqCst), 3);
482 }
483
484 #[tokio::test]
485 async fn test_retry_with_custom_condition() {
486 let config = RetryConfig {
487 max_retries: 3,
488 initial_delay: Duration::from_millis(10),
489 ..Default::default()
490 };
491 let policy = RetryPolicy::new(config);
492 let counter = Arc::new(AtomicU32::new(0));
493 let counter_clone = Arc::clone(&counter);
494
495 let result = policy
497 .execute_with_condition(
498 || {
499 let counter = Arc::clone(&counter_clone);
500 async move {
501 counter.fetch_add(1, Ordering::SeqCst);
502 Err::<String, _>("permanent error")
503 }
504 },
505 |e| e.contains("temporary"),
506 )
507 .await;
508
509 assert!(matches!(result, Err(RetryError::Aborted(_))));
511 assert_eq!(counter.load(Ordering::SeqCst), 1);
512 }
513
514 #[tokio::test]
515 async fn test_retry_builder() {
516 let config = RetryConfigBuilder::new()
517 .max_retries(5)
518 .initial_delay(Duration::from_millis(50))
519 .max_delay(Duration::from_secs(10))
520 .multiplier(3.0)
521 .use_jitter(false)
522 .build();
523
524 assert_eq!(config.max_retries, 5);
525 assert_eq!(config.initial_delay, Duration::from_millis(50));
526 assert_eq!(config.max_delay, Duration::from_secs(10));
527 assert_eq!(config.multiplier, 3.0);
528 assert!(!config.use_jitter);
529 }
530
531 #[tokio::test]
532 async fn test_zero_retries() {
533 let config = RetryConfig {
534 max_retries: 0,
535 initial_delay: Duration::from_millis(10),
536 ..Default::default()
537 };
538 let policy = RetryPolicy::new(config);
539 let counter = Arc::new(AtomicU32::new(0));
540 let counter_clone = Arc::clone(&counter);
541
542 let result = policy
543 .execute(|| async {
544 counter_clone.fetch_add(1, Ordering::SeqCst);
545 Err::<String, _>("error")
546 })
547 .await;
548
549 assert!(result.is_err());
550 assert_eq!(counter.load(Ordering::SeqCst), 1); }
552
553 #[tokio::test]
554 async fn test_concurrent_retries() {
555 let policy = Arc::new(RetryPolicy::default());
556 let mut handles = vec![];
557
558 for i in 0..5 {
559 let policy_clone = Arc::clone(&policy);
560 let handle = tokio::spawn(async move {
561 policy_clone
562 .execute(|| async move {
563 if i % 2 == 0 {
564 Ok(format!("success {}", i))
565 } else {
566 Err(format!("error {}", i))
567 }
568 })
569 .await
570 });
571 handles.push(handle);
572 }
573
574 for (i, handle) in handles.into_iter().enumerate() {
575 let result = handle.await.unwrap();
576 if i % 2 == 0 {
577 assert!(result.is_ok());
578 } else {
579 assert!(result.is_err());
580 }
581 }
582 }
583
584 #[tokio::test]
585 async fn test_jitter_factor_clamping() {
586 let config = RetryConfigBuilder::new()
587 .jitter_factor(1.5) .build();
589
590 assert_eq!(config.jitter_factor, 1.0);
591
592 let config = RetryConfigBuilder::new()
593 .jitter_factor(-0.5) .build();
595
596 assert_eq!(config.jitter_factor, 0.0);
597 }
598
599 #[tokio::test]
600 async fn test_timing_accuracy() {
601 let config = RetryConfig {
602 max_retries: 2,
603 initial_delay: Duration::from_millis(100),
604 multiplier: 2.0,
605 use_jitter: false,
606 ..Default::default()
607 };
608 let policy = RetryPolicy::new(config);
609 let counter = Arc::new(AtomicU32::new(0));
610 let counter_clone = Arc::clone(&counter);
611
612 let start = std::time::Instant::now();
613
614 let _ = policy
615 .execute(|| async {
616 counter_clone.fetch_add(1, Ordering::SeqCst);
617 Err::<String, _>("error")
618 })
619 .await;
620
621 let elapsed = start.elapsed();
622
623 assert!(
625 elapsed >= Duration::from_millis(300),
626 "Elapsed time {:?} less than expected",
627 elapsed
628 );
629 }
630}