use std::{
cmp,
fmt::{self, Debug},
time::Duration,
};
#[derive(Copy, Clone, Debug)]
pub struct Backoff {
initial: Duration,
max: Duration,
mult: u32,
max_count: u32,
current: Duration,
count: u32,
}
impl Backoff {
#[must_use]
pub const fn new(initial: Duration, max: Duration, mult: u32, max_count: u32) -> Self {
Self {
initial,
max,
mult,
max_count,
current: initial,
count: 0,
}
}
pub fn advance(&mut self) -> Option<Duration> {
if self.max_count != 0 && self.count >= self.max_count {
return None;
}
self.count += 1;
let old = self.current.min(self.max);
self.current = old * self.mult;
Some(old)
}
pub const fn reset(&mut self) {
self.current = self.initial;
self.count = 0;
}
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub struct OptionalDuration(Option<Duration>);
impl OptionalDuration {
pub const NONE: Self = Self(None);
#[must_use]
pub const fn from_secs(duration: u64) -> Self {
Self(Some(Duration::from_secs(duration)))
}
pub async fn timeout<T: Future>(
self,
future: T,
) -> Result<T::Output, tokio::time::error::Elapsed> {
match self.0 {
Some(duration) => tokio::time::timeout(duration, future).await,
None => Ok(future.await),
}
}
pub async fn sleep(self) {
match self.0 {
Some(duration) => tokio::time::sleep(duration).await,
None => std::future::pending().await,
}
}
#[must_use = "This function does not modify the original value"]
pub fn map<F: FnOnce(Duration) -> Duration>(self, f: F) -> Self {
Self(self.0.map(f))
}
#[must_use]
pub fn cmp_duration(&self, other: &Duration) -> cmp::Ordering {
self.0.map_or(cmp::Ordering::Greater, |d| d.cmp(other))
}
}
impl std::str::FromStr for OptionalDuration {
type Err = std::num::ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let value = s.parse::<u64>()?;
if value == 0 {
Ok(Self(None))
} else {
Ok(Self::from_secs(value))
}
}
}
impl fmt::Display for OptionalDuration {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.0 {
Some(duration) => duration.fmt(f),
None => write!(f, "indefinite"),
}
}
}
impl From<Duration> for OptionalDuration {
fn from(duration: Duration) -> Self {
if duration.is_zero() {
Self(None)
} else {
Self(Some(duration))
}
}
}
impl From<OptionalDuration> for Option<Duration> {
fn from(opt_dur: OptionalDuration) -> Self {
opt_dur.0
}
}
impl PartialOrd for OptionalDuration {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OptionalDuration {
fn cmp(&self, other: &Self) -> cmp::Ordering {
match (self.0, other.0) {
(Some(a), Some(b)) => a.cmp(&b),
(None, None) => cmp::Ordering::Equal,
(Some(_), None) => cmp::Ordering::Less,
(None, Some(_)) => cmp::Ordering::Greater,
}
}
}
#[derive(Debug, Default)]
pub struct OptionalInterval(Option<tokio::time::Interval>);
impl OptionalInterval {
pub fn set_missed_tick_behavior(&mut self, behavior: tokio::time::MissedTickBehavior) {
if let Some(interval) = &mut self.0 {
interval.set_missed_tick_behavior(behavior);
}
}
pub async fn tick(&mut self) -> tokio::time::Instant {
if let Some(interval) = &mut self.0 {
interval.tick().await
} else {
std::future::pending::<tokio::time::Instant>().await
}
}
}
impl From<OptionalDuration> for OptionalInterval {
fn from(dur: OptionalDuration) -> Self {
Self(dur.0.map(tokio::time::interval))
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_util::FutureExt;
#[test]
fn test_backoff() {
crate::tests::setup_logging();
let mut backoff = Backoff::new(Duration::from_millis(10), Duration::from_secs(1), 2, 5);
assert_eq!(backoff.advance(), Some(Duration::from_millis(10)));
assert_eq!(backoff.advance(), Some(Duration::from_millis(20)));
assert_eq!(backoff.advance(), Some(Duration::from_millis(40)));
assert_eq!(backoff.advance(), Some(Duration::from_millis(80)));
assert_eq!(backoff.advance(), Some(Duration::from_millis(160)));
assert_eq!(backoff.advance(), None);
backoff.reset();
assert_eq!(backoff.advance(), Some(Duration::from_millis(10)));
backoff.reset();
assert_eq!(backoff.advance(), Some(Duration::from_millis(10)));
assert_eq!(backoff.advance(), Some(Duration::from_millis(20)));
assert_eq!(backoff.advance(), Some(Duration::from_millis(40)));
assert_eq!(backoff.advance(), Some(Duration::from_millis(80)));
assert_eq!(backoff.advance(), Some(Duration::from_millis(160)));
assert_eq!(backoff.advance(), None);
assert_eq!(backoff.advance(), None);
let mut backoff = Backoff::new(Duration::from_secs(10), Duration::from_secs(1), 2, 0);
assert_eq!(backoff.advance(), Some(Duration::from_secs(1)));
assert_eq!(backoff.advance(), Some(Duration::from_secs(1)));
assert_eq!(backoff.advance(), Some(Duration::from_secs(1)));
}
#[test]
fn test_optional_duration() {
crate::tests::setup_logging();
let std_dur = Duration::from_secs(0);
let opt_dur = OptionalDuration::from(std_dur);
assert_eq!(opt_dur, OptionalDuration::NONE);
let std_dur = Duration::from_secs(10);
let opt_dur = OptionalDuration::from(std_dur);
assert_eq!(opt_dur, OptionalDuration::from_secs(10));
let dur = OptionalDuration::from_secs(10);
assert_eq!(dur.to_string(), "10s");
let dur_none = OptionalDuration::NONE;
assert_eq!(dur_none.to_string(), "indefinite");
let parsed: OptionalDuration = "20".parse().unwrap();
assert_eq!(parsed.to_string(), "20s");
let parsed_none: OptionalDuration = "0".parse().unwrap();
assert_eq!(parsed_none, OptionalDuration::NONE);
let dur = OptionalDuration::from_secs(2);
let twice = dur.map(|d| d * 2);
assert_eq!(twice, OptionalDuration::from_secs(4));
let dur_none = OptionalDuration::NONE;
let still_none = dur_none.map(|d| d * 2);
assert_eq!(still_none, OptionalDuration::NONE);
assert_eq!(
OptionalDuration::from_secs(5).cmp_duration(&Duration::from_secs(10)),
cmp::Ordering::Less
);
assert!(OptionalDuration::from_secs(10) > OptionalDuration::from_secs(5));
assert_eq!(
OptionalDuration::from_secs(10).cmp_duration(&Duration::from_secs(10)),
cmp::Ordering::Equal
);
assert!(OptionalDuration::from_secs(10) == OptionalDuration::from_secs(10));
assert_eq!(
OptionalDuration::from_secs(15).cmp_duration(&Duration::from_secs(10)),
cmp::Ordering::Greater
);
assert!(OptionalDuration::from_secs(15) > OptionalDuration::from_secs(10));
assert_eq!(
OptionalDuration::NONE.cmp_duration(&Duration::from_secs(10)),
cmp::Ordering::Greater
);
assert!(OptionalDuration::NONE > OptionalDuration::from_secs(10));
assert!(OptionalDuration::NONE == OptionalDuration::NONE);
assert!(OptionalDuration::from_secs(10) < OptionalDuration::NONE);
}
#[tokio::test]
#[cfg(not(loom))]
async fn test_optional_interval() {
crate::tests::setup_logging();
let dur = OptionalDuration::from_secs(2);
let mut interval = OptionalInterval::from(dur);
interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Delay);
assert!(interval.tick().now_or_never().is_none());
tokio::time::sleep(Duration::from_secs(3)).await;
let instant = interval.tick().now_or_never().unwrap();
assert!(instant < tokio::time::Instant::now());
}
}