1use rand::Rng;
7use std::time::Duration;
8use thiserror::Error;
9use tracing::{debug, warn};
10
11#[derive(Debug, Clone)]
13pub struct RetryConfig {
14 pub max_retries: u32,
16 pub initial_delay: Duration,
18 pub max_delay: Duration,
20 pub multiplier: f64,
22 pub use_jitter: bool,
24 pub jitter_factor: f64,
26}
27
28impl Default for RetryConfig {
29 fn default() -> Self {
30 Self {
31 max_retries: 3,
32 initial_delay: Duration::from_millis(100),
33 max_delay: Duration::from_secs(30),
34 multiplier: 2.0,
35 use_jitter: true,
36 jitter_factor: 0.3,
37 }
38 }
39}
40
41#[derive(Debug, Error)]
43pub enum RetryError<E> {
44 #[error("Operation failed after {attempts} attempts: {last_error}")]
45 MaxRetriesExceeded { attempts: u32, last_error: E },
46 #[error("Retry aborted: {0}")]
47 Aborted(String),
48}
49
50pub struct RetryPolicy {
52 config: RetryConfig,
53}
54
55impl RetryPolicy {
56 pub fn new(config: RetryConfig) -> Self {
58 Self { config }
59 }
60
61 pub fn with_default_config() -> Self {
63 Self {
64 config: RetryConfig::default(),
65 }
66 }
67
68 pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> Result<T, RetryError<E>>
70 where
71 F: FnMut() -> Fut,
72 Fut: std::future::Future<Output = Result<T, E>>,
73 E: std::fmt::Display + Clone,
74 {
75 let mut attempts = 0;
76 #[allow(unused_assignments)]
77 let mut _last_error: Option<E> = None;
78
79 loop {
80 attempts += 1;
81
82 debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
83
84 match operation().await {
85 Ok(result) => {
86 if attempts > 1 {
87 debug!("Operation succeeded after {} attempts", attempts);
88 }
89 return Ok(result);
90 }
91 Err(e) => {
92 warn!("Attempt {} failed: {}", attempts, e);
93
94 if attempts > self.config.max_retries {
95 return Err(RetryError::MaxRetriesExceeded {
96 attempts,
97 last_error: e,
98 });
99 }
100
101 _last_error = Some(e);
102
103 let delay = self.calculate_delay(attempts);
105 debug!("Waiting {:?} before retry", delay);
106
107 tokio::time::sleep(delay).await;
108 }
109 }
110 }
111 }
112
113 pub fn execute_sync<F, T, E>(&self, mut operation: F) -> Result<T, RetryError<E>>
115 where
116 F: FnMut() -> Result<T, E>,
117 E: std::fmt::Display + Clone,
118 {
119 let mut attempts = 0;
120 #[allow(unused_assignments)]
121 let mut _last_error: Option<E> = None;
122
123 loop {
124 attempts += 1;
125
126 debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
127
128 match operation() {
129 Ok(result) => {
130 if attempts > 1 {
131 debug!("Operation succeeded after {} attempts", attempts);
132 }
133 return Ok(result);
134 }
135 Err(e) => {
136 warn!("Attempt {} failed: {}", attempts, e);
137
138 if attempts > self.config.max_retries {
139 return Err(RetryError::MaxRetriesExceeded {
140 attempts,
141 last_error: e,
142 });
143 }
144
145 _last_error = Some(e);
146
147 let delay = self.calculate_delay(attempts);
149 debug!("Waiting {:?} before retry", delay);
150
151 std::thread::sleep(delay);
152 }
153 }
154 }
155 }
156
157 pub async fn execute_with_condition<F, Fut, T, E, C>(
159 &self,
160 mut operation: F,
161 mut should_retry: C,
162 ) -> Result<T, RetryError<E>>
163 where
164 F: FnMut() -> Fut,
165 Fut: std::future::Future<Output = Result<T, E>>,
166 E: std::fmt::Display + Clone,
167 C: FnMut(&E) -> bool,
168 {
169 let mut attempts = 0;
170
171 loop {
172 attempts += 1;
173
174 debug!("Retry attempt {}/{}", attempts, self.config.max_retries + 1);
175
176 match operation().await {
177 Ok(result) => {
178 if attempts > 1 {
179 debug!("Operation succeeded after {} attempts", attempts);
180 }
181 return Ok(result);
182 }
183 Err(e) => {
184 if !should_retry(&e) {
185 debug!("Error is not retryable: {}", e);
186 return Err(RetryError::Aborted(format!(
187 "Non-retryable error after {} attempts: {}",
188 attempts, e
189 )));
190 }
191
192 warn!("Attempt {} failed: {}", attempts, e);
193
194 if attempts > self.config.max_retries {
195 return Err(RetryError::MaxRetriesExceeded {
196 attempts,
197 last_error: e,
198 });
199 }
200
201 let delay = self.calculate_delay(attempts);
203 debug!("Waiting {:?} before retry", delay);
204
205 tokio::time::sleep(delay).await;
206 }
207 }
208 }
209 }
210
211 fn calculate_delay(&self, attempt: u32) -> Duration {
213 let base_delay_ms = self.config.initial_delay.as_millis() as f64
215 * self.config.multiplier.powi((attempt - 1) as i32);
216
217 let base_delay = Duration::from_millis(base_delay_ms as u64);
218
219 let capped_delay = if base_delay > self.config.max_delay {
221 self.config.max_delay
222 } else {
223 base_delay
224 };
225
226 if self.config.use_jitter {
228 self.add_jitter(capped_delay)
229 } else {
230 capped_delay
231 }
232 }
233
234 fn add_jitter(&self, delay: Duration) -> Duration {
236 let mut rng = rand::thread_rng();
237 let delay_ms = delay.as_millis() as f64;
238
239 let jitter_range = delay_ms * self.config.jitter_factor;
241
242 let jitter = rng.gen_range(-jitter_range..=jitter_range);
244 let jittered_ms = (delay_ms + jitter).max(0.0);
245
246 Duration::from_millis(jittered_ms as u64)
247 }
248
249 pub fn config(&self) -> &RetryConfig {
251 &self.config
252 }
253}
254
255pub struct RetryConfigBuilder {
257 config: RetryConfig,
258}
259
260impl RetryConfigBuilder {
261 pub fn new() -> Self {
263 Self {
264 config: RetryConfig::default(),
265 }
266 }
267
268 pub fn max_retries(mut self, max_retries: u32) -> Self {
270 self.config.max_retries = max_retries;
271 self
272 }
273
274 pub fn initial_delay(mut self, delay: Duration) -> Self {
276 self.config.initial_delay = delay;
277 self
278 }
279
280 pub fn max_delay(mut self, delay: Duration) -> Self {
282 self.config.max_delay = delay;
283 self
284 }
285
286 pub fn multiplier(mut self, multiplier: f64) -> Self {
288 self.config.multiplier = multiplier;
289 self
290 }
291
292 pub fn use_jitter(mut self, use_jitter: bool) -> Self {
294 self.config.use_jitter = use_jitter;
295 self
296 }
297
298 pub fn jitter_factor(mut self, factor: f64) -> Self {
300 self.config.jitter_factor = factor.clamp(0.0, 1.0);
301 self
302 }
303
304 pub fn build(self) -> RetryConfig {
306 self.config
307 }
308}
309
310impl Default for RetryConfigBuilder {
311 fn default() -> Self {
312 Self::new()
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use std::sync::atomic::{AtomicU32, Ordering};
320 use std::sync::Arc;
321
322 #[tokio::test]
323 async fn test_immediate_success() {
324 let policy = RetryPolicy::default();
325 let counter = Arc::new(AtomicU32::new(0));
326 let counter_clone = Arc::clone(&counter);
327
328 let result = policy
329 .execute(|| async {
330 counter_clone.fetch_add(1, Ordering::SeqCst);
331 Ok::<_, String>("success")
332 })
333 .await;
334
335 assert!(result.is_ok());
336 assert_eq!(result.unwrap(), "success");
337 assert_eq!(counter.load(Ordering::SeqCst), 1);
338 }
339
340 #[tokio::test]
341 async fn test_retry_on_failure() {
342 let config = RetryConfig {
343 max_retries: 3,
344 initial_delay: Duration::from_millis(10),
345 ..Default::default()
346 };
347 let policy = RetryPolicy::new(config);
348 let counter = Arc::new(AtomicU32::new(0));
349 let counter_clone = Arc::clone(&counter);
350
351 let result = policy
352 .execute(|| async {
353 let count = counter_clone.fetch_add(1, Ordering::SeqCst);
354 if count < 2 {
355 Err("temporary failure")
356 } else {
357 Ok("success")
358 }
359 })
360 .await;
361
362 assert!(result.is_ok());
363 assert_eq!(result.unwrap(), "success");
364 assert_eq!(counter.load(Ordering::SeqCst), 3);
365 }
366
367 #[tokio::test]
368 async fn test_max_retries_exceeded() {
369 let config = RetryConfig {
370 max_retries: 2,
371 initial_delay: Duration::from_millis(10),
372 ..Default::default()
373 };
374 let policy = RetryPolicy::new(config);
375 let counter = Arc::new(AtomicU32::new(0));
376 let counter_clone = Arc::clone(&counter);
377
378 let result = policy
379 .execute(|| async {
380 counter_clone.fetch_add(1, Ordering::SeqCst);
381 Err::<String, _>("persistent failure")
382 })
383 .await;
384
385 assert!(result.is_err());
386 assert_eq!(counter.load(Ordering::SeqCst), 3); match result {
389 Err(RetryError::MaxRetriesExceeded { attempts, .. }) => {
390 assert_eq!(attempts, 3);
391 }
392 _ => panic!("Expected MaxRetriesExceeded error"),
393 }
394 }
395
396 #[tokio::test]
397 async fn test_exponential_backoff() {
398 let config = RetryConfig {
399 max_retries: 3,
400 initial_delay: Duration::from_millis(50),
401 multiplier: 2.0,
402 use_jitter: false,
403 ..Default::default()
404 };
405 let policy = RetryPolicy::new(config);
406
407 let delay1 = policy.calculate_delay(1);
408 let delay2 = policy.calculate_delay(2);
409 let delay3 = policy.calculate_delay(3);
410
411 assert_eq!(delay1, Duration::from_millis(50));
412 assert_eq!(delay2, Duration::from_millis(100));
413 assert_eq!(delay3, Duration::from_millis(200));
414 }
415
416 #[tokio::test]
417 async fn test_max_delay_cap() {
418 let config = RetryConfig {
419 max_retries: 5,
420 initial_delay: Duration::from_millis(100),
421 max_delay: Duration::from_millis(500),
422 multiplier: 2.0,
423 use_jitter: false,
424 ..Default::default()
425 };
426 let policy = RetryPolicy::new(config);
427
428 let delay5 = policy.calculate_delay(5);
429 assert_eq!(delay5, Duration::from_millis(500)); }
431
432 #[tokio::test]
433 async fn test_jitter_adds_variation() {
434 let config = RetryConfig {
435 max_retries: 1,
436 initial_delay: Duration::from_millis(100),
437 use_jitter: true,
438 jitter_factor: 0.5,
439 ..Default::default()
440 };
441 let policy = RetryPolicy::new(config);
442
443 let mut delays = vec![];
445 for _ in 0..10 {
446 let delay = policy.calculate_delay(1);
447 delays.push(delay);
448 }
449
450 for delay in &delays {
452 let ms = delay.as_millis();
453 assert!(ms >= 50 && ms <= 150, "Delay {} outside expected range", ms);
454 }
455
456 let all_same = delays.iter().all(|d| d == &delays[0]);
458 assert!(!all_same, "All delays are the same, jitter not working");
459 }
460
461 #[tokio::test]
462 async fn test_synchronous_retry() {
463 let config = RetryConfig {
464 max_retries: 3,
465 initial_delay: Duration::from_millis(10),
466 ..Default::default()
467 };
468 let policy = RetryPolicy::new(config);
469 let counter = Arc::new(AtomicU32::new(0));
470 let counter_clone = Arc::clone(&counter);
471
472 let result = policy.execute_sync(|| {
473 let count = counter_clone.fetch_add(1, Ordering::SeqCst);
474 if count < 2 {
475 Err("temporary failure")
476 } else {
477 Ok("success")
478 }
479 });
480
481 assert!(result.is_ok());
482 assert_eq!(result.unwrap(), "success");
483 assert_eq!(counter.load(Ordering::SeqCst), 3);
484 }
485
486 #[tokio::test]
487 async fn test_retry_with_custom_condition() {
488 let config = RetryConfig {
489 max_retries: 3,
490 initial_delay: Duration::from_millis(10),
491 ..Default::default()
492 };
493 let policy = RetryPolicy::new(config);
494 let counter = Arc::new(AtomicU32::new(0));
495 let counter_clone = Arc::clone(&counter);
496
497 let result = policy
499 .execute_with_condition(
500 || {
501 let counter = Arc::clone(&counter_clone);
502 async move {
503 counter.fetch_add(1, Ordering::SeqCst);
504 Err::<String, _>("permanent error")
505 }
506 },
507 |e| e.contains("temporary"),
508 )
509 .await;
510
511 assert!(matches!(result, Err(RetryError::Aborted(_))));
513 assert_eq!(counter.load(Ordering::SeqCst), 1);
514 }
515
516 #[tokio::test]
517 async fn test_retry_builder() {
518 let config = RetryConfigBuilder::new()
519 .max_retries(5)
520 .initial_delay(Duration::from_millis(50))
521 .max_delay(Duration::from_secs(10))
522 .multiplier(3.0)
523 .use_jitter(false)
524 .build();
525
526 assert_eq!(config.max_retries, 5);
527 assert_eq!(config.initial_delay, Duration::from_millis(50));
528 assert_eq!(config.max_delay, Duration::from_secs(10));
529 assert_eq!(config.multiplier, 3.0);
530 assert!(!config.use_jitter);
531 }
532
533 #[tokio::test]
534 async fn test_zero_retries() {
535 let config = RetryConfig {
536 max_retries: 0,
537 initial_delay: Duration::from_millis(10),
538 ..Default::default()
539 };
540 let policy = RetryPolicy::new(config);
541 let counter = Arc::new(AtomicU32::new(0));
542 let counter_clone = Arc::clone(&counter);
543
544 let result = policy
545 .execute(|| async {
546 counter_clone.fetch_add(1, Ordering::SeqCst);
547 Err::<String, _>("error")
548 })
549 .await;
550
551 assert!(result.is_err());
552 assert_eq!(counter.load(Ordering::SeqCst), 1); }
554
555 #[tokio::test]
556 async fn test_concurrent_retries() {
557 let policy = Arc::new(RetryPolicy::default());
558 let mut handles = vec![];
559
560 for i in 0..5 {
561 let policy_clone = Arc::clone(&policy);
562 let handle = tokio::spawn(async move {
563 policy_clone
564 .execute(|| async move {
565 if i % 2 == 0 {
566 Ok(format!("success {}", i))
567 } else {
568 Err(format!("error {}", i))
569 }
570 })
571 .await
572 });
573 handles.push(handle);
574 }
575
576 for (i, handle) in handles.into_iter().enumerate() {
577 let result = handle.await.unwrap();
578 if i % 2 == 0 {
579 assert!(result.is_ok());
580 } else {
581 assert!(result.is_err());
582 }
583 }
584 }
585
586 #[tokio::test]
587 async fn test_jitter_factor_clamping() {
588 let config = RetryConfigBuilder::new()
589 .jitter_factor(1.5) .build();
591
592 assert_eq!(config.jitter_factor, 1.0);
593
594 let config = RetryConfigBuilder::new()
595 .jitter_factor(-0.5) .build();
597
598 assert_eq!(config.jitter_factor, 0.0);
599 }
600
601 #[tokio::test]
602 async fn test_timing_accuracy() {
603 let config = RetryConfig {
604 max_retries: 2,
605 initial_delay: Duration::from_millis(100),
606 multiplier: 2.0,
607 use_jitter: false,
608 ..Default::default()
609 };
610 let policy = RetryPolicy::new(config);
611 let counter = Arc::new(AtomicU32::new(0));
612 let counter_clone = Arc::clone(&counter);
613
614 let start = std::time::Instant::now();
615
616 let _ = policy
617 .execute(|| async {
618 counter_clone.fetch_add(1, Ordering::SeqCst);
619 Err::<String, _>("error")
620 })
621 .await;
622
623 let elapsed = start.elapsed();
624
625 assert!(
627 elapsed >= Duration::from_millis(300),
628 "Elapsed time {:?} less than expected",
629 elapsed
630 );
631 }
632}