use crate::{
callback::{CallbackContext, CompletionFuture, CompletionSignal},
threadpool::ThreadpoolTimer,
};
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use windows_sys::Win32::System::Threading::{PTP_CALLBACK_INSTANCE, PTP_TIMER};
struct TimerState {
signal: CompletionSignal<()>,
}
#[cfg(debug_assertions)]
mod debug {
use windows_sys::Win32::System::Threading::PTP_TIMER;
thread_local! {
pub(super) static TIMER_CALLBACK_STACK: std::cell::RefCell<Vec<usize>> =
const { std::cell::RefCell::new(Vec::new()) };
}
pub(super) struct TimerCallbackStackGuard {
timer: usize,
}
impl TimerCallbackStackGuard {
pub(super) fn enter(timer: PTP_TIMER) -> Self {
let timer = timer.cast_unsigned();
TIMER_CALLBACK_STACK.with(|stack| stack.borrow_mut().push(timer));
Self { timer }
}
}
impl Drop for TimerCallbackStackGuard {
fn drop(&mut self) {
TIMER_CALLBACK_STACK.with(|stack| {
let popped = stack.borrow_mut().pop();
debug_assert_eq!(
popped,
Some(self.timer),
"threadpool timer callback stack tracking became unbalanced"
);
});
}
}
}
unsafe extern "system" fn timer_callback(
_instance: PTP_CALLBACK_INSTANCE,
context: *mut core::ffi::c_void,
timer: PTP_TIMER,
) {
if context.is_null() {
return;
}
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
#[cfg(debug_assertions)]
let _callback_stack_guard = debug::TimerCallbackStackGuard::enter(timer);
#[cfg(not(debug_assertions))]
let _ = timer;
let state: &TimerState =
unsafe { CallbackContext::<TimerState>::borrow_raw(context as usize) };
state.signal.signal(());
}));
}
pub(crate) struct Delay {
timer: Option<ThreadpoolTimer>,
callback_ctx: Option<CallbackContext<TimerState>>,
_state: Arc<TimerState>,
listener: CompletionFuture<()>,
fired: bool,
}
impl Delay {
pub(crate) fn new(dur: Duration) -> Self {
let state = Arc::new(TimerState {
signal: CompletionSignal::new(),
});
let listener = state.signal.listen();
let callback_ctx = CallbackContext::<TimerState>::new(&state);
let timer = unsafe {
ThreadpoolTimer::new(
Some(timer_callback),
callback_ctx.as_raw() as *mut core::ffi::c_void,
)
};
match timer {
Some(t) => {
t.set_relative(Some(dur));
Self {
timer: Some(t),
callback_ctx: Some(callback_ctx),
_state: state,
listener,
fired: false,
}
}
None => {
drop(callback_ctx);
warn!(
"wrest::timer::Delay: CreateThreadpoolTimer failed; \
timeout will not fire (request will run unbounded)"
);
Self {
timer: None,
callback_ctx: None,
_state: state,
listener,
fired: false,
}
}
}
}
}
impl Future for Delay {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
let this = self.get_mut();
if this.fired {
return Poll::Ready(());
}
match Pin::new(&mut this.listener).poll(cx) {
Poll::Ready(_) => {
this.fired = true;
Poll::Ready(())
}
Poll::Pending => Poll::Pending,
}
}
}
impl Drop for Delay {
fn drop(&mut self) {
let Some(timer) = self.timer.take() else {
return;
};
#[cfg(debug_assertions)]
{
let timer_id = timer.as_raw().cast_unsigned();
debug::TIMER_CALLBACK_STACK.with(|stack| {
debug_assert!(
!stack.borrow().contains(&timer_id),
"Delay dropped while inside its own timer callback; \
this suggests an executor/waker synchronously polled from wake"
);
});
}
drop(timer);
let _ = self.callback_ctx.take();
}
}
#[cfg(test)]
mod tests {
use super::*;
use futures_executor::block_on;
use std::task::{Context, Poll, Waker};
use std::time::Instant;
#[derive(Clone, Copy)]
enum DropState {
BeforeFire,
AfterFire,
WhilePending,
}
#[test]
fn fires_table() {
let cases: &[(Duration, Duration, &str)] = &[
(Duration::ZERO, Duration::ZERO, "zero duration"),
(Duration::from_millis(1), Duration::ZERO, "sub-tick (~scheduler floor)"),
(Duration::from_millis(50), Duration::from_millis(40), "normal 50ms"),
];
for &(dur, min_elapsed, label) in cases {
let start = Instant::now();
block_on(Delay::new(dur));
let elapsed = start.elapsed();
assert!(elapsed >= min_elapsed, "{label}: fired too early ({elapsed:?})");
assert!(elapsed < Duration::from_secs(2), "{label}: fired far too late ({elapsed:?})");
}
}
#[test]
fn drop_lifecycle_table() {
let waker = Waker::noop();
let mut cx = Context::from_waker(waker);
let cases: &[(Duration, DropState, &str)] = &[
(Duration::ZERO, DropState::BeforeFire, "zero / drop before fire"),
(Duration::from_millis(100), DropState::BeforeFire, "100ms / armed at drop"),
(Duration::ZERO, DropState::AfterFire, "zero / drop after await"),
(Duration::from_millis(1), DropState::AfterFire, "1ms / drop after await"),
(Duration::from_millis(20), DropState::WhilePending, "20ms / drop while polled"),
(Duration::from_millis(100), DropState::WhilePending, "100ms / polled, still armed"),
];
const REPS: usize = 8;
for &(dur, state, label) in cases {
for i in 0..REPS {
match state {
DropState::BeforeFire => {
let d = Delay::new(dur);
drop(d);
}
DropState::AfterFire => {
block_on(Delay::new(dur));
}
DropState::WhilePending => {
let mut d = Box::pin(Delay::new(dur));
let _ = d.as_mut().poll(&mut cx);
drop(d);
}
}
let _ = (label, i);
}
}
}
#[test]
fn races_with_select() {
use futures_util::future::{Either, select};
let fast = Delay::new(Duration::from_millis(10));
let slow = Delay::new(Duration::from_secs(60));
let fast = std::pin::pin!(fast);
let slow = std::pin::pin!(slow);
match block_on(select(fast, slow)) {
Either::Left(_) => {} Either::Right(_) => panic!("slow timer beat fast timer"),
}
}
#[test]
fn poll_after_ready_is_idempotent() {
let waker = Waker::noop();
let mut cx = Context::from_waker(waker);
let mut d = Box::pin(Delay::new(Duration::ZERO));
block_on(async {
(&mut d).await;
});
match d.as_mut().poll(&mut cx) {
Poll::Ready(()) => {}
Poll::Pending => panic!("Delay re-polled after Ready returned Pending"),
}
}
#[test]
fn timer_callback_dispatch_table() {
enum TimerOutcome {
Signaled,
Pending,
}
let cases: &[(&str, bool, TimerOutcome)] = &[
("NULL context -> early return, no signal", false, TimerOutcome::Pending),
("valid context -> signals listener", true, TimerOutcome::Signaled),
];
for (label, supply_ctx, expected) in cases {
let state = Arc::new(TimerState {
signal: CompletionSignal::new(),
});
let listener = state.signal.listen();
let ctx = CallbackContext::<TimerState>::new(&state);
let raw_ctx: *mut core::ffi::c_void = if *supply_ctx {
ctx.as_raw() as *mut core::ffi::c_void
} else {
std::ptr::null_mut()
};
unsafe {
timer_callback(0 as PTP_CALLBACK_INSTANCE, raw_ctx, 1 as PTP_TIMER);
}
let waker = Waker::noop();
let mut cx = Context::from_waker(waker);
let mut listener = std::pin::pin!(listener);
let poll = listener.as_mut().poll(&mut cx);
match (expected, poll) {
(TimerOutcome::Signaled, Poll::Ready(Ok(()))) => {}
(TimerOutcome::Signaled, other) => {
panic!("{label}: expected Ready(Ok(())), got {other:?}");
}
(TimerOutcome::Pending, Poll::Pending) => {}
(TimerOutcome::Pending, other) => {
panic!("{label}: expected Pending (no signal), got {other:?}");
}
}
drop(ctx);
}
}
}