use anyhow::Result;
use std::time::Duration;
use tokio::time::sleep;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: usize,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
}
impl RetryConfig {
pub fn new() -> Self {
Self {
max_retries: 2,
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
}
}
pub fn with_settings(
max_retries: usize,
initial_delay: Duration,
max_delay: Duration,
backoff_multiplier: f64,
) -> Self {
Self {
max_retries,
initial_delay,
max_delay,
backoff_multiplier,
}
}
pub fn no_retry() -> Self {
Self {
max_retries: 0,
initial_delay: Duration::from_secs(0),
max_delay: Duration::from_secs(0),
backoff_multiplier: 1.0,
}
}
pub fn calculate_delay(&self, attempt: usize) -> Duration {
if attempt == 0 {
return self.initial_delay;
}
let delay_secs = self.initial_delay.as_secs_f64()
* self.backoff_multiplier.powi(attempt as i32);
let delay = Duration::from_secs_f64(delay_secs);
if delay > self.max_delay {
self.max_delay
} else {
delay
}
}
}
impl Default for RetryConfig {
fn default() -> Self {
Self::new()
}
}
pub async fn retry_async<F, Fut, T>(config: &RetryConfig, mut f: F) -> Result<T>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let mut last_error = None;
for attempt in 0..=config.max_retries {
match f().await {
Ok(result) => {
if attempt > 0 {
tracing::info!("Retry successful after {} attempt(s)", attempt);
}
return Ok(result);
}
Err(e) => {
last_error = Some(e);
if attempt < config.max_retries {
let delay = config.calculate_delay(attempt);
tracing::warn!(
"Attempt {} failed, retrying in {:?}...",
attempt + 1,
delay
);
sleep(delay).await;
} else {
tracing::error!("All {} retry attempts failed", config.max_retries + 1);
}
}
}
}
Err(last_error.unwrap())
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorType {
Transient,
Permanent,
}
pub fn classify_error(error: &anyhow::Error) -> ErrorType {
let error_msg = error.to_string().to_lowercase();
if error_msg.contains("429") || error_msg.contains("rate limit") {
return ErrorType::Transient;
}
if error_msg.contains("500") || error_msg.contains("502")
|| error_msg.contains("503") || error_msg.contains("504")
|| error_msg.contains("server error")
{
return ErrorType::Transient;
}
if error_msg.contains("400") || error_msg.contains("401")
|| error_msg.contains("403") || error_msg.contains("404")
|| error_msg.contains("client error")
{
return ErrorType::Permanent;
}
if error_msg.contains("timeout")
|| error_msg.contains("connection")
|| error_msg.contains("temporarily unavailable")
|| error_msg.contains("too many open files")
|| error_msg.contains("resource temporarily unavailable")
|| error_msg.contains("resource deadlock")
|| error_msg.contains("try again")
{
return ErrorType::Transient;
}
if error_msg.contains("file not found")
|| error_msg.contains("no such file")
|| error_msg.contains("permission denied")
|| error_msg.contains("access denied")
|| error_msg.contains("read-only")
|| error_msg.contains("disk full")
|| error_msg.contains("no space left")
{
return ErrorType::Permanent;
}
if error_msg.contains("invalid data found")
|| error_msg.contains("codec not found")
|| error_msg.contains("unsupported codec")
|| error_msg.contains("unknown codec")
|| error_msg.contains("invalid audio")
|| error_msg.contains("invalid sample rate")
|| error_msg.contains("invalid bit rate")
|| error_msg.contains("invalid channel")
|| error_msg.contains("not supported")
|| error_msg.contains("does not contain any stream")
|| error_msg.contains("no decoder")
|| error_msg.contains("no encoder")
|| error_msg.contains("moov atom not found")
|| error_msg.contains("invalid argument")
|| error_msg.contains("protocol not found")
{
return ErrorType::Permanent;
}
if error_msg.contains("corrupted")
|| error_msg.contains("corrupt")
|| error_msg.contains("truncated")
|| error_msg.contains("header missing")
|| error_msg.contains("malformed")
|| error_msg.contains("end of file")
{
return ErrorType::Permanent;
}
ErrorType::Transient
}
pub async fn smart_retry_async<F, Fut, T>(config: &RetryConfig, mut f: F) -> Result<T>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let mut last_error = None;
for attempt in 0..=config.max_retries {
match f().await {
Ok(result) => {
if attempt > 0 {
tracing::info!("Smart retry successful after {} attempt(s)", attempt);
}
return Ok(result);
}
Err(e) => {
let error_type = classify_error(&e);
if error_type == ErrorType::Permanent {
tracing::error!("Permanent error detected, not retrying: {:?}", e);
return Err(e);
}
if attempt < config.max_retries {
let delay = config.calculate_delay(attempt);
tracing::warn!(
"Transient error on attempt {}: {:?}",
attempt + 1,
e
);
tracing::warn!(
"Retrying in {:?}... ({} attempts remaining)",
delay,
config.max_retries - attempt
);
sleep(delay).await;
} else {
tracing::error!(
"All {} retry attempts exhausted. Final error: {:?}",
config.max_retries + 1,
e
);
}
last_error = Some(e);
}
}
}
Err(last_error.unwrap())
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[test]
fn test_retry_config_creation() {
let config = RetryConfig::new();
assert_eq!(config.max_retries, 2);
assert_eq!(config.initial_delay, Duration::from_secs(1));
assert_eq!(config.backoff_multiplier, 2.0);
}
#[test]
fn test_retry_config_no_retry() {
let config = RetryConfig::no_retry();
assert_eq!(config.max_retries, 0);
}
#[test]
fn test_calculate_delay() {
let config = RetryConfig::new();
assert_eq!(config.calculate_delay(0), Duration::from_secs(1));
assert_eq!(config.calculate_delay(1), Duration::from_secs(2));
assert_eq!(config.calculate_delay(2), Duration::from_secs(4));
assert_eq!(config.calculate_delay(3), Duration::from_secs(8));
let config = RetryConfig::with_settings(
5,
Duration::from_secs(1),
Duration::from_secs(5),
2.0,
);
assert_eq!(config.calculate_delay(10), Duration::from_secs(5));
}
#[tokio::test]
async fn test_retry_async_success_first_try() {
let config = RetryConfig::new();
let counter = Arc::new(AtomicUsize::new(0));
let result: Result<i32> = retry_async(&config, || {
let counter = Arc::clone(&counter);
async move {
counter.fetch_add(1, Ordering::Relaxed);
Ok::<i32, anyhow::Error>(42)
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(counter.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn test_retry_async_success_after_retries() {
let config = RetryConfig::with_settings(
3,
Duration::from_millis(10),
Duration::from_millis(100),
2.0,
);
let counter = Arc::new(AtomicUsize::new(0));
let result = retry_async(&config, || {
let counter = Arc::clone(&counter);
async move {
let count = counter.fetch_add(1, Ordering::Relaxed);
if count < 2 {
anyhow::bail!("Transient error");
}
Ok::<i32, anyhow::Error>(42)
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(counter.load(Ordering::Relaxed), 3);
}
#[tokio::test]
async fn test_retry_async_all_fail() {
let config = RetryConfig::with_settings(
2,
Duration::from_millis(10),
Duration::from_millis(100),
2.0,
);
let counter = Arc::new(AtomicUsize::new(0));
let result: Result<i32> = retry_async(&config, || {
let counter = Arc::clone(&counter);
async move {
counter.fetch_add(1, Ordering::Relaxed);
anyhow::bail!("Always fails")
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::Relaxed), 3); }
#[test]
fn test_classify_error() {
let transient = anyhow::anyhow!("Connection timeout");
assert_eq!(classify_error(&transient), ErrorType::Transient);
let permanent = anyhow::anyhow!("File not found");
assert_eq!(classify_error(&permanent), ErrorType::Permanent);
let unknown = anyhow::anyhow!("Some random error");
assert_eq!(classify_error(&unknown), ErrorType::Transient);
}
#[tokio::test]
async fn test_smart_retry_permanent_error() {
let config = RetryConfig::new();
let counter = Arc::new(AtomicUsize::new(0));
let result: Result<i32> = smart_retry_async(&config, || {
let counter = Arc::clone(&counter);
async move {
counter.fetch_add(1, Ordering::Relaxed);
anyhow::bail!("File not found")
}
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::Relaxed), 1); }
#[tokio::test]
async fn test_smart_retry_transient_error() {
let config = RetryConfig::with_settings(
2,
Duration::from_millis(10),
Duration::from_millis(100),
2.0,
);
let counter = Arc::new(AtomicUsize::new(0));
let result = smart_retry_async(&config, || {
let counter = Arc::clone(&counter);
async move {
let count = counter.fetch_add(1, Ordering::Relaxed);
if count < 2 {
anyhow::bail!("Connection timeout");
}
Ok::<i32, anyhow::Error>(42)
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
assert_eq!(counter.load(Ordering::Relaxed), 3);
}
}