use rand::Rng;
use std::convert::TryInto;
use std::num::{NonZeroU32, NonZeroU8};
use std::time::Duration;
use serde::Deserialize;
pub struct RetryDelay {
last_delay_ms: u32,
low_bound_ms: u32,
}
const MIN_LOW_BOUND: u32 = 1000;
const MAX_LOW_BOUND: u32 = std::u32::MAX - 1;
const MAX_DELAY_MULT: u32 = 3;
impl RetryDelay {
pub fn from_msec(base_delay_msec: u32) -> Self {
let low_bound_ms = base_delay_msec.clamp(MIN_LOW_BOUND, MAX_LOW_BOUND);
RetryDelay {
last_delay_ms: 0,
low_bound_ms,
}
}
pub fn from_duration(d: Duration) -> Self {
let msec = d.as_millis();
let msec = std::cmp::min(msec, u128::from(MAX_LOW_BOUND)) as u32;
RetryDelay::from_msec(msec)
}
fn delay_bounds(&self) -> (u32, u32) {
let low = self.low_bound_ms;
let high = std::cmp::max(
low + 1,
self.last_delay_ms.saturating_mul(MAX_DELAY_MULT),
);
(low, high)
}
pub fn next_delay_msec<R: Rng>(&mut self, rng: &mut R) -> u32 {
let (low, high) = self.delay_bounds();
assert!(low < high);
let val = rng.gen_range(low..high);
self.last_delay_ms = val;
val
}
pub fn next_delay<R: Rng>(&mut self, rng: &mut R) -> Duration {
Duration::from_millis(u64::from(self.next_delay_msec(rng)))
}
}
impl Default for RetryDelay {
fn default() -> Self {
RetryDelay::from_msec(0)
}
}
#[derive(Debug, Copy, Clone, Deserialize, Eq, PartialEq)]
#[serde(deny_unknown_fields)]
pub struct DownloadSchedule {
num_retries: NonZeroU32,
#[serde(with = "humantime_serde")]
initial_delay: Duration,
#[serde(default = "default_parallelism")]
parallelism: NonZeroU8,
}
impl Default for DownloadSchedule {
fn default() -> Self {
DownloadSchedule::new(3, Duration::from_millis(1000), 1)
}
}
fn default_parallelism() -> NonZeroU8 {
#![allow(clippy::unwrap_used)]
1.try_into().unwrap()
}
impl DownloadSchedule {
#[allow(clippy::missing_panics_doc)] pub fn new(attempts: u32, initial_delay: Duration, parallelism: u8) -> Self {
#![allow(clippy::unwrap_used)]
let num_retries = attempts
.try_into()
.unwrap_or_else(|_| 1.try_into().unwrap());
let parallelism = parallelism
.try_into()
.unwrap_or_else(|_| 1.try_into().unwrap());
DownloadSchedule {
num_retries,
initial_delay,
parallelism,
}
}
pub fn attempts(&self) -> impl Iterator<Item = u32> {
0..(self.num_retries.into())
}
pub fn n_attempts(&self) -> u32 {
self.num_retries.into()
}
pub fn parallelism(&self) -> u8 {
self.parallelism.into()
}
pub fn schedule(&self) -> RetryDelay {
RetryDelay::from_duration(self.initial_delay)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn init() {
let rd = RetryDelay::from_msec(2000);
assert_eq!(rd.last_delay_ms, 0);
assert_eq!(rd.low_bound_ms, 2000);
let rd = RetryDelay::from_msec(0);
assert_eq!(rd.last_delay_ms, 0);
assert_eq!(rd.low_bound_ms, 1000);
let rd = RetryDelay::from_duration(Duration::new(1, 500_000_000));
assert_eq!(rd.last_delay_ms, 0);
assert_eq!(rd.low_bound_ms, 1500);
}
#[test]
fn bounds() {
let mut rd = RetryDelay::from_msec(1000);
assert_eq!(rd.delay_bounds(), (1000, 1001));
rd.last_delay_ms = 1500;
assert_eq!(rd.delay_bounds(), (1000, 4500));
rd.last_delay_ms = 3_000_000_000;
assert_eq!(rd.delay_bounds(), (1000, std::u32::MAX));
}
#[test]
fn rng() {
let mut rd = RetryDelay::from_msec(50);
let real_low_bound = std::cmp::max(50, MIN_LOW_BOUND);
let mut rng = rand::thread_rng();
for _ in 1..100 {
let (b_lo, b_hi) = rd.delay_bounds();
assert!(b_lo == real_low_bound);
assert!(b_hi > b_lo);
let delay = rd.next_delay(&mut rng).as_millis() as u32;
assert_eq!(delay, rd.last_delay_ms);
assert!(delay >= b_lo);
assert!(delay < b_hi);
}
}
#[test]
fn config() {
let cfg = DownloadSchedule::default();
assert_eq!(cfg.n_attempts(), 3);
let v: Vec<_> = cfg.attempts().collect();
assert_eq!(&v[..], &[0, 1, 2]);
let sched = cfg.schedule();
assert_eq!(sched.last_delay_ms, 0);
assert_eq!(sched.low_bound_ms, 1000);
let cfg = DownloadSchedule::new(0, Duration::new(0, 0), 0);
assert_eq!(cfg.n_attempts(), 1);
assert_eq!(cfg.parallelism(), 1);
let v: Vec<_> = cfg.attempts().collect();
assert_eq!(&v[..], &[0]);
let sched = cfg.schedule();
assert_eq!(sched.last_delay_ms, 0);
assert_eq!(sched.low_bound_ms, 1000);
}
}