use std::future::Future;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use rand::Rng;
use rust_decimal::Decimal;
use crate::llm::error::LlmError;
use crate::llm::provider::{
CompletionRequest, CompletionResponse, LlmProvider, ModelMetadata, ToolCompletionRequest,
ToolCompletionResponse,
};
pub(crate) const MAX_RETRY_AFTER_SECS: u64 = 3600;
pub(crate) fn is_retryable(err: &LlmError) -> bool {
matches!(
err,
LlmError::RequestFailed { .. }
| LlmError::RateLimited { .. }
| LlmError::InvalidResponse { .. }
| LlmError::SessionRenewalFailed { .. }
| LlmError::Http(_)
| LlmError::Io(_)
)
}
pub(crate) fn retry_backoff_delay(attempt: u32) -> Duration {
let base_ms: u64 = 1000u64.saturating_mul(2u64.saturating_pow(attempt));
let jitter_range = base_ms / 4; let jitter = if jitter_range > 0 {
let offset = rand::thread_rng().gen_range(0..=jitter_range * 2);
offset as i64 - jitter_range as i64
} else {
0
};
let delay_ms = (base_ms as i64 + jitter).max(100) as u64;
Duration::from_millis(delay_ms)
}
pub(crate) fn cap_retry_after(duration: Duration) -> Duration {
duration.min(Duration::from_secs(MAX_RETRY_AFTER_SECS))
}
pub(crate) fn parse_retry_after(header: Option<&reqwest::header::HeaderValue>) -> Duration {
header
.and_then(|v| v.to_str().ok())
.and_then(|v| {
if let Ok(secs) = v.trim().parse::<u64>() {
return Some(cap_retry_after(Duration::from_secs(secs)));
}
if let Ok(dt) = chrono::DateTime::parse_from_rfc2822(v.trim()) {
let now = chrono::Utc::now();
let delta = dt.signed_duration_since(now);
return Some(cap_retry_after(Duration::from_secs(
delta.num_seconds().max(0) as u64,
)));
}
None
})
.unwrap_or(Duration::from_secs(DEFAULT_RETRY_AFTER_SECS))
}
const DEFAULT_RETRY_AFTER_SECS: u64 = 60;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
}
impl Default for RetryConfig {
fn default() -> Self {
Self { max_retries: 3 }
}
}
pub struct RetryProvider {
inner: Arc<dyn LlmProvider>,
config: RetryConfig,
}
impl RetryProvider {
pub fn new(inner: Arc<dyn LlmProvider>, config: RetryConfig) -> Self {
Self { inner, config }
}
async fn retry_loop<T, F, Fut>(&self, mut op: F, label: &str) -> Result<T, LlmError>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, LlmError>>,
{
let mut last_error: Option<LlmError> = None;
for attempt in 0..=self.config.max_retries {
match op().await {
Ok(resp) => return Ok(resp),
Err(err) => {
if !is_retryable(&err) || attempt == self.config.max_retries {
return Err(err);
}
let delay = match &err {
LlmError::RateLimited {
retry_after: Some(duration),
..
} => *duration,
_ => retry_backoff_delay(attempt),
};
tracing::warn!(
provider = %self.inner.model_name(),
attempt = attempt + 1,
max_retries = self.config.max_retries,
delay_ms = delay.as_millis() as u64,
error = %err,
"Retrying after transient error{label}"
);
last_error = Some(err);
tokio::time::sleep(delay).await;
}
}
}
Err(last_error.unwrap_or_else(|| LlmError::RequestFailed {
provider: self.inner.model_name().to_string(),
reason: "retry loop exited unexpectedly".to_string(),
}))
}
}
#[async_trait]
impl LlmProvider for RetryProvider {
fn model_name(&self) -> &str {
self.inner.model_name()
}
fn cost_per_token(&self) -> (Decimal, Decimal) {
self.inner.cost_per_token()
}
fn cache_write_multiplier(&self) -> Decimal {
self.inner.cache_write_multiplier()
}
fn cache_read_discount(&self) -> Decimal {
self.inner.cache_read_discount()
}
async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, LlmError> {
let inner = &self.inner;
self.retry_loop(
|| {
let req = request.clone();
async move { inner.complete(req).await }
},
"",
)
.await
}
async fn complete_with_tools(
&self,
request: ToolCompletionRequest,
) -> Result<ToolCompletionResponse, LlmError> {
let inner = &self.inner;
self.retry_loop(
|| {
let req = request.clone();
async move { inner.complete_with_tools(req).await }
},
" (tools)",
)
.await
}
async fn list_models(&self) -> Result<Vec<String>, LlmError> {
self.inner.list_models().await
}
async fn model_metadata(&self) -> Result<ModelMetadata, LlmError> {
self.inner.model_metadata().await
}
fn effective_model_name(&self, requested_model: Option<&str>) -> String {
self.inner.effective_model_name(requested_model)
}
fn active_model_name(&self) -> String {
self.inner.active_model_name()
}
fn set_model(&self, model: &str) -> Result<(), LlmError> {
self.inner.set_model(model)
}
fn calculate_cost(&self, input_tokens: u32, output_tokens: u32) -> Decimal {
self.inner.calculate_cost(input_tokens, output_tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::testing::StubLlm;
fn make_request() -> CompletionRequest {
CompletionRequest::new(vec![crate::llm::ChatMessage::user("hello")])
}
fn make_tool_request() -> ToolCompletionRequest {
ToolCompletionRequest::new(vec![crate::llm::ChatMessage::user("hello")], vec![])
}
fn fast_config(max_retries: u32) -> RetryConfig {
RetryConfig { max_retries }
}
#[test]
fn test_retry_backoff_delay_exponential_growth() {
for _ in 0..20 {
let d0 = retry_backoff_delay(0);
let d1 = retry_backoff_delay(1);
let d2 = retry_backoff_delay(2);
assert!(d0.as_millis() >= 750, "attempt 0 too low: {:?}", d0);
assert!(d0.as_millis() <= 1250, "attempt 0 too high: {:?}", d0);
assert!(d1.as_millis() >= 1500, "attempt 1 too low: {:?}", d1);
assert!(d1.as_millis() <= 2500, "attempt 1 too high: {:?}", d1);
assert!(d2.as_millis() >= 3000, "attempt 2 too low: {:?}", d2);
assert!(d2.as_millis() <= 5000, "attempt 2 too high: {:?}", d2);
}
}
#[test]
fn test_retry_backoff_delay_minimum() {
for _ in 0..20 {
let delay = retry_backoff_delay(0);
assert!(delay.as_millis() >= 100);
}
}
#[test]
fn test_retry_backoff_delay_no_overflow() {
let delay = retry_backoff_delay(30);
assert!(delay.as_millis() >= 100);
}
#[test]
fn test_is_retryable_classification() {
assert!(is_retryable(&LlmError::RequestFailed {
provider: "p".into(),
reason: "err".into(),
}));
assert!(is_retryable(&LlmError::RateLimited {
provider: "p".into(),
retry_after: None,
}));
assert!(is_retryable(&LlmError::InvalidResponse {
provider: "p".into(),
reason: "bad".into(),
}));
assert!(is_retryable(&LlmError::SessionRenewalFailed {
provider: "p".into(),
reason: "timeout".into(),
}));
assert!(is_retryable(&LlmError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"reset"
))));
assert!(!is_retryable(&LlmError::AuthFailed {
provider: "p".into(),
}));
assert!(!is_retryable(&LlmError::SessionExpired {
provider: "p".into(),
}));
assert!(!is_retryable(&LlmError::ContextLengthExceeded {
used: 100_000,
limit: 50_000,
}));
assert!(!is_retryable(&LlmError::ModelNotAvailable {
provider: "p".into(),
model: "m".into(),
}));
}
#[tokio::test]
async fn success_on_first_attempt() {
let stub = Arc::new(StubLlm::new("ok").with_model_name("test"));
let retry = RetryProvider::new(stub.clone(), fast_config(3));
let resp = retry.complete(make_request()).await;
assert!(resp.is_ok());
assert_eq!(resp.unwrap().content, "ok");
assert_eq!(stub.calls(), 1);
}
#[tokio::test]
async fn retries_transient_errors_then_succeeds() {
let stub = Arc::new(StubLlm::failing("test"));
let retry = RetryProvider::new(stub.clone(), fast_config(2));
let stub_clone = stub.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(1500)).await;
stub_clone.set_failing(false);
});
let resp = retry.complete(make_request()).await;
assert!(resp.is_ok());
assert!(stub.calls() >= 2);
}
#[tokio::test]
async fn non_transient_error_fails_immediately() {
let stub = Arc::new(StubLlm::failing_non_transient("test"));
let retry = RetryProvider::new(stub.clone(), fast_config(3));
let err = retry.complete(make_request()).await.unwrap_err();
assert!(matches!(err, LlmError::ContextLengthExceeded { .. }));
assert_eq!(stub.calls(), 1);
}
#[tokio::test]
async fn exhausts_retries_then_returns_error() {
let stub = Arc::new(StubLlm::failing("test"));
let retry = RetryProvider::new(stub.clone(), fast_config(0));
let err = retry.complete(make_request()).await.unwrap_err();
assert!(matches!(err, LlmError::RequestFailed { .. }));
assert_eq!(stub.calls(), 1);
}
#[tokio::test]
async fn complete_with_tools_retries_same_as_complete() {
let stub = Arc::new(StubLlm::failing_non_transient("test"));
let retry = RetryProvider::new(stub.clone(), fast_config(3));
let err = retry
.complete_with_tools(make_tool_request())
.await
.unwrap_err();
assert!(matches!(err, LlmError::ContextLengthExceeded { .. }));
assert_eq!(stub.calls(), 1);
}
#[tokio::test]
async fn passthrough_methods_delegate_to_inner() {
let stub = Arc::new(StubLlm::new("ok").with_model_name("my-model"));
let retry = RetryProvider::new(stub, fast_config(3));
assert_eq!(retry.model_name(), "my-model");
assert_eq!(retry.active_model_name(), "my-model");
assert_eq!(retry.cost_per_token(), (Decimal::ZERO, Decimal::ZERO));
assert_eq!(retry.calculate_cost(100, 50), Decimal::ZERO);
}
#[test]
fn rate_limited_error_always_has_duration() {
let err = LlmError::RateLimited {
provider: "test".to_string(),
retry_after: Some(std::time::Duration::from_secs(60)),
};
if let LlmError::RateLimited { retry_after, .. } = err {
assert!(
retry_after.is_some(),
"Rate limited error should always have retry_after duration"
);
assert_eq!(
retry_after,
Some(std::time::Duration::from_secs(60)),
"Fallback should be 60 seconds"
);
} else {
panic!("Expected RateLimited error");
}
}
#[test]
fn cap_retry_after_clamps_huge_delays() {
assert_eq!(
cap_retry_after(Duration::from_secs(u64::MAX)),
Duration::from_secs(MAX_RETRY_AFTER_SECS)
);
assert_eq!(
cap_retry_after(Duration::from_secs(0)),
Duration::from_secs(0)
);
}
#[test]
fn parse_retry_after_delay_seconds() {
let val = reqwest::header::HeaderValue::from_static("30");
assert_eq!(parse_retry_after(Some(&val)), Duration::from_secs(30));
}
#[test]
fn parse_retry_after_missing_header() {
assert_eq!(
parse_retry_after(None),
Duration::from_secs(DEFAULT_RETRY_AFTER_SECS)
);
}
#[test]
fn parse_retry_after_unparseable() {
let val = reqwest::header::HeaderValue::from_static("not-a-number");
assert_eq!(
parse_retry_after(Some(&val)),
Duration::from_secs(DEFAULT_RETRY_AFTER_SECS)
);
}
#[test]
fn parse_retry_after_clamps_large_value() {
let val = reqwest::header::HeaderValue::from_static("999999");
assert_eq!(
parse_retry_after(Some(&val)),
Duration::from_secs(MAX_RETRY_AFTER_SECS)
);
}
#[test]
fn parse_retry_after_http_date() {
let future = chrono::Utc::now() + chrono::Duration::seconds(30);
let date_str = future.to_rfc2822();
let val = reqwest::header::HeaderValue::from_str(&date_str).unwrap();
let parsed = parse_retry_after(Some(&val));
let diff = if parsed > Duration::from_secs(30) {
parsed - Duration::from_secs(30)
} else {
Duration::from_secs(30) - parsed
};
assert!(
diff <= Duration::from_secs(2),
"expected ~30s, got {parsed:?} (diff {diff:?}) from header {date_str:?}"
);
}
}