use std::future::Future;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll;
use std::time::Duration;
use std::time::Instant;
use super::context::ContextExt;
#[derive(thiserror::Error, Debug, PartialEq, Eq)]
#[error("deadline expired")]
pub struct Expired;
pub struct Timeout<F> {
future: F,
deadline: Instant,
}
#[inline]
pub fn timeout<F: Future>(timeout: Duration, f: F) -> Timeout<F> {
let now = Instant::now();
let deadline = now.checked_add(timeout).unwrap_or_else(|| {
now + Duration::from_secs(60 * 60 * 24 * 365 * 30)
});
Timeout {
future: f,
deadline,
}
}
impl<F: Future> Timeout<F> {
#[inline]
fn pin_get_future(self: Pin<&mut Self>) -> Pin<&mut F> {
unsafe { self.map_unchecked_mut(|s| &mut s.future) }
}
}
impl<F: Future> Future for Timeout<F> {
type Output = Result<F::Output, Expired>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let deadline = self.deadline;
if let Poll::Ready(v) = self.pin_get_future().poll(cx) {
Poll::Ready(Ok(v))
} else if Instant::now() >= deadline {
Poll::Ready(Err(Expired)) } else {
unsafe { ContextExt::set_deadline(cx, deadline) };
Poll::Pending
}
}
}
pub trait IntoTimeout: Future + Sized {
#[inline]
fn timeout(self, timeout: Duration) -> Timeout<Self> {
super::timeout::timeout(timeout, self)
}
}
impl<T> IntoTimeout for T where T: Future + Sized {}
#[cfg(feature = "tarantool_test")]
mod tests {
use super::*;
use crate::fiber;
use crate::fiber::check_yield;
use crate::fiber::r#async::{oneshot, RecvError};
use crate::fiber::YieldResult::{DidntYield, Yielded};
use crate::test::{TestCase, TESTS};
use crate::test_name;
use linkme::distributed_slice;
use std::time::Duration;
const _0_SEC: Duration = Duration::ZERO;
const _1_SEC: Duration = Duration::from_secs(1);
#[distributed_slice(TESTS)]
static INSTANT_FUTURE: TestCase = TestCase {
name: test_name!("instant_future"),
f: || {
let fut = async { 78 };
assert_eq!(fiber::block_on(fut), 78);
let fut = timeout(Duration::ZERO, async { 79 });
assert_eq!(fiber::block_on(fut), Ok(79));
},
};
#[distributed_slice(TESTS)]
static ACTUAL_TIMEOUT_PROMISE: TestCase = TestCase {
name: test_name!("actual_timeout_promise"),
f: || {
let (tx, rx) = oneshot::channel::<i32>();
let fut = async move { rx.timeout(_0_SEC).await };
let jh = fiber::start_async(fut);
assert_eq!(jh.join(), Err(Expired));
drop(tx);
},
};
#[distributed_slice(TESTS)]
static DROP_TX_BEFORE_TIMEOUT: TestCase = TestCase {
name: test_name!("drop_tx_before_timeout"),
f: || {
let (tx, rx) = oneshot::channel::<i32>();
let fut = async move { rx.timeout(_1_SEC).await };
let jh = fiber::start(move || fiber::block_on(fut));
drop(tx);
assert_eq!(jh.join(), Ok(Err(RecvError)));
},
};
#[distributed_slice(TESTS)]
static SEND_TX_BEFORE_TIMEOUT: TestCase = TestCase {
name: test_name!("send_tx_before_timeout"),
f: || {
let (tx, rx) = oneshot::channel::<i32>();
let fut = async move { rx.timeout(_1_SEC).await };
let jh = fiber::start(move || fiber::block_on(fut));
tx.send(400).unwrap();
assert_eq!(jh.join(), Ok(Ok(400)));
},
};
#[distributed_slice(TESTS)]
static TIMEOUT_DURATION_MAX: TestCase = TestCase {
name: test_name!("timeout_duration_max"),
f: || {
timeout(Duration::MAX, async { 1 });
},
};
#[distributed_slice(TESTS)]
static AWAIT_ACTUALLY_YIELDS: TestCase = TestCase {
name: test_name!("await_actually_yields"),
f: || {
assert_eq!(
check_yield(|| fiber::block_on(async { 101 })),
DidntYield(101)
);
assert_eq!(
check_yield(|| fiber::block_on(timeout(Duration::ZERO, async { 202 }))),
DidntYield(Ok(202))
);
assert_eq!(
check_yield(|| fiber::block_on(timeout(Duration::from_secs(1), async { 303 }))),
DidntYield(Ok(303))
);
let (_tx, rx) = oneshot::channel::<i32>();
let f = check_yield(|| fiber::start(|| fiber::block_on(rx)));
assert!(matches!(f, Yielded(_)));
std::mem::forget(f);
let (_tx, rx) = oneshot::channel::<i32>();
assert_eq!(
check_yield(|| fiber::block_on(timeout(Duration::ZERO, rx))),
DidntYield(Err(Expired))
);
let (_tx, rx) = oneshot::channel::<i32>();
assert_eq!(
check_yield(|| fiber::block_on(timeout(Duration::from_millis(10), rx))),
Yielded(Err(Expired))
);
},
};
}