Skip to main content

aix_core/
retry.rs

1//! Retry logic with exponential backoff and jitter.
2//!
3//! This module provides configurable retry strategies for handling transient
4//! failures like rate limits, network errors, and timeouts.
5
6use crate::error::{AixError, AixResult};
7use rand::Rng;
8use std::future::Future;
9use std::time::Duration;
10
11/// Configuration for retry behavior.
12#[derive(Debug, Clone)]
13pub struct RetryConfig {
14    /// Maximum number of retry attempts
15    pub max_attempts: u32,
16    /// Initial backoff duration before the first retry
17    pub initial_backoff: Duration,
18    /// Maximum backoff duration
19    pub max_backoff: Duration,
20    /// Multiplier for exponential backoff
21    pub multiplier: f64,
22    /// Whether to add jitter to prevent thundering herd
23    pub jitter: bool,
24    /// Whether to retry on rate limit errors
25    pub retry_on_rate_limit: bool,
26    /// Whether to retry on transport errors
27    pub retry_on_transport: bool,
28    /// Whether to retry on timeout errors
29    pub retry_on_timeout: bool,
30}
31
32impl RetryConfig {
33    /// Create a new retry configuration with sensible defaults.
34    pub fn new() -> Self {
35        Self::default()
36    }
37
38    /// Create a new retry configuration builder.
39    pub fn builder() -> RetryConfigBuilder {
40        RetryConfigBuilder::new()
41    }
42
43    /// Set the maximum number of attempts.
44    pub fn with_max_attempts(mut self, max_attempts: u32) -> Self {
45        self.max_attempts = max_attempts;
46        self
47    }
48
49    /// Set the initial backoff duration.
50    pub fn with_initial_backoff(mut self, initial_backoff: Duration) -> Self {
51        self.initial_backoff = initial_backoff;
52        self
53    }
54
55    /// Set the maximum backoff duration.
56    pub fn with_max_backoff(mut self, max_backoff: Duration) -> Self {
57        self.max_backoff = max_backoff;
58        self
59    }
60
61    /// Set the backoff multiplier.
62    pub fn with_multiplier(mut self, multiplier: f64) -> Self {
63        self.multiplier = multiplier;
64        self
65    }
66
67    /// Enable or disable jitter.
68    pub fn with_jitter(mut self, jitter: bool) -> Self {
69        self.jitter = jitter;
70        self
71    }
72
73    /// Set whether to retry on rate limit errors.
74    pub fn with_retry_on_rate_limit(mut self, retry: bool) -> Self {
75        self.retry_on_rate_limit = retry;
76        self
77    }
78
79    /// Set whether to retry on transport errors.
80    pub fn with_retry_on_transport(mut self, retry: bool) -> Self {
81        self.retry_on_transport = retry;
82        self
83    }
84
85    /// Set whether to retry on timeout errors.
86    pub fn with_retry_on_timeout(mut self, retry: bool) -> Self {
87        self.retry_on_timeout = retry;
88        self
89    }
90
91    /// Check if an error should be retried based on this configuration.
92    pub fn should_retry(&self, error: &AixError) -> bool {
93        if !error.is_retryable() {
94            return false;
95        }
96
97        match error {
98            AixError::RateLimit { .. } => self.retry_on_rate_limit,
99            AixError::Transport { .. } => self.retry_on_transport,
100            AixError::Timeout { .. } => self.retry_on_timeout,
101            AixError::Provider { status, .. } => {
102                // For provider errors, check if it's a server error (5xx)
103                status.map_or(false, |s| s >= 500)
104            }
105            _ => false,
106        }
107    }
108
109    /// Calculate the delay for a given attempt number.
110    ///
111    /// # Arguments
112    /// * `attempt` - The attempt number (0-based)
113    ///
114    /// # Returns
115    /// The duration to wait before the next attempt
116    pub fn calculate_delay(&self, attempt: u32) -> Duration {
117        // Exponential backoff: delay = initial * multiplier^attempt
118        let base_delay = self.initial_backoff.as_secs_f64() * self.multiplier.powi(attempt as i32);
119        let base_delay = Duration::from_secs_f64(base_delay);
120
121        // Cap at max_backoff
122        let delay = std::cmp::min(base_delay, self.max_backoff);
123
124        // Add jitter if enabled
125        if self.jitter {
126            let jitter_range = delay.as_secs_f64() * 0.5; // 50% jitter
127            let jitter = rand::thread_rng().gen_range(0.0..jitter_range);
128            let actual_delay = delay.as_secs_f64() * (0.5 + jitter / jitter_range);
129            Duration::from_secs_f64(actual_delay)
130        } else {
131            delay
132        }
133    }
134
135    /// Extract retry delay from rate limit error if available.
136    pub fn extract_retry_delay(&self, error: &AixError) -> Option<Duration> {
137        match error {
138            AixError::RateLimit { retry_after, .. } => *retry_after,
139            _ => None,
140        }
141    }
142}
143
144impl Default for RetryConfig {
145    fn default() -> Self {
146        Self {
147            max_attempts: 3,
148            initial_backoff: Duration::from_millis(1000),
149            max_backoff: Duration::from_secs(30),
150            multiplier: 2.0,
151            jitter: true,
152            retry_on_rate_limit: true,
153            retry_on_transport: true,
154            retry_on_timeout: true,
155        }
156    }
157}
158
159/// Builder for creating `RetryConfig` instances.
160pub struct RetryConfigBuilder {
161    config: RetryConfig,
162}
163
164impl RetryConfigBuilder {
165    /// Create a new retry configuration builder.
166    pub fn new() -> Self {
167        Self {
168            config: RetryConfig::default(),
169        }
170    }
171
172    /// Set the maximum number of attempts.
173    pub fn max_attempts(mut self, max_attempts: u32) -> Self {
174        self.config.max_attempts = max_attempts;
175        self
176    }
177
178    /// Set the initial backoff duration.
179    pub fn initial_backoff(mut self, initial_backoff: Duration) -> Self {
180        self.config.initial_backoff = initial_backoff;
181        self
182    }
183
184    /// Set the maximum backoff duration.
185    pub fn max_backoff(mut self, max_backoff: Duration) -> Self {
186        self.config.max_backoff = max_backoff;
187        self
188    }
189
190    /// Set the backoff multiplier.
191    pub fn multiplier(mut self, multiplier: f64) -> Self {
192        self.config.multiplier = multiplier;
193        self
194    }
195
196    /// Enable or disable jitter.
197    pub fn jitter(mut self, jitter: bool) -> Self {
198        self.config.jitter = jitter;
199        self
200    }
201
202    /// Set whether to retry on rate limit errors.
203    pub fn retry_on_rate_limit(mut self, retry: bool) -> Self {
204        self.config.retry_on_rate_limit = retry;
205        self
206    }
207
208    /// Set whether to retry on transport errors.
209    pub fn retry_on_transport(mut self, retry: bool) -> Self {
210        self.config.retry_on_transport = retry;
211        self
212    }
213
214    /// Set whether to retry on timeout errors.
215    pub fn retry_on_timeout(mut self, retry: bool) -> Self {
216        self.config.retry_on_timeout = retry;
217        self
218    }
219
220    /// Build the retry configuration.
221    pub fn build(self) -> RetryConfig {
222        self.config
223    }
224}
225
226impl Default for RetryConfigBuilder {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232/// Retry strategy executor.
233pub struct RetryStrategy {
234    config: RetryConfig,
235}
236
237impl RetryStrategy {
238    /// Create a new retry strategy with the given configuration.
239    pub fn new(config: RetryConfig) -> Self {
240        Self { config }
241    }
242
243    /// Execute an operation with retry logic.
244    ///
245    /// # Arguments
246    /// * `f` - A closure that returns a Future yielding an `AixResult`
247    ///
248    /// # Returns
249    /// The result of the operation, or the last error if all retries fail
250    pub async fn execute<F, Fut, T>(&self, mut f: F) -> AixResult<T>
251    where
252        F: FnMut() -> Fut,
253        Fut: Future<Output = AixResult<T>>,
254    {
255        let mut last_error = None;
256
257        for attempt in 0..self.config.max_attempts {
258            // Attempt the operation
259            match f().await {
260                Ok(result) => return Ok(result),
261                Err(error) => {
262                    last_error = Some(error.clone());
263
264                    // Check if we should retry this error
265                    if !self.config.should_retry(&error) {
266                        return Err(error);
267                    }
268
269                    // If this is the last attempt, don't wait
270                    if attempt == self.config.max_attempts - 1 {
271                        return Err(error);
272                    }
273
274                    // Calculate delay
275                    let delay = self
276                        .config
277                        .extract_retry_delay(&error)
278                        .unwrap_or_else(|| self.config.calculate_delay(attempt));
279
280                    // Wait before retrying
281                    tokio::time::sleep(delay).await;
282                }
283            }
284
285            // Increment attempt counter (already done in for loop)
286        }
287
288        // All attempts failed, return the last error
289        Err(last_error.unwrap_or_else(|| AixError::other("All retry attempts failed")))
290    }
291
292    /// Get a reference to the configuration.
293    pub fn config(&self) -> &RetryConfig {
294        &self.config
295    }
296
297    /// Get a mutable reference to the configuration.
298    pub fn config_mut(&mut self) -> &mut RetryConfig {
299        &mut self.config
300    }
301}
302
303impl From<RetryConfig> for RetryStrategy {
304    fn from(config: RetryConfig) -> Self {
305        Self::new(config)
306    }
307}
308
309/// Preset retry configurations.
310impl RetryConfig {
311    /// No retry - attempt operation only once.
312    pub fn no_retry() -> Self {
313        Self {
314            max_attempts: 1,
315            initial_backoff: Duration::from_millis(0),
316            max_backoff: Duration::from_millis(0),
317            multiplier: 1.0,
318            jitter: false,
319            retry_on_rate_limit: false,
320            retry_on_transport: false,
321            retry_on_timeout: false,
322        }
323    }
324
325    /// Conservative retry - fewer attempts with longer delays.
326    pub fn conservative() -> Self {
327        Self {
328            max_attempts: 2,
329            initial_backoff: Duration::from_secs(2),
330            max_backoff: Duration::from_secs(10),
331            multiplier: 2.0,
332            jitter: true,
333            retry_on_rate_limit: true,
334            retry_on_transport: false, // Don't retry transport errors conservatively
335            retry_on_timeout: false,
336        }
337    }
338
339    /// Aggressive retry - more attempts with shorter initial delays.
340    pub fn aggressive() -> Self {
341        Self {
342            max_attempts: 5,
343            initial_backoff: Duration::from_millis(500),
344            max_backoff: Duration::from_secs(30),
345            multiplier: 1.5,
346            jitter: true,
347            retry_on_rate_limit: true,
348            retry_on_transport: true,
349            retry_on_timeout: true,
350        }
351    }
352
353    /// Fast retry - for operations that should retry quickly.
354    pub fn fast() -> Self {
355        Self {
356            max_attempts: 3,
357            initial_backoff: Duration::from_millis(200),
358            max_backoff: Duration::from_secs(5),
359            multiplier: 1.5,
360            jitter: true,
361            retry_on_rate_limit: true,
362            retry_on_transport: true,
363            retry_on_timeout: false, // Don't retry timeouts in fast mode
364        }
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371    use crate::error::AixError;
372
373    #[test]
374    fn test_retry_config_builder() {
375        let config = RetryConfig::builder()
376            .max_attempts(5)
377            .initial_backoff(Duration::from_millis(500))
378            .max_backoff(Duration::from_secs(10))
379            .multiplier(1.5)
380            .jitter(false)
381            .retry_on_rate_limit(false)
382            .build();
383
384        assert_eq!(config.max_attempts, 5);
385        assert_eq!(config.initial_backoff, Duration::from_millis(500));
386        assert_eq!(config.max_backoff, Duration::from_secs(10));
387        assert_eq!(config.multiplier, 1.5);
388        assert!(!config.jitter);
389        assert!(!config.retry_on_rate_limit);
390    }
391
392    #[test]
393    fn test_backoff_calculation() {
394        let config = RetryConfig {
395            max_attempts: 3,
396            initial_backoff: Duration::from_millis(1000),
397            max_backoff: Duration::from_secs(10),
398            multiplier: 2.0,
399            jitter: false, // Disable jitter for predictable testing
400            retry_on_rate_limit: true,
401            retry_on_transport: true,
402            retry_on_timeout: true,
403        };
404
405        // First attempt: 1000ms
406        assert_eq!(config.calculate_delay(0), Duration::from_millis(1000));
407        // Second attempt: 2000ms
408        assert_eq!(config.calculate_delay(1), Duration::from_millis(2000));
409        // Third attempt: 4000ms
410        assert_eq!(config.calculate_delay(2), Duration::from_millis(4000));
411
412        // Test capping at max_backoff
413        let long_config = RetryConfig {
414            max_attempts: 10,
415            initial_backoff: Duration::from_millis(1000),
416            max_backoff: Duration::from_millis(3000),
417            multiplier: 2.0,
418            jitter: false,
419            retry_on_rate_limit: true,
420            retry_on_transport: true,
421            retry_on_timeout: true,
422        };
423
424        // Should be capped at 3000ms
425        assert_eq!(long_config.calculate_delay(3), Duration::from_millis(3000));
426    }
427
428    #[test]
429    fn test_jitter() {
430        let config = RetryConfig {
431            max_attempts: 3,
432            initial_backoff: Duration::from_millis(1000),
433            max_backoff: Duration::from_secs(10),
434            multiplier: 2.0,
435            jitter: true,
436            retry_on_rate_limit: true,
437            retry_on_transport: true,
438            retry_on_timeout: true,
439        };
440
441        // With jitter, the delay should be between 500ms and 1500ms for first attempt
442        let delay = config.calculate_delay(0);
443        assert!(delay >= Duration::from_millis(500));
444        assert!(delay <= Duration::from_millis(1500));
445    }
446
447    #[test]
448    fn test_should_retry() {
449        let config = RetryConfig::default();
450
451        // Retryable errors
452        assert!(config.should_retry(&AixError::transport("network error", "request")));
453        assert!(config.should_retry(&AixError::rate_limit("openai", "too many requests")));
454        assert!(config.should_retry(&AixError::timeout("chat", Duration::from_secs(30))));
455        assert!(config.should_retry(&AixError::provider_with_details("openai", "server error", 500, "internal_error")));
456
457        // Non-retryable errors
458        assert!(!config.should_retry(&AixError::config("invalid config")));
459        assert!(!config.should_retry(&AixError::auth("openai", "unauthorized")));
460        assert!(!config.should_retry(&AixError::provider_with_details("openai", "bad request", 400, "invalid_request")));
461    }
462
463    #[tokio::test]
464    async fn test_retry_strategy_success() {
465        let strategy = RetryStrategy::new(RetryConfig::default());
466        let mut call_count = 0;
467
468        let result = strategy
469            .execute(|| {
470                call_count += 1;
471                async move { Ok::<_, AixError>("success") }
472            })
473            .await;
474
475        assert_eq!(result.unwrap(), "success");
476        assert_eq!(call_count, 1); // Should only be called once
477    }
478
479    #[tokio::test]
480    async fn test_retry_strategy_with_retry() {
481        let strategy = RetryStrategy::new(RetryConfig::builder().max_attempts(3).build());
482        let mut call_count = 0;
483
484        let result = strategy
485            .execute(|| {
486                call_count += 1;
487                async move {
488                    if call_count < 3 {
489                        Err::<_, AixError>(AixError::transport("network error", "request"))
490                    } else {
491                        Ok("success")
492                    }
493                }
494            })
495            .await;
496
497        assert_eq!(result.unwrap(), "success");
498        assert_eq!(call_count, 3); // Should be called 3 times
499    }
500
501    #[tokio::test]
502    async fn test_retry_strategy_exhausted() {
503        let strategy = RetryStrategy::new(RetryConfig::builder().max_attempts(2).build());
504        let mut call_count = 0;
505
506        let result = strategy
507            .execute(|| {
508                call_count += 1;
509                async move {
510                    Err::<_, AixError>(AixError::transport("network error", "request"))
511                }
512            })
513            .await;
514
515        assert!(result.is_err());
516        assert_eq!(call_count, 2); // Should be called max_attempts times
517    }
518
519    #[test]
520    fn test_preset_configs() {
521        let no_retry = RetryConfig::no_retry();
522        assert_eq!(no_retry.max_attempts, 1);
523        assert!(!no_retry.retry_on_rate_limit);
524
525        let conservative = RetryConfig::conservative();
526        assert_eq!(conservative.max_attempts, 2);
527        assert!(!conservative.retry_on_transport);
528
529        let aggressive = RetryConfig::aggressive();
530        assert_eq!(aggressive.max_attempts, 5);
531        assert!(aggressive.retry_on_transport);
532
533        let fast = RetryConfig::fast();
534        assert_eq!(fast.max_attempts, 3);
535        assert_eq!(fast.initial_backoff, Duration::from_millis(200));
536    }
537}