use std::time::Duration;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub backoff_multiplier: f64,
pub jitter: bool,
pub retry_on_401: bool,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_backoff: Duration::from_secs(1),
max_backoff: Duration::from_secs(30),
backoff_multiplier: 2.0,
jitter: true,
retry_on_401: true,
}
}
}
impl RetryConfig {
pub fn compute_backoff(
&self,
current_backoff: Duration,
retry_after: Option<Duration>,
) -> Duration {
if let Some(retry_after) = retry_after {
return retry_after;
}
let backoff = current_backoff.min(self.max_backoff);
if self.jitter {
self.add_jitter(backoff)
} else {
backoff
}
}
fn add_jitter(&self, duration: Duration) -> Duration {
use std::collections::hash_map::RandomState;
use std::hash::BuildHasher;
let millis = duration.as_millis() as f64;
let jitter_range = millis * 0.2;
let random =
(RandomState::new().hash_one(std::time::SystemTime::now()) % 100) as f64 / 100.0;
let jittered = millis + (jitter_range * (random * 2.0 - 1.0));
Duration::from_millis(jittered.max(0.0) as u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn jitter_stays_within_bounds() {
let config = RetryConfig::default();
let base = Duration::from_secs(10);
for _ in 0..100 {
let jittered = config.add_jitter(base);
let millis = jittered.as_millis();
assert!(millis >= 8000);
assert!(millis <= 12000);
}
}
}