use std::time::Duration;
use crate::error::{AgentError, OperationError};
const DEFAULT_INITIAL_BACKOFF: Duration = Duration::from_millis(200);
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(30);
const DEFAULT_MULTIPLIER: f64 = 2.0;
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub(crate) max_retries: u32,
pub(crate) initial_backoff: Duration,
pub(crate) max_backoff: Duration,
pub(crate) multiplier: f64,
}
impl RetryPolicy {
pub fn new(max_retries: u32) -> Self {
assert!(max_retries > 0, "max_retries must be greater than 0");
Self {
max_retries,
initial_backoff: DEFAULT_INITIAL_BACKOFF,
max_backoff: DEFAULT_MAX_BACKOFF,
multiplier: DEFAULT_MULTIPLIER,
}
}
pub fn backoff(mut self, duration: Duration) -> Self {
assert!(!duration.is_zero(), "initial backoff must not be zero");
self.initial_backoff = duration;
self
}
pub fn max_backoff(mut self, duration: Duration) -> Self {
assert!(!duration.is_zero(), "max backoff must not be zero");
self.max_backoff = duration;
self
}
pub fn multiplier(mut self, multiplier: f64) -> Self {
assert!(
multiplier >= 1.0 && multiplier.is_finite(),
"multiplier must be >= 1.0 and finite, got {multiplier}"
);
self.multiplier = multiplier;
self
}
pub fn max_retries(&self) -> u32 {
self.max_retries
}
pub(crate) fn delay_for_attempt(&self, attempt: u32) -> Duration {
let delay = self.initial_backoff.as_secs_f64() * self.multiplier.powi(attempt as i32);
let capped = delay.min(self.max_backoff.as_secs_f64());
Duration::from_secs_f64(capped)
}
}
pub fn is_retryable(error: &OperationError) -> bool {
match error {
OperationError::Http { status, .. } => match status {
None => true,
Some(code) => *code >= 500 || *code == 429,
},
OperationError::Agent(agent_err) => matches!(
agent_err,
AgentError::ProcessFailed { .. } | AgentError::Timeout { .. }
),
OperationError::Timeout { .. } => true,
OperationError::Shell { .. } | OperationError::Deserialize { .. } => false,
}
}
pub(crate) fn is_retryable_status(status: u16) -> bool {
status >= 500 || status == 429
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn new_creates_policy_with_defaults() {
let policy = RetryPolicy::new(3);
assert_eq!(policy.max_retries, 3);
assert_eq!(policy.initial_backoff, DEFAULT_INITIAL_BACKOFF);
assert_eq!(policy.max_backoff, DEFAULT_MAX_BACKOFF);
assert!((policy.multiplier - DEFAULT_MULTIPLIER).abs() < f64::EPSILON);
}
#[test]
#[should_panic(expected = "max_retries must be greater than 0")]
fn new_zero_retries_panics() {
let _ = RetryPolicy::new(0);
}
#[test]
fn backoff_sets_initial_backoff() {
let policy = RetryPolicy::new(1).backoff(Duration::from_secs(1));
assert_eq!(policy.initial_backoff, Duration::from_secs(1));
}
#[test]
#[should_panic(expected = "initial backoff must not be zero")]
fn backoff_zero_panics() {
let _ = RetryPolicy::new(1).backoff(Duration::ZERO);
}
#[test]
fn max_backoff_sets_cap() {
let policy = RetryPolicy::new(1).max_backoff(Duration::from_secs(60));
assert_eq!(policy.max_backoff, Duration::from_secs(60));
}
#[test]
#[should_panic(expected = "max backoff must not be zero")]
fn max_backoff_zero_panics() {
let _ = RetryPolicy::new(1).max_backoff(Duration::ZERO);
}
#[test]
fn multiplier_sets_value() {
let policy = RetryPolicy::new(1).multiplier(3.0);
assert!((policy.multiplier - 3.0).abs() < f64::EPSILON);
}
#[test]
#[should_panic(expected = "multiplier must be >= 1.0")]
fn multiplier_below_one_panics() {
let _ = RetryPolicy::new(1).multiplier(0.5);
}
#[test]
#[should_panic(expected = "multiplier must be >= 1.0 and finite")]
fn multiplier_nan_panics() {
let _ = RetryPolicy::new(1).multiplier(f64::NAN);
}
#[test]
#[should_panic(expected = "multiplier must be >= 1.0 and finite")]
fn multiplier_infinity_panics() {
let _ = RetryPolicy::new(1).multiplier(f64::INFINITY);
}
#[test]
fn max_retries_accessor() {
assert_eq!(RetryPolicy::new(5).max_retries(), 5);
}
#[test]
fn delay_for_attempt_zero_is_initial_backoff() {
let policy = RetryPolicy::new(3).backoff(Duration::from_millis(100));
let delay = policy.delay_for_attempt(0);
assert_eq!(delay, Duration::from_millis(100));
}
#[test]
fn delay_for_attempt_grows_exponentially() {
let policy = RetryPolicy::new(5)
.backoff(Duration::from_millis(100))
.multiplier(2.0);
assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(100));
assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(200));
assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(400));
assert_eq!(policy.delay_for_attempt(3), Duration::from_millis(800));
}
#[test]
fn delay_for_attempt_capped_at_max_backoff() {
let policy = RetryPolicy::new(10)
.backoff(Duration::from_secs(1))
.max_backoff(Duration::from_secs(5))
.multiplier(10.0);
assert_eq!(policy.delay_for_attempt(0), Duration::from_secs(1));
assert_eq!(policy.delay_for_attempt(1), Duration::from_secs(5));
assert_eq!(policy.delay_for_attempt(2), Duration::from_secs(5));
}
#[test]
fn delay_for_attempt_with_multiplier_one_is_constant() {
let policy = RetryPolicy::new(3)
.backoff(Duration::from_millis(500))
.multiplier(1.0);
assert_eq!(policy.delay_for_attempt(0), Duration::from_millis(500));
assert_eq!(policy.delay_for_attempt(1), Duration::from_millis(500));
assert_eq!(policy.delay_for_attempt(2), Duration::from_millis(500));
}
#[test]
fn http_transport_error_is_retryable() {
let err = OperationError::Http {
status: None,
message: "connection refused".to_string(),
};
assert!(is_retryable(&err));
}
#[test]
fn http_500_is_retryable() {
let err = OperationError::Http {
status: Some(500),
message: "internal server error".to_string(),
};
assert!(is_retryable(&err));
}
#[test]
fn http_502_is_retryable() {
let err = OperationError::Http {
status: Some(502),
message: "bad gateway".to_string(),
};
assert!(is_retryable(&err));
}
#[test]
fn http_503_is_retryable() {
let err = OperationError::Http {
status: Some(503),
message: "service unavailable".to_string(),
};
assert!(is_retryable(&err));
}
#[test]
fn http_429_is_retryable() {
let err = OperationError::Http {
status: Some(429),
message: "too many requests".to_string(),
};
assert!(is_retryable(&err));
}
#[test]
fn http_400_is_not_retryable() {
let err = OperationError::Http {
status: Some(400),
message: "bad request".to_string(),
};
assert!(!is_retryable(&err));
}
#[test]
fn http_404_is_not_retryable() {
let err = OperationError::Http {
status: Some(404),
message: "not found".to_string(),
};
assert!(!is_retryable(&err));
}
#[test]
fn agent_process_failed_is_retryable() {
let err = OperationError::Agent(AgentError::ProcessFailed {
exit_code: 1,
stderr: "crash".to_string(),
});
assert!(is_retryable(&err));
}
#[test]
fn agent_timeout_is_retryable() {
let err = OperationError::Agent(AgentError::Timeout {
limit: Duration::from_secs(60),
});
assert!(is_retryable(&err));
}
#[test]
fn agent_prompt_too_large_is_not_retryable() {
let err = OperationError::Agent(AgentError::PromptTooLarge {
chars: 1_000_000,
estimated_tokens: 250_000,
model_limit: 200_000,
});
assert!(!is_retryable(&err));
}
#[test]
fn agent_schema_validation_is_not_retryable() {
let err = OperationError::Agent(AgentError::SchemaValidation {
expected: "object".to_string(),
got: "string".to_string(),
debug_messages: Vec::new(),
partial_usage: Box::default(),
});
assert!(!is_retryable(&err));
}
#[test]
fn operation_timeout_is_retryable() {
let err = OperationError::Timeout {
step: "fetch".to_string(),
limit: Duration::from_secs(30),
};
assert!(is_retryable(&err));
}
#[test]
fn shell_error_is_not_retryable() {
let err = OperationError::Shell {
exit_code: 1,
stderr: "fail".to_string(),
};
assert!(!is_retryable(&err));
}
#[test]
fn deserialize_error_is_not_retryable() {
let err = OperationError::Deserialize {
target_type: "MyStruct".to_string(),
reason: "missing field".to_string(),
};
assert!(!is_retryable(&err));
}
#[test]
fn retryable_status_codes() {
assert!(is_retryable_status(500));
assert!(is_retryable_status(502));
assert!(is_retryable_status(503));
assert!(is_retryable_status(504));
assert!(is_retryable_status(429));
}
#[test]
fn non_retryable_status_codes() {
assert!(!is_retryable_status(200));
assert!(!is_retryable_status(201));
assert!(!is_retryable_status(301));
assert!(!is_retryable_status(400));
assert!(!is_retryable_status(401));
assert!(!is_retryable_status(403));
assert!(!is_retryable_status(404));
assert!(!is_retryable_status(422));
assert!(!is_retryable_status(428));
}
#[test]
fn builder_chain_all_methods() {
let policy = RetryPolicy::new(5)
.backoff(Duration::from_millis(100))
.max_backoff(Duration::from_secs(10))
.multiplier(3.0);
assert_eq!(policy.max_retries, 5);
assert_eq!(policy.initial_backoff, Duration::from_millis(100));
assert_eq!(policy.max_backoff, Duration::from_secs(10));
assert!((policy.multiplier - 3.0).abs() < f64::EPSILON);
}
#[test]
fn clone_produces_independent_copy() {
let policy = RetryPolicy::new(3).backoff(Duration::from_millis(100));
let cloned = policy.clone();
assert_eq!(policy.max_retries, cloned.max_retries);
assert_eq!(policy.initial_backoff, cloned.initial_backoff);
}
#[test]
fn debug_does_not_panic() {
let policy = RetryPolicy::new(1);
let debug = format!("{:?}", policy);
assert!(debug.contains("RetryPolicy"));
}
}