#![cfg_attr(
all(
not(feature = "std"),
not(feature = "async-io"),
not(feature = "tokio")
),
no_std
)]
#![cfg_attr(docsrs, feature(doc_cfg))]
use core::{
future::Future,
pin::Pin,
sync::atomic::Ordering,
task::{Context, Poll},
time::Duration,
};
use portable_atomic::AtomicU64;
pub mod runtime;
use runtime::{Instant, Runtime, Sleep};
#[derive(Debug)]
pub struct Timeout<R: Runtime> {
runtime: R,
epoch: R::Instant,
timeout_from_epoch_ns: AtomicU64,
default_timeout: AtomicU64,
}
#[cfg(feature = "tokio")]
pub type TokioTimeout = Timeout<runtime::Tokio>;
#[cfg(feature = "tokio")]
impl TokioTimeout {
pub fn new_tokio(default_timeout: Duration) -> Self {
let runtime = runtime::Tokio::new();
let epoch = runtime.now();
let default_timeout = u64::try_from(default_timeout.as_nanos()).unwrap();
Self {
runtime,
epoch,
timeout_from_epoch_ns: default_timeout.into(),
default_timeout: default_timeout.into(),
}
}
}
impl<R: Runtime> Timeout<R> {
#[must_use]
pub fn new(runtime: R, default_timeout: Duration) -> Self {
let epoch = runtime.now();
let default_timeout = u64::try_from(default_timeout.as_nanos()).unwrap();
Self {
runtime,
epoch,
timeout_from_epoch_ns: default_timeout.into(),
default_timeout: default_timeout.into(),
}
}
fn elapsed(&self) -> Duration {
self.runtime.now().duration_since(&self.epoch)
}
pub fn reset(&self) {
self.timeout_from_epoch_ns.store(
u64::try_from(self.elapsed().as_nanos()).unwrap()
+ self.default_timeout.load(Ordering::Acquire),
Ordering::Release,
);
}
pub fn default_timeout(&self) -> Duration {
Duration::from_nanos(self.default_timeout.load(Ordering::Acquire))
}
pub fn set_default_timeout(&self, default_timeout: Duration) {
self.default_timeout.store(
u64::try_from(default_timeout.as_nanos()).unwrap(),
Ordering::Release,
);
}
fn timeout_duration(&self) -> Option<Duration> {
let elapsed_nanos = u64::try_from(self.elapsed().as_nanos()).unwrap();
let target_nanos = self.timeout_from_epoch_ns.load(Ordering::Acquire);
(elapsed_nanos < target_nanos).then(|| Duration::from_nanos(target_nanos - elapsed_nanos))
}
pub async fn wait(&self) {
pin_project_lite::pin_project! {
struct SleepFuture<F: Sleep> {
#[pin]
inner: F,
}
}
impl<F: Sleep> Future for SleepFuture<F> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.project().inner.poll_sleep(cx)
}
}
if let Some(timeout) = self.timeout_duration() {
let future = self.runtime.create_sleep(timeout);
let mut future = SleepFuture { inner: future };
let future = &mut unsafe { Pin::new_unchecked(&mut future) };
while let Some(instant) = self.timeout_duration() {
future.as_mut().project().inner.reset(instant);
future.as_mut().await;
}
}
}
}
#[cfg(feature = "wrapper")]
mod wrapper;
#[cfg(feature = "wrapper")]
pub use wrapper::Wrapper;
#[cfg(all(feature = "wrapper", feature = "tokio"))]
pub use wrapper::TokioWrapper;
#[cfg(test)]
mod tests {
use tokio::time::Instant;
use crate::*;
#[test]
fn test_expiry() {
let start = Instant::now();
tokio_test::block_on(async {
let timer = Timeout::new(runtime::Tokio::new(), Duration::from_secs(1));
timer.wait().await;
});
assert!(start.elapsed() >= Duration::from_secs(1));
}
#[test]
fn test_non_expiry() {
let start = Instant::now();
assert!(tokio_test::block_on(async {
let timer = Timeout::new(runtime::Tokio::new(), Duration::from_secs(2));
tokio::select! {
_ = timer.wait() => {
false
}
_ = async {
tokio::time::sleep(Duration::from_secs(1)).await;
timer.reset();
tokio::time::sleep(Duration::from_secs(1)).await;
} => {
true
}
}
}));
assert!(start.elapsed() >= Duration::from_secs(2));
}
}