use crate::time::{AsyncLocalTimeout, AsyncTimeout, Elapsed};
use ::tokio::time::{timeout, timeout_at, Timeout};
use core::{
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use std::time::Instant;
pin_project_lite::pin_project! {
#[repr(transparent)]
pub struct TokioTimeout<F> {
#[pin]
inner: Timeout<F>,
}
}
impl<F> From<Timeout<F>> for TokioTimeout<F> {
fn from(timeout: Timeout<F>) -> Self {
Self { inner: timeout }
}
}
impl<F> From<TokioTimeout<F>> for Timeout<F> {
fn from(timeout: TokioTimeout<F>) -> Self {
timeout.inner
}
}
impl<F: Future> Future for TokioTimeout<F> {
type Output = Result<F::Output, Elapsed>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match self.project().inner.poll(cx) {
Poll::Ready(Ok(rst)) => Poll::Ready(Ok(rst)),
Poll::Ready(Err(e)) => Poll::Ready(Err(e.into())),
Poll::Pending => Poll::Pending,
}
}
}
impl<F: Future + Send> AsyncTimeout<F> for TokioTimeout<F> {
fn timeout(t: Duration, fut: F) -> Self
where
Self: Sized,
{
<Self as AsyncLocalTimeout<F>>::timeout_local(t, fut)
}
fn timeout_at(deadline: Instant, fut: F) -> Self
where
Self: Sized,
{
<Self as AsyncLocalTimeout<F>>::timeout_local_at(deadline, fut)
}
}
impl<F> AsyncLocalTimeout<F> for TokioTimeout<F>
where
F: Future,
{
fn timeout_local(t: Duration, fut: F) -> Self
where
Self: Sized,
{
Self {
inner: timeout(t, fut),
}
}
fn timeout_local_at(deadline: Instant, fut: F) -> Self
where
Self: Sized,
{
Self {
inner: timeout_at(tokio::time::Instant::from_std(deadline), fut),
}
}
}
#[cfg(test)]
mod tests {
use super::{AsyncTimeout, TokioTimeout};
use std::time::{Duration, Instant};
const BAD: Duration = Duration::from_secs(1);
const GOOD: Duration = Duration::from_millis(10);
const TIMEOUT: Duration = Duration::from_millis(200);
const BOUND: Duration = Duration::from_secs(10);
#[tokio::test(flavor = "multi_thread")]
async fn test_timeout() {
futures::executor::block_on(async {
let fut = async {
tokio::time::sleep(BAD).await;
1
};
let start = Instant::now();
let rst = TokioTimeout::timeout(TIMEOUT, fut).await;
assert!(rst.is_err());
let elapsed = start.elapsed();
assert!(elapsed >= TIMEOUT && elapsed <= TIMEOUT + BOUND);
let fut = async {
tokio::time::sleep(GOOD).await;
1
};
let start = Instant::now();
let rst = TokioTimeout::timeout(TIMEOUT, fut).await;
assert!(rst.is_ok());
let elapsed = start.elapsed();
assert!(elapsed >= GOOD && elapsed <= GOOD + BOUND);
});
}
#[tokio::test(flavor = "multi_thread")]
async fn test_timeout_at() {
futures::executor::block_on(async {
let fut = async {
tokio::time::sleep(BAD).await;
1
};
let start = Instant::now();
let rst = TokioTimeout::timeout_at(Instant::now() + TIMEOUT, fut).await;
assert!(rst.is_err());
let elapsed = start.elapsed();
assert!(elapsed >= TIMEOUT && elapsed <= TIMEOUT + BOUND);
let fut = async {
tokio::time::sleep(GOOD).await;
1
};
let start = Instant::now();
let rst = TokioTimeout::timeout_at(Instant::now() + TIMEOUT, fut).await;
assert!(rst.is_ok());
let elapsed = start.elapsed();
assert!(elapsed >= GOOD && elapsed <= GOOD + BOUND);
});
}
}