1use std::cell::{RefCell, RefMut};
2use std::error::Error;
3use std::ops::Deref;
4use std::rc::Rc;
5use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
6use std::sync::mpsc;
7use std::sync::mpsc::{Receiver, Sender};
8use std::sync::{Arc, Condvar, Mutex};
9use std::thread;
10use std::time::{Duration, SystemTime};
11
12const WORKER_STATE_INIT: u8 = 0;
13const WORKER_STATE_STARTED: u8 = 1;
14const WORKER_STATE_SHUTDOWN: u8 = 2;
15
16#[derive(Clone)]
17pub struct WheelTimer {
18 worker_state: Arc<AtomicU8>, start_time: u64,
20 tick_duration: u64, ticks_per_wheel: u32,
22 mask: u64,
23 condvar: Arc<(Mutex<u64>, Condvar)>,
24 sender: Option<Sender<WheelTimeout>>,
25}
26
27impl WheelTimer {
28 pub fn new(tick_duration: u64, ticks_per_wheel: u32) -> Result<WheelTimer, Box<dyn Error>> {
29 if tick_duration <= 0 {
30 return Err(format!("tickDuration must be greater than 0: {}", tick_duration).into());
31 }
32 if ticks_per_wheel <= 0 {
33 return Err(
34 format!("ticksPerWheel must be greater than 0: {}", ticks_per_wheel).into(),
35 );
36 }
37 if ticks_per_wheel > 1073741824 {
38 return Err(format!(
39 "ticksPerWheel may not be greater than 2^30: {}",
40 ticks_per_wheel
41 )
42 .into());
43 }
44 let ticks_per_wheel = normalize_ticks_per_wheel(ticks_per_wheel);
45 let mask = (ticks_per_wheel - 1) as u64;
46
47 if tick_duration >= u64::MAX / ticks_per_wheel as u64 {
49 return Err(format!(
50 "tickDuration: {} (expected: 0 < tickDuration in nanos < {}",
51 tick_duration,
52 u64::MAX / ticks_per_wheel as u64
53 )
54 .into());
55 }
56 let mut timer = WheelTimer {
57 worker_state: Arc::new(AtomicU8::new(WORKER_STATE_INIT)),
58 start_time: 0,
59 tick_duration,
60 ticks_per_wheel,
61 mask,
62 condvar: Arc::new((Mutex::new(0), Condvar::new())),
63 sender: None,
64 };
65 timer.start().unwrap();
66 Ok(timer)
67 }
68
69 pub fn start(&mut self) -> Result<(), Box<dyn Error + '_>> {
70 match self.worker_state.load(Ordering::SeqCst) {
71 WORKER_STATE_INIT => {
72 let ret = self.worker_state.compare_exchange(
73 WORKER_STATE_INIT,
74 WORKER_STATE_STARTED,
75 Ordering::SeqCst,
76 Ordering::Acquire,
77 );
78 match ret {
79 Ok(_) => {
80 let (tx, rx) = mpsc::channel();
81 self.sender = Some(tx);
82 let worker_state = self.worker_state.clone();
83 let condvar = self.condvar.clone();
84 let tick_duration = self.tick_duration;
85 let mask = self.mask;
86 let ticks_per_wheel = self.ticks_per_wheel;
87
88 thread::spawn(move || {
89 let mut worker = Worker::new(
90 worker_state,
91 condvar,
92 tick_duration,
93 mask,
94 ticks_per_wheel,
95 rx,
96 );
97 worker.start();
98 });
99 }
100 Err(_) => {
101 }
103 }
104 }
105 WORKER_STATE_STARTED => {
106 }
108 WORKER_STATE_SHUTDOWN => return Err("cannot be started once stopped".into()),
109 _ => return Err("Invalid worker state".into()),
110 }
111 let (lock, condvar) = self.condvar.deref();
113 let mut guard = lock.lock()?;
114 while *guard == 0 {
115 guard = condvar.wait(guard)?;
116 self.start_time = *guard;
117 }
118 Ok(())
119 }
120
121 pub fn stop(&self) {
122 let ret = self.worker_state.compare_exchange(
123 WORKER_STATE_STARTED,
124 WORKER_STATE_SHUTDOWN,
125 Ordering::SeqCst,
126 Ordering::Acquire,
127 );
128 match ret {
129 Ok(_) => {
130 }
132 Err(_) => {
133 self.worker_state
134 .swap(WORKER_STATE_SHUTDOWN, Ordering::SeqCst);
135 }
136 }
137 }
138
139 pub fn new_timeout(&mut self, task: Box<dyn TimerTask + Send>, delay: Duration) {
140 let deadline = system_time_unix() + delay.as_millis() as u64 - self.start_time;
141 let timeout = WheelTimeout::new(task, deadline);
142 let sender = self.sender.as_ref().unwrap();
143 sender.send(timeout).unwrap();
144 }
145}
146
147fn normalize_ticks_per_wheel(ticks_per_wheel: u32) -> u32 {
148 let mut normalized_ticks_per_wheel = 1;
149 while normalized_ticks_per_wheel < ticks_per_wheel {
150 normalized_ticks_per_wheel <<= 1;
151 }
152 normalized_ticks_per_wheel
153}
154
155struct Worker {
156 worker_state: Arc<AtomicU8>,
157 condvar: Arc<(Mutex<u64>, Condvar)>,
158 tick: u64,
159 tick_duration: u64,
160 mask: u64,
161 wheel: Vec<WheelBucket>,
162 start_time: u64,
163 receiver: Receiver<WheelTimeout>,
164 last_task_id: AtomicU64,
165}
166
167impl Worker {
168 fn new(
169 worker_state: Arc<AtomicU8>,
170 condvar: Arc<(Mutex<u64>, Condvar)>,
171 tick_duration: u64,
172 mask: u64,
173 ticks_per_wheel: u32,
174 rx: Receiver<WheelTimeout>,
175 ) -> Worker {
176 let wheel = create_wheel(ticks_per_wheel).unwrap();
177
178 Worker {
179 worker_state,
180 condvar,
181 tick: 0,
182 tick_duration,
183 mask,
184 wheel,
185 start_time: 0,
186 receiver: rx,
187 last_task_id: AtomicU64::new(1),
188 }
189 }
190
191 fn start(&mut self) {
192 let mut start_time = system_time_unix();
194 if start_time == 0 {
195 start_time = 1;
196 }
197 self.start_time = start_time;
198
199 let (lock, condvar) = self.condvar.deref();
201 let mut guard = lock.lock().unwrap();
202 *guard = start_time;
203 condvar.notify_one();
204 drop(guard);
205
206 while self.worker_state.load(Ordering::SeqCst) == WORKER_STATE_STARTED {
207 let deadline = self.wait_for_next_tick();
208 if deadline > 0 {
209 self.transfer_timeouts_to_buckets();
210 let idx = self.tick & self.mask;
211 let bucket = self.wheel.get_mut(idx as usize).unwrap();
212 bucket.expire_timeouts(deadline);
213 self.tick += 1;
214 }
215 }
216 println!("Worker shutdown")
217 }
218
219 fn wait_for_next_tick(&self) -> u64 {
220 let deadline = (self.tick_duration * (self.tick + 1)) as i64;
221 loop {
222 let current_time = (system_time_unix() - self.start_time) as i64;
223 let sleep_time_ms = (deadline - current_time + 999999) / 1000000;
224
225 if sleep_time_ms <= 0 {
226 return current_time as u64;
227 }
228 thread::sleep(Duration::new(0, (sleep_time_ms * 1000000) as u32));
229 }
230 }
231
232 fn transfer_timeouts_to_buckets(&mut self) {
233 for _ in 0..100000 {
234 match self.receiver.try_recv() {
235 Ok(timeout) => {
236 let task_id = self.last_task_id.fetch_add(1, Ordering::SeqCst);
237 let calculated = timeout.deadline / self.tick_duration;
238
239 let mut bucket_timeout =
240 BucketTimeout::new(task_id, timeout.task, timeout.deadline);
241 bucket_timeout.remaining_rounds =
242 (calculated - self.tick) / self.wheel.len() as u64;
243
244 let mut ticks = self.tick;
245 if calculated > self.tick {
246 ticks = calculated;
247 }
248 let stop_index = ticks & self.mask;
249
250 let bucket = self.wheel.get_mut(stop_index as usize).unwrap();
251 bucket.add_timeout(bucket_timeout);
252 }
253 Err(_) => {
254 break;
255 }
256 }
257 }
258 }
259}
260
261fn create_wheel(ticks_per_wheel: u32) -> Result<Vec<WheelBucket>, Box<dyn Error>> {
262 let mut wheel = Vec::with_capacity(ticks_per_wheel as usize);
263 for _ in 0..ticks_per_wheel {
264 wheel.push(WheelBucket {
265 head: None,
266 tail: None,
267 })
268 }
269 Ok(wheel)
270}
271
272pub fn system_time_unix() -> u64 {
273 SystemTime::now()
274 .duration_since(SystemTime::UNIX_EPOCH)
275 .unwrap()
276 .as_millis() as u64
277}
278
279pub struct WheelBucket {
280 head: Option<Rc<RefCell<BucketTimeout>>>,
281 tail: Option<Rc<RefCell<BucketTimeout>>>,
282}
283
284impl WheelBucket {
285 fn add_timeout(&mut self, mut timeout: BucketTimeout) {
286 match self.head.as_ref() {
287 None => {
288 let rc_timeout = Rc::new(RefCell::new(timeout));
289 self.head = Some(rc_timeout.clone());
290 self.tail = Some(rc_timeout);
291 }
292 Some(_) => {
293 let rc_tail = self.tail.as_ref().unwrap().clone();
294 timeout.prev = Some(rc_tail);
295 let rc_timeout = Rc::new(RefCell::new(timeout));
296 {
297 let mut tail = self.tail.as_ref().unwrap().deref().borrow_mut();
298 tail.next = Some(rc_timeout.clone());
299 }
300 self.tail = Some(rc_timeout);
301 }
302 }
303 }
304
305 fn expire_timeouts(&mut self, deadline: u64) {
306 let mut current = self.head.clone();
307 loop {
308 match current {
309 None => {
310 return;
311 }
312 Some(timeout) => {
313 let mut next = RefCell::borrow(&timeout).next.clone();
314
315 let mut timeout_mut = RefCell::borrow_mut(&timeout);
316 if timeout_mut.remaining_rounds <= 0 {
317 next = self.remove(timeout_mut);
318
319 let mut timeout_mut = RefCell::borrow_mut(&timeout);
320 timeout_mut.prev = None;
321 timeout_mut.next = None;
322 if timeout_mut.deadline <= deadline {
323 timeout_mut.expire();
324 } else {
325 panic!(
327 "timeout.deadline {} > deadline {}",
328 timeout_mut.deadline, deadline
329 )
330 }
331 } else if timeout_mut.is_cancelled() {
332 next = self.remove(timeout_mut);
333 } else {
334 timeout_mut.remaining_rounds -= 1;
335 }
336 current = next;
337 }
338 }
339 }
340 }
341
342 fn remove(&mut self, timeout: RefMut<BucketTimeout>) -> Option<Rc<RefCell<BucketTimeout>>> {
343 let prev = timeout.prev.clone();
344 let next = timeout.next.clone();
345 match prev.clone() {
346 None => {}
347 Some(v) => {
348 let mut prev = v.deref().borrow_mut();
349 prev.next = next.clone();
350 }
351 }
352 match next.clone() {
353 None => {}
354 Some(v) => {
355 let mut next = v.deref().borrow_mut();
356 next.prev = prev.clone()
357 }
358 }
359 let task_id = timeout.task_id;
360 drop(timeout);
362
363 let head_task_id = self.head.as_ref().unwrap().deref().borrow().task_id;
364 let tail_task_id = self.tail.as_ref().unwrap().deref().borrow().task_id;
365 if task_id == head_task_id {
366 if task_id == tail_task_id {
367 self.tail = None;
368 self.head = None;
369 } else {
370 self.head = next.clone()
371 }
372 } else if task_id == tail_task_id {
373 self.tail = prev.clone();
374 }
375 next
376 }
377}
378
379const ST_INIT: u8 = 0;
380const ST_CANCELLED: u8 = 1;
381const ST_EXPIRED: u8 = 2;
382
383struct BucketTimeout {
384 task_id: u64,
385 state: AtomicU8, deadline: u64,
387 remaining_rounds: u64,
388 task: Box<dyn TimerTask + Send>,
389 prev: Option<Rc<RefCell<BucketTimeout>>>,
390 next: Option<Rc<RefCell<BucketTimeout>>>,
391}
392
393impl BucketTimeout {
394 fn new(task_id: u64, task: Box<dyn TimerTask + Send>, deadline: u64) -> BucketTimeout {
395 BucketTimeout {
396 task_id,
397 state: AtomicU8::new(ST_INIT),
398 deadline,
399 task,
400 remaining_rounds: 0,
401 prev: None,
402 next: None,
403 }
404 }
405
406 fn state(&self) -> u8 {
407 self.state.load(Ordering::SeqCst)
408 }
409
410 fn is_cancelled(&self) -> bool {
411 self.state() == ST_CANCELLED
412 }
413
414 fn compare_exchange(&self, expected: u8, state: u8) -> bool {
415 let ret = self
416 .state
417 .compare_exchange(expected, state, Ordering::SeqCst, Ordering::Acquire);
418 match ret {
419 Err(_) => false,
420 Ok(_) => true,
421 }
422 }
423
424 fn expire(&mut self) {
425 if !self.compare_exchange(ST_INIT, ST_EXPIRED) {
426 return;
427 }
428 self.task.run();
429 }
430}
431
432pub trait TimerTask {
433 fn run(&mut self);
434}
435
436struct WheelTimeout {
437 deadline: u64,
438 task: Box<dyn TimerTask + Send>,
439}
440
441impl WheelTimeout {
442 fn new(task: Box<dyn TimerTask + Send>, deadline: u64) -> WheelTimeout {
443 WheelTimeout { deadline, task }
444 }
445}