use std::sync::Arc;
use std::time::Duration;
use crate::error::Error;
use crate::llm::types::{CompletionRequest, CompletionResponse};
use super::LlmProvider;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub base_delay: Duration,
pub max_delay: Duration,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
base_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(30),
}
}
}
pub type OnRetry = dyn Fn(u32, u32, u64, &str) + Send + Sync;
pub struct RetryingProvider<P> {
inner: P,
config: RetryConfig,
on_retry: Option<Arc<OnRetry>>,
}
impl<P> RetryingProvider<P> {
pub fn new(inner: P, config: RetryConfig) -> Self {
Self {
inner,
config,
on_retry: None,
}
}
pub fn with_defaults(inner: P) -> Self {
Self::new(inner, RetryConfig::default())
}
pub fn with_on_retry(mut self, callback: Arc<OnRetry>) -> Self {
self.on_retry = Some(callback);
self
}
}
fn classify_for_retry(err: &Error) -> &'static str {
match err {
Error::Api { status: 429, .. } => "rate_limited",
Error::Api { status: 500, .. } => "server_error_500",
Error::Api { status: 502, .. } => "server_error_502",
Error::Api { status: 503, .. } => "server_error_503",
Error::Api { status: 529, .. } => "overloaded",
Error::Http(_) => "network_error",
_ => "unknown",
}
}
fn is_retryable(err: &Error) -> bool {
match err {
Error::Api { status, .. } => matches!(*status, 429 | 500 | 502 | 503 | 529),
Error::Http(_) => true,
_ => false,
}
}
fn compute_delay(config: &RetryConfig, attempt: u32) -> Duration {
use std::sync::atomic::{AtomicU64, Ordering};
let base_ms = config.base_delay.as_millis() as u64;
let max_ms = config.max_delay.as_millis() as u64;
static SEED: AtomicU64 = AtomicU64::new(0x9E3779B97F4A7C15);
let prev_max_ms = base_ms.saturating_mul(1u64.checked_shl(attempt).unwrap_or(u32::MAX as u64));
let upper = prev_max_ms.saturating_mul(3).min(max_ms.max(base_ms));
let lower = base_ms.min(upper);
let next = SEED
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |s| {
Some(s.wrapping_mul(1664525).wrapping_add(1013904223))
})
.unwrap_or(0);
let span = upper - lower + 1;
let pick = lower + (next % span);
Duration::from_millis(pick.min(max_ms))
}
impl<P: LlmProvider> LlmProvider for RetryingProvider<P> {
fn model_name(&self) -> Option<&str> {
self.inner.model_name()
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, Error> {
let mut last_err: Option<Error> = None;
for attempt in 0..=self.config.max_retries {
if attempt > 0 {
let delay = compute_delay(&self.config, attempt - 1);
let delay_ms = delay.as_millis() as u64;
let error_class =
classify_for_retry(last_err.as_ref().expect("last_err set before retry"));
if let Some(ref cb) = self.on_retry {
cb(attempt, self.config.max_retries, delay_ms, error_class);
}
tracing::warn!(
attempt = attempt,
max_retries = self.config.max_retries,
delay_ms = delay_ms,
error = %last_err.as_ref().expect("last_err set before retry"),
"retrying LLM call after transient failure"
);
tokio::time::sleep(delay).await;
}
match self.inner.complete(request.clone()).await {
Ok(response) => return Ok(response),
Err(e) if is_retryable(&e) => {
last_err = Some(e);
}
Err(e) => return Err(e),
}
}
Err(last_err.expect("at least one attempt must have been made"))
}
async fn stream_complete(
&self,
request: CompletionRequest,
on_text: &super::OnText,
) -> Result<CompletionResponse, Error> {
let mut last_err: Option<Error> = None;
fn noop_text(_: &str) {}
let noop: &super::OnText = &noop_text;
for attempt in 0..=self.config.max_retries {
if attempt > 0 {
let delay = compute_delay(&self.config, attempt - 1);
let delay_ms = delay.as_millis() as u64;
let error_class =
classify_for_retry(last_err.as_ref().expect("last_err set before retry"));
if let Some(ref cb) = self.on_retry {
cb(attempt, self.config.max_retries, delay_ms, error_class);
}
tracing::warn!(
attempt = attempt,
max_retries = self.config.max_retries,
delay_ms = delay_ms,
error = %last_err.as_ref().expect("last_err set before retry"),
"retrying streaming LLM call after transient failure (streaming suppressed)"
);
tokio::time::sleep(delay).await;
}
let callback = if attempt == 0 { on_text } else { &noop };
match self.inner.stream_complete(request.clone(), callback).await {
Ok(response) => return Ok(response),
Err(e) if is_retryable(&e) => {
last_err = Some(e);
}
Err(e) => return Err(e),
}
}
Err(last_err.expect("at least one attempt must have been made"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::types::{Message, StopReason, TokenUsage};
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
struct FailNTimes {
remaining_failures: AtomicU32,
error_factory: Box<dyn Fn() -> Error + Send + Sync>,
call_count: Arc<AtomicU32>,
}
impl FailNTimes {
fn new(
failures: u32,
error_factory: impl Fn() -> Error + Send + Sync + 'static,
) -> (Self, Arc<AtomicU32>) {
let count = Arc::new(AtomicU32::new(0));
(
Self {
remaining_failures: AtomicU32::new(failures),
error_factory: Box::new(error_factory),
call_count: count.clone(),
},
count,
)
}
}
fn success_response() -> CompletionResponse {
CompletionResponse {
content: vec![crate::llm::types::ContentBlock::Text { text: "ok".into() }],
stop_reason: StopReason::EndTurn,
usage: TokenUsage {
input_tokens: 10,
output_tokens: 5,
..Default::default()
},
model: None,
}
}
impl LlmProvider for FailNTimes {
async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, Error> {
self.call_count.fetch_add(1, Ordering::SeqCst);
if self
.remaining_failures
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
if v > 0 { Some(v - 1) } else { None }
})
.is_ok()
{
return Err((self.error_factory)());
}
Ok(success_response())
}
}
fn test_request() -> CompletionRequest {
CompletionRequest {
system: String::new(),
messages: vec![Message::user("test")],
tools: vec![],
max_tokens: 100,
tool_choice: None,
reasoning_effort: None,
}
}
fn fast_config(max_retries: u32) -> RetryConfig {
RetryConfig {
max_retries,
base_delay: Duration::from_millis(1), max_delay: Duration::from_millis(10),
}
}
#[tokio::test]
async fn succeeds_on_first_attempt() {
let (mock, count) = FailNTimes::new(0, || Error::Api {
status: 429,
message: "rate limited".into(),
});
let provider = RetryingProvider::new(mock, fast_config(3));
let result = provider.complete(test_request()).await;
assert!(result.is_ok());
assert_eq!(count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn retries_on_429_and_succeeds() {
let (mock, count) = FailNTimes::new(2, || Error::Api {
status: 429,
message: "rate limited".into(),
});
let provider = RetryingProvider::new(mock, fast_config(3));
let result = provider.complete(test_request()).await;
assert!(result.is_ok());
assert_eq!(count.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn retries_on_500_and_succeeds() {
let (mock, count) = FailNTimes::new(1, || Error::Api {
status: 500,
message: "internal server error".into(),
});
let provider = RetryingProvider::new(mock, fast_config(3));
let result = provider.complete(test_request()).await;
assert!(result.is_ok());
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn retries_on_502_and_succeeds() {
let (mock, count) = FailNTimes::new(1, || Error::Api {
status: 502,
message: "bad gateway".into(),
});
let provider = RetryingProvider::new(mock, fast_config(3));
let result = provider.complete(test_request()).await;
assert!(result.is_ok());
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn retries_on_503_and_succeeds() {
let (mock, count) = FailNTimes::new(1, || Error::Api {
status: 503,
message: "service unavailable".into(),
});
let provider = RetryingProvider::new(mock, fast_config(3));
let result = provider.complete(test_request()).await;
assert!(result.is_ok());
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn retries_on_529_and_succeeds() {
let (mock, count) = FailNTimes::new(1, || Error::Api {
status: 529,
message: "overloaded".into(),
});
let provider = RetryingProvider::new(mock, fast_config(3));
let result = provider.complete(test_request()).await;
assert!(result.is_ok());
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn exhausts_retries_and_returns_last_error() {
let (mock, count) = FailNTimes::new(10, || Error::Api {
status: 429,
message: "rate limited".into(),
});
let provider = RetryingProvider::new(mock, fast_config(2));
let result = provider.complete(test_request()).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, Error::Api { status: 429, .. }));
assert_eq!(count.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn does_not_retry_400() {
let (mock, count) = FailNTimes::new(5, || Error::Api {
status: 400,
message: "bad request".into(),
});
let provider = RetryingProvider::new(mock, fast_config(3));
let result = provider.complete(test_request()).await;
assert!(result.is_err());
assert_eq!(count.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn does_not_retry_401() {
let (mock, count) = FailNTimes::new(5, || Error::Api {
status: 401,
message: "unauthorized".into(),
});
let provider = RetryingProvider::new(mock, fast_config(3));
let result = provider.complete(test_request()).await;
assert!(result.is_err());
assert_eq!(count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn does_not_retry_json_parse_error() {
let (mock, count) = FailNTimes::new(5, || {
Error::Json(serde_json::from_str::<()>("invalid").unwrap_err())
});
let provider = RetryingProvider::new(mock, fast_config(3));
let result = provider.complete(test_request()).await;
assert!(result.is_err());
assert_eq!(count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn zero_retries_means_single_attempt() {
let (mock, count) = FailNTimes::new(1, || Error::Api {
status: 429,
message: "rate limited".into(),
});
let provider = RetryingProvider::new(mock, fast_config(0));
let result = provider.complete(test_request()).await;
assert!(result.is_err());
assert_eq!(count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn stream_complete_retries_on_transient_failure() {
let (mock, count) = FailNTimes::new(2, || Error::Api {
status: 429,
message: "rate limited".into(),
});
let provider = RetryingProvider::new(mock, fast_config(3));
let on_text: &crate::llm::OnText = &|_| {};
let result = provider.stream_complete(test_request(), on_text).await;
assert!(result.is_ok());
assert_eq!(count.load(Ordering::SeqCst), 3); }
#[tokio::test]
async fn stream_complete_does_not_retry_non_retryable() {
let (mock, count) = FailNTimes::new(5, || Error::Api {
status: 400,
message: "bad request".into(),
});
let provider = RetryingProvider::new(mock, fast_config(3));
let on_text: &crate::llm::OnText = &|_| {};
let result = provider.stream_complete(test_request(), on_text).await;
assert!(result.is_err());
assert_eq!(count.load(Ordering::SeqCst), 1); }
#[test]
fn default_config_values() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.base_delay, Duration::from_millis(500));
assert_eq!(config.max_delay, Duration::from_secs(30));
}
#[test]
fn is_retryable_checks() {
assert!(is_retryable(&Error::Api {
status: 429,
message: "".into()
}));
assert!(is_retryable(&Error::Api {
status: 500,
message: "".into()
}));
assert!(is_retryable(&Error::Api {
status: 502,
message: "".into()
}));
assert!(is_retryable(&Error::Api {
status: 503,
message: "".into()
}));
assert!(is_retryable(&Error::Api {
status: 529,
message: "".into()
}));
assert!(!is_retryable(&Error::Api {
status: 400,
message: "".into()
}));
assert!(!is_retryable(&Error::Api {
status: 401,
message: "".into()
}));
assert!(!is_retryable(&Error::Api {
status: 403,
message: "".into()
}));
assert!(!is_retryable(&Error::Api {
status: 404,
message: "".into()
}));
assert!(!is_retryable(&Error::Agent("test".into())));
assert!(!is_retryable(&Error::Config("test".into())));
assert!(!is_retryable(&Error::Memory("test".into())));
}
#[test]
fn compute_delay_in_jitter_range() {
let config = RetryConfig {
max_retries: 5,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
};
for attempt in 0..4 {
let delay = compute_delay(&config, attempt);
assert!(
delay >= config.base_delay,
"attempt {attempt}: delay {delay:?} below base"
);
assert!(
delay <= config.max_delay,
"attempt {attempt}: delay {delay:?} above max"
);
}
}
#[test]
fn compute_delay_caps_at_max() {
let config = RetryConfig {
max_retries: 10,
base_delay: Duration::from_millis(1000),
max_delay: Duration::from_secs(5),
};
for _ in 0..50 {
let d = compute_delay(&config, 3);
assert!(d <= config.max_delay, "delay {d:?} exceeds max");
let d = compute_delay(&config, 10);
assert!(d <= config.max_delay, "delay {d:?} exceeds max");
}
}
#[test]
fn compute_delay_handles_overflow() {
let config = RetryConfig {
max_retries: 100,
base_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
};
for _ in 0..50 {
let delay = compute_delay(&config, 50);
assert!(delay <= config.max_delay);
}
}
#[tokio::test]
async fn stream_retry_suppresses_on_text_on_retry() {
let text_calls = Arc::new(AtomicU32::new(0));
let text_calls_clone = text_calls.clone();
let on_text_fn = move |_: &str| {
text_calls_clone.fetch_add(1, Ordering::SeqCst);
};
let on_text: &crate::llm::OnText = &on_text_fn;
struct StreamFailOnce {
failed: AtomicU32,
}
impl LlmProvider for StreamFailOnce {
async fn complete(
&self,
_request: CompletionRequest,
) -> Result<CompletionResponse, Error> {
Ok(success_response())
}
async fn stream_complete(
&self,
_request: CompletionRequest,
on_text: &crate::llm::OnText,
) -> Result<CompletionResponse, Error> {
on_text("hello");
if self
.failed
.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |v| {
if v == 0 { Some(1) } else { None }
})
.is_ok()
{
return Err(Error::Api {
status: 503,
message: "transient".into(),
});
}
Ok(success_response())
}
}
let provider = RetryingProvider::new(
StreamFailOnce {
failed: AtomicU32::new(0),
},
fast_config(3),
);
let result = provider.stream_complete(test_request(), on_text).await;
assert!(result.is_ok());
assert_eq!(text_calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn retrying_provider_fires_on_retry() {
let (mock, _count) = FailNTimes::new(2, || Error::Api {
status: 429,
message: "rate limited".into(),
});
let retries_seen = Arc::new(AtomicU32::new(0));
let retries_clone = retries_seen.clone();
let provider = RetryingProvider::new(mock, fast_config(3)).with_on_retry(Arc::new(
move |attempt, max_retries, _delay_ms, error_class| {
assert!(attempt > 0);
assert_eq!(max_retries, 3);
assert_eq!(error_class, "rate_limited");
retries_clone.fetch_add(1, Ordering::SeqCst);
},
));
let result = provider.complete(test_request()).await;
assert!(result.is_ok());
assert_eq!(retries_seen.load(Ordering::SeqCst), 2); }
#[tokio::test]
async fn retrying_provider_on_retry_none_is_noop() {
let (mock, count) = FailNTimes::new(1, || Error::Api {
status: 500,
message: "server error".into(),
});
let provider = RetryingProvider::new(mock, fast_config(3));
let result = provider.complete(test_request()).await;
assert!(result.is_ok());
assert_eq!(count.load(Ordering::SeqCst), 2);
}
#[test]
fn classify_for_retry_returns_correct_classes() {
assert_eq!(
classify_for_retry(&Error::Api {
status: 429,
message: "".into()
}),
"rate_limited"
);
assert_eq!(
classify_for_retry(&Error::Api {
status: 500,
message: "".into()
}),
"server_error_500"
);
assert_eq!(
classify_for_retry(&Error::Api {
status: 502,
message: "".into()
}),
"server_error_502"
);
assert_eq!(
classify_for_retry(&Error::Api {
status: 503,
message: "".into()
}),
"server_error_503"
);
assert_eq!(
classify_for_retry(&Error::Api {
status: 529,
message: "".into()
}),
"overloaded"
);
assert_eq!(classify_for_retry(&Error::Agent("other".into())), "unknown");
}
#[test]
fn model_name_forwards_to_inner() {
struct NamedProvider;
impl LlmProvider for NamedProvider {
async fn complete(
&self,
_request: CompletionRequest,
) -> Result<CompletionResponse, Error> {
unimplemented!()
}
fn model_name(&self) -> Option<&str> {
Some("my-model")
}
}
let provider = RetryingProvider::with_defaults(NamedProvider);
assert_eq!(provider.model_name(), Some("my-model"));
}
}