use simple_agent_type::prelude::{ProviderError, SimpleAgentsError};
use std::future::Future;
use std::time::Duration;
#[derive(Debug, Clone, Copy)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f32,
pub jitter: bool,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(10),
backoff_multiplier: 2.0,
jitter: true,
}
}
}
impl RetryPolicy {
fn backoff(&self, attempt: u32) -> Duration {
let base =
self.initial_backoff.as_millis() as f32 * self.backoff_multiplier.powi(attempt as i32);
let capped = base.min(self.max_backoff.as_millis() as f32);
let duration_ms = if self.jitter {
let jitter_factor = 0.5 + (random_f32() * 0.5);
capped * jitter_factor
} else {
capped
};
Duration::from_millis(duration_ms as u64).min(self.max_backoff)
}
}
pub async fn execute_with_retry<F, Fut, T>(
policy: RetryPolicy,
operation: F,
) -> Result<T, SimpleAgentsError>
where
F: Fn() -> Fut,
Fut: Future<Output = Result<T, SimpleAgentsError>>,
{
if policy.max_attempts == 0 {
return Err(SimpleAgentsError::Config(
"retry max_attempts must be >= 1".to_string(),
));
}
let mut last_error: Option<SimpleAgentsError> = None;
for attempt in 0..policy.max_attempts {
match operation().await {
Ok(result) => return Ok(result),
Err(error) => {
if !is_retryable(&error) {
return Err(error);
}
if attempt >= policy.max_attempts - 1 {
last_error = Some(error);
break;
}
tokio::time::sleep(policy.backoff(attempt)).await;
last_error = Some(error);
}
}
}
Err(last_error.unwrap_or_else(|| {
SimpleAgentsError::Config("retry loop exhausted without attempts".to_string())
}))
}
fn is_retryable(error: &SimpleAgentsError) -> bool {
matches!(
error,
SimpleAgentsError::Provider(
ProviderError::RateLimit { .. }
| ProviderError::Timeout(_)
| ProviderError::ServerError(_)
) | SimpleAgentsError::Network(_)
)
}
fn random_f32() -> f32 {
use rand::Rng;
rand::thread_rng().gen()
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn succeeds_without_retry() {
let policy = RetryPolicy {
max_attempts: 3,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(5),
backoff_multiplier: 2.0,
jitter: false,
};
let result =
execute_with_retry(policy, || async { Ok::<_, SimpleAgentsError>("ok") }).await;
assert_eq!(result.unwrap(), "ok");
}
#[tokio::test]
async fn retries_on_retryable_error() {
let policy = RetryPolicy {
max_attempts: 2,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(5),
backoff_multiplier: 2.0,
jitter: false,
};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_clone = attempts.clone();
let result = execute_with_retry(policy, move || {
let attempts = attempts_clone.clone();
async move {
let current = attempts.fetch_add(1, Ordering::Relaxed);
if current == 0 {
Err(SimpleAgentsError::Provider(ProviderError::Timeout(
Duration::from_secs(1),
)))
} else {
Ok("ok")
}
}
})
.await;
assert_eq!(result.unwrap(), "ok");
assert_eq!(attempts.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn fails_on_non_retryable_error() {
let policy = RetryPolicy {
max_attempts: 3,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(5),
backoff_multiplier: 2.0,
jitter: false,
};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_clone = attempts.clone();
let result = execute_with_retry(policy, move || {
let attempts = attempts_clone.clone();
async move {
attempts.fetch_add(1, Ordering::Relaxed);
Err::<&str, _>(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
}
})
.await;
assert!(matches!(
result,
Err(SimpleAgentsError::Provider(ProviderError::InvalidApiKey))
));
assert_eq!(attempts.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn zero_attempts_returns_config_error() {
let policy = RetryPolicy {
max_attempts: 0,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(5),
backoff_multiplier: 2.0,
jitter: false,
};
let result =
execute_with_retry(policy, || async { Ok::<_, SimpleAgentsError>("ok") }).await;
assert!(matches!(result, Err(SimpleAgentsError::Config(_))));
}
}