async_timer_rs/hashed/
mod.rs

1use std::{
2    collections::{HashMap, HashSet},
3    sync::{Arc, Mutex},
4    task::{Poll, Waker},
5    time::{Duration, Instant},
6};
7
8mod timewheel;
9use timewheel::*;
10
11#[derive(Clone)]
12pub struct TimerExecutor {
13    tick_duration: Duration,
14    inner: Arc<Mutex<TimerExecutorImpl>>,
15}
16
17struct TimerExecutorImpl {
18    timer_id_seq: usize,
19    wheel: TimeWheel<usize>,
20    wakers: HashMap<usize, std::task::Waker>,
21    fired: HashSet<usize>,
22}
23
24impl TimerExecutorImpl {
25    fn new(step: u64) -> Self {
26        Self {
27            timer_id_seq: 0,
28            wheel: TimeWheel::new(step),
29            wakers: Default::default(),
30            fired: Default::default(),
31        }
32    }
33
34    fn create_timer(&mut self, duration: u64) -> usize {
35        self.timer_id_seq += 1;
36
37        let timer = self.timer_id_seq;
38
39        self.wheel.add(duration, timer);
40
41        timer
42    }
43
44    fn poll(&mut self, timer: usize, waker: Waker) -> Poll<()> {
45        if self.fired.remove(&timer) {
46            Poll::Ready(())
47        } else {
48            log::debug!("inser timer {} waker", timer);
49            self.wakers.insert(timer, waker);
50            Poll::Pending
51        }
52    }
53
54    fn tick(&mut self) {
55        if let Poll::Ready(timers) = self.wheel.tick() {
56            log::debug!("ready timers {:?}", timers);
57            for timer in timers {
58                self.fired.insert(timer);
59
60                if let Some(waker) = self.wakers.remove(&timer) {
61                    log::debug!("wake up timer {}", timer);
62                    waker.wake_by_ref();
63                }
64            }
65        }
66    }
67}
68
69impl TimerExecutor {
70    pub fn new(step: u64, tick_duration: Duration) -> Self {
71        let inner: Arc<Mutex<TimerExecutorImpl>> =
72            Arc::new(Mutex::new(TimerExecutorImpl::new(step)));
73
74        let inner_tick = inner.clone();
75
76        std::thread::spawn(move || {
77            let mut inaccuracy: u128 = 0;
78            // When no other strong reference is alive, stop tick thread
79            while Arc::strong_count(&inner_tick) > 1 {
80                // Correct the cumulative deviation
81                let call_times = inaccuracy / tick_duration.as_millis() + 1;
82
83                inaccuracy = inaccuracy % tick_duration.as_millis();
84
85                let now = Instant::now();
86
87                for _ in 0..call_times {
88                    inner_tick.lock().unwrap().tick();
89                }
90
91                std::thread::sleep(tick_duration);
92                inaccuracy += now.elapsed().as_millis() - tick_duration.as_millis();
93            }
94        });
95
96        Self {
97            inner,
98            tick_duration,
99        }
100    }
101
102    /// Create a new timeout future instance.
103    pub fn timeout(&self, duration: Duration) -> Timeout {
104        let mut ticks = duration.as_millis() / self.tick_duration.as_millis();
105
106        if ticks == 0 {
107            ticks = 1;
108        }
109
110        let timer_id = self.inner.lock().unwrap().create_timer(ticks as u64);
111
112        Timeout {
113            timer_id,
114            executor: self.inner.clone(),
115        }
116    }
117}
118
119#[derive(Clone)]
120pub struct Timeout {
121    timer_id: usize,
122    executor: Arc<Mutex<TimerExecutorImpl>>,
123}
124
125impl std::future::Future for Timeout {
126    type Output = ();
127
128    fn poll(
129        self: std::pin::Pin<&mut Self>,
130        cx: &mut std::task::Context<'_>,
131    ) -> std::task::Poll<Self::Output> {
132        self.executor
133            .lock()
134            .unwrap()
135            .poll(self.timer_id, cx.waker().clone())
136    }
137}
138
139impl crate::Timer for Timeout {
140    fn new(duration: Duration) -> Self {
141        global_timer_executor().timeout(duration)
142    }
143}
144
145impl crate::TimerWithContext for Timeout {
146    type Context = TimerExecutor;
147    fn new_with_context<C>(duration: Duration, mut context: C) -> Self
148    where
149        C: AsMut<Self::Context>,
150    {
151        context.as_mut().timeout(duration)
152    }
153}
154
155/// Accesss global static timer executor instance
156pub fn global_timer_executor() -> &'static TimerExecutor {
157    use once_cell::sync::OnceCell;
158
159    static INSTANCE: OnceCell<TimerExecutor> = OnceCell::new();
160
161    INSTANCE.get_or_init(|| TimerExecutor::new(3600, Duration::from_millis(100)))
162}
163
164#[cfg(test)]
165mod tests {
166    use std::time::{Duration, Instant};
167
168    use crate::Timer;
169
170    use super::Timeout;
171
172    #[async_std::test]
173    async fn test() {
174        _ = pretty_env_logger::try_init();
175
176        async fn test_timeout(duration: Duration) {
177            let now = Instant::now();
178
179            Timeout::new(duration).await;
180
181            let elapsed = now.elapsed();
182
183            log::debug!("system time elapsed {:?}", elapsed);
184
185            assert_eq!(elapsed.as_secs(), duration.as_secs());
186        }
187
188        test_timeout(Duration::from_secs(2)).await;
189
190        test_timeout(Duration::from_secs(4)).await;
191
192        test_timeout(Duration::from_secs(10)).await;
193
194        test_timeout(Duration::from_secs(60)).await;
195    }
196}