use std::{
collections::BTreeMap,
error::Error,
fmt::Display,
future::Future,
io,
marker::PhantomData,
pin::Pin,
task::{Context, Poll, Waker},
time::{Duration, Instant},
};
use futures_core::Stream;
use pin_project_lite::pin_project;
use crate::{Id, REACTOR};
pub(crate) struct TimerQueue {
current_id: Id,
timers: BTreeMap<(Instant, Id), Waker>,
}
impl TimerQueue {
pub(crate) const fn new() -> Self {
Self {
current_id: const { Id::new(1) },
timers: BTreeMap::new(),
}
}
pub(crate) fn register(&mut self, expiry: Instant, mut waker: Waker) -> Id {
loop {
let id = self.current_id;
self.current_id = id.overflowing_incr();
waker = match self.timers.insert((expiry, id), waker) {
None => break id,
Some(prev_waker) => self.timers.insert((expiry, id), prev_waker).unwrap(),
}
}
}
pub(crate) fn modify(&mut self, id: Id, expiry: Instant, waker: &Waker) {
if let Some(wk) = self.timers.get_mut(&(expiry, id)) {
wk.clone_from(waker)
} else {
log::error!(
"{:?} Modifying non-existent timer ID = {}",
std::thread::current().id(),
id.0
);
}
}
pub(crate) fn cancel(&mut self, id: Id, expiry: Instant) {
self.timers.remove(&(expiry, id));
}
pub(crate) fn next_timeout(&mut self) -> Option<Duration> {
let now = Instant::now();
self.timers
.first_key_value()
.map(|((expiry, _), _)| expiry.saturating_duration_since(now))
}
pub(crate) fn clear_expired(&mut self) {
let now = Instant::now();
while let Some(entry) = self.timers.first_entry() {
let expiry = entry.key().0;
if expiry <= now {
entry.remove().wake();
} else {
break;
}
}
}
#[cfg(test)]
pub(crate) fn is_empty(&self) -> bool {
self.timers.is_empty()
}
}
#[derive(Debug)]
#[must_use = "Futures do nothing unless polled"]
pub struct Timer {
expiry: Instant,
timer_id: Option<Id>,
_phantom: PhantomData<*const ()>,
}
unsafe impl Sync for Timer {}
impl Timer {
pub fn at(expiry: Instant) -> Self {
Timer {
expiry,
timer_id: None,
_phantom: PhantomData,
}
}
pub fn delay(delay: Duration) -> Self {
Self::at(Instant::now() + delay)
}
fn register(&mut self, cx: &mut Context<'_>) {
REACTOR.with(|r| match self.timer_id {
None => {
self.timer_id = Some(r.register_timer(self.expiry, cx.waker().clone()));
}
Some(id) => r.modify_timer(id, self.expiry, cx.waker()),
});
}
}
impl Future for Timer {
type Output = Instant;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.expiry <= Instant::now() {
if let Some(id) = self.timer_id.take() {
REACTOR.with(|r| r.cancel_timer(id, self.expiry));
}
return Poll::Ready(self.expiry);
}
self.register(cx);
Poll::Pending
}
}
impl Drop for Timer {
fn drop(&mut self) {
if let Some(id) = self.timer_id.take() {
REACTOR.with(|r| r.cancel_timer(id, self.expiry));
}
}
}
pub fn sleep(duration: Duration) -> Timer {
Timer::delay(duration)
}
#[must_use = "Streams do nothing unless polled"]
pub struct Periodic {
timer: Timer,
period: Duration,
}
impl Periodic {
#[allow(clippy::self_named_constructors)]
pub fn periodic(period: Duration) -> Self {
Self {
timer: Timer::delay(period),
period,
}
}
pub fn periodic_at(start: Instant, period: Duration) -> Self {
Self {
timer: Timer::at(start),
period,
}
}
pub fn set_period(&mut self, period: Duration) {
self.period = period;
}
}
impl Stream for Periodic {
type Item = Instant;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Poll::Ready(expiry) = Pin::new(&mut self.timer).poll(cx) {
let next = expiry + self.period;
self.timer.expiry = next;
Poll::Ready(Some(expiry))
} else {
Poll::Pending
}
}
}
#[derive(Debug)]
pub struct TimedOut(());
impl Display for TimedOut {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Future timed out")
}
}
impl Error for TimedOut {}
impl From<TimedOut> for io::Error {
fn from(_: TimedOut) -> Self {
io::Error::from(io::ErrorKind::TimedOut)
}
}
pin_project! {
#[derive(Debug)]
#[must_use = "Futures do nothing unless polled"]
pub struct Timeout<F> {
#[pin]
timer: Timer,
#[pin]
fut: F,
}
}
impl<F: Future> Future for Timeout<F> {
type Output = Result<F::Output, TimedOut>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if let Poll::Ready(out) = self.as_mut().project().fut.poll(cx) {
return Poll::Ready(Ok(out));
}
if self.as_mut().project().timer.poll(cx).is_ready() {
return Poll::Ready(Err(TimedOut(())));
}
Poll::Pending
}
}
pub fn timeout<F: Future>(fut: F, timeout: Duration) -> Timeout<F> {
Timeout {
timer: Timer::delay(timeout),
fut,
}
}
pub fn timeout_at<F: Future>(fut: F, expiry: Instant) -> Timeout<F> {
Timeout {
timer: Timer::at(expiry),
fut,
}
}
#[cfg(test)]
mod tests {
use std::{
pin::{pin, Pin},
sync::Arc,
};
use crate::test::MockWaker;
use super::*;
#[test]
fn next_timeout() {
let wakers: Vec<_> = (0..3).map(|_| Arc::new(MockWaker::default())).collect();
let mut tq = TimerQueue::new();
assert!(tq.next_timeout().is_none());
tq.register(Instant::now(), wakers[0].clone().into());
tq.register(
Instant::now() - Duration::from_secs(1),
wakers[1].clone().into(),
);
tq.register(
Instant::now() + Duration::from_millis(50),
wakers[2].clone().into(),
);
assert_eq!(tq.next_timeout().unwrap(), Duration::ZERO);
tq.clear_expired();
assert!(tq.next_timeout().unwrap() > Duration::from_millis(40));
assert!(wakers[0].get());
assert!(wakers[1].get());
assert!(!wakers[2].get());
std::thread::sleep(Duration::from_millis(50));
tq.clear_expired();
assert!(tq.next_timeout().is_none());
assert!(wakers[2].get());
assert!(tq.timers.is_empty());
}
#[test]
fn modify() {
let wakers: Vec<_> = (0..2).map(|_| Arc::new(MockWaker::default())).collect();
let mut tq = TimerQueue::new();
let expiry = Instant::now() + Duration::from_millis(10);
let id = tq.register(expiry, wakers[0].clone().into());
tq.clear_expired();
assert!(tq.next_timeout().is_some());
tq.modify(id, expiry, &wakers[1].clone().into());
std::thread::sleep(Duration::from_millis(10));
tq.clear_expired();
assert!(tq.next_timeout().is_none());
assert!(!wakers[0].get());
assert!(wakers[1].get());
assert!(tq.timers.is_empty());
}
#[test]
fn cancel() {
let waker = Arc::new(MockWaker::default());
let mut tq = TimerQueue::new();
let expiry = Instant::now() + Duration::from_secs(10);
let id = tq.register(expiry, waker.clone().into());
tq.clear_expired();
assert!(tq.next_timeout().is_some());
tq.cancel(id, expiry);
tq.clear_expired();
assert!(tq.next_timeout().is_none());
assert!(!waker.get());
assert!(tq.timers.is_empty());
}
#[test]
fn timer_expired() {
let waker = Arc::new(MockWaker::default());
let mut timer = Timer::at(Instant::now());
assert!(Pin::new(&mut timer)
.poll(&mut Context::from_waker(&waker.into()))
.is_ready());
assert!(timer.timer_id.is_none());
assert!(REACTOR.with(|r| r.is_empty()));
}
#[test]
fn timer() {
let waker = Arc::new(MockWaker::default());
let mut timer = pin!(Timer::delay(Duration::from_millis(10)));
assert!(timer
.as_mut()
.poll(&mut Context::from_waker(&waker.clone().into()))
.is_pending());
assert!(timer.timer_id.is_some());
assert!(!REACTOR.with(|r| r.is_empty()));
std::thread::sleep(Duration::from_millis(10));
assert!(timer
.as_mut()
.poll(&mut Context::from_waker(&waker.into()))
.is_ready());
assert!(timer.timer_id.is_none());
assert!(REACTOR.with(|r| r.is_empty()));
}
#[test]
fn periodic() {
let waker = Arc::new(MockWaker::default());
let mut periodic = pin!(Periodic::periodic(Duration::from_millis(5)));
assert!(periodic
.as_mut()
.poll_next(&mut Context::from_waker(&waker.clone().into()))
.is_pending());
assert!(!REACTOR.with(|r| r.is_empty()));
std::thread::sleep(Duration::from_millis(5));
assert!(periodic
.as_mut()
.poll_next(&mut Context::from_waker(&waker.clone().into()))
.is_ready());
assert!(REACTOR.with(|r| r.is_empty()));
std::thread::sleep(Duration::from_millis(5));
assert!(periodic
.as_mut()
.poll_next(&mut Context::from_waker(&waker.clone().into()))
.is_ready());
assert!(REACTOR.with(|r| r.is_empty()));
}
#[test]
fn timeouts() {
let waker = Arc::new(MockWaker::default()).into();
let res1 = Pin::new(&mut timeout(
Timer::at(Instant::now()),
Duration::from_secs(10),
))
.poll(&mut Context::from_waker(&waker));
assert!(matches!(res1, Poll::Ready(Ok(_))));
let res2 = Pin::new(&mut timeout_at(
Timer::delay(Duration::from_secs(10)),
Instant::now(),
))
.poll(&mut Context::from_waker(&waker));
assert!(matches!(res2, Poll::Ready(Err(_))));
}
}