audiobook_forge/core/
retry.rs

1//! Error recovery and retry logic
2
3use anyhow::Result;
4use std::time::Duration;
5use tokio::time::sleep;
6
7/// Retry configuration
8#[derive(Debug, Clone)]
9pub struct RetryConfig {
10    /// Maximum number of retry attempts
11    pub max_retries: usize,
12    /// Initial delay between retries
13    pub initial_delay: Duration,
14    /// Maximum delay between retries
15    pub max_delay: Duration,
16    /// Backoff multiplier (exponential backoff)
17    pub backoff_multiplier: f64,
18}
19
20impl RetryConfig {
21    /// Create a new retry config with sensible defaults
22    pub fn new() -> Self {
23        Self {
24            max_retries: 2,
25            initial_delay: Duration::from_secs(1),
26            max_delay: Duration::from_secs(30),
27            backoff_multiplier: 2.0,
28        }
29    }
30
31    /// Create a retry config with custom settings
32    pub fn with_settings(
33        max_retries: usize,
34        initial_delay: Duration,
35        max_delay: Duration,
36        backoff_multiplier: f64,
37    ) -> Self {
38        Self {
39            max_retries,
40            initial_delay,
41            max_delay,
42            backoff_multiplier,
43        }
44    }
45
46    /// No retries
47    pub fn no_retry() -> Self {
48        Self {
49            max_retries: 0,
50            initial_delay: Duration::from_secs(0),
51            max_delay: Duration::from_secs(0),
52            backoff_multiplier: 1.0,
53        }
54    }
55
56    /// Calculate delay for retry attempt
57    fn calculate_delay(&self, attempt: usize) -> Duration {
58        if attempt == 0 {
59            return self.initial_delay;
60        }
61
62        let delay_secs = self.initial_delay.as_secs_f64()
63            * self.backoff_multiplier.powi(attempt as i32);
64
65        let delay = Duration::from_secs_f64(delay_secs);
66
67        // Clamp to max_delay
68        if delay > self.max_delay {
69            self.max_delay
70        } else {
71            delay
72        }
73    }
74}
75
76impl Default for RetryConfig {
77    fn default() -> Self {
78        Self::new()
79    }
80}
81
82/// Execute a function with retry logic
83pub async fn retry_async<F, Fut, T>(config: &RetryConfig, mut f: F) -> Result<T>
84where
85    F: FnMut() -> Fut,
86    Fut: std::future::Future<Output = Result<T>>,
87{
88    let mut last_error = None;
89
90    for attempt in 0..=config.max_retries {
91        match f().await {
92            Ok(result) => {
93                if attempt > 0 {
94                    tracing::info!("Retry successful after {} attempt(s)", attempt);
95                }
96                return Ok(result);
97            }
98            Err(e) => {
99                last_error = Some(e);
100
101                if attempt < config.max_retries {
102                    let delay = config.calculate_delay(attempt);
103                    tracing::warn!(
104                        "Attempt {} failed, retrying in {:?}...",
105                        attempt + 1,
106                        delay
107                    );
108                    sleep(delay).await;
109                } else {
110                    tracing::error!("All {} retry attempts failed", config.max_retries + 1);
111                }
112            }
113        }
114    }
115
116    // If we get here, all retries failed
117    Err(last_error.unwrap())
118}
119
120/// Error classification for smart retry logic
121#[derive(Debug, Clone, Copy, PartialEq, Eq)]
122pub enum ErrorType {
123    /// Transient errors (worth retrying)
124    Transient,
125    /// Permanent errors (no point retrying)
126    Permanent,
127}
128
129/// Classify an error to determine if retry is worthwhile
130pub fn classify_error(error: &anyhow::Error) -> ErrorType {
131    let error_msg = error.to_string().to_lowercase();
132
133    // Transient errors (worth retrying)
134    if error_msg.contains("timeout")
135        || error_msg.contains("connection")
136        || error_msg.contains("temporarily unavailable")
137        || error_msg.contains("too many open files")
138        || error_msg.contains("resource temporarily unavailable")
139    {
140        return ErrorType::Transient;
141    }
142
143    // Permanent errors (no point retrying)
144    if error_msg.contains("file not found")
145        || error_msg.contains("permission denied")
146        || error_msg.contains("invalid")
147        || error_msg.contains("unsupported")
148        || error_msg.contains("corrupted")
149    {
150        return ErrorType::Permanent;
151    }
152
153    // Default to transient (conservative approach)
154    ErrorType::Transient
155}
156
157/// Execute with smart retry (only retry transient errors)
158pub async fn smart_retry_async<F, Fut, T>(config: &RetryConfig, mut f: F) -> Result<T>
159where
160    F: FnMut() -> Fut,
161    Fut: std::future::Future<Output = Result<T>>,
162{
163    let mut last_error = None;
164
165    for attempt in 0..=config.max_retries {
166        match f().await {
167            Ok(result) => {
168                if attempt > 0 {
169                    tracing::info!("Smart retry successful after {} attempt(s)", attempt);
170                }
171                return Ok(result);
172            }
173            Err(e) => {
174                let error_type = classify_error(&e);
175
176                if error_type == ErrorType::Permanent {
177                    tracing::error!("Permanent error detected, not retrying: {}", e);
178                    return Err(e);
179                }
180
181                last_error = Some(e);
182
183                if attempt < config.max_retries {
184                    let delay = config.calculate_delay(attempt);
185                    tracing::warn!(
186                        "Transient error (attempt {}), retrying in {:?}...",
187                        attempt + 1,
188                        delay
189                    );
190                    sleep(delay).await;
191                } else {
192                    tracing::error!(
193                        "All {} retry attempts failed for transient error",
194                        config.max_retries + 1
195                    );
196                }
197            }
198        }
199    }
200
201    Err(last_error.unwrap())
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use std::sync::atomic::{AtomicUsize, Ordering};
208    use std::sync::Arc;
209
210    #[test]
211    fn test_retry_config_creation() {
212        let config = RetryConfig::new();
213        assert_eq!(config.max_retries, 2);
214        assert_eq!(config.initial_delay, Duration::from_secs(1));
215        assert_eq!(config.backoff_multiplier, 2.0);
216    }
217
218    #[test]
219    fn test_retry_config_no_retry() {
220        let config = RetryConfig::no_retry();
221        assert_eq!(config.max_retries, 0);
222    }
223
224    #[test]
225    fn test_calculate_delay() {
226        let config = RetryConfig::new();
227
228        assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
229        assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
230        assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
231        assert_eq!(config.calculate_delay(3), Duration::from_secs(8));
232
233        // Test max delay clamping
234        let config = RetryConfig::with_settings(
235            5,
236            Duration::from_secs(1),
237            Duration::from_secs(5),
238            2.0,
239        );
240        assert_eq!(config.calculate_delay(10), Duration::from_secs(5));
241    }
242
243    #[tokio::test]
244    async fn test_retry_async_success_first_try() {
245        let config = RetryConfig::new();
246        let counter = Arc::new(AtomicUsize::new(0));
247
248        let result: Result<i32> = retry_async(&config, || {
249            let counter = Arc::clone(&counter);
250            async move {
251                counter.fetch_add(1, Ordering::Relaxed);
252                Ok::<i32, anyhow::Error>(42)
253            }
254        })
255        .await;
256
257        assert!(result.is_ok());
258        assert_eq!(result.unwrap(), 42);
259        assert_eq!(counter.load(Ordering::Relaxed), 1);
260    }
261
262    #[tokio::test]
263    async fn test_retry_async_success_after_retries() {
264        let config = RetryConfig::with_settings(
265            3,
266            Duration::from_millis(10),
267            Duration::from_millis(100),
268            2.0,
269        );
270        let counter = Arc::new(AtomicUsize::new(0));
271
272        let result = retry_async(&config, || {
273            let counter = Arc::clone(&counter);
274            async move {
275                let count = counter.fetch_add(1, Ordering::Relaxed);
276                if count < 2 {
277                    anyhow::bail!("Transient error");
278                }
279                Ok::<i32, anyhow::Error>(42)
280            }
281        })
282        .await;
283
284        assert!(result.is_ok());
285        assert_eq!(result.unwrap(), 42);
286        assert_eq!(counter.load(Ordering::Relaxed), 3);
287    }
288
289    #[tokio::test]
290    async fn test_retry_async_all_fail() {
291        let config = RetryConfig::with_settings(
292            2,
293            Duration::from_millis(10),
294            Duration::from_millis(100),
295            2.0,
296        );
297        let counter = Arc::new(AtomicUsize::new(0));
298
299        let result: Result<i32> = retry_async(&config, || {
300            let counter = Arc::clone(&counter);
301            async move {
302                counter.fetch_add(1, Ordering::Relaxed);
303                anyhow::bail!("Always fails")
304            }
305        })
306        .await;
307
308        assert!(result.is_err());
309        assert_eq!(counter.load(Ordering::Relaxed), 3); // Initial + 2 retries
310    }
311
312    #[test]
313    fn test_classify_error() {
314        let transient = anyhow::anyhow!("Connection timeout");
315        assert_eq!(classify_error(&transient), ErrorType::Transient);
316
317        let permanent = anyhow::anyhow!("File not found");
318        assert_eq!(classify_error(&permanent), ErrorType::Permanent);
319
320        let unknown = anyhow::anyhow!("Some random error");
321        assert_eq!(classify_error(&unknown), ErrorType::Transient);
322    }
323
324    #[tokio::test]
325    async fn test_smart_retry_permanent_error() {
326        let config = RetryConfig::new();
327        let counter = Arc::new(AtomicUsize::new(0));
328
329        let result: Result<i32> = smart_retry_async(&config, || {
330            let counter = Arc::clone(&counter);
331            async move {
332                counter.fetch_add(1, Ordering::Relaxed);
333                anyhow::bail!("File not found")
334            }
335        })
336        .await;
337
338        assert!(result.is_err());
339        assert_eq!(counter.load(Ordering::Relaxed), 1); // No retries for permanent error
340    }
341
342    #[tokio::test]
343    async fn test_smart_retry_transient_error() {
344        let config = RetryConfig::with_settings(
345            2,
346            Duration::from_millis(10),
347            Duration::from_millis(100),
348            2.0,
349        );
350        let counter = Arc::new(AtomicUsize::new(0));
351
352        let result = smart_retry_async(&config, || {
353            let counter = Arc::clone(&counter);
354            async move {
355                let count = counter.fetch_add(1, Ordering::Relaxed);
356                if count < 2 {
357                    anyhow::bail!("Connection timeout");
358                }
359                Ok::<i32, anyhow::Error>(42)
360            }
361        })
362        .await;
363
364        assert!(result.is_ok());
365        assert_eq!(result.unwrap(), 42);
366        assert_eq!(counter.load(Ordering::Relaxed), 3);
367    }
368}