Skip to main content

agent_code_lib/llm/
retry.rs

1//! Retry logic and streaming fallback handling.
2//!
3//! When streaming fails mid-response, the retry handler can:
4//! - Discard partial tool executions with synthetic error blocks
5//! - Fall back to a smaller model on repeated overload errors
6//! - Apply exponential backoff with jitter
7
8use std::time::Duration;
9
10/// Retry configuration.
11#[derive(Debug, Clone)]
12pub struct RetryConfig {
13    /// Maximum retry attempts for transient errors.
14    pub max_retries: u32,
15    /// Initial backoff duration.
16    pub initial_backoff: Duration,
17    /// Maximum backoff duration.
18    pub max_backoff: Duration,
19    /// Backoff multiplier (exponential).
20    pub multiplier: f64,
21    /// Maximum 529 (overloaded) retries before falling back.
22    pub max_overload_retries: u32,
23}
24
25impl Default for RetryConfig {
26    fn default() -> Self {
27        Self {
28            max_retries: 3,
29            initial_backoff: Duration::from_millis(1000),
30            max_backoff: Duration::from_secs(60),
31            multiplier: 2.0,
32            max_overload_retries: 3,
33        }
34    }
35}
36
37/// State tracker for retry logic across multiple attempts.
38#[derive(Debug, Default)]
39pub struct RetryState {
40    /// Number of consecutive failures.
41    pub consecutive_failures: u32,
42    /// Number of 429 (rate limit) retries.
43    pub rate_limit_retries: u32,
44    /// Number of 529 (overload) retries.
45    pub overload_retries: u32,
46    /// Whether we've fallen back to the smaller model.
47    pub using_fallback: bool,
48}
49
50impl RetryState {
51    /// Determine the next action after a failure.
52    pub fn next_action(&mut self, error: &RetryableError, config: &RetryConfig) -> RetryAction {
53        self.consecutive_failures += 1;
54
55        match error {
56            RetryableError::RateLimited { retry_after } => {
57                self.rate_limit_retries += 1;
58                if self.rate_limit_retries > config.max_retries {
59                    return RetryAction::Abort("Rate limit retries exhausted".into());
60                }
61                RetryAction::Retry {
62                    after: Duration::from_millis(*retry_after),
63                }
64            }
65            RetryableError::Overloaded => {
66                self.overload_retries += 1;
67                if self.overload_retries > config.max_overload_retries {
68                    if !self.using_fallback {
69                        self.using_fallback = true;
70                        self.overload_retries = 0;
71                        return RetryAction::FallbackModel;
72                    }
73                    return RetryAction::Abort("Overload retries exhausted on fallback".into());
74                }
75                let backoff = calculate_backoff(
76                    self.overload_retries,
77                    config.initial_backoff,
78                    config.max_backoff,
79                    config.multiplier,
80                );
81                RetryAction::Retry { after: backoff }
82            }
83            RetryableError::StreamInterrupted => {
84                if self.consecutive_failures > config.max_retries {
85                    return RetryAction::Abort("Stream retry limit reached".into());
86                }
87                let backoff = calculate_backoff(
88                    self.consecutive_failures,
89                    config.initial_backoff,
90                    config.max_backoff,
91                    config.multiplier,
92                );
93                RetryAction::Retry { after: backoff }
94            }
95            RetryableError::NonRetryable(msg) => RetryAction::Abort(msg.clone()),
96        }
97    }
98
99    /// Reset counters after a successful call.
100    pub fn reset(&mut self) {
101        self.consecutive_failures = 0;
102        self.rate_limit_retries = 0;
103        // Don't reset overload_retries or using_fallback — those persist.
104    }
105}
106
107/// Categorized error for retry logic.
108pub enum RetryableError {
109    RateLimited { retry_after: u64 },
110    Overloaded,
111    StreamInterrupted,
112    NonRetryable(String),
113}
114
115/// Action the caller should take after a failure.
116#[derive(Debug)]
117pub enum RetryAction {
118    /// Wait and retry with the same model.
119    Retry { after: Duration },
120    /// Switch to the fallback model and retry.
121    FallbackModel,
122    /// Give up — unrecoverable.
123    Abort(String),
124}
125
126/// Calculate exponential backoff with jitter.
127fn calculate_backoff(attempt: u32, initial: Duration, max: Duration, multiplier: f64) -> Duration {
128    let base = initial.as_millis() as f64 * multiplier.powi(attempt as i32 - 1);
129    let capped = base.min(max.as_millis() as f64);
130    // Add 10% jitter.
131    let jitter = capped * 0.1 * rand_f64();
132    Duration::from_millis((capped + jitter) as u64)
133}
134
135/// Simple pseudo-random f64 in [0, 1) using timestamp.
136fn rand_f64() -> f64 {
137    let nanos = std::time::SystemTime::now()
138        .duration_since(std::time::UNIX_EPOCH)
139        .unwrap_or_default()
140        .subsec_nanos();
141    (nanos % 1000) as f64 / 1000.0
142}
143
144#[cfg(test)]
145mod tests {
146    use super::*;
147
148    #[test]
149    fn test_default_config() {
150        let c = RetryConfig::default();
151        assert_eq!(c.max_retries, 3);
152        assert!(c.multiplier > 1.0);
153    }
154
155    #[test]
156    fn test_retry_on_rate_limit() {
157        let mut state = RetryState::default();
158        let config = RetryConfig::default();
159        let err = RetryableError::RateLimited { retry_after: 500 };
160        match state.next_action(&err, &config) {
161            RetryAction::Retry { after } => assert!(after.as_millis() >= 500),
162            other => panic!("Expected Retry, got {other:?}"),
163        }
164    }
165
166    #[test]
167    fn test_retry_exhaustion() {
168        let mut state = RetryState::default();
169        let config = RetryConfig {
170            max_retries: 1,
171            ..Default::default()
172        };
173        let err = RetryableError::RateLimited { retry_after: 100 };
174        let _ = state.next_action(&err, &config); // First retry.
175        match state.next_action(&err, &config) {
176            RetryAction::Abort(_) => {}
177            other => panic!("Expected Abort, got {other:?}"),
178        }
179    }
180
181    #[test]
182    fn test_non_retryable_aborts() {
183        let mut state = RetryState::default();
184        let config = RetryConfig::default();
185        let err = RetryableError::NonRetryable("bad request".into());
186        match state.next_action(&err, &config) {
187            RetryAction::Abort(msg) => assert!(msg.contains("bad request")),
188            other => panic!("Expected Abort, got {other:?}"),
189        }
190    }
191
192    #[test]
193    fn test_overload_escalates_to_fallback() {
194        let mut state = RetryState::default();
195        let config = RetryConfig {
196            max_overload_retries: 2,
197            ..Default::default()
198        };
199        let err = RetryableError::Overloaded;
200        let _ = state.next_action(&err, &config);
201        let _ = state.next_action(&err, &config);
202        match state.next_action(&err, &config) {
203            RetryAction::FallbackModel => {}
204            other => panic!("Expected FallbackModel, got {other:?}"),
205        }
206    }
207
208    #[test]
209    fn test_reset_preserves_fallback() {
210        let mut state = RetryState {
211            using_fallback: true,
212            consecutive_failures: 5,
213            ..Default::default()
214        };
215        state.reset();
216        assert_eq!(state.consecutive_failures, 0);
217        assert!(state.using_fallback); // Preserved.
218    }
219
220    #[test]
221    fn test_backoff_increases_with_attempt() {
222        let initial = Duration::from_millis(1000);
223        let max = Duration::from_secs(60);
224        let multiplier = 2.0;
225
226        let _b1 = calculate_backoff(1, initial, max, multiplier);
227        let b2 = calculate_backoff(2, initial, max, multiplier);
228        let b3 = calculate_backoff(3, initial, max, multiplier);
229
230        // Each attempt should generally produce a larger backoff (before jitter caps).
231        // With multiplier 2.0: attempt 1 ~1s, attempt 2 ~2s, attempt 3 ~4s.
232        assert!(b2.as_millis() >= 1500, "b2 should be >= 1.5s, got {:?}", b2);
233        assert!(b3.as_millis() >= 3000, "b3 should be >= 3s, got {:?}", b3);
234    }
235
236    #[test]
237    fn test_reset_clears_rate_limit_retries() {
238        let mut state = RetryState {
239            consecutive_failures: 3,
240            rate_limit_retries: 5,
241            overload_retries: 2,
242            using_fallback: false,
243        };
244        state.reset();
245        assert_eq!(state.rate_limit_retries, 0);
246        assert_eq!(state.consecutive_failures, 0);
247        // overload_retries and using_fallback persist.
248        assert_eq!(state.overload_retries, 2);
249    }
250
251    #[test]
252    fn test_overloads_then_fallback_then_abort() {
253        let mut state = RetryState::default();
254        let config = RetryConfig {
255            max_overload_retries: 1,
256            ..Default::default()
257        };
258        let err = RetryableError::Overloaded;
259
260        // First overload: retry with backoff.
261        match state.next_action(&err, &config) {
262            RetryAction::Retry { .. } => {}
263            other => panic!("Expected Retry, got {other:?}"),
264        }
265
266        // Second overload: exceeds max_overload_retries, triggers fallback.
267        match state.next_action(&err, &config) {
268            RetryAction::FallbackModel => {}
269            other => panic!("Expected FallbackModel, got {other:?}"),
270        }
271        assert!(state.using_fallback);
272
273        // Now on fallback model, overload again: retry.
274        match state.next_action(&err, &config) {
275            RetryAction::Retry { .. } => {}
276            other => panic!("Expected Retry on fallback, got {other:?}"),
277        }
278
279        // Exceed overloads on fallback: abort.
280        match state.next_action(&err, &config) {
281            RetryAction::Abort(msg) => assert!(msg.contains("fallback")),
282            other => panic!("Expected Abort, got {other:?}"),
283        }
284    }
285
286    #[test]
287    fn test_stream_interrupted_retries_then_aborts() {
288        let mut state = RetryState::default();
289        let config = RetryConfig {
290            max_retries: 2,
291            ..Default::default()
292        };
293        let err = RetryableError::StreamInterrupted;
294
295        // First two interruptions should retry.
296        match state.next_action(&err, &config) {
297            RetryAction::Retry { .. } => {}
298            other => panic!("Expected Retry, got {other:?}"),
299        }
300        match state.next_action(&err, &config) {
301            RetryAction::Retry { .. } => {}
302            other => panic!("Expected Retry, got {other:?}"),
303        }
304
305        // Third interruption exceeds max_retries => abort.
306        match state.next_action(&err, &config) {
307            RetryAction::Abort(msg) => assert!(msg.contains("Stream")),
308            other => panic!("Expected Abort, got {other:?}"),
309        }
310    }
311
312    #[test]
313    fn test_retry_state_default_values() {
314        let state = RetryState::default();
315        assert_eq!(state.consecutive_failures, 0);
316        assert_eq!(state.rate_limit_retries, 0);
317        assert_eq!(state.overload_retries, 0);
318        assert!(!state.using_fallback);
319    }
320}