use crate::api::SocketTime;
use std::cmp::min;
use std::time::Duration;
pub const MAX_DURATION: Duration = Duration::from_secs(24 * 3600);
pub const MAX_RESTARTS: u32 = u32::MAX;
const MAX_BACKOFF_COUNT: u32 = 10;
pub enum BackoffAlgorithm {
Fixed,
Exponential,
}
pub struct Timer {
base_duration: Duration,
expiration_count: u32,
backoff_algorithm: BackoffAlgorithm,
max_restarts: u32,
max_backoff_duration: Duration,
next_expiry: Option<SocketTime>,
}
impl Timer {
pub fn new(
duration: Duration,
backoff_algorithm: BackoffAlgorithm,
max_restarts: Option<u32>,
max_backoff_duration: Option<Duration>,
) -> Self {
Self {
base_duration: duration,
backoff_algorithm,
expiration_count: 0,
max_restarts: max_restarts.unwrap_or(MAX_RESTARTS),
max_backoff_duration: min(max_backoff_duration.unwrap_or(MAX_DURATION), MAX_DURATION),
next_expiry: None,
}
}
fn get_backoff_duration(&self) -> Duration {
let duration = match self.backoff_algorithm {
BackoffAlgorithm::Fixed => self.base_duration,
BackoffAlgorithm::Exponential => {
let backoff_count = self.expiration_count.saturating_sub(1).min(MAX_BACKOFF_COUNT);
self.base_duration.saturating_mul(1 << backoff_count)
}
};
min(duration, self.max_backoff_duration)
}
fn compute_expiry(&self, from_time: SocketTime) -> Option<SocketTime> {
if self.base_duration == Duration::ZERO {
None
} else {
Some(from_time + self.get_backoff_duration())
}
}
pub fn expire(&mut self, now: SocketTime) -> bool {
let Some(current_expiry) = self.next_expiry else {
return false;
};
if current_expiry > now {
return false;
}
let restarts_remaining = self.expiration_count < self.max_restarts;
self.expiration_count = self.expiration_count.saturating_add(1);
self.next_expiry =
restarts_remaining.then(|| self.compute_expiry(current_expiry)).flatten();
true
}
pub fn next_expiry(&self) -> Option<SocketTime> {
self.next_expiry
}
pub fn is_running(&self) -> bool {
self.next_expiry.is_some()
}
pub fn stop(&mut self) {
self.next_expiry = None;
}
pub fn start(&mut self, now: SocketTime) {
self.expiration_count = 0;
self.next_expiry = self.compute_expiry(now);
}
pub fn set_duration(&mut self, duration: Duration) {
self.base_duration = duration;
}
pub fn duration(&self) -> Duration {
self.base_duration
}
}
#[cfg(test)]
mod tests {
use super::*;
const START_TIME: SocketTime = SocketTime::zero();
#[test]
fn new_timer_is_not_running() {
let t = Timer::new(
Duration::from_millis(1000),
BackoffAlgorithm::Fixed,
None,
None,
);
assert_eq!(t.duration(), Duration::from_millis(1000));
assert!(!t.is_running());
assert!(t.next_expiry().is_none());
}
#[test]
fn stopped_timer_does_not_expire() {
let mut t = Timer::new(
Duration::from_millis(1000),
BackoffAlgorithm::Fixed,
None,
None,
);
assert_eq!(t.duration(), Duration::from_millis(1000));
let now = START_TIME;
t.start(now);
t.stop();
assert!(!t.expire(now + Duration::from_millis(1000)));
}
#[test]
fn timer_expires_after_duration() {
let mut t = Timer::new(
Duration::from_millis(1000),
BackoffAlgorithm::Fixed,
None,
None,
);
let now = START_TIME;
t.start(now);
assert!(t.is_running());
assert!(!t.expire(now + Duration::from_millis(999)));
assert!(t.expire(now + Duration::from_millis(1000)));
assert!(t.is_running());
}
#[test]
fn timer_restarts_after_expired() {
let mut t = Timer::new(
Duration::from_millis(1000),
BackoffAlgorithm::Fixed,
None,
None,
);
let now = START_TIME;
t.start(now);
assert!(t.is_running());
assert!(t.expire(now + Duration::from_millis(1000)));
assert_eq!(t.next_expiry, Some(now + Duration::from_millis(2000)));
assert!(!t.expire(now + Duration::from_millis(1001)));
assert!(t.expire(now + Duration::from_millis(2000)));
assert_eq!(t.next_expiry, Some(now + Duration::from_millis(3000)));
assert!(t.is_running());
}
#[test]
fn timer_stops_when_exhausted() {
let mut t = Timer::new(
Duration::from_millis(1000),
BackoffAlgorithm::Fixed,
Some(0),
None,
);
let now = START_TIME;
t.start(now);
assert!(t.is_running());
assert!(t.expire(now + Duration::from_millis(1000)));
assert!(!t.is_running());
assert!(t.next_expiry.is_none());
}
#[test]
fn can_be_restarted_limited_number_times() {
let mut t = Timer::new(
Duration::from_millis(1000),
BackoffAlgorithm::Fixed,
Some(2),
None,
);
let now = START_TIME;
t.start(now);
assert!(t.is_running());
assert!(t.expire(now + Duration::from_millis(1000)));
assert!(t.expire(now + Duration::from_millis(2000)));
assert!(t.expire(now + Duration::from_millis(3000)));
assert!(!t.is_running());
}
#[test]
fn timer_restart_does_not_drift() {
let mut t = Timer::new(
Duration::from_millis(1000),
BackoffAlgorithm::Fixed,
None,
None,
);
let now = START_TIME;
t.start(now);
assert!(t.is_running());
assert!(t.expire(now + Duration::from_millis(1050)));
assert_eq!(t.next_expiry, Some(now + Duration::from_millis(2000)));
}
#[test]
fn can_do_exponential_backoff() {
let mut t = Timer::new(
Duration::from_millis(1000),
BackoffAlgorithm::Exponential,
None,
None,
);
let now = START_TIME;
t.start(now);
assert!(t.is_running());
assert!(t.expire(now + Duration::from_millis(1050)));
assert_eq!(t.next_expiry, Some(now + Duration::from_millis(2000)));
assert!(t.expire(now + Duration::from_millis(2100)));
assert_eq!(t.next_expiry, Some(now + Duration::from_millis(4000)));
assert!(t.expire(now + Duration::from_millis(4400)));
assert_eq!(t.next_expiry, Some(now + Duration::from_millis(8000)));
}
#[test]
fn does_not_overflow_when_expired_many_times() {
let mut t = Timer::new(
Duration::from_millis(1000),
BackoffAlgorithm::Exponential,
None,
None,
);
let now = START_TIME;
t.start(now);
assert!(t.is_running());
for _ in 0..1000 {
let now = t.next_expiry().unwrap();
t.expire(now);
}
}
}