1use crate::{Error, Result};
2use std::time::Duration;
3
4#[derive(Debug, Clone)]
9pub struct RetryPolicy {
10 pub max_attempts: u32,
12
13 pub initial_backoff_ms: u64,
15
16 pub max_backoff_ms: u64,
18
19 pub backoff_multiplier: f64,
21}
22
23impl RetryPolicy {
24 pub fn new(
26 max_attempts: u32,
27 initial_backoff_ms: u64,
28 max_backoff_ms: u64,
29 backoff_multiplier: f64,
30 ) -> Self {
31 Self {
32 max_attempts,
33 initial_backoff_ms,
34 max_backoff_ms,
35 backoff_multiplier,
36 }
37 }
38
39 pub fn no_retry() -> Self {
41 Self {
42 max_attempts: 0,
43 initial_backoff_ms: 0,
44 max_backoff_ms: 0,
45 backoff_multiplier: 1.0,
46 }
47 }
48
49 pub fn fast() -> Self {
51 Self {
52 max_attempts: 3,
53 initial_backoff_ms: 10,
54 max_backoff_ms: 100,
55 backoff_multiplier: 2.0,
56 }
57 }
58
59 pub fn standard() -> Self {
61 Self {
62 max_attempts: 5,
63 initial_backoff_ms: 100,
64 max_backoff_ms: 5000,
65 backoff_multiplier: 2.0,
66 }
67 }
68
69 pub fn backoff_duration(&self, attempt: u32) -> Duration {
71 let backoff_ms = (self.initial_backoff_ms as f64
72 * self.backoff_multiplier.powi(attempt as i32))
73 .min(self.max_backoff_ms as f64) as u64;
74 Duration::from_millis(backoff_ms)
75 }
76}
77
78impl Default for RetryPolicy {
79 fn default() -> Self {
81 Self::standard()
82 }
83}
84
85pub fn retry_with_policy<F, T>(
108 policy: &RetryPolicy,
109 mut operation: F,
110) -> Result<T>
111where
112 F: FnMut() -> Result<T>,
113{
114 let mut last_error = None;
115
116 match operation() {
118 Ok(result) => return Ok(result),
119 Err(e) => {
120 if !e.is_retryable() {
121 return Err(e);
123 }
124 last_error = Some(e);
125 }
126 }
127
128 for attempt in 0..policy.max_attempts {
130 let backoff = policy.backoff_duration(attempt);
131 std::thread::sleep(backoff);
132
133 match operation() {
134 Ok(result) => return Ok(result),
135 Err(e) => {
136 if !e.is_retryable() {
137 return Err(e);
139 }
140 last_error = Some(e);
141 }
142 }
143 }
144
145 Err(last_error.unwrap_or_else(|| {
147 Error::Internal("retry exhausted without error".to_string())
148 }))
149}
150
151pub fn retry<F, T>(operation: F) -> Result<T>
168where
169 F: FnMut() -> Result<T>,
170{
171 retry_with_policy(&RetryPolicy::default(), operation)
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use std::sync::{Arc, Mutex};
178
179 #[test]
180 fn test_retry_policy_default() {
181 let policy = RetryPolicy::default();
182 assert_eq!(policy.max_attempts, 5);
183 assert_eq!(policy.initial_backoff_ms, 100);
184 assert_eq!(policy.max_backoff_ms, 5000);
185 assert_eq!(policy.backoff_multiplier, 2.0);
186 }
187
188 #[test]
189 fn test_retry_policy_fast() {
190 let policy = RetryPolicy::fast();
191 assert_eq!(policy.max_attempts, 3);
192 assert_eq!(policy.initial_backoff_ms, 10);
193 }
194
195 #[test]
196 fn test_retry_policy_no_retry() {
197 let policy = RetryPolicy::no_retry();
198 assert_eq!(policy.max_attempts, 0);
199 }
200
201 #[test]
202 fn test_backoff_duration_exponential() {
203 let policy = RetryPolicy::new(5, 100, 10000, 2.0);
204
205 assert_eq!(policy.backoff_duration(0).as_millis(), 100);
206 assert_eq!(policy.backoff_duration(1).as_millis(), 200);
207 assert_eq!(policy.backoff_duration(2).as_millis(), 400);
208 assert_eq!(policy.backoff_duration(3).as_millis(), 800);
209 }
210
211 #[test]
212 fn test_backoff_duration_respects_max() {
213 let policy = RetryPolicy::new(10, 100, 500, 2.0);
214
215 assert_eq!(policy.backoff_duration(5).as_millis(), 500);
217 assert_eq!(policy.backoff_duration(10).as_millis(), 500);
218 }
219
220 #[test]
221 fn test_retry_succeeds_immediately() {
222 let policy = RetryPolicy::fast();
223 let counter = Arc::new(Mutex::new(0));
224 let counter_clone = counter.clone();
225
226 let result = retry_with_policy(&policy, || {
227 let mut count = counter_clone.lock().unwrap();
228 *count += 1;
229 Ok::<i32, Error>(42)
230 });
231
232 assert!(result.is_ok());
233 assert_eq!(result.unwrap(), 42);
234 assert_eq!(*counter.lock().unwrap(), 1); }
236
237 #[test]
238 fn test_retry_succeeds_after_failures() {
239 let policy = RetryPolicy::fast();
240 let counter = Arc::new(Mutex::new(0));
241 let counter_clone = counter.clone();
242
243 let result = retry_with_policy(&policy, || {
244 let mut count = counter_clone.lock().unwrap();
245 *count += 1;
246
247 if *count < 3 {
248 Err(Error::Io(std::io::Error::new(
250 std::io::ErrorKind::TimedOut,
251 "timeout"
252 )))
253 } else {
254 Ok::<i32, Error>(42)
255 }
256 });
257
258 assert!(result.is_ok());
259 assert_eq!(result.unwrap(), 42);
260 assert_eq!(*counter.lock().unwrap(), 3); }
262
263 #[test]
264 fn test_retry_fails_after_max_attempts() {
265 let policy = RetryPolicy::new(2, 1, 10, 1.5); let counter = Arc::new(Mutex::new(0));
267 let counter_clone = counter.clone();
268
269 let result = retry_with_policy(&policy, || {
270 let mut count = counter_clone.lock().unwrap();
271 *count += 1;
272
273 Err::<i32, Error>(Error::Io(std::io::Error::new(
275 std::io::ErrorKind::TimedOut,
276 "timeout"
277 )))
278 });
279
280 assert!(result.is_err());
281 assert_eq!(*counter.lock().unwrap(), 3);
283 }
284
285 #[test]
286 fn test_retry_does_not_retry_non_retryable_error() {
287 let policy = RetryPolicy::fast();
288 let counter = Arc::new(Mutex::new(0));
289 let counter_clone = counter.clone();
290
291 let result = retry_with_policy(&policy, || {
292 let mut count = counter_clone.lock().unwrap();
293 *count += 1;
294
295 Err::<i32, Error>(Error::InvalidArgument("bad input".to_string()))
297 });
298
299 assert!(result.is_err());
300 assert_eq!(*counter.lock().unwrap(), 1); match result {
303 Err(Error::InvalidArgument(_)) => (),
304 _ => panic!("Expected InvalidArgument error"),
305 }
306 }
307
308 #[test]
309 fn test_retry_no_retry_policy() {
310 let policy = RetryPolicy::no_retry();
311 let counter = Arc::new(Mutex::new(0));
312 let counter_clone = counter.clone();
313
314 let result = retry_with_policy(&policy, || {
315 let mut count = counter_clone.lock().unwrap();
316 *count += 1;
317
318 Err::<i32, Error>(Error::Io(std::io::Error::new(
320 std::io::ErrorKind::TimedOut,
321 "timeout"
322 )))
323 });
324
325 assert!(result.is_err());
326 assert_eq!(*counter.lock().unwrap(), 1); }
328
329 #[test]
330 fn test_retry_helper_function() {
331 let counter = Arc::new(Mutex::new(0));
332 let counter_clone = counter.clone();
333
334 let result = retry(|| {
335 let mut count = counter_clone.lock().unwrap();
336 *count += 1;
337
338 if *count < 2 {
339 Err(Error::Io(std::io::Error::new(
340 std::io::ErrorKind::TimedOut,
341 "timeout"
342 )))
343 } else {
344 Ok::<i32, Error>(100)
345 }
346 });
347
348 assert!(result.is_ok());
349 assert_eq!(result.unwrap(), 100);
350 }
351}