use std::time::Duration;
#[derive(Debug, Clone)]
pub struct ReconnectPolicy {
pub initial_delay: Duration,
pub max_delay: Duration,
pub multiplier: f64,
pub max_attempts: Option<u32>,
pub jitter: bool,
}
impl Default for ReconnectPolicy {
fn default() -> Self {
Self {
initial_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(30),
multiplier: 1.5,
max_attempts: None,
jitter: true,
}
}
}
impl ReconnectPolicy {
pub fn new() -> Self {
Self::default()
}
pub fn with_initial_delay(mut self, delay: Duration) -> Self {
self.initial_delay = delay;
self
}
pub fn with_max_delay(mut self, max: Duration) -> Self {
self.max_delay = max;
self
}
pub fn with_multiplier(mut self, mult: f64) -> Self {
self.multiplier = mult.max(1.0);
self
}
pub fn with_max_attempts(mut self, n: u32) -> Self {
self.max_attempts = Some(n);
self
}
pub fn with_jitter(mut self, jitter: bool) -> Self {
self.jitter = jitter;
self
}
}
pub(crate) struct BackoffIter<'a> {
policy: &'a ReconnectPolicy,
current: Duration,
attempts: u32,
}
impl<'a> BackoffIter<'a> {
pub(crate) fn new(policy: &'a ReconnectPolicy) -> Self {
Self {
policy,
current: policy.initial_delay,
attempts: 0,
}
}
pub(crate) fn next_delay(&mut self) -> Option<Duration> {
if let Some(max) = self.policy.max_attempts
&& self.attempts >= max
{
return None;
}
self.attempts += 1;
let delay = self.current;
let next_secs = (self.current.as_secs_f64() * self.policy.multiplier)
.min(self.policy.max_delay.as_secs_f64());
self.current = Duration::from_secs_f64(next_secs);
if self.policy.jitter {
Some(apply_jitter(delay))
} else {
Some(delay)
}
}
pub(crate) fn reset(&mut self) {
self.current = self.policy.initial_delay;
self.attempts = 0;
}
}
fn apply_jitter(d: Duration) -> Duration {
let factor = rand::random_range(-0.1..=0.1);
let adjusted = d.as_secs_f64() * (1.0 + factor);
Duration::from_secs_f64(adjusted.max(0.0))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_backoff_grows_to_max() {
let policy = ReconnectPolicy::new()
.with_initial_delay(Duration::from_millis(100))
.with_max_delay(Duration::from_millis(500))
.with_multiplier(2.0)
.with_jitter(false);
let mut iter = BackoffIter::new(&policy);
assert_eq!(iter.next_delay(), Some(Duration::from_millis(100)));
assert_eq!(iter.next_delay(), Some(Duration::from_millis(200)));
assert_eq!(iter.next_delay(), Some(Duration::from_millis(400)));
assert_eq!(iter.next_delay(), Some(Duration::from_millis(500)));
assert_eq!(iter.next_delay(), Some(Duration::from_millis(500)));
}
#[test]
fn test_max_attempts_exhausted() {
let policy = ReconnectPolicy::new()
.with_initial_delay(Duration::from_millis(10))
.with_max_attempts(3)
.with_jitter(false);
let mut iter = BackoffIter::new(&policy);
assert!(iter.next_delay().is_some());
assert!(iter.next_delay().is_some());
assert!(iter.next_delay().is_some());
assert_eq!(iter.next_delay(), None);
}
#[test]
fn test_reset_restarts_delay() {
let policy = ReconnectPolicy::new()
.with_initial_delay(Duration::from_millis(100))
.with_multiplier(2.0)
.with_max_attempts(10)
.with_jitter(false);
let mut iter = BackoffIter::new(&policy);
iter.next_delay(); iter.next_delay(); iter.reset();
assert_eq!(iter.next_delay(), Some(Duration::from_millis(100)));
}
#[test]
fn test_jitter_stays_within_bounds() {
let policy = ReconnectPolicy::new()
.with_initial_delay(Duration::from_millis(1000))
.with_max_delay(Duration::from_millis(2000))
.with_multiplier(1.0) .with_jitter(true);
let mut iter = BackoffIter::new(&policy);
for _ in 0..50 {
let d = iter.next_delay().unwrap().as_secs_f64();
assert!((0.9..=1.11).contains(&d), "jitter out of range: {d}");
}
}
#[test]
fn test_multiplier_below_one_clamped() {
let policy = ReconnectPolicy::new()
.with_initial_delay(Duration::from_millis(100))
.with_multiplier(0.5) .with_jitter(false);
let mut iter = BackoffIter::new(&policy);
assert_eq!(iter.next_delay(), Some(Duration::from_millis(100)));
assert_eq!(iter.next_delay(), Some(Duration::from_millis(100)));
}
}