Skip to main content

outrig_cli/llm/
retry.rs

1//! Transient-error retry wrapper around a [`CompletionModel`].
2//!
3//! [`RetryingModel`] wraps the OpenAi-backed completion model so that each
4//! individual model call retries on transient HTTP failures (request timeouts,
5//! connection errors, 408/425/429/5xx) with bounded exponential backoff.
6//!
7//! Retrying at the model-call layer -- rather than around `agent.prompt(...)`
8//! -- is deliberate: a single user turn drives a model -> tool -> model loop,
9//! and wrapping the whole prompt would re-execute already-run container tool
10//! calls when a *later* model call fails. A failed model call has produced no
11//! tool calls yet, so retrying just that call replays nothing observable.
12
13use std::time::Duration;
14
15use rand::RngExt;
16use rig::completion::{CompletionError, CompletionModel, CompletionRequest, CompletionResponse};
17use rig::streaming::StreamingCompletionResponse;
18
19/// Retry attempts after the initial call, so up to `MAX_RETRIES + 1` calls.
20const MAX_RETRIES: usize = 2;
21/// First backoff delay; doubles each retry, capped at [`MAX_DELAY`].
22const BASE_DELAY: Duration = Duration::from_secs(1);
23/// Ceiling on a single backoff delay, before jitter is applied.
24const MAX_DELAY: Duration = Duration::from_secs(30);
25
26/// Wraps a [`CompletionModel`], retrying transient failures on every call.
27#[derive(Clone)]
28pub struct RetryingModel<M> {
29    inner: M,
30}
31
32impl<M> RetryingModel<M> {
33    pub fn new(inner: M) -> Self {
34        Self { inner }
35    }
36}
37
38impl<M: CompletionModel> CompletionModel for RetryingModel<M> {
39    type Response = M::Response;
40    type StreamingResponse = M::StreamingResponse;
41    type Client = M::Client;
42
43    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
44        Self::new(M::make(client, model))
45    }
46
47    async fn completion(
48        &self,
49        request: CompletionRequest,
50    ) -> std::result::Result<CompletionResponse<Self::Response>, CompletionError> {
51        let mut attempt = 0;
52        loop {
53            match self.inner.completion(request.clone()).await {
54                Ok(response) => return Ok(response),
55                Err(err) if attempt < MAX_RETRIES && is_transient(&err) => {
56                    let delay = backoff(attempt);
57                    eprintln!(
58                        "[outrig] LLM call failed ({err}); retrying in {:.1}s \
59                         (attempt {}/{MAX_RETRIES})",
60                        delay.as_secs_f64(),
61                        attempt + 1,
62                    );
63                    tokio::time::sleep(delay).await;
64                    attempt += 1;
65                }
66                Err(err) => return Err(err),
67            }
68        }
69    }
70
71    async fn stream(
72        &self,
73        request: CompletionRequest,
74    ) -> std::result::Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
75    {
76        // The OpenAi path is non-streaming in outrig; delegate without retry so
77        // the wrapper stays a faithful `CompletionModel` for any future caller.
78        self.inner.stream(request).await
79    }
80}
81
82/// Classify an error as a transient failure worth retrying: request timeouts,
83/// connection errors, and the retry-friendly HTTP status codes. Everything else
84/// (other 4xx, malformed/JSON responses, provider-reported errors) is terminal.
85fn is_transient(err: &CompletionError) -> bool {
86    use rig::http_client::Error as HttpError;
87    match err {
88        CompletionError::HttpError(HttpError::InvalidStatusCode(code))
89        | CompletionError::HttpError(HttpError::InvalidStatusCodeWithMessage(code, _)) => {
90            matches!(code.as_u16(), 408 | 425 | 429 | 500 | 502 | 503 | 504)
91        }
92        // `Instance` wraps the underlying reqwest transport error -- connection
93        // failures, read timeouts, and the like -- all retry-worthy.
94        CompletionError::HttpError(HttpError::Instance(_)) => true,
95        _ => false,
96    }
97}
98
99/// Pre-jitter backoff in seconds: `BASE_DELAY * 2^attempt`, capped at
100/// [`MAX_DELAY`]. Factored out so the (deterministic) schedule is testable.
101fn backoff_secs(attempt: usize) -> f64 {
102    let factor = 2f64.powi(attempt.min(16) as i32);
103    (BASE_DELAY.as_secs_f64() * factor).min(MAX_DELAY.as_secs_f64())
104}
105
106/// Backoff delay for a given retry attempt, with equal jitter: the capped delay
107/// scaled by a random factor in `[0.5, 1.0]` to avoid synchronized retries.
108fn backoff(attempt: usize) -> Duration {
109    let frac = rand::rng().random_range(0.5..=1.0);
110    Duration::from_secs_f64(backoff_secs(attempt) * frac)
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use reqwest::StatusCode;
117    use rig::http_client::Error as HttpError;
118
119    fn http_status(code: u16) -> CompletionError {
120        CompletionError::HttpError(HttpError::InvalidStatusCodeWithMessage(
121            StatusCode::from_u16(code).unwrap(),
122            "boom".to_string(),
123        ))
124    }
125
126    #[test]
127    fn retryable_status_codes_are_transient() {
128        for code in [408, 425, 429, 500, 502, 503, 504] {
129            assert!(is_transient(&http_status(code)), "{code} should retry");
130        }
131    }
132
133    #[test]
134    fn client_errors_are_terminal() {
135        for code in [400, 401, 403, 404, 422] {
136            assert!(!is_transient(&http_status(code)), "{code} should not retry");
137        }
138    }
139
140    #[test]
141    fn transport_errors_are_transient() {
142        let io = std::io::Error::new(std::io::ErrorKind::TimedOut, "read timed out");
143        let err = CompletionError::HttpError(HttpError::Instance(Box::new(io)));
144        assert!(is_transient(&err));
145    }
146
147    #[test]
148    fn non_http_errors_are_terminal() {
149        assert!(!is_transient(&CompletionError::ProviderError(
150            "nope".into()
151        )));
152        assert!(!is_transient(&CompletionError::ResponseError(
153            "nope".into()
154        )));
155    }
156
157    #[test]
158    fn backoff_schedule_grows_then_saturates() {
159        // 1, 2, 4, 8, 16, then capped at MAX_DELAY (30).
160        assert_eq!(backoff_secs(0), 1.0);
161        assert_eq!(backoff_secs(1), 2.0);
162        assert_eq!(backoff_secs(2), 4.0);
163        let max = MAX_DELAY.as_secs_f64();
164        assert_eq!(backoff_secs(5), max);
165        assert_eq!(backoff_secs(50), max);
166        // Monotonic non-decreasing.
167        for a in 0..20 {
168            assert!(backoff_secs(a) <= backoff_secs(a + 1));
169        }
170    }
171
172    #[test]
173    fn backoff_jitter_stays_within_bounds() {
174        for attempt in 0..=5 {
175            let cap = backoff_secs(attempt);
176            for _ in 0..100 {
177                let d = backoff(attempt).as_secs_f64();
178                assert!(d >= cap * 0.5 - f64::EPSILON, "{d} < {}", cap * 0.5);
179                assert!(d <= cap + f64::EPSILON, "{d} > {cap}");
180            }
181        }
182    }
183}