use std::time::Duration;
use nautilus_core::correctness::{check_in_range_inclusive_f64, check_predicate_true};
use rand::RngExt;
#[derive(Clone, Debug)]
pub struct ExponentialBackoff {
delay_initial: Duration,
delay_max: Duration,
delay_current: Duration,
factor: f64,
jitter_ms: u64,
immediate_reconnect: bool,
immediate_reconnect_original: bool,
}
impl ExponentialBackoff {
pub fn new(
delay_initial: Duration,
delay_max: Duration,
factor: f64,
jitter_ms: u64,
immediate_first: bool,
) -> anyhow::Result<Self> {
check_predicate_true(!delay_initial.is_zero(), "delay_initial must be non-zero")?;
check_predicate_true(
delay_max >= delay_initial,
"delay_max must be >= delay_initial",
)?;
check_predicate_true(
delay_max.as_nanos() <= u128::from(u64::MAX),
"delay_max exceeds maximum representable duration (≈584 years)",
)?;
check_in_range_inclusive_f64(factor, 1.0, 100.0, "factor")?;
Ok(Self {
delay_initial,
delay_max,
delay_current: delay_initial,
factor,
jitter_ms,
immediate_reconnect: immediate_first,
immediate_reconnect_original: immediate_first,
})
}
pub fn next_duration(&mut self) -> Duration {
if self.immediate_reconnect && self.delay_current == self.delay_initial {
self.immediate_reconnect = false;
return Duration::ZERO;
}
let jitter = rand::rng().random_range(0..=self.jitter_ms);
let delay_with_jitter = self.delay_current + Duration::from_millis(jitter);
let clamped_delay = std::cmp::min(delay_with_jitter, self.delay_max);
let current_nanos = self.delay_current.as_nanos();
let max_nanos = self.delay_max.as_nanos();
let next_nanos_u128 = if current_nanos > u128::from(u64::MAX) {
max_nanos
} else {
let current_u64 = current_nanos as u64;
let next_f64 = current_u64 as f64 * self.factor;
if next_f64 > u64::MAX as f64 {
u128::from(u64::MAX)
} else {
u128::from(next_f64 as u64)
}
};
let clamped = std::cmp::min(next_nanos_u128, max_nanos);
let final_nanos = if clamped > u128::from(u64::MAX) {
u64::MAX
} else {
clamped as u64
};
self.delay_current = Duration::from_nanos(final_nanos);
clamped_delay
}
pub const fn reset(&mut self) {
self.delay_current = self.delay_initial;
self.immediate_reconnect = self.immediate_reconnect_original;
}
#[must_use]
pub const fn current_delay(&self) -> Duration {
self.delay_current
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use rstest::rstest;
use super::*;
#[rstest]
fn test_no_jitter_exponential_growth() {
let initial = Duration::from_millis(100);
let max = Duration::from_millis(1600);
let factor = 2.0;
let jitter = 0;
let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
let d1 = backoff.next_duration();
assert_eq!(d1, Duration::from_millis(100));
let d2 = backoff.next_duration();
assert_eq!(d2, Duration::from_millis(200));
let d3 = backoff.next_duration();
assert_eq!(d3, Duration::from_millis(400));
let d4 = backoff.next_duration();
assert_eq!(d4, Duration::from_millis(800));
let d5 = backoff.next_duration();
assert_eq!(d5, Duration::from_millis(1600));
let d6 = backoff.next_duration();
assert_eq!(d6, Duration::from_millis(1600));
}
#[rstest]
fn test_reset() {
let initial = Duration::from_millis(100);
let max = Duration::from_millis(1600);
let factor = 2.0;
let jitter = 0;
let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
let _ = backoff.next_duration(); backoff.reset();
let d = backoff.next_duration();
assert_eq!(d, Duration::from_millis(100));
}
#[rstest]
fn test_jitter_within_bounds() {
let initial = Duration::from_millis(100);
let max = Duration::from_millis(1000);
let factor = 2.0;
let jitter = 50;
for _ in 0..10 {
let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
let base = backoff.delay_current;
let delay = backoff.next_duration();
let min_expected = base;
let max_expected = base + Duration::from_millis(jitter);
assert!(
delay >= min_expected,
"Delay {delay:?} is less than expected minimum {min_expected:?}"
);
assert!(
delay <= max_expected,
"Delay {delay:?} exceeds expected maximum {max_expected:?}"
);
}
}
#[rstest]
fn test_factor_less_than_two() {
let initial = Duration::from_millis(100);
let max = Duration::from_millis(200);
let factor = 1.5;
let jitter = 0;
let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
let d1 = backoff.next_duration();
assert_eq!(d1, Duration::from_millis(100));
let d2 = backoff.next_duration();
assert_eq!(d2, Duration::from_millis(150));
let d3 = backoff.next_duration();
assert_eq!(d3, Duration::from_millis(200));
let d4 = backoff.next_duration();
assert_eq!(d4, Duration::from_millis(200));
}
#[rstest]
fn test_max_delay_is_respected() {
let initial = Duration::from_millis(500);
let max = Duration::from_millis(1000);
let factor = 3.0;
let jitter = 0;
let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
let d1 = backoff.next_duration();
assert_eq!(d1, Duration::from_millis(500));
let d2 = backoff.next_duration();
assert_eq!(d2, Duration::from_millis(1000));
let d3 = backoff.next_duration();
assert_eq!(d3, Duration::from_millis(1000));
}
#[rstest]
fn test_current_delay_getter() {
let initial = Duration::from_millis(100);
let max = Duration::from_millis(1600);
let factor = 2.0;
let jitter = 0;
let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
assert_eq!(backoff.current_delay(), initial);
let _ = backoff.next_duration();
assert_eq!(backoff.current_delay(), Duration::from_millis(200));
let _ = backoff.next_duration();
assert_eq!(backoff.current_delay(), Duration::from_millis(400));
backoff.reset();
assert_eq!(backoff.current_delay(), initial);
}
#[rstest]
fn test_validation_zero_initial_delay() {
let result =
ExponentialBackoff::new(Duration::ZERO, Duration::from_millis(1000), 2.0, 0, false);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("delay_initial must be non-zero")
);
}
#[rstest]
fn test_validation_max_less_than_initial() {
let result = ExponentialBackoff::new(
Duration::from_millis(1000),
Duration::from_millis(500),
2.0,
0,
false,
);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("delay_max must be >= delay_initial")
);
}
#[rstest]
fn test_validation_factor_too_small() {
let result = ExponentialBackoff::new(
Duration::from_millis(100),
Duration::from_millis(1000),
0.5,
0,
false,
);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("factor"));
}
#[rstest]
fn test_validation_factor_too_large() {
let result = ExponentialBackoff::new(
Duration::from_millis(100),
Duration::from_millis(1000),
150.0,
0,
false,
);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("factor"));
}
#[rstest]
fn test_validation_delay_max_exceeds_u64_max_nanos() {
let max_valid = Duration::from_nanos(u64::MAX);
let too_large = max_valid + Duration::from_nanos(1);
let result = ExponentialBackoff::new(Duration::from_millis(100), too_large, 2.0, 0, false);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("delay_max exceeds maximum representable duration")
);
}
#[rstest]
fn test_immediate_first() {
let initial = Duration::from_millis(100);
let max = Duration::from_millis(1600);
let factor = 2.0;
let jitter = 0;
let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
let d1 = backoff.next_duration();
assert_eq!(
d1,
Duration::ZERO,
"Expected immediate reconnect (zero delay) on first call"
);
let d2 = backoff.next_duration();
assert_eq!(
d2, initial,
"Expected the delay to be the initial delay after immediate reconnect"
);
let d3 = backoff.next_duration();
let expected = initial * 2; assert_eq!(
d3, expected,
"Expected exponential growth from the initial delay"
);
}
#[rstest]
fn test_reset_restores_immediate_first() {
let initial = Duration::from_millis(100);
let max = Duration::from_millis(1600);
let factor = 2.0;
let jitter = 0;
let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
let d1 = backoff.next_duration();
assert_eq!(d1, Duration::ZERO);
let d2 = backoff.next_duration();
assert_eq!(d2, initial);
backoff.reset();
let d3 = backoff.next_duration();
assert_eq!(
d3,
Duration::ZERO,
"Reset should restore immediate_first behavior"
);
}
#[rstest]
fn test_jitter_never_exceeds_max_delay() {
let initial = Duration::from_millis(100);
let max = Duration::from_millis(1000);
let factor = 2.0;
let jitter = 500;
let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
while backoff.current_delay() < max {
backoff.next_duration();
}
for _ in 0..100 {
let delay = backoff.next_duration();
assert!(
delay <= max,
"Delay with jitter {delay:?} exceeded max {max:?}"
);
}
}
}