use crate::error::ZerobusError;
use rand::Rng;
use std::time::Duration;
use tokio::time::sleep;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_attempts: u32,
pub base_delay_ms: u64,
pub max_delay_ms: u64,
pub jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 5,
base_delay_ms: 100,
max_delay_ms: 30000,
jitter: true,
}
}
}
impl RetryConfig {
pub fn new(max_attempts: u32, base_delay_ms: u64, max_delay_ms: u64) -> Self {
Self {
max_attempts,
base_delay_ms,
max_delay_ms,
jitter: true,
}
}
pub async fn execute_with_retry<F, Fut, T>(&self, f: F) -> Result<T, ZerobusError>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, ZerobusError>>,
{
let (result, _) = self.execute_with_retry_tracked(f).await;
result
}
pub async fn execute_with_retry_tracked<F, Fut, T>(
&self,
mut f: F,
) -> (Result<T, ZerobusError>, u32)
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, ZerobusError>>,
{
let mut last_error = None;
for attempt in 0..self.max_attempts {
let attempt_number = attempt + 1; match f().await {
Ok(result) => return (Ok(result), attempt_number),
Err(e) => {
last_error = Some(e.clone());
if !e.is_retryable() {
return (Err(e), attempt_number);
}
if attempt < self.max_attempts - 1 {
let delay = self.calculate_delay(attempt);
sleep(delay).await;
}
}
}
}
(
Err(ZerobusError::RetryExhausted(format!(
"All {} retry attempts exhausted. Last error: {}",
self.max_attempts,
last_error
.as_ref()
.map(|e| e.to_string())
.unwrap_or_else(|| "unknown".to_string())
))),
self.max_attempts,
)
}
fn calculate_delay(&self, attempt: u32) -> Duration {
let exponential_delay_ms = self.base_delay_ms.saturating_mul(1 << attempt.min(20));
let capped_delay_ms = exponential_delay_ms.min(self.max_delay_ms);
let delay_ms = if self.jitter {
let mut rng = rand::thread_rng();
rng.gen_range(0..=capped_delay_ms)
} else {
capped_delay_ms
};
Duration::from_millis(delay_ms)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_retry_succeeds_on_first_attempt() {
let config = RetryConfig::default();
let result = config
.execute_with_retry(|| async { Ok::<_, ZerobusError>("success".to_string()) })
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
}
#[tokio::test]
async fn test_retry_exhausted_after_max_attempts() {
let config = RetryConfig::new(3, 10, 1000);
let mut attempts = 0;
let result = config
.execute_with_retry(|| {
attempts += 1;
async { Err::<String, _>(ZerobusError::ConnectionError("test error".to_string())) }
})
.await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
ZerobusError::RetryExhausted(_)
));
assert_eq!(attempts, 3);
}
}