1use std::time::Duration;
14
15use rand::RngExt;
16use rig::completion::{CompletionError, CompletionModel, CompletionRequest, CompletionResponse};
17use rig::streaming::StreamingCompletionResponse;
18
19const MAX_RETRIES: usize = 2;
21const BASE_DELAY: Duration = Duration::from_secs(1);
23const MAX_DELAY: Duration = Duration::from_secs(30);
25
26#[derive(Clone)]
28pub struct RetryingModel<M> {
29 inner: M,
30}
31
32impl<M> RetryingModel<M> {
33 pub fn new(inner: M) -> Self {
34 Self { inner }
35 }
36}
37
38impl<M: CompletionModel> CompletionModel for RetryingModel<M> {
39 type Response = M::Response;
40 type StreamingResponse = M::StreamingResponse;
41 type Client = M::Client;
42
43 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
44 Self::new(M::make(client, model))
45 }
46
47 async fn completion(
48 &self,
49 request: CompletionRequest,
50 ) -> std::result::Result<CompletionResponse<Self::Response>, CompletionError> {
51 let mut attempt = 0;
52 loop {
53 match self.inner.completion(request.clone()).await {
54 Ok(response) => return Ok(response),
55 Err(err) if attempt < MAX_RETRIES && is_transient(&err) => {
56 let delay = backoff(attempt);
57 eprintln!(
58 "[outrig] LLM call failed ({err}); retrying in {:.1}s \
59 (attempt {}/{MAX_RETRIES})",
60 delay.as_secs_f64(),
61 attempt + 1,
62 );
63 tokio::time::sleep(delay).await;
64 attempt += 1;
65 }
66 Err(err) => return Err(err),
67 }
68 }
69 }
70
71 async fn stream(
72 &self,
73 request: CompletionRequest,
74 ) -> std::result::Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
75 {
76 self.inner.stream(request).await
79 }
80}
81
82fn is_transient(err: &CompletionError) -> bool {
86 use rig::http_client::Error as HttpError;
87 match err {
88 CompletionError::HttpError(HttpError::InvalidStatusCode(code))
89 | CompletionError::HttpError(HttpError::InvalidStatusCodeWithMessage(code, _)) => {
90 matches!(code.as_u16(), 408 | 425 | 429 | 500 | 502 | 503 | 504)
91 }
92 CompletionError::HttpError(HttpError::Instance(_)) => true,
95 _ => false,
96 }
97}
98
99fn backoff_secs(attempt: usize) -> f64 {
102 let factor = 2f64.powi(attempt.min(16) as i32);
103 (BASE_DELAY.as_secs_f64() * factor).min(MAX_DELAY.as_secs_f64())
104}
105
106fn backoff(attempt: usize) -> Duration {
109 let frac = rand::rng().random_range(0.5..=1.0);
110 Duration::from_secs_f64(backoff_secs(attempt) * frac)
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use reqwest::StatusCode;
117 use rig::http_client::Error as HttpError;
118
119 fn http_status(code: u16) -> CompletionError {
120 CompletionError::HttpError(HttpError::InvalidStatusCodeWithMessage(
121 StatusCode::from_u16(code).unwrap(),
122 "boom".to_string(),
123 ))
124 }
125
126 #[test]
127 fn retryable_status_codes_are_transient() {
128 for code in [408, 425, 429, 500, 502, 503, 504] {
129 assert!(is_transient(&http_status(code)), "{code} should retry");
130 }
131 }
132
133 #[test]
134 fn client_errors_are_terminal() {
135 for code in [400, 401, 403, 404, 422] {
136 assert!(!is_transient(&http_status(code)), "{code} should not retry");
137 }
138 }
139
140 #[test]
141 fn transport_errors_are_transient() {
142 let io = std::io::Error::new(std::io::ErrorKind::TimedOut, "read timed out");
143 let err = CompletionError::HttpError(HttpError::Instance(Box::new(io)));
144 assert!(is_transient(&err));
145 }
146
147 #[test]
148 fn non_http_errors_are_terminal() {
149 assert!(!is_transient(&CompletionError::ProviderError(
150 "nope".into()
151 )));
152 assert!(!is_transient(&CompletionError::ResponseError(
153 "nope".into()
154 )));
155 }
156
157 #[test]
158 fn backoff_schedule_grows_then_saturates() {
159 assert_eq!(backoff_secs(0), 1.0);
161 assert_eq!(backoff_secs(1), 2.0);
162 assert_eq!(backoff_secs(2), 4.0);
163 let max = MAX_DELAY.as_secs_f64();
164 assert_eq!(backoff_secs(5), max);
165 assert_eq!(backoff_secs(50), max);
166 for a in 0..20 {
168 assert!(backoff_secs(a) <= backoff_secs(a + 1));
169 }
170 }
171
172 #[test]
173 fn backoff_jitter_stays_within_bounds() {
174 for attempt in 0..=5 {
175 let cap = backoff_secs(attempt);
176 for _ in 0..100 {
177 let d = backoff(attempt).as_secs_f64();
178 assert!(d >= cap * 0.5 - f64::EPSILON, "{d} < {}", cap * 0.5);
179 assert!(d <= cap + f64::EPSILON, "{d} > {cap}");
180 }
181 }
182 }
183}