ora_timer/
wheel.rs

1//! A hierarchical timing wheel based on <https://github.com/Bathtor/rust-hash-wheel-timer>.
2
3mod byte;
4
5use alloc::{boxed::Box, vec::Vec};
6use core::{marker::PhantomData, mem, time::Duration};
7
8use byte::ByteWheel;
9
10use crate::resolution::{MillisecondResolution, Resolution};
11
12/// A hierarchical timing wheel with a given entry type and resolution.
13#[must_use]
14#[derive(Debug)]
15pub struct TimingWheel<T, R = MillisecondResolution>
16where
17    R: Resolution,
18{
19    primary: Box<ByteWheel<T, [u8; 0]>>,
20    secondary: Box<ByteWheel<T, [u8; 1]>>,
21    tertiary: Box<ByteWheel<T, [u8; 2]>>,
22    quarternary: Box<ByteWheel<T, [u8; 3]>>,
23    // We use double buffering to avoid allocations in the tick function
24    // due to the overflow list being emptied.
25    overflow: Vec<OverflowEntry<T>>,
26    overflow_buf: Vec<OverflowEntry<T>>,
27    _resolution: PhantomData<R>,
28}
29
30impl<T, R> Default for TimingWheel<T, R>
31where
32    R: Resolution,
33{
34    fn default() -> Self {
35        TimingWheel::new()
36    }
37}
38
39impl<T, R> TimingWheel<T, R>
40where
41    R: Resolution,
42{
43    /// Create a new timing wheel.
44    pub fn new() -> Self {
45        TimingWheel {
46            primary: Box::new(ByteWheel::new()),
47            secondary: Box::new(ByteWheel::new()),
48            tertiary: Box::new(ByteWheel::new()),
49            quarternary: Box::new(ByteWheel::new()),
50            overflow: Vec::new(),
51            overflow_buf: Vec::new(),
52            _resolution: PhantomData,
53        }
54    }
55
56    /// Returns the entry if it has already expired.
57    #[allow(clippy::cast_possible_truncation)]
58    pub fn insert(&mut self, entry: T, delay: Duration) -> Option<T> {
59        if delay > R::MAX_DURATION {
60            let remaining_delay = R::steps_as_duration(self.remaining_time_in_cycle());
61            let new_delay = delay - remaining_delay;
62            let overflow_e = OverflowEntry::new(entry, new_delay);
63            self.overflow.push(overflow_e);
64            None
65        } else {
66            let delay = R::cycle_steps(&delay, true);
67            let current_time = self.cycle_timestamp();
68            let absolute_time = delay.wrapping_add(current_time);
69            let absolute_bytes: [u8; 4] = absolute_time.to_be_bytes();
70            let zero_time = absolute_time ^ current_time; // a-b%2
71            let zero_bytes: [u8; 4] = zero_time.to_be_bytes();
72            match zero_bytes {
73                [0, 0, 0, 0] => Some(entry),
74                [0, 0, 0, _] => {
75                    self.primary.insert(absolute_bytes[3], entry, []);
76                    None
77                }
78                [0, 0, _, _] => {
79                    self.secondary
80                        .insert(absolute_bytes[2], entry, [absolute_bytes[3]]);
81                    None
82                }
83                [0, _, _, _] => {
84                    self.tertiary.insert(
85                        absolute_bytes[1],
86                        entry,
87                        [absolute_bytes[2], absolute_bytes[3]],
88                    );
89                    None
90                }
91                [_, _, _, _] => {
92                    self.quarternary.insert(
93                        absolute_bytes[0],
94                        entry,
95                        [absolute_bytes[1], absolute_bytes[2], absolute_bytes[3]],
96                    );
97                    None
98                }
99            }
100        }
101    }
102
103    /// Advance the timing wheel and collect all entries that have been expired.
104    pub fn tick(&mut self) -> Vec<T> {
105        let mut res = Vec::new();
106        self.tick_with(&mut res);
107        res
108    }
109
110    /// Advance the timing wheel and collect all entries that have been expired.
111    pub fn tick_with(&mut self, res: &mut Vec<T>) {
112        // primary
113        let (move0, current0) = self.primary.tick();
114        res.extend(move0.map(|we| we.entry));
115        if current0 == 0u8 {
116            // secondary
117            let (move1, current1) = self.secondary.tick();
118            // Don't bother reserving, as most of the values will likely be redistributed over the primary wheel instead of being returned
119            for we in move1 {
120                if we.rest[0] == 0u8 {
121                    res.push(we.entry);
122                } else {
123                    self.primary.insert(we.rest[0], we.entry, []);
124                }
125            }
126            if current1 == 0u8 {
127                // tertiary
128                let (move2, current2) = self.tertiary.tick();
129                for we in move2 {
130                    match we.rest {
131                        [0, 0] => {
132                            res.push(we.entry);
133                        }
134                        [0, b0] => {
135                            self.primary.insert(b0, we.entry, []);
136                        }
137                        [b1, b0] => {
138                            self.secondary.insert(b1, we.entry, [b0]);
139                        }
140                    }
141                }
142                if current2 == 0u8 {
143                    // quarternary
144                    let (move3, current3) = self.quarternary.tick();
145                    for we in move3 {
146                        match we.rest {
147                            [0, 0, 0] => {
148                                res.push(we.entry);
149                            }
150                            [0, 0, b0] => {
151                                self.primary.insert(b0, we.entry, []);
152                            }
153                            [0, b1, b0] => {
154                                self.secondary.insert(b1, we.entry, [b0]);
155                            }
156                            [b2, b1, b0] => {
157                                self.tertiary.insert(b2, we.entry, [b1, b0]);
158                            }
159                        }
160                    }
161                    if current3 == 0u8 {
162                        // overflow list
163                        if !self.overflow.is_empty() {
164                            mem::swap(&mut self.overflow, &mut self.overflow_buf);
165                            let mut overflow_buf = mem::take(&mut self.overflow_buf);
166                            for overflow_e in overflow_buf.drain(..) {
167                                if let Some(entry) =
168                                    self.insert(overflow_e.entry, overflow_e.remaining_delay)
169                                {
170                                    res.push(entry);
171                                }
172                            }
173                            self.overflow_buf = overflow_buf;
174                        }
175                    }
176                }
177            }
178        }
179    }
180
181    /// Skip `amount` steps, note that this will succeed
182    /// and no checks will take place.
183    ///
184    /// Use [`TimingWheel::can_skip`] to determine if this function
185    /// can be used without silently dropping any entries that
186    /// have not been expired.
187    pub fn skip(&mut self, amount: u32) {
188        let new_time = self.cycle_timestamp().wrapping_add(amount);
189        let new_time_bytes: [u8; 4] = new_time.to_be_bytes();
190        self.primary.set_current(new_time_bytes[3]);
191        self.secondary.set_current(new_time_bytes[2]);
192        self.tertiary.set_current(new_time_bytes[1]);
193        self.quarternary.set_current(new_time_bytes[0]);
194    }
195
196    /// Returns how many steps can be skipped safely without
197    /// missing entries.
198    #[must_use]
199    #[allow(clippy::cast_possible_truncation, clippy::cast_lossless)]
200    pub fn can_skip(&self) -> u32 {
201        if self.primary.is_empty() {
202            if self.secondary.is_empty() {
203                if self.tertiary.is_empty() {
204                    if self.quarternary.is_empty() {
205                        if self.overflow.is_empty() {
206                            0
207                        } else {
208                            (self.remaining_time_in_cycle() - 1u64) as u32
209                        }
210                    } else {
211                        let tertiary_current = self.cycle_timestamp() & (TERTIARY_LENGTH - 1u32);
212                        let rem = TERTIARY_LENGTH - tertiary_current;
213                        rem - 1u32
214                    }
215                } else {
216                    let secondary_current = self.cycle_timestamp() & (SECONDARY_LENGTH - 1u32);
217                    let rem = SECONDARY_LENGTH - secondary_current;
218                    rem - 1u32
219                }
220            } else {
221                let primary_current = self.primary.current() as u32;
222                let rem = PRIMARY_LENGTH - primary_current;
223                rem - 1u32
224            }
225        } else {
226            0
227        }
228    }
229
230    /// Return the amount of entries in the wheel.
231    #[must_use]
232    pub fn len(&self) -> usize {
233        self.primary.len()
234            + self.secondary.len()
235            + self.tertiary.len()
236            + self.quarternary.len()
237            + self.overflow.len()
238    }
239
240    /// Return whether the wheel is empty.
241    #[must_use]
242    pub fn is_empty(&self) -> bool {
243        self.len() == 0
244    }
245
246    /// Due to double buffering the overflow list
247    /// is never shrunk.
248    ///
249    /// This function can be used to shrink the overflow
250    /// list if its capacity equals or is over the given threshold.
251    pub fn gc(&mut self, threshold: usize) {
252        if self.overflow.capacity() >= threshold {
253            self.overflow.shrink_to_fit();
254            self.overflow_buf.shrink_to_fit();
255        }
256    }
257
258    #[allow(clippy::cast_lossless)]
259    fn remaining_time_in_cycle(&self) -> u64 {
260        CYCLE_LENGTH - (self.cycle_timestamp() as u64)
261    }
262
263    #[must_use]
264    fn cycle_timestamp(&self) -> u32 {
265        let time_bytes = [
266            self.quarternary.current(),
267            self.tertiary.current(),
268            self.secondary.current(),
269            self.primary.current(),
270        ];
271        u32::from_be_bytes(time_bytes)
272    }
273}
274
275const CYCLE_LENGTH: u64 = 1 << 32; // 2^32
276const PRIMARY_LENGTH: u32 = 1 << 8; // 2^8
277const SECONDARY_LENGTH: u32 = 1 << 16; // 2^16
278const TERTIARY_LENGTH: u32 = 1 << 24; // 2^24
279
280#[derive(Debug)]
281struct OverflowEntry<T> {
282    entry: T,
283    remaining_delay: Duration,
284}
285impl<T> OverflowEntry<T> {
286    fn new(entry: T, remaining_delay: Duration) -> Self {
287        OverflowEntry {
288            entry,
289            remaining_delay,
290        }
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use crate::resolution::MillisecondResolution;
298
299    #[test]
300    fn smoke_millis() {
301        let mut wheel: TimingWheel<usize, MillisecondResolution> = TimingWheel::new();
302        assert!(wheel.insert(0, Duration::ZERO).is_some());
303
304        assert!(wheel.insert(0, Duration::from_millis(1)).is_none());
305        assert_eq!(wheel.len(), 1);
306        assert_eq!(wheel.tick().pop().unwrap(), 0);
307
308        assert!(wheel.insert(0, Duration::from_millis(10)).is_none());
309        assert_eq!(wheel.len(), 1);
310        assert_eq!(wheel.can_skip(), 0);
311    }
312
313    #[test]
314    fn skip_millis() {
315        let mut wheel: TimingWheel<usize, MillisecondResolution> = TimingWheel::new();
316        assert!(wheel.insert(0, Duration::from_millis(0xFF)).is_none());
317        assert_eq!(wheel.len(), 1);
318        assert_eq!(wheel.can_skip(), 0);
319
320        let mut wheel: TimingWheel<usize, MillisecondResolution> = TimingWheel::new();
321        assert!(wheel.insert(0, Duration::from_millis(0x100)).is_none());
322        assert_eq!(wheel.len(), 1);
323        assert_eq!(wheel.can_skip(), 0xFF);
324        wheel.skip(0xFF);
325        assert_eq!(wheel.tick().pop().unwrap(), 0);
326
327        let mut wheel: TimingWheel<usize, MillisecondResolution> = TimingWheel::new();
328        assert!(wheel.insert(0, Duration::from_millis(0x10000)).is_none());
329        assert_eq!(wheel.len(), 1);
330        assert_eq!(wheel.can_skip(), 0xFFFF);
331        wheel.skip(0xFFFF);
332        assert_eq!(wheel.tick().pop().unwrap(), 0);
333
334        let mut wheel: TimingWheel<usize, MillisecondResolution> = TimingWheel::new();
335        assert!(wheel
336            .insert(0, Duration::from_millis(0x0100_0000))
337            .is_none());
338        assert_eq!(wheel.len(), 1);
339        assert_eq!(wheel.can_skip(), 0xFF_FFFF);
340        wheel.skip(0xFF_FFFF);
341        assert_eq!(wheel.tick().pop().unwrap(), 0);
342
343        let mut wheel: TimingWheel<usize, MillisecondResolution> = TimingWheel::new();
344        assert!(wheel
345            .insert(0, Duration::from_millis(0x0001_0000_0000))
346            .is_none());
347        assert_eq!(wheel.len(), 1);
348        assert_eq!(wheel.can_skip(), 0xFFFF_FFFF);
349        wheel.skip(0xFFFF_FFFF);
350        assert_eq!(wheel.tick().pop().unwrap(), 0);
351
352        let mut wheel: TimingWheel<usize, MillisecondResolution> = TimingWheel::new();
353        assert!(wheel
354            .insert(0, Duration::from_millis(0x100_0000_0000))
355            .is_none());
356        assert_eq!(wheel.len(), 1);
357        assert_eq!(wheel.can_skip(), 0xFFFF_FFFF);
358        wheel.skip(0xFFFF_FFFF);
359        // The value was in the overflow list that just got into the wheel
360        assert!(wheel.tick().is_empty());
361        assert_eq!(wheel.can_skip(), 0xFFFF_FFFF);
362        wheel.skip(0xFFFF_FFFF);
363
364        // We cannot skip as we need to check the overflow list
365        assert_eq!(wheel.can_skip(), 0);
366        wheel.tick();
367        // Value is still in the overflow list, and will stay there for
368        // quite a while.
369        assert_eq!(wheel.can_skip(), 0xFFFF_FFFF);
370    }
371}