use std::time::Duration;
use crate::error::Error;
#[derive(Debug, Clone, Copy)]
pub struct RetryPolicy {
pub max_retries: u32,
pub base_delay: Duration,
pub max_delay: Duration,
pub jitter: bool,
}
impl Default for RetryPolicy {
fn default() -> Self {
RetryPolicy {
max_retries: 2,
base_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(20),
jitter: true,
}
}
}
impl RetryPolicy {
#[must_use]
pub fn none() -> Self {
RetryPolicy {
max_retries: 0,
..Self::default()
}
}
#[must_use]
pub fn with_max_retries(mut self, max_retries: u32) -> Self {
self.max_retries = max_retries;
self
}
#[must_use]
pub fn with_base_delay(mut self, base_delay: Duration) -> Self {
self.base_delay = base_delay;
self
}
#[must_use]
pub fn with_max_delay(mut self, max_delay: Duration) -> Self {
self.max_delay = max_delay;
self
}
#[must_use]
pub fn with_jitter(mut self, jitter: bool) -> Self {
self.jitter = jitter;
self
}
pub fn should_retry(&self, retries_done: u32, err: &Error) -> bool {
retries_done < self.max_retries && err.is_retryable()
}
pub fn delay_for(&self, retries_done: u32, err: &Error) -> Duration {
if let Some(hint) = err.retry_after() {
return hint.min(self.max_delay);
}
self.backoff(retries_done)
}
fn backoff(&self, n: u32) -> Duration {
let factor = 2u32.saturating_pow(n);
let raw = self.base_delay.saturating_mul(factor).min(self.max_delay);
if self.jitter {
let frac = jitter_fraction();
raw.mul_f64(frac)
} else {
raw
}
}
}
fn jitter_fraction() -> f64 {
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
static COUNTER: AtomicU32 = AtomicU32::new(0);
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_or(0, |d| d.subsec_nanos());
let mut x = (nanos ^ COUNTER.fetch_add(1, Ordering::Relaxed)) | 1;
x ^= x << 13;
x ^= x >> 17;
x ^= x << 5;
f64::from(x % 1_000_000) / 1_000_000.0
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::print_stdout
)]
mod tests {
use super::*;
fn api_err(status: u16, retry_after: Option<Duration>) -> Error {
Error::Api {
status,
code: None,
message: "x".into(),
error: None,
retry_after,
}
}
#[test]
fn does_not_retry_non_retryable() {
let p = RetryPolicy::default();
assert!(!p.should_retry(0, &api_err(400, None)));
assert!(!p.should_retry(0, &Error::Decode("x".into())));
}
#[test]
fn retries_transient_within_budget() {
let p = RetryPolicy::default().with_max_retries(2);
assert!(p.should_retry(0, &api_err(429, None)));
assert!(p.should_retry(1, &api_err(503, None)));
assert!(!p.should_retry(2, &api_err(503, None)));
}
#[test]
fn none_disables_retry() {
let p = RetryPolicy::none();
assert!(!p.should_retry(0, &api_err(503, None)));
}
#[test]
fn retry_after_hint_is_honored_and_clamped() {
let p = RetryPolicy::default()
.with_jitter(false)
.with_max_delay(Duration::from_secs(5));
let d = p.delay_for(0, &api_err(429, Some(Duration::from_secs(2))));
assert_eq!(d, Duration::from_secs(2));
let clamped = p.delay_for(0, &api_err(429, Some(Duration::from_secs(45))));
assert_eq!(clamped, Duration::from_secs(5));
}
#[test]
fn backoff_grows_and_clamps_without_jitter() {
let p = RetryPolicy::default()
.with_jitter(false)
.with_base_delay(Duration::from_millis(100))
.with_max_delay(Duration::from_secs(1));
assert_eq!(
p.delay_for(0, &api_err(503, None)),
Duration::from_millis(100)
);
assert_eq!(
p.delay_for(1, &api_err(503, None)),
Duration::from_millis(200)
);
assert_eq!(
p.delay_for(2, &api_err(503, None)),
Duration::from_millis(400)
);
assert_eq!(p.delay_for(4, &api_err(503, None)), Duration::from_secs(1));
}
#[test]
fn jitter_stays_within_bounds() {
let p = RetryPolicy::default()
.with_jitter(true)
.with_base_delay(Duration::from_millis(100))
.with_max_delay(Duration::from_secs(10));
for _ in 0..1000 {
let d = p.delay_for(1, &api_err(503, None));
assert!(
d <= Duration::from_millis(200),
"jitter exceeded raw: {d:?}"
);
}
}
}