use std::time::{Duration, Instant};
use crate::types::{AnthropicError, Result};
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_retries: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub multiplier: f64,
pub jitter: bool,
pub max_elapsed_time: Option<Duration>,
pub retry_conditions: Vec<RetryCondition>,
}
#[derive(Debug, Clone, PartialEq)]
pub enum RetryCondition {
Timeout,
ConnectionError,
HttpStatus(u16),
RateLimit,
ServerError,
AuthenticationError,
All,
}
#[derive(Debug)]
pub struct RetryExecutor {
policy: RetryPolicy,
}
#[derive(Debug)]
pub enum RetryResult<T> {
Success(T),
Failed(AnthropicError),
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(30),
multiplier: 2.0,
jitter: true,
max_elapsed_time: Some(Duration::from_secs(60)),
retry_conditions: vec![
RetryCondition::Timeout,
RetryCondition::ConnectionError,
RetryCondition::RateLimit,
RetryCondition::ServerError,
],
}
}
}
impl RetryPolicy {
pub fn exponential() -> Self {
Self::default()
}
pub fn max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
pub fn initial_delay(mut self, delay: Duration) -> Self {
self.initial_delay = delay;
self
}
pub fn max_delay(mut self, delay: Duration) -> Self {
self.max_delay = delay;
self
}
pub fn multiplier(mut self, multiplier: f64) -> Self {
self.multiplier = multiplier;
self
}
pub fn jitter(mut self, jitter: bool) -> Self {
self.jitter = jitter;
self
}
pub fn max_elapsed_time(mut self, max_elapsed: Duration) -> Self {
self.max_elapsed_time = Some(max_elapsed);
self
}
pub fn retry_conditions(mut self, conditions: Vec<RetryCondition>) -> Self {
self.retry_conditions = conditions;
self
}
pub fn should_retry(&self, error: &AnthropicError) -> bool {
for condition in &self.retry_conditions {
match condition {
RetryCondition::All => return true,
RetryCondition::Timeout => {
if matches!(error, AnthropicError::Timeout) {
return true;
}
}
RetryCondition::ConnectionError => {
if matches!(error, AnthropicError::NetworkError(_)) {
return true;
}
}
RetryCondition::HttpStatus(code) => {
if let AnthropicError::HttpError { status, .. } = error {
if status == code {
return true;
}
}
}
RetryCondition::RateLimit => {
if let AnthropicError::HttpError { status, .. } = error {
if *status == 429 {
return true;
}
}
}
RetryCondition::ServerError => {
if let AnthropicError::HttpError { status, .. } = error {
if *status >= 500 && *status < 600 {
return true;
}
}
}
RetryCondition::AuthenticationError => {
if let AnthropicError::HttpError { status, .. } = error {
if *status == 401 {
return true;
}
}
}
}
}
false
}
pub fn calculate_delay(&self, attempt: u32) -> Duration {
let base_delay = self.initial_delay.as_millis() as f64;
let delay_ms = base_delay * self.multiplier.powi(attempt as i32);
let delay = Duration::from_millis(delay_ms as u64);
let delay = std::cmp::min(delay, self.max_delay);
if self.jitter {
self.add_jitter(delay)
} else {
delay
}
}
fn add_jitter(&self, delay: Duration) -> Duration {
let jitter_range = delay.as_millis() as f64 * 0.1; let jitter = (std::ptr::addr_of!(self) as usize % 100) as f64 / 100.0 * jitter_range;
let jittered_ms = (delay.as_millis() as f64 + jitter) as u64;
Duration::from_millis(jittered_ms)
}
}
impl RetryExecutor {
pub fn new(policy: RetryPolicy) -> Self {
Self { policy }
}
pub async fn execute<T, F, Fut>(&self, operation: F) -> RetryResult<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let start_time = Instant::now();
let mut last_error = None;
for attempt in 0..=self.policy.max_retries {
if let Some(max_elapsed) = self.policy.max_elapsed_time {
if start_time.elapsed() >= max_elapsed {
break;
}
}
match operation().await {
Ok(result) => {
return RetryResult::Success(result);
}
Err(error) => {
last_error = Some(error.clone());
if attempt < self.policy.max_retries && self.policy.should_retry(&error) {
let delay = self.policy.calculate_delay(attempt);
tracing::debug!(
"Request failed (attempt {}/{}): {}. Retrying in {:?}",
attempt + 1,
self.policy.max_retries + 1,
error,
delay
);
tokio::time::sleep(delay).await;
} else {
break;
}
}
}
}
RetryResult::Failed(last_error.unwrap_or_else(|| {
AnthropicError::Other("Unknown error in retry executor".to_string())
}))
}
pub fn get_policy(&self) -> &RetryPolicy {
&self.policy
}
}
pub fn default_retry() -> RetryExecutor {
RetryExecutor::new(RetryPolicy::default())
}
pub fn api_retry() -> RetryExecutor {
RetryExecutor::new(
RetryPolicy::exponential()
.max_retries(3)
.initial_delay(Duration::from_millis(500))
.max_delay(Duration::from_secs(30))
.retry_conditions(vec![
RetryCondition::RateLimit,
RetryCondition::ServerError,
RetryCondition::Timeout,
RetryCondition::ConnectionError,
])
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_policy_should_retry() {
let policy = RetryPolicy::default();
assert!(policy.should_retry(&AnthropicError::Timeout));
assert!(policy.should_retry(&AnthropicError::HttpError {
status: 429,
message: "Rate limited".to_string(),
}));
assert!(policy.should_retry(&AnthropicError::HttpError {
status: 500,
message: "Server error".to_string(),
}));
assert!(!policy.should_retry(&AnthropicError::InvalidApiKey));
}
#[test]
fn test_delay_calculation() {
let policy = RetryPolicy::exponential()
.initial_delay(Duration::from_millis(100))
.multiplier(2.0)
.jitter(false);
assert_eq!(policy.calculate_delay(0), Duration::from_millis(100));
assert_eq!(policy.calculate_delay(1), Duration::from_millis(200));
assert_eq!(policy.calculate_delay(2), Duration::from_millis(400));
}
#[tokio::test]
async fn test_retry_executor_success() {
let policy = RetryPolicy::exponential().max_retries(2);
let executor = RetryExecutor::new(policy);
let result = executor.execute(|| async {
Ok::<i32, AnthropicError>(42)
}).await;
match result {
RetryResult::Success(value) => assert_eq!(value, 42),
_ => panic!("Expected success"),
}
}
#[tokio::test]
async fn test_retry_executor_failure() {
let policy = RetryPolicy::exponential()
.max_retries(1)
.initial_delay(Duration::from_millis(1));
let executor = RetryExecutor::new(policy);
let result = executor.execute(|| async {
Err::<i32, AnthropicError>(AnthropicError::InvalidApiKey)
}).await;
match result {
RetryResult::Failed(_) => {},
_ => panic!("Expected failure"),
}
}
}