multi_llm/internals/
retry.rs

1//! Retry logic with exponential backoff and rate limiting
2//!
3//! This module provides resilient HTTP request handling for LLM providers with:
4//! - Exponential backoff: 1s, 2s, 4s, 8s, 16s maximum
5//! - Rate limit handling with Retry-After header support
6//! - Circuit breaker pattern: 5 failures = 30s cooldown
7//! - Configurable timeout: 30s request, 5m total operation
8
9use crate::error::{LlmError, LlmResult};
10use crate::logging::{log_debug, log_error, log_warn};
11
12use std::time::{Duration, Instant};
13use tokio::time::sleep;
14
15/// Retry policy configuration for LLM requests
16#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub struct RetryPolicy {
18    /// Maximum number of retry attempts
19    pub max_attempts: u32,
20    /// Initial delay before first retry
21    pub initial_delay: Duration,
22    /// Maximum delay between retries
23    pub max_delay: Duration,
24    /// Multiplier for exponential backoff
25    pub backoff_multiplier: f64,
26    /// Maximum total operation time
27    pub total_timeout: Duration,
28    /// Request timeout for individual attempts
29    pub request_timeout: Duration,
30}
31
32impl Default for RetryPolicy {
33    fn default() -> Self {
34        Self {
35            max_attempts: 5,
36            initial_delay: Duration::from_secs(1),
37            max_delay: Duration::from_secs(16),
38            backoff_multiplier: 2.0,
39            total_timeout: Duration::from_secs(300), // 5 minutes
40            request_timeout: Duration::from_secs(120), // Increased for slower models like Hermes
41        }
42    }
43}
44
45/// Circuit breaker states
46#[derive(Debug, Clone, PartialEq)]
47pub(crate) enum CircuitState {
48    Closed,   // Normal operation
49    Open,     // Failing, blocking requests
50    HalfOpen, // Testing if service recovered
51}
52
53/// Circuit breaker for provider resilience
54#[derive(Debug)]
55pub(crate) struct CircuitBreaker {
56    pub(crate) state: CircuitState,
57    pub(crate) failure_count: u32,
58    pub(crate) last_failure_time: Option<Instant>,
59    pub(crate) failure_threshold: u32,
60    pub(crate) recovery_timeout: Duration,
61}
62
63impl Default for CircuitBreaker {
64    fn default() -> Self {
65        Self {
66            state: CircuitState::Closed,
67            failure_count: 0,
68            last_failure_time: None,
69            failure_threshold: 5,
70            recovery_timeout: Duration::from_secs(30),
71        }
72    }
73}
74
75impl CircuitBreaker {
76    /// Check if request should be allowed through the circuit breaker
77    pub fn should_allow_request(&mut self) -> bool {
78        match self.state {
79            CircuitState::Closed => true,
80            CircuitState::Open => self.check_recovery_timeout(),
81            CircuitState::HalfOpen => true,
82        }
83    }
84
85    /// Check if circuit breaker should transition from Open to HalfOpen
86    fn check_recovery_timeout(&mut self) -> bool {
87        let Some(last_failure) = self.last_failure_time else {
88            return false;
89        };
90
91        if last_failure.elapsed() >= self.recovery_timeout {
92            log_debug!(
93                circuit_breaker = "transitioning_to_half_open",
94                recovery_timeout_seconds = self.recovery_timeout.as_secs(),
95                "Circuit breaker attempting recovery"
96            );
97            self.state = CircuitState::HalfOpen;
98            true
99        } else {
100            false
101        }
102    }
103
104    /// Record a successful request
105    pub fn record_success(&mut self) {
106        match self.state {
107            CircuitState::HalfOpen => {
108                log_debug!(
109                    circuit_breaker = "recovered",
110                    "Circuit breaker recovered, returning to closed state"
111                );
112                self.state = CircuitState::Closed;
113                self.failure_count = 0;
114                self.last_failure_time = None;
115            }
116            CircuitState::Closed => {
117                self.failure_count = 0;
118            }
119            CircuitState::Open => {
120                // Shouldn't happen, but reset if it does
121                self.failure_count = 0;
122                self.last_failure_time = None;
123            }
124        }
125    }
126
127    /// Record a failed request
128    pub fn record_failure(&mut self) {
129        self.failure_count += 1;
130        self.last_failure_time = Some(Instant::now());
131
132        if self.failure_count >= self.failure_threshold {
133            if self.state != CircuitState::Open {
134                log_warn!(
135                    circuit_breaker = "opened",
136                    failure_count = self.failure_count,
137                    failure_threshold = self.failure_threshold,
138                    recovery_timeout_seconds = self.recovery_timeout.as_secs(),
139                    "Circuit breaker opened due to repeated failures"
140                );
141            }
142            self.state = CircuitState::Open;
143        }
144    }
145
146    /// Get current circuit breaker state
147    pub fn state(&self) -> CircuitState {
148        self.state.clone()
149    }
150}
151
152/// Retry executor that handles exponential backoff and circuit breaking
153#[derive(Debug)]
154pub(crate) struct RetryExecutor {
155    pub(crate) policy: RetryPolicy,
156    pub(crate) circuit_breaker: CircuitBreaker,
157}
158
159impl Default for RetryExecutor {
160    fn default() -> Self {
161        Self::new(RetryPolicy::default())
162    }
163}
164
165impl RetryExecutor {
166    /// Create a new retry executor with the given policy
167    pub fn new(policy: RetryPolicy) -> Self {
168        Self {
169            policy,
170            circuit_breaker: CircuitBreaker::default(),
171        }
172    }
173
174    /// Execute a request with retry logic and circuit breaking
175    pub async fn execute<F, Fut, T>(&mut self, operation: F) -> LlmResult<T>
176    where
177        F: Fn() -> Fut,
178        Fut: std::future::Future<Output = LlmResult<T>>,
179    {
180        let start_time = Instant::now();
181        let mut attempt = 0;
182        let mut last_error = None;
183
184        while attempt < self.policy.max_attempts {
185            self.check_circuit_breaker()?;
186            self.check_total_timeout(&start_time)?;
187
188            attempt += 1;
189
190            match self
191                .execute_single_attempt(&operation, attempt, &mut last_error)
192                .await
193            {
194                Ok(response) => return Ok(response),
195                Err(should_continue) => {
196                    if !should_continue {
197                        break;
198                    }
199                }
200            }
201        }
202
203        self.handle_exhausted_retries(attempt, last_error, &start_time)
204    }
205
206    /// Execute a single attempt and return whether to continue retrying
207    async fn execute_single_attempt<F, Fut, T>(
208        &mut self,
209        operation: &F,
210        attempt: u32,
211        last_error: &mut Option<LlmError>,
212    ) -> Result<T, bool>
213    where
214        F: Fn() -> Fut,
215        Fut: std::future::Future<Output = LlmResult<T>>,
216    {
217        self.log_attempt(attempt);
218
219        let operation_start = Instant::now();
220        let result = tokio::time::timeout(self.policy.request_timeout, operation()).await;
221
222        match result {
223            Ok(Ok(response)) => {
224                self.circuit_breaker.record_success();
225                log_debug!(
226                    attempt = attempt,
227                    duration_ms = operation_start.elapsed().as_millis(),
228                    "Request succeeded"
229                );
230                Ok(response)
231            }
232            Ok(Err(error)) => {
233                let should_continue = self.handle_error(error, attempt, last_error).await;
234                Err(should_continue)
235            }
236            Err(_timeout) => {
237                let should_continue = self.handle_timeout(attempt, last_error).await;
238                Err(should_continue)
239            }
240        }
241    }
242
243    fn check_circuit_breaker(&mut self) -> LlmResult<()> {
244        if !self.circuit_breaker.should_allow_request() {
245            return Err(LlmError::request_failed(
246                "Circuit breaker is open - service temporarily unavailable".to_string(),
247                None,
248            ));
249        }
250        Ok(())
251    }
252
253    fn check_total_timeout(&mut self, start_time: &Instant) -> LlmResult<()> {
254        if start_time.elapsed() >= self.policy.total_timeout {
255            return Err(LlmError::timeout(self.policy.total_timeout.as_secs()));
256        }
257        Ok(())
258    }
259
260    fn log_attempt(&mut self, attempt: u32) {
261        log_debug!(
262            attempt = attempt,
263            max_attempts = self.policy.max_attempts,
264            circuit_state = ?self.circuit_breaker.state(),
265            "Executing request with retry logic"
266        );
267    }
268
269    async fn handle_error(
270        &mut self,
271        error: LlmError,
272        attempt: u32,
273        last_error: &mut Option<LlmError>,
274    ) -> bool {
275        let should_retry = self.should_retry_error(&error);
276        *last_error = Some(error);
277
278        if should_retry && attempt < self.policy.max_attempts {
279            self.circuit_breaker.record_failure();
280            let delay = self.calculate_delay(attempt);
281            log_debug!(
282                attempt = attempt,
283                max_attempts = self.policy.max_attempts,
284                delay_ms = delay.as_millis(),
285                error = ?last_error.as_ref(),
286                "Request failed, retrying after delay"
287            );
288            sleep(delay).await;
289            true
290        } else {
291            self.circuit_breaker.record_failure();
292            false
293        }
294    }
295
296    async fn handle_timeout(&mut self, attempt: u32, last_error: &mut Option<LlmError>) -> bool {
297        let timeout_error = LlmError::timeout(self.policy.request_timeout.as_secs());
298        *last_error = Some(timeout_error);
299
300        if attempt < self.policy.max_attempts {
301            self.circuit_breaker.record_failure();
302            let delay = self.calculate_delay(attempt);
303            log_debug!(
304                attempt = attempt,
305                max_attempts = self.policy.max_attempts,
306                delay_ms = delay.as_millis(),
307                timeout_seconds = self.policy.request_timeout.as_secs(),
308                "Request timed out, retrying after delay"
309            );
310            sleep(delay).await;
311            true
312        } else {
313            self.circuit_breaker.record_failure();
314            false
315        }
316    }
317
318    fn handle_exhausted_retries<T>(
319        &mut self,
320        attempt: u32,
321        last_error: Option<LlmError>,
322        start_time: &Instant,
323    ) -> LlmResult<T> {
324        let final_error = last_error.unwrap_or_else(|| {
325            LlmError::request_failed("Maximum retry attempts exceeded".to_string(), None)
326        });
327
328        log_error!(
329            attempts = attempt,
330            total_duration_ms = start_time.elapsed().as_millis(),
331            circuit_state = ?self.circuit_breaker.state(),
332            error = %final_error,
333            "Request failed after all retry attempts"
334        );
335
336        Err(final_error)
337    }
338
339    /// Determine if an error should trigger a retry
340    fn should_retry_error(&self, error: &LlmError) -> bool {
341        match error {
342            LlmError::RequestFailed { .. } => true,
343            LlmError::Timeout { .. } => true,
344            LlmError::RateLimitExceeded { .. } => true,
345            LlmError::AuthenticationFailed { .. } => false, // Don't retry auth errors
346            LlmError::ConfigurationError { .. } => false,   // Don't retry config errors
347            LlmError::TokenLimitExceeded { .. } => false,   // Don't retry token limit errors
348            LlmError::UnsupportedProvider { .. } => false,  // Don't retry unsupported provider
349            LlmError::ResponseParsingError { .. } => false, // Don't retry parsing errors
350            LlmError::ToolExecutionFailed { .. } => false,  // Don't retry tool errors
351            LlmError::SchemaValidationFailed { .. } => false, // Don't retry schema errors
352        }
353    }
354
355    /// Calculate delay for exponential backoff
356    pub fn calculate_delay(&self, attempt: u32) -> Duration {
357        let delay_seconds = self.policy.initial_delay.as_secs_f64()
358            * self.policy.backoff_multiplier.powi((attempt - 1) as i32);
359
360        let delay = Duration::from_secs_f64(delay_seconds.min(self.policy.max_delay.as_secs_f64()));
361
362        // Add jitter to prevent thundering herd
363        let jitter = fastrand::f64() * 0.1; // Up to 10% jitter
364        Duration::from_secs_f64(delay.as_secs_f64() * (1.0 + jitter))
365    }
366}