mixtape_core/provider/
retry.rs

1//! Shared retry logic for model providers
2//!
3//! This module provides exponential backoff with jitter for retrying
4//! transient errors like rate limiting, service unavailability, and
5//! network issues.
6
7use super::ProviderError;
8use std::sync::Arc;
9use std::time::{Duration, SystemTime, UNIX_EPOCH};
10
11/// Configuration for retry behavior on transient errors (throttling, rate limits)
12#[derive(Debug, Clone)]
13pub struct RetryConfig {
14    /// Maximum number of retry attempts (default: 8)
15    pub max_attempts: usize,
16    /// Base delay in milliseconds for exponential backoff (default: 500ms)
17    pub base_delay_ms: u64,
18    /// Maximum delay cap in milliseconds (default: 30000ms)
19    pub max_delay_ms: u64,
20}
21
22impl Default for RetryConfig {
23    fn default() -> Self {
24        Self {
25            max_attempts: 8,
26            base_delay_ms: 500,
27            max_delay_ms: 30_000,
28        }
29    }
30}
31
32/// Information about a retry attempt
33#[derive(Debug, Clone)]
34pub struct RetryInfo {
35    /// Which attempt this is (1-based)
36    pub attempt: usize,
37    /// Maximum attempts configured
38    pub max_attempts: usize,
39    /// How long we'll wait before retrying
40    pub delay: Duration,
41    /// The error that triggered the retry
42    pub error: String,
43}
44
45/// Callback type for retry events
46pub type RetryCallback = Arc<dyn Fn(RetryInfo) + Send + Sync>;
47
48/// Determine if an error is transient and should be retried
49pub fn is_retryable_error(err: &ProviderError) -> bool {
50    match err {
51        // These are transient and should be retried
52        ProviderError::RateLimited(_) => true,
53        ProviderError::ServiceUnavailable(_) => true,
54        ProviderError::Network(_) => true,
55        ProviderError::Communication(_) => true,
56
57        // These are permanent and should not be retried
58        ProviderError::Authentication(_) => false,
59        ProviderError::Configuration(_) => false,
60        ProviderError::Model(_) => false,
61        ProviderError::Other(_) => false,
62    }
63}
64
65/// Calculate backoff delay for a given attempt using exponential backoff with jitter
66pub fn backoff_delay(attempt: usize, config: &RetryConfig) -> Duration {
67    let shift = (attempt.saturating_sub(1)).min(10) as u32;
68    let exp = 1_u64.checked_shl(shift).unwrap_or(u64::MAX);
69    let base = config.base_delay_ms.saturating_mul(exp);
70    let capped = base.min(config.max_delay_ms);
71    let jittered = jitter_ms(capped);
72    Duration::from_millis(jittered)
73}
74
75/// Apply ±20% jitter to a base delay
76fn jitter_ms(base_ms: u64) -> u64 {
77    let nanos = SystemTime::now()
78        .duration_since(UNIX_EPOCH)
79        .unwrap_or_default()
80        .subsec_nanos() as i64;
81    let jitter_pct = (nanos % 41) - 20; // -20..20
82    let base = base_ms as i64;
83    let jittered = base + (base * jitter_pct / 100);
84    jittered.max(0) as u64
85}
86
87/// Retry an async operation with exponential backoff
88///
89/// Only retries on transient errors (rate limiting, service unavailable, network).
90/// Permanent errors (authentication, configuration, model) fail immediately.
91///
92/// # Example
93///
94/// ```ignore
95/// let result = retry_with_backoff(
96///     || async { provider.generate(messages.clone(), tools.clone(), system.clone()).await },
97///     &config,
98///     &Some(Arc::new(|info| eprintln!("Retry {}: {}", info.attempt, info.error))),
99/// ).await?;
100/// ```
101pub async fn retry_with_backoff<F, Fut, T>(
102    mut op: F,
103    config: &RetryConfig,
104    on_retry: &Option<RetryCallback>,
105) -> Result<T, ProviderError>
106where
107    F: FnMut() -> Fut,
108    Fut: std::future::Future<Output = Result<T, ProviderError>>,
109{
110    let mut attempt = 0;
111    loop {
112        attempt += 1;
113        match op().await {
114            Ok(result) => return Ok(result),
115            Err(err) => {
116                if attempt >= config.max_attempts || !is_retryable_error(&err) {
117                    return Err(err);
118                }
119                let delay = backoff_delay(attempt, config);
120
121                // Notify callback if set
122                if let Some(callback) = on_retry {
123                    callback(RetryInfo {
124                        attempt,
125                        max_attempts: config.max_attempts,
126                        delay,
127                        error: err.to_string(),
128                    });
129                }
130
131                tokio::time::sleep(delay).await;
132            }
133        }
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    #[test]
142    fn test_retry_config_default() {
143        let config = RetryConfig::default();
144        assert_eq!(config.max_attempts, 8);
145        assert_eq!(config.base_delay_ms, 500);
146        assert_eq!(config.max_delay_ms, 30_000);
147    }
148
149    #[test]
150    fn test_is_retryable_error_rate_limited() {
151        assert!(is_retryable_error(&ProviderError::RateLimited(
152            "too many requests".into()
153        )));
154    }
155
156    #[test]
157    fn test_is_retryable_error_service_unavailable() {
158        assert!(is_retryable_error(&ProviderError::ServiceUnavailable(
159            "503".into()
160        )));
161    }
162
163    #[test]
164    fn test_is_retryable_error_network() {
165        assert!(is_retryable_error(&ProviderError::Network(
166            "connection refused".into()
167        )));
168    }
169
170    #[test]
171    fn test_is_retryable_error_communication() {
172        assert!(is_retryable_error(&ProviderError::Communication(
173            "timeout".into()
174        )));
175    }
176
177    #[test]
178    fn test_is_retryable_error_not_retryable() {
179        // Authentication errors should not be retried
180        assert!(!is_retryable_error(&ProviderError::Authentication(
181            "bad creds".into()
182        )));
183
184        // Configuration errors should not be retried
185        assert!(!is_retryable_error(&ProviderError::Configuration(
186            "invalid model".into()
187        )));
188
189        // Model errors should not be retried
190        assert!(!is_retryable_error(&ProviderError::Model(
191            "content filtered".into()
192        )));
193
194        // Generic errors should not be retried
195        assert!(!is_retryable_error(&ProviderError::Other("unknown".into())));
196    }
197
198    #[test]
199    fn test_backoff_delay_first_attempt() {
200        let config = RetryConfig::default();
201        let delay = backoff_delay(1, &config);
202
203        // First attempt: base_delay (500ms) * 2^0 = 500ms, with jitter
204        // Allow for ±20% jitter
205        assert!(delay.as_millis() >= 400);
206        assert!(delay.as_millis() <= 600);
207    }
208
209    #[test]
210    fn test_backoff_delay_exponential_growth() {
211        let config = RetryConfig {
212            base_delay_ms: 100,
213            max_delay_ms: 10_000,
214            max_attempts: 10,
215        };
216
217        let delay1 = backoff_delay(1, &config);
218        let delay2 = backoff_delay(2, &config);
219        let delay3 = backoff_delay(3, &config);
220
221        // Each delay should roughly double (accounting for jitter)
222        // delay1 ~ 100ms, delay2 ~ 200ms, delay3 ~ 400ms
223        assert!(delay2.as_millis() > delay1.as_millis());
224        assert!(delay3.as_millis() > delay2.as_millis());
225    }
226
227    #[test]
228    fn test_backoff_delay_respects_max() {
229        let config = RetryConfig {
230            base_delay_ms: 1000,
231            max_delay_ms: 2000,
232            max_attempts: 10,
233        };
234
235        // After several attempts, should cap at max_delay_ms
236        let delay = backoff_delay(10, &config);
237        // With jitter, should be around 2000ms ± 20%
238        assert!(delay.as_millis() <= 2400);
239    }
240
241    #[test]
242    fn test_jitter_ms_produces_variation() {
243        // Jitter should produce values within ±20% of base
244        let base = 1000u64;
245
246        // Call multiple times and verify range
247        // Due to deterministic time-based jitter, we just verify it's in range
248        let jittered = jitter_ms(base);
249        assert!(jittered >= 800); // base - 20%
250        assert!(jittered <= 1200); // base + 20%
251    }
252
253    #[tokio::test]
254    async fn test_retry_with_backoff_success_first_try() {
255        let config = RetryConfig {
256            max_attempts: 3,
257            base_delay_ms: 10,
258            max_delay_ms: 100,
259        };
260
261        let mut call_count = 0;
262        let result = retry_with_backoff(
263            || {
264                call_count += 1;
265                async { Ok::<_, ProviderError>("success") }
266            },
267            &config,
268            &None,
269        )
270        .await;
271
272        assert!(result.is_ok());
273        assert_eq!(result.unwrap(), "success");
274        assert_eq!(call_count, 1);
275    }
276
277    #[tokio::test]
278    async fn test_retry_with_backoff_retries_on_transient_error() {
279        let config = RetryConfig {
280            max_attempts: 3,
281            base_delay_ms: 1, // Very short for testing
282            max_delay_ms: 10,
283        };
284
285        let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
286        let count_clone = call_count.clone();
287
288        let result = retry_with_backoff(
289            || {
290                let count = count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
291                async move {
292                    if count < 2 {
293                        Err(ProviderError::RateLimited("throttled".into()))
294                    } else {
295                        Ok("success after retry")
296                    }
297                }
298            },
299            &config,
300            &None,
301        )
302        .await;
303
304        assert!(result.is_ok());
305        assert_eq!(result.unwrap(), "success after retry");
306        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 3);
307    }
308
309    #[tokio::test]
310    async fn test_retry_with_backoff_gives_up_after_max_attempts() {
311        let config = RetryConfig {
312            max_attempts: 2,
313            base_delay_ms: 1,
314            max_delay_ms: 10,
315        };
316
317        let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
318        let count_clone = call_count.clone();
319
320        let result: Result<(), ProviderError> = retry_with_backoff(
321            || {
322                count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
323                async { Err(ProviderError::RateLimited("always throttled".into())) }
324            },
325            &config,
326            &None,
327        )
328        .await;
329
330        assert!(result.is_err());
331        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 2);
332    }
333
334    #[tokio::test]
335    async fn test_retry_with_backoff_no_retry_on_permanent_error() {
336        let config = RetryConfig {
337            max_attempts: 5,
338            base_delay_ms: 1,
339            max_delay_ms: 10,
340        };
341
342        let call_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
343        let count_clone = call_count.clone();
344
345        let result: Result<(), ProviderError> = retry_with_backoff(
346            || {
347                count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
348                async { Err(ProviderError::Authentication("bad credentials".into())) }
349            },
350            &config,
351            &None,
352        )
353        .await;
354
355        assert!(result.is_err());
356        // Should only try once since auth errors are not retryable
357        assert_eq!(call_count.load(std::sync::atomic::Ordering::SeqCst), 1);
358    }
359
360    #[tokio::test]
361    async fn test_retry_with_backoff_callback_invoked() {
362        let config = RetryConfig {
363            max_attempts: 3,
364            base_delay_ms: 1,
365            max_delay_ms: 10,
366        };
367
368        let callback_count = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
369        let callback_count_clone = callback_count.clone();
370
371        let callback: RetryCallback = Arc::new(move |info: RetryInfo| {
372            callback_count_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
373            assert!(info.attempt > 0);
374            assert_eq!(info.max_attempts, 3);
375        });
376
377        let attempt = std::sync::Arc::new(std::sync::atomic::AtomicUsize::new(0));
378        let attempt_clone = attempt.clone();
379
380        let _result: Result<(), ProviderError> = retry_with_backoff(
381            || {
382                let count = attempt_clone.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
383                async move {
384                    if count < 2 {
385                        Err(ProviderError::ServiceUnavailable("503".into()))
386                    } else {
387                        Ok(())
388                    }
389                }
390            },
391            &config,
392            &Some(callback),
393        )
394        .await;
395
396        // Callback should be invoked for each retry (not the initial attempt)
397        assert_eq!(callback_count.load(std::sync::atomic::Ordering::SeqCst), 2);
398    }
399}