use crate::io::AgentIO;
use anyhow::{anyhow, Result};
use rand::Rng;
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
#[serde(default = "RetryConfig::default_max_retries")]
pub max_retries: u32,
#[serde(default = "RetryConfig::default_initial_delay_ms")]
pub initial_delay_ms: u64,
#[serde(default = "RetryConfig::default_max_delay_ms")]
pub max_delay_ms: u64,
#[serde(default = "RetryConfig::default_backoff_multiplier")]
pub backoff_multiplier: f64,
}
impl RetryConfig {
fn default_max_retries() -> u32 {
5
}
fn default_initial_delay_ms() -> u64 {
1_000
}
fn default_max_delay_ms() -> u64 {
60_000
}
fn default_backoff_multiplier() -> f64 {
2.0
}
}
impl Default for RetryConfig {
fn default() -> Self {
let max_retries = std::env::var("XCODEAI_RETRY_MAX")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or_else(Self::default_max_retries);
let initial_delay_ms = std::env::var("XCODEAI_RETRY_INITIAL_MS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or_else(Self::default_initial_delay_ms);
RetryConfig {
max_retries,
initial_delay_ms,
max_delay_ms: Self::default_max_delay_ms(),
backoff_multiplier: Self::default_backoff_multiplier(),
}
}
}
#[derive(Debug, PartialEq)]
pub enum RetryDecision {
RetryAfter(Duration),
Fail,
}
pub fn classify_http_status(status: u16, fallback_delay: Duration) -> RetryDecision {
match status {
429 | 500 | 502 | 503 | 504 => RetryDecision::RetryAfter(fallback_delay),
400 | 422 => RetryDecision::Fail,
_ => RetryDecision::Fail,
}
}
#[allow(dead_code)]
pub fn parse_retry_after(header: Option<&str>, default: Duration) -> Duration {
match header {
Some(value) => {
if let Ok(secs) = value.trim().parse::<u64>() {
return Duration::from_secs(secs);
}
default
}
None => default,
}
}
fn add_jitter(duration: Duration) -> Duration {
let mut rng = rand::thread_rng();
let multiplier = 0.75 + rng.gen::<f64>() * 0.50;
let jittered_ms = (duration.as_millis() as f64 * multiplier) as u64;
Duration::from_millis(jittered_ms)
}
pub fn next_delay(config: &RetryConfig, current_delay: Duration) -> Duration {
let next_ms = (current_delay.as_millis() as f64 * config.backoff_multiplier) as u64;
let capped_ms = next_ms.min(config.max_delay_ms);
Duration::from_millis(capped_ms)
}
pub async fn retry_with_backoff<T, F, Fut>(
config: &RetryConfig,
io: &dyn AgentIO,
f: F,
) -> Result<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let mut current_delay = Duration::from_millis(config.initial_delay_ms);
let mut last_err: anyhow::Error = anyhow!("No attempts made");
for attempt in 0..=config.max_retries {
match f().await {
Ok(value) => return Ok(value),
Err(err) => {
let retryable = err.downcast_ref::<RetryableError>();
let retry_decision = match retryable {
Some(RetryableError::Http {
status,
retry_after,
}) => {
let server_delay = retry_after
.map(Duration::from_secs)
.unwrap_or(current_delay);
classify_http_status(*status, server_delay)
}
Some(RetryableError::Timeout) => {
RetryDecision::RetryAfter(current_delay)
}
Some(RetryableError::Network(_)) => {
RetryDecision::RetryAfter(current_delay)
}
None => {
return Err(err);
}
};
if attempt >= config.max_retries {
last_err = err;
break;
}
match retry_decision {
RetryDecision::Fail => {
return Err(err);
}
RetryDecision::RetryAfter(wait) => {
let jittered = add_jitter(wait);
let reason = match retryable {
Some(RetryableError::Http { status, .. }) => match status {
429 => "rate limit",
500 => "server error",
502 => "bad gateway",
503 => "service unavailable",
504 => "gateway timeout",
_ => "HTTP error",
},
Some(RetryableError::Timeout) => "timeout",
Some(RetryableError::Network(_)) => "network error",
None => "error",
};
io.show_status(&format!(
"⏳ Retrying in {:.1}s (attempt {}/{}) — {reason}",
jittered.as_secs_f64(),
attempt + 1,
config.max_retries,
))
.await
.ok();
tokio::time::sleep(jittered).await;
current_delay = next_delay(config, current_delay);
last_err = err;
}
}
}
}
}
Err(last_err)
}
#[derive(thiserror::Error, Debug)]
pub enum RetryableError {
#[error("HTTP {status} error")]
Http {
status: u16,
retry_after: Option<u64>,
},
#[error("Request timed out")]
Timeout,
#[error("Network error: {0}")]
Network(String),
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::NullIO;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[test]
fn test_classify_retryable_codes() {
let delay = Duration::from_secs(1);
for code in [429u16, 500, 502, 503, 504] {
assert_eq!(
classify_http_status(code, delay),
RetryDecision::RetryAfter(delay),
"expected RetryAfter for {code}"
);
}
}
#[test]
fn test_classify_permanent_codes() {
let delay = Duration::from_secs(1);
for code in [400u16, 422, 401, 403, 404] {
assert_eq!(
classify_http_status(code, delay),
RetryDecision::Fail,
"expected Fail for {code}"
);
}
}
#[test]
fn test_parse_retry_after_integer() {
let result = parse_retry_after(Some("30"), Duration::from_secs(5));
assert_eq!(result, Duration::from_secs(30));
}
#[test]
fn test_parse_retry_after_with_whitespace() {
let result = parse_retry_after(Some(" 15 "), Duration::from_secs(5));
assert_eq!(result, Duration::from_secs(15));
}
#[test]
fn test_parse_retry_after_none_returns_default() {
let default = Duration::from_secs(5);
let result = parse_retry_after(None, default);
assert_eq!(result, default);
}
#[test]
fn test_parse_retry_after_invalid_falls_back_to_default() {
let default = Duration::from_secs(5);
let result = parse_retry_after(Some("Wed, 01 Mar 2026 12:30:00 GMT"), default);
assert_eq!(result, default);
}
#[test]
fn test_next_delay_doubles_by_default() {
let config = RetryConfig::default();
let d = Duration::from_millis(1000);
let next = next_delay(&config, d);
assert_eq!(next, Duration::from_millis(2000));
}
#[test]
fn test_next_delay_caps_at_max() {
let config = RetryConfig {
max_delay_ms: 3000,
..RetryConfig::default()
};
let d = Duration::from_millis(2000);
let next = next_delay(&config, d);
assert_eq!(next, Duration::from_millis(3000));
}
#[tokio::test]
async fn test_no_retry_on_success() {
let config = RetryConfig {
max_retries: 3,
initial_delay_ms: 1, ..RetryConfig::default()
};
let io = NullIO;
let call_count = Arc::new(AtomicU32::new(0));
let cc = call_count.clone();
let result = retry_with_backoff(&config, &io, || {
let cc = cc.clone();
async move {
cc.fetch_add(1, Ordering::SeqCst);
Ok::<&str, anyhow::Error>("ok")
}
})
.await;
assert!(result.is_ok());
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_on_429_then_success() {
let config = RetryConfig {
max_retries: 3,
initial_delay_ms: 1,
max_delay_ms: 10,
backoff_multiplier: 2.0,
};
let io = NullIO;
let call_count = Arc::new(AtomicU32::new(0));
let cc = call_count.clone();
let result = retry_with_backoff(&config, &io, || {
let cc = cc.clone();
async move {
let n = cc.fetch_add(1, Ordering::SeqCst);
if n < 2 {
Err(anyhow::Error::new(RetryableError::Http {
status: 429,
retry_after: None,
}))
} else {
Ok("done")
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "done");
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_no_retry_on_400() {
let config = RetryConfig {
max_retries: 3,
initial_delay_ms: 1,
..RetryConfig::default()
};
let io = NullIO;
let call_count = Arc::new(AtomicU32::new(0));
let cc = call_count.clone();
let result = retry_with_backoff(&config, &io, || {
let cc = cc.clone();
async move {
cc.fetch_add(1, Ordering::SeqCst);
Err::<&str, _>(anyhow::Error::new(RetryableError::Http {
status: 400,
retry_after: None,
}))
}
})
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_exhausts_retries() {
let config = RetryConfig {
max_retries: 2,
initial_delay_ms: 1,
max_delay_ms: 10,
backoff_multiplier: 2.0,
};
let io = NullIO;
let call_count = Arc::new(AtomicU32::new(0));
let cc = call_count.clone();
let result = retry_with_backoff(&config, &io, || {
let cc = cc.clone();
async move {
cc.fetch_add(1, Ordering::SeqCst);
Err::<&str, _>(anyhow::Error::new(RetryableError::Http {
status: 503,
retry_after: None,
}))
}
})
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_non_retryable_error_propagates_immediately() {
let config = RetryConfig {
max_retries: 5,
initial_delay_ms: 1,
..RetryConfig::default()
};
let io = NullIO;
let call_count = Arc::new(AtomicU32::new(0));
let cc = call_count.clone();
let result = retry_with_backoff(&config, &io, || {
let cc = cc.clone();
async move {
cc.fetch_add(1, Ordering::SeqCst);
Err::<&str, _>(anyhow::anyhow!("Some parse error"))
}
})
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_after_header_respected() {
let wait = parse_retry_after(Some("1"), Duration::from_secs(60));
assert_eq!(wait, Duration::from_secs(1));
assert_ne!(wait, Duration::from_secs(60));
}
#[tokio::test]
async fn test_retry_on_network_error() {
let config = RetryConfig {
max_retries: 2,
initial_delay_ms: 1,
max_delay_ms: 5,
backoff_multiplier: 1.5,
};
let io = NullIO;
let call_count = Arc::new(AtomicU32::new(0));
let cc = call_count.clone();
let result = retry_with_backoff(&config, &io, || {
let cc = cc.clone();
async move {
let n = cc.fetch_add(1, Ordering::SeqCst);
if n < 1 {
Err(anyhow::Error::new(RetryableError::Network(
"connection refused".into(),
)))
} else {
Ok("recovered")
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
}