okee_wheel_timer/time_wheel/
core.rs1use std::collections::HashMap;
2
3use crate::event::{Event, EventId};
4
5use super::types::{ScheduleResult, TimeWheelError, UpdateResult};
6
7type Bucket<T> = HashMap<EventId, Event<T>>;
8
9#[derive(Debug, Clone, Copy)]
10struct IdEntry {
11 bucket_index: usize,
12}
13
14#[derive(Debug)]
15pub struct HashedWheelTimer<T> {
24 curr_tick: u64,
25 curr_bucket: usize,
26 curr_delta_tick: u64,
27 curr_sequence_id: EventId,
28 buckets: Vec<Bucket<T>>,
29 by_id: HashMap<EventId, IdEntry>,
30}
31
32impl<T> HashedWheelTimer<T> {
33 pub fn new(buckets_num: usize) -> Self {
37 assert!(buckets_num > 0, "buckets_num must be greater than 0");
38
39 Self {
40 curr_tick: 0,
41 curr_bucket: 0,
42 curr_delta_tick: 0,
43 curr_sequence_id: 0,
44 buckets: build_buckets(buckets_num),
45 by_id: HashMap::new(),
46 }
47 }
48
49 pub fn count_all(&self) -> usize {
51 self.by_id.len()
52 }
53
54 pub fn count_in_bucket(&self, bucket_index: usize) -> Result<usize, TimeWheelError> {
58 self.buckets
59 .get(bucket_index)
60 .map(|bucket| bucket.len())
61 .ok_or(TimeWheelError::InvalidBucketIndex {
62 index: bucket_index,
63 buckets: self.buckets.len(),
64 })
65 }
66
67 pub fn is_empty(&self) -> bool {
69 self.by_id.is_empty()
70 }
71
72 pub fn is_empty_bucket(&self, bucket_index: usize) -> Result<bool, TimeWheelError> {
74 self.count_in_bucket(bucket_index).map(|count| count == 0)
75 }
76
77 pub fn has_events_in_current_tick(&self) -> bool {
79 self.find_min_delta_tick().is_some()
80 }
81
82 pub fn curr_tick(&self) -> u64 {
84 self.curr_tick
85 }
86
87 pub fn curr_bucket(&self) -> usize {
89 self.curr_bucket
90 }
91
92 pub fn curr_delta_tick(&self) -> u64 {
94 self.curr_delta_tick
95 }
96
97 pub fn curr_seq_id(&self) -> EventId {
99 self.curr_sequence_id
100 }
101
102 pub fn schedule(&mut self, on_tick: u64, data: T) -> ScheduleResult {
106 self.curr_sequence_id += 1;
107 self.insert(self.curr_sequence_id, on_tick, data)
108 }
109
110 pub fn contains(&self, id: EventId) -> bool {
112 self.by_id.contains_key(&id)
113 }
114
115 pub fn get(&self, id: EventId) -> Option<&Event<T>> {
117 let meta = self.by_id.get(&id)?;
118 self.buckets[meta.bucket_index].get(&id)
119 }
120
121 pub fn remove(&mut self, id: EventId) -> Option<Event<T>> {
123 self.remove_internal(id)
124 }
125
126 pub fn update(&mut self, id: EventId, on_tick: u64, data: T) -> Option<UpdateResult> {
128 self.remove_internal(id)?;
129 let inserted = self.insert(id, on_tick, data);
130 Some(UpdateResult { id: inserted.id })
131 }
132
133 pub fn reschedule(&mut self, id: EventId, on_tick: u64) -> Option<UpdateResult> {
135 let old_event = self.remove_internal(id)?;
136 let inserted = self.insert(id, on_tick, old_event.into_data());
137 Some(UpdateResult { id: inserted.id })
138 }
139
140 pub fn pop_events(&mut self) -> Vec<Event<T>> {
144 let Some(min_delta) = self.find_min_delta_tick() else {
145 return Vec::new();
146 };
147
148 self.curr_delta_tick = min_delta;
149 self.drain_wave()
150 }
151
152 pub fn step(&mut self) {
154 self.curr_tick += 1;
155 self.curr_bucket = (self.curr_bucket + 1) % self.buckets.len();
156 self.curr_delta_tick = 0;
157 }
158
159 pub fn reset(&mut self) {
161 self.curr_tick = 0;
162 self.curr_bucket = 0;
163 self.curr_delta_tick = 0;
164 self.curr_sequence_id = 0;
165 self.buckets = build_buckets(self.buckets.len());
166 self.by_id.clear();
167 }
168
169 fn insert(&mut self, event_id: EventId, on_tick: u64, data: T) -> ScheduleResult {
170 let tick = on_tick.max(self.curr_tick);
171 let delta_tick = if tick == self.curr_tick {
172 self.curr_delta_tick + 1
173 } else {
174 0
175 };
176
177 let bucket_index = self.bucket_index(tick);
178 let event = Event::new(event_id, tick, delta_tick, data);
179
180 self.buckets[bucket_index].insert(event_id, event);
181 self.by_id.insert(event_id, IdEntry { bucket_index });
182
183 ScheduleResult { id: event_id }
184 }
185
186 fn remove_internal(&mut self, id: EventId) -> Option<Event<T>> {
187 let meta = self.by_id.remove(&id)?;
188 self.buckets[meta.bucket_index].remove(&id)
189 }
190
191 fn find_min_delta_tick(&self) -> Option<u64> {
192 self.buckets[self.curr_bucket]
193 .values()
194 .filter_map(|event| (event.tick() == self.curr_tick).then_some(event.delta_tick()))
195 .min()
196 }
197
198 fn drain_wave(&mut self) -> Vec<Event<T>> {
199 let mut event_ids: Vec<EventId> = self.buckets[self.curr_bucket]
200 .iter()
201 .filter_map(|(id, event)| {
202 (event.tick() == self.curr_tick && event.delta_tick() == self.curr_delta_tick)
203 .then_some(*id)
204 })
205 .collect();
206
207 event_ids.sort_unstable();
208
209 event_ids
210 .into_iter()
211 .filter_map(|id| self.remove_internal(id))
212 .collect()
213 }
214
215 fn bucket_index(&self, tick: u64) -> usize {
216 (tick % self.buckets.len() as u64) as usize
217 }
218}
219
220fn build_buckets<T>(buckets_num: usize) -> Vec<Bucket<T>> {
221 let mut buckets = Vec::with_capacity(buckets_num);
222 for _ in 0..buckets_num {
223 buckets.push(HashMap::new());
224 }
225 buckets
226}