use std::cell::RefCell;
use std::time::{Duration, Instant};
#[cfg(test)]
mod tests;
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)]
pub enum WaiterError {
Timeout,
NotStarted,
}
pub trait Waiter: WaiterClone + Send {
fn restart(&mut self) -> Result<(), WaiterError> {
Ok(())
}
fn start(&mut self) {}
fn wait(&self) -> Result<(), WaiterError>;
}
pub trait WaiterClone {
fn clone_box(&self) -> Box<dyn Waiter>;
}
impl<T> WaiterClone for T
where
T: 'static + Waiter + Clone,
{
fn clone_box(&self) -> Box<dyn Waiter> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn Waiter> {
fn clone(&self) -> Self {
self.clone_box()
}
}
#[derive(Clone)]
pub struct Delay {
inner: Box<dyn Waiter>,
}
impl Delay {
fn from(inner: Box<dyn Waiter>) -> Self {
Delay { inner }
}
pub fn instant() -> Self {
Self::from(Box::new(InstantWaiter {}))
}
pub fn timeout(timeout: Duration) -> Self {
Self::from(Box::new(TimeoutWaiter::new(timeout)))
}
pub fn count_timeout(count: u64) -> Self {
Self::from(Box::new(CountTimeoutWaiter::new(count)))
}
pub fn throttle(throttle: Duration) -> Self {
Self::from(Box::new(ThrottleWaiter::new(throttle)))
}
pub fn exponential_backoff_capped(initial: Duration, multiplier: f32, cap: Duration) -> Self {
Self::from(Box::new(ExponentialBackoffWaiter::new(
initial, multiplier, cap,
)))
}
pub fn exponential_backoff(initial: Duration, multiplier: f32) -> Self {
Self::exponential_backoff_capped(initial, multiplier, Duration::from_secs(std::u64::MAX))
}
pub fn side_effect<F>(function: F) -> Self
where
F: 'static + Sync + Send + Clone + Fn() -> Result<(), WaiterError>,
{
Self::from(Box::new(SideEffectWaiter::new(function)))
}
pub fn builder() -> DelayBuilder {
DelayBuilder { inner: None }
}
}
impl Waiter for Delay {
fn restart(&mut self) -> Result<(), WaiterError> {
self.inner.restart()
}
fn start(&mut self) {
self.inner.start()
}
fn wait(&self) -> Result<(), WaiterError> {
self.inner.wait()
}
}
pub struct DelayBuilder {
inner: Option<Delay>,
}
impl DelayBuilder {
pub fn with(mut self, other: Delay) -> Self {
self.inner = Some(match self.inner.take() {
None => other,
Some(w) => Delay::from(Box::new(DelayComposer::new(w, other))),
});
self
}
pub fn timeout(self, timeout: Duration) -> Self {
self.with(Delay::timeout(timeout))
}
pub fn throttle(self, throttle: Duration) -> Self {
self.with(Delay::throttle(throttle))
}
pub fn exponential_backoff(self, initial: Duration, multiplier: f32) -> Self {
self.with(Delay::exponential_backoff(initial, multiplier))
}
pub fn exponential_backoff_capped(
self,
initial: Duration,
multiplier: f32,
cap: Duration,
) -> Self {
self.with(Delay::exponential_backoff_capped(initial, multiplier, cap))
}
pub fn side_effect<F>(self, function: F) -> Self
where
F: 'static + Sync + Send + Clone + Fn() -> Result<(), WaiterError>,
{
self.with(Delay::side_effect(function))
}
pub fn build(mut self) -> Delay {
self.inner.take().unwrap_or_else(Delay::instant)
}
}
#[derive(Clone)]
struct DelayComposer {
a: Delay,
b: Delay,
}
impl DelayComposer {
fn new(a: Delay, b: Delay) -> Self {
Self { a, b }
}
}
impl Waiter for DelayComposer {
fn restart(&mut self) -> Result<(), WaiterError> {
self.a.restart()?;
self.b.restart()?;
Ok(())
}
fn start(&mut self) {
self.a.start();
self.b.start();
}
fn wait(&self) -> Result<(), WaiterError> {
self.a.wait()?;
self.b.wait()?;
Ok(())
}
}
#[derive(Clone)]
struct InstantWaiter {}
impl Waiter for InstantWaiter {
fn wait(&self) -> Result<(), WaiterError> {
Ok(())
}
}
#[derive(Clone)]
struct TimeoutWaiter {
timeout: Duration,
start: Option<Instant>,
}
impl TimeoutWaiter {
pub fn new(timeout: Duration) -> Self {
Self {
timeout,
start: None,
}
}
}
impl Waiter for TimeoutWaiter {
fn restart(&mut self) -> Result<(), WaiterError> {
let _ = self.start.ok_or(WaiterError::NotStarted)?;
self.start = Some(Instant::now());
Ok(())
}
fn start(&mut self) {
self.start = Some(Instant::now());
}
fn wait(&self) -> Result<(), WaiterError> {
let start = self.start.ok_or(WaiterError::NotStarted)?;
if start.elapsed() > self.timeout {
Err(WaiterError::Timeout)
} else {
Ok(())
}
}
}
#[derive(Clone)]
struct CountTimeoutWaiter {
max_count: u64,
count: Option<RefCell<u64>>,
}
impl CountTimeoutWaiter {
pub fn new(max_count: u64) -> Self {
CountTimeoutWaiter {
max_count,
count: None,
}
}
}
impl Waiter for CountTimeoutWaiter {
fn restart(&mut self) -> Result<(), WaiterError> {
let count = self.count.as_ref().ok_or(WaiterError::NotStarted)?;
count.replace(0);
Ok(())
}
fn start(&mut self) {
self.count = Some(RefCell::new(0));
}
fn wait(&self) -> Result<(), WaiterError> {
let count = self.count.as_ref().ok_or(WaiterError::NotStarted)?;
let current = *count.borrow() + 1;
count.replace(current);
if current >= self.max_count {
Err(WaiterError::Timeout)
} else {
Ok(())
}
}
}
#[derive(Clone)]
struct ThrottleWaiter {
throttle: Duration,
}
impl ThrottleWaiter {
pub fn new(throttle: Duration) -> Self {
Self { throttle }
}
}
impl Waiter for ThrottleWaiter {
fn wait(&self) -> Result<(), WaiterError> {
std::thread::sleep(self.throttle);
Ok(())
}
}
#[derive(Clone)]
struct ExponentialBackoffWaiter {
next: Option<RefCell<Duration>>,
initial: Duration,
multiplier: f32,
cap: Duration,
}
impl ExponentialBackoffWaiter {
pub fn new(initial: Duration, multiplier: f32, cap: Duration) -> Self {
ExponentialBackoffWaiter {
next: None,
initial,
multiplier,
cap,
}
}
}
impl Waiter for ExponentialBackoffWaiter {
fn restart(&mut self) -> Result<(), WaiterError> {
let next = self.next.as_ref().ok_or(WaiterError::NotStarted)?;
next.replace(self.initial);
Ok(())
}
fn start(&mut self) {
self.next = Some(RefCell::new(self.initial));
}
fn wait(&self) -> Result<(), WaiterError> {
let next = self.next.as_ref().ok_or(WaiterError::NotStarted)?;
let current = *next.borrow();
let current_nsec = current.as_nanos() as f32;
let mut next_duration = Duration::from_nanos((current_nsec * self.multiplier) as u64);
if next_duration > self.cap {
next_duration = self.cap;
}
next.replace(next_duration);
std::thread::sleep(current);
Ok(())
}
}
#[derive(Clone)]
struct SideEffectWaiter<F>
where
F: 'static + Sync + Send + Clone + Fn() -> Result<(), WaiterError>,
{
function: F,
}
impl<F> SideEffectWaiter<F>
where
F: 'static + Sync + Send + Clone + Fn() -> Result<(), WaiterError>,
{
pub fn new(function: F) -> Self {
Self { function }
}
}
impl<F> Waiter for SideEffectWaiter<F>
where
F: 'static + Sync + Send + Clone + Fn() -> Result<(), WaiterError>,
{
fn wait(&self) -> Result<(), WaiterError> {
(self.function)()
}
}