1use std::fmt::Display;
6use std::future::Future;
7use std::time::Duration;
8use tokio::time::sleep;
9
10pub const DEFAULT_MAX_RETRIES: u32 = 10;
12
13pub const BASE_DELAY_MS: u64 = 500;
15
16pub const MAX_DELAY_MS: u64 = 32000;
18
19#[derive(Debug)]
21pub struct RetryError<E> {
22 pub original_error: E,
23 pub attempts: u32,
24}
25
26impl<E: Display + Clone> Display for RetryError<E> {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 write!(
29 f,
30 "RetryError: {} after {} attempts",
31 self.original_error, self.attempts
32 )
33 }
34}
35
36impl<E: Display + Clone + std::fmt::Debug> std::error::Error for RetryError<E> {
37 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
38 None
39 }
40}
41
42pub type RetryResult<T, E> = Result<T, RetryError<E>>;
44
45pub struct RetryConfig {
47 pub max_retries: u32,
49 pub base_delay_ms: u64,
51 pub max_delay_ms: u64,
53 pub jitter: bool,
55 pub should_retry: Option<Box<dyn Fn(&str) -> bool + Send + Sync>>,
57}
58
59impl RetryConfig {
60 pub fn new() -> Self {
62 Self {
63 max_retries: DEFAULT_MAX_RETRIES,
64 base_delay_ms: BASE_DELAY_MS,
65 max_delay_ms: MAX_DELAY_MS,
66 jitter: true,
67 should_retry: None,
68 }
69 }
70}
71
72impl Default for RetryConfig {
73 fn default() -> Self {
74 Self::new()
75 }
76}
77
78pub fn get_retry_delay(attempt: u32, retry_after_ms: Option<u64>, config: &RetryConfig) -> u64 {
80 if let Some(retry_after) = retry_after_ms {
82 return retry_after;
83 }
84
85 let base_delay = config
87 .base_delay_ms
88 .saturating_mul(2u64.saturating_pow(attempt - 1));
89 let delay = base_delay.min(config.max_delay_ms);
90
91 if config.jitter {
93 let jitter = (delay as f64 * 0.25 * rand_jitter()) as u64;
94 delay + jitter
95 } else {
96 delay
97 }
98}
99
100fn rand_jitter() -> f64 {
102 use std::time::{SystemTime, UNIX_EPOCH};
103 let nanos = SystemTime::now()
104 .duration_since(UNIX_EPOCH)
105 .unwrap_or_default()
106 .subsec_nanos();
107 (nanos as f64) / (u32::MAX as f64)
108}
109
110pub async fn retry_async<T, E, F, Fut>(mut operation: F, config: RetryConfig) -> RetryResult<T, E>
120where
121 F: FnMut() -> Fut,
122 Fut: Future<Output = Result<T, E>>,
123 E: std::fmt::Display + Clone,
124{
125 let mut last_error: Option<E> = None;
126
127 for attempt in 1..=config.max_retries + 1 {
128 match operation().await {
129 Ok(result) => return Ok(result),
130 Err(e) => {
131 last_error = Some(e.clone());
132
133 if let Some(should_retry) = &config.should_retry {
135 let error_str = format!("{}", e);
136 if !should_retry(&error_str) {
137 return Err(RetryError {
138 original_error: e,
139 attempts: attempt,
140 });
141 }
142 }
143
144 if attempt <= config.max_retries {
146 let delay = get_retry_delay(attempt, None, &config);
147 sleep(Duration::from_millis(delay)).await;
148 }
149 }
150 }
151 }
152
153 Err(RetryError {
154 original_error: last_error.unwrap_or_else(|| {
155 panic!("retry_async called with max_retries=0 and no error occurred")
156 }),
157 attempts: config.max_retries + 1,
158 })
159}
160
161pub async fn retry_with_retry_after<T, E, F, Fut>(
168 mut operation: F,
169 config: RetryConfig,
170 get_retry_after: impl Fn(&E) -> Option<u64>,
171) -> RetryResult<T, E>
172where
173 F: FnMut(u32) -> Fut,
174 Fut: Future<Output = Result<T, E>>,
175 E: std::fmt::Display + Clone,
176{
177 let mut last_error: Option<E> = None;
178
179 for attempt in 1..=config.max_retries + 1 {
180 match operation(attempt).await {
181 Ok(result) => return Ok(result),
182 Err(e) => {
183 last_error = Some(e.clone());
184
185 if let Some(should_retry) = &config.should_retry {
187 let error_str = format!("{}", e);
188 if !should_retry(&error_str) {
189 return Err(RetryError {
190 original_error: e,
191 attempts: attempt,
192 });
193 }
194 }
195
196 if attempt <= config.max_retries {
198 let retry_after_ms = get_retry_after(&e);
199 let delay = get_retry_delay(attempt, retry_after_ms, &config);
200 sleep(Duration::from_millis(delay)).await;
201 }
202 }
203 }
204 }
205
206 Err(RetryError {
207 original_error: last_error.unwrap_or_else(|| {
208 panic!("retry_with_retry_after called with max_retries=0 and no error occurred")
209 }),
210 attempts: config.max_retries + 1,
211 })
212}
213
214pub fn is_rate_limit_error(error: &str) -> bool {
216 error.contains("429") || error.to_lowercase().contains("rate limit")
217}
218
219pub fn is_service_unavailable_error(error: &str) -> bool {
221 error.contains("529") || error.contains("overloaded")
222}
223
224pub fn is_retryable_error(error: &str) -> bool {
226 is_rate_limit_error(error)
227 || is_service_unavailable_error(error)
228 || is_connection_error(error)
229 || is_server_error(error)
230}
231
232pub fn is_connection_error(error: &str) -> bool {
234 let error_str = error.to_lowercase();
235 error_str.contains("connection")
236 || error_str.contains("econnreset")
237 || error_str.contains("econnrefused")
238 || error_str.contains("epipe")
239 || error_str.contains("timeout")
240}
241
242pub fn is_server_error(error: &str) -> bool {
244 error.contains("500")
246 || error.contains("501")
247 || error.contains("502")
248 || error.contains("503")
249 || error.contains("504")
250}
251
252pub fn rate_limit_config() -> RetryConfig {
254 RetryConfig {
255 max_retries: 5,
256 base_delay_ms: 1000,
257 max_delay_ms: 60000,
258 jitter: true,
259 should_retry: Some(Box::new(|e| is_rate_limit_error(e))),
260 }
261}
262
263pub fn service_unavailable_config() -> RetryConfig {
265 RetryConfig {
266 max_retries: 3,
267 base_delay_ms: 2000,
268 max_delay_ms: 30000,
269 jitter: true,
270 should_retry: Some(Box::new(|e| is_service_unavailable_error(e))),
271 }
272}
273
274pub fn default_retry_config() -> RetryConfig {
276 RetryConfig::default()
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[tokio::test]
284 async fn test_retry_success_first_try() {
285 let call_count = std::sync::atomic::AtomicU32::new(0);
286 let operation = || {
287 let call_count = &call_count;
288 async move {
289 call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
290 Ok::<_, &'static str>("success")
291 }
292 };
293
294 let result = retry_async(operation, RetryConfig::default()).await;
295 assert!(result.is_ok());
296 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
297 }
298
299 #[tokio::test]
300 async fn test_retry_success_after_failures() {
301 let call_count = std::sync::atomic::AtomicU32::new(0);
302 let operation = || {
303 let call_count = &call_count;
304 async move {
305 let count = call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
306 if count < 2 {
307 Err("temporary error")
308 } else {
309 Ok("success")
310 }
311 }
312 };
313
314 let result = retry_async(operation, RetryConfig::default()).await;
315 assert!(result.is_ok());
316 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
317 }
318
319 #[tokio::test]
320 async fn test_retry_exhausted() {
321 let call_count = std::sync::atomic::AtomicU32::new(0);
322 let operation = || {
323 let call_count = &call_count;
324 async move {
325 call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
326 Err::<String, _>("persistent error")
327 }
328 };
329
330 let config = RetryConfig {
331 max_retries: 3,
332 ..Default::default()
333 };
334 let result = retry_async(operation, config).await;
335 assert!(result.is_err());
336 assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 4);
337 }
338
339 #[tokio::test]
340 async fn test_retry_with_should_retry() {
341 let operation = || async move { Err::<String, _>("rate limit") };
342
343 let config = RetryConfig {
344 max_retries: 3,
345 should_retry: Some(Box::new(|e| format!("{}", e).contains("rate limit"))),
346 ..Default::default()
347 };
348 let result = retry_async(operation, config).await;
349 assert!(result.is_err());
350 }
351
352 #[test]
353 fn test_get_retry_delay_exponential() {
354 let config = RetryConfig {
355 base_delay_ms: 100,
356 max_delay_ms: 10000,
357 jitter: false,
358 ..Default::default()
359 };
360
361 assert_eq!(get_retry_delay(1, None, &config), 100);
362 assert_eq!(get_retry_delay(2, None, &config), 200);
363 assert_eq!(get_retry_delay(3, None, &config), 400);
364 assert_eq!(get_retry_delay(4, None, &config), 800);
365 }
366
367 #[test]
368 fn test_get_retry_delay_max_cap() {
369 let config = RetryConfig {
370 base_delay_ms: 1000,
371 max_delay_ms: 500,
372 jitter: false,
373 ..Default::default()
374 };
375
376 assert_eq!(get_retry_delay(10, None, &config), 500);
378 }
379
380 #[test]
381 fn test_get_retry_delay_with_retry_after() {
382 let config = RetryConfig::default();
383
384 let delay = get_retry_delay(1, Some(5000), &config);
386 assert_eq!(delay, 5000);
387 }
388
389 #[test]
390 fn test_is_rate_limit_error() {
391 assert!(is_rate_limit_error(&"429 Too Many Requests"));
392 assert!(is_rate_limit_error(&"rate limit exceeded"));
393 assert!(!is_rate_limit_error(&"404 Not Found"));
394 }
395
396 #[test]
397 fn test_is_service_unavailable_error() {
398 assert!(is_service_unavailable_error(&"529 Service Unavailable"));
399 assert!(is_service_unavailable_error(&"server overloaded"));
400 assert!(!is_service_unavailable_error(&"400 Bad Request"));
401 }
402
403 #[test]
404 fn test_is_connection_error() {
405 assert!(is_connection_error(&"connection refused"));
406 assert!(is_connection_error(&"ECONNRESET"));
407 assert!(!is_connection_error(&"404 Not Found"));
408 }
409
410 #[test]
411 fn test_is_server_error() {
412 assert!(is_server_error(&"500 Internal Server Error"));
413 assert!(is_server_error(&"503 Service Unavailable"));
414 assert!(!is_server_error(&"400 Bad Request"));
415 }
416}