1use anyhow::Result;
4use std::time::Duration;
5use tokio::time::sleep;
6
7#[derive(Debug, Clone)]
9pub struct RetryConfig {
10 pub max_retries: usize,
12 pub initial_delay: Duration,
14 pub max_delay: Duration,
16 pub backoff_multiplier: f64,
18}
19
20impl RetryConfig {
21 pub fn new() -> Self {
23 Self {
24 max_retries: 2,
25 initial_delay: Duration::from_secs(1),
26 max_delay: Duration::from_secs(30),
27 backoff_multiplier: 2.0,
28 }
29 }
30
31 pub fn with_settings(
33 max_retries: usize,
34 initial_delay: Duration,
35 max_delay: Duration,
36 backoff_multiplier: f64,
37 ) -> Self {
38 Self {
39 max_retries,
40 initial_delay,
41 max_delay,
42 backoff_multiplier,
43 }
44 }
45
46 pub fn no_retry() -> Self {
48 Self {
49 max_retries: 0,
50 initial_delay: Duration::from_secs(0),
51 max_delay: Duration::from_secs(0),
52 backoff_multiplier: 1.0,
53 }
54 }
55
56 fn calculate_delay(&self, attempt: usize) -> Duration {
58 if attempt == 0 {
59 return self.initial_delay;
60 }
61
62 let delay_secs = self.initial_delay.as_secs_f64()
63 * self.backoff_multiplier.powi(attempt as i32);
64
65 let delay = Duration::from_secs_f64(delay_secs);
66
67 if delay > self.max_delay {
69 self.max_delay
70 } else {
71 delay
72 }
73 }
74}
75
76impl Default for RetryConfig {
77 fn default() -> Self {
78 Self::new()
79 }
80}
81
82pub async fn retry_async<F, Fut, T>(config: &RetryConfig, mut f: F) -> Result<T>
84where
85 F: FnMut() -> Fut,
86 Fut: std::future::Future<Output = Result<T>>,
87{
88 let mut last_error = None;
89
90 for attempt in 0..=config.max_retries {
91 match f().await {
92 Ok(result) => {
93 if attempt > 0 {
94 tracing::info!("Retry successful after {} attempt(s)", attempt);
95 }
96 return Ok(result);
97 }
98 Err(e) => {
99 last_error = Some(e);
100
101 if attempt < config.max_retries {
102 let delay = config.calculate_delay(attempt);
103 tracing::warn!(
104 "Attempt {} failed, retrying in {:?}...",
105 attempt + 1,
106 delay
107 );
108 sleep(delay).await;
109 } else {
110 tracing::error!("All {} retry attempts failed", config.max_retries + 1);
111 }
112 }
113 }
114 }
115
116 Err(last_error.unwrap())
118}
119
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum ErrorType {
123 Transient,
125 Permanent,
127}
128
129pub fn classify_error(error: &anyhow::Error) -> ErrorType {
131 let error_msg = error.to_string().to_lowercase();
132
133 if error_msg.contains("timeout")
135 || error_msg.contains("connection")
136 || error_msg.contains("temporarily unavailable")
137 || error_msg.contains("too many open files")
138 || error_msg.contains("resource temporarily unavailable")
139 || error_msg.contains("resource deadlock")
140 || error_msg.contains("try again")
141 {
142 return ErrorType::Transient;
143 }
144
145 if error_msg.contains("file not found")
147 || error_msg.contains("no such file")
148 || error_msg.contains("permission denied")
149 || error_msg.contains("access denied")
150 || error_msg.contains("read-only")
151 || error_msg.contains("disk full")
152 || error_msg.contains("no space left")
153 {
154 return ErrorType::Permanent;
155 }
156
157 if error_msg.contains("invalid data found")
159 || error_msg.contains("codec not found")
160 || error_msg.contains("unsupported codec")
161 || error_msg.contains("unknown codec")
162 || error_msg.contains("invalid audio")
163 || error_msg.contains("invalid sample rate")
164 || error_msg.contains("invalid bit rate")
165 || error_msg.contains("invalid channel")
166 || error_msg.contains("not supported")
167 || error_msg.contains("does not contain any stream")
168 || error_msg.contains("no decoder")
169 || error_msg.contains("no encoder")
170 || error_msg.contains("moov atom not found")
171 || error_msg.contains("invalid argument")
172 || error_msg.contains("protocol not found")
173 {
174 return ErrorType::Permanent;
175 }
176
177 if error_msg.contains("corrupted")
179 || error_msg.contains("corrupt")
180 || error_msg.contains("truncated")
181 || error_msg.contains("header missing")
182 || error_msg.contains("malformed")
183 || error_msg.contains("end of file")
184 {
185 return ErrorType::Permanent;
186 }
187
188 ErrorType::Transient
190}
191
192pub async fn smart_retry_async<F, Fut, T>(config: &RetryConfig, mut f: F) -> Result<T>
194where
195 F: FnMut() -> Fut,
196 Fut: std::future::Future<Output = Result<T>>,
197{
198 let mut last_error = None;
199
200 for attempt in 0..=config.max_retries {
201 match f().await {
202 Ok(result) => {
203 if attempt > 0 {
204 tracing::info!("Smart retry successful after {} attempt(s)", attempt);
205 }
206 return Ok(result);
207 }
208 Err(e) => {
209 let error_type = classify_error(&e);
210
211 if error_type == ErrorType::Permanent {
212 tracing::error!("Permanent error detected, not retrying: {}", e);
213 return Err(e);
214 }
215
216 if attempt < config.max_retries {
217 let delay = config.calculate_delay(attempt);
218 tracing::warn!(
219 "Transient error on attempt {}: {}",
220 attempt + 1,
221 e
222 );
223 tracing::warn!(
224 "Retrying in {:?}... ({} attempts remaining)",
225 delay,
226 config.max_retries - attempt
227 );
228 sleep(delay).await;
229 } else {
230 tracing::error!(
231 "All {} retry attempts exhausted. Final error: {}",
232 config.max_retries + 1,
233 e
234 );
235 }
236
237 last_error = Some(e);
238 }
239 }
240 }
241
242 Err(last_error.unwrap())
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use std::sync::atomic::{AtomicUsize, Ordering};
249 use std::sync::Arc;
250
251 #[test]
252 fn test_retry_config_creation() {
253 let config = RetryConfig::new();
254 assert_eq!(config.max_retries, 2);
255 assert_eq!(config.initial_delay, Duration::from_secs(1));
256 assert_eq!(config.backoff_multiplier, 2.0);
257 }
258
259 #[test]
260 fn test_retry_config_no_retry() {
261 let config = RetryConfig::no_retry();
262 assert_eq!(config.max_retries, 0);
263 }
264
265 #[test]
266 fn test_calculate_delay() {
267 let config = RetryConfig::new();
268
269 assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
270 assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
271 assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
272 assert_eq!(config.calculate_delay(3), Duration::from_secs(8));
273
274 let config = RetryConfig::with_settings(
276 5,
277 Duration::from_secs(1),
278 Duration::from_secs(5),
279 2.0,
280 );
281 assert_eq!(config.calculate_delay(10), Duration::from_secs(5));
282 }
283
284 #[tokio::test]
285 async fn test_retry_async_success_first_try() {
286 let config = RetryConfig::new();
287 let counter = Arc::new(AtomicUsize::new(0));
288
289 let result: Result<i32> = retry_async(&config, || {
290 let counter = Arc::clone(&counter);
291 async move {
292 counter.fetch_add(1, Ordering::Relaxed);
293 Ok::<i32, anyhow::Error>(42)
294 }
295 })
296 .await;
297
298 assert!(result.is_ok());
299 assert_eq!(result.unwrap(), 42);
300 assert_eq!(counter.load(Ordering::Relaxed), 1);
301 }
302
303 #[tokio::test]
304 async fn test_retry_async_success_after_retries() {
305 let config = RetryConfig::with_settings(
306 3,
307 Duration::from_millis(10),
308 Duration::from_millis(100),
309 2.0,
310 );
311 let counter = Arc::new(AtomicUsize::new(0));
312
313 let result = retry_async(&config, || {
314 let counter = Arc::clone(&counter);
315 async move {
316 let count = counter.fetch_add(1, Ordering::Relaxed);
317 if count < 2 {
318 anyhow::bail!("Transient error");
319 }
320 Ok::<i32, anyhow::Error>(42)
321 }
322 })
323 .await;
324
325 assert!(result.is_ok());
326 assert_eq!(result.unwrap(), 42);
327 assert_eq!(counter.load(Ordering::Relaxed), 3);
328 }
329
330 #[tokio::test]
331 async fn test_retry_async_all_fail() {
332 let config = RetryConfig::with_settings(
333 2,
334 Duration::from_millis(10),
335 Duration::from_millis(100),
336 2.0,
337 );
338 let counter = Arc::new(AtomicUsize::new(0));
339
340 let result: Result<i32> = retry_async(&config, || {
341 let counter = Arc::clone(&counter);
342 async move {
343 counter.fetch_add(1, Ordering::Relaxed);
344 anyhow::bail!("Always fails")
345 }
346 })
347 .await;
348
349 assert!(result.is_err());
350 assert_eq!(counter.load(Ordering::Relaxed), 3); }
352
353 #[test]
354 fn test_classify_error() {
355 let transient = anyhow::anyhow!("Connection timeout");
356 assert_eq!(classify_error(&transient), ErrorType::Transient);
357
358 let permanent = anyhow::anyhow!("File not found");
359 assert_eq!(classify_error(&permanent), ErrorType::Permanent);
360
361 let unknown = anyhow::anyhow!("Some random error");
362 assert_eq!(classify_error(&unknown), ErrorType::Transient);
363 }
364
365 #[tokio::test]
366 async fn test_smart_retry_permanent_error() {
367 let config = RetryConfig::new();
368 let counter = Arc::new(AtomicUsize::new(0));
369
370 let result: Result<i32> = smart_retry_async(&config, || {
371 let counter = Arc::clone(&counter);
372 async move {
373 counter.fetch_add(1, Ordering::Relaxed);
374 anyhow::bail!("File not found")
375 }
376 })
377 .await;
378
379 assert!(result.is_err());
380 assert_eq!(counter.load(Ordering::Relaxed), 1); }
382
383 #[tokio::test]
384 async fn test_smart_retry_transient_error() {
385 let config = RetryConfig::with_settings(
386 2,
387 Duration::from_millis(10),
388 Duration::from_millis(100),
389 2.0,
390 );
391 let counter = Arc::new(AtomicUsize::new(0));
392
393 let result = smart_retry_async(&config, || {
394 let counter = Arc::clone(&counter);
395 async move {
396 let count = counter.fetch_add(1, Ordering::Relaxed);
397 if count < 2 {
398 anyhow::bail!("Connection timeout");
399 }
400 Ok::<i32, anyhow::Error>(42)
401 }
402 })
403 .await;
404
405 assert!(result.is_ok());
406 assert_eq!(result.unwrap(), 42);
407 assert_eq!(counter.load(Ordering::Relaxed), 3);
408 }
409}