use crate::{RainyError, Result};
use std::time::Duration;
use tokio::time::sleep;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub base_delay_ms: u64,
pub max_delay_ms: u64,
pub backoff_multiplier: f64,
pub jitter: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
base_delay_ms: 1000,
max_delay_ms: 30000,
backoff_multiplier: 2.0,
jitter: true,
}
}
}
impl RetryConfig {
pub fn new(max_retries: u32) -> Self {
Self {
max_retries,
..Default::default()
}
}
pub fn delay_for_attempt(&self, attempt: u32) -> Duration {
let base_delay = self.base_delay_ms as f64;
let multiplier = self.backoff_multiplier.powi(attempt as i32);
let mut delay = base_delay * multiplier;
if self.jitter && attempt > 0 {
use rand::RngExt;
let mut rng = rand::rng();
let jitter_factor = rng.random_range(0.75..=1.25);
delay *= jitter_factor;
}
delay = delay.min(self.max_delay_ms as f64);
Duration::from_millis(delay as u64)
}
}
pub async fn retry_with_backoff<F, Fut, T>(config: &RetryConfig, operation: F) -> Result<T>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let mut last_error = None;
for attempt in 0..=config.max_retries {
match operation().await {
Ok(result) => return Ok(result),
Err(error) => {
if !error.is_retryable() || attempt == config.max_retries {
return Err(error);
}
let delay = config.delay_for_attempt(attempt);
#[cfg(feature = "tracing")]
tracing::warn!(
"Request failed (attempt {}/{}), retrying in {:?}: {}",
attempt + 1,
config.max_retries + 1,
delay,
error
);
last_error = Some(error);
if attempt < config.max_retries {
sleep(delay).await;
}
}
}
}
Err(last_error.unwrap_or_else(|| RainyError::Network {
message: "All retry attempts failed".to_string(),
retryable: false,
source_error: None,
}))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_delay_calculation() {
let config = RetryConfig::default();
let delay0 = config.delay_for_attempt(0);
let delay1 = config.delay_for_attempt(1);
let delay2 = config.delay_for_attempt(2);
assert!(delay0.as_millis() >= 1000);
assert!(delay1.as_millis() >= delay0.as_millis());
assert!(delay2.as_millis() >= delay1.as_millis());
assert!(delay2.as_millis() <= 30000);
}
}