1use std::time::Duration;
2
3use serde::{Deserialize, Serialize};
4
5fn default_max_retries() -> u32 {
6 3
7}
8fn default_base_delay_ms() -> u64 {
9 1000
10}
11fn default_max_delay_ms() -> u64 {
12 30_000
13}
14fn default_jitter() -> bool {
15 true
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct RetryConfig {
21 #[serde(default = "default_max_retries")]
23 pub max_retries: u32,
24
25 #[serde(default = "default_base_delay_ms")]
27 pub base_delay_ms: u64,
28
29 #[serde(default = "default_max_delay_ms")]
31 pub max_delay_ms: u64,
32
33 #[serde(default = "default_jitter")]
35 pub jitter: bool,
36}
37
38impl Default for RetryConfig {
39 fn default() -> Self {
40 Self {
41 max_retries: default_max_retries(),
42 base_delay_ms: default_base_delay_ms(),
43 max_delay_ms: default_max_delay_ms(),
44 jitter: default_jitter(),
45 }
46 }
47}
48
49impl RetryConfig {
50 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
52 let exp = self
53 .base_delay_ms
54 .saturating_mul(1u64.checked_shl(attempt).unwrap_or(u64::MAX));
55 let capped = exp.min(self.max_delay_ms);
56 if self.jitter {
57 Duration::from_millis(fastrand::u64(0..=capped))
58 } else {
59 Duration::from_millis(capped)
60 }
61 }
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67
68 #[test]
69 fn default_config() {
70 let config = RetryConfig::default();
71 assert_eq!(config.max_retries, 3);
72 assert_eq!(config.base_delay_ms, 1000);
73 assert_eq!(config.max_delay_ms, 30_000);
74 assert!(config.jitter);
75 }
76
77 #[test]
78 fn delay_exponential_no_jitter() {
79 let config = RetryConfig {
80 jitter: false,
81 ..Default::default()
82 };
83 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(1000));
84 assert_eq!(config.delay_for_attempt(1), Duration::from_millis(2000));
85 assert_eq!(config.delay_for_attempt(2), Duration::from_millis(4000));
86 assert_eq!(config.delay_for_attempt(3), Duration::from_millis(8000));
87 }
88
89 #[test]
90 fn delay_capped_at_max() {
91 let config = RetryConfig {
92 jitter: false,
93 max_delay_ms: 30_000,
94 ..Default::default()
95 };
96 assert_eq!(config.delay_for_attempt(10), Duration::from_millis(30_000));
98 }
99
100 #[test]
101 fn delay_with_jitter_in_bounds() {
102 let config = RetryConfig::default();
103 for _ in 0..100 {
104 let delay = config.delay_for_attempt(0);
105 assert!(delay <= Duration::from_millis(1000));
106 }
107 }
108
109 #[test]
110 fn delay_attempt_zero() {
111 let config = RetryConfig {
112 jitter: false,
113 ..Default::default()
114 };
115 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(1000));
116 }
117
118 #[test]
119 fn serde_roundtrip() {
120 let config = RetryConfig {
121 max_retries: 5,
122 base_delay_ms: 500,
123 max_delay_ms: 10_000,
124 jitter: false,
125 };
126 let toml_str = toml::to_string(&config).unwrap();
127 let parsed: RetryConfig = toml::from_str(&toml_str).unwrap();
128 assert_eq!(parsed.max_retries, 5);
129 assert_eq!(parsed.base_delay_ms, 500);
130 assert_eq!(parsed.max_delay_ms, 10_000);
131 assert!(!parsed.jitter);
132 }
133
134 #[test]
135 fn serde_defaults_from_empty() {
136 let parsed: RetryConfig = toml::from_str("").unwrap();
137 assert_eq!(parsed.max_retries, 3);
138 assert_eq!(parsed.base_delay_ms, 1000);
139 assert_eq!(parsed.max_delay_ms, 30_000);
140 assert!(parsed.jitter);
141 }
142}