moduvex_runtime/time/
wheel.rs1use std::collections::HashMap;
22use std::task::Waker;
23use std::time::Instant;
24
25const SLOTS: usize = 64;
27const SLOTS_MASK: u64 = (SLOTS - 1) as u64;
28
29const LEVELS: usize = 6;
31
32const LEVEL0_MS: u64 = 1;
34
35fn slot_width_ms(level: usize) -> u64 {
37 LEVEL0_MS * (SLOTS as u64).pow(level as u32)
38}
39
40#[derive(Debug)]
44pub(crate) struct TimerEntry {
45 pub id: u64,
47 pub deadline: Instant,
49 pub waker: Waker,
51}
52
53#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
57pub struct TimerId(u64);
58
59pub struct TimerWheel {
66 origin: Instant,
68 wheel: Vec<Vec<Vec<TimerEntry>>>,
70 index: HashMap<u64, (usize, usize)>,
73 next_id: u64,
75 last_tick_ms: u64,
77}
78
79impl TimerWheel {
80 pub(crate) fn new(origin: Instant) -> Self {
82 let wheel = (0..LEVELS)
84 .map(|_| (0..SLOTS).map(|_| Vec::new()).collect())
85 .collect();
86 Self {
87 origin,
88 wheel,
89 index: HashMap::new(),
90 next_id: 1,
91 last_tick_ms: 0,
92 }
93 }
94
95 fn instant_to_ms(&self, t: Instant) -> u64 {
97 t.saturating_duration_since(self.origin)
98 .as_millis()
99 .try_into()
100 .unwrap_or(u64::MAX)
101 }
102
103 pub(crate) fn insert(&mut self, deadline: Instant, waker: Waker) -> TimerId {
106 let id = self.next_id;
107 self.next_id += 1;
108
109 let deadline_ms = self.instant_to_ms(deadline);
110 let effective_ms = deadline_ms.max(self.last_tick_ms);
112
113 let (level, slot) = self.level_slot(effective_ms);
114 self.wheel[level][slot].push(TimerEntry {
115 id,
116 deadline,
117 waker,
118 });
119 self.index.insert(id, (level, slot));
120
121 TimerId(id)
122 }
123
124 pub(crate) fn cancel(&mut self, id: TimerId) -> bool {
127 let Some((level, slot)) = self.index.remove(&id.0) else {
128 return false;
129 };
130 let bucket = &mut self.wheel[level][slot];
131 let before = bucket.len();
132 bucket.retain(|e| e.id != id.0);
133 bucket.len() < before
134 }
135
136 pub(crate) fn tick(&mut self, now: Instant) -> Vec<Waker> {
139 let now_ms = self.instant_to_ms(now);
140 let mut fired: Vec<Waker> = Vec::new();
141
142 let from = self.last_tick_ms;
145 let to = now_ms;
146
147 if to < from {
148 return fired; }
150
151 let mut t = from;
155 loop {
156 let slot0 = (t & SLOTS_MASK) as usize;
158 let entries = std::mem::take(&mut self.wheel[0][slot0]);
159 for entry in entries {
160 self.index.remove(&entry.id);
161 if self.instant_to_ms(entry.deadline) <= t {
164 fired.push(entry.waker);
165 } else {
166 self.insert_raw(entry);
168 }
169 }
170
171 for level in 1..LEVELS {
174 let width = slot_width_ms(level);
175 if t % width == 0 {
176 let slot = ((t / width) & SLOTS_MASK) as usize;
177 let entries = std::mem::take(&mut self.wheel[level][slot]);
178 for entry in entries {
179 self.index.remove(&entry.id);
180 if self.instant_to_ms(entry.deadline) <= t {
181 fired.push(entry.waker);
182 } else {
183 self.insert_raw(entry);
184 }
185 }
186 }
187 }
188
189 if t >= to {
190 break;
191 }
192 t += 1;
193 }
194
195 self.last_tick_ms = to;
196 fired
197 }
198
199 pub(crate) fn next_deadline(&self) -> Option<Instant> {
201 let mut earliest: Option<Instant> = None;
202 for level in &self.wheel {
203 for slot in level {
204 for entry in slot {
205 earliest = Some(match earliest {
206 None => entry.deadline,
207 Some(e) => e.min(entry.deadline),
208 });
209 }
210 }
211 }
212 earliest
213 }
214
215 fn insert_raw(&mut self, entry: TimerEntry) {
217 let deadline_ms = self.instant_to_ms(entry.deadline);
218 let effective_ms = deadline_ms.max(self.last_tick_ms);
219 let (level, slot) = self.level_slot(effective_ms);
220 self.index.insert(entry.id, (level, slot));
221 self.wheel[level][slot].push(entry);
222 }
223
224 fn level_slot(&self, deadline_ms: u64) -> (usize, usize) {
226 let delta = deadline_ms.saturating_sub(self.last_tick_ms);
227
228 for level in 0..LEVELS {
229 let width = slot_width_ms(level);
230 let range = width * SLOTS as u64;
231 if delta < range || level == LEVELS - 1 {
232 let slot = ((deadline_ms / width) & SLOTS_MASK) as usize;
234 return (level, slot);
235 }
236 }
237 (LEVELS - 1, 0)
239 }
240}
241
242#[cfg(test)]
245mod tests {
246 use super::*;
247 use std::sync::{Arc, Mutex};
248 use std::task::{RawWaker, RawWakerVTable};
249 use std::time::Duration;
250
251 fn make_flag_waker(flag: Arc<Mutex<bool>>) -> Waker {
252 let data = Arc::into_raw(flag) as *const ();
253
254 unsafe fn clone_w(p: *const ()) -> RawWaker {
255 Arc::increment_strong_count(p as *const Mutex<bool>);
256 RawWaker::new(p, &VT)
257 }
258 unsafe fn wake(p: *const ()) {
259 *Arc::from_raw(p as *const Mutex<bool>).lock().unwrap() = true;
260 }
261 unsafe fn wake_ref(p: *const ()) {
262 *(*(&p as *const *const () as *const Arc<Mutex<bool>>))
263 .lock()
264 .unwrap() = true;
265 }
266 unsafe fn drop_w(p: *const ()) {
267 drop(Arc::from_raw(p as *const Mutex<bool>));
268 }
269 static VT: RawWakerVTable = RawWakerVTable::new(clone_w, wake, wake_ref, drop_w);
270
271 unsafe { Waker::from_raw(RawWaker::new(data, &VT)) }
273 }
274
275 #[test]
276 fn insert_and_tick_fires_waker() {
277 let flag = Arc::new(Mutex::new(false));
278 let waker = make_flag_waker(Arc::clone(&flag));
279
280 let origin = Instant::now();
281 let mut wheel = TimerWheel::new(origin);
282
283 let deadline = origin + Duration::from_millis(50);
284 wheel.insert(deadline, waker);
285
286 let wakers = wheel.tick(origin + Duration::from_millis(30));
288 assert!(wakers.is_empty());
289
290 let wakers = wheel.tick(origin + Duration::from_millis(60));
292 assert_eq!(wakers.len(), 1);
293 for w in wakers {
294 w.wake();
295 }
296 assert!(*flag.lock().unwrap(), "waker must have fired");
297 }
298
299 #[test]
300 fn cancel_prevents_firing() {
301 let flag = Arc::new(Mutex::new(false));
302 let waker = make_flag_waker(Arc::clone(&flag));
303
304 let origin = Instant::now();
305 let mut wheel = TimerWheel::new(origin);
306
307 let deadline = origin + Duration::from_millis(50);
308 let id = wheel.insert(deadline, waker);
309 let removed = wheel.cancel(id);
310 assert!(removed, "cancel must return true for existing timer");
311
312 let wakers = wheel.tick(origin + Duration::from_millis(100));
314 assert!(wakers.is_empty(), "cancelled timer must not fire");
315 assert!(!*flag.lock().unwrap());
316 }
317
318 #[test]
319 fn zero_deadline_fires_on_next_tick() {
320 let flag = Arc::new(Mutex::new(false));
321 let waker = make_flag_waker(Arc::clone(&flag));
322
323 let origin = Instant::now();
324 let mut wheel = TimerWheel::new(origin);
325
326 wheel.insert(origin, waker);
328 let wakers = wheel.tick(origin + Duration::from_millis(1));
329 assert_eq!(wakers.len(), 1);
330 for w in wakers {
331 w.wake();
332 }
333 assert!(*flag.lock().unwrap());
334 }
335
336 #[test]
337 fn multiple_timers_fire_in_order() {
338 let origin = Instant::now();
339 let mut wheel = TimerWheel::new(origin);
340 let results = Arc::new(Mutex::new(Vec::<u32>::new()));
341
342 for i in 0u32..5 {
343 let r = Arc::clone(&results);
344 let flag = Arc::new(Mutex::new(false));
345 let _waker = make_flag_waker(Arc::clone(&flag));
346 let _ = flag; let data = Box::into_raw(Box::new((i, r))) as *const ();
349 type Payload = (u32, Arc<Mutex<Vec<u32>>>);
350 unsafe fn clone_p(p: *const ()) -> RawWaker {
351 let b = Box::from_raw(p as *mut Payload);
352 let cloned = Box::new((b.0, Arc::clone(&b.1)));
353 std::mem::forget(b);
354 RawWaker::new(Box::into_raw(cloned) as *const (), &PVT)
355 }
356 unsafe fn wake_p(p: *const ()) {
357 let b = Box::from_raw(p as *mut Payload);
358 b.1.lock().unwrap().push(b.0);
359 }
360 unsafe fn wake_p_ref(p: *const ()) {
361 let b = Box::from_raw(p as *mut Payload);
362 b.1.lock().unwrap().push(b.0);
363 std::mem::forget(b);
364 }
365 unsafe fn drop_p(p: *const ()) {
366 drop(Box::from_raw(p as *mut Payload));
367 }
368 static PVT: RawWakerVTable = RawWakerVTable::new(clone_p, wake_p, wake_p_ref, drop_p);
369 let waker2 = unsafe { Waker::from_raw(RawWaker::new(data, &PVT)) };
371
372 wheel.insert(origin + Duration::from_millis((i as u64 + 1) * 10), waker2);
373 }
374
375 let wakers = wheel.tick(origin + Duration::from_millis(60));
377 assert_eq!(wakers.len(), 5);
378 for w in wakers {
379 w.wake();
380 }
381 let v = results.lock().unwrap();
382 assert_eq!(v.len(), 5);
383 }
384
385 #[test]
386 fn next_deadline_returns_earliest() {
387 let origin = Instant::now();
388 let mut wheel = TimerWheel::new(origin);
389
390 let d1 = origin + Duration::from_millis(200);
391 let d2 = origin + Duration::from_millis(50);
392
393 let f1 = Arc::new(Mutex::new(false));
394 let f2 = Arc::new(Mutex::new(false));
395 wheel.insert(d1, make_flag_waker(Arc::clone(&f1)));
396 wheel.insert(d2, make_flag_waker(Arc::clone(&f2)));
397
398 let earliest = wheel.next_deadline().expect("should have a deadline");
399 assert_eq!(earliest, d2, "next_deadline must return earliest");
400 }
401}