use std::time::Duration;
use eyre::{Result, ensure};
use rand::RngExt as _;
use super::RetryConfig;
#[derive(Clone, Debug)]
pub struct ExponentialBackoff {
delay_initial: Duration,
delay_max: Duration,
delay_current: Duration,
factor: f64,
jitter_ms: u64,
immediate_reconnect: bool,
immediate_reconnect_original: bool,
}
impl ExponentialBackoff {
pub fn try_new(delay_initial: Duration, delay_max: Duration, factor: f64, jitter_ms: u64, immediate_first: bool) -> Result<Self> {
ensure!(!delay_initial.is_zero(), "delay_initial must be non-zero");
ensure!(delay_max >= delay_initial, "delay_max must be >= delay_initial");
ensure!(delay_max.as_nanos() <= u128::from(u64::MAX), "delay_max exceeds maximum representable duration (≈584 years)");
ensure!((1.0..=100.0).contains(&factor), "factor must be in range [1.0, 100.0], got {factor}");
Ok(Self {
delay_initial,
delay_max,
delay_current: delay_initial,
factor,
jitter_ms,
immediate_reconnect: immediate_first,
immediate_reconnect_original: immediate_first,
})
}
pub fn next_duration(&mut self) -> Duration {
if self.immediate_reconnect && self.delay_current == self.delay_initial {
self.immediate_reconnect = false;
return Duration::ZERO;
}
let jitter = rand::rng().random_range(0..=self.jitter_ms);
let delay_with_jitter = self.delay_current + Duration::from_millis(jitter);
let current_nanos = self.delay_current.as_nanos();
let max_nanos = self.delay_max.as_nanos();
let next_nanos_u128 = if current_nanos > u128::from(u64::MAX) {
max_nanos
} else {
let current_u64 = current_nanos as u64;
let next_f64 = current_u64 as f64 * self.factor;
if next_f64 > u64::MAX as f64 { u128::from(u64::MAX) } else { u128::from(next_f64 as u64) }
};
let clamped = std::cmp::min(next_nanos_u128, max_nanos);
let final_nanos = if clamped > u128::from(u64::MAX) { u64::MAX } else { clamped as u64 };
self.delay_current = Duration::from_nanos(final_nanos);
delay_with_jitter
}
pub const fn reset(&mut self) {
self.delay_current = self.delay_initial;
self.immediate_reconnect = self.immediate_reconnect_original;
}
#[must_use]
pub const fn current_delay(&self) -> Duration {
self.delay_current
}
}
impl TryFrom<&RetryConfig> for ExponentialBackoff {
type Error = eyre::Report;
fn try_from(c: &RetryConfig) -> Result<Self> {
Self::try_new(
Duration::from_millis(c.initial_delay_ms),
Duration::from_millis(c.max_delay_ms),
c.backoff_factor,
c.jitter_ms,
c.immediate_first,
)
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
#[test]
fn test_no_jitter_exponential_growth() {
let initial = Duration::from_millis(100);
let max = Duration::from_millis(1600);
let mut backoff = ExponentialBackoff::try_new(initial, max, 2.0, 0, false).unwrap();
let d1 = backoff.next_duration();
assert_eq!(d1, Duration::from_millis(100));
let d2 = backoff.next_duration();
assert_eq!(d2, Duration::from_millis(200));
let d3 = backoff.next_duration();
assert_eq!(d3, Duration::from_millis(400));
let d4 = backoff.next_duration();
assert_eq!(d4, Duration::from_millis(800));
let d5 = backoff.next_duration();
assert_eq!(d5, Duration::from_millis(1600));
let d6 = backoff.next_duration();
assert_eq!(d6, Duration::from_millis(1600));
}
#[test]
fn test_reset() {
let initial = Duration::from_millis(100);
let max = Duration::from_millis(1600);
let mut backoff = ExponentialBackoff::try_new(initial, max, 2.0, 0, false).unwrap();
let _ = backoff.next_duration();
backoff.reset();
let d = backoff.next_duration();
assert_eq!(d, Duration::from_millis(100));
}
#[test]
fn test_jitter_within_bounds() {
let initial = Duration::from_millis(100);
let max = Duration::from_millis(1000);
let jitter = 50;
for _ in 0..10 {
let mut backoff = ExponentialBackoff::try_new(initial, max, 2.0, jitter, false).unwrap();
let base = backoff.delay_current;
let delay = backoff.next_duration();
assert!(delay >= base, "Delay {delay:?} is less than expected minimum {base:?}");
assert!(delay <= base + Duration::from_millis(jitter), "Delay {delay:?} exceeds expected maximum");
}
}
#[test]
fn test_immediate_first() {
let initial = Duration::from_millis(100);
let max = Duration::from_millis(1600);
let mut backoff = ExponentialBackoff::try_new(initial, max, 2.0, 0, true).unwrap();
let d1 = backoff.next_duration();
assert_eq!(d1, Duration::ZERO, "Expected immediate reconnect on first call");
let d2 = backoff.next_duration();
assert_eq!(d2, initial, "Expected initial delay after immediate reconnect");
let d3 = backoff.next_duration();
assert_eq!(d3, initial * 2, "Expected exponential growth from initial delay");
}
#[test]
fn test_reset_restores_immediate_first() {
let initial = Duration::from_millis(100);
let max = Duration::from_millis(1600);
let mut backoff = ExponentialBackoff::try_new(initial, max, 2.0, 0, true).unwrap();
let d1 = backoff.next_duration();
assert_eq!(d1, Duration::ZERO);
let d2 = backoff.next_duration();
assert_eq!(d2, initial);
backoff.reset();
let d3 = backoff.next_duration();
assert_eq!(d3, Duration::ZERO, "Reset should restore immediate_first behavior");
}
#[test]
fn test_validation_zero_initial_delay() {
let result = ExponentialBackoff::try_new(Duration::ZERO, Duration::from_millis(1000), 2.0, 0, false);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("delay_initial must be non-zero"));
}
#[test]
fn test_validation_max_less_than_initial() {
let result = ExponentialBackoff::try_new(Duration::from_millis(1000), Duration::from_millis(500), 2.0, 0, false);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("delay_max must be >= delay_initial"));
}
#[test]
fn test_validation_factor_out_of_range() {
let result = ExponentialBackoff::try_new(Duration::from_millis(100), Duration::from_millis(1000), 0.5, 0, false);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("factor"));
let result2 = ExponentialBackoff::try_new(Duration::from_millis(100), Duration::from_millis(1000), 150.0, 0, false);
assert!(result2.is_err());
assert!(result2.unwrap_err().to_string().contains("factor"));
}
}