Skip to main content

ai_agent/services/
retry.rs

1//! Retry logic with exponential backoff.
2//!
3//! Provides retry functionality similar to claude code's withRetry.
4
5use std::fmt::Display;
6use std::future::Future;
7use std::time::Duration;
8use tokio::time::sleep;
9
10/// Default maximum number of retries
11pub const DEFAULT_MAX_RETRIES: u32 = 10;
12
13/// Base delay in milliseconds
14pub const BASE_DELAY_MS: u64 = 500;
15
16/// Maximum delay cap in milliseconds
17pub const MAX_DELAY_MS: u64 = 32000;
18
19/// Error that indicates retries are exhausted
20#[derive(Debug)]
21pub struct RetryError<E> {
22    pub original_error: E,
23    pub attempts: u32,
24}
25
26impl<E: Display + Clone> Display for RetryError<E> {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        write!(
29            f,
30            "RetryError: {} after {} attempts",
31            self.original_error, self.attempts
32        )
33    }
34}
35
36impl<E: Display + Clone + std::fmt::Debug> std::error::Error for RetryError<E> {
37    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
38        None
39    }
40}
41
42/// Result type for retry operations
43pub type RetryResult<T, E> = Result<T, RetryError<E>>;
44
45/// Configuration for retry behavior
46pub struct RetryConfig {
47    /// Maximum number of retries (default: 10)
48    pub max_retries: u32,
49    /// Base delay in milliseconds (default: 500)
50    pub base_delay_ms: u64,
51    /// Maximum delay cap in milliseconds (default: 32000)
52    pub max_delay_ms: u64,
53    /// Enable jitter (default: true)
54    pub jitter: bool,
55    /// Retry on specific error conditions (takes error message as string)
56    pub should_retry: Option<Box<dyn Fn(&str) -> bool + Send + Sync>>,
57}
58
59impl RetryConfig {
60    /// Create default retry config
61    pub fn new() -> Self {
62        Self {
63            max_retries: DEFAULT_MAX_RETRIES,
64            base_delay_ms: BASE_DELAY_MS,
65            max_delay_ms: MAX_DELAY_MS,
66            jitter: true,
67            should_retry: None,
68        }
69    }
70}
71
72impl Default for RetryConfig {
73    fn default() -> Self {
74        Self::new()
75    }
76}
77
78/// Calculate retry delay with exponential backoff and optional jitter
79pub fn get_retry_delay(attempt: u32, retry_after_ms: Option<u64>, config: &RetryConfig) -> u64 {
80    // If retry-after header is provided, use it directly
81    if let Some(retry_after) = retry_after_ms {
82        return retry_after;
83    }
84
85    // Exponential backoff: base * 2^(attempt-1)
86    let base_delay = config
87        .base_delay_ms
88        .saturating_mul(2u64.saturating_pow(attempt - 1));
89    let delay = base_delay.min(config.max_delay_ms);
90
91    // Add jitter (25% of base delay)
92    if config.jitter {
93        let jitter = (delay as f64 * 0.25 * rand_jitter()) as u64;
94        delay + jitter
95    } else {
96        delay
97    }
98}
99
100/// Simple random jitter between 0 and 1
101fn rand_jitter() -> f64 {
102    use std::time::{SystemTime, UNIX_EPOCH};
103    let nanos = SystemTime::now()
104        .duration_since(UNIX_EPOCH)
105        .unwrap_or_default()
106        .subsec_nanos();
107    (nanos as f64) / (u32::MAX as f64)
108}
109
110/// Retry an async operation with exponential backoff
111///
112/// # Arguments
113/// * `operation` - The async operation to retry
114/// * `config` - Retry configuration
115///
116/// # Returns
117/// * `Ok(T)` - Success
118/// * `Err(RetryError<E>)` - All retries exhausted
119pub async fn retry_async<T, E, F, Fut>(mut operation: F, config: RetryConfig) -> RetryResult<T, E>
120where
121    F: FnMut() -> Fut,
122    Fut: Future<Output = Result<T, E>>,
123    E: std::fmt::Display + Clone,
124{
125    let mut last_error: Option<E> = None;
126
127    for attempt in 1..=config.max_retries + 1 {
128        match operation().await {
129            Ok(result) => return Ok(result),
130            Err(e) => {
131                last_error = Some(e.clone());
132
133                // Check if we should retry this error
134                if let Some(should_retry) = &config.should_retry {
135                    let error_str = format!("{}", e);
136                    if !should_retry(&error_str) {
137                        return Err(RetryError {
138                            original_error: e,
139                            attempts: attempt,
140                        });
141                    }
142                }
143
144                // Don't delay on the last attempt
145                if attempt <= config.max_retries {
146                    let delay = get_retry_delay(attempt, None, &config);
147                    sleep(Duration::from_millis(delay)).await;
148                }
149            }
150        }
151    }
152
153    Err(RetryError {
154        original_error: last_error.unwrap_or_else(|| {
155            panic!("retry_async called with max_retries=0 and no error occurred")
156        }),
157        attempts: config.max_retries + 1,
158    })
159}
160
161/// Retry an async operation with exponential backoff and retry-after support
162///
163/// # Arguments
164/// * `operation` - The async operation to retry (receives attempt number)
165/// * `config` - Retry configuration
166/// * `get_retry_after` - Extract retry-after from error (returns milliseconds)
167pub async fn retry_with_retry_after<T, E, F, Fut>(
168    mut operation: F,
169    config: RetryConfig,
170    get_retry_after: impl Fn(&E) -> Option<u64>,
171) -> RetryResult<T, E>
172where
173    F: FnMut(u32) -> Fut,
174    Fut: Future<Output = Result<T, E>>,
175    E: std::fmt::Display + Clone,
176{
177    let mut last_error: Option<E> = None;
178
179    for attempt in 1..=config.max_retries + 1 {
180        match operation(attempt).await {
181            Ok(result) => return Ok(result),
182            Err(e) => {
183                last_error = Some(e.clone());
184
185                // Check if we should retry this error
186                if let Some(should_retry) = &config.should_retry {
187                    let error_str = format!("{}", e);
188                    if !should_retry(&error_str) {
189                        return Err(RetryError {
190                            original_error: e,
191                            attempts: attempt,
192                        });
193                    }
194                }
195
196                // Don't delay on the last attempt
197                if attempt <= config.max_retries {
198                    let retry_after_ms = get_retry_after(&e);
199                    let delay = get_retry_delay(attempt, retry_after_ms, &config);
200                    sleep(Duration::from_millis(delay)).await;
201                }
202            }
203        }
204    }
205
206    Err(RetryError {
207        original_error: last_error.unwrap_or_else(|| {
208            panic!("retry_with_retry_after called with max_retries=0 and no error occurred")
209        }),
210        attempts: config.max_retries + 1,
211    })
212}
213
214/// Check if an error is a rate limit error (429)
215pub fn is_rate_limit_error(error: &str) -> bool {
216    error.contains("429") || error.to_lowercase().contains("rate limit")
217}
218
219/// Check if an error is a service unavailable error (529)
220pub fn is_service_unavailable_error(error: &str) -> bool {
221    error.contains("529") || error.contains("overloaded")
222}
223
224/// Check if an error is a temporary error that should be retried
225pub fn is_retryable_error(error: &str) -> bool {
226    is_rate_limit_error(error)
227        || is_service_unavailable_error(error)
228        || is_connection_error(error)
229        || is_server_error(error)
230}
231
232/// Check if an error is a connection error
233pub fn is_connection_error(error: &str) -> bool {
234    let error_str = error.to_lowercase();
235    error_str.contains("connection")
236        || error_str.contains("econnreset")
237        || error_str.contains("econnrefused")
238        || error_str.contains("epipe")
239        || error_str.contains("timeout")
240}
241
242/// Check if an error is a server error (5xx)
243pub fn is_server_error(error: &str) -> bool {
244    // Check for 5xx status codes in error message
245    error.contains("500")
246        || error.contains("501")
247        || error.contains("502")
248        || error.contains("503")
249        || error.contains("504")
250}
251
252/// Create a retry config for rate limit errors
253pub fn rate_limit_config() -> RetryConfig {
254    RetryConfig {
255        max_retries: 5,
256        base_delay_ms: 1000,
257        max_delay_ms: 60000,
258        jitter: true,
259        should_retry: Some(Box::new(|e| is_rate_limit_error(e))),
260    }
261}
262
263/// Create a retry config for service unavailable errors
264pub fn service_unavailable_config() -> RetryConfig {
265    RetryConfig {
266        max_retries: 3,
267        base_delay_ms: 2000,
268        max_delay_ms: 30000,
269        jitter: true,
270        should_retry: Some(Box::new(|e| is_service_unavailable_error(e))),
271    }
272}
273
274/// Create a retry config for all retryable errors
275pub fn default_retry_config() -> RetryConfig {
276    RetryConfig::default()
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[tokio::test]
284    async fn test_retry_success_first_try() {
285        let call_count = std::sync::atomic::AtomicU32::new(0);
286        let operation = || {
287            let call_count = &call_count;
288            async move {
289                call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
290                Ok::<_, &'static str>("success")
291            }
292        };
293
294        let result = retry_async(operation, RetryConfig::default()).await;
295        assert!(result.is_ok());
296        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
297    }
298
299    #[tokio::test]
300    async fn test_retry_success_after_failures() {
301        let call_count = std::sync::atomic::AtomicU32::new(0);
302        let operation = || {
303            let call_count = &call_count;
304            async move {
305                let count = call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
306                if count < 2 {
307                    Err("temporary error")
308                } else {
309                    Ok("success")
310                }
311            }
312        };
313
314        let result = retry_async(operation, RetryConfig::default()).await;
315        assert!(result.is_ok());
316        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
317    }
318
319    #[tokio::test]
320    async fn test_retry_exhausted() {
321        let call_count = std::sync::atomic::AtomicU32::new(0);
322        let operation = || {
323            let call_count = &call_count;
324            async move {
325                call_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
326                Err::<String, _>("persistent error")
327            }
328        };
329
330        let config = RetryConfig {
331            max_retries: 3,
332            ..Default::default()
333        };
334        let result = retry_async(operation, config).await;
335        assert!(result.is_err());
336        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 4);
337    }
338
339    #[tokio::test]
340    async fn test_retry_with_should_retry() {
341        let operation = || async move { Err::<String, _>("rate limit") };
342
343        let config = RetryConfig {
344            max_retries: 3,
345            should_retry: Some(Box::new(|e| format!("{}", e).contains("rate limit"))),
346            ..Default::default()
347        };
348        let result = retry_async(operation, config).await;
349        assert!(result.is_err());
350    }
351
352    #[test]
353    fn test_get_retry_delay_exponential() {
354        let config = RetryConfig {
355            base_delay_ms: 100,
356            max_delay_ms: 10000,
357            jitter: false,
358            ..Default::default()
359        };
360
361        assert_eq!(get_retry_delay(1, None, &config), 100);
362        assert_eq!(get_retry_delay(2, None, &config), 200);
363        assert_eq!(get_retry_delay(3, None, &config), 400);
364        assert_eq!(get_retry_delay(4, None, &config), 800);
365    }
366
367    #[test]
368    fn test_get_retry_delay_max_cap() {
369        let config = RetryConfig {
370            base_delay_ms: 1000,
371            max_delay_ms: 500,
372            jitter: false,
373            ..Default::default()
374        };
375
376        // Should be capped at max_delay_ms
377        assert_eq!(get_retry_delay(10, None, &config), 500);
378    }
379
380    #[test]
381    fn test_get_retry_delay_with_retry_after() {
382        let config = RetryConfig::default();
383
384        // Should use retry-after if provided
385        let delay = get_retry_delay(1, Some(5000), &config);
386        assert_eq!(delay, 5000);
387    }
388
389    #[test]
390    fn test_is_rate_limit_error() {
391        assert!(is_rate_limit_error(&"429 Too Many Requests"));
392        assert!(is_rate_limit_error(&"rate limit exceeded"));
393        assert!(!is_rate_limit_error(&"404 Not Found"));
394    }
395
396    #[test]
397    fn test_is_service_unavailable_error() {
398        assert!(is_service_unavailable_error(&"529 Service Unavailable"));
399        assert!(is_service_unavailable_error(&"server overloaded"));
400        assert!(!is_service_unavailable_error(&"400 Bad Request"));
401    }
402
403    #[test]
404    fn test_is_connection_error() {
405        assert!(is_connection_error(&"connection refused"));
406        assert!(is_connection_error(&"ECONNRESET"));
407        assert!(!is_connection_error(&"404 Not Found"));
408    }
409
410    #[test]
411    fn test_is_server_error() {
412        assert!(is_server_error(&"500 Internal Server Error"));
413        assert!(is_server_error(&"503 Service Unavailable"));
414        assert!(!is_server_error(&"400 Bad Request"));
415    }
416}