use rand::{rngs::StdRng, Rng, SeedableRng};
use std::time::Duration;
use crate::error::ConnectError;
use dyn_clone::DynClone;
pub trait RetryStrategy: Send + Sync + DynClone + 'static {
fn next_delay(&mut self, error: &ConnectError, attempt: u32) -> Option<Duration>;
fn reset(&mut self);
fn clone_box(&self) -> Box<dyn RetryStrategy>
where
Self: Sized + 'static,
{
dyn_clone::clone_box(self)
}
}
dyn_clone::clone_trait_object!(RetryStrategy);
#[derive(Clone)]
pub struct ExponentialBackoff {
initial: Duration,
max: Duration,
factor: f64,
jitter: f64,
max_attempts: Option<u32>,
seed: Option<u64>,
current_delay: Duration,
rng: StdRng,
}
impl ExponentialBackoff {
#[must_use]
pub fn new(initial: Duration, max: Duration) -> Self {
Self {
initial,
max,
factor: 2.0,
jitter: 0.1,
max_attempts: None,
seed: None,
current_delay: initial,
rng: StdRng::from_os_rng(),
}
}
#[must_use]
pub const fn with_factor(mut self, factor: f64) -> Self {
self.factor = factor;
self
}
#[must_use]
pub fn with_jitter(mut self, jitter: f64) -> Self {
self.jitter = jitter.clamp(0.0, 1.0);
self
}
#[must_use]
pub const fn with_max_attempts(mut self, max: u32) -> Self {
self.max_attempts = Some(max);
self
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self.rng = StdRng::seed_from_u64(seed);
self
}
#[must_use]
pub fn fast() -> Self {
Self::new(Duration::from_millis(100), Duration::from_secs(5))
.with_factor(1.5)
.with_jitter(0.1)
}
#[must_use]
pub fn standard() -> Self {
Self::new(Duration::from_secs(1), Duration::from_secs(60))
.with_factor(2.0)
.with_jitter(0.1)
}
#[must_use]
pub fn conservative() -> Self {
Self::new(Duration::from_secs(2), Duration::from_secs(120))
.with_factor(2.0)
.with_jitter(0.05)
}
fn apply_jitter(&mut self, duration: Duration) -> Duration {
if self.jitter == 0.0 {
return duration;
}
let secs = duration.as_secs_f64();
let jitter_range = secs * self.jitter;
let jitter = self.rng.random_range(-jitter_range..jitter_range);
let result = (secs + jitter).max(0.0);
Duration::from_secs_f64(result)
}
}
impl Default for ExponentialBackoff {
fn default() -> Self {
Self::standard()
}
}
impl RetryStrategy for ExponentialBackoff {
fn next_delay(&mut self, error: &ConnectError, attempt: u32) -> Option<Duration> {
if !error.is_retryable() {
return None;
}
if let Some(max) = self.max_attempts {
if attempt >= max {
return None;
}
}
if let Some(suggested) = error.suggested_delay() {
return Some(self.apply_jitter(suggested));
}
let delay = self.current_delay.min(self.max);
let delay_with_jitter = self.apply_jitter(delay);
let next_delay_secs = self.current_delay.as_secs_f64() * self.factor;
self.current_delay = Duration::from_secs_f64(next_delay_secs.min(self.max.as_secs_f64()));
Some(delay_with_jitter)
}
fn reset(&mut self) {
self.current_delay = self.initial;
if let Some(seed) = self.seed {
self.rng = StdRng::seed_from_u64(seed);
}
}
}
#[derive(Clone)]
pub struct FixedDelay {
delay: Duration,
max_attempts: Option<u32>,
}
impl FixedDelay {
#[must_use]
pub const fn new(delay: Duration) -> Self {
Self {
delay,
max_attempts: None,
}
}
#[must_use]
pub const fn with_max_attempts(mut self, max: u32) -> Self {
self.max_attempts = Some(max);
self
}
}
impl RetryStrategy for FixedDelay {
fn next_delay(&mut self, error: &ConnectError, attempt: u32) -> Option<Duration> {
if !error.is_retryable() {
return None;
}
if let Some(max) = self.max_attempts {
if attempt >= max {
return None;
}
}
Some(self.delay)
}
fn reset(&mut self) {
}
}
#[derive(Clone, Copy, Default)]
pub struct NoRetry;
impl RetryStrategy for NoRetry {
fn next_delay(&mut self, _error: &ConnectError, _attempt: u32) -> Option<Duration> {
None
}
fn reset(&mut self) {}
}
pub struct CustomRetry<F>
where
F: Fn(&ConnectError, u32) -> Option<Duration> + Send + Sync + Clone + 'static,
{
f: F,
}
impl<F> CustomRetry<F>
where
F: Fn(&ConnectError, u32) -> Option<Duration> + Send + Sync + Clone + 'static,
{
#[must_use]
pub const fn new(f: F) -> Self {
Self { f }
}
}
impl<F> Clone for CustomRetry<F>
where
F: Fn(&ConnectError, u32) -> Option<Duration> + Send + Sync + Clone + 'static,
{
fn clone(&self) -> Self {
Self { f: self.f.clone() }
}
}
impl<F> RetryStrategy for CustomRetry<F>
where
F: Fn(&ConnectError, u32) -> Option<Duration> + Send + Sync + Clone + 'static,
{
fn next_delay(&mut self, error: &ConnectError, attempt: u32) -> Option<Duration> {
(self.f)(error, attempt)
}
fn reset(&mut self) {}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exponential_backoff() {
let mut strategy = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(60))
.with_factor(2.0)
.with_jitter(0.0) .with_seed(42);
let error = ConnectError::TcpConnect("test".into());
let d1 = strategy.next_delay(&error, 1).unwrap();
assert_eq!(d1, Duration::from_secs(1));
let d2 = strategy.next_delay(&error, 2).unwrap();
assert_eq!(d2, Duration::from_secs(2));
let d3 = strategy.next_delay(&error, 3).unwrap();
assert_eq!(d3, Duration::from_secs(4));
strategy.reset();
let d1_again = strategy.next_delay(&error, 1).unwrap();
assert_eq!(d1_again, Duration::from_secs(1));
}
#[test]
fn test_exponential_backoff_max() {
let mut strategy = ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(5))
.with_jitter(0.0);
let error = ConnectError::Refused;
for _ in 0..10 {
let delay = strategy.next_delay(&error, 1).unwrap();
assert!(delay <= Duration::from_secs(5));
}
}
#[test]
fn test_non_retryable_error() {
let mut strategy = ExponentialBackoff::standard();
let error = ConnectError::InvalidUri("bad".into());
assert!(strategy.next_delay(&error, 1).is_none());
}
#[test]
fn test_max_attempts() {
let mut strategy = ExponentialBackoff::fast().with_max_attempts(3);
let error = ConnectError::Refused;
assert!(strategy.next_delay(&error, 1).is_some());
assert!(strategy.next_delay(&error, 2).is_some());
assert!(strategy.next_delay(&error, 3).is_none()); }
#[test]
fn test_fixed_delay() {
let mut strategy = FixedDelay::new(Duration::from_secs(5));
let error = ConnectError::Refused;
assert_eq!(strategy.next_delay(&error, 1), Some(Duration::from_secs(5)));
assert_eq!(strategy.next_delay(&error, 2), Some(Duration::from_secs(5)));
assert_eq!(strategy.next_delay(&error, 3), Some(Duration::from_secs(5)));
}
#[test]
fn test_no_retry() {
let mut strategy = NoRetry;
let error = ConnectError::Refused;
assert!(strategy.next_delay(&error, 1).is_none());
}
}