#![cfg_attr(feature = "no_std", no_std)]
extern crate alloc;
use alloc::boxed::Box;
use core::sync::atomic::{AtomicU64, Ordering};
#[cfg(not(feature = "no_std"))]
use core::time::Duration;
#[cfg(feature = "async")]
use core::{future::Future, pin::Pin};
#[cfg(not(feature = "no_std"))]
mod throttle;
#[cfg(not(feature = "no_std"))]
pub use throttle::ExponentialBackoffWaiter;
#[cfg(not(feature = "no_std"))]
pub use throttle::ThrottleWaiter;
#[cfg(not(feature = "no_std"))]
mod timeout;
#[cfg(not(feature = "no_std"))]
pub use timeout::TimeoutWaiter;
mod compose;
pub use compose::DelayComposer;
#[cfg(test)]
mod tests;
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq)]
pub enum WaiterError {
Timeout,
NotStarted,
}
pub trait Waiter: WaiterClone + Send + Sync {
fn restart(&mut self) -> Result<(), WaiterError> {
Ok(())
}
fn start(&mut self) {}
fn wait(&mut self) -> Result<(), WaiterError>;
#[cfg(feature = "async")]
fn async_wait(&mut self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
Box::pin(futures_util::future::ready(self.wait()))
}
}
pub trait WaiterClone {
fn clone_box(&self) -> Box<dyn Waiter>;
}
impl<T> WaiterClone for T
where
T: 'static + Waiter + Clone,
{
fn clone_box(&self) -> Box<dyn Waiter> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn Waiter> {
fn clone(&self) -> Self {
WaiterClone::clone_box(self.as_ref())
}
}
impl Waiter for Box<dyn Waiter> {
fn restart(&mut self) -> Result<(), WaiterError> {
self.as_mut().restart()
}
fn start(&mut self) {
self.as_mut().start()
}
fn wait(&mut self) -> Result<(), WaiterError> {
self.as_mut().wait()
}
#[cfg(feature = "async")]
fn async_wait(&mut self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
self.as_mut().async_wait()
}
}
pub struct Delay {
inner: Box<dyn Waiter>,
}
impl Delay {
pub fn from(inner: Box<dyn Waiter>) -> Self {
Delay { inner }
}
pub fn instant() -> Self {
Self::from(Box::new(InstantWaiter {}))
}
#[cfg(not(feature = "no_std"))]
pub fn timeout(timeout: Duration) -> Self {
Self::from(Box::new(TimeoutWaiter::new(timeout)))
}
pub fn count_timeout(count: u64) -> Self {
Self::from(Box::new(CountTimeoutWaiter::new(count)))
}
#[cfg(not(feature = "no_std"))]
pub fn throttle(throttle: Duration) -> Self {
Self::from(Box::new(ThrottleWaiter::new(throttle)))
}
#[cfg(not(feature = "no_std"))]
pub fn exponential_backoff_capped(initial: Duration, multiplier: f32, cap: Duration) -> Self {
Self::from(Box::new(ExponentialBackoffWaiter::new(
initial, multiplier, cap,
)))
}
#[cfg(not(feature = "no_std"))]
pub fn exponential_backoff(initial: Duration, multiplier: f32) -> Self {
Self::exponential_backoff_capped(initial, multiplier, Duration::from_secs(std::u64::MAX))
}
pub fn side_effect<F>(function: F) -> Self
where
F: 'static + Sync + Send + Clone + Fn() -> Result<(), WaiterError>,
{
Self::from(Box::new(SideEffectWaiter::new(function)))
}
pub fn builder() -> DelayBuilder {
DelayBuilder { inner: None }
}
}
impl Clone for Delay {
fn clone(&self) -> Self {
Self::from(self.inner.clone_box())
}
}
impl Waiter for Delay {
fn restart(&mut self) -> Result<(), WaiterError> {
self.inner.restart()
}
fn start(&mut self) {
self.inner.start()
}
fn wait(&mut self) -> Result<(), WaiterError> {
self.inner.wait()
}
#[cfg(feature = "async")]
fn async_wait(&mut self) -> Pin<Box<dyn Future<Output = Result<(), WaiterError>> + Send>> {
self.inner.async_wait()
}
}
pub struct DelayBuilder {
inner: Option<Delay>,
}
impl DelayBuilder {
pub fn with(mut self, other: Delay) -> Self {
self.inner = Some(match self.inner.take() {
None => other,
Some(w) => Delay::from(Box::new(DelayComposer::new(w, other))),
});
self
}
#[cfg(not(feature = "no_std"))]
pub fn timeout(self, timeout: Duration) -> Self {
self.with(Delay::timeout(timeout))
}
#[cfg(not(feature = "no_std"))]
pub fn throttle(self, throttle: Duration) -> Self {
self.with(Delay::throttle(throttle))
}
#[cfg(not(feature = "no_std"))]
pub fn exponential_backoff(self, initial: Duration, multiplier: f32) -> Self {
self.with(Delay::exponential_backoff(initial, multiplier))
}
#[cfg(not(feature = "no_std"))]
pub fn exponential_backoff_capped(
self,
initial: Duration,
multiplier: f32,
cap: Duration,
) -> Self {
self.with(Delay::exponential_backoff_capped(initial, multiplier, cap))
}
pub fn side_effect<F>(self, function: F) -> Self
where
F: 'static + Sync + Send + Clone + Fn() -> Result<(), WaiterError>,
{
self.with(Delay::side_effect(function))
}
pub fn build(mut self) -> Delay {
self.inner.take().unwrap_or_else(Delay::instant)
}
}
#[derive(Clone)]
struct InstantWaiter {}
impl Waiter for InstantWaiter {
fn wait(&mut self) -> Result<(), WaiterError> {
Ok(())
}
}
struct CountTimeoutWaiter {
max_count: u64,
count: Option<AtomicU64>,
}
impl CountTimeoutWaiter {
pub fn new(max_count: u64) -> Self {
CountTimeoutWaiter {
max_count,
count: None,
}
}
}
impl Clone for CountTimeoutWaiter {
fn clone(&self) -> Self {
Self {
max_count: self.max_count,
count: self
.count
.as_ref()
.map(|count| AtomicU64::new(count.load(Ordering::Relaxed))),
}
}
}
impl Waiter for CountTimeoutWaiter {
fn restart(&mut self) -> Result<(), WaiterError> {
if self.count.is_none() {
Err(WaiterError::NotStarted)
} else {
self.count = Some(AtomicU64::new(0));
Ok(())
}
}
fn start(&mut self) {
self.count = Some(AtomicU64::new(0));
}
fn wait(&mut self) -> Result<(), WaiterError> {
let count = self.count.as_mut().ok_or(WaiterError::NotStarted)?;
let current = count.fetch_add(1, Ordering::Relaxed);
if current >= self.max_count {
Err(WaiterError::Timeout)
} else {
Ok(())
}
}
}
#[derive(Clone)]
struct SideEffectWaiter<F>
where
F: 'static + Sync + Send + Clone + Fn() -> Result<(), WaiterError>,
{
function: F,
}
impl<F> SideEffectWaiter<F>
where
F: 'static + Sync + Send + Clone + Fn() -> Result<(), WaiterError>,
{
pub fn new(function: F) -> Self {
Self { function }
}
}
impl<F> Waiter for SideEffectWaiter<F>
where
F: 'static + Sync + Send + Clone + Fn() -> Result<(), WaiterError>,
{
fn wait(&mut self) -> Result<(), WaiterError> {
(self.function)()
}
}