use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::task::Waker;
#[derive(Debug)]
struct VirtualTimer {
deadline: u64,
timer_id: u64,
waker: Waker,
}
impl Eq for VirtualTimer {}
impl PartialEq for VirtualTimer {
fn eq(&self, other: &Self) -> bool {
self.deadline == other.deadline && self.timer_id == other.timer_id
}
}
impl Ord for VirtualTimer {
fn cmp(&self, other: &Self) -> Ordering {
other
.deadline
.cmp(&self.deadline)
.then_with(|| other.timer_id.cmp(&self.timer_id))
}
}
impl PartialOrd for VirtualTimer {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct VirtualTimerHandle {
timer_id: u64,
deadline: u64,
}
impl VirtualTimerHandle {
#[must_use]
pub const fn timer_id(&self) -> u64 {
self.timer_id
}
#[must_use]
pub const fn deadline(&self) -> u64 {
self.deadline
}
}
#[derive(Debug)]
pub struct ExpiredTimer {
pub timer_id: u64,
pub deadline: u64,
pub waker: Waker,
}
#[derive(Debug)]
pub struct VirtualTimerWheel {
heap: BinaryHeap<VirtualTimer>,
current_tick: u64,
next_timer_id: u64,
cancelled: std::collections::BTreeSet<u64>,
}
impl Default for VirtualTimerWheel {
fn default() -> Self {
Self::new()
}
}
impl VirtualTimerWheel {
#[must_use]
pub fn new() -> Self {
Self {
heap: BinaryHeap::new(),
current_tick: 0,
next_timer_id: 0,
cancelled: std::collections::BTreeSet::new(),
}
}
#[must_use]
pub fn starting_at(tick: u64) -> Self {
Self {
heap: BinaryHeap::new(),
current_tick: tick,
next_timer_id: 0,
cancelled: std::collections::BTreeSet::new(),
}
}
#[must_use]
pub const fn current_tick(&self) -> u64 {
self.current_tick
}
#[must_use]
pub fn len(&self) -> usize {
self.pending_count()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.pending_count() == 0
}
fn pending_count(&self) -> usize {
self.heap
.iter()
.filter(|t| !self.cancelled.contains(&t.timer_id))
.count()
}
pub fn insert(&mut self, deadline: u64, waker: Waker) -> VirtualTimerHandle {
let timer_id = self.next_timer_id;
self.next_timer_id = self
.next_timer_id
.checked_add(1)
.expect("virtual timer ID space exhausted");
self.heap.push(VirtualTimer {
deadline,
timer_id,
waker,
});
VirtualTimerHandle { timer_id, deadline }
}
pub fn cancel(&mut self, handle: VirtualTimerHandle) {
self.cancelled.insert(handle.timer_id);
}
#[must_use]
pub fn next_deadline(&mut self) -> Option<u64> {
while let Some(top) = self.heap.peek() {
if self.cancelled.remove(&top.timer_id) {
self.heap.pop();
} else {
return Some(top.deadline);
}
}
None
}
pub fn advance_to_next(&mut self) -> Vec<ExpiredTimer> {
self.next_deadline()
.map_or_else(Vec::new, |deadline| self.advance_to(deadline))
}
pub fn advance_by(&mut self, ticks: u64) -> Vec<ExpiredTimer> {
self.advance_to(self.current_tick.saturating_add(ticks))
}
pub fn advance_to(&mut self, target_tick: u64) -> Vec<ExpiredTimer> {
if target_tick < self.current_tick {
return Vec::new();
}
let mut expired = Vec::new();
while let Some(timer) = self.heap.peek() {
if timer.deadline > target_tick {
break;
}
let Some(timer) = self.heap.pop() else {
break;
};
if self.cancelled.remove(&timer.timer_id) {
continue;
}
expired.push(ExpiredTimer {
timer_id: timer.timer_id,
deadline: timer.deadline,
waker: timer.waker,
});
}
self.current_tick = target_tick;
self.cleanup_cancelled();
expired.sort_by(|a, b| {
a.deadline
.cmp(&b.deadline)
.then_with(|| a.timer_id.cmp(&b.timer_id))
});
expired
}
fn cleanup_cancelled(&mut self) {
if self.cancelled.len() > self.heap.len() {
let heap_ids: std::collections::BTreeSet<_> =
self.heap.iter().map(|t| t.timer_id).collect();
self.cancelled.retain(|id| heap_ids.contains(id));
}
}
#[must_use]
pub fn collect_wakers(&self, up_to_tick: u64) -> Vec<Waker> {
let mut ready: Vec<_> = self
.heap
.iter()
.filter(|t| t.deadline <= up_to_tick && !self.cancelled.contains(&t.timer_id))
.collect();
ready.sort_by(|a, b| {
a.deadline
.cmp(&b.deadline)
.then_with(|| a.timer_id.cmp(&b.timer_id))
});
ready.into_iter().map(|t| t.waker.clone()).collect()
}
pub fn clear(&mut self) {
self.heap.clear();
self.cancelled.clear();
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicUsize, Ordering};
fn scrub_timer_id(timer_id: u64) -> &'static str {
match timer_id {
0 => "[TIMER_A]",
1 => "[TIMER_B]",
2 => "[TIMER_C]",
3 => "[TIMER_D]",
_ => "[TIMER_OTHER]",
}
}
struct CountingWaker(AtomicUsize);
use std::task::Wake;
impl Wake for CountingWaker {
fn wake(self: Arc<Self>) {
self.0.fetch_add(1, Ordering::Relaxed);
}
fn wake_by_ref(self: &Arc<Self>) {
self.0.fetch_add(1, Ordering::Relaxed);
}
}
fn counting_waker() -> (Arc<CountingWaker>, Waker) {
let counter = Arc::new(CountingWaker(AtomicUsize::new(0)));
let waker = Waker::from(counter.clone());
(counter, waker)
}
struct RecordingWaker {
id: usize,
wake_order: Arc<Mutex<Vec<usize>>>,
}
impl Wake for RecordingWaker {
fn wake(self: Arc<Self>) {
self.wake_order
.lock()
.expect("wake order lock")
.push(self.id);
}
fn wake_by_ref(self: &Arc<Self>) {
self.wake_order
.lock()
.expect("wake order lock")
.push(self.id);
}
}
fn recording_waker(id: usize, wake_order: Arc<Mutex<Vec<usize>>>) -> Waker {
Waker::from(Arc::new(RecordingWaker { id, wake_order }))
}
#[test]
fn new_wheel_starts_at_zero() {
let wheel = VirtualTimerWheel::new();
assert_eq!(wheel.current_tick(), 0);
assert!(wheel.is_empty());
}
#[test]
fn starting_at_custom_tick() {
let wheel = VirtualTimerWheel::starting_at(1000);
assert_eq!(wheel.current_tick(), 1000);
}
#[test]
fn insert_and_advance_to() {
let mut wheel = VirtualTimerWheel::new();
let (_, waker1) = counting_waker();
let (_, waker2) = counting_waker();
let (_, waker3) = counting_waker();
wheel.insert(100, waker1);
wheel.insert(50, waker2);
wheel.insert(200, waker3);
let expired = wheel.advance_to(75);
assert_eq!(expired.len(), 1);
assert_eq!(expired[0].deadline, 50);
assert_eq!(wheel.current_tick(), 75);
let expired = wheel.advance_to(150);
assert_eq!(expired.len(), 1);
assert_eq!(expired[0].deadline, 100);
let expired = wheel.advance_to(250);
assert_eq!(expired.len(), 1);
assert_eq!(expired[0].deadline, 200);
assert!(wheel.is_empty());
}
#[test]
fn advance_to_next() {
let mut wheel = VirtualTimerWheel::new();
let (_, waker1) = counting_waker();
let (_, waker2) = counting_waker();
wheel.insert(100, waker1);
wheel.insert(50, waker2);
let expired = wheel.advance_to_next();
assert_eq!(expired.len(), 1);
assert_eq!(expired[0].deadline, 50);
assert_eq!(wheel.current_tick(), 50);
let expired = wheel.advance_to_next();
assert_eq!(expired.len(), 1);
assert_eq!(expired[0].deadline, 100);
assert_eq!(wheel.current_tick(), 100);
let expired = wheel.advance_to_next();
assert!(expired.is_empty());
assert_eq!(wheel.current_tick(), 100); }
#[test]
fn advance_by() {
let mut wheel = VirtualTimerWheel::new();
let (_, waker1) = counting_waker();
let (_, waker2) = counting_waker();
wheel.insert(100, waker1);
wheel.insert(50, waker2);
let expired = wheel.advance_by(75);
assert_eq!(expired.len(), 1);
assert_eq!(wheel.current_tick(), 75);
let expired = wheel.advance_by(50);
assert_eq!(expired.len(), 1);
assert_eq!(wheel.current_tick(), 125);
}
#[test]
fn deterministic_ordering_by_timer_id() {
let mut wheel = VirtualTimerWheel::new();
let (_, waker1) = counting_waker();
let (_, waker2) = counting_waker();
let (_, waker3) = counting_waker();
let h1 = wheel.insert(100, waker1);
let h2 = wheel.insert(100, waker2);
let h3 = wheel.insert(100, waker3);
let expired = wheel.advance_to(100);
assert_eq!(expired.len(), 3);
assert_eq!(expired[0].timer_id, h1.timer_id());
assert_eq!(expired[1].timer_id, h2.timer_id());
assert_eq!(expired[2].timer_id, h3.timer_id());
}
#[test]
fn cancel_timer() {
let mut wheel = VirtualTimerWheel::new();
let (_, waker1) = counting_waker();
let (_, waker2) = counting_waker();
let h1 = wheel.insert(100, waker1);
let h2 = wheel.insert(100, waker2);
wheel.cancel(h1);
let expired = wheel.advance_to(100);
assert_eq!(expired.len(), 1);
assert_eq!(expired[0].timer_id, h2.timer_id());
}
#[test]
fn stale_cancel_handle_does_not_hide_pending_timers() {
let mut wheel = VirtualTimerWheel::new();
let (_, stale_waker) = counting_waker();
let stale_handle = wheel.insert(10, stale_waker);
let expired = wheel.advance_to(10);
assert_eq!(expired.len(), 1);
let (_, live_waker) = counting_waker();
let live_handle = wheel.insert(20, live_waker);
wheel.cancel(stale_handle);
assert_eq!(wheel.len(), 1);
assert!(!wheel.is_empty());
assert_eq!(wheel.next_deadline(), Some(20));
let expired = wheel.advance_to(20);
assert_eq!(expired.len(), 1);
assert_eq!(expired[0].timer_id, live_handle.timer_id());
}
#[test]
fn insert_panics_before_timer_ids_wrap() {
let mut wheel = VirtualTimerWheel::new();
wheel.next_timer_id = u64::MAX - 1;
let (_, first_waker) = counting_waker();
let first = wheel.insert(10, first_waker);
assert_eq!(first.timer_id(), u64::MAX - 1);
assert_eq!(wheel.next_timer_id, u64::MAX);
let (_, overflow_waker) = counting_waker();
let panic = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _ = wheel.insert(20, overflow_waker);
}));
assert!(
panic.is_err(),
"timer wheel must fail closed instead of wrapping timer IDs"
);
assert_eq!(
wheel.next_timer_id,
u64::MAX,
"failed insert must not wrap the next timer ID"
);
assert_eq!(
wheel.next_deadline(),
Some(10),
"overflow attempt must not enqueue a wrapped timer"
);
}
#[test]
fn next_deadline_skips_cancelled() {
let mut wheel = VirtualTimerWheel::new();
let (_, waker1) = counting_waker();
let (_, waker2) = counting_waker();
let h1 = wheel.insert(50, waker1);
wheel.insert(100, waker2);
wheel.cancel(h1);
assert_eq!(wheel.next_deadline(), Some(100));
}
#[test]
fn determinism_across_runs() {
fn run_test(seed: u64) -> Vec<u64> {
let mut wheel = VirtualTimerWheel::starting_at(seed);
let deadlines = [
seed.wrapping_mul(7) % 1000,
seed.wrapping_mul(13) % 1000,
seed.wrapping_mul(17) % 1000,
];
for deadline in deadlines {
let (_, waker) = counting_waker();
wheel.insert(seed + deadline, waker);
}
let expired = wheel.advance_to(seed + 1000);
expired.iter().map(|e| e.timer_id).collect()
}
let order1 = run_test(42);
let order2 = run_test(42);
assert_eq!(order1, order2, "Same seed should produce same order");
let order3 = run_test(123);
assert_eq!(order3.len(), 3);
}
#[test]
fn advance_to_past_is_noop() {
let mut wheel = VirtualTimerWheel::starting_at(100);
let expired = wheel.advance_to(50);
assert!(expired.is_empty());
assert_eq!(wheel.current_tick(), 100);
}
#[test]
fn advance_to_current_tick_fires_due_timers() {
let mut wheel = VirtualTimerWheel::starting_at(100);
let (_, waker) = counting_waker();
wheel.insert(100, waker);
let expired = wheel.advance_to(100);
assert_eq!(expired.len(), 1);
assert_eq!(expired[0].deadline, 100);
assert_eq!(wheel.current_tick(), 100);
}
#[test]
fn large_time_jump() {
let mut wheel = VirtualTimerWheel::new();
let (_, waker1) = counting_waker();
let (_, waker2) = counting_waker();
let (_, waker3) = counting_waker();
wheel.insert(100, waker1);
wheel.insert(1000, waker2);
wheel.insert(1_000_000, waker3);
let expired = wheel.advance_to(2_000_000);
assert_eq!(expired.len(), 3);
assert_eq!(expired[0].deadline, 100);
assert_eq!(expired[1].deadline, 1000);
assert_eq!(expired[2].deadline, 1_000_000);
}
#[test]
fn mixed_deadlines_ordering() {
let mut wheel = VirtualTimerWheel::new();
let (_, waker1) = counting_waker();
let (_, waker2) = counting_waker();
let (_, waker3) = counting_waker();
let (_, waker4) = counting_waker();
wheel.insert(200, waker1); wheel.insert(100, waker2); wheel.insert(100, waker3); wheel.insert(200, waker4);
let expired = wheel.advance_to(300);
assert_eq!(expired.len(), 4);
assert_eq!(expired[0].deadline, 100);
assert_eq!(expired[0].timer_id, 1);
assert_eq!(expired[1].deadline, 100);
assert_eq!(expired[1].timer_id, 2);
assert_eq!(expired[2].deadline, 200);
assert_eq!(expired[2].timer_id, 0);
assert_eq!(expired[3].deadline, 200);
assert_eq!(expired[3].timer_id, 3);
}
#[test]
fn collect_wakers_preserves_deterministic_deadline_then_id_order() {
let mut wheel = VirtualTimerWheel::new();
let wake_order = Arc::new(Mutex::new(Vec::new()));
let h0 = wheel.insert(200, recording_waker(0, wake_order.clone()));
let h1 = wheel.insert(100, recording_waker(1, wake_order.clone()));
let h2 = wheel.insert(100, recording_waker(2, wake_order.clone()));
let h3 = wheel.insert(150, recording_waker(3, wake_order.clone()));
assert_eq!(h0.timer_id(), 0);
assert_eq!(h1.timer_id(), 1);
assert_eq!(h2.timer_id(), 2);
assert_eq!(h3.timer_id(), 3);
let wakers = wheel.collect_wakers(200);
assert_eq!(wakers.len(), 4);
for waker in wakers {
waker.wake();
}
let order = wake_order.lock().expect("wake order lock").clone();
assert_eq!(
order,
vec![1, 2, 3, 0],
"collect_wakers must preserve deadline-then-id order"
);
}
#[test]
fn virtual_timer_handle_debug_clone_copy_eq_hash() {
use std::collections::HashSet;
let mut wheel = VirtualTimerWheel::new();
let (_counter, waker) = counting_waker();
let handle = wheel.insert(100, waker);
let b = handle; let c = handle;
assert_eq!(handle, b);
assert_eq!(handle, c);
let dbg = format!("{handle:?}");
assert!(dbg.contains("VirtualTimerHandle"));
let mut set = HashSet::new();
set.insert(handle);
assert!(set.contains(&b));
}
#[test]
fn wheel_tick_snapshot_scrubbed() {
let mut wheel = VirtualTimerWheel::new();
let (_, waker_a) = counting_waker();
let (_, waker_b) = counting_waker();
let (_, waker_c) = counting_waker();
let timer_a = wheel.insert(20, waker_a);
let timer_b = wheel.insert(10, waker_b);
let timer_c = wheel.insert(10, waker_c);
wheel.cancel(timer_b);
let expired = wheel.advance_to(15);
insta::assert_json_snapshot!(
"wheel_tick_scrubbed",
json!({
"before": {
"inserted": [
{"timer": scrub_timer_id(timer_a.timer_id()), "deadline": timer_a.deadline()},
{"timer": scrub_timer_id(timer_b.timer_id()), "deadline": timer_b.deadline()},
{"timer": scrub_timer_id(timer_c.timer_id()), "deadline": timer_c.deadline()},
],
"cancelled": scrub_timer_id(timer_b.timer_id()),
},
"after": {
"current_tick": wheel.current_tick(),
"next_deadline": wheel.next_deadline(),
"pending_len": wheel.len(),
"expired": expired.into_iter().map(|timer| json!({
"timer": scrub_timer_id(timer.timer_id),
"deadline": timer.deadline,
})).collect::<Vec<_>>(),
}
})
);
}
}