use std::{
cell::RefCell,
task::Waker,
time::{Duration, Instant},
};
use slab::Slab;
const NUM_LEVELS: usize = 7;
const SLOTS_PER_LEVEL: usize = 64;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct TimerHandle {
slab_index: usize,
generation: u64,
}
struct TimerEntry {
waker: Waker,
expiration_tick: u64,
generation: u64,
level: usize,
slot: usize,
}
struct TimingWheel {
entries: Slab<TimerEntry>,
wheels: [Vec<Vec<usize>>; NUM_LEVELS],
current_tick: u64,
generation_counter: u64,
min_expiration_tick: Option<u64>,
}
impl TimingWheel {
#[inline]
pub fn new() -> Self {
Self {
entries: Slab::new(),
wheels: std::array::from_fn(|_| vec![Vec::new(); SLOTS_PER_LEVEL]),
current_tick: 0,
generation_counter: 0,
min_expiration_tick: None,
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
#[inline]
pub fn now(&self) -> u64 {
self.current_tick
}
#[inline]
pub fn nearest_wakeup(&self) -> Option<std::num::NonZeroU64> {
self.min_expiration_tick.and_then(std::num::NonZeroU64::new)
}
#[inline]
fn update_min_on_insert(&mut self, expiration_tick: u64) {
self.min_expiration_tick = Some(
self.min_expiration_tick
.map_or(expiration_tick, |min| min.min(expiration_tick)),
);
}
fn update_min_on_remove(&mut self) {
if self.entries.is_empty() {
self.min_expiration_tick = None;
} else {
self.min_expiration_tick = self
.entries
.iter()
.map(|(_, entry)| entry.expiration_tick)
.min();
}
}
#[inline]
pub fn insert(&mut self, waker: Waker, expiration_tick: u64) -> TimerHandle {
let delay = expiration_tick.saturating_sub(self.current_tick);
let (level, slot) = self.calculate_level_and_slot(delay);
self.generation_counter += 1;
let generation = self.generation_counter;
let slab_index = self.entries.insert(TimerEntry {
waker,
expiration_tick,
generation,
level,
slot,
});
self.wheels[level][slot].push(slab_index);
self.update_min_on_insert(expiration_tick);
TimerHandle {
slab_index,
generation,
}
}
#[inline]
pub fn remove(&mut self, handle: TimerHandle) {
if let Some(entry) = self.entries.get(handle.slab_index) {
if entry.generation == handle.generation {
let level = entry.level;
let slot = entry.slot;
let slab_index = handle.slab_index;
if let Some(pos) = self.wheels[level][slot]
.iter()
.position(|&idx| idx == slab_index)
{
self.wheels[level][slot].remove(pos);
}
self.entries.remove(slab_index);
self.update_min_on_remove();
}
}
}
pub fn advance(&mut self, ticks: u64) -> Vec<Waker> {
if ticks == 0 {
return Vec::new();
}
let mut expired_wakers = Vec::new();
let start_tick = self.current_tick;
self.current_tick += ticks;
self.cascade_all_levels(&mut expired_wakers);
let start_slot = (start_tick as usize) & (SLOTS_PER_LEVEL - 1);
let end_slot = (self.current_tick as usize) & (SLOTS_PER_LEVEL - 1);
if ticks >= SLOTS_PER_LEVEL as u64 {
for slot in 0..SLOTS_PER_LEVEL {
self.process_slot_at_level_0(slot, &mut expired_wakers);
}
} else if start_slot <= end_slot {
for slot in start_slot..=end_slot {
self.process_slot_at_level_0(slot, &mut expired_wakers);
}
} else {
for slot in start_slot..SLOTS_PER_LEVEL {
self.process_slot_at_level_0(slot, &mut expired_wakers);
}
for slot in 0..=end_slot {
self.process_slot_at_level_0(slot, &mut expired_wakers);
}
}
if !expired_wakers.is_empty() {
self.update_min_on_remove();
}
expired_wakers
}
fn cascade_all_levels(&mut self, expired_wakers: &mut Vec<Waker>) {
for level in (1..NUM_LEVELS).rev() {
self.cascade_level(level, expired_wakers);
}
}
fn cascade_level(&mut self, level: usize, expired_wakers: &mut Vec<Waker>) {
for slot in 0..SLOTS_PER_LEVEL {
let timers_to_cascade: Vec<usize> = std::mem::take(&mut self.wheels[level][slot]);
for slab_index in timers_to_cascade {
if let Some(entry) = self.entries.get(slab_index) {
if entry.expiration_tick <= self.current_tick {
let entry = self.entries.remove(slab_index);
expired_wakers.push(entry.waker);
} else {
let delay = entry.expiration_tick - self.current_tick;
let (new_level, new_slot) = self.calculate_level_and_slot(delay);
if let Some(entry) = self.entries.get_mut(slab_index) {
entry.level = new_level;
entry.slot = new_slot;
}
self.wheels[new_level][new_slot].push(slab_index);
}
}
}
}
}
fn process_slot_at_level_0(&mut self, slot: usize, expired_wakers: &mut Vec<Waker>) {
let timers_to_process: Vec<usize> = std::mem::take(&mut self.wheels[0][slot]);
for slab_index in timers_to_process {
if let Some(entry) = self.entries.get(slab_index) {
if entry.expiration_tick <= self.current_tick {
let entry = self.entries.remove(slab_index);
expired_wakers.push(entry.waker);
} else {
let delay = entry.expiration_tick - self.current_tick;
let (new_level, new_slot) = self.calculate_level_and_slot(delay);
if let Some(entry) = self.entries.get_mut(slab_index) {
entry.level = new_level;
entry.slot = new_slot;
}
self.wheels[new_level][new_slot].push(slab_index);
}
}
}
}
#[inline]
fn calculate_level_and_slot(&self, delay: u64) -> (usize, usize) {
for level in 0..NUM_LEVELS {
let level_shift = 6 * level;
if level == NUM_LEVELS - 1 {
let slot = ((delay >> level_shift) & (SLOTS_PER_LEVEL as u64 - 1)) as usize;
return (level, slot);
}
let max_for_level = ((SLOTS_PER_LEVEL as u64) << level_shift) - 1;
if delay <= max_for_level || level == NUM_LEVELS - 1 {
let slot = ((delay >> level_shift) & (SLOTS_PER_LEVEL as u64 - 1)) as usize;
return (level, slot);
}
}
let level_shift = 6 * (NUM_LEVELS - 1);
let slot = ((delay >> level_shift) & (SLOTS_PER_LEVEL as u64 - 1)) as usize;
(NUM_LEVELS - 1, slot)
}
}
impl Default for TimingWheel {
fn default() -> Self {
Self::new()
}
}
pub struct Timer {
wheel: RefCell<TimingWheel>,
instant: RefCell<Instant>,
}
impl Timer {
#[inline]
pub fn new() -> Self {
Self {
wheel: RefCell::new(TimingWheel::new()),
instant: RefCell::new(Instant::now()),
}
}
#[inline]
pub fn submit(&self, deadline: Instant, waker: Waker) -> Option<TimerHandle> {
let millis = deadline
.saturating_duration_since(*self.instant.borrow())
.as_millis() as u64;
if millis < 1 {
waker.wake();
return None;
}
let mut wheel = self.wheel.borrow_mut();
let now = wheel.now();
Some(wheel.insert(waker, now + millis))
}
#[inline]
pub fn cancel(&self, handle: TimerHandle) {
let mut wheel = self.wheel.borrow_mut();
wheel.remove(handle);
}
#[inline]
pub fn spin_and_get_deadline(&self) -> (Option<Duration>, bool) {
let mut instant = self.instant.borrow_mut();
let mut wheel = self.wheel.borrow_mut();
let mut woken_up = false;
if !wheel.is_empty() {
for waker in wheel.advance(instant.elapsed().as_millis() as u64) {
waker.wake();
woken_up = true;
}
}
*instant = Instant::now();
let now = wheel.now();
(
wheel
.nearest_wakeup()
.map(|deadline| Duration::from_millis(deadline.get().saturating_sub(now))),
woken_up,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
use std::task::{RawWaker, RawWakerVTable, Waker};
use std::time::Duration;
fn mock_waker(counter: Arc<Mutex<u32>>) -> Waker {
fn clone(data: *const ()) -> RawWaker {
RawWaker::new(data, &VTABLE)
}
fn wake(data: *const ()) {
let counter = unsafe { &*(data as *const Mutex<u32>) };
let mut lock = counter.lock().unwrap();
*lock += 1;
}
fn wake_by_ref(data: *const ()) {
wake(data)
}
fn drop(_: *const ()) {}
static VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
let ptr = Arc::into_raw(counter) as *const ();
unsafe { Waker::from_raw(RawWaker::new(ptr, &VTABLE)) }
}
#[test]
fn test_timer_submit_and_deadline() {
let timer = Timer::new();
let counter = Arc::new(Mutex::new(0));
let waker = mock_waker(counter.clone());
let _handle = timer.submit(Instant::now() + Duration::from_millis(50), waker);
let (deadline, _woken_up) = timer.spin_and_get_deadline();
assert!(deadline.is_some());
let ms = deadline.unwrap().as_millis();
assert!(ms <= 50 && ms > 0, "Deadline should be <= 50ms, got {}", ms);
}
#[test]
fn test_timer_cancel() {
let timer = Timer::new();
let counter = Arc::new(Mutex::new(0));
let waker = mock_waker(counter.clone());
let handle = timer
.submit(Instant::now() + Duration::from_millis(100), waker)
.expect("Failed to submit timer");
timer.cancel(handle);
let (deadline, _woken_up) = timer.spin_and_get_deadline();
assert!(deadline.is_none());
}
#[test]
fn test_timer_wake() {
let timer = Timer::new();
let counter = Arc::new(Mutex::new(0));
let waker = mock_waker(counter.clone());
let _handle = timer.submit(Instant::now() + Duration::from_millis(1), waker);
std::thread::sleep(Duration::from_millis(5));
timer.spin_and_get_deadline();
let count = counter.lock().unwrap();
assert!(
*count >= 1,
"Waker should have been called or at least not panic"
);
}
#[test]
fn test_timing_wheel_empty() {
let wheel = TimingWheel::new();
assert!(wheel.is_empty());
assert_eq!(wheel.now(), 0);
assert!(wheel.nearest_wakeup().is_none());
}
#[test]
fn test_timing_wheel_insert_and_expire_single() {
let mut wheel = TimingWheel::new();
let counter = Arc::new(Mutex::new(0));
let waker = mock_waker(counter.clone());
let _handle = wheel.insert(waker, 10);
assert!(!wheel.is_empty());
assert_eq!(
wheel.nearest_wakeup(),
Some(std::num::NonZeroU64::new(10).unwrap())
);
let expired = wheel.advance(9);
assert!(expired.is_empty());
assert_eq!(wheel.now(), 9);
let expired = wheel.advance(1);
assert_eq!(expired.len(), 1);
assert!(wheel.is_empty());
assert!(wheel.nearest_wakeup().is_none());
}
#[test]
fn test_timing_wheel_insert_and_expire_immediate() {
let mut wheel = TimingWheel::new();
let counter = Arc::new(Mutex::new(0));
let waker = mock_waker(counter.clone());
let _handle = wheel.insert(waker, 0);
assert!(!wheel.is_empty());
let expired = wheel.advance(1);
assert_eq!(expired.len(), 1);
}
#[test]
fn test_timing_wheel_cancel() {
let mut wheel = TimingWheel::new();
let counter = Arc::new(Mutex::new(0));
let waker = mock_waker(counter.clone());
let handle = wheel.insert(waker, 100);
assert!(!wheel.is_empty());
wheel.remove(handle);
assert!(wheel.is_empty());
assert!(wheel.nearest_wakeup().is_none());
}
#[test]
fn test_timing_wheel_cancel_wrong_generation() {
let mut wheel = TimingWheel::new();
let counter = Arc::new(Mutex::new(0));
let waker = mock_waker(counter.clone());
let handle1 = wheel.insert(waker, 100);
wheel.remove(handle1);
let waker2 = mock_waker(counter.clone());
let handle2 = wheel.insert(waker2, 200);
wheel.remove(handle1);
assert!(!wheel.is_empty());
assert_eq!(
wheel.nearest_wakeup(),
Some(std::num::NonZeroU64::new(200).unwrap())
);
wheel.remove(handle2);
assert!(wheel.is_empty());
}
#[test]
fn test_timing_wheel_multiple_timers_same_slot() {
let mut wheel = TimingWheel::new();
let counter1 = Arc::new(Mutex::new(0));
let counter2 = Arc::new(Mutex::new(0));
let counter3 = Arc::new(Mutex::new(0));
let _h1 = wheel.insert(mock_waker(counter1.clone()), 50);
let _h2 = wheel.insert(mock_waker(counter2.clone()), 50);
let _h3 = wheel.insert(mock_waker(counter3.clone()), 50);
assert_eq!(
wheel.nearest_wakeup(),
Some(std::num::NonZeroU64::new(50).unwrap())
);
let expired = wheel.advance(50);
assert_eq!(expired.len(), 3);
for waker in expired {
waker.wake();
}
assert_eq!(*counter1.lock().unwrap(), 1);
assert_eq!(*counter2.lock().unwrap(), 1);
assert_eq!(*counter3.lock().unwrap(), 1);
}
#[test]
fn test_timing_wheel_multiple_timers_different_levels() {
let mut wheel = TimingWheel::new();
let counter1 = Arc::new(Mutex::new(0));
let counter2 = Arc::new(Mutex::new(0));
let counter3 = Arc::new(Mutex::new(0));
let _h1 = wheel.insert(mock_waker(counter1.clone()), 10);
let _h2 = wheel.insert(mock_waker(counter2.clone()), 100);
let _h3 = wheel.insert(mock_waker(counter3.clone()), 5000);
assert_eq!(
wheel.nearest_wakeup(),
Some(std::num::NonZeroU64::new(10).unwrap())
);
let expired = wheel.advance(10);
assert_eq!(expired.len(), 1);
expired.into_iter().for_each(|w| w.wake());
assert_eq!(*counter1.lock().unwrap(), 1);
let expired = wheel.advance(90);
assert_eq!(expired.len(), 1);
expired.into_iter().for_each(|w| w.wake());
assert_eq!(*counter2.lock().unwrap(), 1);
let expired = wheel.advance(4900);
assert_eq!(expired.len(), 1);
expired.into_iter().for_each(|w| w.wake());
assert_eq!(*counter3.lock().unwrap(), 1);
}
#[test]
fn test_timing_wheel_large_advance() {
let mut wheel = TimingWheel::new();
let counter = Arc::new(Mutex::new(0));
let waker = mock_waker(counter.clone());
let _handle = wheel.insert(waker, 1000);
let expired = wheel.advance(1000);
assert_eq!(expired.len(), 1);
assert_eq!(wheel.now(), 1000);
}
#[test]
fn test_timing_wheel_wrap_around_level_0() {
let mut wheel = TimingWheel::new();
let counter1 = Arc::new(Mutex::new(0));
let counter2 = Arc::new(Mutex::new(0));
let _h1 = wheel.insert(mock_waker(counter1.clone()), 63);
let _h2 = wheel.insert(mock_waker(counter2.clone()), 65);
let expired = wheel.advance(65);
assert_eq!(expired.len(), 2);
}
#[test]
fn test_timing_wheel_nearest_wakeup_updates() {
let mut wheel = TimingWheel::new();
let counter = Arc::new(Mutex::new(0));
let _h1 = wheel.insert(mock_waker(counter.clone()), 100);
assert_eq!(
wheel.nearest_wakeup(),
Some(std::num::NonZeroU64::new(100).unwrap())
);
let _h2 = wheel.insert(mock_waker(counter.clone()), 50);
assert_eq!(
wheel.nearest_wakeup(),
Some(std::num::NonZeroU64::new(50).unwrap())
);
let _h3 = wheel.insert(mock_waker(counter.clone()), 200);
assert_eq!(
wheel.nearest_wakeup(),
Some(std::num::NonZeroU64::new(50).unwrap())
);
wheel.advance(50);
assert_eq!(
wheel.nearest_wakeup(),
Some(std::num::NonZeroU64::new(100).unwrap())
);
}
#[test]
fn test_timing_wheel_boundary_values() {
let mut wheel = TimingWheel::new();
let counter = Arc::new(Mutex::new(0));
let _h1 = wheel.insert(mock_waker(counter.clone()), 63);
let expired = wheel.advance(63);
assert_eq!(expired.len(), 1);
let _h2 = wheel.insert(mock_waker(counter.clone()), 64);
let expired = wheel.advance(64);
assert_eq!(expired.len(), 1);
let _h3 = wheel.insert(mock_waker(counter.clone()), 4095);
let expired = wheel.advance(4095);
assert_eq!(expired.len(), 1);
let _h4 = wheel.insert(mock_waker(counter.clone()), 4096);
let expired = wheel.advance(4096);
assert_eq!(expired.len(), 1);
}
#[test]
fn test_timing_wheel_zero_advance() {
let mut wheel = TimingWheel::new();
let counter = Arc::new(Mutex::new(0));
let waker = mock_waker(counter.clone());
let _handle = wheel.insert(waker, 100);
let expired = wheel.advance(0);
assert!(expired.is_empty());
assert_eq!(wheel.now(), 0);
}
#[test]
fn test_timing_wheel_cascade_from_higher_levels() {
let mut wheel = TimingWheel::new();
let counter = Arc::new(Mutex::new(0));
let _handle = wheel.insert(mock_waker(counter.clone()), 5000);
let expired = wheel.advance(5000);
assert_eq!(expired.len(), 1);
}
#[test]
fn test_timing_wheel_many_timers() {
let mut wheel = TimingWheel::new();
let mut handles = Vec::new();
let counter = Arc::new(Mutex::new(0));
for i in 1..=100 {
let waker = mock_waker(counter.clone());
let handle = wheel.insert(waker, i * 10);
handles.push(handle);
}
assert_eq!(
wheel.nearest_wakeup(),
Some(std::num::NonZeroU64::new(10).unwrap())
);
for (i, handle) in handles.iter().enumerate() {
if i % 2 == 0 {
wheel.remove(*handle);
}
}
let expired = wheel.advance(1000);
assert_eq!(expired.len(), 50);
}
#[test]
fn test_timing_wheel_reinsert_on_early_wake() {
let mut wheel = TimingWheel::new();
let counter = Arc::new(Mutex::new(0));
let waker = mock_waker(counter.clone());
let _handle = wheel.insert(waker, 100);
let expired = wheel.advance(50);
assert!(expired.is_empty());
assert!(!wheel.is_empty());
assert_eq!(
wheel.nearest_wakeup(),
Some(std::num::NonZeroU64::new(100).unwrap())
);
let expired = wheel.advance(50);
assert_eq!(expired.len(), 1);
}
}