1use std::collections::HashMap;
22use std::task::Waker;
23use std::time::Instant;
24
25const SLOTS: usize = 64;
27const SLOTS_MASK: u64 = (SLOTS - 1) as u64;
28
29const LEVELS: usize = 6;
31
32const LEVEL0_MS: u64 = 1;
34
35fn slot_width_ms(level: usize) -> u64 {
37 LEVEL0_MS * (SLOTS as u64).pow(level as u32)
38}
39
40#[derive(Debug)]
44pub(crate) struct TimerEntry {
45 pub id: u64,
47 pub deadline: Instant,
49 pub waker: Waker,
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57pub struct TimerId(u64);
58
59pub struct TimerWheel {
66 origin: Instant,
68 wheel: Vec<Vec<Vec<TimerEntry>>>,
70 index: HashMap<u64, (usize, usize)>,
73 next_id: u64,
75 last_tick_ms: u64,
77}
78
79impl TimerWheel {
80 pub(crate) fn new(origin: Instant) -> Self {
82 let wheel = (0..LEVELS)
84 .map(|_| (0..SLOTS).map(|_| Vec::new()).collect())
85 .collect();
86 Self {
87 origin,
88 wheel,
89 index: HashMap::new(),
90 next_id: 1,
91 last_tick_ms: 0,
92 }
93 }
94
95 fn instant_to_ms(&self, t: Instant) -> u64 {
97 t.saturating_duration_since(self.origin)
98 .as_millis()
99 .try_into()
100 .unwrap_or(u64::MAX)
101 }
102
103 pub(crate) fn insert(&mut self, deadline: Instant, waker: Waker) -> TimerId {
106 let id = self.next_id;
107 self.next_id += 1;
108
109 let deadline_ms = self.instant_to_ms(deadline);
110 let effective_ms = deadline_ms.max(self.last_tick_ms);
112
113 let (level, slot) = self.level_slot(effective_ms);
114 self.wheel[level][slot].push(TimerEntry {
115 id,
116 deadline,
117 waker,
118 });
119 self.index.insert(id, (level, slot));
120
121 TimerId(id)
122 }
123
124 pub(crate) fn cancel(&mut self, id: TimerId) -> bool {
127 let Some((level, slot)) = self.index.remove(&id.0) else {
128 return false;
129 };
130 let bucket = &mut self.wheel[level][slot];
131 let before = bucket.len();
132 bucket.retain(|e| e.id != id.0);
133 bucket.len() < before
134 }
135
136 pub(crate) fn tick(&mut self, now: Instant) -> Vec<Waker> {
146 let now_ms = self.instant_to_ms(now);
147 let mut fired: Vec<Waker> = Vec::new();
148
149 let from = self.last_tick_ms;
150 let to = now_ms;
151
152 if to < from {
153 return fired;
154 }
155
156 let from_slot0 = (from & SLOTS_MASK) as usize;
162 let to_slot0 = (to & SLOTS_MASK) as usize;
163 let span = to.saturating_sub(from);
164
165 if span >= SLOTS as u64 {
166 for slot in 0..SLOTS {
168 self.drain_slot(0, slot, to, &mut fired);
169 }
170 } else if from_slot0 <= to_slot0 {
171 for slot in from_slot0..=to_slot0 {
173 self.drain_slot(0, slot, to, &mut fired);
174 }
175 } else {
176 for slot in from_slot0..SLOTS {
178 self.drain_slot(0, slot, to, &mut fired);
179 }
180 for slot in 0..=to_slot0 {
181 self.drain_slot(0, slot, to, &mut fired);
182 }
183 }
184
185 for level in 1..LEVELS {
187 let width = slot_width_ms(level);
188 let first_boundary = if from % width == 0 {
190 from
191 } else {
192 (from / width + 1) * width
193 };
194 let mut boundary = first_boundary;
195 while boundary <= to {
196 let slot = ((boundary / width) & SLOTS_MASK) as usize;
197 self.drain_slot(level, slot, to, &mut fired);
198 boundary = match boundary.checked_add(width) {
199 Some(b) => b,
200 None => break,
201 };
202 }
203 }
204
205 self.last_tick_ms = to;
206 fired
207 }
208
209 fn drain_slot(&mut self, level: usize, slot: usize, now_ms: u64, fired: &mut Vec<Waker>) {
214 let entries = std::mem::take(&mut self.wheel[level][slot]);
215 for entry in entries {
216 self.index.remove(&entry.id);
217 if self.instant_to_ms(entry.deadline) <= now_ms {
218 fired.push(entry.waker);
219 } else {
220 self.insert_raw(entry);
221 }
222 }
223 }
224
225 pub(crate) fn next_deadline(&self) -> Option<Instant> {
227 let mut earliest: Option<Instant> = None;
228 for level in &self.wheel {
229 for slot in level {
230 for entry in slot {
231 earliest = Some(match earliest {
232 None => entry.deadline,
233 Some(e) => e.min(entry.deadline),
234 });
235 }
236 }
237 }
238 earliest
239 }
240
241 fn insert_raw(&mut self, entry: TimerEntry) {
243 let deadline_ms = self.instant_to_ms(entry.deadline);
244 let effective_ms = deadline_ms.max(self.last_tick_ms);
245 let (level, slot) = self.level_slot(effective_ms);
246 self.index.insert(entry.id, (level, slot));
247 self.wheel[level][slot].push(entry);
248 }
249
250 fn level_slot(&self, deadline_ms: u64) -> (usize, usize) {
252 let delta = deadline_ms.saturating_sub(self.last_tick_ms);
253
254 for level in 0..LEVELS {
255 let width = slot_width_ms(level);
256 let range = width * SLOTS as u64;
257 if delta < range || level == LEVELS - 1 {
258 let slot = ((deadline_ms / width) & SLOTS_MASK) as usize;
260 return (level, slot);
261 }
262 }
263 (LEVELS - 1, 0)
265 }
266}
267
268#[cfg(test)]
271mod tests {
272 use super::*;
273 use std::sync::{Arc, Mutex};
274 use std::task::{RawWaker, RawWakerVTable};
275 use std::time::Duration;
276
277 fn make_flag_waker(flag: Arc<Mutex<bool>>) -> Waker {
278 let data = Arc::into_raw(flag) as *const ();
279
280 unsafe fn clone_w(p: *const ()) -> RawWaker {
281 Arc::increment_strong_count(p as *const Mutex<bool>);
282 RawWaker::new(p, &VT)
283 }
284 unsafe fn wake(p: *const ()) {
285 *Arc::from_raw(p as *const Mutex<bool>).lock().unwrap() = true;
286 }
287 unsafe fn wake_ref(p: *const ()) {
288 *(*(&p as *const *const () as *const Arc<Mutex<bool>>))
289 .lock()
290 .unwrap() = true;
291 }
292 unsafe fn drop_w(p: *const ()) {
293 drop(Arc::from_raw(p as *const Mutex<bool>));
294 }
295 static VT: RawWakerVTable = RawWakerVTable::new(clone_w, wake, wake_ref, drop_w);
296
297 unsafe { Waker::from_raw(RawWaker::new(data, &VT)) }
299 }
300
301 #[test]
302 fn insert_and_tick_fires_waker() {
303 let flag = Arc::new(Mutex::new(false));
304 let waker = make_flag_waker(Arc::clone(&flag));
305
306 let origin = Instant::now();
307 let mut wheel = TimerWheel::new(origin);
308
309 let deadline = origin + Duration::from_millis(50);
310 wheel.insert(deadline, waker);
311
312 let wakers = wheel.tick(origin + Duration::from_millis(30));
314 assert!(wakers.is_empty());
315
316 let wakers = wheel.tick(origin + Duration::from_millis(60));
318 assert_eq!(wakers.len(), 1);
319 for w in wakers {
320 w.wake();
321 }
322 assert!(*flag.lock().unwrap(), "waker must have fired");
323 }
324
325 #[test]
326 fn cancel_prevents_firing() {
327 let flag = Arc::new(Mutex::new(false));
328 let waker = make_flag_waker(Arc::clone(&flag));
329
330 let origin = Instant::now();
331 let mut wheel = TimerWheel::new(origin);
332
333 let deadline = origin + Duration::from_millis(50);
334 let id = wheel.insert(deadline, waker);
335 let removed = wheel.cancel(id);
336 assert!(removed, "cancel must return true for existing timer");
337
338 let wakers = wheel.tick(origin + Duration::from_millis(100));
340 assert!(wakers.is_empty(), "cancelled timer must not fire");
341 assert!(!*flag.lock().unwrap());
342 }
343
344 #[test]
345 fn zero_deadline_fires_on_next_tick() {
346 let flag = Arc::new(Mutex::new(false));
347 let waker = make_flag_waker(Arc::clone(&flag));
348
349 let origin = Instant::now();
350 let mut wheel = TimerWheel::new(origin);
351
352 wheel.insert(origin, waker);
354 let wakers = wheel.tick(origin + Duration::from_millis(1));
355 assert_eq!(wakers.len(), 1);
356 for w in wakers {
357 w.wake();
358 }
359 assert!(*flag.lock().unwrap());
360 }
361
362 #[test]
363 fn multiple_timers_fire_in_order() {
364 let origin = Instant::now();
365 let mut wheel = TimerWheel::new(origin);
366 let results = Arc::new(Mutex::new(Vec::<u32>::new()));
367
368 for i in 0u32..5 {
369 let r = Arc::clone(&results);
370 let flag = Arc::new(Mutex::new(false));
371 let _waker = make_flag_waker(Arc::clone(&flag));
372 let _ = flag; let data = Box::into_raw(Box::new((i, r))) as *const ();
375 type Payload = (u32, Arc<Mutex<Vec<u32>>>);
376 unsafe fn clone_p(p: *const ()) -> RawWaker {
377 let b = Box::from_raw(p as *mut Payload);
378 let cloned = Box::new((b.0, Arc::clone(&b.1)));
379 std::mem::forget(b);
380 RawWaker::new(Box::into_raw(cloned) as *const (), &PVT)
381 }
382 unsafe fn wake_p(p: *const ()) {
383 let b = Box::from_raw(p as *mut Payload);
384 b.1.lock().unwrap().push(b.0);
385 }
386 unsafe fn wake_p_ref(p: *const ()) {
387 let b = Box::from_raw(p as *mut Payload);
388 b.1.lock().unwrap().push(b.0);
389 std::mem::forget(b);
390 }
391 unsafe fn drop_p(p: *const ()) {
392 drop(Box::from_raw(p as *mut Payload));
393 }
394 static PVT: RawWakerVTable = RawWakerVTable::new(clone_p, wake_p, wake_p_ref, drop_p);
395 let waker2 = unsafe { Waker::from_raw(RawWaker::new(data, &PVT)) };
397
398 wheel.insert(origin + Duration::from_millis((i as u64 + 1) * 10), waker2);
399 }
400
401 let wakers = wheel.tick(origin + Duration::from_millis(60));
403 assert_eq!(wakers.len(), 5);
404 for w in wakers {
405 w.wake();
406 }
407 let v = results.lock().unwrap();
408 assert_eq!(v.len(), 5);
409 }
410
411 #[test]
412 fn next_deadline_returns_earliest() {
413 let origin = Instant::now();
414 let mut wheel = TimerWheel::new(origin);
415
416 let d1 = origin + Duration::from_millis(200);
417 let d2 = origin + Duration::from_millis(50);
418
419 let f1 = Arc::new(Mutex::new(false));
420 let f2 = Arc::new(Mutex::new(false));
421 wheel.insert(d1, make_flag_waker(Arc::clone(&f1)));
422 wheel.insert(d2, make_flag_waker(Arc::clone(&f2)));
423
424 let earliest = wheel.next_deadline().expect("should have a deadline");
425 assert_eq!(earliest, d2, "next_deadline must return earliest");
426 }
427
428 #[test]
429 fn large_time_jump_fires_timer_quickly() {
430 let flag = Arc::new(Mutex::new(false));
433 let waker = make_flag_waker(Arc::clone(&flag));
434
435 let origin = Instant::now();
436 let mut wheel = TimerWheel::new(origin);
437
438 let deadline = origin + Duration::from_millis(50);
439 wheel.insert(deadline, waker);
440
441 let start = std::time::Instant::now();
443 let wakers = wheel.tick(origin + Duration::from_secs(10));
444 let elapsed = start.elapsed();
445
446 assert_eq!(wakers.len(), 1, "timer must fire on 10s jump");
447 for w in wakers {
448 w.wake();
449 }
450 assert!(*flag.lock().unwrap(), "waker must have been called");
451 assert!(
453 elapsed < Duration::from_millis(10),
454 "10s tick must complete in <10ms, took {:?}",
455 elapsed
456 );
457 }
458
459 #[test]
462 fn wheel_cancel_nonexistent_returns_false() {
463 let origin = Instant::now();
464 let mut w = TimerWheel::new(origin);
465 let fake_id = TimerId(9999);
466 assert!(!w.cancel(fake_id));
467 }
468
469 #[test]
470 fn wheel_cancel_already_fired_returns_false() {
471 let flag = Arc::new(Mutex::new(false));
472 let origin = Instant::now();
473 let mut w = TimerWheel::new(origin);
474 let waker = make_flag_waker(Arc::clone(&flag));
475 let id = w.insert(origin + Duration::from_millis(5), waker);
476 let _ = w.tick(origin + Duration::from_millis(10)); assert!(!w.cancel(id)); }
479
480 #[test]
481 fn wheel_tick_backwards_is_noop() {
482 let flag = Arc::new(Mutex::new(false));
483 let origin = Instant::now();
484 let mut w = TimerWheel::new(origin);
485 let waker = make_flag_waker(Arc::clone(&flag));
486 w.insert(origin + Duration::from_millis(50), waker);
487 let _ = w.tick(origin + Duration::from_millis(100)); let wakers = w.tick(origin + Duration::from_millis(10));
490 assert!(wakers.is_empty());
491 }
492
493 #[test]
494 fn wheel_multiple_timers_same_slot() {
495 let origin = Instant::now();
496 let mut w = TimerWheel::new(origin);
497 for _ in 0..5 {
498 let flag = Arc::new(Mutex::new(false));
499 let waker = make_flag_waker(Arc::clone(&flag));
500 w.insert(origin + Duration::from_millis(10), waker);
501 }
502 let wakers = w.tick(origin + Duration::from_millis(20));
503 assert_eq!(wakers.len(), 5);
504 }
505
506 #[test]
507 fn wheel_1000_timers_all_fire() {
508 let origin = Instant::now();
509 let mut w = TimerWheel::new(origin);
510 for i in 0..1000u64 {
511 let flag = Arc::new(Mutex::new(false));
512 let waker = make_flag_waker(Arc::clone(&flag));
513 w.insert(origin + Duration::from_millis(i % 100), waker);
514 }
515 let wakers = w.tick(origin + Duration::from_millis(200));
516 assert_eq!(wakers.len(), 1000);
517 }
518
519 #[test]
520 fn wheel_next_deadline_empty_returns_none() {
521 let origin = Instant::now();
522 let w = TimerWheel::new(origin);
523 assert!(w.next_deadline().is_none());
524 }
525
526 #[test]
527 fn wheel_next_deadline_after_cancel_updates() {
528 let origin = Instant::now();
529 let mut w = TimerWheel::new(origin);
530 let f1 = Arc::new(Mutex::new(false));
531 let f2 = Arc::new(Mutex::new(false));
532 let d1 = origin + Duration::from_millis(100);
533 let d2 = origin + Duration::from_millis(200);
534 let id1 = w.insert(d1, make_flag_waker(Arc::clone(&f1)));
535 let _id2 = w.insert(d2, make_flag_waker(Arc::clone(&f2)));
536 assert_eq!(w.next_deadline().unwrap(), d1);
537 w.cancel(id1);
538 assert_eq!(w.next_deadline().unwrap(), d2);
539 }
540
541 #[test]
542 fn wheel_partial_tick_does_not_fire_future_timers() {
543 let origin = Instant::now();
544 let mut w = TimerWheel::new(origin);
545 let flag = Arc::new(Mutex::new(false));
546 w.insert(
547 origin + Duration::from_millis(100),
548 make_flag_waker(Arc::clone(&flag)),
549 );
550 let wakers = w.tick(origin + Duration::from_millis(50));
551 assert!(wakers.is_empty());
552 assert!(!*flag.lock().unwrap());
553 }
554
555 #[test]
556 fn wheel_level_boundary_cascades_correctly() {
557 let origin = Instant::now();
559 let mut w = TimerWheel::new(origin);
560 let flag = Arc::new(Mutex::new(false));
561 w.insert(
562 origin + Duration::from_millis(65),
563 make_flag_waker(Arc::clone(&flag)),
564 );
565 let wakers = w.tick(origin + Duration::from_millis(70));
566 assert_eq!(wakers.len(), 1);
567 }
568
569 #[test]
570 fn wheel_insert_past_deadline_fires_on_first_tick() {
571 let origin = Instant::now();
573 let mut w = TimerWheel::new(origin);
574 let flag = Arc::new(Mutex::new(false));
575 let past_deadline = origin; w.insert(past_deadline, make_flag_waker(Arc::clone(&flag)));
577 let wakers = w.tick(origin + Duration::from_millis(1));
578 assert!(!wakers.is_empty());
579 }
580
581 #[test]
582 fn wheel_two_timers_different_deadlines_only_earlier_fires() {
583 let origin = Instant::now();
584 let mut w = TimerWheel::new(origin);
585 let f1 = Arc::new(Mutex::new(false));
586 let f2 = Arc::new(Mutex::new(false));
587 w.insert(
588 origin + Duration::from_millis(10),
589 make_flag_waker(Arc::clone(&f1)),
590 );
591 w.insert(
592 origin + Duration::from_millis(50),
593 make_flag_waker(Arc::clone(&f2)),
594 );
595 let wakers = w.tick(origin + Duration::from_millis(20));
596 assert_eq!(wakers.len(), 1);
598 assert!(!*f2.lock().unwrap());
599 }
600
601 #[test]
602 fn wheel_cancel_all_removes_from_index() {
603 let origin = Instant::now();
604 let mut w = TimerWheel::new(origin);
605 let mut ids = Vec::new();
606 for i in 1..=5u64 {
607 let flag = Arc::new(Mutex::new(false));
608 let id = w.insert(origin + Duration::from_millis(i * 10), make_flag_waker(flag));
609 ids.push(id);
610 }
611 for id in ids {
613 assert!(w.cancel(id));
614 }
615 let wakers = w.tick(origin + Duration::from_millis(100));
617 assert!(wakers.is_empty());
618 }
619
620 #[test]
621 fn wheel_many_deadlines_at_different_levels() {
622 let origin = Instant::now();
623 let mut w = TimerWheel::new(origin);
624 let f1 = Arc::new(Mutex::new(false));
626 w.insert(origin + Duration::from_millis(1), make_flag_waker(Arc::clone(&f1)));
627 let f2 = Arc::new(Mutex::new(false));
629 w.insert(origin + Duration::from_millis(100), make_flag_waker(Arc::clone(&f2)));
630 let f3 = Arc::new(Mutex::new(false));
632 w.insert(origin + Duration::from_millis(5000), make_flag_waker(Arc::clone(&f3)));
633
634 let wakers = w.tick(origin + Duration::from_millis(200));
636 assert_eq!(wakers.len(), 2);
637 assert!(!*f3.lock().unwrap());
639
640 let wakers2 = w.tick(origin + Duration::from_millis(6000));
642 assert_eq!(wakers2.len(), 1);
643 }
644
645 #[test]
646 fn wheel_empty_tick_returns_empty_vec() {
647 let origin = Instant::now();
648 let mut w = TimerWheel::new(origin);
649 let wakers = w.tick(origin + Duration::from_millis(1000));
651 assert!(wakers.is_empty());
652 }
653
654 #[test]
655 fn wheel_same_tick_twice_second_empty() {
656 let origin = Instant::now();
657 let mut w = TimerWheel::new(origin);
658 let flag = Arc::new(Mutex::new(false));
659 w.insert(origin + Duration::from_millis(10), make_flag_waker(Arc::clone(&flag)));
660 let wakers1 = w.tick(origin + Duration::from_millis(20));
661 assert_eq!(wakers1.len(), 1);
662 let wakers2 = w.tick(origin + Duration::from_millis(20));
664 assert!(wakers2.is_empty());
665 }
666
667 #[test]
668 fn wheel_timer_id_uniqueness() {
669 let origin = Instant::now();
670 let mut w = TimerWheel::new(origin);
671 let mut ids = std::collections::HashSet::new();
672 for i in 0..10u64 {
673 let flag = Arc::new(Mutex::new(false));
674 let id = w.insert(origin + Duration::from_millis(i * 5 + 1), make_flag_waker(flag));
675 assert!(ids.insert(id));
677 }
678 }
679
680 #[test]
681 fn wheel_tick_advances_last_tick_ms() {
682 let origin = Instant::now();
684 let mut w = TimerWheel::new(origin);
685 let flag = Arc::new(Mutex::new(false));
687 w.insert(origin + Duration::from_millis(200), make_flag_waker(Arc::clone(&flag)));
688 let wakers1 = w.tick(origin + Duration::from_millis(100));
690 assert!(wakers1.is_empty());
691 let wakers2 = w.tick(origin + Duration::from_millis(250));
693 assert_eq!(wakers2.len(), 1);
694 }
695}