Skip to main content

moduvex_runtime/time/
wheel.rs

1//! Hierarchical timer wheel — O(1) insert/cancel, O(levels) tick.
2//!
3//! # Design
4//! 6 levels × 64 slots. Each level covers a range of deadlines:
5//!
6//! | Level | Slot width | Total range |
7//! |-------|-----------|-------------|
8//! | 0     | 1 ms      | 64 ms       |
9//! | 1     | 64 ms     | ~4 s        |
10//! | 2     | ~4 s      | ~4 min      |
11//! | 3     | ~4 min    | ~4.5 h      |
12//! | 4     | ~4.5 h    | ~12 d       |
13//! | 5     | ~12 d     | ~2 yr       |
14//!
15//! Timers beyond level 5 are clamped into the last slot of level 5.
16//!
17//! # Cascade
18//! When the executor's "current tick" advances past a slot boundary at level N,
19//! all timers in that slot are re-inserted at level N-1 (standard wheel cascade).
20
21use std::collections::HashMap;
22use std::task::Waker;
23use std::time::Instant;
24
25/// Number of slots per wheel level (must be a power of 2).
26const SLOTS: usize = 64;
27const SLOTS_MASK: u64 = (SLOTS - 1) as u64;
28
29/// Number of wheel levels.
30const LEVELS: usize = 6;
31
32/// Width of level 0 in milliseconds (1 ms per slot).
33const LEVEL0_MS: u64 = 1;
34
35/// Width of each slot at level N = LEVEL0_MS * SLOTS^N.
36fn slot_width_ms(level: usize) -> u64 {
37    LEVEL0_MS * (SLOTS as u64).pow(level as u32)
38}
39
40// ── Timer entry ───────────────────────────────────────────────────────────────
41
42/// A single pending timer.
43#[derive(Debug)]
44pub(crate) struct TimerEntry {
45    /// Unique timer identifier (for cancellation).
46    pub id: u64,
47    /// Absolute deadline.
48    pub deadline: Instant,
49    /// Waker to call when the deadline passes.
50    pub waker: Waker,
51}
52
53// ── TimerId ───────────────────────────────────────────────────────────────────
54
55/// Opaque handle returned by `TimerWheel::insert`. Used to cancel a timer.
56#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57pub struct TimerId(u64);
58
59// ── TimerWheel ────────────────────────────────────────────────────────────────
60
61/// Hierarchical timer wheel.
62///
63/// All operations are relative to a monotonic millisecond counter derived from
64/// an `Instant` origin captured at construction time.
65pub struct TimerWheel {
66    /// The `Instant` corresponding to tick 0.
67    origin: Instant,
68    /// `wheel[level][slot]` → list of timer entries.
69    wheel: Vec<Vec<Vec<TimerEntry>>>,
70    /// Index (slot, level) of each active timer for O(1) lookup on cancel.
71    /// Maps timer id → (level, slot).
72    index: HashMap<u64, (usize, usize)>,
73    /// Monotonically increasing ID counter.
74    next_id: u64,
75    /// Last processed tick (milliseconds since origin).
76    last_tick_ms: u64,
77}
78
79impl TimerWheel {
80    /// Create a new timer wheel with `origin` as the zero point.
81    pub(crate) fn new(origin: Instant) -> Self {
82        // wheel[level][slot] = vec of entries
83        let wheel = (0..LEVELS)
84            .map(|_| (0..SLOTS).map(|_| Vec::new()).collect())
85            .collect();
86        Self {
87            origin,
88            wheel,
89            index: HashMap::new(),
90            next_id: 1,
91            last_tick_ms: 0,
92        }
93    }
94
95    /// Convert an `Instant` to milliseconds since origin, saturating at 0.
96    fn instant_to_ms(&self, t: Instant) -> u64 {
97        t.saturating_duration_since(self.origin)
98            .as_millis()
99            .try_into()
100            .unwrap_or(u64::MAX)
101    }
102
103    /// Insert a timer that fires at `deadline`. Returns a `TimerId` for
104    /// cancellation. The `waker` is called when the deadline passes.
105    pub(crate) fn insert(&mut self, deadline: Instant, waker: Waker) -> TimerId {
106        let id = self.next_id;
107        self.next_id += 1;
108
109        let deadline_ms = self.instant_to_ms(deadline);
110        // Fire immediately if deadline already passed.
111        let effective_ms = deadline_ms.max(self.last_tick_ms);
112
113        let (level, slot) = self.level_slot(effective_ms);
114        self.wheel[level][slot].push(TimerEntry {
115            id,
116            deadline,
117            waker,
118        });
119        self.index.insert(id, (level, slot));
120
121        TimerId(id)
122    }
123
124    /// Cancel the timer identified by `id`. Returns `true` if the timer was
125    /// found and removed, `false` if it had already fired or was not found.
126    pub(crate) fn cancel(&mut self, id: TimerId) -> bool {
127        let Some((level, slot)) = self.index.remove(&id.0) else {
128            return false;
129        };
130        let bucket = &mut self.wheel[level][slot];
131        let before = bucket.len();
132        bucket.retain(|e| e.id != id.0);
133        bucket.len() < before
134    }
135
136    /// Advance the wheel to `now`, returning all wakers whose timers have
137    /// expired. Callers must call `wake()` on each returned `Waker`.
138    pub(crate) fn tick(&mut self, now: Instant) -> Vec<Waker> {
139        let now_ms = self.instant_to_ms(now);
140        let mut fired: Vec<Waker> = Vec::new();
141
142        // Process every millisecond tick from last processed up to now.
143        // For large jumps (e.g. after a long sleep) we cascade all levels.
144        let from = self.last_tick_ms;
145        let to = now_ms;
146
147        if to < from {
148            return fired; // clock did not advance (equal means process current slot)
149        }
150
151        // Range is inclusive of `from` so that timers inserted exactly at
152        // `last_tick_ms` (deadline ≤ last_tick_ms) get drained on the first
153        // tick call after they are inserted.
154        let mut t = from;
155        loop {
156            // Drain level-0 slot for this tick.
157            let slot0 = (t & SLOTS_MASK) as usize;
158            let entries = std::mem::take(&mut self.wheel[0][slot0]);
159            for entry in entries {
160                self.index.remove(&entry.id);
161                // Fire if deadline has passed; otherwise re-insert (shouldn't
162                // happen in correct usage, but guard against edge cases).
163                if self.instant_to_ms(entry.deadline) <= t {
164                    fired.push(entry.waker);
165                } else {
166                    // Re-insert if somehow placed in wrong slot.
167                    self.insert_raw(entry);
168                }
169            }
170
171            // Cascade higher levels when their slot boundary is crossed.
172            // Level N cascades when tick crosses a multiple of SLOTS^N.
173            for level in 1..LEVELS {
174                let width = slot_width_ms(level);
175                if t % width == 0 {
176                    let slot = ((t / width) & SLOTS_MASK) as usize;
177                    let entries = std::mem::take(&mut self.wheel[level][slot]);
178                    for entry in entries {
179                        self.index.remove(&entry.id);
180                        if self.instant_to_ms(entry.deadline) <= t {
181                            fired.push(entry.waker);
182                        } else {
183                            self.insert_raw(entry);
184                        }
185                    }
186                }
187            }
188
189            if t >= to {
190                break;
191            }
192            t += 1;
193        }
194
195        self.last_tick_ms = to;
196        fired
197    }
198
199    /// Return the nearest deadline across all wheel slots, if any timers are pending.
200    pub(crate) fn next_deadline(&self) -> Option<Instant> {
201        let mut earliest: Option<Instant> = None;
202        for level in &self.wheel {
203            for slot in level {
204                for entry in slot {
205                    earliest = Some(match earliest {
206                        None => entry.deadline,
207                        Some(e) => e.min(entry.deadline),
208                    });
209                }
210            }
211        }
212        earliest
213    }
214
215    /// Internal: insert a pre-existing `TimerEntry` into the correct bucket.
216    fn insert_raw(&mut self, entry: TimerEntry) {
217        let deadline_ms = self.instant_to_ms(entry.deadline);
218        let effective_ms = deadline_ms.max(self.last_tick_ms);
219        let (level, slot) = self.level_slot(effective_ms);
220        self.index.insert(entry.id, (level, slot));
221        self.wheel[level][slot].push(entry);
222    }
223
224    /// Compute the (level, slot) for a timer with deadline at `deadline_ms`.
225    fn level_slot(&self, deadline_ms: u64) -> (usize, usize) {
226        let delta = deadline_ms.saturating_sub(self.last_tick_ms);
227
228        for level in 0..LEVELS {
229            let width = slot_width_ms(level);
230            let range = width * SLOTS as u64;
231            if delta < range || level == LEVELS - 1 {
232                // Compute absolute slot at this level.
233                let slot = ((deadline_ms / width) & SLOTS_MASK) as usize;
234                return (level, slot);
235            }
236        }
237        // Unreachable: loop handles all cases.
238        (LEVELS - 1, 0)
239    }
240}
241
242// ── Tests ─────────────────────────────────────────────────────────────────────
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use std::sync::{Arc, Mutex};
248    use std::task::{RawWaker, RawWakerVTable};
249    use std::time::Duration;
250
251    fn make_flag_waker(flag: Arc<Mutex<bool>>) -> Waker {
252        let data = Arc::into_raw(flag) as *const ();
253
254        unsafe fn clone_w(p: *const ()) -> RawWaker {
255            Arc::increment_strong_count(p as *const Mutex<bool>);
256            RawWaker::new(p, &VT)
257        }
258        unsafe fn wake(p: *const ()) {
259            *Arc::from_raw(p as *const Mutex<bool>).lock().unwrap() = true;
260        }
261        unsafe fn wake_ref(p: *const ()) {
262            *(*(&p as *const *const () as *const Arc<Mutex<bool>>))
263                .lock()
264                .unwrap() = true;
265        }
266        unsafe fn drop_w(p: *const ()) {
267            drop(Arc::from_raw(p as *const Mutex<bool>));
268        }
269        static VT: RawWakerVTable = RawWakerVTable::new(clone_w, wake, wake_ref, drop_w);
270
271        // SAFETY: vtable satisfies the RawWaker contract.
272        unsafe { Waker::from_raw(RawWaker::new(data, &VT)) }
273    }
274
275    #[test]
276    fn insert_and_tick_fires_waker() {
277        let flag = Arc::new(Mutex::new(false));
278        let waker = make_flag_waker(Arc::clone(&flag));
279
280        let origin = Instant::now();
281        let mut wheel = TimerWheel::new(origin);
282
283        let deadline = origin + Duration::from_millis(50);
284        wheel.insert(deadline, waker);
285
286        // Tick before deadline — should not fire.
287        let wakers = wheel.tick(origin + Duration::from_millis(30));
288        assert!(wakers.is_empty());
289
290        // Tick at/after deadline — should fire.
291        let wakers = wheel.tick(origin + Duration::from_millis(60));
292        assert_eq!(wakers.len(), 1);
293        for w in wakers {
294            w.wake();
295        }
296        assert!(*flag.lock().unwrap(), "waker must have fired");
297    }
298
299    #[test]
300    fn cancel_prevents_firing() {
301        let flag = Arc::new(Mutex::new(false));
302        let waker = make_flag_waker(Arc::clone(&flag));
303
304        let origin = Instant::now();
305        let mut wheel = TimerWheel::new(origin);
306
307        let deadline = origin + Duration::from_millis(50);
308        let id = wheel.insert(deadline, waker);
309        let removed = wheel.cancel(id);
310        assert!(removed, "cancel must return true for existing timer");
311
312        // Tick past deadline — must not fire.
313        let wakers = wheel.tick(origin + Duration::from_millis(100));
314        assert!(wakers.is_empty(), "cancelled timer must not fire");
315        assert!(!*flag.lock().unwrap());
316    }
317
318    #[test]
319    fn zero_deadline_fires_on_next_tick() {
320        let flag = Arc::new(Mutex::new(false));
321        let waker = make_flag_waker(Arc::clone(&flag));
322
323        let origin = Instant::now();
324        let mut wheel = TimerWheel::new(origin);
325
326        // Deadline in the past (or now) → fires immediately on next tick.
327        wheel.insert(origin, waker);
328        let wakers = wheel.tick(origin + Duration::from_millis(1));
329        assert_eq!(wakers.len(), 1);
330        for w in wakers {
331            w.wake();
332        }
333        assert!(*flag.lock().unwrap());
334    }
335
336    #[test]
337    fn multiple_timers_fire_in_order() {
338        let origin = Instant::now();
339        let mut wheel = TimerWheel::new(origin);
340        let results = Arc::new(Mutex::new(Vec::<u32>::new()));
341
342        for i in 0u32..5 {
343            let r = Arc::clone(&results);
344            let flag = Arc::new(Mutex::new(false));
345            let _waker = make_flag_waker(Arc::clone(&flag));
346            let _ = flag; // waker owns it now
347                          // Re-build a waker that records the index.
348            let data = Box::into_raw(Box::new((i, r))) as *const ();
349            type Payload = (u32, Arc<Mutex<Vec<u32>>>);
350            unsafe fn clone_p(p: *const ()) -> RawWaker {
351                let b = Box::from_raw(p as *mut Payload);
352                let cloned = Box::new((b.0, Arc::clone(&b.1)));
353                std::mem::forget(b);
354                RawWaker::new(Box::into_raw(cloned) as *const (), &PVT)
355            }
356            unsafe fn wake_p(p: *const ()) {
357                let b = Box::from_raw(p as *mut Payload);
358                b.1.lock().unwrap().push(b.0);
359            }
360            unsafe fn wake_p_ref(p: *const ()) {
361                let b = Box::from_raw(p as *mut Payload);
362                b.1.lock().unwrap().push(b.0);
363                std::mem::forget(b);
364            }
365            unsafe fn drop_p(p: *const ()) {
366                drop(Box::from_raw(p as *mut Payload));
367            }
368            static PVT: RawWakerVTable = RawWakerVTable::new(clone_p, wake_p, wake_p_ref, drop_p);
369            // SAFETY: PVT satisfies the RawWaker contract; payload is Box-allocated.
370            let waker2 = unsafe { Waker::from_raw(RawWaker::new(data, &PVT)) };
371
372            wheel.insert(origin + Duration::from_millis((i as u64 + 1) * 10), waker2);
373        }
374
375        // Single tick past all deadlines.
376        let wakers = wheel.tick(origin + Duration::from_millis(60));
377        assert_eq!(wakers.len(), 5);
378        for w in wakers {
379            w.wake();
380        }
381        let v = results.lock().unwrap();
382        assert_eq!(v.len(), 5);
383    }
384
385    #[test]
386    fn next_deadline_returns_earliest() {
387        let origin = Instant::now();
388        let mut wheel = TimerWheel::new(origin);
389
390        let d1 = origin + Duration::from_millis(200);
391        let d2 = origin + Duration::from_millis(50);
392
393        let f1 = Arc::new(Mutex::new(false));
394        let f2 = Arc::new(Mutex::new(false));
395        wheel.insert(d1, make_flag_waker(Arc::clone(&f1)));
396        wheel.insert(d2, make_flag_waker(Arc::clone(&f2)));
397
398        let earliest = wheel.next_deadline().expect("should have a deadline");
399        assert_eq!(earliest, d2, "next_deadline must return earliest");
400    }
401}