1use crate::{Error, Result};
27use std::future::Future;
28use std::time::Duration;
29use tokio::time::sleep;
30
31#[derive(Debug, Clone)]
33pub struct RetryConfig {
34 pub max_attempts: u32,
36
37 pub initial_delay: Duration,
39
40 pub max_delay: Duration,
42
43 pub backoff_multiplier: f64,
45
46 pub jitter_factor: f64,
48}
49
50impl Default for RetryConfig {
51 fn default() -> Self {
52 Self {
53 max_attempts: 3,
54 initial_delay: Duration::from_secs(1),
55 max_delay: Duration::from_secs(60),
56 backoff_multiplier: 2.0,
57 jitter_factor: 0.1,
58 }
59 }
60}
61
62impl RetryConfig {
63 pub fn new() -> Self {
65 Self::default()
66 }
67
68 pub fn with_max_attempts(mut self, attempts: u32) -> Self {
70 self.max_attempts = attempts;
71 self
72 }
73
74 pub fn with_initial_delay(mut self, delay: Duration) -> Self {
76 self.initial_delay = delay;
77 self
78 }
79
80 pub fn with_max_delay(mut self, delay: Duration) -> Self {
82 self.max_delay = delay;
83 self
84 }
85
86 pub fn with_backoff_multiplier(mut self, multiplier: f64) -> Self {
88 self.backoff_multiplier = multiplier;
89 self
90 }
91
92 pub fn with_jitter_factor(mut self, jitter: f64) -> Self {
94 self.jitter_factor = jitter.clamp(0.0, 1.0);
95 self
96 }
97
98 fn calculate_delay(&self, attempt: u32) -> Duration {
100 let base_delay_ms = self.initial_delay.as_millis() as f64;
101 let exponential_delay = base_delay_ms * self.backoff_multiplier.powi(attempt as i32);
102
103 let capped_delay = exponential_delay.min(self.max_delay.as_millis() as f64);
105
106 let jitter_range = capped_delay * self.jitter_factor;
108 let jitter = rand::random::<f64>() * jitter_range;
109 let final_delay = capped_delay + jitter - (jitter_range / 2.0);
110
111 Duration::from_millis(final_delay.max(0.0) as u64)
112 }
113}
114
115pub async fn retry_with_backoff<F, Fut, T>(config: RetryConfig, mut operation: F) -> Result<T>
148where
149 F: FnMut() -> Fut,
150 Fut: Future<Output = Result<T>>,
151{
152 let mut last_error = None;
153
154 for attempt in 0..config.max_attempts {
155 match operation().await {
156 Ok(result) => return Ok(result),
157 Err(err) => {
158 last_error = Some(err);
159
160 if attempt < config.max_attempts - 1 {
162 let delay = config.calculate_delay(attempt);
163 sleep(delay).await;
164 }
165 }
166 }
167 }
168
169 Err(last_error.unwrap_or_else(|| Error::other("Retry failed with no error")))
170}
171
172pub fn is_retryable_error(error: &Error) -> bool {
177 match error {
178 Error::Http(_) => true, Error::Timeout => true, Error::Stream(_) => true, Error::Api(msg) => {
182 msg.contains("500") || msg.contains("502") || msg.contains("503") || msg.contains("504")
185 }
186 Error::Config(_) => false, Error::InvalidInput(_) => false, _ => false, }
190}
191
192pub async fn retry_with_backoff_conditional<F, Fut, T>(
218 config: RetryConfig,
219 mut operation: F,
220) -> Result<T>
221where
222 F: FnMut() -> Fut,
223 Fut: Future<Output = Result<T>>,
224{
225 let mut last_error = None;
226
227 for attempt in 0..config.max_attempts {
228 match operation().await {
229 Ok(result) => return Ok(result),
230 Err(err) => {
231 if !is_retryable_error(&err) {
233 return Err(err);
234 }
235
236 last_error = Some(err);
237
238 if attempt < config.max_attempts - 1 {
240 let delay = config.calculate_delay(attempt);
241 sleep(delay).await;
242 }
243 }
244 }
245 }
246
247 Err(last_error.unwrap_or_else(|| Error::other("Retry failed with no error")))
248}
249
250#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
255 fn test_retry_config_builder() {
256 let config = RetryConfig::new()
257 .with_max_attempts(5)
258 .with_initial_delay(Duration::from_millis(500))
259 .with_max_delay(Duration::from_secs(30))
260 .with_backoff_multiplier(1.5)
261 .with_jitter_factor(0.2);
262
263 assert_eq!(config.max_attempts, 5);
264 assert_eq!(config.initial_delay, Duration::from_millis(500));
265 assert_eq!(config.max_delay, Duration::from_secs(30));
266 assert_eq!(config.backoff_multiplier, 1.5);
267 assert_eq!(config.jitter_factor, 0.2);
268 }
269
270 #[test]
271 fn test_calculate_delay() {
272 let config = RetryConfig::new()
273 .with_initial_delay(Duration::from_secs(1))
274 .with_backoff_multiplier(2.0)
275 .with_jitter_factor(0.0); let delay0 = config.calculate_delay(0);
278 let delay1 = config.calculate_delay(1);
279 let delay2 = config.calculate_delay(2);
280
281 assert!(delay1 > delay0);
283 assert!(delay2 > delay1);
284 }
285
286 #[tokio::test]
287 async fn test_retry_success_on_first_attempt() {
288 let config = RetryConfig::new().with_max_attempts(3);
289
290 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
291 let count_clone = call_count.clone();
292 let result = retry_with_backoff(config, move || {
293 count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
294 async { Ok::<i32, Error>(42) }
295 })
296 .await;
297
298 assert!(result.is_ok());
299 assert_eq!(result.unwrap(), 42);
300 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1); }
302
303 #[tokio::test]
304 async fn test_retry_success_after_failures() {
305 let config = RetryConfig::new()
306 .with_max_attempts(3)
307 .with_initial_delay(Duration::from_millis(10));
308
309 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
310 let count_clone = call_count.clone();
311 let result = retry_with_backoff(config, move || {
312 let count = count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst) + 1;
313 async move {
314 if count < 3 {
315 Err(Error::timeout())
316 } else {
317 Ok::<i32, Error>(42)
318 }
319 }
320 })
321 .await;
322
323 assert!(result.is_ok());
324 assert_eq!(result.unwrap(), 42);
325 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3); }
327
328 #[tokio::test]
329 async fn test_retry_exhausts_attempts() {
330 let config = RetryConfig::new()
331 .with_max_attempts(2)
332 .with_initial_delay(Duration::from_millis(10));
333
334 let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
335 let count_clone = call_count.clone();
336 let result = retry_with_backoff(config, move || {
337 count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
338 async { Err::<i32, Error>(Error::timeout()) }
339 })
340 .await;
341
342 assert!(result.is_err());
343 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2); }
345
346 #[test]
347 fn test_is_retryable_error() {
348 assert!(is_retryable_error(&Error::timeout()));
349 assert!(is_retryable_error(&Error::api(
350 "500 Internal Server Error".to_string()
351 )));
352 assert!(is_retryable_error(&Error::api(
353 "503 Service Unavailable".to_string()
354 )));
355 assert!(!is_retryable_error(&Error::config(
356 "Invalid config".to_string()
357 )));
358 assert!(!is_retryable_error(&Error::invalid_input(
359 "Bad input".to_string()
360 )));
361 }
362}