Skip to main content

research_master/utils/
retry.rs

1//! Retry utilities with exponential backoff for resilient API calls.
2
3use std::time::Duration;
4use tokio::time::{sleep, timeout};
5
6use crate::sources::SourceError;
7
8/// Configuration for retry behavior
9#[derive(Debug, Clone, Copy)]
10pub struct RetryConfig {
11    /// Maximum number of retry attempts
12    pub max_attempts: u32,
13    /// Initial delay between retries
14    pub initial_delay: Duration,
15    /// Maximum delay between retries
16    pub max_delay: Duration,
17    /// Multiplier for exponential backoff
18    pub backoff_multiplier: f64,
19    /// Maximum total time to spend on retries (including delays)
20    pub max_total_time: Duration,
21}
22
23impl Default for RetryConfig {
24    fn default() -> Self {
25        Self {
26            max_attempts: 3,
27            initial_delay: Duration::from_secs(1),
28            max_delay: Duration::from_secs(60),
29            backoff_multiplier: 2.0,
30            max_total_time: Duration::from_secs(120),
31        }
32    }
33}
34
35/// Transient errors that should trigger a retry
36#[derive(Debug, Clone, PartialEq)]
37pub enum TransientError {
38    /// Network connectivity issues
39    Network,
40    /// Rate limit exceeded (with optional retry-after seconds)
41    RateLimit(Option<u64>),
42    /// Server error (5xx)
43    ServerError,
44    /// Service unavailable (503)
45    ServiceUnavailable,
46    /// Gateway timeout (504)
47    GatewayTimeout,
48    /// Too many requests (429)
49    TooManyRequests,
50    /// Request timeout
51    Timeout,
52}
53
54impl TransientError {
55    /// Check if a reqwest error represents a transient error
56    pub fn from_reqwest_error(err: &reqwest::Error) -> Option<Self> {
57        if err.is_timeout() {
58            return Some(TransientError::Timeout);
59        }
60        if err.is_connect() {
61            return Some(TransientError::Network);
62        }
63
64        if let Some(status) = err.status() {
65            if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
66                return Some(TransientError::TooManyRequests);
67            }
68
69            if status == reqwest::StatusCode::SERVICE_UNAVAILABLE {
70                return Some(TransientError::ServiceUnavailable);
71            }
72
73            if status == reqwest::StatusCode::GATEWAY_TIMEOUT {
74                return Some(TransientError::GatewayTimeout);
75            }
76
77            if status.is_server_error() {
78                return Some(TransientError::ServerError);
79            }
80        }
81
82        None
83    }
84
85    /// Check if a SourceError represents a transient error
86    pub fn from_source_error(err: &SourceError) -> Option<Self> {
87        match err {
88            SourceError::RateLimit => Some(TransientError::RateLimit(None)),
89            SourceError::Network(_) => Some(TransientError::Network),
90            SourceError::Api(msg) => {
91                // Heuristic: check for common transient error patterns in messages
92                let msg_lower = msg.to_lowercase();
93                if msg_lower.contains("timeout") {
94                    Some(TransientError::Timeout)
95                } else if msg_lower.contains("service unavailable")
96                    || msg_lower.contains("temporarily unavailable")
97                {
98                    Some(TransientError::ServiceUnavailable)
99                } else {
100                    None
101                }
102            }
103            _ => None,
104        }
105    }
106
107    /// Get the recommended delay for this error
108    pub fn recommended_delay(&self) -> Duration {
109        match self {
110            TransientError::RateLimit(Some(seconds)) => Duration::from_secs(*seconds + 1),
111            TransientError::RateLimit(None) => Duration::from_secs(61),
112            TransientError::TooManyRequests => Duration::from_secs(61),
113            TransientError::ServiceUnavailable => Duration::from_secs(10),
114            TransientError::GatewayTimeout => Duration::from_secs(5),
115            TransientError::Timeout => Duration::from_secs(2),
116            TransientError::Network => Duration::from_secs(2),
117            TransientError::ServerError => Duration::from_secs(2),
118        }
119    }
120}
121
122/// Result of a retry operation
123pub enum RetryResult<T> {
124    /// Operation succeeded
125    Success(T),
126    /// Operation failed with a transient error after all retries
127    TransientFailure(SourceError, TransientError, u32),
128    /// Operation failed with a permanent error
129    PermanentFailure(SourceError),
130}
131
132/// Execute an async operation with retry logic
133///
134/// # Arguments
135///
136/// * `config` - Retry configuration
137/// * `operation` - The async operation to execute
138///
139/// # Returns
140///
141/// The result of the operation, or an error after all retries are exhausted
142pub async fn with_retry<T, F, Fut>(config: RetryConfig, operation: F) -> Result<T, SourceError>
143where
144    F: FnMut() -> Fut,
145    Fut: std::future::Future<Output = Result<T, SourceError>>,
146{
147    let mut attempts = 0;
148    let mut total_elapsed = Duration::ZERO;
149    let mut operation = operation;
150
151    loop {
152        attempts += 1;
153
154        match timeout(config.max_total_time, operation()).await {
155            Ok(Ok(result)) => {
156                // Success
157                if attempts > 1 {
158                    tracing::info!(
159                        "Operation succeeded on attempt {} after {} transient failures",
160                        attempts,
161                        attempts - 1
162                    );
163                }
164                return Ok(result);
165            }
166            Ok(Err(error)) => {
167                // Check if this is a transient error
168                if let Some(transient) = TransientError::from_source_error(&error) {
169                    // Calculate delay with exponential backoff
170                    let delay = if attempts == 1 {
171                        config.initial_delay
172                    } else {
173                        let exp_delay = config.initial_delay.as_secs_f64()
174                            * config.backoff_multiplier.powf(attempts as f64 - 1.0);
175                        let delay_secs = exp_delay.min(config.max_delay.as_secs_f64());
176                        Duration::from_secs_f64(delay_secs)
177                    };
178
179                    // Also consider error-specific recommended delay
180                    let delay = std::cmp::max(delay, transient.recommended_delay());
181
182                    total_elapsed += delay;
183
184                    if attempts >= config.max_attempts || total_elapsed >= config.max_total_time {
185                        tracing::warn!(
186                            "Operation failed after {} attempts (total elapsed: {:?}): {}",
187                            attempts,
188                            total_elapsed,
189                            error
190                        );
191                        return Err(error);
192                    }
193
194                    tracing::debug!(
195                        "Transient error on attempt {}: {:?}, retrying in {:?}",
196                        attempts,
197                        transient,
198                        delay
199                    );
200
201                    sleep(delay).await;
202                    continue;
203                } else {
204                    // Permanent error - return immediately
205                    return Err(error);
206                }
207            }
208            Err(_) => {
209                // Timeout of the entire operation
210                let error = SourceError::Network("Operation timed out".to_string());
211                if attempts >= config.max_attempts {
212                    return Err(error);
213                }
214
215                let delay = config.initial_delay;
216                total_elapsed += delay;
217
218                tracing::debug!(
219                    "Operation timed out, attempt {}/{}",
220                    attempts,
221                    config.max_attempts
222                );
223                sleep(delay).await;
224            }
225        }
226    }
227}
228
229/// Execute an async operation with retry logic that returns RetryResult
230///
231/// This provides more detailed information about failures for callers that need it
232pub async fn with_retry_detailed<T, F, Fut>(config: RetryConfig, operation: F) -> RetryResult<T>
233where
234    F: FnMut() -> Fut,
235    Fut: std::future::Future<Output = Result<T, SourceError>>,
236{
237    let mut attempts = 0;
238    let mut total_elapsed = Duration::ZERO;
239    let mut operation = operation;
240
241    loop {
242        attempts += 1;
243
244        match timeout(config.max_total_time, operation()).await {
245            Ok(Ok(result)) => {
246                return RetryResult::Success(result);
247            }
248            Ok(Err(error)) => {
249                if let Some(transient) = TransientError::from_source_error(&error) {
250                    let delay = if attempts == 1 {
251                        config.initial_delay
252                    } else {
253                        let exp_delay = config.initial_delay.as_secs_f64()
254                            * config.backoff_multiplier.powf(attempts as f64 - 1.0);
255                        Duration::from_secs_f64(exp_delay.min(config.max_delay.as_secs_f64()))
256                    };
257
258                    let delay = std::cmp::max(delay, transient.recommended_delay());
259                    total_elapsed += delay;
260
261                    if attempts >= config.max_attempts || total_elapsed >= config.max_total_time {
262                        return RetryResult::TransientFailure(error, transient, attempts);
263                    }
264
265                    sleep(delay).await;
266                    continue;
267                } else {
268                    return RetryResult::PermanentFailure(error);
269                }
270            }
271            Err(_) => {
272                let error = SourceError::Network("Operation timed out".to_string());
273                if attempts >= config.max_attempts {
274                    return RetryResult::TransientFailure(error, TransientError::Timeout, attempts);
275                }
276
277                let delay = config.initial_delay;
278                total_elapsed += delay;
279                sleep(delay).await;
280            }
281        }
282    }
283}
284
285/// Create a default retry configuration optimized for external APIs
286pub fn api_retry_config() -> RetryConfig {
287    RetryConfig {
288        max_attempts: 5,
289        initial_delay: Duration::from_secs(2),
290        max_delay: Duration::from_secs(120),
291        backoff_multiplier: 2.0,
292        max_total_time: Duration::from_secs(300),
293    }
294}
295
296/// Create a retry configuration for sources with strict rate limits
297pub fn strict_rate_limit_retry_config() -> RetryConfig {
298    RetryConfig {
299        max_attempts: 3,
300        initial_delay: Duration::from_secs(2),
301        max_delay: Duration::from_secs(120),
302        backoff_multiplier: 2.0,
303        max_total_time: Duration::from_secs(180),
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310    use std::cell::RefCell;
311    use std::rc::Rc;
312
313    #[tokio::test]
314    async fn test_retry_success_first_try() {
315        let config = RetryConfig::default();
316        let call_count = Rc::new(RefCell::new(0));
317
318        let result = {
319            let call_count = call_count.clone();
320            with_retry(config, move || {
321                let call_count = call_count.clone();
322                async move {
323                    *call_count.borrow_mut() += 1;
324                    Ok("success")
325                }
326            })
327        }
328        .await;
329
330        assert_eq!(result.unwrap(), "success");
331        assert_eq!(*call_count.borrow(), 1);
332    }
333
334    #[tokio::test]
335    async fn test_retry_success_after_failures() {
336        // Use Network error which has 2s recommended delay, so we need longer max_total_time
337        let config = RetryConfig {
338            max_attempts: 4, // 4 attempts = 3 retries + final attempt
339            initial_delay: Duration::from_millis(10),
340            max_delay: Duration::from_millis(100),
341            backoff_multiplier: 2.0,
342            max_total_time: Duration::from_secs(10),
343        };
344        let call_count = Rc::new(RefCell::new(0));
345
346        let result = {
347            let call_count = call_count.clone();
348            with_retry(config, move || {
349                let call_count = call_count.clone();
350                async move {
351                    *call_count.borrow_mut() += 1;
352                    let count = *call_count.borrow();
353                    if count < 3 {
354                        // Fail on attempts 1 and 2
355                        Err(SourceError::Network("temporary error".to_string()))
356                    } else {
357                        // Succeed on attempt 3
358                        Ok("success")
359                    }
360                }
361            })
362        }
363        .await;
364
365        assert_eq!(result.unwrap(), "success");
366        assert_eq!(*call_count.borrow(), 3);
367    }
368
369    #[tokio::test]
370    async fn test_retry_returns_permanent_error() {
371        let config = RetryConfig {
372            max_attempts: 5,
373            initial_delay: Duration::from_millis(10),
374            max_delay: Duration::from_millis(50),
375            backoff_multiplier: 2.0,
376            max_total_time: Duration::from_secs(5),
377        };
378        let call_count = Rc::new(RefCell::new(0));
379
380        let result: Result<&str, SourceError> = {
381            let call_count = call_count.clone();
382            with_retry(config, move || {
383                let call_count = call_count.clone();
384                async move {
385                    *call_count.borrow_mut() += 1;
386                    Err(SourceError::NotFound("not found".to_string()))
387                }
388            })
389        }
390        .await;
391
392        assert!(result.is_err());
393        if let Err(e) = result {
394            match e {
395                SourceError::NotFound(_) => {} // Expected
396                _ => panic!("Expected NotFound error"),
397            }
398        }
399        assert_eq!(*call_count.borrow(), 1); // Should not retry on permanent error
400    }
401
402    #[test]
403    fn test_transient_error_detection() {
404        // Test rate limit detection
405        let rate_limit_error = SourceError::RateLimit;
406        assert!(TransientError::from_source_error(&rate_limit_error).is_some());
407
408        // Test network error detection
409        let network_error = SourceError::Network("connection refused".to_string());
410        assert!(TransientError::from_source_error(&network_error).is_some());
411
412        // Test non-transient error
413        let parse_error = SourceError::Parse("invalid json".to_string());
414        assert!(TransientError::from_source_error(&parse_error).is_none());
415    }
416
417    #[test]
418    fn test_recommended_delay() {
419        assert_eq!(
420            TransientError::RateLimit(Some(30)).recommended_delay(),
421            Duration::from_secs(31)
422        );
423
424        assert_eq!(
425            TransientError::RateLimit(None).recommended_delay(),
426            Duration::from_secs(61)
427        );
428
429        assert_eq!(
430            TransientError::ServiceUnavailable.recommended_delay(),
431            Duration::from_secs(10)
432        );
433
434        assert_eq!(
435            TransientError::Network.recommended_delay(),
436            Duration::from_secs(2)
437        );
438    }
439}