1use parking_lot::RwLock;
32use std::collections::BTreeMap;
33use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
34use std::sync::Arc;
35use std::time::{Duration, Instant};
36use thread_local::ThreadLocal;
37use tokio::sync::Notify;
38
39const RESOLUTION_MS: u64 = 10;
40const RESOLUTION_DURATION: Duration = Duration::from_millis(RESOLUTION_MS);
41
42#[inline]
44fn round_to(raw: u128, resolution: u128) -> u128 {
45 raw - 1 + resolution - (raw - 1) % resolution
46}
47#[derive(PartialEq, PartialOrd, Eq, Ord, Clone, Copy, Debug)]
49struct Time(u128);
50
51impl From<u128> for Time {
52 fn from(raw_ms: u128) -> Self {
53 Time(round_to(raw_ms, RESOLUTION_MS as u128))
54 }
55}
56
57impl From<Duration> for Time {
58 fn from(d: Duration) -> Self {
59 Time(round_to(d.as_millis(), RESOLUTION_MS as u128))
60 }
61}
62
63impl Time {
64 pub fn not_after(&self, ts: u128) -> bool {
65 self.0 <= ts
66 }
67}
68
69pub struct TimerStub(Arc<Notify>, Arc<AtomicBool>);
71
72impl TimerStub {
73 pub async fn poll(self) {
75 if self.1.load(Ordering::SeqCst) {
76 return;
77 }
78 self.0.notified().await;
79 }
80}
81
82struct Timer(Arc<Notify>, Arc<AtomicBool>);
83
84impl Timer {
85 pub fn new() -> Self {
86 Timer(Arc::new(Notify::new()), Arc::new(AtomicBool::new(false)))
87 }
88
89 pub fn fire(&self) {
90 self.1.store(true, Ordering::SeqCst);
91 self.0.notify_waiters();
92 }
93
94 pub fn subscribe(&self) -> TimerStub {
95 TimerStub(self.0.clone(), self.1.clone())
96 }
97}
98
99pub struct TimerManager {
101 timers: ThreadLocal<RwLock<BTreeMap<Time, Timer>>>,
103 zero: Instant, clock_watchdog: AtomicI64,
106 paused: AtomicBool,
107}
108
109const DELAYS_SEC: i64 = 2; impl Default for TimerManager {
113 fn default() -> Self {
114 TimerManager {
115 timers: ThreadLocal::new(),
116 zero: Instant::now(),
117 clock_watchdog: AtomicI64::new(-DELAYS_SEC),
118 paused: AtomicBool::new(false),
119 }
120 }
121}
122
123impl TimerManager {
124 pub fn new() -> Self {
126 Self::default()
127 }
128
129 pub(crate) fn clock_thread(&self) {
131 loop {
132 std::thread::sleep(RESOLUTION_DURATION);
133 let now = Instant::now() - self.zero;
134 self.clock_watchdog
135 .store(now.as_secs() as i64, Ordering::Relaxed);
136 if self.is_paused_for_fork() {
137 continue;
139 }
140 let now = now.as_millis();
141 for thread_timer in self.timers.iter() {
143 let mut timers = thread_timer.write();
144 loop {
146 let key_to_remove = timers.iter().next().and_then(|(k, _)| {
147 if k.not_after(now) {
148 Some(*k)
149 } else {
150 None
151 }
152 });
153 if let Some(k) = key_to_remove {
154 let timer = timers.remove(&k);
155 timer.unwrap().fire();
157 } else {
158 break;
159 }
160 }
161 }
163 }
164 }
165
166 pub(crate) fn should_i_start_clock(&self) -> bool {
169 let Err(prev) = self.is_clock_running() else {
170 return false;
171 };
172 let now = Instant::now().duration_since(self.zero).as_secs() as i64;
173 let res =
174 self.clock_watchdog
175 .compare_exchange(prev, now, Ordering::SeqCst, Ordering::SeqCst);
176 res.is_ok()
177 }
178
179 pub(crate) fn is_clock_running(&self) -> Result<(), i64> {
182 let now = Instant::now().duration_since(self.zero).as_secs() as i64;
183 let prev = self.clock_watchdog.load(Ordering::SeqCst);
184 if now < prev + DELAYS_SEC {
185 Ok(())
186 } else {
187 Err(prev)
188 }
189 }
190
191 pub fn register_timer(&self, duration: Duration) -> TimerStub {
195 if self.is_paused_for_fork() {
196 let timer = Timer::new();
202 timer.fire();
203 return timer.subscribe();
204 }
205 let now: Time = (Instant::now() + duration - self.zero).into();
206 {
207 let timers = self.timers.get_or(|| RwLock::new(BTreeMap::new())).read();
208 if let Some(t) = timers.get(&now) {
209 return t.subscribe();
210 }
211 } let timer = Timer::new();
214 let mut timers = self.timers.get_or(|| RwLock::new(BTreeMap::new())).write();
215 let stub = timer.subscribe();
220 timers.insert(now, timer);
221 stub
222 }
223
224 fn is_paused_for_fork(&self) -> bool {
225 self.paused.load(Ordering::SeqCst)
226 }
227
228 pub fn pause_for_fork(&self) {
235 self.paused.store(true, Ordering::SeqCst);
236 std::thread::sleep(RESOLUTION_DURATION * 2);
238 }
239
240 pub fn unpause(&self) {
244 self.paused.store(false, Ordering::SeqCst)
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn test_round() {
254 assert_eq!(round_to(30, 10), 30);
255 assert_eq!(round_to(31, 10), 40);
256 assert_eq!(round_to(29, 10), 30);
257 }
258
259 #[test]
260 fn test_time() {
261 let t: Time = 128.into(); assert_eq!(t, Duration::from_millis(130).into());
263 assert!(!t.not_after(128));
264 assert!(!t.not_after(129));
265 assert!(t.not_after(130));
266 assert!(t.not_after(131));
267 }
268
269 #[tokio::test]
270 async fn test_timer_manager() {
271 let tm_a = Arc::new(TimerManager::new());
272 let tm = tm_a.clone();
273 std::thread::spawn(move || tm_a.clock_thread());
274
275 let now = Instant::now();
276 let t1 = tm.register_timer(Duration::from_secs(1));
277 let t2 = tm.register_timer(Duration::from_secs(1));
278 t1.poll().await;
279 assert_eq!(now.elapsed().as_secs(), 1);
280 let now = Instant::now();
281 t2.poll().await;
282 assert_eq!(now.elapsed().as_secs(), 0);
284 }
285
286 #[test]
287 fn test_timer_manager_start_check() {
288 let tm = Arc::new(TimerManager::new());
289 assert!(tm.should_i_start_clock());
290 assert!(!tm.should_i_start_clock());
291 assert!(tm.is_clock_running().is_ok());
292 }
293
294 #[test]
295 fn test_timer_manager_watchdog() {
296 let tm = Arc::new(TimerManager::new());
297 assert!(tm.should_i_start_clock());
298 assert!(!tm.should_i_start_clock());
299
300 std::thread::sleep(Duration::from_secs(DELAYS_SEC as u64 + 1));
302 assert!(tm.is_clock_running().is_err());
303 assert!(tm.should_i_start_clock());
304 }
305
306 #[tokio::test]
307 async fn test_timer_manager_pause() {
308 let tm_a = Arc::new(TimerManager::new());
309 let tm = tm_a.clone();
310 std::thread::spawn(move || tm_a.clock_thread());
311
312 let now = Instant::now();
313 let t1 = tm.register_timer(Duration::from_secs(2));
314 tm.pause_for_fork();
315 let t2 = tm.register_timer(Duration::from_secs(2));
319 t2.poll().await;
320 assert_eq!(now.elapsed().as_secs(), 0);
321
322 std::thread::sleep(Duration::from_secs(1));
323 tm.unpause();
324 t1.poll().await;
325 assert_eq!(now.elapsed().as_secs(), 2);
326 }
327}