1#![allow(
2 clippy::cast_possible_truncation,
3 clippy::cast_sign_loss,
4 clippy::cast_precision_loss,
5 clippy::cast_possible_wrap
6)]
7use rand::Rng;
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
17#[serde(tag = "type")]
18pub enum RetryStrategy {
19 Fixed {
21 delay: u64,
23 },
24
25 Linear {
27 initial: u64,
29 increment: u64,
31 max_delay: Option<u64>,
33 },
34
35 Exponential {
37 initial: u64,
39 multiplier: f64,
41 max_delay: Option<u64>,
43 },
44
45 Polynomial {
47 initial: u64,
49 power: f64,
51 max_delay: Option<u64>,
53 },
54
55 Fibonacci {
57 initial: u64,
59 max_delay: Option<u64>,
61 },
62
63 DecorrelatedJitter {
66 base: u64,
68 max_delay: u64,
70 },
71
72 FullJitter {
74 initial: u64,
76 multiplier: f64,
78 max_delay: Option<u64>,
80 },
81
82 EqualJitter {
84 initial: u64,
86 multiplier: f64,
88 max_delay: Option<u64>,
90 },
91
92 Custom {
94 delays: Vec<u64>,
96 fallback: u64,
98 },
99
100 Immediate,
102}
103
104impl Default for RetryStrategy {
105 fn default() -> Self {
106 Self::Exponential {
107 initial: 1,
108 multiplier: 2.0,
109 max_delay: Some(3600),
110 }
111 }
112}
113
114impl RetryStrategy {
115 #[must_use]
117 pub fn fixed(delay: u64) -> Self {
118 Self::Fixed { delay }
119 }
120
121 #[must_use]
123 pub fn linear(initial: u64, increment: u64) -> Self {
124 Self::Linear {
125 initial,
126 increment,
127 max_delay: None,
128 }
129 }
130
131 #[must_use]
133 pub fn linear_with_max(initial: u64, increment: u64, max_delay: u64) -> Self {
134 Self::Linear {
135 initial,
136 increment,
137 max_delay: Some(max_delay),
138 }
139 }
140
141 #[must_use]
143 pub fn exponential(initial: u64, multiplier: f64) -> Self {
144 Self::Exponential {
145 initial,
146 multiplier,
147 max_delay: None,
148 }
149 }
150
151 #[must_use]
153 pub fn exponential_with_max(initial: u64, multiplier: f64, max_delay: u64) -> Self {
154 Self::Exponential {
155 initial,
156 multiplier,
157 max_delay: Some(max_delay),
158 }
159 }
160
161 #[must_use]
163 pub fn polynomial(initial: u64, power: f64) -> Self {
164 Self::Polynomial {
165 initial,
166 power,
167 max_delay: None,
168 }
169 }
170
171 #[must_use]
173 pub fn fibonacci(initial: u64) -> Self {
174 Self::Fibonacci {
175 initial,
176 max_delay: None,
177 }
178 }
179
180 #[must_use]
182 pub fn decorrelated_jitter(base: u64, max_delay: u64) -> Self {
183 Self::DecorrelatedJitter { base, max_delay }
184 }
185
186 #[must_use]
188 pub fn full_jitter(initial: u64, multiplier: f64, max_delay: u64) -> Self {
189 Self::FullJitter {
190 initial,
191 multiplier,
192 max_delay: Some(max_delay),
193 }
194 }
195
196 #[must_use]
198 pub fn equal_jitter(initial: u64, multiplier: f64, max_delay: u64) -> Self {
199 Self::EqualJitter {
200 initial,
201 multiplier,
202 max_delay: Some(max_delay),
203 }
204 }
205
206 #[must_use]
208 pub fn custom(delays: Vec<u64>, fallback: u64) -> Self {
209 Self::Custom { delays, fallback }
210 }
211
212 #[must_use]
214 pub fn immediate() -> Self {
215 Self::Immediate
216 }
217
218 #[must_use]
227 pub fn calculate_delay(&self, retry_count: u32, previous_delay: Option<u64>) -> u64 {
228 match self {
229 Self::Fixed { delay } => *delay,
230
231 Self::Linear {
232 initial,
233 increment,
234 max_delay,
235 } => {
236 let delay = *initial + (*increment * u64::from(retry_count));
237 max_delay.map_or(delay, |max| delay.min(max))
238 }
239
240 Self::Exponential {
241 initial,
242 multiplier,
243 max_delay,
244 } => {
245 let delay = (*initial as f64 * multiplier.powi(retry_count as i32)) as u64;
246 max_delay.map_or(delay, |max| delay.min(max))
247 }
248
249 Self::Polynomial {
250 initial,
251 power,
252 max_delay,
253 } => {
254 let delay = (*initial as f64 * (f64::from(retry_count) + 1.0).powf(*power)) as u64;
255 max_delay.map_or(delay, |max| delay.min(max))
256 }
257
258 Self::Fibonacci { initial, max_delay } => {
259 let delay = *initial * fibonacci_number(retry_count + 2);
262 max_delay.map_or(delay, |max| delay.min(max))
263 }
264
265 Self::DecorrelatedJitter { base, max_delay } => {
266 let prev = previous_delay.unwrap_or(*base);
267 let upper = (prev * 3).min(*max_delay);
268 let lower = *base;
269 if upper <= lower {
270 lower
271 } else {
272 rand::rng().random_range(lower..=upper)
273 }
274 }
275
276 Self::FullJitter {
277 initial,
278 multiplier,
279 max_delay,
280 } => {
281 let exp_delay = (*initial as f64 * multiplier.powi(retry_count as i32)) as u64;
282 let capped = max_delay.map_or(exp_delay, |max| exp_delay.min(max));
283 if capped == 0 {
284 0
285 } else {
286 rand::rng().random_range(0..=capped)
287 }
288 }
289
290 Self::EqualJitter {
291 initial,
292 multiplier,
293 max_delay,
294 } => {
295 let exp_delay = (*initial as f64 * multiplier.powi(retry_count as i32)) as u64;
296 let capped = max_delay.map_or(exp_delay, |max| exp_delay.min(max));
297 let half = capped / 2;
298 if half == 0 {
299 half
300 } else {
301 half + rand::rng().random_range(0..=half)
302 }
303 }
304
305 Self::Custom { delays, fallback } => delays
306 .get(retry_count as usize)
307 .copied()
308 .unwrap_or(*fallback),
309
310 Self::Immediate => 0,
311 }
312 }
313
314 #[inline]
316 #[must_use]
317 pub const fn name(&self) -> &'static str {
318 match self {
319 Self::Fixed { .. } => "fixed",
320 Self::Linear { .. } => "linear",
321 Self::Exponential { .. } => "exponential",
322 Self::Polynomial { .. } => "polynomial",
323 Self::Fibonacci { .. } => "fibonacci",
324 Self::DecorrelatedJitter { .. } => "decorrelated_jitter",
325 Self::FullJitter { .. } => "full_jitter",
326 Self::EqualJitter { .. } => "equal_jitter",
327 Self::Custom { .. } => "custom",
328 Self::Immediate => "immediate",
329 }
330 }
331
332 #[inline]
334 #[must_use]
335 pub const fn is_jittered(&self) -> bool {
336 matches!(
337 self,
338 Self::DecorrelatedJitter { .. } | Self::FullJitter { .. } | Self::EqualJitter { .. }
339 )
340 }
341}
342
343impl std::fmt::Display for RetryStrategy {
344 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
345 match self {
346 Self::Fixed { delay } => write!(f, "Fixed({delay}s)"),
347 Self::Linear {
348 initial, increment, ..
349 } => write!(f, "Linear({initial}s + {increment}s/retry)"),
350 Self::Exponential {
351 initial,
352 multiplier,
353 ..
354 } => write!(f, "Exponential({initial}s * {multiplier}^n)"),
355 Self::Polynomial { initial, power, .. } => {
356 write!(f, "Polynomial({initial}s * n^{power})")
357 }
358 Self::Fibonacci { initial, .. } => write!(f, "Fibonacci({initial}s)"),
359 Self::DecorrelatedJitter { base, max_delay } => {
360 write!(f, "DecorrelatedJitter(base={base}s, max={max_delay}s)")
361 }
362 Self::FullJitter {
363 initial,
364 multiplier,
365 ..
366 } => write!(f, "FullJitter({initial}s * {multiplier}^n)"),
367 Self::EqualJitter {
368 initial,
369 multiplier,
370 ..
371 } => write!(f, "EqualJitter({initial}s * {multiplier}^n)"),
372 Self::Custom { delays, fallback } => {
373 write!(f, "Custom({} delays, fallback={}s)", delays.len(), fallback)
374 }
375 Self::Immediate => write!(f, "Immediate"),
376 }
377 }
378}
379
380fn fibonacci_number(n: u32) -> u64 {
382 if n <= 1 {
383 return u64::from(n);
384 }
385
386 let mut a = 0u64;
387 let mut b = 1u64;
388
389 for _ in 2..=n {
390 let temp = a + b;
391 a = b;
392 b = temp;
393 }
394
395 b
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct RetryPolicy {
401 pub max_retries: u32,
403
404 pub strategy: RetryStrategy,
406
407 #[serde(default)]
409 pub retry_on: Vec<String>,
410
411 #[serde(default)]
413 pub dont_retry_on: Vec<String>,
414
415 #[serde(default = "default_true")]
417 pub retry_on_timeout: bool,
418
419 #[serde(default)]
421 pub preserve_on_failure: bool,
422}
423
424fn default_true() -> bool {
425 true
426}
427
428impl Default for RetryPolicy {
429 fn default() -> Self {
430 Self {
431 max_retries: 3,
432 strategy: RetryStrategy::default(),
433 retry_on: Vec::new(),
434 dont_retry_on: Vec::new(),
435 retry_on_timeout: true,
436 preserve_on_failure: false,
437 }
438 }
439}
440
441impl RetryPolicy {
442 #[must_use]
444 pub fn new(max_retries: u32, strategy: RetryStrategy) -> Self {
445 Self {
446 max_retries,
447 strategy,
448 ..Default::default()
449 }
450 }
451
452 #[must_use]
454 pub fn no_retry() -> Self {
455 Self {
456 max_retries: 0,
457 strategy: RetryStrategy::Immediate,
458 ..Default::default()
459 }
460 }
461
462 #[must_use]
464 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
465 self.max_retries = max_retries;
466 self
467 }
468
469 #[must_use]
471 pub fn with_strategy(mut self, strategy: RetryStrategy) -> Self {
472 self.strategy = strategy;
473 self
474 }
475
476 #[must_use]
478 pub fn retry_on(mut self, patterns: Vec<String>) -> Self {
479 self.retry_on = patterns;
480 self
481 }
482
483 #[must_use]
485 pub fn dont_retry_on(mut self, patterns: Vec<String>) -> Self {
486 self.dont_retry_on = patterns;
487 self
488 }
489
490 #[must_use]
492 pub fn with_retry_on_timeout(mut self, retry: bool) -> Self {
493 self.retry_on_timeout = retry;
494 self
495 }
496
497 #[must_use]
499 pub fn should_retry(&self, error: &str, retry_count: u32) -> bool {
500 if retry_count >= self.max_retries {
502 return false;
503 }
504
505 for pattern in &self.dont_retry_on {
507 if error.contains(pattern) {
508 return false;
509 }
510 }
511
512 if self.retry_on.is_empty() {
514 return true;
515 }
516
517 for pattern in &self.retry_on {
519 if error.contains(pattern) {
520 return true;
521 }
522 }
523
524 false
525 }
526
527 #[inline]
529 #[must_use]
530 pub fn get_retry_delay(&self, retry_count: u32, previous_delay: Option<u64>) -> u64 {
531 self.strategy.calculate_delay(retry_count, previous_delay)
532 }
533
534 #[inline]
536 #[must_use]
537 pub const fn allows_retry(&self) -> bool {
538 self.max_retries > 0
539 }
540}
541
542impl std::fmt::Display for RetryPolicy {
543 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
544 write!(
545 f,
546 "RetryPolicy(max={}, strategy={})",
547 self.max_retries, self.strategy
548 )
549 }
550}
551
552#[cfg(test)]
553mod tests {
554 use super::*;
555
556 #[test]
557 fn test_fixed_delay() {
558 let strategy = RetryStrategy::fixed(5);
559 assert_eq!(strategy.calculate_delay(0, None), 5);
560 assert_eq!(strategy.calculate_delay(1, None), 5);
561 assert_eq!(strategy.calculate_delay(10, None), 5);
562 }
563
564 #[test]
565 fn test_linear_backoff() {
566 let strategy = RetryStrategy::linear(1, 2);
567 assert_eq!(strategy.calculate_delay(0, None), 1);
568 assert_eq!(strategy.calculate_delay(1, None), 3);
569 assert_eq!(strategy.calculate_delay(2, None), 5);
570 assert_eq!(strategy.calculate_delay(3, None), 7);
571 }
572
573 #[test]
574 fn test_linear_with_max() {
575 let strategy = RetryStrategy::linear_with_max(1, 2, 5);
576 assert_eq!(strategy.calculate_delay(0, None), 1);
577 assert_eq!(strategy.calculate_delay(1, None), 3);
578 assert_eq!(strategy.calculate_delay(2, None), 5);
579 assert_eq!(strategy.calculate_delay(3, None), 5);
580 assert_eq!(strategy.calculate_delay(10, None), 5);
581 }
582
583 #[test]
584 fn test_exponential_backoff() {
585 let strategy = RetryStrategy::exponential(1, 2.0);
586 assert_eq!(strategy.calculate_delay(0, None), 1);
587 assert_eq!(strategy.calculate_delay(1, None), 2);
588 assert_eq!(strategy.calculate_delay(2, None), 4);
589 assert_eq!(strategy.calculate_delay(3, None), 8);
590 }
591
592 #[test]
593 fn test_exponential_with_max() {
594 let strategy = RetryStrategy::exponential_with_max(1, 2.0, 5);
595 assert_eq!(strategy.calculate_delay(0, None), 1);
596 assert_eq!(strategy.calculate_delay(1, None), 2);
597 assert_eq!(strategy.calculate_delay(2, None), 4);
598 assert_eq!(strategy.calculate_delay(3, None), 5);
599 }
600
601 #[test]
602 fn test_fibonacci_backoff() {
603 let strategy = RetryStrategy::fibonacci(1);
604 assert_eq!(strategy.calculate_delay(0, None), 1); assert_eq!(strategy.calculate_delay(1, None), 2); assert_eq!(strategy.calculate_delay(2, None), 3); assert_eq!(strategy.calculate_delay(3, None), 5); assert_eq!(strategy.calculate_delay(4, None), 8); }
610
611 #[test]
612 fn test_custom_delays() {
613 let strategy = RetryStrategy::custom(vec![1, 5, 10, 30], 60);
614 assert_eq!(strategy.calculate_delay(0, None), 1);
615 assert_eq!(strategy.calculate_delay(1, None), 5);
616 assert_eq!(strategy.calculate_delay(2, None), 10);
617 assert_eq!(strategy.calculate_delay(3, None), 30);
618 assert_eq!(strategy.calculate_delay(4, None), 60);
619 assert_eq!(strategy.calculate_delay(10, None), 60);
620 }
621
622 #[test]
623 fn test_immediate() {
624 let strategy = RetryStrategy::immediate();
625 assert_eq!(strategy.calculate_delay(0, None), 0);
626 assert_eq!(strategy.calculate_delay(10, None), 0);
627 }
628
629 #[test]
630 fn test_retry_policy_should_retry() {
631 let policy = RetryPolicy::new(3, RetryStrategy::fixed(1))
632 .retry_on(vec!["timeout".to_string(), "connection".to_string()])
633 .dont_retry_on(vec!["fatal".to_string()]);
634
635 assert!(policy.should_retry("connection refused", 0));
636 assert!(policy.should_retry("timeout error", 1));
637 assert!(!policy.should_retry("fatal error", 0));
638 assert!(!policy.should_retry("connection error", 3)); }
640
641 #[test]
642 fn test_retry_policy_empty_retry_on() {
643 let policy = RetryPolicy::new(3, RetryStrategy::fixed(1));
644
645 assert!(policy.should_retry("any error", 0));
647 assert!(policy.should_retry("another error", 1));
648 }
649
650 #[test]
651 fn test_fibonacci_numbers() {
652 assert_eq!(fibonacci_number(0), 0);
653 assert_eq!(fibonacci_number(1), 1);
654 assert_eq!(fibonacci_number(2), 1);
655 assert_eq!(fibonacci_number(3), 2);
656 assert_eq!(fibonacci_number(4), 3);
657 assert_eq!(fibonacci_number(5), 5);
658 assert_eq!(fibonacci_number(6), 8);
659 assert_eq!(fibonacci_number(10), 55);
660 }
661
662 #[test]
663 fn test_strategy_names() {
664 assert_eq!(RetryStrategy::fixed(1).name(), "fixed");
665 assert_eq!(RetryStrategy::linear(1, 1).name(), "linear");
666 assert_eq!(RetryStrategy::exponential(1, 2.0).name(), "exponential");
667 assert_eq!(RetryStrategy::fibonacci(1).name(), "fibonacci");
668 assert_eq!(RetryStrategy::immediate().name(), "immediate");
669 }
670
671 #[test]
672 fn test_strategy_display() {
673 assert_eq!(format!("{}", RetryStrategy::fixed(5)), "Fixed(5s)");
674 assert_eq!(
675 format!("{}", RetryStrategy::linear(1, 2)),
676 "Linear(1s + 2s/retry)"
677 );
678 assert_eq!(
679 format!("{}", RetryStrategy::exponential(1, 2.0)),
680 "Exponential(1s * 2^n)"
681 );
682 }
683
684 mod proptests {
685 use super::*;
686 use proptest::prelude::*;
687
688 proptest! {
689 #[test]
690 fn test_fixed_delay_is_constant(delay in 1u64..10000, attempt in 0u32..100) {
691 let strategy = RetryStrategy::fixed(delay);
692 let calculated_delay = strategy.calculate_delay(attempt, Some(delay));
693 prop_assert_eq!(calculated_delay, delay);
694 }
695
696 #[test]
697 fn test_linear_delay_increases_linearly(
698 initial in 100u64..1000,
699 increment in 100u64..1000,
700 attempt in 0u32..50
701 ) {
702 let strategy = RetryStrategy::linear(initial, increment);
703 let expected = initial + (increment * u64::from(attempt));
704 let calculated = strategy.calculate_delay(attempt, None);
705 prop_assert_eq!(calculated, expected);
706 }
707
708 #[test]
709 fn test_exponential_delay_grows(
710 initial in 100u64..1000,
711 multiplier in 1.5f64..3.0,
712 attempt in 0u32..10
713 ) {
714 let strategy = RetryStrategy::exponential(initial, multiplier);
715 let delay1 = strategy.calculate_delay(attempt, None);
716 let delay2 = strategy.calculate_delay(attempt + 1, Some(delay1));
717
718 prop_assert!(delay2 >= delay1);
720 }
721
722 #[test]
723 fn test_exponential_with_max_respects_limit(
724 initial in 100u64..1000,
725 multiplier in 2.0f64..4.0,
726 max_delay in 5000u64..10000,
727 attempt in 0u32..20
728 ) {
729 let strategy = RetryStrategy::exponential_with_max(initial, multiplier, max_delay);
730 let calculated = strategy.calculate_delay(attempt, None);
731 prop_assert!(calculated <= max_delay);
732 }
733
734 #[test]
735 fn test_fibonacci_delay_grows(
736 initial in 100u64..1000,
737 attempt in 1u32..15
738 ) {
739 let strategy = RetryStrategy::fibonacci(initial);
740 let delay1 = strategy.calculate_delay(attempt, None);
741 let delay2 = strategy.calculate_delay(attempt + 1, Some(delay1));
742
743 prop_assert!(delay2 >= delay1);
745 }
746
747 #[test]
748 fn test_immediate_is_always_zero(attempt in 0u32..1000) {
749 let strategy = RetryStrategy::immediate();
750 let delay = strategy.calculate_delay(attempt, None);
751 prop_assert_eq!(delay, 0);
752 }
753
754 #[test]
755 fn test_full_jitter_within_bounds(
756 initial in 100u64..1000,
757 multiplier in 2.0f64..3.0,
758 max_delay in 10000u64..20000,
759 attempt in 0u32..10
760 ) {
761 let strategy = RetryStrategy::full_jitter(initial, multiplier, max_delay);
762 let delay = strategy.calculate_delay(attempt, None);
763
764 prop_assert!(delay <= max_delay);
766 }
767
768 #[test]
769 fn test_decorrelated_jitter_within_bounds(
770 base in 100u64..1000,
771 max_delay in 10000u64..20000,
772 attempt in 0u32..50,
773 prev_delay in 100u64..5000
774 ) {
775 let strategy = RetryStrategy::decorrelated_jitter(base, max_delay);
776 let delay = strategy.calculate_delay(attempt, Some(prev_delay));
777
778 prop_assert!(delay <= max_delay);
780 prop_assert!(delay >= base);
781 }
782
783 #[test]
784 fn test_polynomial_delay_grows(
785 initial in 100u64..1000,
786 power in 1.0f64..3.0,
787 attempt in 1u32..10
788 ) {
789 let strategy = RetryStrategy::polynomial(initial, power);
790 let delay1 = strategy.calculate_delay(attempt, None);
791 let delay2 = strategy.calculate_delay(attempt + 1, Some(delay1));
792
793 if power >= 1.0 {
795 prop_assert!(delay2 >= delay1);
796 }
797 }
798
799 #[test]
800 fn test_custom_strategy_uses_provided_delays(
801 delays in prop::collection::vec(100u64..5000, 1..10),
802 fallback in 1000u64..5000,
803 attempt in 0u32..20
804 ) {
805 let strategy = RetryStrategy::custom(delays.clone(), fallback);
806 let calculated = strategy.calculate_delay(attempt, None);
807
808 if (attempt as usize) < delays.len() {
809 prop_assert_eq!(calculated, delays[attempt as usize]);
810 } else {
811 prop_assert_eq!(calculated, fallback);
813 }
814 }
815
816 #[test]
817 fn test_retry_policy_respects_max_retries(
818 max_retries in 0u32..100,
819 current_retry in 0u32..150
820 ) {
821 let policy = RetryPolicy::new(max_retries, RetryStrategy::fixed(1000));
822
823 let should_retry = policy.should_retry("test error", current_retry);
824
825 if current_retry < max_retries {
826 prop_assert!(should_retry);
827 } else {
828 prop_assert!(!should_retry);
829 }
830 }
831 }
832 }
833}