use std::ffi::c_void;
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use std::task::{Context, Poll, Waker};
use std::time::Duration;
use tokio::time::Instant;
use windows_sys::Win32::Foundation::{CloseHandle, FILETIME, HANDLE, TRUE};
use windows_sys::Win32::System::Threading::{
CancelWaitableTimer, CloseThreadpoolTimer, CloseThreadpoolWait, CreateThreadpoolTimer,
CreateThreadpoolWait, CreateWaitableTimerExW, SetThreadpoolTimer, SetThreadpoolWait,
SetWaitableTimer, WaitForThreadpoolTimerCallbacks, WaitForThreadpoolWaitCallbacks,
CREATE_WAITABLE_TIMER_HIGH_RESOLUTION, PTP_CALLBACK_INSTANCE, PTP_TIMER, PTP_WAIT,
TIMER_ALL_ACCESS,
};
struct State {
fired: AtomicBool,
waker: Mutex<Option<Waker>>,
}
impl State {
fn new() -> Self {
Self {
fired: AtomicBool::new(false),
waker: Mutex::new(None),
}
}
fn fire(&self) {
self.fired.store(true, Ordering::Release);
let waker = self.waker.lock().expect("waker mutex poisoned").take();
if let Some(w) = waker {
w.wake();
}
}
fn poll(&self, cx: &mut Context<'_>) -> Poll<()> {
if self.fired.swap(false, Ordering::Acquire) {
return Poll::Ready(());
}
let mut slot = self.waker.lock().expect("waker mutex poisoned");
if self.fired.swap(false, Ordering::Acquire) {
return Poll::Ready(());
}
match slot.as_ref() {
Some(w) if w.will_wake(cx.waker()) => {}
_ => *slot = Some(cx.waker().clone()),
}
Poll::Pending
}
fn reset(&self) {
self.fired.store(false, Ordering::Release);
*self.waker.lock().expect("waker mutex poisoned") = None;
}
}
fn high_res_supported() -> bool {
static SUPPORTED: OnceLock<bool> = OnceLock::new();
*SUPPORTED.get_or_init(|| unsafe {
let h = CreateWaitableTimerExW(
ptr::null(),
ptr::null(),
CREATE_WAITABLE_TIMER_HIGH_RESOLUTION,
TIMER_ALL_ACCESS,
);
if h.is_null() {
false
} else {
CloseHandle(h);
true
}
})
}
struct HighRes {
htimer: HANDLE,
pwait: PTP_WAIT,
state: Arc<State>,
ctx: *const State,
}
unsafe impl Send for HighRes {}
impl HighRes {
fn new(state: Arc<State>) -> Option<Self> {
if !high_res_supported() {
return None;
}
unsafe {
let htimer = CreateWaitableTimerExW(
ptr::null(),
ptr::null(),
CREATE_WAITABLE_TIMER_HIGH_RESOLUTION,
TIMER_ALL_ACCESS,
);
if htimer.is_null() {
return None;
}
let ctx: *const State = Arc::into_raw(Arc::clone(&state));
let pwait = CreateThreadpoolWait(Some(wait_callback), ctx as *mut c_void, ptr::null());
if pwait == 0 {
drop(Arc::from_raw(ctx));
CloseHandle(htimer);
return None;
}
Some(Self {
htimer,
pwait,
state,
ctx,
})
}
}
fn arm(&mut self, deadline: Instant) {
let delta = deadline.saturating_duration_since(Instant::now());
let due = relative_due_time(delta);
self.state.reset();
unsafe {
SetWaitableTimer(self.htimer, &due, 0, None, ptr::null(), 0);
SetThreadpoolWait(self.pwait, self.htimer, ptr::null());
}
}
fn disarm(&mut self) {
unsafe {
SetThreadpoolWait(self.pwait, ptr::null_mut(), ptr::null());
CancelWaitableTimer(self.htimer);
WaitForThreadpoolWaitCallbacks(self.pwait, TRUE);
}
self.state.reset();
}
fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll<()> {
self.state.poll(cx)
}
}
impl Drop for HighRes {
fn drop(&mut self) {
unsafe {
SetThreadpoolWait(self.pwait, ptr::null_mut(), ptr::null());
CancelWaitableTimer(self.htimer);
WaitForThreadpoolWaitCallbacks(self.pwait, TRUE);
CloseThreadpoolWait(self.pwait);
CloseHandle(self.htimer);
drop(Arc::from_raw(self.ctx));
}
}
}
unsafe extern "system" fn wait_callback(
_instance: PTP_CALLBACK_INSTANCE,
context: *mut c_void,
_wait: PTP_WAIT,
_wait_result: u32,
) {
if context.is_null() {
return;
}
let state: &State = unsafe { &*context.cast::<State>() };
state.fire();
}
struct Pool {
handle: PTP_TIMER,
state: Arc<State>,
ctx: *const State,
}
unsafe impl Send for Pool {}
impl Pool {
fn new(state: Arc<State>) -> Self {
let ctx: *const State = Arc::into_raw(Arc::clone(&state));
let handle =
unsafe { CreateThreadpoolTimer(Some(timer_callback), ctx as *mut c_void, ptr::null()) };
if handle == 0 {
let err = std::io::Error::last_os_error();
unsafe {
drop(Arc::from_raw(ctx));
}
panic!("CreateThreadpoolTimer failed: {err}");
}
Self { handle, state, ctx }
}
fn arm(&mut self, deadline: Instant) {
let delta = deadline.saturating_duration_since(Instant::now());
let ft = relative_filetime(delta);
self.state.reset();
unsafe {
SetThreadpoolTimer(self.handle, &ft, 0, 0);
}
}
fn disarm(&mut self) {
unsafe {
SetThreadpoolTimer(self.handle, ptr::null(), 0, 0);
WaitForThreadpoolTimerCallbacks(self.handle, TRUE);
}
self.state.reset();
}
fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll<()> {
self.state.poll(cx)
}
}
impl Drop for Pool {
fn drop(&mut self) {
unsafe {
SetThreadpoolTimer(self.handle, ptr::null(), 0, 0);
WaitForThreadpoolTimerCallbacks(self.handle, TRUE);
CloseThreadpoolTimer(self.handle);
drop(Arc::from_raw(self.ctx));
}
}
}
unsafe extern "system" fn timer_callback(
_instance: PTP_CALLBACK_INSTANCE,
context: *mut c_void,
_timer: PTP_TIMER,
) {
if context.is_null() {
return;
}
let state: &State = unsafe { &*context.cast::<State>() };
state.fire();
}
enum Inner {
HighRes(HighRes),
Pool(Pool),
}
pub(super) struct Timer(Inner);
impl Timer {
pub(super) fn new() -> Self {
let state = Arc::new(State::new());
if let Some(hr) = HighRes::new(Arc::clone(&state)) {
return Self(Inner::HighRes(hr));
}
Self(Inner::Pool(Pool::new(state)))
}
pub(super) fn arm(&mut self, deadline: Instant) {
match &mut self.0 {
Inner::HighRes(t) => t.arm(deadline),
Inner::Pool(t) => t.arm(deadline),
}
}
pub(super) fn disarm(&mut self) {
match &mut self.0 {
Inner::HighRes(t) => t.disarm(),
Inner::Pool(t) => t.disarm(),
}
}
pub(super) fn poll_expired(&mut self, cx: &mut Context<'_>) -> Poll<()> {
match &mut self.0 {
Inner::HighRes(t) => t.poll_expired(cx),
Inner::Pool(t) => t.poll_expired(cx),
}
}
}
fn relative_due_time(d: Duration) -> i64 {
let hundred_ns = d.as_nanos() / 100;
let hundred_ns = if hundred_ns == 0 { 1 } else { hundred_ns };
clamp_to_i64(hundred_ns).wrapping_neg()
}
fn relative_filetime(d: Duration) -> FILETIME {
#[allow(clippy::cast_sign_loss)]
let bits = relative_due_time(d) as u64;
#[allow(clippy::cast_possible_truncation)]
FILETIME {
dwLowDateTime: bits as u32,
dwHighDateTime: (bits >> 32) as u32,
}
}
fn clamp_to_i64(value: u128) -> i64 {
if value > i64::MAX as u128 {
i64::MAX
} else {
#[allow(clippy::cast_possible_truncation)]
{
value as i64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn filetime_to_i64(ft: FILETIME) -> i64 {
((u64::from(ft.dwHighDateTime) << 32) | u64::from(ft.dwLowDateTime)) as i64
}
#[test]
fn relative_due_time_zero_is_minus_one() {
assert_eq!(relative_due_time(Duration::ZERO), -1);
}
#[test]
fn relative_due_time_sub_hundred_ns_floors_to_minus_one() {
assert_eq!(relative_due_time(Duration::from_nanos(50)), -1);
}
#[test]
fn relative_due_time_one_hundred_ns_unit() {
assert_eq!(relative_due_time(Duration::from_nanos(100)), -1);
assert_eq!(relative_due_time(Duration::from_nanos(200)), -2);
}
#[test]
fn relative_due_time_one_millisecond() {
assert_eq!(relative_due_time(Duration::from_millis(1)), -10_000);
}
#[test]
fn relative_due_time_one_second() {
assert_eq!(relative_due_time(Duration::from_secs(1)), -10_000_000);
}
#[test]
fn relative_due_time_is_always_negative() {
for d in [
Duration::ZERO,
Duration::from_nanos(1),
Duration::from_nanos(99),
Duration::from_nanos(100),
Duration::from_micros(1),
Duration::from_millis(1),
Duration::from_secs(1),
Duration::from_secs(60 * 60 * 24),
] {
assert!(
relative_due_time(d) < 0,
"relative_due_time({d:?}) was not negative",
);
}
}
#[test]
fn relative_due_time_clamps_huge_durations() {
let huge = Duration::new(u64::MAX, 999_999_999);
let v = relative_due_time(huge);
assert_eq!(v, (i64::MAX).wrapping_neg());
assert!(v < 0);
}
#[test]
fn relative_filetime_matches_relative_due_time() {
for d in [
Duration::ZERO,
Duration::from_nanos(50),
Duration::from_nanos(100),
Duration::from_micros(1),
Duration::from_millis(1),
Duration::from_millis(250),
Duration::from_secs(1),
Duration::from_secs(3600),
] {
let ft = relative_filetime(d);
assert_eq!(
filetime_to_i64(ft),
relative_due_time(d),
"FILETIME for {d:?} did not round-trip to the same i64",
);
}
}
#[test]
fn relative_filetime_one_millisecond_split() {
let ft = relative_filetime(Duration::from_millis(1));
assert_eq!(ft.dwHighDateTime, 0xFFFF_FFFF);
assert_eq!(ft.dwLowDateTime, 0xFFFF_D8F0);
}
#[test]
fn relative_filetime_zero_is_minus_one_split() {
let ft = relative_filetime(Duration::ZERO);
assert_eq!(ft.dwHighDateTime, 0xFFFF_FFFF);
assert_eq!(ft.dwLowDateTime, 0xFFFF_FFFF);
}
}