use backoff::{backoff::Backoff, ExponentialBackoff, ExponentialBackoffBuilder};
use std::future::Future;
use std::time::Duration;
use tracing::{debug, warn};
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub base_delay_ms: u64,
pub max_delay_ms: u64,
pub backoff_multiplier: f64,
pub use_jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
base_delay_ms: 1000, max_delay_ms: 30_000, backoff_multiplier: 2.0, use_jitter: true,
}
}
}
impl RetryConfig {
pub fn fast() -> Self {
Self {
max_retries: 5,
base_delay_ms: 100, max_delay_ms: 5_000, backoff_multiplier: 1.5, use_jitter: true,
}
}
pub fn slow() -> Self {
Self {
max_retries: 3,
base_delay_ms: 5000, max_delay_ms: 60_000, backoff_multiplier: 2.0, use_jitter: true,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorClass {
Permanent,
Retryable,
RateLimited,
}
fn create_backoff(config: &RetryConfig) -> ExponentialBackoff {
let mut builder = ExponentialBackoffBuilder::new();
builder.with_initial_interval(Duration::from_millis(config.base_delay_ms));
builder.with_max_interval(Duration::from_millis(config.max_delay_ms));
builder.with_multiplier(config.backoff_multiplier);
builder.with_max_elapsed_time(None);
if !config.use_jitter {
builder.with_randomization_factor(0.0);
} else {
builder.with_randomization_factor(0.25);
}
builder.build()
}
pub async fn retry_async<F, Fut, T, E, C>(
mut operation: F,
classifier: C,
config: &RetryConfig,
operation_name: &str,
) -> Result<T, E>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: std::fmt::Display + Clone,
C: Fn(&E) -> ErrorClass,
{
debug!(
"Starting operation '{}' with retry config: max_retries={}, base_delay={}ms",
operation_name, config.max_retries, config.base_delay_ms
);
let mut backoff = create_backoff(config);
let mut attempts = 0u32;
loop {
attempts += 1;
debug!("Attempt {} for '{}'", attempts, operation_name);
match operation().await {
Ok(result) => {
if attempts > 1 {
debug!(
"Operation '{}' succeeded after {} attempts",
operation_name, attempts
);
}
return Ok(result);
}
Err(error) => {
let error_class = classifier(&error);
warn!(
"Operation '{}' failed (attempt {}): {} (class: {:?})",
operation_name, attempts, error, error_class
);
match error_class {
ErrorClass::Permanent => {
debug!("Error is permanent, not retrying");
return Err(error);
}
ErrorClass::Retryable | ErrorClass::RateLimited => {
if attempts > config.max_retries {
warn!(
"Operation '{}' failed after {} attempts",
operation_name, attempts
);
return Err(error);
}
let delay = if let Some(duration) = backoff.next_backoff() {
if error_class == ErrorClass::RateLimited {
duration * 2
} else {
duration
}
} else {
warn!("Backoff exhausted for '{}'", operation_name);
return Err(error);
};
debug!("Retrying '{}' after {:?}", operation_name, delay);
tokio::time::sleep(delay).await;
}
}
}
}
}
}
pub async fn retry_with_backoff<F, Fut, T>(
operation: F,
config: &RetryConfig,
operation_name: &str,
) -> Result<T, String>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, String>>,
{
retry_async(
operation,
|_| ErrorClass::Retryable, config,
operation_name,
)
.await
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[tokio::test]
async fn test_retry_succeeds_first_attempt() {
let config = RetryConfig::fast();
let result = retry_async(
|| async { Ok::<_, String>("success") },
|_| ErrorClass::Retryable,
&config,
"test_op",
)
.await;
assert_eq!(result.unwrap(), "success");
}
#[tokio::test]
async fn test_retry_succeeds_after_failures() {
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = attempts.clone();
let config = RetryConfig::fast();
let result = retry_async(
|| {
let attempts = attempts_clone.clone();
async move {
let count = attempts.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err("temporary failure".to_string())
} else {
Ok("success")
}
}
},
|_| ErrorClass::Retryable,
&config,
"test_op",
)
.await;
assert_eq!(result.unwrap(), "success");
assert_eq!(attempts.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_permanent_error_no_retry() {
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = attempts.clone();
let config = RetryConfig::fast();
let result = retry_async(
|| {
let attempts = attempts_clone.clone();
async move {
attempts.fetch_add(1, Ordering::SeqCst);
Err::<String, _>("permanent error".to_string())
}
},
|_| ErrorClass::Permanent,
&config,
"test_op",
)
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn test_retry_exhausts_all_attempts() {
let attempts = Arc::new(AtomicU32::new(0));
let attempts_clone = attempts.clone();
let config = RetryConfig {
max_retries: 2,
base_delay_ms: 10,
max_delay_ms: 100,
backoff_multiplier: 2.0,
use_jitter: false,
};
let result = retry_async(
|| {
let attempts = attempts_clone.clone();
async move {
attempts.fetch_add(1, Ordering::SeqCst);
Err::<String, _>("always fails".to_string())
}
},
|_| ErrorClass::Retryable,
&config,
"test_op",
)
.await;
assert!(result.is_err());
assert_eq!(attempts.load(Ordering::SeqCst), 3); }
#[test]
fn test_create_backoff_config() {
let config = RetryConfig {
max_retries: 5,
base_delay_ms: 100,
max_delay_ms: 10_000,
backoff_multiplier: 2.0,
use_jitter: false,
};
let backoff = create_backoff(&config);
assert_eq!(backoff.initial_interval, Duration::from_millis(100));
assert_eq!(backoff.max_interval, Duration::from_millis(10_000));
assert_eq!(backoff.multiplier, 2.0);
assert_eq!(backoff.randomization_factor, 0.0); }
#[test]
fn test_create_backoff_with_jitter() {
let config = RetryConfig {
max_retries: 5,
base_delay_ms: 100,
max_delay_ms: 10_000,
backoff_multiplier: 2.0,
use_jitter: true,
};
let backoff = create_backoff(&config);
assert_eq!(backoff.randomization_factor, 0.25); }
#[test]
fn test_retry_config_presets() {
let fast = RetryConfig::fast();
assert_eq!(fast.base_delay_ms, 100);
assert_eq!(fast.max_retries, 5);
let slow = RetryConfig::slow();
assert_eq!(slow.base_delay_ms, 5000);
assert_eq!(slow.max_retries, 3);
}
}