use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use super::context::ContextExt;
use crate::fiber;
use crate::time::Instant;
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
pub enum Error<E> {
#[error("deadline expired")]
Expired,
#[error("{0}")]
Failed(#[from] E),
}
pub type Result<T, E> = std::result::Result<T, Error<E>>;
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Timeout<F> {
future: F,
deadline: Option<Instant>,
}
#[inline(always)]
pub fn timeout<F: Future>(timeout: Duration, f: F) -> Timeout<F> {
Timeout {
future: f,
deadline: fiber::clock().checked_add(timeout),
}
}
#[inline(always)]
pub fn deadline<F: Future>(deadline: Instant, f: F) -> Timeout<F> {
Timeout {
future: f,
deadline: Some(deadline),
}
}
impl<F: Future> Timeout<F> {
#[inline]
fn pin_get_future(self: Pin<&mut Self>) -> Pin<&mut F> {
unsafe { self.map_unchecked_mut(|s| &mut s.future) }
}
}
impl<F, T, E> Future for Timeout<F>
where
F: Future<Output = std::result::Result<T, E>>,
{
type Output = Result<T, E>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let deadline = self.deadline;
if let Poll::Ready(v) = self.pin_get_future().poll(cx) {
return Poll::Ready(v.map_err(Error::Failed));
}
match deadline {
Some(deadline) if fiber::clock() >= deadline => {
Poll::Ready(Err(Error::Expired)) }
Some(deadline) => {
unsafe { ContextExt::set_deadline(cx, deadline) };
Poll::Pending
}
None => {
Poll::Pending
}
}
}
}
pub trait IntoTimeout: Future + Sized {
#[inline(always)]
fn timeout(self, timeout: Duration) -> Timeout<Self> {
self::timeout(timeout, self)
}
#[inline(always)]
fn deadline(self, deadline: Instant) -> Timeout<Self> {
self::deadline(deadline, self)
}
}
impl<T> IntoTimeout for T where T: Future + Sized {}
#[cfg(feature = "internal_test")]
mod tests {
use super::*;
use crate::fiber;
use crate::fiber::check_yield;
use crate::fiber::r#async::{oneshot, RecvError};
use crate::fiber::YieldResult::{DidntYield, Yielded};
use crate::test::util::ok;
use std::time::Duration;
const _0_SEC: Duration = Duration::ZERO;
const _1_SEC: Duration = Duration::from_secs(1);
#[crate::test(tarantool = "crate")]
fn instant_future() {
let fut = async { 78 };
assert_eq!(fiber::block_on(fut), 78);
let fut = timeout(Duration::ZERO, async { ok(79) });
assert_eq!(fiber::block_on(fut), Ok(79));
}
#[crate::test(tarantool = "crate")]
fn actual_timeout_promise() {
let (tx, rx) = oneshot::channel::<i32>();
let fut = async move { rx.timeout(_0_SEC).await };
let jh = fiber::start_async(fut);
assert_eq!(jh.join(), Err(Error::Expired));
drop(tx);
}
#[crate::test(tarantool = "crate")]
fn drop_tx_before_timeout() {
let (tx, rx) = oneshot::channel::<i32>();
let fut = async move { rx.timeout(_1_SEC).await };
let jh = fiber::start(move || fiber::block_on(fut));
drop(tx);
assert_eq!(jh.join(), Err(Error::Failed(RecvError)));
}
#[crate::test(tarantool = "crate")]
fn send_tx_before_timeout() {
let (tx, rx) = oneshot::channel::<i32>();
let fut = async move { rx.timeout(_1_SEC).await };
let jh = fiber::start(move || fiber::block_on(fut));
tx.send(400).unwrap();
assert_eq!(jh.join(), Ok(400));
}
#[crate::test(tarantool = "crate")]
fn timeout_duration_max() {
fiber::block_on(timeout(Duration::MAX, async { ok(1) })).unwrap();
}
#[crate::test(tarantool = "crate")]
fn await_actually_yields() {
assert_eq!(
check_yield(|| fiber::block_on(async { 101 })),
DidntYield(101)
);
assert_eq!(
check_yield(|| fiber::block_on(timeout(Duration::ZERO, async { ok(202) }))),
DidntYield(Ok(202))
);
assert_eq!(
check_yield(|| fiber::block_on(timeout(Duration::from_secs(1), async { ok(303) }))),
DidntYield(Ok(303))
);
let (_tx, rx) = oneshot::channel::<i32>();
let f = check_yield(|| fiber::start(|| fiber::block_on(rx)));
assert!(matches!(f, Yielded(_)));
std::mem::forget(f);
let (_tx, rx) = oneshot::channel::<i32>();
assert_eq!(
check_yield(|| fiber::block_on(timeout(Duration::ZERO, rx))),
DidntYield(Err(Error::Expired))
);
let (_tx, rx) = oneshot::channel::<i32>();
let now = fiber::clock();
assert_eq!(
check_yield(|| fiber::block_on(deadline(now, rx))),
DidntYield(Err(Error::Expired))
);
let (_tx, rx) = oneshot::channel::<i32>();
let one_second_ago = now.saturating_sub(Duration::from_secs(1));
assert_eq!(
check_yield(|| fiber::block_on(deadline(one_second_ago, rx))),
DidntYield(Err(Error::Expired))
);
let (_tx, rx) = oneshot::channel::<i32>();
assert_eq!(
check_yield(|| fiber::block_on(timeout(Duration::from_millis(10), rx))),
Yielded(Err(Error::Expired))
);
let (_tx, rx) = oneshot::channel::<i32>();
let in_10_millis = fiber::clock().saturating_add(Duration::from_millis(10));
assert_eq!(
check_yield(|| fiber::block_on(deadline(in_10_millis, rx))),
Yielded(Err(Error::Expired))
);
}
}