use std::future::Future;
use std::time::Duration;
use crate::error::LlmError;
use crate::provider::StatusTx;
const BASE_BACKOFF_SECS: u64 = 1;
pub(crate) fn retry_delay(response: &reqwest::Response, attempt: u32) -> Duration {
if let Some(val) = response.headers().get("retry-after")
&& let Ok(s) = val.to_str()
&& let Ok(secs) = s.parse::<u64>()
{
return Duration::from_secs(secs);
}
Duration::from_secs(BASE_BACKOFF_SECS << attempt)
}
pub(crate) async fn send_with_retry<F, Fut>(
provider_name: &str,
max_retries: u32,
status_tx: Option<&StatusTx>,
mut f: F,
) -> Result<reqwest::Response, LlmError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<reqwest::Response, reqwest::Error>>,
{
for attempt in 0..=max_retries {
let response = f().await.map_err(LlmError::Http)?;
let status = response.status();
if status == reqwest::StatusCode::TOO_MANY_REQUESTS
|| status == reqwest::StatusCode::SERVICE_UNAVAILABLE
{
if attempt == max_retries {
return Err(LlmError::RateLimited);
}
let delay = retry_delay(&response, attempt);
let msg = format!(
"{provider_name} rate limited or unavailable, retrying in {}s ({}/{})",
delay.as_secs(),
attempt + 1,
max_retries
);
if let Some(tx) = status_tx {
let _ = tx.send(msg.clone());
}
tracing::warn!("{msg}");
tokio::time::sleep(delay).await;
continue;
}
return Ok(response);
}
Err(LlmError::RateLimited)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn retry_delay_exponential_backoff() {
assert_eq!(BASE_BACKOFF_SECS, 1);
assert_eq!(BASE_BACKOFF_SECS << 1, 2);
assert_eq!(BASE_BACKOFF_SECS << 2, 4);
}
async fn spawn_mock_server(responses: Vec<&'static str>) -> (u16, tokio::task::JoinHandle<()>) {
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
let handle = tokio::spawn(async move {
for resp in responses {
let Ok((mut stream, _)) = listener.accept().await else {
break;
};
tokio::spawn(async move {
let (reader, mut writer) = stream.split();
let mut buf_reader = BufReader::new(reader);
let mut line = String::new();
loop {
line.clear();
buf_reader.read_line(&mut line).await.unwrap_or(0);
if line == "\r\n" || line == "\n" || line.is_empty() {
break;
}
}
writer.write_all(resp.as_bytes()).await.ok();
});
}
});
(port, handle)
}
#[tokio::test]
async fn send_with_retry_success_on_first_attempt() {
let ok_response = "HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok";
let (port, _handle) = spawn_mock_server(vec![ok_response]).await;
let client = reqwest::Client::new();
let url = format!("http://127.0.0.1:{port}/test");
let result = send_with_retry("test", 3, None, || {
let req = client.get(&url).build().unwrap();
let c = client.clone();
async move { c.execute(req).await }
})
.await;
assert!(result.is_ok(), "expected Ok, got: {result:?}");
assert_eq!(result.unwrap().status(), 200);
}
#[tokio::test]
async fn send_with_retry_exhausts_retries_returns_rate_limited() {
let rate_limit_response =
"HTTP/1.1 429 Too Many Requests\r\nRetry-After: 0\r\nContent-Length: 0\r\n\r\n";
let (port, _handle) =
spawn_mock_server(vec![rate_limit_response, rate_limit_response]).await;
let client = reqwest::Client::new();
let url = format!("http://127.0.0.1:{port}/test");
let result = send_with_retry("test", 1, None, || {
let req = client.get(&url).build().unwrap();
let c = client.clone();
async move { c.execute(req).await }
})
.await;
assert!(
matches!(result, Err(LlmError::RateLimited)),
"expected RateLimited, got: {result:?}"
);
}
#[tokio::test]
async fn send_with_retry_succeeds_after_one_429() {
let rate_limit_response =
"HTTP/1.1 429 Too Many Requests\r\nRetry-After: 0\r\nContent-Length: 0\r\n\r\n";
let ok_response = "HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok";
let (port, _handle) = spawn_mock_server(vec![rate_limit_response, ok_response]).await;
let client = reqwest::Client::new();
let url = format!("http://127.0.0.1:{port}/test");
let result = send_with_retry("test", 2, None, || {
let req = client.get(&url).build().unwrap();
let c = client.clone();
async move { c.execute(req).await }
})
.await;
assert!(
result.is_ok(),
"expected Ok after one retry, got: {result:?}"
);
assert_eq!(result.unwrap().status(), 200);
}
use proptest::prelude::*;
proptest! {
#[test]
fn retry_delay_range_always_valid(attempt in 0u32..63) {
let delay = Duration::from_secs(BASE_BACKOFF_SECS << attempt);
assert!(delay.as_secs() >= BASE_BACKOFF_SECS, "delay must be at least base backoff");
if attempt > 0 {
let prev = Duration::from_secs(BASE_BACKOFF_SECS << (attempt - 1));
assert_eq!(delay.as_secs(), prev.as_secs() * 2);
}
}
}
}