use std::sync::OnceLock;
use std::time::Duration;
use reqwest::{RequestBuilder, Response, StatusCode};
const DEFAULT_MAX_RETRIES: u32 = 3;
const DEFAULT_MAX_ATTEMPTS: u32 = DEFAULT_MAX_RETRIES + 1;
const DEFAULT_BASE_DELAY_MS: u64 = 500;
const DEFAULT_MAX_DELAY_MS: u64 = 30_000;
const RETRY_AFTER_CEILING: Duration = Duration::from_secs(60);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RetryConfig {
pub max_attempts: u32,
pub base_delay: Duration,
pub max_delay: Duration,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: DEFAULT_MAX_ATTEMPTS,
base_delay: Duration::from_millis(DEFAULT_BASE_DELAY_MS),
max_delay: Duration::from_millis(DEFAULT_MAX_DELAY_MS),
}
}
}
impl RetryConfig {
pub fn from_env() -> Self {
let default = Self::default();
let max_attempts = std::env::var("BAMBOO_LLM_MAX_RETRIES")
.ok()
.and_then(|v| v.trim().parse::<u32>().ok())
.map(|retries| retries.saturating_add(1))
.unwrap_or(default.max_attempts);
let base_delay = std::env::var("BAMBOO_LLM_RETRY_BASE_DELAY_MS")
.ok()
.and_then(|v| v.trim().parse::<u64>().ok())
.map(Duration::from_millis)
.unwrap_or(default.base_delay);
let max_delay = std::env::var("BAMBOO_LLM_RETRY_MAX_DELAY_MS")
.ok()
.and_then(|v| v.trim().parse::<u64>().ok())
.map(Duration::from_millis)
.unwrap_or(default.max_delay);
Self {
max_attempts: max_attempts.max(1),
base_delay,
max_delay,
}
}
}
pub fn global() -> &'static RetryConfig {
static CONFIG: OnceLock<RetryConfig> = OnceLock::new();
CONFIG.get_or_init(RetryConfig::from_env)
}
fn is_retryable_status(status: StatusCode) -> bool {
status == StatusCode::TOO_MANY_REQUESTS
|| matches!(
status,
StatusCode::INTERNAL_SERVER_ERROR
| StatusCode::BAD_GATEWAY
| StatusCode::SERVICE_UNAVAILABLE
| StatusCode::GATEWAY_TIMEOUT
)
}
fn is_retryable_reqwest_error(err: &reqwest::Error) -> bool {
err.is_timeout() || err.is_connect()
}
fn parse_retry_after(response: &Response) -> Option<Duration> {
let value = response
.headers()
.get(reqwest::header::RETRY_AFTER)?
.to_str()
.ok()?
.trim()
.to_string();
value.parse::<u64>().ok().map(Duration::from_secs)
}
fn backoff_delay(config: &RetryConfig, attempt: u32) -> Duration {
let factor = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
let raw_ms = (config.base_delay.as_millis() as u64).saturating_mul(factor);
let capped_ms = raw_ms.min(config.max_delay.as_millis() as u64);
Duration::from_millis(jitter_ms(capped_ms))
}
fn jitter_ms(ms: u64) -> u64 {
if ms == 0 {
return 0;
}
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos() as u64)
.unwrap_or(0);
let half = ms / 2;
let span = ms - half; half + (nanos % (span + 1))
}
pub async fn send_with_retry<F>(
config: &RetryConfig,
provider: &str,
build: F,
) -> reqwest::Result<Response>
where
F: Fn() -> RequestBuilder,
{
let max_attempts = config.max_attempts.max(1);
let mut attempt: u32 = 0;
loop {
let result = build().send().await;
let is_last = attempt + 1 >= max_attempts;
match result {
Ok(response) => {
let status = response.status();
if is_last || !is_retryable_status(status) {
return Ok(response);
}
let delay = parse_retry_after(&response)
.map(|d| d.min(RETRY_AFTER_CEILING))
.unwrap_or_else(|| backoff_delay(config, attempt));
tracing::warn!(
"[{provider}] transient HTTP {} on attempt {}/{}; retrying in {}ms",
status.as_u16(),
attempt + 1,
max_attempts,
delay.as_millis()
);
drop(response);
tokio::time::sleep(delay).await;
}
Err(err) => {
if is_last || !is_retryable_reqwest_error(&err) {
return Err(err);
}
let delay = backoff_delay(config, attempt);
tracing::warn!(
"[{provider}] transient transport error on attempt {}/{} ({}); retrying in {}ms",
attempt + 1,
max_attempts,
err,
delay.as_millis()
);
tokio::time::sleep(delay).await;
}
}
attempt += 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use wiremock::matchers::method;
use wiremock::{Mock, MockServer, Respond, ResponseTemplate};
fn fast_config(max_attempts: u32) -> RetryConfig {
RetryConfig {
max_attempts,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(5),
}
}
struct FlakyResponder {
fail_count: usize,
first_status: u16,
hits: Arc<AtomicUsize>,
}
impl Respond for FlakyResponder {
fn respond(&self, _req: &wiremock::Request) -> ResponseTemplate {
let n = self.hits.fetch_add(1, Ordering::SeqCst);
if n < self.fail_count {
ResponseTemplate::new(self.first_status)
} else {
ResponseTemplate::new(200).set_body_string("ok")
}
}
}
async fn run(config: &RetryConfig, server: &MockServer) -> reqwest::Result<reqwest::Response> {
let client = reqwest::Client::new();
let url = format!("{}/v1/test", server.uri());
send_with_retry(config, "test", || {
client.post(&url).json(&serde_json::json!({"a": 1}))
})
.await
}
#[tokio::test]
async fn retries_503_then_succeeds() {
let server = MockServer::start().await;
let hits = Arc::new(AtomicUsize::new(0));
Mock::given(method("POST"))
.respond_with(FlakyResponder {
fail_count: 1,
first_status: 503,
hits: hits.clone(),
})
.mount(&server)
.await;
let resp = run(&fast_config(3), &server).await.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(hits.load(Ordering::SeqCst), 2, "one 503 then one success");
assert_eq!(resp.text().await.unwrap(), "ok");
}
#[tokio::test]
async fn retries_429_then_succeeds() {
let server = MockServer::start().await;
let hits = Arc::new(AtomicUsize::new(0));
Mock::given(method("POST"))
.respond_with(FlakyResponder {
fail_count: 2,
first_status: 429,
hits: hits.clone(),
})
.mount(&server)
.await;
let resp = run(&fast_config(3), &server).await.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(hits.load(Ordering::SeqCst), 3, "two 429s then one success");
}
#[tokio::test]
async fn does_not_retry_400() {
let server = MockServer::start().await;
let hits = Arc::new(AtomicUsize::new(0));
Mock::given(method("POST"))
.respond_with(FlakyResponder {
fail_count: 99,
first_status: 400,
hits: hits.clone(),
})
.mount(&server)
.await;
let resp = run(&fast_config(3), &server).await.unwrap();
assert_eq!(resp.status(), 400);
assert_eq!(hits.load(Ordering::SeqCst), 1, "400 is not retried");
}
#[tokio::test]
async fn does_not_retry_401() {
let server = MockServer::start().await;
let hits = Arc::new(AtomicUsize::new(0));
Mock::given(method("POST"))
.respond_with(FlakyResponder {
fail_count: 99,
first_status: 401,
hits: hits.clone(),
})
.mount(&server)
.await;
let resp = run(&fast_config(3), &server).await.unwrap();
assert_eq!(resp.status(), 401);
assert_eq!(hits.load(Ordering::SeqCst), 1, "401 is not retried");
}
#[tokio::test]
async fn bounded_gives_up_after_max_attempts() {
let server = MockServer::start().await;
let hits = Arc::new(AtomicUsize::new(0));
Mock::given(method("POST"))
.respond_with(FlakyResponder {
fail_count: 99,
first_status: 503,
hits: hits.clone(),
})
.mount(&server)
.await;
let resp = run(&fast_config(3), &server).await.unwrap();
assert_eq!(resp.status(), 503);
assert_eq!(hits.load(Ordering::SeqCst), 3, "exactly max_attempts hits");
}
#[tokio::test]
async fn single_attempt_when_retries_disabled() {
let server = MockServer::start().await;
let hits = Arc::new(AtomicUsize::new(0));
Mock::given(method("POST"))
.respond_with(FlakyResponder {
fail_count: 99,
first_status: 503,
hits: hits.clone(),
})
.mount(&server)
.await;
let resp = run(&fast_config(1), &server).await.unwrap();
assert_eq!(resp.status(), 503);
assert_eq!(hits.load(Ordering::SeqCst), 1, "max_attempts=1 => no retry");
}
#[tokio::test]
async fn respects_retry_after_header() {
let server = MockServer::start().await;
let hits = Arc::new(AtomicUsize::new(0));
struct RetryAfterResponder {
hits: Arc<AtomicUsize>,
}
impl Respond for RetryAfterResponder {
fn respond(&self, _req: &wiremock::Request) -> ResponseTemplate {
let n = self.hits.fetch_add(1, Ordering::SeqCst);
if n == 0 {
ResponseTemplate::new(429).insert_header("Retry-After", "1")
} else {
ResponseTemplate::new(200).set_body_string("ok")
}
}
}
Mock::given(method("POST"))
.respond_with(RetryAfterResponder { hits: hits.clone() })
.mount(&server)
.await;
let started = std::time::Instant::now();
let resp = run(&fast_config(3), &server).await.unwrap();
let elapsed = started.elapsed();
assert_eq!(resp.status(), 200);
assert_eq!(hits.load(Ordering::SeqCst), 2);
assert!(
elapsed >= Duration::from_millis(900),
"Retry-After: 1s should dominate the tiny backoff (elapsed={elapsed:?})"
);
}
#[test]
fn from_env_defaults_and_overrides() {
std::env::remove_var("BAMBOO_LLM_MAX_RETRIES");
let cfg = RetryConfig::from_env();
assert_eq!(cfg.max_attempts, DEFAULT_MAX_ATTEMPTS);
assert_eq!(cfg.max_attempts, DEFAULT_MAX_RETRIES + 1);
std::env::set_var("BAMBOO_LLM_MAX_RETRIES", "5");
let cfg = RetryConfig::from_env();
assert_eq!(cfg.max_attempts, 6);
std::env::set_var("BAMBOO_LLM_MAX_RETRIES", "0");
let cfg = RetryConfig::from_env();
assert_eq!(cfg.max_attempts, 1);
std::env::remove_var("BAMBOO_LLM_MAX_RETRIES");
}
#[test]
fn status_classification() {
for s in [429u16, 500, 502, 503, 504] {
assert!(
is_retryable_status(StatusCode::from_u16(s).unwrap()),
"{s} retryable"
);
}
for s in [400u16, 401, 403, 404, 422, 200] {
assert!(
!is_retryable_status(StatusCode::from_u16(s).unwrap()),
"{s} not retryable"
);
}
}
}