use chrono::TimeDelta;
use rand::Rng;
pub trait Strategy {
fn backoff(&self, attempt: u16) -> TimeDelta;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Constant {
delay: TimeDelta,
}
impl Strategy for Constant {
fn backoff(&self, _attempt: u16) -> TimeDelta {
self.delay
}
}
pub struct Exponential {
base: TimeDelta,
max: Option<TimeDelta>,
}
impl Strategy for Exponential {
fn backoff(&self, attempt: u16) -> TimeDelta {
let mut seconds = self
.base
.num_seconds()
.checked_pow(attempt.into())
.unwrap_or(i64::MAX);
if let Some(max) = self.max {
seconds = seconds.min(max.num_seconds());
}
TimeDelta::seconds(seconds)
}
}
pub struct Linear {
factor: TimeDelta,
max: Option<TimeDelta>,
}
impl Strategy for Linear {
fn backoff(&self, attempt: u16) -> TimeDelta {
let mut backoff = self.factor * attempt.into();
if let Some(max) = self.max {
backoff = backoff.min(max);
}
backoff
}
}
pub struct Polynomial {
factor: TimeDelta,
power: u32,
max: Option<TimeDelta>,
}
impl Strategy for Polynomial {
fn backoff(&self, attempt: u16) -> TimeDelta {
let mut backoff = self.factor * (attempt as i32).pow(self.power);
if let Some(max) = self.max {
backoff = backoff.min(max);
}
backoff
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Jitter {
Absolute(TimeDelta),
Relative(f64),
}
impl Jitter {
fn apply_jitter(&self, value: TimeDelta) -> TimeDelta {
let milliseconds = match self {
Self::Absolute(delta) => delta.num_milliseconds(),
Self::Relative(ratio) => (value.num_milliseconds() as f64 * ratio).round() as i64,
};
let rand_jitter_seconds = rand::thread_rng().gen_range(-milliseconds..=milliseconds);
value + TimeDelta::milliseconds(rand_jitter_seconds)
}
}
pub struct BackoffStrategy<T: Strategy> {
strategy: T,
jitter: Option<Jitter>,
additional_offset: Option<TimeDelta>,
min: TimeDelta,
}
impl BackoffStrategy<Constant> {
pub const fn constant(delay: TimeDelta) -> Self {
Self::new(Constant { delay })
}
}
impl BackoffStrategy<Exponential> {
pub const fn exponential(base: TimeDelta) -> Self {
Self::new(Exponential { base, max: None })
}
pub const fn with_max(mut self, max_delay: TimeDelta) -> Self {
self.strategy.max = Some(max_delay);
self
}
}
impl BackoffStrategy<Linear> {
pub const fn linear(factor: TimeDelta) -> Self {
Self::new(Linear { factor, max: None })
}
pub const fn with_max(mut self, max_delay: TimeDelta) -> Self {
self.strategy.max = Some(max_delay);
self
}
}
impl BackoffStrategy<Polynomial> {
pub const fn polynomial(factor: TimeDelta, power: u32) -> Self {
Self::new(Polynomial {
factor,
power,
max: None,
})
}
pub const fn with_max(mut self, max_delay: TimeDelta) -> Self {
self.strategy.max = Some(max_delay);
self
}
}
impl<T> BackoffStrategy<T>
where
T: Strategy,
{
pub const fn new(strategy: T) -> Self {
Self {
strategy,
jitter: None,
additional_offset: None,
min: TimeDelta::zero(),
}
}
pub const fn with_jitter(mut self, jitter: Jitter) -> Self {
self.jitter = Some(jitter);
self
}
pub const fn with_min(mut self, min: TimeDelta) -> Self {
self.min = min;
self
}
#[doc(hidden)]
pub const fn with_offset(mut self, offset: TimeDelta) -> Self {
self.additional_offset = Some(offset);
self
}
}
impl<T> Strategy for BackoffStrategy<T>
where
T: Strategy,
{
fn backoff(&self, attempt: u16) -> TimeDelta {
let mut backoff = self.strategy.backoff(attempt);
if let Some(additional_offset) = self.additional_offset {
backoff += additional_offset;
}
if let Some(jitter) = self.jitter {
backoff = jitter.apply_jitter(backoff);
}
backoff.max(self.min)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn constant_backoff() {
let delay = TimeDelta::minutes(1);
let strategy = BackoffStrategy::constant(delay);
for i in 1..100 {
assert_eq!(strategy.backoff(i), delay);
}
}
#[test]
fn constant_backoff_with_absolute_jitter() {
let delay = TimeDelta::minutes(1);
let jitter = TimeDelta::seconds(10);
let strategy = BackoffStrategy::constant(delay).with_jitter(Jitter::Absolute(jitter));
for i in 1..100 {
let backoff = strategy.backoff(i);
assert!(backoff >= delay - jitter);
assert!(backoff <= delay + jitter);
}
}
#[test]
fn constant_backoff_with_relative_jitter() {
let delay = TimeDelta::minutes(1);
let strategy = BackoffStrategy::constant(delay).with_jitter(Jitter::Relative(0.1));
for i in 1..100 {
let jitter = TimeDelta::seconds(10);
let backoff = strategy.backoff(i);
assert!(backoff >= delay - jitter);
assert!(backoff <= delay + jitter);
}
}
#[test]
fn constant_backoff_with_jitter_min() {
let delay = TimeDelta::seconds(20);
let jitter = TimeDelta::seconds(20);
let min = TimeDelta::seconds(5);
let strategy = BackoffStrategy::constant(delay)
.with_jitter(Jitter::Absolute(jitter))
.with_min(min);
for i in 1..100 {
let backoff = strategy.backoff(i);
assert!(backoff >= min);
assert!(backoff <= delay + jitter);
}
}
#[test]
fn polynomial_backoff() {
let delay = TimeDelta::minutes(1);
let strategy = BackoffStrategy::polynomial(delay, 2);
for i in 1..100 {
assert_eq!(strategy.backoff(i), delay * i.pow(2) as _);
}
}
#[test]
fn polynomial_backoff_with_max() {
let delay = TimeDelta::minutes(1);
let max = TimeDelta::minutes(10);
let strategy = BackoffStrategy::polynomial(delay, 2).with_max(max);
for i in 1..100 {
let backoff = strategy.backoff(i);
assert!(backoff <= max);
}
}
#[test]
fn linear_backoff() {
let delay = TimeDelta::minutes(1);
let strategy = BackoffStrategy::linear(delay);
for i in 1..100 {
assert_eq!(strategy.backoff(i), delay * i as _);
}
}
#[test]
fn linear_backoff_with_max() {
let delay = TimeDelta::minutes(1);
let max = TimeDelta::minutes(10);
let strategy = BackoffStrategy::linear(delay).with_max(max);
for i in 1..100 {
let backoff = strategy.backoff(i);
assert!(backoff <= max);
}
}
#[test]
fn linear_backoff_with_absolute_jitter() {
let delay = TimeDelta::minutes(1);
let jitter = TimeDelta::seconds(10);
let strategy = BackoffStrategy::linear(delay).with_jitter(Jitter::Absolute(jitter));
for i in 1..100 {
let backoff = strategy.backoff(i);
assert!(backoff >= delay * i as _ - jitter);
assert!(backoff <= delay * i as _ + jitter);
}
}
#[test]
fn linear_backoff_with_relative_jitter() {
let delay = TimeDelta::minutes(1);
let strategy = BackoffStrategy::linear(delay).with_jitter(Jitter::Relative(0.1));
for i in 1..100 {
let backoff = strategy.backoff(i);
let jitter = TimeDelta::seconds(10) * i as _;
assert!(backoff >= delay * i as _ - jitter);
assert!(backoff <= delay * i as _ + jitter);
}
}
#[test]
fn linear_backoff_with_jitter_min() {
let delay = TimeDelta::seconds(20);
let jitter = TimeDelta::seconds(20);
let min = TimeDelta::seconds(5);
let strategy = BackoffStrategy::linear(delay)
.with_jitter(Jitter::Absolute(jitter))
.with_min(min);
for i in 1..100 {
let backoff = strategy.backoff(i);
assert!(backoff >= min);
assert!(backoff <= delay * i as _ + jitter);
}
}
#[test]
fn exponential_backoff() {
let delay = TimeDelta::seconds(1);
let strategy = BackoffStrategy::exponential(delay);
for i in 1..10 {
assert_eq!(
strategy.backoff(i).num_seconds(),
delay.num_seconds().pow(i as _)
);
}
}
#[test]
fn exponential_backoff_with_max() {
let delay = TimeDelta::minutes(1);
let max = TimeDelta::minutes(10);
let strategy = BackoffStrategy::exponential(delay).with_max(max);
for i in 1..100 {
let backoff = strategy.backoff(i);
assert!(backoff <= max);
}
}
#[test]
fn exponential_backoff_with_absolute_jitter() {
let delay = TimeDelta::minutes(1);
let jitter = TimeDelta::seconds(10);
let strategy = BackoffStrategy::exponential(delay).with_jitter(Jitter::Absolute(jitter));
for i in 1..5 {
let backoff = strategy.backoff(i);
assert!(
backoff.num_seconds() >= delay.num_seconds().pow(i as _) - jitter.num_seconds()
);
assert!(
backoff.num_seconds() <= delay.num_seconds().pow(i as _) + jitter.num_seconds()
);
}
}
#[test]
fn exponential_backoff_with_relative_jitter() {
let delay = TimeDelta::minutes(1);
let strategy = BackoffStrategy::exponential(delay).with_jitter(Jitter::Relative(0.1));
for i in 1..5 {
let backoff = strategy.backoff(i);
assert!(backoff.num_seconds() as f64 >= delay.num_seconds().pow(i as _) as f64 * 0.9);
assert!(backoff.num_seconds() as f64 <= delay.num_seconds().pow(i as _) as f64 * 1.1);
}
}
#[test]
fn exponential_backoff_with_jitter_min() {
let delay = TimeDelta::seconds(1);
let jitter = TimeDelta::seconds(10);
let min = TimeDelta::seconds(5);
let strategy = BackoffStrategy::exponential(delay)
.with_jitter(Jitter::Absolute(jitter))
.with_min(min);
for i in 1..5 {
let backoff = strategy.backoff(i);
assert!(backoff >= min);
}
}
}