use std::collections::HashMap;
use std::task::Waker;
use std::time::Instant;
const SLOTS: usize = 64;
const SLOTS_MASK: u64 = (SLOTS - 1) as u64;
const LEVELS: usize = 6;
const LEVEL0_MS: u64 = 1;
fn slot_width_ms(level: usize) -> u64 {
LEVEL0_MS * (SLOTS as u64).pow(level as u32)
}
#[derive(Debug)]
pub(crate) struct TimerEntry {
pub id: u64,
pub deadline: Instant,
pub waker: Waker,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TimerId(u64);
pub struct TimerWheel {
origin: Instant,
wheel: Vec<Vec<Vec<TimerEntry>>>,
index: HashMap<u64, (usize, usize)>,
next_id: u64,
last_tick_ms: u64,
}
impl TimerWheel {
pub(crate) fn new(origin: Instant) -> Self {
let wheel = (0..LEVELS)
.map(|_| (0..SLOTS).map(|_| Vec::new()).collect())
.collect();
Self {
origin,
wheel,
index: HashMap::new(),
next_id: 1,
last_tick_ms: 0,
}
}
fn instant_to_ms(&self, t: Instant) -> u64 {
t.saturating_duration_since(self.origin)
.as_millis()
.try_into()
.unwrap_or(u64::MAX)
}
pub(crate) fn insert(&mut self, deadline: Instant, waker: Waker) -> TimerId {
let id = self.next_id;
self.next_id += 1;
let deadline_ms = self.instant_to_ms(deadline);
let effective_ms = deadline_ms.max(self.last_tick_ms);
let (level, slot) = self.level_slot(effective_ms);
self.wheel[level][slot].push(TimerEntry {
id,
deadline,
waker,
});
self.index.insert(id, (level, slot));
TimerId(id)
}
pub(crate) fn cancel(&mut self, id: TimerId) -> bool {
let Some((level, slot)) = self.index.remove(&id.0) else {
return false;
};
let bucket = &mut self.wheel[level][slot];
let before = bucket.len();
bucket.retain(|e| e.id != id.0);
bucket.len() < before
}
pub(crate) fn tick(&mut self, now: Instant) -> Vec<Waker> {
let now_ms = self.instant_to_ms(now);
let mut fired: Vec<Waker> = Vec::new();
let from = self.last_tick_ms;
let to = now_ms;
if to < from {
return fired;
}
let from_slot0 = (from & SLOTS_MASK) as usize;
let to_slot0 = (to & SLOTS_MASK) as usize;
let span = to.saturating_sub(from);
if span >= SLOTS as u64 {
for slot in 0..SLOTS {
self.drain_slot(0, slot, to, &mut fired);
}
} else if from_slot0 <= to_slot0 {
for slot in from_slot0..=to_slot0 {
self.drain_slot(0, slot, to, &mut fired);
}
} else {
for slot in from_slot0..SLOTS {
self.drain_slot(0, slot, to, &mut fired);
}
for slot in 0..=to_slot0 {
self.drain_slot(0, slot, to, &mut fired);
}
}
for level in 1..LEVELS {
let width = slot_width_ms(level);
let first_boundary = if from % width == 0 {
from
} else {
(from / width + 1) * width
};
let mut boundary = first_boundary;
while boundary <= to {
let slot = ((boundary / width) & SLOTS_MASK) as usize;
self.drain_slot(level, slot, to, &mut fired);
boundary = match boundary.checked_add(width) {
Some(b) => b,
None => break,
};
}
}
self.last_tick_ms = to;
fired
}
fn drain_slot(&mut self, level: usize, slot: usize, now_ms: u64, fired: &mut Vec<Waker>) {
let entries = std::mem::take(&mut self.wheel[level][slot]);
for entry in entries {
self.index.remove(&entry.id);
if self.instant_to_ms(entry.deadline) <= now_ms {
fired.push(entry.waker);
} else {
self.insert_raw(entry);
}
}
}
pub(crate) fn next_deadline(&self) -> Option<Instant> {
let mut earliest: Option<Instant> = None;
for level in &self.wheel {
for slot in level {
for entry in slot {
earliest = Some(match earliest {
None => entry.deadline,
Some(e) => e.min(entry.deadline),
});
}
}
}
earliest
}
fn insert_raw(&mut self, entry: TimerEntry) {
let deadline_ms = self.instant_to_ms(entry.deadline);
let effective_ms = deadline_ms.max(self.last_tick_ms);
let (level, slot) = self.level_slot(effective_ms);
self.index.insert(entry.id, (level, slot));
self.wheel[level][slot].push(entry);
}
fn level_slot(&self, deadline_ms: u64) -> (usize, usize) {
let delta = deadline_ms.saturating_sub(self.last_tick_ms);
for level in 0..LEVELS {
let width = slot_width_ms(level);
let range = width * SLOTS as u64;
if delta < range || level == LEVELS - 1 {
let slot = ((deadline_ms / width) & SLOTS_MASK) as usize;
return (level, slot);
}
}
(LEVELS - 1, 0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
use std::task::{RawWaker, RawWakerVTable};
use std::time::Duration;
fn make_flag_waker(flag: Arc<Mutex<bool>>) -> Waker {
let data = Arc::into_raw(flag) as *const ();
unsafe fn clone_w(p: *const ()) -> RawWaker {
Arc::increment_strong_count(p as *const Mutex<bool>);
RawWaker::new(p, &VT)
}
unsafe fn wake(p: *const ()) {
*Arc::from_raw(p as *const Mutex<bool>).lock().unwrap() = true;
}
unsafe fn wake_ref(p: *const ()) {
*(*(&p as *const *const () as *const Arc<Mutex<bool>>))
.lock()
.unwrap() = true;
}
unsafe fn drop_w(p: *const ()) {
drop(Arc::from_raw(p as *const Mutex<bool>));
}
static VT: RawWakerVTable = RawWakerVTable::new(clone_w, wake, wake_ref, drop_w);
unsafe { Waker::from_raw(RawWaker::new(data, &VT)) }
}
#[test]
fn insert_and_tick_fires_waker() {
let flag = Arc::new(Mutex::new(false));
let waker = make_flag_waker(Arc::clone(&flag));
let origin = Instant::now();
let mut wheel = TimerWheel::new(origin);
let deadline = origin + Duration::from_millis(50);
wheel.insert(deadline, waker);
let wakers = wheel.tick(origin + Duration::from_millis(30));
assert!(wakers.is_empty());
let wakers = wheel.tick(origin + Duration::from_millis(60));
assert_eq!(wakers.len(), 1);
for w in wakers {
w.wake();
}
assert!(*flag.lock().unwrap(), "waker must have fired");
}
#[test]
fn cancel_prevents_firing() {
let flag = Arc::new(Mutex::new(false));
let waker = make_flag_waker(Arc::clone(&flag));
let origin = Instant::now();
let mut wheel = TimerWheel::new(origin);
let deadline = origin + Duration::from_millis(50);
let id = wheel.insert(deadline, waker);
let removed = wheel.cancel(id);
assert!(removed, "cancel must return true for existing timer");
let wakers = wheel.tick(origin + Duration::from_millis(100));
assert!(wakers.is_empty(), "cancelled timer must not fire");
assert!(!*flag.lock().unwrap());
}
#[test]
fn zero_deadline_fires_on_next_tick() {
let flag = Arc::new(Mutex::new(false));
let waker = make_flag_waker(Arc::clone(&flag));
let origin = Instant::now();
let mut wheel = TimerWheel::new(origin);
wheel.insert(origin, waker);
let wakers = wheel.tick(origin + Duration::from_millis(1));
assert_eq!(wakers.len(), 1);
for w in wakers {
w.wake();
}
assert!(*flag.lock().unwrap());
}
#[test]
fn multiple_timers_fire_in_order() {
let origin = Instant::now();
let mut wheel = TimerWheel::new(origin);
let results = Arc::new(Mutex::new(Vec::<u32>::new()));
for i in 0u32..5 {
let r = Arc::clone(&results);
let flag = Arc::new(Mutex::new(false));
let _waker = make_flag_waker(Arc::clone(&flag));
let _ = flag; let data = Box::into_raw(Box::new((i, r))) as *const ();
type Payload = (u32, Arc<Mutex<Vec<u32>>>);
unsafe fn clone_p(p: *const ()) -> RawWaker {
let b = Box::from_raw(p as *mut Payload);
let cloned = Box::new((b.0, Arc::clone(&b.1)));
std::mem::forget(b);
RawWaker::new(Box::into_raw(cloned) as *const (), &PVT)
}
unsafe fn wake_p(p: *const ()) {
let b = Box::from_raw(p as *mut Payload);
b.1.lock().unwrap().push(b.0);
}
unsafe fn wake_p_ref(p: *const ()) {
let b = Box::from_raw(p as *mut Payload);
b.1.lock().unwrap().push(b.0);
std::mem::forget(b);
}
unsafe fn drop_p(p: *const ()) {
drop(Box::from_raw(p as *mut Payload));
}
static PVT: RawWakerVTable = RawWakerVTable::new(clone_p, wake_p, wake_p_ref, drop_p);
let waker2 = unsafe { Waker::from_raw(RawWaker::new(data, &PVT)) };
wheel.insert(origin + Duration::from_millis((i as u64 + 1) * 10), waker2);
}
let wakers = wheel.tick(origin + Duration::from_millis(60));
assert_eq!(wakers.len(), 5);
for w in wakers {
w.wake();
}
let v = results.lock().unwrap();
assert_eq!(v.len(), 5);
}
#[test]
fn next_deadline_returns_earliest() {
let origin = Instant::now();
let mut wheel = TimerWheel::new(origin);
let d1 = origin + Duration::from_millis(200);
let d2 = origin + Duration::from_millis(50);
let f1 = Arc::new(Mutex::new(false));
let f2 = Arc::new(Mutex::new(false));
wheel.insert(d1, make_flag_waker(Arc::clone(&f1)));
wheel.insert(d2, make_flag_waker(Arc::clone(&f2)));
let earliest = wheel.next_deadline().expect("should have a deadline");
assert_eq!(earliest, d2, "next_deadline must return earliest");
}
#[test]
fn large_time_jump_fires_timer_quickly() {
let flag = Arc::new(Mutex::new(false));
let waker = make_flag_waker(Arc::clone(&flag));
let origin = Instant::now();
let mut wheel = TimerWheel::new(origin);
let deadline = origin + Duration::from_millis(50);
wheel.insert(deadline, waker);
let start = std::time::Instant::now();
let wakers = wheel.tick(origin + Duration::from_secs(10));
let elapsed = start.elapsed();
assert_eq!(wakers.len(), 1, "timer must fire on 10s jump");
for w in wakers {
w.wake();
}
assert!(*flag.lock().unwrap(), "waker must have been called");
assert!(
elapsed < Duration::from_millis(10),
"10s tick must complete in <10ms, took {:?}",
elapsed
);
}
#[test]
fn wheel_cancel_nonexistent_returns_false() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let fake_id = TimerId(9999);
assert!(!w.cancel(fake_id));
}
#[test]
fn wheel_cancel_already_fired_returns_false() {
let flag = Arc::new(Mutex::new(false));
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let waker = make_flag_waker(Arc::clone(&flag));
let id = w.insert(origin + Duration::from_millis(5), waker);
let _ = w.tick(origin + Duration::from_millis(10)); assert!(!w.cancel(id)); }
#[test]
fn wheel_tick_backwards_is_noop() {
let flag = Arc::new(Mutex::new(false));
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let waker = make_flag_waker(Arc::clone(&flag));
w.insert(origin + Duration::from_millis(50), waker);
let _ = w.tick(origin + Duration::from_millis(100)); let wakers = w.tick(origin + Duration::from_millis(10));
assert!(wakers.is_empty());
}
#[test]
fn wheel_multiple_timers_same_slot() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
for _ in 0..5 {
let flag = Arc::new(Mutex::new(false));
let waker = make_flag_waker(Arc::clone(&flag));
w.insert(origin + Duration::from_millis(10), waker);
}
let wakers = w.tick(origin + Duration::from_millis(20));
assert_eq!(wakers.len(), 5);
}
#[test]
fn wheel_1000_timers_all_fire() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
for i in 0..1000u64 {
let flag = Arc::new(Mutex::new(false));
let waker = make_flag_waker(Arc::clone(&flag));
w.insert(origin + Duration::from_millis(i % 100), waker);
}
let wakers = w.tick(origin + Duration::from_millis(200));
assert_eq!(wakers.len(), 1000);
}
#[test]
fn wheel_next_deadline_empty_returns_none() {
let origin = Instant::now();
let w = TimerWheel::new(origin);
assert!(w.next_deadline().is_none());
}
#[test]
fn wheel_next_deadline_after_cancel_updates() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let f1 = Arc::new(Mutex::new(false));
let f2 = Arc::new(Mutex::new(false));
let d1 = origin + Duration::from_millis(100);
let d2 = origin + Duration::from_millis(200);
let id1 = w.insert(d1, make_flag_waker(Arc::clone(&f1)));
let _id2 = w.insert(d2, make_flag_waker(Arc::clone(&f2)));
assert_eq!(w.next_deadline().unwrap(), d1);
w.cancel(id1);
assert_eq!(w.next_deadline().unwrap(), d2);
}
#[test]
fn wheel_partial_tick_does_not_fire_future_timers() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let flag = Arc::new(Mutex::new(false));
w.insert(
origin + Duration::from_millis(100),
make_flag_waker(Arc::clone(&flag)),
);
let wakers = w.tick(origin + Duration::from_millis(50));
assert!(wakers.is_empty());
assert!(!*flag.lock().unwrap());
}
#[test]
fn wheel_level_boundary_cascades_correctly() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let flag = Arc::new(Mutex::new(false));
w.insert(
origin + Duration::from_millis(65),
make_flag_waker(Arc::clone(&flag)),
);
let wakers = w.tick(origin + Duration::from_millis(70));
assert_eq!(wakers.len(), 1);
}
#[test]
fn wheel_insert_past_deadline_fires_on_first_tick() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let flag = Arc::new(Mutex::new(false));
let past_deadline = origin; w.insert(past_deadline, make_flag_waker(Arc::clone(&flag)));
let wakers = w.tick(origin + Duration::from_millis(1));
assert!(!wakers.is_empty());
}
#[test]
fn wheel_two_timers_different_deadlines_only_earlier_fires() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let f1 = Arc::new(Mutex::new(false));
let f2 = Arc::new(Mutex::new(false));
w.insert(
origin + Duration::from_millis(10),
make_flag_waker(Arc::clone(&f1)),
);
w.insert(
origin + Duration::from_millis(50),
make_flag_waker(Arc::clone(&f2)),
);
let wakers = w.tick(origin + Duration::from_millis(20));
assert_eq!(wakers.len(), 1);
assert!(!*f2.lock().unwrap());
}
#[test]
fn wheel_cancel_all_removes_from_index() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let mut ids = Vec::new();
for i in 1..=5u64 {
let flag = Arc::new(Mutex::new(false));
let id = w.insert(origin + Duration::from_millis(i * 10), make_flag_waker(flag));
ids.push(id);
}
for id in ids {
assert!(w.cancel(id));
}
let wakers = w.tick(origin + Duration::from_millis(100));
assert!(wakers.is_empty());
}
#[test]
fn wheel_many_deadlines_at_different_levels() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let f1 = Arc::new(Mutex::new(false));
w.insert(origin + Duration::from_millis(1), make_flag_waker(Arc::clone(&f1)));
let f2 = Arc::new(Mutex::new(false));
w.insert(origin + Duration::from_millis(100), make_flag_waker(Arc::clone(&f2)));
let f3 = Arc::new(Mutex::new(false));
w.insert(origin + Duration::from_millis(5000), make_flag_waker(Arc::clone(&f3)));
let wakers = w.tick(origin + Duration::from_millis(200));
assert_eq!(wakers.len(), 2);
assert!(!*f3.lock().unwrap());
let wakers2 = w.tick(origin + Duration::from_millis(6000));
assert_eq!(wakers2.len(), 1);
}
#[test]
fn wheel_empty_tick_returns_empty_vec() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let wakers = w.tick(origin + Duration::from_millis(1000));
assert!(wakers.is_empty());
}
#[test]
fn wheel_same_tick_twice_second_empty() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let flag = Arc::new(Mutex::new(false));
w.insert(origin + Duration::from_millis(10), make_flag_waker(Arc::clone(&flag)));
let wakers1 = w.tick(origin + Duration::from_millis(20));
assert_eq!(wakers1.len(), 1);
let wakers2 = w.tick(origin + Duration::from_millis(20));
assert!(wakers2.is_empty());
}
#[test]
fn wheel_timer_id_uniqueness() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let mut ids = std::collections::HashSet::new();
for i in 0..10u64 {
let flag = Arc::new(Mutex::new(false));
let id = w.insert(origin + Duration::from_millis(i * 5 + 1), make_flag_waker(flag));
assert!(ids.insert(id));
}
}
#[test]
fn wheel_tick_advances_last_tick_ms() {
let origin = Instant::now();
let mut w = TimerWheel::new(origin);
let flag = Arc::new(Mutex::new(false));
w.insert(origin + Duration::from_millis(200), make_flag_waker(Arc::clone(&flag)));
let wakers1 = w.tick(origin + Duration::from_millis(100));
assert!(wakers1.is_empty());
let wakers2 = w.tick(origin + Duration::from_millis(250));
assert_eq!(wakers2.len(), 1);
}
}