use crate::error::{AgentError, LlmFailureReason};
#[cfg(test)]
use crate::error::{LlmProviderError, LlmProviderErrorKind};
use serde::{Deserialize, Serialize};
use std::time::Duration;
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum LlmRetryFailureKind {
RateLimited,
NetworkTimeout,
CallTimeout,
RetryableProviderError,
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct LlmRetryFailure {
pub provider: String,
pub kind: LlmRetryFailureKind,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub retry_after_ms: Option<u64>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub duration_ms: Option<u64>,
pub message: String,
}
impl LlmRetryFailure {
pub fn from_agent_error(error: &AgentError) -> Option<Self> {
match error {
AgentError::Llm {
provider,
reason,
message,
} => match reason {
LlmFailureReason::RateLimited { retry_after } => Some(Self {
provider: (*provider).to_string(),
kind: LlmRetryFailureKind::RateLimited,
retry_after_ms: retry_after.map(duration_millis_u64),
duration_ms: None,
message: message.clone(),
}),
LlmFailureReason::NetworkTimeout { duration_ms } => Some(Self {
provider: (*provider).to_string(),
kind: LlmRetryFailureKind::NetworkTimeout,
retry_after_ms: None,
duration_ms: Some(*duration_ms),
message: message.clone(),
}),
LlmFailureReason::CallTimeout { duration_ms } => Some(Self {
provider: (*provider).to_string(),
kind: LlmRetryFailureKind::CallTimeout,
retry_after_ms: None,
duration_ms: Some(*duration_ms),
message: message.clone(),
}),
LlmFailureReason::ProviderError(provider_error)
if provider_error.is_retryable() =>
{
Some(Self {
provider: (*provider).to_string(),
kind: LlmRetryFailureKind::RetryableProviderError,
retry_after_ms: None,
duration_ms: None,
message: message.clone(),
})
}
_ => None,
},
_ => None,
}
}
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct LlmRetryPlan {
pub attempt: u32,
pub max_retries: u32,
pub computed_delay_ms: u64,
pub selected_delay_ms: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub retry_after_hint_ms: Option<u64>,
pub rate_limit_floor_applied: bool,
pub budget_capped: bool,
}
impl LlmRetryPlan {
pub fn selected_delay(&self) -> Duration {
Duration::from_millis(self.selected_delay_ms)
}
}
#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct LlmRetrySchedule {
pub failure: LlmRetryFailure,
pub plan: LlmRetryPlan,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryPolicy {
pub max_retries: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub multiplier: f64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub call_timeout: Option<Duration>,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(30),
multiplier: 2.0,
call_timeout: None,
}
}
}
impl RetryPolicy {
pub fn new() -> Self {
Self::default()
}
pub fn no_retry() -> Self {
Self {
max_retries: 0,
call_timeout: None,
..Default::default()
}
}
pub fn with_max_retries(mut self, max: u32) -> Self {
self.max_retries = max;
self
}
pub fn with_initial_delay(mut self, delay: Duration) -> Self {
self.initial_delay = delay;
self
}
pub fn with_max_delay(mut self, delay: Duration) -> Self {
self.max_delay = delay;
self
}
pub fn with_multiplier(mut self, multiplier: f64) -> Self {
self.multiplier = multiplier;
self
}
pub fn with_call_timeout(mut self, timeout: Option<Duration>) -> Self {
self.call_timeout = timeout;
self
}
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
if attempt == 0 {
return Duration::ZERO;
}
let base_delay = self.initial_delay.as_secs_f64()
* self
.multiplier
.powi(i32::try_from(attempt - 1).unwrap_or(i32::MAX));
let jitter = 1.0 + (rand_jitter() * 0.2 - 0.1);
let delay_with_jitter = base_delay * jitter;
let delay_secs = delay_with_jitter.min(self.max_delay.as_secs_f64());
Duration::from_secs_f64(delay_secs)
}
pub fn should_retry(&self, attempt: u32) -> bool {
attempt < self.max_retries
}
pub fn schedule_retry(
&self,
error: &AgentError,
attempt_index: u32,
remaining_budget: Option<Duration>,
) -> Option<LlmRetrySchedule> {
if !self.should_retry(attempt_index) {
return None;
}
let failure = LlmRetryFailure::from_agent_error(error)?;
let attempt = attempt_index.saturating_add(1);
let hint = error.retry_after_hint();
let computed = self.delay_for_attempt(attempt);
let (selected, rate_limit_floor_applied) =
select_retry_delay(hint, computed, error.is_rate_limited());
let capped = match remaining_budget {
Some(remaining) => selected.min(remaining),
None => selected,
};
Some(LlmRetrySchedule {
failure,
plan: LlmRetryPlan {
attempt,
max_retries: self.max_retries,
computed_delay_ms: duration_millis_u64(computed),
selected_delay_ms: duration_millis_u64(capped),
retry_after_hint_ms: hint.map(duration_millis_u64),
rate_limit_floor_applied,
budget_capped: capped < selected,
},
})
}
}
pub fn select_retry_delay(
hint: Option<Duration>,
computed: Duration,
is_rate_limited: bool,
) -> (Duration, bool) {
match hint {
Some(h) if h > computed => (h, false),
_ if is_rate_limited => {
let floor = Duration::from_secs(30);
(computed.max(floor), computed < floor)
}
_ => (computed, false),
}
}
fn duration_millis_u64(duration: Duration) -> u64 {
u64::try_from(duration.as_millis()).unwrap_or(u64::MAX)
}
fn rand_jitter() -> f64 {
use crate::time_compat::SystemTime;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
crate::time_compat::SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
.hash(&mut hasher);
let hash = hasher.finish();
(hash as f64) / (u64::MAX as f64)
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::*;
#[test]
fn test_retry_policy_default() {
let policy = RetryPolicy::default();
assert_eq!(policy.max_retries, 3);
assert_eq!(policy.initial_delay, Duration::from_millis(500));
assert_eq!(policy.max_delay, Duration::from_secs(30));
assert_eq!(policy.multiplier, 2.0);
assert_eq!(policy.call_timeout, None);
}
#[test]
fn test_retry_policy_no_retry() {
let policy = RetryPolicy::no_retry();
assert_eq!(policy.max_retries, 0);
assert_eq!(policy.call_timeout, None);
assert!(!policy.should_retry(0));
}
#[test]
fn test_call_timeout_builder() {
let policy = RetryPolicy::default().with_call_timeout(Some(Duration::from_secs(45)));
assert_eq!(policy.call_timeout, Some(Duration::from_secs(45)));
let policy = policy.with_call_timeout(None);
assert_eq!(policy.call_timeout, None);
}
#[test]
fn test_delay_calculation() {
let policy = RetryPolicy::default();
assert_eq!(policy.delay_for_attempt(0), Duration::ZERO);
let delay1 = policy.delay_for_attempt(1);
let delay2 = policy.delay_for_attempt(2);
let delay3 = policy.delay_for_attempt(3);
assert!(delay1.as_millis() >= 400 && delay1.as_millis() <= 600);
assert!(delay2 > delay1 / 2); assert!(delay3 > delay2 / 2); }
#[test]
fn test_max_delay_cap() {
let policy = RetryPolicy::default()
.with_initial_delay(Duration::from_secs(10))
.with_max_delay(Duration::from_secs(15));
let delay = policy.delay_for_attempt(10);
assert!(delay <= Duration::from_secs(17)); }
#[test]
fn test_should_retry() {
let policy = RetryPolicy::default().with_max_retries(3);
assert!(policy.should_retry(0));
assert!(policy.should_retry(1));
assert!(policy.should_retry(2));
assert!(!policy.should_retry(3));
assert!(!policy.should_retry(4));
}
#[test]
fn retry_schedule_carries_typed_failure_and_delay_plan() {
let policy = RetryPolicy::default().with_max_retries(3);
let error = AgentError::Llm {
provider: "test",
reason: LlmFailureReason::RateLimited {
retry_after: Some(Duration::from_secs(60)),
},
message: "rate limited".to_string(),
};
let schedule = policy
.schedule_retry(&error, 0, Some(Duration::from_secs(45)))
.expect("rate limit should be retryable");
assert_eq!(schedule.failure.kind, LlmRetryFailureKind::RateLimited);
assert_eq!(schedule.failure.retry_after_ms, Some(60_000));
assert_eq!(schedule.plan.attempt, 1);
assert_eq!(schedule.plan.max_retries, 3);
assert_eq!(schedule.plan.retry_after_hint_ms, Some(60_000));
assert_eq!(schedule.plan.selected_delay_ms, 45_000);
assert!(schedule.plan.budget_capped);
}
#[test]
fn retry_schedule_rejects_non_retryable_errors() {
let policy = RetryPolicy::default().with_max_retries(3);
let error = AgentError::Llm {
provider: "test",
reason: LlmFailureReason::AuthError,
message: "auth".to_string(),
};
assert!(policy.schedule_retry(&error, 0, None).is_none());
}
#[test]
fn retry_schedule_reads_typed_provider_retryability() {
let policy = RetryPolicy::default().with_max_retries(3);
let error = AgentError::Llm {
provider: "test",
reason: LlmFailureReason::ProviderError(LlmProviderError::retryable(
LlmProviderErrorKind::ServerOverloaded,
serde_json::json!({
"retryable": false,
"message": "json payload must not suppress typed retryability"
}),
)),
message: "provider overloaded".to_string(),
};
let schedule = policy
.schedule_retry(&error, 0, None)
.expect("typed retryable provider error should be scheduled");
assert_eq!(
schedule.failure.kind,
LlmRetryFailureKind::RetryableProviderError
);
}
#[test]
fn retry_schedule_ignores_json_only_provider_retryability() {
let policy = RetryPolicy::default().with_max_retries(3);
let error = AgentError::Llm {
provider: "test",
reason: LlmFailureReason::ProviderError(LlmProviderError::non_retryable(
LlmProviderErrorKind::InvalidRequest,
serde_json::json!({
"retryable": true,
"message": "json payload must not admit retries"
}),
)),
message: "invalid request".to_string(),
};
assert!(policy.schedule_retry(&error, 0, None).is_none());
}
#[test]
fn test_retry_policy_serialization() {
let policy = RetryPolicy::default();
let json = serde_json::to_string(&policy).unwrap();
let parsed: RetryPolicy = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.max_retries, policy.max_retries);
assert_eq!(parsed.initial_delay, policy.initial_delay);
assert_eq!(parsed.max_delay, policy.max_delay);
assert_eq!(parsed.multiplier, policy.multiplier);
assert_eq!(parsed.call_timeout, None);
}
#[test]
fn test_retry_policy_serialization_with_call_timeout() {
let policy = RetryPolicy::default().with_call_timeout(Some(Duration::from_secs(60)));
let json = serde_json::to_string(&policy).unwrap();
let parsed: RetryPolicy = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.call_timeout, Some(Duration::from_secs(60)));
}
#[test]
fn test_retry_policy_deserialization_missing_call_timeout() {
let json = r#"{"max_retries":3,"initial_delay":{"secs":0,"nanos":500000000},"max_delay":{"secs":30,"nanos":0},"multiplier":2.0}"#;
let parsed: RetryPolicy = serde_json::from_str(json).unwrap();
assert_eq!(parsed.call_timeout, None);
}
}