use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use std::time::{Duration, Instant};
use nexus_timer::{Wheel, WheelBuilder};
pub(crate) struct TimerDriver {
wheel: Wheel<Waker>,
expired: Vec<Waker>,
}
impl TimerDriver {
pub(crate) fn new(capacity: usize) -> Self {
let now = Instant::now();
let wheel = WheelBuilder::default().unbounded(capacity).build(now);
Self {
wheel,
expired: Vec::with_capacity(64),
}
}
pub(crate) fn schedule(&mut self, deadline: Instant, waker: Waker) {
self.wheel.schedule_forget(deadline, waker);
}
pub(crate) fn next_deadline(&self) -> Option<Instant> {
self.wheel.next_deadline()
}
pub(crate) fn fire_expired(&mut self, now: Instant) -> usize {
self.expired.clear();
let fired = self.wheel.poll(now, &mut self.expired);
for waker in self.expired.drain(..) {
waker.wake();
}
fired
}
}
#[derive(Clone, Copy)]
pub struct TimerHandle {
driver: *mut TimerDriver,
}
impl TimerHandle {
pub(crate) fn new(driver: &mut TimerDriver) -> Self {
Self {
driver: std::ptr::from_mut(driver),
}
}
pub fn sleep(&self, duration: Duration) -> Sleep {
Sleep {
deadline: Instant::now() + duration,
driver: self.driver,
registered: false,
waker: None,
}
}
pub fn sleep_until(&self, deadline: Instant) -> Sleep {
Sleep {
deadline,
driver: self.driver,
registered: false,
waker: None,
}
}
}
pub struct Sleep {
deadline: Instant,
driver: *mut TimerDriver,
registered: bool,
waker: Option<Waker>,
}
impl Future for Sleep {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if Instant::now() >= self.deadline {
return Poll::Ready(());
}
let needs_register =
!self.registered || self.waker.as_ref().is_none_or(|w| !w.will_wake(cx.waker()));
if needs_register {
let driver = unsafe { &mut *self.driver };
driver.schedule(self.deadline, cx.waker().clone());
self.registered = true;
self.waker = Some(cx.waker().clone());
}
Poll::Pending
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Elapsed;
impl std::fmt::Display for Elapsed {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("deadline elapsed")
}
}
impl std::error::Error for Elapsed {}
pub struct Timeout<F> {
future: F,
sleep: Sleep,
}
impl<F> Timeout<F> {
pub(crate) fn new(future: F, sleep: Sleep) -> Self {
Self { future, sleep }
}
pub fn into_inner(self) -> F {
self.future
}
}
impl<F: Future> Future for Timeout<F> {
type Output = Result<F::Output, Elapsed>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = unsafe { self.get_unchecked_mut() };
if Pin::new(&mut this.sleep).poll(cx).is_ready() {
return Poll::Ready(Err(Elapsed));
}
if let Poll::Ready(val) = unsafe { Pin::new_unchecked(&mut this.future) }.poll(cx) {
return Poll::Ready(Ok(val));
}
Poll::Pending
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MissedTickBehavior {
Burst,
Skip,
Delay,
}
pub struct Interval {
period: Duration,
start: Instant,
next_deadline: Instant,
sleep: Option<Sleep>,
missed_tick_behavior: MissedTickBehavior,
}
impl Interval {
pub(crate) fn new(period: Duration) -> Self {
assert!(!period.is_zero(), "interval period must be non-zero");
let now = Instant::now();
Self {
period,
start: now,
next_deadline: now + period,
sleep: None,
missed_tick_behavior: MissedTickBehavior::Burst,
}
}
pub(crate) fn new_at(start: Instant, period: Duration) -> Self {
assert!(!period.is_zero(), "interval period must be non-zero");
Self {
period,
start,
next_deadline: start,
sleep: None,
missed_tick_behavior: MissedTickBehavior::Burst,
}
}
pub async fn tick(&mut self) {
if self.sleep.is_none() {
self.sleep = Some(crate::context::sleep_until(self.next_deadline));
}
if let Some(ref mut sleep) = self.sleep {
Pin::new(sleep).await;
}
let now = Instant::now();
self.sleep = None;
match self.missed_tick_behavior {
MissedTickBehavior::Burst => {
self.next_deadline += self.period;
}
MissedTickBehavior::Skip => {
if now >= self.next_deadline {
let elapsed = now.duration_since(self.start);
let period_nanos = self.period.as_nanos();
let periods = elapsed.as_nanos() / period_nanos;
let next_nanos = (periods + 1).saturating_mul(period_nanos);
let offset =
Duration::from_nanos(u64::try_from(next_nanos).unwrap_or(u64::MAX));
self.next_deadline = self.start + offset;
} else {
self.next_deadline += self.period;
}
}
MissedTickBehavior::Delay => {
self.next_deadline = now + self.period;
}
}
}
pub fn reset(&mut self) {
self.next_deadline = Instant::now() + self.period;
self.sleep = None;
}
pub fn reset_at(&mut self, deadline: Instant) {
self.next_deadline = deadline;
self.sleep = None;
}
pub fn period(&self) -> Duration {
self.period
}
pub fn missed_tick_behavior(&self) -> MissedTickBehavior {
self.missed_tick_behavior
}
pub fn set_missed_tick_behavior(&mut self, behavior: MissedTickBehavior) {
self.missed_tick_behavior = behavior;
}
}
pub struct YieldNow(pub(crate) bool);
impl Future for YieldNow {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.0 {
Poll::Ready(())
} else {
self.0 = true;
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::task::{RawWaker, RawWakerVTable};
fn noop_waker() -> Waker {
fn noop(_: *const ()) {}
fn clone(p: *const ()) -> RawWaker {
RawWaker::new(p, &VTABLE)
}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, noop, noop, noop);
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
}
#[test]
fn timer_driver_fire_expired() {
let mut driver = TimerDriver::new(64);
let now = Instant::now();
let waker = noop_waker();
driver.schedule(now - Duration::from_millis(10), waker.clone());
driver.schedule(now + Duration::from_secs(100), waker);
let fired = driver.fire_expired(now);
assert_eq!(fired, 1);
assert!(driver.next_deadline().unwrap() > now);
}
#[test]
fn timer_driver_next_deadline() {
let mut driver = TimerDriver::new(64);
assert!(driver.next_deadline().is_none());
let now = Instant::now();
let soon = now + Duration::from_millis(10);
let later = now + Duration::from_millis(100);
let waker = noop_waker();
driver.schedule(later, waker.clone());
driver.schedule(soon, waker);
let next = driver.next_deadline().unwrap();
assert!(next <= soon + Duration::from_millis(2));
}
}