use crate::error::{AgentError, ErrorRecovery};
use std::future::Future;
use std::pin::Pin;
pub type Result<T> = std::result::Result<T, AgentError>;
pub type RetryOperation<'a, T> =
Box<dyn FnMut() -> Pin<Box<dyn Future<Output = Result<T>> + Send + 'a>> + Send + 'a>;
pub async fn retry_with_recovery<'a, T: 'a>(
mut operation: RetryOperation<'a, T>,
max_attempts: Option<u32>,
) -> Result<T> {
let max_attempts = max_attempts.unwrap_or(1);
let mut attempt = 0;
loop {
attempt += 1;
match operation().await {
Ok(result) => return Ok(result),
Err(err) => {
if !err.is_retriable() || attempt >= max_attempts {
return Err(err);
}
eprintln!("Attempt {}/{} failed: {}", attempt, max_attempts, err);
eprintln!(" → {}", err.suggested_action());
let backoff = err.backoff_strategy();
if let Some(delay) = backoff.delay_for_attempt(attempt) {
eprintln!(" Waiting {:?} before retry...", delay);
tokio::time::sleep(delay).await;
} else {
return Err(err);
}
}
}
}
}
pub async fn retry<'a, T: 'a>(mut operation: RetryOperation<'a, T>) -> Result<T> {
let mut attempt = 0;
loop {
attempt += 1;
match operation().await {
Ok(result) => return Ok(result),
Err(err) => {
let max_retries = err.max_retries().unwrap_or(0);
if !err.is_retriable() || attempt > max_retries {
eprintln!("Error (giving up): {}", err);
eprintln!(" → {}", err.suggested_action());
return Err(err);
}
eprintln!("Attempt {}/{} failed: {}", attempt, max_retries, err);
eprintln!(" → {}", err.suggested_action());
let backoff = err.backoff_strategy();
if let Some(delay) = backoff.delay_for_attempt(attempt) {
eprintln!(" Waiting {:?} before retry...", delay);
tokio::time::sleep(delay).await;
} else {
return Err(err);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
#[tokio::test]
async fn test_retry_succeeds_on_first_attempt() {
let result: Result<i32> = retry_with_recovery(
Box::new(|| Box::pin(async { Ok::<i32, AgentError>(42) })),
Some(3),
)
.await;
assert_eq!(result, Ok(42));
}
#[tokio::test]
async fn test_retry_succeeds_on_second_attempt() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result: Result<i32> = retry_with_recovery(
Box::new(move || {
let counter = counter_clone.clone();
Box::pin(async move {
let count = counter.fetch_add(1, Ordering::SeqCst);
if count == 0 {
Err(AgentError::Transport("connection lost".to_string()))
} else {
Ok::<i32, AgentError>(42)
}
})
}),
Some(3),
)
.await;
assert_eq!(result, Ok(42));
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_retry_gives_up_on_non_retriable_error() {
let result: Result<i32> = retry_with_recovery(
Box::new(|| {
Box::pin(async {
Err::<i32, AgentError>(AgentError::Config("bad config".to_string()))
})
}),
Some(3),
)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_retry_respects_max_attempts() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result: Result<i32> = retry_with_recovery(
Box::new(move || {
let counter = counter_clone.clone();
Box::pin(async move {
counter.fetch_add(1, Ordering::SeqCst);
Err::<i32, AgentError>(AgentError::Transport("always fails".to_string()))
})
}),
Some(3),
)
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_uses_error_policy() {
let counter = Arc::new(AtomicU32::new(0));
let counter_clone = counter.clone();
let result: Result<i32> = retry(Box::new(move || {
let counter = counter_clone.clone();
Box::pin(async move {
let count = counter.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err::<i32, AgentError>(AgentError::Transport("failing".to_string()))
} else {
Ok(42)
}
})
}))
.await;
assert_eq!(result, Ok(42));
}
}