hashed_wheel_timer/
lib.rs

1use std::cell::{RefCell, RefMut};
2use std::error::Error;
3use std::ops::Deref;
4use std::rc::Rc;
5use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
6use std::sync::mpsc;
7use std::sync::mpsc::{Receiver, Sender};
8use std::sync::{Arc, Condvar, Mutex};
9use std::thread;
10use std::time::{Duration, SystemTime};
11
12const WORKER_STATE_INIT: u8 = 0;
13const WORKER_STATE_STARTED: u8 = 1;
14const WORKER_STATE_SHUTDOWN: u8 = 2;
15
16#[derive(Clone)]
17pub struct WheelTimer {
18    worker_state: Arc<AtomicU8>, // 0 - init, 1 - started, 2 - shutdown
19    start_time: u64,
20    tick_duration: u64, // the duration between tick, time unit is millisecond
21    ticks_per_wheel: u32,
22    mask: u64,
23    condvar: Arc<(Mutex<u64>, Condvar)>,
24    sender: Option<Sender<WheelTimeout>>,
25}
26
27impl WheelTimer {
28    pub fn new(tick_duration: u64, ticks_per_wheel: u32) -> Result<WheelTimer, Box<dyn Error>> {
29        if tick_duration <= 0 {
30            return Err(format!("tickDuration must be greater than 0: {}", tick_duration).into());
31        }
32        if ticks_per_wheel <= 0 {
33            return Err(
34                format!("ticksPerWheel must be greater than 0: {}", ticks_per_wheel).into(),
35            );
36        }
37        if ticks_per_wheel > 1073741824 {
38            return Err(format!(
39                "ticksPerWheel may not be greater than 2^30: {}",
40                ticks_per_wheel
41            )
42                .into());
43        }
44        let ticks_per_wheel = normalize_ticks_per_wheel(ticks_per_wheel);
45        let mask = (ticks_per_wheel - 1) as u64;
46
47        // Prevent overflow
48        if tick_duration >= u64::MAX / ticks_per_wheel as u64 {
49            return Err(format!(
50                "tickDuration: {} (expected: 0 < tickDuration in nanos < {}",
51                tick_duration,
52                u64::MAX / ticks_per_wheel as u64
53            )
54                .into());
55        }
56        let mut timer = WheelTimer {
57            worker_state: Arc::new(AtomicU8::new(WORKER_STATE_INIT)),
58            start_time: 0,
59            tick_duration,
60            ticks_per_wheel,
61            mask,
62            condvar: Arc::new((Mutex::new(0), Condvar::new())),
63            sender: None,
64        };
65        timer.start().unwrap();
66        Ok(timer)
67    }
68
69    pub fn start(&mut self) -> Result<(), Box<dyn Error + '_>> {
70        match self.worker_state.load(Ordering::SeqCst) {
71            WORKER_STATE_INIT => {
72                let ret = self.worker_state.compare_exchange(
73                    WORKER_STATE_INIT,
74                    WORKER_STATE_STARTED,
75                    Ordering::SeqCst,
76                    Ordering::Acquire,
77                );
78                match ret {
79                    Ok(_) => {
80                        let (tx, rx) = mpsc::channel();
81                        self.sender = Some(tx);
82                        let worker_state = self.worker_state.clone();
83                        let condvar = self.condvar.clone();
84                        let tick_duration = self.tick_duration;
85                        let mask = self.mask;
86                        let ticks_per_wheel = self.ticks_per_wheel;
87
88                        thread::spawn(move || {
89                            let mut worker = Worker::new(
90                                worker_state,
91                                condvar,
92                                tick_duration,
93                                mask,
94                                ticks_per_wheel,
95                                rx,
96                            );
97                            worker.start();
98                        });
99                    }
100                    Err(_) => {
101                        // nothing to do
102                    }
103                }
104            }
105            WORKER_STATE_STARTED => {
106                // nothing to do
107            }
108            WORKER_STATE_SHUTDOWN => return Err("cannot be started once stopped".into()),
109            _ => return Err("Invalid worker state".into()),
110        }
111        // Wait worker thread initialize start_time finish
112        let (lock, condvar) = self.condvar.deref();
113        let mut guard = lock.lock()?;
114        while *guard == 0 {
115            guard = condvar.wait(guard)?;
116            self.start_time = *guard;
117        }
118        Ok(())
119    }
120
121    pub fn stop(&self) {
122        let ret = self.worker_state.compare_exchange(
123            WORKER_STATE_STARTED,
124            WORKER_STATE_SHUTDOWN,
125            Ordering::SeqCst,
126            Ordering::Acquire,
127        );
128        match ret {
129            Ok(_) => {
130                // releases resources
131            }
132            Err(_) => {
133                self.worker_state
134                    .swap(WORKER_STATE_SHUTDOWN, Ordering::SeqCst);
135            }
136        }
137    }
138
139    pub fn new_timeout(&mut self, task: Box<dyn TimerTask + Send>, delay: Duration) {
140        let deadline = system_time_unix() + delay.as_millis() as u64 - self.start_time;
141        let timeout = WheelTimeout::new(task, deadline);
142        let sender = self.sender.as_ref().unwrap();
143        sender.send(timeout).unwrap();
144    }
145}
146
147fn normalize_ticks_per_wheel(ticks_per_wheel: u32) -> u32 {
148    let mut normalized_ticks_per_wheel = 1;
149    while normalized_ticks_per_wheel < ticks_per_wheel {
150        normalized_ticks_per_wheel <<= 1;
151    }
152    normalized_ticks_per_wheel
153}
154
155struct Worker {
156    worker_state: Arc<AtomicU8>,
157    condvar: Arc<(Mutex<u64>, Condvar)>,
158    tick: u64,
159    tick_duration: u64,
160    mask: u64,
161    wheel: Vec<WheelBucket>,
162    start_time: u64,
163    receiver: Receiver<WheelTimeout>,
164    last_task_id: AtomicU64,
165}
166
167impl Worker {
168    fn new(
169        worker_state: Arc<AtomicU8>,
170        condvar: Arc<(Mutex<u64>, Condvar)>,
171        tick_duration: u64,
172        mask: u64,
173        ticks_per_wheel: u32,
174        rx: Receiver<WheelTimeout>,
175    ) -> Worker {
176        let wheel = create_wheel(ticks_per_wheel).unwrap();
177
178        Worker {
179            worker_state,
180            condvar,
181            tick: 0,
182            tick_duration,
183            mask,
184            wheel,
185            start_time: 0,
186            receiver: rx,
187            last_task_id: AtomicU64::new(1),
188        }
189    }
190
191    fn start(&mut self) {
192        // Initialize the startTime.
193        let mut start_time = system_time_unix();
194        if start_time == 0 {
195            start_time = 1;
196        }
197        self.start_time = start_time;
198
199        // Notify the other thread waiting for the initialization at start()
200        let (lock, condvar) = self.condvar.deref();
201        let mut guard = lock.lock().unwrap();
202        *guard = start_time;
203        condvar.notify_one();
204        drop(guard);
205
206        while self.worker_state.load(Ordering::SeqCst) == WORKER_STATE_STARTED {
207            let deadline = self.wait_for_next_tick();
208            if deadline > 0 {
209                self.transfer_timeouts_to_buckets();
210                let idx = self.tick & self.mask;
211                let bucket = self.wheel.get_mut(idx as usize).unwrap();
212                bucket.expire_timeouts(deadline);
213                self.tick += 1;
214            }
215        }
216        println!("Worker shutdown")
217    }
218
219    fn wait_for_next_tick(&self) -> u64 {
220        let deadline = (self.tick_duration * (self.tick + 1)) as i64;
221        loop {
222            let current_time = (system_time_unix() - self.start_time) as i64;
223            let sleep_time_ms = (deadline - current_time + 999999) / 1000000;
224
225            if sleep_time_ms <= 0 {
226                return current_time as u64;
227            }
228            thread::sleep(Duration::new(0, (sleep_time_ms * 1000000) as u32));
229        }
230    }
231
232    fn transfer_timeouts_to_buckets(&mut self) {
233        for _ in 0..100000 {
234            match self.receiver.try_recv() {
235                Ok(timeout) => {
236                    let task_id = self.last_task_id.fetch_add(1, Ordering::SeqCst);
237                    let calculated = timeout.deadline / self.tick_duration;
238
239                    let mut bucket_timeout =
240                        BucketTimeout::new(task_id, timeout.task, timeout.deadline);
241                    bucket_timeout.remaining_rounds =
242                        (calculated - self.tick) / self.wheel.len() as u64;
243
244                    let mut ticks = self.tick;
245                    if calculated > self.tick {
246                        ticks = calculated;
247                    }
248                    let stop_index = ticks & self.mask;
249
250                    let bucket = self.wheel.get_mut(stop_index as usize).unwrap();
251                    bucket.add_timeout(bucket_timeout);
252                }
253                Err(_) => {
254                    break;
255                }
256            }
257        }
258    }
259}
260
261fn create_wheel(ticks_per_wheel: u32) -> Result<Vec<WheelBucket>, Box<dyn Error>> {
262    let mut wheel = Vec::with_capacity(ticks_per_wheel as usize);
263    for _ in 0..ticks_per_wheel {
264        wheel.push(WheelBucket {
265            head: None,
266            tail: None,
267        })
268    }
269    Ok(wheel)
270}
271
272pub fn system_time_unix() -> u64 {
273    SystemTime::now()
274        .duration_since(SystemTime::UNIX_EPOCH)
275        .unwrap()
276        .as_millis() as u64
277}
278
279pub struct WheelBucket {
280    head: Option<Rc<RefCell<BucketTimeout>>>,
281    tail: Option<Rc<RefCell<BucketTimeout>>>,
282}
283
284impl WheelBucket {
285    fn add_timeout(&mut self, mut timeout: BucketTimeout) {
286        match self.head.as_ref() {
287            None => {
288                let rc_timeout = Rc::new(RefCell::new(timeout));
289                self.head = Some(rc_timeout.clone());
290                self.tail = Some(rc_timeout);
291            }
292            Some(_) => {
293                let rc_tail = self.tail.as_ref().unwrap().clone();
294                timeout.prev = Some(rc_tail);
295                let rc_timeout = Rc::new(RefCell::new(timeout));
296                {
297                    let mut tail = self.tail.as_ref().unwrap().deref().borrow_mut();
298                    tail.next = Some(rc_timeout.clone());
299                }
300                self.tail = Some(rc_timeout);
301            }
302        }
303    }
304
305    fn expire_timeouts(&mut self, deadline: u64) {
306        let mut current = self.head.clone();
307        loop {
308            match current {
309                None => {
310                    return;
311                }
312                Some(timeout) => {
313                    let mut next = RefCell::borrow(&timeout).next.clone();
314
315                    let mut timeout_mut = RefCell::borrow_mut(&timeout);
316                    if timeout_mut.remaining_rounds <= 0 {
317                        next = self.remove(timeout_mut);
318
319                        let mut timeout_mut = RefCell::borrow_mut(&timeout);
320                        timeout_mut.prev = None;
321                        timeout_mut.next = None;
322                        if timeout_mut.deadline <= deadline {
323                            timeout_mut.expire();
324                        } else {
325                            // The timeout was placed into a wrong slot. This should never happen.
326                            panic!(
327                                "timeout.deadline {} > deadline {}",
328                                timeout_mut.deadline, deadline
329                            )
330                        }
331                    } else if timeout_mut.is_cancelled() {
332                        next = self.remove(timeout_mut);
333                    } else {
334                        timeout_mut.remaining_rounds -= 1;
335                    }
336                    current = next;
337                }
338            }
339        }
340    }
341
342    fn remove(&mut self, timeout: RefMut<BucketTimeout>) -> Option<Rc<RefCell<BucketTimeout>>> {
343        let prev = timeout.prev.clone();
344        let next = timeout.next.clone();
345        match prev.clone() {
346            None => {}
347            Some(v) => {
348                let mut prev = v.deref().borrow_mut();
349                prev.next = next.clone();
350            }
351        }
352        match next.clone() {
353            None => {}
354            Some(v) => {
355                let mut next = v.deref().borrow_mut();
356                next.prev = prev.clone()
357            }
358        }
359        let task_id = timeout.task_id;
360        // release borrow
361        drop(timeout);
362
363        let head_task_id = self.head.as_ref().unwrap().deref().borrow().task_id;
364        let tail_task_id = self.tail.as_ref().unwrap().deref().borrow().task_id;
365        if task_id == head_task_id {
366            if task_id == tail_task_id {
367                self.tail = None;
368                self.head = None;
369            } else {
370                self.head = next.clone()
371            }
372        } else if task_id == tail_task_id {
373            self.tail = prev.clone();
374        }
375        next
376    }
377}
378
379const ST_INIT: u8 = 0;
380const ST_CANCELLED: u8 = 1;
381const ST_EXPIRED: u8 = 2;
382
383struct BucketTimeout {
384    task_id: u64,
385    state: AtomicU8, // 0: init, 1: cancelled, 2: expired
386    deadline: u64,
387    remaining_rounds: u64,
388    task: Box<dyn TimerTask + Send>,
389    prev: Option<Rc<RefCell<BucketTimeout>>>,
390    next: Option<Rc<RefCell<BucketTimeout>>>,
391}
392
393impl BucketTimeout {
394    fn new(task_id: u64, task: Box<dyn TimerTask + Send>, deadline: u64) -> BucketTimeout {
395        BucketTimeout {
396            task_id,
397            state: AtomicU8::new(ST_INIT),
398            deadline,
399            task,
400            remaining_rounds: 0,
401            prev: None,
402            next: None,
403        }
404    }
405
406    fn state(&self) -> u8 {
407        self.state.load(Ordering::SeqCst)
408    }
409
410    fn is_cancelled(&self) -> bool {
411        self.state() == ST_CANCELLED
412    }
413
414    fn compare_exchange(&self, expected: u8, state: u8) -> bool {
415        let ret = self
416            .state
417            .compare_exchange(expected, state, Ordering::SeqCst, Ordering::Acquire);
418        match ret {
419            Err(_) => false,
420            Ok(_) => true,
421        }
422    }
423
424    fn expire(&mut self) {
425        if !self.compare_exchange(ST_INIT, ST_EXPIRED) {
426            return;
427        }
428        self.task.run();
429    }
430}
431
432pub trait TimerTask {
433    fn run(&mut self);
434}
435
436struct WheelTimeout {
437    deadline: u64,
438    task: Box<dyn TimerTask + Send>,
439}
440
441impl WheelTimeout {
442    fn new(task: Box<dyn TimerTask + Send>, deadline: u64) -> WheelTimeout {
443        WheelTimeout { deadline, task }
444    }
445}