use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::future::Future;
use std::time::Duration;
use tokio::time::sleep;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum BackoffStrategy {
Fixed,
Exponential,
Linear,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum JitterType {
None,
Full,
Equal,
Decorrelated,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub base_delay: Duration,
pub max_delay: Duration,
pub strategy: BackoffStrategy,
pub jitter: JitterType,
pub backoff_multiplier: f64,
pub total_timeout: Option<Duration>,
}
impl RetryPolicy {
pub fn exponential(base_delay: Duration, max_attempts: u32) -> Self {
Self {
max_attempts,
base_delay,
max_delay: Duration::from_secs(60),
strategy: BackoffStrategy::Exponential,
jitter: JitterType::Equal,
backoff_multiplier: 2.0,
total_timeout: None,
}
}
pub fn fixed(delay: Duration, max_attempts: u32) -> Self {
Self {
max_attempts,
base_delay: delay,
max_delay: delay,
strategy: BackoffStrategy::Fixed,
jitter: JitterType::None,
backoff_multiplier: 1.0,
total_timeout: None,
}
}
pub fn linear(base_delay: Duration, max_attempts: u32) -> Self {
Self {
max_attempts,
base_delay,
max_delay: Duration::from_secs(60),
strategy: BackoffStrategy::Linear,
jitter: JitterType::Equal,
backoff_multiplier: 1.0,
total_timeout: None,
}
}
pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
self.max_delay = max_delay;
self
}
pub fn with_jitter(mut self, jitter: JitterType) -> Self {
self.jitter = jitter;
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.total_timeout = Some(timeout);
self
}
pub fn with_multiplier(mut self, multiplier: f64) -> Self {
self.backoff_multiplier = multiplier;
self
}
fn calculate_delay(&self, attempt: u32) -> Duration {
if attempt == 0 {
return Duration::from_secs(0);
}
let base_ms = self.base_delay.as_millis() as f64;
let computed_delay_ms = match self.strategy {
BackoffStrategy::Fixed => base_ms,
BackoffStrategy::Exponential => {
base_ms * self.backoff_multiplier.powi(attempt as i32 - 1)
}
BackoffStrategy::Linear => base_ms * attempt as f64,
};
let capped_ms = computed_delay_ms.min(self.max_delay.as_millis() as f64);
let final_ms = match self.jitter {
JitterType::None => capped_ms,
JitterType::Full => {
fastrand::f64() * capped_ms
}
JitterType::Equal => {
capped_ms / 2.0 + (fastrand::f64() * capped_ms / 2.0)
}
JitterType::Decorrelated => {
let last_delay = if attempt > 1 {
self.calculate_delay(attempt - 1).as_millis() as f64
} else {
base_ms
};
let random_delay = base_ms + (fastrand::f64() * (last_delay * 3.0 - base_ms));
random_delay.min(self.max_delay.as_millis() as f64)
}
};
Duration::from_millis(final_ms as u64)
}
pub async fn retry<F, Fut, T, E>(&self, mut f: F) -> Result<T>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<T, E>>,
E: std::error::Error + Send + Sync + 'static,
{
let start_time = std::time::Instant::now();
let mut last_error = None;
for attempt in 0..self.max_attempts {
if let Some(timeout) = self.total_timeout {
if start_time.elapsed() >= timeout {
return Err(anyhow!("Retry timeout exceeded after {attempt} attempts"));
}
}
match f().await {
Ok(result) => return Ok(result),
Err(e) => {
last_error = Some(e);
if attempt + 1 < self.max_attempts {
let delay = self.calculate_delay(attempt + 1);
sleep(delay).await;
}
}
}
}
if let Some(e) = last_error {
Err(anyhow!(
"Operation failed after {} attempts: {}",
self.max_attempts,
e
))
} else {
Err(anyhow!(
"Operation failed after {} attempts",
self.max_attempts
))
}
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self::exponential(Duration::from_millis(100), 3)
}
}
pub trait Retryable<T, E> {
fn with_retry(self, policy: RetryPolicy) -> impl Future<Output = Result<T>>;
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct RetryStats {
pub total_attempts: u64,
pub successful_ops: u64,
pub failed_ops: u64,
pub total_delay_ms: u64,
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[tokio::test]
async fn test_retry_success_first_attempt() {
let policy = RetryPolicy::exponential(Duration::from_millis(10), 3);
let result = policy
.retry(|| async { Ok::<_, std::io::Error>("success") })
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
}
#[tokio::test]
async fn test_retry_success_after_failures() {
let counter = Arc::new(AtomicU32::new(0));
let policy = RetryPolicy::exponential(Duration::from_millis(10), 3);
let counter_clone = counter.clone();
let result = policy
.retry(|| {
let c = counter_clone.clone();
async move {
let count = c.fetch_add(1, Ordering::SeqCst);
if count < 2 {
Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Transient failure",
))
} else {
Ok("success")
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_all_attempts_fail() {
let policy = RetryPolicy::exponential(Duration::from_millis(10), 3);
let result = policy
.retry(|| async {
Err::<&str, std::io::Error>(std::io::Error::new(
std::io::ErrorKind::Other,
"Always fails",
))
})
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_fixed_backoff() {
let policy = RetryPolicy::fixed(Duration::from_millis(50), 3);
for i in 1..=3 {
let delay = policy.calculate_delay(i);
assert_eq!(delay.as_millis(), 50);
}
}
#[tokio::test]
async fn test_exponential_backoff() {
let policy = RetryPolicy::exponential(Duration::from_millis(100), 4);
let policy_no_jitter = policy.with_jitter(JitterType::None);
let d1 = policy_no_jitter.calculate_delay(1).as_millis();
let d2 = policy_no_jitter.calculate_delay(2).as_millis();
let d3 = policy_no_jitter.calculate_delay(3).as_millis();
assert_eq!(d1, 100);
assert_eq!(d2, 200);
assert_eq!(d3, 400);
}
#[tokio::test]
async fn test_linear_backoff() {
let policy =
RetryPolicy::linear(Duration::from_millis(100), 4).with_jitter(JitterType::None);
let d1 = policy.calculate_delay(1).as_millis();
let d2 = policy.calculate_delay(2).as_millis();
let d3 = policy.calculate_delay(3).as_millis();
assert_eq!(d1, 100);
assert_eq!(d2, 200);
assert_eq!(d3, 300);
}
#[tokio::test]
async fn test_max_delay_cap() {
let policy = RetryPolicy::exponential(Duration::from_millis(100), 10)
.with_max_delay(Duration::from_millis(500))
.with_jitter(JitterType::None);
let delay = policy.calculate_delay(5);
assert!(delay.as_millis() <= 500);
}
#[tokio::test]
async fn test_jitter_full() {
let policy =
RetryPolicy::exponential(Duration::from_millis(100), 3).with_jitter(JitterType::Full);
for _ in 0..10 {
let delay = policy.calculate_delay(1);
assert!(delay.as_millis() <= 100);
}
}
#[tokio::test]
async fn test_jitter_equal() {
let policy =
RetryPolicy::exponential(Duration::from_millis(100), 3).with_jitter(JitterType::Equal);
for _ in 0..10 {
let delay = policy.calculate_delay(1);
let ms = delay.as_millis();
assert!(ms >= 50 && ms <= 100);
}
}
#[tokio::test]
async fn test_timeout() {
let policy = RetryPolicy::exponential(Duration::from_millis(50), 10)
.with_timeout(Duration::from_millis(150));
let start = std::time::Instant::now();
let result = policy
.retry(|| async {
Err::<&str, std::io::Error>(std::io::Error::new(
std::io::ErrorKind::Other,
"Always fails",
))
})
.await;
let elapsed = start.elapsed();
assert!(result.is_err());
assert!(elapsed < Duration::from_millis(500));
assert!(elapsed >= Duration::from_millis(150));
}
}