use std::time::Duration;
use rand::RngExt;
use rig::completion::{CompletionError, CompletionModel, CompletionRequest, CompletionResponse};
use rig::streaming::StreamingCompletionResponse;
const MAX_RETRIES: usize = 2;
const BASE_DELAY: Duration = Duration::from_secs(1);
const MAX_DELAY: Duration = Duration::from_secs(30);
#[derive(Clone)]
pub struct RetryingModel<M> {
inner: M,
}
impl<M> RetryingModel<M> {
pub fn new(inner: M) -> Self {
Self { inner }
}
}
impl<M: CompletionModel> CompletionModel for RetryingModel<M> {
type Response = M::Response;
type StreamingResponse = M::StreamingResponse;
type Client = M::Client;
fn make(client: &Self::Client, model: impl Into<String>) -> Self {
Self::new(M::make(client, model))
}
async fn completion(
&self,
request: CompletionRequest,
) -> std::result::Result<CompletionResponse<Self::Response>, CompletionError> {
let mut attempt = 0;
loop {
match self.inner.completion(request.clone()).await {
Ok(response) => return Ok(response),
Err(err) if attempt < MAX_RETRIES && is_transient(&err) => {
let delay = backoff(attempt);
eprintln!(
"[outrig] LLM call failed ({err}); retrying in {:.1}s \
(attempt {}/{MAX_RETRIES})",
delay.as_secs_f64(),
attempt + 1,
);
tokio::time::sleep(delay).await;
attempt += 1;
}
Err(err) => return Err(err),
}
}
}
async fn stream(
&self,
request: CompletionRequest,
) -> std::result::Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError>
{
self.inner.stream(request).await
}
}
fn is_transient(err: &CompletionError) -> bool {
use rig::http_client::Error as HttpError;
match err {
CompletionError::HttpError(HttpError::InvalidStatusCode(code))
| CompletionError::HttpError(HttpError::InvalidStatusCodeWithMessage(code, _)) => {
matches!(code.as_u16(), 408 | 425 | 429 | 500 | 502 | 503 | 504)
}
CompletionError::HttpError(HttpError::Instance(_)) => true,
_ => false,
}
}
fn backoff_secs(attempt: usize) -> f64 {
let factor = 2f64.powi(attempt.min(16) as i32);
(BASE_DELAY.as_secs_f64() * factor).min(MAX_DELAY.as_secs_f64())
}
fn backoff(attempt: usize) -> Duration {
let frac = rand::rng().random_range(0.5..=1.0);
Duration::from_secs_f64(backoff_secs(attempt) * frac)
}
#[cfg(test)]
mod tests {
use super::*;
use reqwest::StatusCode;
use rig::http_client::Error as HttpError;
fn http_status(code: u16) -> CompletionError {
CompletionError::HttpError(HttpError::InvalidStatusCodeWithMessage(
StatusCode::from_u16(code).unwrap(),
"boom".to_string(),
))
}
#[test]
fn retryable_status_codes_are_transient() {
for code in [408, 425, 429, 500, 502, 503, 504] {
assert!(is_transient(&http_status(code)), "{code} should retry");
}
}
#[test]
fn client_errors_are_terminal() {
for code in [400, 401, 403, 404, 422] {
assert!(!is_transient(&http_status(code)), "{code} should not retry");
}
}
#[test]
fn transport_errors_are_transient() {
let io = std::io::Error::new(std::io::ErrorKind::TimedOut, "read timed out");
let err = CompletionError::HttpError(HttpError::Instance(Box::new(io)));
assert!(is_transient(&err));
}
#[test]
fn non_http_errors_are_terminal() {
assert!(!is_transient(&CompletionError::ProviderError(
"nope".into()
)));
assert!(!is_transient(&CompletionError::ResponseError(
"nope".into()
)));
}
#[test]
fn backoff_schedule_grows_then_saturates() {
assert_eq!(backoff_secs(0), 1.0);
assert_eq!(backoff_secs(1), 2.0);
assert_eq!(backoff_secs(2), 4.0);
let max = MAX_DELAY.as_secs_f64();
assert_eq!(backoff_secs(5), max);
assert_eq!(backoff_secs(50), max);
for a in 0..20 {
assert!(backoff_secs(a) <= backoff_secs(a + 1));
}
}
#[test]
fn backoff_jitter_stays_within_bounds() {
for attempt in 0..=5 {
let cap = backoff_secs(attempt);
for _ in 0..100 {
let d = backoff(attempt).as_secs_f64();
assert!(d >= cap * 0.5 - f64::EPSILON, "{d} < {}", cap * 0.5);
assert!(d <= cap + f64::EPSILON, "{d} > {cap}");
}
}
}
}