1use anyhow::Result;
4use std::time::Duration;
5use tokio::time::sleep;
6
7#[derive(Debug, Clone)]
9pub struct RetryConfig {
10 pub max_retries: usize,
12 pub initial_delay: Duration,
14 pub max_delay: Duration,
16 pub backoff_multiplier: f64,
18}
19
20impl RetryConfig {
21 pub fn new() -> Self {
23 Self {
24 max_retries: 2,
25 initial_delay: Duration::from_secs(1),
26 max_delay: Duration::from_secs(30),
27 backoff_multiplier: 2.0,
28 }
29 }
30
31 pub fn with_settings(
33 max_retries: usize,
34 initial_delay: Duration,
35 max_delay: Duration,
36 backoff_multiplier: f64,
37 ) -> Self {
38 Self {
39 max_retries,
40 initial_delay,
41 max_delay,
42 backoff_multiplier,
43 }
44 }
45
46 pub fn no_retry() -> Self {
48 Self {
49 max_retries: 0,
50 initial_delay: Duration::from_secs(0),
51 max_delay: Duration::from_secs(0),
52 backoff_multiplier: 1.0,
53 }
54 }
55
56 pub fn calculate_delay(&self, attempt: usize) -> Duration {
58 if attempt == 0 {
59 return self.initial_delay;
60 }
61
62 let delay_secs = self.initial_delay.as_secs_f64()
63 * self.backoff_multiplier.powi(attempt as i32);
64
65 let delay = Duration::from_secs_f64(delay_secs);
66
67 if delay > self.max_delay {
69 self.max_delay
70 } else {
71 delay
72 }
73 }
74}
75
76impl Default for RetryConfig {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82pub async fn retry_async<F, Fut, T>(config: &RetryConfig, mut f: F) -> Result<T>
84where
85 F: FnMut() -> Fut,
86 Fut: std::future::Future<Output = Result<T>>,
87{
88 let mut last_error = None;
89
90 for attempt in 0..=config.max_retries {
91 match f().await {
92 Ok(result) => {
93 if attempt > 0 {
94 tracing::info!("Retry successful after {} attempt(s)", attempt);
95 }
96 return Ok(result);
97 }
98 Err(e) => {
99 last_error = Some(e);
100
101 if attempt < config.max_retries {
102 let delay = config.calculate_delay(attempt);
103 tracing::warn!(
104 "Attempt {} failed, retrying in {:?}...",
105 attempt + 1,
106 delay
107 );
108 sleep(delay).await;
109 } else {
110 tracing::error!("All {} retry attempts failed", config.max_retries + 1);
111 }
112 }
113 }
114 }
115
116 Err(last_error.unwrap())
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum ErrorType {
123 Transient,
125 Permanent,
127}
128
129pub fn classify_error(error: &anyhow::Error) -> ErrorType {
131 let error_msg = error.to_string().to_lowercase();
132
133 if error_msg.contains("429") || error_msg.contains("rate limit") {
136 return ErrorType::Transient;
137 }
138
139 if error_msg.contains("500") || error_msg.contains("502")
141 || error_msg.contains("503") || error_msg.contains("504")
142 || error_msg.contains("server error")
143 {
144 return ErrorType::Transient;
145 }
146
147 if error_msg.contains("400") || error_msg.contains("401")
149 || error_msg.contains("403") || error_msg.contains("404")
150 || error_msg.contains("client error")
151 {
152 return ErrorType::Permanent;
153 }
154
155 if error_msg.contains("timeout")
157 || error_msg.contains("connection")
158 || error_msg.contains("temporarily unavailable")
159 || error_msg.contains("too many open files")
160 || error_msg.contains("resource temporarily unavailable")
161 || error_msg.contains("resource deadlock")
162 || error_msg.contains("try again")
163 {
164 return ErrorType::Transient;
165 }
166
167 if error_msg.contains("file not found")
169 || error_msg.contains("no such file")
170 || error_msg.contains("permission denied")
171 || error_msg.contains("access denied")
172 || error_msg.contains("read-only")
173 || error_msg.contains("disk full")
174 || error_msg.contains("no space left")
175 {
176 return ErrorType::Permanent;
177 }
178
179 if error_msg.contains("invalid data found")
181 || error_msg.contains("codec not found")
182 || error_msg.contains("unsupported codec")
183 || error_msg.contains("unknown codec")
184 || error_msg.contains("invalid audio")
185 || error_msg.contains("invalid sample rate")
186 || error_msg.contains("invalid bit rate")
187 || error_msg.contains("invalid channel")
188 || error_msg.contains("not supported")
189 || error_msg.contains("does not contain any stream")
190 || error_msg.contains("no decoder")
191 || error_msg.contains("no encoder")
192 || error_msg.contains("moov atom not found")
193 || error_msg.contains("invalid argument")
194 || error_msg.contains("protocol not found")
195 {
196 return ErrorType::Permanent;
197 }
198
199 if error_msg.contains("corrupted")
201 || error_msg.contains("corrupt")
202 || error_msg.contains("truncated")
203 || error_msg.contains("header missing")
204 || error_msg.contains("malformed")
205 || error_msg.contains("end of file")
206 {
207 return ErrorType::Permanent;
208 }
209
210 ErrorType::Transient
212}
213
214pub async fn smart_retry_async<F, Fut, T>(config: &RetryConfig, mut f: F) -> Result<T>
216where
217 F: FnMut() -> Fut,
218 Fut: std::future::Future<Output = Result<T>>,
219{
220 let mut last_error = None;
221
222 for attempt in 0..=config.max_retries {
223 match f().await {
224 Ok(result) => {
225 if attempt > 0 {
226 tracing::info!("Smart retry successful after {} attempt(s)", attempt);
227 }
228 return Ok(result);
229 }
230 Err(e) => {
231 let error_type = classify_error(&e);
232
233 if error_type == ErrorType::Permanent {
234 tracing::error!("Permanent error detected, not retrying: {:?}", e);
235 return Err(e);
236 }
237
238 if attempt < config.max_retries {
239 let delay = config.calculate_delay(attempt);
240 tracing::warn!(
241 "Transient error on attempt {}: {:?}",
242 attempt + 1,
243 e
244 );
245 tracing::warn!(
246 "Retrying in {:?}... ({} attempts remaining)",
247 delay,
248 config.max_retries - attempt
249 );
250 sleep(delay).await;
251 } else {
252 tracing::error!(
253 "All {} retry attempts exhausted. Final error: {:?}",
254 config.max_retries + 1,
255 e
256 );
257 }
258
259 last_error = Some(e);
260 }
261 }
262 }
263
264 Err(last_error.unwrap())
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use std::sync::atomic::{AtomicUsize, Ordering};
271 use std::sync::Arc;
272
273 #[test]
274 fn test_retry_config_creation() {
275 let config = RetryConfig::new();
276 assert_eq!(config.max_retries, 2);
277 assert_eq!(config.initial_delay, Duration::from_secs(1));
278 assert_eq!(config.backoff_multiplier, 2.0);
279 }
280
281 #[test]
282 fn test_retry_config_no_retry() {
283 let config = RetryConfig::no_retry();
284 assert_eq!(config.max_retries, 0);
285 }
286
287 #[test]
288 fn test_calculate_delay() {
289 let config = RetryConfig::new();
290
291 assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
292 assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
293 assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
294 assert_eq!(config.calculate_delay(3), Duration::from_secs(8));
295
296 let config = RetryConfig::with_settings(
298 5,
299 Duration::from_secs(1),
300 Duration::from_secs(5),
301 2.0,
302 );
303 assert_eq!(config.calculate_delay(10), Duration::from_secs(5));
304 }
305
306 #[tokio::test]
307 async fn test_retry_async_success_first_try() {
308 let config = RetryConfig::new();
309 let counter = Arc::new(AtomicUsize::new(0));
310
311 let result: Result<i32> = retry_async(&config, || {
312 let counter = Arc::clone(&counter);
313 async move {
314 counter.fetch_add(1, Ordering::Relaxed);
315 Ok::<i32, anyhow::Error>(42)
316 }
317 })
318 .await;
319
320 assert!(result.is_ok());
321 assert_eq!(result.unwrap(), 42);
322 assert_eq!(counter.load(Ordering::Relaxed), 1);
323 }
324
325 #[tokio::test]
326 async fn test_retry_async_success_after_retries() {
327 let config = RetryConfig::with_settings(
328 3,
329 Duration::from_millis(10),
330 Duration::from_millis(100),
331 2.0,
332 );
333 let counter = Arc::new(AtomicUsize::new(0));
334
335 let result = retry_async(&config, || {
336 let counter = Arc::clone(&counter);
337 async move {
338 let count = counter.fetch_add(1, Ordering::Relaxed);
339 if count < 2 {
340 anyhow::bail!("Transient error");
341 }
342 Ok::<i32, anyhow::Error>(42)
343 }
344 })
345 .await;
346
347 assert!(result.is_ok());
348 assert_eq!(result.unwrap(), 42);
349 assert_eq!(counter.load(Ordering::Relaxed), 3);
350 }
351
352 #[tokio::test]
353 async fn test_retry_async_all_fail() {
354 let config = RetryConfig::with_settings(
355 2,
356 Duration::from_millis(10),
357 Duration::from_millis(100),
358 2.0,
359 );
360 let counter = Arc::new(AtomicUsize::new(0));
361
362 let result: Result<i32> = retry_async(&config, || {
363 let counter = Arc::clone(&counter);
364 async move {
365 counter.fetch_add(1, Ordering::Relaxed);
366 anyhow::bail!("Always fails")
367 }
368 })
369 .await;
370
371 assert!(result.is_err());
372 assert_eq!(counter.load(Ordering::Relaxed), 3); }
374
375 #[test]
376 fn test_classify_error() {
377 let transient = anyhow::anyhow!("Connection timeout");
378 assert_eq!(classify_error(&transient), ErrorType::Transient);
379
380 let permanent = anyhow::anyhow!("File not found");
381 assert_eq!(classify_error(&permanent), ErrorType::Permanent);
382
383 let unknown = anyhow::anyhow!("Some random error");
384 assert_eq!(classify_error(&unknown), ErrorType::Transient);
385 }
386
387 #[tokio::test]
388 async fn test_smart_retry_permanent_error() {
389 let config = RetryConfig::new();
390 let counter = Arc::new(AtomicUsize::new(0));
391
392 let result: Result<i32> = smart_retry_async(&config, || {
393 let counter = Arc::clone(&counter);
394 async move {
395 counter.fetch_add(1, Ordering::Relaxed);
396 anyhow::bail!("File not found")
397 }
398 })
399 .await;
400
401 assert!(result.is_err());
402 assert_eq!(counter.load(Ordering::Relaxed), 1); }
404
405 #[tokio::test]
406 async fn test_smart_retry_transient_error() {
407 let config = RetryConfig::with_settings(
408 2,
409 Duration::from_millis(10),
410 Duration::from_millis(100),
411 2.0,
412 );
413 let counter = Arc::new(AtomicUsize::new(0));
414
415 let result = smart_retry_async(&config, || {
416 let counter = Arc::clone(&counter);
417 async move {
418 let count = counter.fetch_add(1, Ordering::Relaxed);
419 if count < 2 {
420 anyhow::bail!("Connection timeout");
421 }
422 Ok::<i32, anyhow::Error>(42)
423 }
424 })
425 .await;
426
427 assert!(result.is_ok());
428 assert_eq!(result.unwrap(), 42);
429 assert_eq!(counter.load(Ordering::Relaxed), 3);
430 }
431}