#![allow(unsafe_code)]
#![allow(unsafe_op_in_unsafe_fn)]
use std::cell::Cell;
use std::marker::PhantomPinned;
use std::ptr::NonNull;
use std::task::Waker;
use std::time::{Duration, Instant};
pub struct TimerNode {
next: Cell<Option<NonNull<Self>>>,
prev: Cell<Option<NonNull<Self>>>,
waker: Cell<Option<Waker>>,
slot: Cell<usize>,
level: Cell<u8>,
deadline: Cell<Instant>,
linked: Cell<bool>,
_pinned: PhantomPinned,
}
impl std::fmt::Debug for TimerNode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TimerNode")
.field("next", &self.next.get().map(std::ptr::NonNull::as_ptr))
.field("prev", &self.prev.get().map(std::ptr::NonNull::as_ptr))
.field("waker", &"<waker>")
.field("slot", &self.slot.get())
.field("level", &self.level.get())
.field("deadline", &self.deadline.get())
.field("linked", &self.linked.get())
.finish_non_exhaustive()
}
}
impl Drop for TimerNode {
fn drop(&mut self) {
if self.is_linked() {
if std::thread::panicking() {
return;
}
panic!(
"TimerNode dropped while still linked in TimerWheel! This is a severe safety violation and use-after-free bug."
);
}
}
}
impl TimerNode {
#[must_use]
pub fn new() -> Self {
Self {
next: Cell::new(None),
prev: Cell::new(None),
waker: Cell::new(None),
slot: Cell::new(0),
level: Cell::new(0),
deadline: Cell::new(Instant::now()),
linked: Cell::new(false),
_pinned: PhantomPinned,
}
}
#[must_use]
pub fn is_linked(&self) -> bool {
self.linked.get()
}
#[must_use]
pub fn deadline(&self) -> Instant {
self.deadline.get()
}
#[must_use]
pub fn level(&self) -> u8 {
self.level.get()
}
fn take_waker(&self) -> Option<Waker> {
self.waker.take()
}
fn set(&self, deadline: Instant, waker: Waker, slot: usize, level: u8) {
self.deadline.set(deadline);
self.waker.set(Some(waker));
self.slot.set(slot);
self.level.set(level);
}
fn update_slot_level(&self, slot: usize, level: u8) {
self.slot.set(slot);
self.level.set(level);
}
}
impl Default for TimerNode {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
struct TimerSlot {
head: Cell<Option<NonNull<TimerNode>>>,
tail: Cell<Option<NonNull<TimerNode>>>,
count: Cell<usize>,
}
impl TimerSlot {
const fn new() -> Self {
Self {
head: Cell::new(None),
tail: Cell::new(None),
count: Cell::new(0),
}
}
unsafe fn push_back(&self, node: NonNull<TimerNode>) {
let node_ref = node.as_ref();
node_ref.next.set(None);
node_ref.prev.set(self.tail.get());
node_ref.linked.set(true);
if let Some(tail) = self.tail.get() {
tail.as_ref().next.set(Some(node));
} else {
self.head.set(Some(node));
}
self.tail.set(Some(node));
self.count.set(self.count.get() + 1);
}
unsafe fn remove(&self, node: NonNull<TimerNode>) {
let node_ref = node.as_ref();
if !node_ref.linked.get() {
return;
}
let prev = node_ref.prev.get();
let next = node_ref.next.get();
match prev {
Some(prev_ptr) => prev_ptr.as_ref().next.set(next),
None => self.head.set(next),
}
match next {
Some(next_ptr) => next_ptr.as_ref().prev.set(prev),
None => self.tail.set(prev),
}
node_ref.prev.set(None);
node_ref.next.set(None);
node_ref.linked.set(false);
self.count.set(self.count.get().saturating_sub(1));
}
unsafe fn pop_front(&self) -> Option<NonNull<TimerNode>> {
let head = self.head.get()?;
self.remove(head);
Some(head)
}
unsafe fn drain(&self) -> Vec<Waker> {
let mut wakers = Vec::with_capacity(self.count.get());
while let Some(node) = self.pop_front() {
if let Some(waker) = node.as_ref().take_waker() {
wakers.push(waker);
}
}
wakers
}
unsafe fn take_all(&self) -> Option<NonNull<TimerNode>> {
let head = self.head.get();
self.head.set(None);
self.tail.set(None);
self.count.set(0);
head
}
unsafe fn collect_expired(&self, now: Instant) -> (Vec<Waker>, usize) {
let mut wakers = Vec::new();
let mut removed_count = 0;
let mut current = self.head.get();
while let Some(node_ptr) = current {
let node_ref = node_ptr.as_ref();
let next = node_ref.next.get();
if node_ref.deadline() <= now {
self.remove(node_ptr);
if let Some(waker) = node_ref.take_waker() {
wakers.push(waker);
}
removed_count += 1;
}
current = next;
}
(wakers, removed_count)
}
}
impl Default for TimerSlot {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct TimerWheel<const SLOTS: usize> {
slots: [TimerSlot; SLOTS],
current: usize,
current_tick: u64,
resolution: Duration,
count: usize,
base_time: Instant,
}
impl<const SLOTS: usize> TimerWheel<SLOTS> {
#[must_use]
pub fn new(resolution: Duration) -> Self {
Self::new_at(resolution, Instant::now())
}
#[must_use]
pub fn new_at(resolution: Duration, base_time: Instant) -> Self {
Self {
slots: std::array::from_fn(|_| TimerSlot::new()),
current: 0,
current_tick: 0,
resolution,
count: 0,
base_time,
}
}
#[must_use]
pub fn resolution(&self) -> Duration {
self.resolution
}
#[must_use]
pub fn len(&self) -> usize {
self.count
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.count == 0
}
#[must_use]
pub fn current_slot(&self) -> usize {
self.current
}
#[must_use]
fn current_time(&self) -> Instant {
self.base_time
+ Duration::from_nanos(
self.current_tick
.saturating_mul(duration_to_ns(self.resolution)),
)
}
fn slot_for(&self, deadline: Instant) -> usize {
self.slot_for_with_min_tick(deadline, self.current_tick)
}
fn slot_for_with_min_tick(&self, deadline: Instant, min_tick: u64) -> usize {
let elapsed = deadline.saturating_duration_since(self.base_time);
let ticks = elapsed.as_nanos() / self.resolution.as_nanos().max(1);
let safe_ticks = ticks.max(u128::from(min_tick));
(safe_ticks % (SLOTS as u128)) as usize
}
pub unsafe fn insert(
&mut self,
mut node: std::pin::Pin<&mut TimerNode>,
deadline: Instant,
waker: Waker,
) {
assert!(
!node.is_linked(),
"attempted to insert already-linked timer node"
);
let slot = self.slot_for(deadline);
node.as_mut()
.get_unchecked_mut()
.set(deadline, waker, slot, 0);
let node_ptr = NonNull::from(node.as_mut().get_unchecked_mut());
self.slots[slot].push_back(node_ptr);
self.count += 1;
}
#[allow(clippy::needless_pass_by_value)]
pub unsafe fn cancel(&mut self, node: std::pin::Pin<&mut TimerNode>) {
if !node.is_linked() {
return;
}
let slot = node.slot.get();
let node_ptr = NonNull::from(&*node);
self.slots[slot].remove(node_ptr);
let _ = node.as_ref().take_waker();
self.count = self.count.saturating_sub(1);
}
pub unsafe fn tick(&mut self, now: Instant) -> Vec<Waker> {
let wakers = self.drain_retired_slot(self.current, now);
self.current = (self.current + 1) % SLOTS;
self.current_tick = self.current_tick.saturating_add(1);
wakers
}
pub unsafe fn advance_to(&mut self, now: Instant) -> Vec<Waker> {
let mut all_wakers = Vec::with_capacity(self.count);
let elapsed = now.saturating_duration_since(self.base_time);
let target_tick = elapsed.as_nanos() / self.resolution.as_nanos().max(1);
let target_tick_u64 = target_tick.min(u128::from(u64::MAX)) as u64;
if target_tick_u64 <= self.current_tick {
let (wakers, removed) = self.slots[self.current].collect_expired(now);
self.count = self.count.saturating_sub(removed);
return wakers;
}
let ticks_to_advance = target_tick_u64 - self.current_tick;
if ticks_to_advance >= SLOTS as u64 {
let min_tick = target_tick_u64.saturating_add(1);
for slot_idx in 0..SLOTS {
let wakers = self.drain_slot_with_min_tick(slot_idx, now, min_tick);
all_wakers.extend(wakers);
}
self.current = ((target_tick_u64 + 1) % (SLOTS as u64)) as usize;
} else {
let target_slot = (target_tick_u64 % (SLOTS as u64)) as usize;
let min_tick = target_tick_u64.saturating_add(1);
while self.current != target_slot {
let wakers = self.drain_slot_with_min_tick(self.current, now, min_tick);
all_wakers.extend(wakers);
self.current = (self.current + 1) % SLOTS;
}
let wakers = self.drain_slot_with_min_tick(self.current, now, min_tick);
all_wakers.extend(wakers);
self.current = (self.current + 1) % SLOTS;
}
self.current_tick = target_tick_u64 + 1;
all_wakers
}
#[must_use]
pub fn next_expiration(&self) -> Option<Duration> {
if self.is_empty() {
return None;
}
let now = self.current_time();
let mut min_deadline: Option<Instant> = None;
for slot in &self.slots {
let mut current = slot.head.get();
while let Some(node_ptr) = current {
let node_ref = unsafe { node_ptr.as_ref() };
let deadline = node_ref.deadline();
match min_deadline {
None => min_deadline = Some(deadline),
Some(min) if deadline < min => min_deadline = Some(deadline),
_ => {}
}
current = node_ref.next.get();
}
}
min_deadline.map(|deadline| deadline.saturating_duration_since(now))
}
pub unsafe fn clear(&mut self) {
for slot in &self.slots {
let _ = slot.drain();
}
self.count = 0;
}
unsafe fn drain_retired_slot(&mut self, slot_idx: usize, now: Instant) -> Vec<Waker> {
self.drain_slot_with_min_tick(slot_idx, now, self.current_tick.saturating_add(1))
}
unsafe fn drain_slot_with_min_tick(
&mut self,
slot_idx: usize,
now: Instant,
min_tick: u64,
) -> Vec<Waker> {
let mut wakers = Vec::new();
let mut current = self.slots[slot_idx].take_all();
while let Some(node_ptr) = current {
let node_ref = node_ptr.as_ref();
let next = node_ref.next.get();
node_ref.linked.set(false);
node_ref.prev.set(None);
node_ref.next.set(None);
if node_ref.deadline() <= now {
if let Some(waker) = node_ref.take_waker() {
wakers.push(waker);
}
self.count = self.count.saturating_sub(1);
} else {
let new_slot = self.slot_for_with_min_tick(node_ref.deadline(), min_tick);
node_ref.update_slot_level(new_slot, 0);
self.slots[new_slot].push_back(node_ptr);
}
current = next;
}
wakers
}
}
impl<const SLOTS: usize> Drop for TimerWheel<SLOTS> {
fn drop(&mut self) {
unsafe {
self.clear();
}
}
}
#[derive(Debug)]
pub struct HierarchicalTimerWheel {
level0: WheelLevel<LEVEL0_SLOTS>,
level1: WheelLevel<LEVEL1_SLOTS>,
level2: WheelLevel<LEVEL2_SLOTS>,
level3: WheelLevel<LEVEL3_SLOTS>,
base_time: Instant,
current_tick: u64,
count: usize,
}
const LEVEL0_SLOTS: usize = 256;
const LEVEL1_SLOTS: usize = 64;
const LEVEL2_SLOTS: usize = 64;
const LEVEL3_SLOTS: usize = 128;
const DEFAULT_LEVEL0_RESOLUTION: Duration = Duration::from_millis(1);
#[derive(Debug)]
struct WheelLevel<const SLOTS: usize> {
slots: [TimerSlot; SLOTS],
cursor: usize,
resolution_ns: u64,
}
impl<const SLOTS: usize> WheelLevel<SLOTS> {
fn new(resolution_ns: u64, cursor: usize) -> Self {
Self {
slots: std::array::from_fn(|_| TimerSlot::new()),
cursor,
resolution_ns,
}
}
}
impl HierarchicalTimerWheel {
#[must_use]
pub fn new() -> Self {
Self::new_at(DEFAULT_LEVEL0_RESOLUTION, Instant::now())
}
#[must_use]
pub fn new_at(level0_resolution: Duration, base_time: Instant) -> Self {
let level0_res_ns = duration_to_ns(level0_resolution);
let level1_res_ns = level0_res_ns.saturating_mul(LEVEL0_SLOTS as u64);
let level2_res_ns = level1_res_ns.saturating_mul(LEVEL1_SLOTS as u64);
let level3_res_ns = level2_res_ns.saturating_mul(LEVEL2_SLOTS as u64);
Self {
level0: WheelLevel::new(level0_res_ns, 0),
level1: WheelLevel::new(level1_res_ns, 0),
level2: WheelLevel::new(level2_res_ns, 0),
level3: WheelLevel::new(level3_res_ns, 0),
base_time,
current_tick: 0,
count: 0,
}
}
#[must_use]
pub fn resolution(&self) -> Duration {
Duration::from_nanos(self.level0.resolution_ns.max(1))
}
#[must_use]
pub fn len(&self) -> usize {
self.count
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.count == 0
}
#[must_use]
pub fn current_time(&self) -> Instant {
self.base_time
+ Duration::from_nanos(self.current_tick.saturating_mul(self.level0.resolution_ns))
}
pub unsafe fn insert(
&mut self,
mut node: std::pin::Pin<&mut TimerNode>,
deadline: Instant,
waker: Waker,
) {
assert!(
!node.is_linked(),
"attempted to insert already-linked timer node"
);
let (level, slot) = self.slot_for(deadline);
node.as_mut()
.get_unchecked_mut()
.set(deadline, waker, slot, level);
let node_ptr = NonNull::from(node.as_mut().get_unchecked_mut());
self.push_node(level, slot, node_ptr);
self.count += 1;
}
#[allow(clippy::needless_pass_by_value)]
pub unsafe fn cancel(&mut self, node: std::pin::Pin<&mut TimerNode>) {
if !node.is_linked() {
return;
}
let slot = node.slot.get();
let level = node.level.get();
let node_ptr = NonNull::from(&*node);
self.remove_node(level, slot, node_ptr);
let _ = node.as_ref().take_waker();
self.count = self.count.saturating_sub(1);
}
pub unsafe fn tick(&mut self, now: Instant) -> Vec<Waker> {
let mut wakers = self.drain_level0_current_slot(now);
self.level0.cursor = (self.level0.cursor + 1) % LEVEL0_SLOTS;
self.current_tick = self.current_tick.saturating_add(1);
if self.level0.cursor == 0 {
self.cascade(1, now, &mut wakers);
}
wakers
}
pub unsafe fn advance_to(&mut self, now: Instant) -> Vec<Waker> {
let elapsed = now.saturating_duration_since(self.base_time);
let target_tick = duration_to_ns(elapsed) / self.level0.resolution_ns.max(1);
let mut wakers = Vec::with_capacity(self.count);
let ticks_to_advance = target_tick.saturating_sub(self.current_tick);
if ticks_to_advance > 65536 {
let mut remaining = Vec::with_capacity(self.count);
macro_rules! drain_level {
($level:expr) => {
for slot in &mut $level.slots {
let mut current_node = slot.take_all();
while let Some(node_ptr) = current_node {
let node_ref = node_ptr.as_ref();
let next = node_ref.next.get();
node_ref.linked.set(false);
node_ref.prev.set(None);
node_ref.next.set(None);
if node_ref.deadline() <= now {
if let Some(w) = node_ref.take_waker() {
wakers.push(w);
}
} else {
remaining.push(node_ptr);
}
current_node = next;
}
}
};
}
drain_level!(self.level0);
drain_level!(self.level1);
drain_level!(self.level2);
drain_level!(self.level3);
let next_tick = target_tick + 1;
self.current_tick = next_tick;
self.level0.cursor = (next_tick % LEVEL0_SLOTS as u64) as usize;
self.level1.cursor = ((next_tick / LEVEL0_SLOTS as u64) % LEVEL1_SLOTS as u64) as usize;
self.level2.cursor =
((next_tick / (LEVEL0_SLOTS * LEVEL1_SLOTS) as u64) % LEVEL2_SLOTS as u64) as usize;
self.level3.cursor = ((next_tick / (LEVEL0_SLOTS * LEVEL1_SLOTS * LEVEL2_SLOTS) as u64)
% LEVEL3_SLOTS as u64) as usize;
self.count = 0;
for node in remaining {
let node_ref = node.as_ref();
let (new_level, new_slot) = self.slot_for(node_ref.deadline());
node_ref.update_slot_level(new_slot, new_level);
self.push_node(new_level, new_slot, node);
self.count += 1;
}
return wakers;
}
if ticks_to_advance == 0 {
let (mut current, removed) = self.level0.slots[self.level0.cursor].collect_expired(now);
self.count = self.count.saturating_sub(removed);
wakers.append(&mut current);
return wakers;
}
while self.current_tick < target_tick {
if self.is_empty() {
self.current_tick = target_tick;
self.level0.cursor = (target_tick % LEVEL0_SLOTS as u64) as usize;
self.level1.cursor =
((target_tick / LEVEL0_SLOTS as u64) % LEVEL1_SLOTS as u64) as usize;
self.level2.cursor = ((target_tick / (LEVEL0_SLOTS * LEVEL1_SLOTS) as u64)
% LEVEL2_SLOTS as u64) as usize;
self.level3.cursor = ((target_tick
/ (LEVEL0_SLOTS * LEVEL1_SLOTS * LEVEL2_SLOTS) as u64)
% LEVEL3_SLOTS as u64) as usize;
break;
}
let mut tick_wakers = self.tick(now);
wakers.append(&mut tick_wakers);
}
let mut tick_wakers = self.tick(now);
wakers.append(&mut tick_wakers);
wakers
}
#[must_use]
pub fn next_expiration(&self) -> Option<Duration> {
if self.is_empty() {
return None;
}
let now = self.current_time();
self.min_deadline()
.map(|deadline| deadline.saturating_duration_since(now))
}
pub unsafe fn clear(&mut self) {
let _ = self.level0.clear_slots();
let _ = self.level1.clear_slots();
let _ = self.level2.clear_slots();
let _ = self.level3.clear_slots();
self.count = 0;
}
unsafe fn drain_level0_current_slot(&mut self, now: Instant) -> Vec<Waker> {
let slot_idx = self.level0.cursor;
let mut wakers = Vec::new();
let mut current_node = self.level0.slots[slot_idx].take_all();
while let Some(node_ptr) = current_node {
let node_ref = node_ptr.as_ref();
let next = node_ref.next.get();
node_ref.linked.set(false);
node_ref.prev.set(None);
node_ref.next.set(None);
if node_ref.deadline() <= now {
if let Some(waker) = node_ref.take_waker() {
wakers.push(waker);
}
self.count = self.count.saturating_sub(1);
} else {
let (new_level, new_slot) = self
.slot_for_from_tick(node_ref.deadline(), self.current_tick.saturating_add(1));
node_ref.update_slot_level(new_slot, new_level);
self.push_node(new_level, new_slot, node_ptr);
}
current_node = next;
}
wakers
}
fn slot_for(&self, deadline: Instant) -> (u8, usize) {
self.slot_for_from_tick(deadline, self.current_tick)
}
fn slot_for_from_tick(&self, deadline: Instant, min_level0_tick: u64) -> (u8, usize) {
let current = self.base_time
+ Duration::from_nanos(min_level0_tick.saturating_mul(self.level0.resolution_ns));
let delta_ns = duration_to_ns(deadline.saturating_duration_since(current));
let ticks_until = delta_ns / self.level0.resolution_ns.max(1);
if ticks_until < LEVEL0_SLOTS as u64 {
(
0,
self.slot_for_level_from_tick(
deadline,
&self.level0,
LEVEL0_SLOTS,
min_level0_tick,
),
)
} else if ticks_until < (LEVEL0_SLOTS * LEVEL1_SLOTS) as u64 {
(
1,
self.slot_for_level_from_tick(
deadline,
&self.level1,
LEVEL1_SLOTS,
min_level0_tick,
),
)
} else if ticks_until < (LEVEL0_SLOTS * LEVEL1_SLOTS * LEVEL2_SLOTS) as u64 {
(
2,
self.slot_for_level_from_tick(
deadline,
&self.level2,
LEVEL2_SLOTS,
min_level0_tick,
),
)
} else {
(
3,
self.slot_for_level_from_tick(
deadline,
&self.level3,
LEVEL3_SLOTS,
min_level0_tick,
),
)
}
}
fn slot_for_level_from_tick<const SLOTS: usize>(
&self,
deadline: Instant,
level: &WheelLevel<SLOTS>,
slots: usize,
min_level0_tick: u64,
) -> usize {
let elapsed_ns = duration_to_ns(deadline.saturating_duration_since(self.base_time));
let tick = elapsed_ns / level.resolution_ns.max(1);
let current_elapsed_ns = min_level0_tick.saturating_mul(self.level0.resolution_ns);
let current_level_tick = current_elapsed_ns / level.resolution_ns.max(1);
let safe_tick = tick.max(current_level_tick);
(safe_tick % slots as u64) as usize
}
fn push_node(&self, level: u8, slot: usize, node: NonNull<TimerNode>) {
match level {
0 => unsafe { self.level0.slots[slot].push_back(node) },
1 => unsafe { self.level1.slots[slot].push_back(node) },
2 => unsafe { self.level2.slots[slot].push_back(node) },
_ => unsafe { self.level3.slots[slot].push_back(node) },
}
}
fn remove_node(&self, level: u8, slot: usize, node: NonNull<TimerNode>) {
match level {
0 => unsafe { self.level0.slots[slot].remove(node) },
1 => unsafe { self.level1.slots[slot].remove(node) },
2 => unsafe { self.level2.slots[slot].remove(node) },
_ => unsafe { self.level3.slots[slot].remove(node) },
}
}
#[allow(clippy::only_used_in_recursion)]
fn cascade(&mut self, level_index: u8, now: Instant, wakers: &mut Vec<Waker>) {
let (mut current_node, wrapped) = match level_index {
1 => self.level1.advance_and_take(),
2 => self.level2.advance_and_take(),
3 => self.level3.advance_and_take(),
_ => return,
};
while let Some(node_ptr) = current_node {
let node_ref = unsafe { node_ptr.as_ref() };
let next = node_ref.next.get();
node_ref.linked.set(false);
node_ref.prev.set(None);
node_ref.next.set(None);
if node_ref.deadline() <= now {
if let Some(waker) = node_ref.take_waker() {
wakers.push(waker);
}
self.count = self.count.saturating_sub(1);
} else {
let (new_level, new_slot) =
self.slot_for_from_tick(node_ref.deadline(), self.current_tick);
node_ref.update_slot_level(new_slot, new_level);
self.push_node(new_level, new_slot, node_ptr);
}
current_node = next;
}
if wrapped {
self.cascade(level_index + 1, now, wakers);
}
}
fn min_deadline(&self) -> Option<Instant> {
let mut min_deadline: Option<Instant> = None;
for deadline in self.iter_deadlines() {
min_deadline = Some(min_deadline.map_or(deadline, |current| current.min(deadline)));
}
min_deadline
}
fn iter_deadlines(&self) -> impl Iterator<Item = Instant> + '_ {
self.level0
.iter_deadlines()
.chain(self.level1.iter_deadlines())
.chain(self.level2.iter_deadlines())
.chain(self.level3.iter_deadlines())
}
}
impl Default for HierarchicalTimerWheel {
fn default() -> Self {
Self::new()
}
}
impl Drop for HierarchicalTimerWheel {
fn drop(&mut self) {
unsafe {
self.clear();
}
}
}
impl<const SLOTS: usize> WheelLevel<SLOTS> {
fn iter_deadlines(&self) -> impl Iterator<Item = Instant> + '_ {
self.slots.iter().flat_map(TimerSlot::iter_deadlines)
}
unsafe fn clear_slots(&mut self) -> Vec<Waker> {
let mut wakers = Vec::new();
for slot in &self.slots {
wakers.extend(slot.drain());
}
wakers
}
fn advance_and_take(&mut self) -> (Option<NonNull<TimerNode>>, bool) {
self.cursor = (self.cursor + 1) % SLOTS;
let wrapped = self.cursor == 0;
let head = unsafe { self.slots[self.cursor].take_all() };
(head, wrapped)
}
}
impl TimerSlot {
fn iter_deadlines(&self) -> impl Iterator<Item = Instant> + '_ {
TimerSlotIter::new(self.head.get())
}
}
struct TimerSlotIter {
current: Option<NonNull<TimerNode>>,
}
impl TimerSlotIter {
fn new(current: Option<NonNull<TimerNode>>) -> Self {
Self { current }
}
}
impl Iterator for TimerSlotIter {
type Item = Instant;
fn next(&mut self) -> Option<Self::Item> {
let node = self.current?;
let node_ref = unsafe { node.as_ref() };
self.current = node_ref.next.get();
Some(node_ref.deadline())
}
}
fn duration_to_ns(duration: Duration) -> u64 {
duration.as_nanos().min(u128::from(u64::MAX)) as u64
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::task::Wake;
fn init_test(name: &str) {
crate::test_utils::init_test_logging();
crate::test_phase!(name);
}
struct CounterWaker {
counter: Arc<AtomicU64>,
}
impl Wake for CounterWaker {
fn wake(self: Arc<Self>) {
self.counter.fetch_add(1, Ordering::SeqCst);
}
fn wake_by_ref(self: &Arc<Self>) {
self.counter.fetch_add(1, Ordering::SeqCst);
}
}
fn counter_waker(counter: Arc<AtomicU64>) -> Waker {
Arc::new(CounterWaker { counter }).into()
}
fn intrusive_wheel_signature<const SLOTS: usize>(
entries: &[(u16, bool)],
insertion_order: &[usize],
cancel_after_insert: bool,
) -> (Option<Duration>, usize, usize, bool) {
let base = Instant::now();
let mut wheel: TimerWheel<SLOTS> = TimerWheel::new_at(Duration::from_millis(1), base);
let counter = Arc::new(AtomicU64::new(0));
let mut nodes: Vec<Pin<Box<TimerNode>>> = (0..entries.len())
.map(|_| Box::pin(TimerNode::new()))
.collect();
for &index in insertion_order {
let (offset_ms, cancelled) = entries[index];
if !cancel_after_insert && cancelled {
continue;
}
let deadline = base + Duration::from_millis(u64::from(offset_ms));
let waker = counter_waker(counter.clone());
unsafe {
wheel.insert(nodes[index].as_mut(), deadline, waker);
}
}
if cancel_after_insert {
for (index, (_, cancelled)) in entries.iter().enumerate() {
if *cancelled {
unsafe {
wheel.cancel(nodes[index].as_mut());
}
}
}
}
let next = wheel.next_expiration();
let max_offset_ms = entries
.iter()
.filter_map(|(offset_ms, cancelled)| (!cancelled).then_some(*offset_ms))
.max()
.unwrap_or(0);
let advance_target = base + Duration::from_millis(u64::from(max_offset_ms) + 2);
let wakers = unsafe { wheel.advance_to(advance_target) };
let wake_count = wakers.len();
for waker in wakers {
waker.wake();
}
(
next,
wake_count,
nodes.iter().filter(|node| node.is_linked()).count(),
wheel.is_empty(),
)
}
#[test]
fn intrusive_wheel_new() {
init_test("intrusive_wheel_new");
let wheel: TimerWheel<256> = TimerWheel::new(Duration::from_millis(1));
crate::assert_with_log!(
wheel.is_empty(),
"wheel starts empty",
true,
wheel.is_empty()
);
crate::assert_with_log!(wheel.is_empty(), "len is 0", 0, wheel.len());
crate::assert_with_log!(
wheel.resolution() == Duration::from_millis(1),
"resolution",
Duration::from_millis(1),
wheel.resolution()
);
crate::test_complete!("intrusive_wheel_new");
}
#[test]
fn intrusive_wheel_insert_and_tick() {
init_test("intrusive_wheel_insert_and_tick");
let base = Instant::now();
let mut wheel: TimerWheel<256> = TimerWheel::new_at(Duration::from_millis(1), base);
let counter = Arc::new(AtomicU64::new(0));
let mut node = Box::pin(TimerNode::new());
let deadline = base + Duration::from_millis(5);
let waker = counter_waker(counter.clone());
unsafe {
wheel.insert(node.as_mut(), deadline, waker);
}
crate::assert_with_log!(wheel.len() == 1, "len is 1", 1, wheel.len());
crate::assert_with_log!(node.is_linked(), "node is linked", true, node.is_linked());
std::thread::sleep(Duration::from_millis(10));
let wakers = unsafe { wheel.advance_to(Instant::now()) };
crate::assert_with_log!(wakers.len() == 1, "got 1 waker", 1, wakers.len());
for waker in wakers {
waker.wake();
}
let count = counter.load(Ordering::SeqCst);
crate::assert_with_log!(count == 1, "counter is 1", 1, count);
crate::assert_with_log!(wheel.is_empty(), "wheel is empty", true, wheel.is_empty());
crate::test_complete!("intrusive_wheel_insert_and_tick");
}
#[test]
fn intrusive_wheel_cancel() {
init_test("intrusive_wheel_cancel");
let base = Instant::now();
let mut wheel: TimerWheel<256> = TimerWheel::new_at(Duration::from_millis(1), base);
let counter = Arc::new(AtomicU64::new(0));
let mut node = Box::pin(TimerNode::new());
let deadline = base + Duration::from_millis(50);
let waker = counter_waker(counter.clone());
unsafe {
wheel.insert(node.as_mut(), deadline, waker);
}
crate::assert_with_log!(wheel.len() == 1, "len is 1", 1, wheel.len());
unsafe {
wheel.cancel(node.as_mut());
}
crate::assert_with_log!(!node.is_linked(), "node unlinked", false, node.is_linked());
crate::assert_with_log!(wheel.is_empty(), "wheel is empty", true, wheel.is_empty());
std::thread::sleep(Duration::from_millis(60));
let wakers = unsafe { wheel.advance_to(Instant::now()) };
crate::assert_with_log!(wakers.is_empty(), "no wakers", true, wakers.is_empty());
let count = counter.load(Ordering::SeqCst);
crate::assert_with_log!(count == 0, "counter is 0", 0, count);
crate::test_complete!("intrusive_wheel_cancel");
}
#[test]
fn intrusive_wheel_multiple_timers() {
init_test("intrusive_wheel_multiple_timers");
let base = Instant::now();
let mut wheel: TimerWheel<256> = TimerWheel::new_at(Duration::from_millis(1), base);
let counter = Arc::new(AtomicU64::new(0));
let mut nodes: Vec<Pin<Box<TimerNode>>> =
(0..5).map(|_| Box::pin(TimerNode::new())).collect();
for (i, node) in nodes.iter_mut().enumerate() {
let deadline = base + Duration::from_millis((i as u64 + 1) * 10);
let waker = counter_waker(counter.clone());
unsafe {
wheel.insert(node.as_mut(), deadline, waker);
}
}
crate::assert_with_log!(wheel.len() == 5, "len is 5", 5, wheel.len());
std::thread::sleep(Duration::from_millis(60));
let wakers = unsafe { wheel.advance_to(Instant::now()) };
crate::assert_with_log!(wakers.len() == 5, "got 5 wakers", 5, wakers.len());
for waker in wakers {
waker.wake();
}
let count = counter.load(Ordering::SeqCst);
crate::assert_with_log!(count == 5, "counter is 5", 5, count);
crate::test_complete!("intrusive_wheel_multiple_timers");
}
#[test]
fn intrusive_wheel_wrap_around() {
init_test("intrusive_wheel_wrap_around");
let base = Instant::now();
let mut wheel: TimerWheel<4> = TimerWheel::new_at(Duration::from_millis(10), base);
let counter = Arc::new(AtomicU64::new(0));
let mut node = Box::pin(TimerNode::new());
let deadline = base + Duration::from_millis(50);
let waker = counter_waker(counter);
unsafe {
wheel.insert(node.as_mut(), deadline, waker);
}
let slot = wheel.slot_for(deadline);
crate::assert_with_log!(slot == 1, "slot wraps to 1", 1, slot);
std::thread::sleep(Duration::from_millis(60));
let wakers = unsafe { wheel.advance_to(Instant::now()) };
crate::assert_with_log!(wakers.len() == 1, "fired", 1, wakers.len());
crate::test_complete!("intrusive_wheel_wrap_around");
}
#[test]
fn intrusive_wheel_clear() {
init_test("intrusive_wheel_clear");
let base = Instant::now();
let mut wheel: TimerWheel<256> = TimerWheel::new_at(Duration::from_millis(1), base);
let counter = Arc::new(AtomicU64::new(0));
let mut nodes: Vec<Pin<Box<TimerNode>>> =
(0..3).map(|_| Box::pin(TimerNode::new())).collect();
for (i, node) in nodes.iter_mut().enumerate() {
let deadline = base + Duration::from_millis((i as u64 + 1) * 10);
let waker = counter_waker(counter.clone());
unsafe {
wheel.insert(node.as_mut(), deadline, waker);
}
}
crate::assert_with_log!(wheel.len() == 3, "len is 3", 3, wheel.len());
unsafe {
wheel.clear();
}
crate::assert_with_log!(wheel.is_empty(), "wheel empty", true, wheel.is_empty());
let count = counter.load(Ordering::SeqCst);
crate::assert_with_log!(count == 0, "counter still 0", 0, count);
crate::test_complete!("intrusive_wheel_clear");
}
#[test]
fn timer_node_default() {
init_test("timer_node_default");
let node = TimerNode::default();
crate::assert_with_log!(!node.is_linked(), "not linked", false, node.is_linked());
crate::test_complete!("timer_node_default");
}
#[test]
fn intrusive_wheel_next_expiration() {
init_test("intrusive_wheel_next_expiration");
let base = Instant::now();
let mut wheel: TimerWheel<256> = TimerWheel::new_at(Duration::from_millis(1), base);
let empty = wheel.next_expiration();
crate::assert_with_log!(empty.is_none(), "empty wheel", true, empty.is_none());
let mut node = Box::pin(TimerNode::new());
let deadline = base + Duration::from_millis(100);
let waker = Arc::new(CounterWaker {
counter: Arc::new(AtomicU64::new(0)),
})
.into();
unsafe {
wheel.insert(node.as_mut(), deadline, waker);
}
let next = wheel.next_expiration();
crate::assert_with_log!(next.is_some(), "has expiration", true, next.is_some());
unsafe {
wheel.cancel(node.as_mut());
}
crate::test_complete!("intrusive_wheel_next_expiration");
}
#[test]
fn intrusive_wheel_next_expiration_uses_wheel_time_not_wall_clock() {
init_test("intrusive_wheel_next_expiration_uses_wheel_time_not_wall_clock");
let base = Instant::now()
.checked_sub(Duration::from_secs(1))
.unwrap_or_else(Instant::now);
let mut wheel: TimerWheel<256> = TimerWheel::new_at(Duration::from_millis(1), base);
let mut node = Box::pin(TimerNode::new());
let deadline = base + Duration::from_millis(100);
let waker = Arc::new(CounterWaker {
counter: Arc::new(AtomicU64::new(0)),
})
.into();
unsafe {
wheel.insert(node.as_mut(), deadline, waker);
}
let next = wheel.next_expiration();
crate::assert_with_log!(
next == Some(Duration::from_millis(100)),
"next expiration is relative to wheel progress, not ambient wall clock",
Some(Duration::from_millis(100)),
next
);
unsafe {
wheel.cancel(node.as_mut());
}
crate::test_complete!("intrusive_wheel_next_expiration_uses_wheel_time_not_wall_clock");
}
#[test]
fn intrusive_wheel_rebuckets_nonexpired_timer_from_retired_slot() {
init_test("intrusive_wheel_rebuckets_nonexpired_timer_from_retired_slot");
let base = Instant::now();
let mut wheel: TimerWheel<256> = TimerWheel::new_at(Duration::from_millis(1), base);
let counter = Arc::new(AtomicU64::new(0));
let mut node = Box::pin(TimerNode::new());
let deadline = base + Duration::from_micros(5_500);
let waker = counter_waker(counter.clone());
unsafe {
wheel.insert(node.as_mut(), deadline, waker);
}
let early = unsafe { wheel.advance_to(base + Duration::from_millis(5)) };
crate::assert_with_log!(
early.is_empty(),
"timer later in the bucket must not fire at the bucket boundary",
true,
early.is_empty()
);
crate::assert_with_log!(
node.is_linked(),
"timer stays scheduled",
true,
node.is_linked()
);
let due = unsafe { wheel.advance_to(base + Duration::from_millis(6)) };
crate::assert_with_log!(due.len() == 1, "timer fires on next tick", 1, due.len());
for waker in due {
waker.wake();
}
let count = counter.load(Ordering::SeqCst);
crate::assert_with_log!(count == 1, "waker fired once", 1, count);
crate::test_complete!("intrusive_wheel_rebuckets_nonexpired_timer_from_retired_slot");
}
#[test]
fn intrusive_wheel_rebuckets_survivor_after_full_rotation_advance() {
init_test("intrusive_wheel_rebuckets_survivor_after_full_rotation_advance");
let base = Instant::now();
let mut wheel: TimerWheel<4> = TimerWheel::new_at(Duration::from_millis(10), base);
let counter = Arc::new(AtomicU64::new(0));
let mut node = Box::pin(TimerNode::new());
let deadline = base + Duration::from_millis(45);
let waker = counter_waker(counter.clone());
unsafe {
wheel.insert(node.as_mut(), deadline, waker);
}
let early = unsafe { wheel.advance_to(base + Duration::from_millis(40)) };
crate::assert_with_log!(
early.is_empty(),
"full-rotation advance must not fire future timer early",
true,
early.is_empty()
);
crate::assert_with_log!(
node.is_linked(),
"future timer stays scheduled after full rotation",
true,
node.is_linked()
);
let due = unsafe { wheel.advance_to(base + Duration::from_millis(50)) };
crate::assert_with_log!(
due.len() == 1,
"future timer fires after rebucketing from full rotation",
1,
due.len()
);
for waker in due {
waker.wake();
}
let count = counter.load(Ordering::SeqCst);
crate::assert_with_log!(count == 1, "waker fired once", 1, count);
crate::test_complete!("intrusive_wheel_rebuckets_survivor_after_full_rotation_advance");
}
#[test]
fn hierarchical_cascade_fires_expired() {
init_test("hierarchical_cascade_fires_expired");
let base = Instant::now()
.checked_sub(Duration::from_secs(1))
.unwrap_or_else(Instant::now);
let mut wheel = HierarchicalTimerWheel::new_at(Duration::from_millis(1), base);
let counter = Arc::new(AtomicU64::new(0));
let mut node = Box::pin(TimerNode::new());
let deadline = base + Duration::from_millis(300);
let waker = counter_waker(counter.clone());
unsafe {
wheel.insert(node.as_mut(), deadline, waker);
}
let (level, slot) = wheel.slot_for(deadline);
crate::assert_with_log!(level == 1, "timer placed in level1", 1u8, level);
let mut wakers = Vec::new();
for _ in 0..(LEVEL0_SLOTS * (slot + 1)) {
let mut tick_wakers = unsafe { wheel.tick(Instant::now()) };
wakers.append(&mut tick_wakers);
}
for waker in wakers {
waker.wake();
}
let count = counter.load(Ordering::SeqCst);
crate::assert_with_log!(count == 1, "expired fired", 1, count);
crate::assert_with_log!(wheel.is_empty(), "wheel empty", true, wheel.is_empty());
crate::test_complete!("hierarchical_cascade_fires_expired");
}
#[test]
fn hierarchical_next_expiration_uses_wheel_time_not_wall_clock() {
init_test("hierarchical_next_expiration_uses_wheel_time_not_wall_clock");
let base = Instant::now()
.checked_sub(Duration::from_secs(1))
.unwrap_or_else(Instant::now);
let mut wheel = HierarchicalTimerWheel::new_at(Duration::from_millis(1), base);
let mut node = Box::pin(TimerNode::new());
let deadline = base + Duration::from_millis(100);
let waker = Arc::new(CounterWaker {
counter: Arc::new(AtomicU64::new(0)),
})
.into();
unsafe {
wheel.insert(node.as_mut(), deadline, waker);
}
let next = wheel.next_expiration();
crate::assert_with_log!(
next == Some(Duration::from_millis(100)),
"hierarchical next expiration is relative to wheel progress, not ambient wall clock",
Some(Duration::from_millis(100)),
next
);
unsafe {
wheel.cancel(node.as_mut());
}
crate::test_complete!("hierarchical_next_expiration_uses_wheel_time_not_wall_clock");
}
#[test]
fn hierarchical_wheel_rebuckets_nonexpired_level0_timer_from_retired_slot() {
init_test("hierarchical_wheel_rebuckets_nonexpired_level0_timer_from_retired_slot");
let base = Instant::now();
let mut wheel = HierarchicalTimerWheel::new_at(Duration::from_millis(1), base);
let counter = Arc::new(AtomicU64::new(0));
let mut node = Box::pin(TimerNode::new());
let deadline = base + Duration::from_micros(5_500);
let waker = counter_waker(counter.clone());
unsafe {
wheel.insert(node.as_mut(), deadline, waker);
}
let early = unsafe { wheel.advance_to(base + Duration::from_millis(5)) };
crate::assert_with_log!(
early.is_empty(),
"hierarchical wheel must not fire later-in-bucket timers early",
true,
early.is_empty()
);
crate::assert_with_log!(
node.is_linked(),
"timer stays scheduled",
true,
node.is_linked()
);
let due = unsafe { wheel.advance_to(base + Duration::from_millis(6)) };
crate::assert_with_log!(due.len() == 1, "timer fires on next tick", 1, due.len());
for waker in due {
waker.wake();
}
let count = counter.load(Ordering::SeqCst);
crate::assert_with_log!(count == 1, "waker fired once", 1, count);
crate::test_complete!(
"hierarchical_wheel_rebuckets_nonexpired_level0_timer_from_retired_slot"
);
}
#[test]
fn hierarchical_wheel_cascade_survivor_reinserts_without_extra_tick_delay() {
init_test("hierarchical_wheel_cascade_survivor_reinserts_without_extra_tick_delay");
let base = Instant::now();
let mut wheel = HierarchicalTimerWheel::new_at(Duration::from_millis(1), base);
let counter = Arc::new(AtomicU64::new(0));
let mut node = Box::pin(TimerNode::new());
let deadline = base + Duration::from_micros(256_500);
let waker = counter_waker(counter.clone());
unsafe {
wheel.insert(node.as_mut(), deadline, waker);
}
let before_due = unsafe { wheel.advance_to(base + Duration::from_millis(256)) };
crate::assert_with_log!(
before_due.is_empty(),
"cascade boundary must not fire timer early",
true,
before_due.is_empty()
);
crate::assert_with_log!(
node.is_linked(),
"timer remains scheduled after cascade rebucketing",
true,
node.is_linked()
);
let due = unsafe { wheel.advance_to(base + Duration::from_millis(257)) };
crate::assert_with_log!(
due.len() == 1,
"timer fires on the immediate next tick after cascade",
1,
due.len()
);
for waker in due {
waker.wake();
}
let count = counter.load(Ordering::SeqCst);
crate::assert_with_log!(count == 1, "waker fired exactly once", 1, count);
crate::assert_with_log!(
wheel.is_empty(),
"wheel is empty after firing",
true,
wheel.is_empty()
);
crate::test_complete!(
"hierarchical_wheel_cascade_survivor_reinserts_without_extra_tick_delay"
);
}
proptest! {
#[test]
fn metamorphic_intrusive_wheel_cancelled_subset_matches_filtered_rotation(
entries in prop::collection::vec((1u16..96u16, any::<bool>()), 1..12),
raw_shift in 0usize..32,
) {
let mut rotated_order: Vec<usize> = (0..entries.len()).collect();
rotated_order.rotate_left(raw_shift % entries.len());
let base_signature =
intrusive_wheel_signature::<16>(&entries, &rotated_order, true);
let filtered_signature =
intrusive_wheel_signature::<16>(&entries, &rotated_order, false);
let survivor_count = entries.iter().filter(|(_, cancelled)| !cancelled).count();
prop_assert_eq!(
base_signature.0,
filtered_signature.0,
"cancelling a subset after insertion must preserve the next expiration of surviving timers",
);
prop_assert_eq!(
base_signature.1,
survivor_count,
"exactly the uncancelled timers should fire",
);
prop_assert_eq!(
filtered_signature.1,
survivor_count,
"filtered insertion should fire the same surviving timers",
);
prop_assert_eq!(
base_signature.1,
filtered_signature.1,
"post-insert cancellation must match excluding the same timers up front",
);
prop_assert_eq!(
base_signature.2,
0,
"no timer nodes should remain linked after advancing past all survivors",
);
prop_assert_eq!(
filtered_signature.2,
0,
"filtered run must also drain all timer nodes",
);
prop_assert!(
base_signature.3 && filtered_signature.3,
"both wheels should be empty after draining surviving timers",
);
}
#[test]
fn metamorphic_intrusive_wheel_split_advance_matches_direct_frontier(
offsets in prop::collection::vec(1u16..96u16, 1..12),
raw_split_ms in 0u16..96u16,
) {
let base = Instant::now();
let mut split_wheel: TimerWheel<16> = TimerWheel::new_at(Duration::from_millis(1), base);
let mut direct_wheel: TimerWheel<16> = TimerWheel::new_at(Duration::from_millis(1), base);
let counter = Arc::new(AtomicU64::new(0));
let mut split_nodes: Vec<Pin<Box<TimerNode>>> =
(0..offsets.len()).map(|_| Box::pin(TimerNode::new())).collect();
let mut direct_nodes: Vec<Pin<Box<TimerNode>>> =
(0..offsets.len()).map(|_| Box::pin(TimerNode::new())).collect();
for (index, offset_ms) in offsets.iter().copied().enumerate() {
let deadline = base + Duration::from_millis(u64::from(offset_ms));
unsafe {
split_wheel.insert(
split_nodes[index].as_mut(),
deadline,
counter_waker(counter.clone()),
);
direct_wheel.insert(
direct_nodes[index].as_mut(),
deadline,
counter_waker(counter.clone()),
);
}
}
let max_offset_ms = offsets.iter().copied().max().unwrap_or(0);
let split_ms = raw_split_ms.min(max_offset_ms);
let early_target = base + Duration::from_millis(u64::from(split_ms));
let late_target = base + Duration::from_millis(u64::from(max_offset_ms) + 2);
let early_wakers = unsafe { split_wheel.advance_to(early_target) };
let late_wakers = unsafe { split_wheel.advance_to(late_target) };
let direct_wakers = unsafe { direct_wheel.advance_to(late_target) };
prop_assert!(
early_wakers.len() <= direct_wakers.len(),
"an earlier frontier cannot fire more timers than the later direct frontier",
);
prop_assert_eq!(
early_wakers.len() + late_wakers.len(),
direct_wakers.len(),
"splitting the advance must preserve the total timers fired by the final frontier",
);
prop_assert!(
split_wheel.is_empty() && direct_wheel.is_empty(),
"both wheels should be empty after advancing past the latest deadline",
);
prop_assert_eq!(
split_nodes.iter().filter(|node| node.is_linked()).count(),
0,
"split advance must unlink every timer node",
);
prop_assert_eq!(
direct_nodes.iter().filter(|node| node.is_linked()).count(),
0,
"direct advance must unlink every timer node",
);
}
}
}