1use std::time::Duration;
4
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
9pub struct RetryConfig {
10 pub max_attempts: u32,
12
13 #[serde(
15 rename = "initial_delay_ms",
16 serialize_with = "serialize_duration_ms",
17 deserialize_with = "deserialize_duration_ms"
18 )]
19 pub initial_delay: Duration,
20
21 #[serde(
23 rename = "max_delay_ms",
24 serialize_with = "serialize_duration_ms",
25 deserialize_with = "deserialize_duration_ms"
26 )]
27 pub max_delay: Duration,
28
29 pub multiplier: f64,
31}
32
33impl Default for RetryConfig {
34 fn default() -> Self {
35 Self {
36 max_attempts: 3,
37 initial_delay: Duration::from_millis(100),
38 max_delay: Duration::from_secs(10),
39 multiplier: 2.0,
40 }
41 }
42}
43
44impl RetryConfig {
45 pub fn builder() -> RetryConfigBuilder {
47 RetryConfigBuilder::default()
48 }
49
50 pub fn no_retry() -> Self {
52 Self {
53 max_attempts: 1,
54 ..Default::default()
55 }
56 }
57
58 #[allow(
60 clippy::cast_precision_loss,
61 clippy::cast_possible_wrap,
62 clippy::cast_possible_truncation,
63 clippy::cast_sign_loss
64 )]
65 pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
66 let delay_ms = self.initial_delay.as_millis() as f64 * self.multiplier.powi(attempt as i32);
67 let delay = Duration::from_millis(delay_ms as u64);
68 delay.min(self.max_delay)
69 }
70}
71
72#[derive(Debug, Default)]
74#[must_use]
75pub struct RetryConfigBuilder {
76 config: RetryConfig,
77}
78
79impl RetryConfigBuilder {
80 pub fn max_attempts(mut self, n: u32) -> Self {
82 self.config.max_attempts = n;
83 self
84 }
85
86 pub fn initial_delay(mut self, delay: Duration) -> Self {
88 self.config.initial_delay = delay;
89 self
90 }
91
92 pub fn max_delay(mut self, delay: Duration) -> Self {
94 self.config.max_delay = delay;
95 self
96 }
97
98 pub fn multiplier(mut self, m: f64) -> Self {
100 self.config.multiplier = m;
101 self
102 }
103
104 pub fn build(self) -> RetryConfig {
106 self.config
107 }
108}
109
110#[allow(clippy::cast_possible_truncation)]
111fn serialize_duration_ms<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
112where
113 S: serde::Serializer,
114{
115 serializer.serialize_u64(duration.as_millis() as u64)
116}
117
118fn deserialize_duration_ms<'de, D>(deserializer: D) -> Result<Duration, D::Error>
119where
120 D: serde::Deserializer<'de>,
121{
122 let ms = u64::deserialize(deserializer)?;
123 Ok(Duration::from_millis(ms))
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129 use std::time::Duration;
130
131 #[test]
134 fn retry_config_default_values() {
135 let config = RetryConfig::default();
136
137 assert_eq!(config.max_attempts, 3);
138 assert_eq!(config.initial_delay, Duration::from_millis(100));
139 assert_eq!(config.max_delay, Duration::from_secs(10));
140 assert!((config.multiplier - 2.0_f64).abs() < f64::EPSILON);
141 }
142
143 #[test]
144 fn retry_config_builder_sets_max_attempts() {
145 let config = RetryConfig::builder().max_attempts(5).build();
146
147 assert_eq!(config.max_attempts, 5);
148 }
149
150 #[test]
151 fn retry_config_builder_sets_initial_delay() {
152 let config = RetryConfig::builder()
153 .initial_delay(Duration::from_millis(500))
154 .build();
155
156 assert_eq!(config.initial_delay, Duration::from_millis(500));
157 }
158
159 #[test]
160 fn retry_config_builder_sets_max_delay() {
161 let config = RetryConfig::builder()
162 .max_delay(Duration::from_secs(30))
163 .build();
164
165 assert_eq!(config.max_delay, Duration::from_secs(30));
166 }
167
168 #[test]
169 fn retry_config_builder_sets_multiplier() {
170 let config = RetryConfig::builder().multiplier(1.5).build();
171
172 assert!((config.multiplier - 1.5_f64).abs() < f64::EPSILON);
173 }
174
175 #[test]
176 fn retry_config_delay_for_attempt_increases_exponentially() {
177 let config = RetryConfig::builder()
178 .initial_delay(Duration::from_millis(100))
179 .multiplier(2.0)
180 .max_delay(Duration::from_secs(10))
181 .build();
182
183 assert_eq!(config.delay_for_attempt(0), Duration::from_millis(100));
185 assert_eq!(config.delay_for_attempt(1), Duration::from_millis(200));
187 assert_eq!(config.delay_for_attempt(2), Duration::from_millis(400));
189 assert_eq!(config.delay_for_attempt(3), Duration::from_millis(800));
191 }
192
193 #[test]
194 fn retry_config_delay_capped_at_max() {
195 let config = RetryConfig::builder()
196 .initial_delay(Duration::from_secs(1))
197 .multiplier(10.0)
198 .max_delay(Duration::from_secs(5))
199 .build();
200
201 assert_eq!(config.delay_for_attempt(0), Duration::from_secs(1));
203 assert_eq!(config.delay_for_attempt(1), Duration::from_secs(5));
205 assert_eq!(config.delay_for_attempt(2), Duration::from_secs(5));
207 }
208
209 #[test]
210 fn retry_config_no_retry_returns_single_attempt() {
211 let config = RetryConfig::no_retry();
212
213 assert_eq!(config.max_attempts, 1);
214 }
215
216 #[test]
219 fn retry_config_serializes_to_json() {
220 let config = RetryConfig::builder()
221 .max_attempts(5)
222 .initial_delay(Duration::from_millis(200))
223 .build();
224
225 let json = serde_json::to_value(&config).unwrap();
226
227 assert_eq!(json["max_attempts"], 5);
228 assert_eq!(json["initial_delay_ms"], 200);
229 }
230
231 #[test]
232 fn retry_config_deserializes_from_json() {
233 let json = serde_json::json!({
234 "max_attempts": 4,
235 "initial_delay_ms": 500,
236 "max_delay_ms": 30000,
237 "multiplier": 1.5
238 });
239
240 let config: RetryConfig = serde_json::from_value(json).unwrap();
241
242 assert_eq!(config.max_attempts, 4);
243 assert_eq!(config.initial_delay, Duration::from_millis(500));
244 assert_eq!(config.max_delay, Duration::from_secs(30));
245 assert!((config.multiplier - 1.5_f64).abs() < f64::EPSILON);
246 }
247
248 #[test]
249 fn retry_config_serde_roundtrip() {
250 let original = RetryConfig::builder()
251 .max_attempts(7)
252 .initial_delay(Duration::from_millis(250))
253 .max_delay(Duration::from_secs(60))
254 .multiplier(3.0)
255 .build();
256
257 let json = serde_json::to_string(&original).unwrap();
258 let deserialized: RetryConfig = serde_json::from_str(&json).unwrap();
259
260 assert_eq!(original, deserialized);
261 }
262}