use std::time::Duration;
use reqwest::StatusCode;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_retries: u32,
pub base_delay_ms: u64,
pub max_delay_ms: u64,
pub retryable_status_codes: Vec<u16>,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
base_delay_ms: 1000,
max_delay_ms: 30_000,
retryable_status_codes: vec![429, 500, 502, 503, 529],
}
}
}
impl RetryConfig {
pub fn disabled() -> Self {
Self {
max_retries: 0,
..Default::default()
}
}
pub fn is_retryable_status(&self, status: StatusCode) -> bool {
self.retryable_status_codes.contains(&status.as_u16())
}
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
let exp_delay = self.base_delay_ms.saturating_mul(1u64 << attempt.min(10));
let capped = exp_delay.min(self.max_delay_ms);
let jitter_range = capped / 4;
let jitter = if jitter_range > 0 {
let entropy = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.subsec_nanos() as u64)
.unwrap_or(0);
let jitter_offset = (entropy ^ (attempt as u64).wrapping_mul(0x517cc1b727220a95))
% (jitter_range * 2 + 1);
capped - jitter_range + jitter_offset
} else {
capped
};
Duration::from_millis(jitter)
}
pub fn parse_retry_after(header_value: Option<&str>) -> Option<Duration> {
let value = header_value?.trim();
if let Ok(seconds) = value.parse::<f64>() {
if seconds > 0.0 && seconds <= 300.0 {
return Some(Duration::from_secs_f64(seconds));
}
}
None
}
}
#[derive(Debug)]
pub enum AttemptOutcome<T> {
Success(T),
Retryable {
status: StatusCode,
body: String,
retry_after: Option<Duration>,
},
Fatal(anyhow::Error),
}
pub async fn with_retry<T, F, Fut>(config: &RetryConfig, operation: F) -> anyhow::Result<T>
where
F: Fn(u32) -> Fut,
Fut: std::future::Future<Output = AttemptOutcome<T>>,
{
let mut last_status = None;
let mut last_body = String::new();
for attempt in 0..=config.max_retries {
match operation(attempt).await {
AttemptOutcome::Success(value) => {
if attempt > 0 {
tracing::info!("LLM API request succeeded after {} retries", attempt);
}
return Ok(value);
}
AttemptOutcome::Fatal(err) => {
return Err(err);
}
AttemptOutcome::Retryable {
status,
body,
retry_after,
} => {
last_status = Some(status);
last_body = body;
if attempt < config.max_retries {
let delay = retry_after.unwrap_or_else(|| config.delay_for_attempt(attempt));
tracing::warn!(
"LLM API request failed with {} (attempt {}/{}), retrying in {:?}",
status,
attempt + 1,
config.max_retries + 1,
delay,
);
tokio::time::sleep(delay).await;
}
}
}
}
let status = last_status.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
anyhow::bail!(
"LLM API request failed after {} attempts. Last status: {} Body: {}",
config.max_retries + 1,
status,
last_body,
)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[test]
fn test_retry_config_default() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.base_delay_ms, 1000);
assert_eq!(config.max_delay_ms, 30_000);
assert_eq!(config.retryable_status_codes, vec![429, 500, 502, 503, 529]);
}
#[test]
fn test_retry_config_disabled() {
let config = RetryConfig::disabled();
assert_eq!(config.max_retries, 0);
}
#[test]
fn test_is_retryable_status() {
let config = RetryConfig::default();
assert!(config.is_retryable_status(StatusCode::TOO_MANY_REQUESTS)); assert!(config.is_retryable_status(StatusCode::INTERNAL_SERVER_ERROR)); assert!(config.is_retryable_status(StatusCode::BAD_GATEWAY)); assert!(config.is_retryable_status(StatusCode::SERVICE_UNAVAILABLE)); assert!(config.is_retryable_status(StatusCode::from_u16(529).unwrap()));
assert!(!config.is_retryable_status(StatusCode::OK)); assert!(!config.is_retryable_status(StatusCode::BAD_REQUEST)); assert!(!config.is_retryable_status(StatusCode::UNAUTHORIZED)); assert!(!config.is_retryable_status(StatusCode::FORBIDDEN)); assert!(!config.is_retryable_status(StatusCode::NOT_FOUND)); }
#[test]
fn test_delay_for_attempt_exponential() {
let config = RetryConfig {
base_delay_ms: 1000,
max_delay_ms: 60_000,
..Default::default()
};
let d0 = config.delay_for_attempt(0);
assert!(d0.as_millis() >= 750 && d0.as_millis() <= 1250);
let d1 = config.delay_for_attempt(1);
assert!(d1.as_millis() >= 1500 && d1.as_millis() <= 2500);
let d2 = config.delay_for_attempt(2);
assert!(d2.as_millis() >= 3000 && d2.as_millis() <= 5000);
}
#[test]
fn test_delay_capped_at_max() {
let config = RetryConfig {
base_delay_ms: 1000,
max_delay_ms: 5000,
..Default::default()
};
let d = config.delay_for_attempt(10);
assert!(d.as_millis() <= 6250); }
#[test]
fn test_delay_zero_base() {
let config = RetryConfig {
base_delay_ms: 0,
max_delay_ms: 1000,
..Default::default()
};
let d = config.delay_for_attempt(0);
assert_eq!(d.as_millis(), 0);
}
#[test]
fn test_parse_retry_after_integer() {
let d = RetryConfig::parse_retry_after(Some("5"));
assert_eq!(d, Some(Duration::from_secs(5)));
}
#[test]
fn test_parse_retry_after_decimal() {
let d = RetryConfig::parse_retry_after(Some("1.5"));
assert_eq!(d, Some(Duration::from_secs_f64(1.5)));
}
#[test]
fn test_parse_retry_after_none() {
assert_eq!(RetryConfig::parse_retry_after(None), None);
}
#[test]
fn test_parse_retry_after_invalid() {
assert_eq!(RetryConfig::parse_retry_after(Some("not-a-number")), None);
}
#[test]
fn test_parse_retry_after_negative() {
assert_eq!(RetryConfig::parse_retry_after(Some("-1")), None);
}
#[test]
fn test_parse_retry_after_zero() {
assert_eq!(RetryConfig::parse_retry_after(Some("0")), None);
}
#[test]
fn test_parse_retry_after_too_large() {
assert_eq!(RetryConfig::parse_retry_after(Some("301")), None);
}
#[test]
fn test_parse_retry_after_with_whitespace() {
let d = RetryConfig::parse_retry_after(Some(" 3 "));
assert_eq!(d, Some(Duration::from_secs(3)));
}
#[test]
fn test_retry_config_serde_roundtrip() {
let config = RetryConfig::default();
let json = serde_json::to_string(&config).unwrap();
let deserialized: RetryConfig = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.max_retries, config.max_retries);
assert_eq!(deserialized.base_delay_ms, config.base_delay_ms);
assert_eq!(deserialized.max_delay_ms, config.max_delay_ms);
assert_eq!(
deserialized.retryable_status_codes,
config.retryable_status_codes
);
}
#[test]
fn test_retry_config_deserialize_custom() {
let json = r#"{"max_retries":5,"base_delay_ms":500,"max_delay_ms":10000,"retryable_status_codes":[429,503]}"#;
let config: RetryConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.max_retries, 5);
assert_eq!(config.base_delay_ms, 500);
assert_eq!(config.max_delay_ms, 10_000);
assert_eq!(config.retryable_status_codes, vec![429, 503]);
}
#[tokio::test]
async fn test_with_retry_success_first_attempt() {
let config = RetryConfig::default();
let call_count = Arc::new(AtomicU32::new(0));
let cc = call_count.clone();
let result = with_retry(&config, |_attempt| {
let cc = cc.clone();
async move {
cc.fetch_add(1, Ordering::SeqCst);
AttemptOutcome::Success("ok")
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "ok");
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_with_retry_success_after_retries() {
let config = RetryConfig {
max_retries: 3,
base_delay_ms: 10, max_delay_ms: 50,
..Default::default()
};
let call_count = Arc::new(AtomicU32::new(0));
let cc = call_count.clone();
let result = with_retry(&config, |attempt| {
let cc = cc.clone();
async move {
cc.fetch_add(1, Ordering::SeqCst);
if attempt < 2 {
AttemptOutcome::Retryable {
status: StatusCode::TOO_MANY_REQUESTS,
body: "rate limited".to_string(),
retry_after: None,
}
} else {
AttemptOutcome::Success("recovered")
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "recovered");
assert_eq!(call_count.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn test_with_retry_all_retries_exhausted() {
let config = RetryConfig {
max_retries: 2,
base_delay_ms: 10,
max_delay_ms: 50,
..Default::default()
};
let call_count = Arc::new(AtomicU32::new(0));
let cc = call_count.clone();
let result: anyhow::Result<&str> = with_retry(&config, |_attempt| {
let cc = cc.clone();
async move {
cc.fetch_add(1, Ordering::SeqCst);
AttemptOutcome::Retryable {
status: StatusCode::SERVICE_UNAVAILABLE,
body: "service down".to_string(),
retry_after: None,
}
}
})
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("3 attempts")); assert!(err.contains("503"));
assert!(err.contains("service down"));
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_with_retry_fatal_error_no_retry() {
let config = RetryConfig {
max_retries: 3,
base_delay_ms: 10,
max_delay_ms: 50,
..Default::default()
};
let call_count = Arc::new(AtomicU32::new(0));
let cc = call_count.clone();
let result: anyhow::Result<&str> = with_retry(&config, |_attempt| {
let cc = cc.clone();
async move {
cc.fetch_add(1, Ordering::SeqCst);
AttemptOutcome::Fatal(anyhow::anyhow!("invalid API key"))
}
})
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("invalid API key"));
assert_eq!(call_count.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn test_with_retry_disabled() {
let config = RetryConfig::disabled();
let call_count = Arc::new(AtomicU32::new(0));
let cc = call_count.clone();
let result: anyhow::Result<&str> = with_retry(&config, |_attempt| {
let cc = cc.clone();
async move {
cc.fetch_add(1, Ordering::SeqCst);
AttemptOutcome::Retryable {
status: StatusCode::TOO_MANY_REQUESTS,
body: "rate limited".to_string(),
retry_after: None,
}
}
})
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn test_with_retry_respects_retry_after_header() {
let config = RetryConfig {
max_retries: 1,
base_delay_ms: 10,
max_delay_ms: 50,
..Default::default()
};
let call_count = Arc::new(AtomicU32::new(0));
let cc = call_count.clone();
let start = tokio::time::Instant::now();
let result = with_retry(&config, |attempt| {
let cc = cc.clone();
async move {
cc.fetch_add(1, Ordering::SeqCst);
if attempt == 0 {
AttemptOutcome::Retryable {
status: StatusCode::TOO_MANY_REQUESTS,
body: "rate limited".to_string(),
retry_after: Some(Duration::from_millis(100)),
}
} else {
AttemptOutcome::Success("ok")
}
}
})
.await;
assert!(result.is_ok());
assert!(start.elapsed() >= Duration::from_millis(90));
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
}