use std::time::{Duration, Instant};
use rand::Rng;
pub trait RetryPolicy: Send + Sync {
fn should_retry(&mut self) -> bool;
fn attempt_count(&self) -> u32;
fn next_sleep(&self) -> Duration;
}
pub struct ExponentialTimeBoundedRetry {
max_duration: Duration,
max_sleep: Duration,
start: Instant,
attempts: u32,
current_sleep: Duration,
}
impl ExponentialTimeBoundedRetry {
pub fn new(max_duration: Duration, initial_sleep: Duration, max_sleep: Duration) -> Self {
Self {
max_duration,
max_sleep,
start: Instant::now(),
attempts: 0,
current_sleep: initial_sleep,
}
}
pub fn with_defaults() -> Self {
Self::new(
Duration::from_secs(120),
Duration::from_millis(50),
Duration::from_secs(3),
)
}
}
impl RetryPolicy for ExponentialTimeBoundedRetry {
fn should_retry(&mut self) -> bool {
if self.attempts == 0 {
self.attempts = 1;
return true;
}
let elapsed = self.start.elapsed();
if elapsed >= self.max_duration {
return false;
}
self.attempts += 1;
if self.attempts > 2 {
self.current_sleep = std::cmp::min(self.current_sleep * 2, self.max_sleep);
}
true
}
fn attempt_count(&self) -> u32 {
self.attempts
}
fn next_sleep(&self) -> Duration {
add_jitter(self.current_sleep)
}
}
pub struct ExponentialBackoffRetry {
max_sleep: Duration,
max_retries: u32,
attempts: u32,
current_sleep: Duration,
}
impl ExponentialBackoffRetry {
pub fn new(base_sleep: Duration, max_sleep: Duration, max_retries: u32) -> Self {
Self {
max_sleep,
max_retries,
attempts: 0,
current_sleep: base_sleep,
}
}
}
impl RetryPolicy for ExponentialBackoffRetry {
fn should_retry(&mut self) -> bool {
if self.attempts == 0 {
self.attempts = 1;
return true;
}
if self.attempts > self.max_retries {
return false;
}
self.attempts += 1;
if self.attempts > 2 {
self.current_sleep = std::cmp::min(self.current_sleep * 2, self.max_sleep);
}
true
}
fn attempt_count(&self) -> u32 {
self.attempts
}
fn next_sleep(&self) -> Duration {
add_jitter(self.current_sleep)
}
}
fn add_jitter(base: Duration) -> Duration {
let mut rng = rand::thread_rng();
let jitter_fraction: f64 = rng.gen_range(0.0..0.1);
let jitter = base.mul_f64(jitter_fraction);
base + jitter
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_time_bounded_first_attempt_always_allowed() {
let mut policy = ExponentialTimeBoundedRetry::new(
Duration::from_millis(0), Duration::from_millis(10),
Duration::from_millis(100),
);
assert!(policy.should_retry());
assert_eq!(policy.attempt_count(), 1);
assert!(!policy.should_retry());
}
#[test]
fn test_time_bounded_multiple_retries() {
let mut policy = ExponentialTimeBoundedRetry::new(
Duration::from_secs(10), Duration::from_millis(10),
Duration::from_millis(200),
);
for _ in 0..5 {
assert!(policy.should_retry());
}
assert!(policy.attempt_count() == 5);
}
#[test]
fn test_time_bounded_sleep_grows() {
let initial = Duration::from_millis(50);
let max_sleep = Duration::from_secs(3);
let mut policy =
ExponentialTimeBoundedRetry::new(Duration::from_secs(60), initial, max_sleep);
assert!(policy.should_retry()); let s1 = policy.next_sleep();
assert!(policy.should_retry()); let _s2 = policy.next_sleep();
assert!(policy.should_retry()); let s3 = policy.next_sleep();
assert!(s1 <= initial + initial.mul_f64(0.11)); assert!(s3 >= initial); }
#[test]
fn test_backoff_retry_max_retries() {
let mut policy = ExponentialBackoffRetry::new(
Duration::from_millis(10),
Duration::from_millis(100),
3, );
assert!(policy.should_retry()); assert!(policy.should_retry()); assert!(policy.should_retry()); assert!(policy.should_retry()); assert!(!policy.should_retry()); assert_eq!(policy.attempt_count(), 4);
}
#[test]
fn test_backoff_retry_zero_retries() {
let mut policy =
ExponentialBackoffRetry::new(Duration::from_millis(10), Duration::from_millis(100), 0);
assert!(policy.should_retry()); assert!(!policy.should_retry()); assert_eq!(policy.attempt_count(), 1);
}
#[test]
fn test_backoff_sleep_capped() {
let base = Duration::from_millis(50);
let max_sleep = Duration::from_millis(100);
let mut policy = ExponentialBackoffRetry::new(base, max_sleep, 10);
for _ in 0..6 {
assert!(policy.should_retry());
}
let sleep = policy.next_sleep();
assert!(sleep <= max_sleep + max_sleep.mul_f64(0.11));
}
#[test]
fn test_jitter_within_bounds() {
let base = Duration::from_millis(100);
for _ in 0..100 {
let result = add_jitter(base);
assert!(result >= base);
assert!(result <= base + base.mul_f64(0.11)); }
}
}