async_time_mock_core/
timer_registry.rs1use crate::await_all::await_all;
2use crate::timeout::Timeout;
3use crate::timer::{Timer, TimerListener};
4use crate::{Instant, Interval};
5use event_listener::Event;
6use std::collections::{BTreeMap, VecDeque};
7use std::fmt::{Debug, Formatter};
8use std::future::Future;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::{Arc, RwLock, RwLockWriteGuard};
11use std::time::{Duration, SystemTime};
12
13pub struct TimerRegistry {
14 id: u64,
15 current_time: RwLock<Duration>,
16 timers_by_time: RwLock<TimersByTime>,
17 any_timer_scheduled_signal: Event,
18 advance_time_lock: async_lock::Mutex<()>,
19}
20
21impl Default for TimerRegistry {
22 fn default() -> Self {
23 Self {
24 id: Self::next_id(),
25 current_time: Default::default(),
26 timers_by_time: Default::default(),
27 any_timer_scheduled_signal: Default::default(),
28 advance_time_lock: Default::default(),
29 }
30 }
31}
32
33type TimersByTime = BTreeMap<Duration, VecDeque<Timer>>;
34
35impl TimerRegistry {
36 pub fn sleep(&self, duration: Duration) -> TimerListener {
42 assert!(!duration.is_zero(), "Sleeping for zero time is not allowed");
43
44 let listener = {
45 let timers_by_time = self.timers_by_time.write().expect("RwLock was poisoned");
46 let wakeup_time = *self.current_time.read().expect("RwLock was poisoned") + duration;
47 Self::schedule_timer(timers_by_time, wakeup_time)
48 };
49 self.any_timer_scheduled_signal.notify(1);
50
51 listener
52 }
53
54 pub fn sleep_until(&self, until: Instant) -> TimerListener {
63 let listener = {
64 let timers_by_time = self.timers_by_time.write().expect("RwLock was poisoned");
65 let wakeup_time = until.into_duration(self.id);
66 Self::schedule_timer(timers_by_time, wakeup_time)
67 };
68 self.any_timer_scheduled_signal.notify(1);
69
70 listener
71 }
72
73 pub fn timeout<F>(&self, timeout: Duration, future: F) -> Timeout<F>
80 where
81 F: Future,
82 {
83 Timeout::new(future, self.sleep(timeout))
84 }
85
86 pub fn timeout_at<F>(&self, at: Instant, future: F) -> Timeout<F>
96 where
97 F: Future,
98 {
99 Timeout::new(future, self.sleep_until(at))
100 }
101
102 pub fn interval(self: &Arc<Self>, period: Duration) -> Interval {
103 Interval::new(self.clone(), self.now(), period)
104 }
105
106 pub fn interval_at(self: &Arc<Self>, start: Instant, period: Duration) -> Interval {
107 Interval::new(self.clone(), start, period)
108 }
109
110 fn schedule_timer(mut timers_by_time: RwLockWriteGuard<'_, TimersByTime>, at: Duration) -> TimerListener {
111 let (timer, listener) = Timer::new();
112 timers_by_time.entry(at).or_default().push_back(timer);
113 listener
114 }
115
116 pub async fn advance_time(&self, by_duration: Duration) {
124 let _guard = self.advance_time_lock.lock().await;
125
126 let finished_time = *self.current_time.read().expect("RwLock was poisoned") + by_duration;
127
128 if self.timers_by_time.read().expect("RwLock was poisoned").is_empty() {
129 self.any_timer_scheduled_signal.listen().await;
131 }
132
133 loop {
134 let timers_to_run = {
135 let mut timers_by_time = self.timers_by_time.write().expect("RwLock was poisoned");
136 match timers_by_time.keys().next() {
137 Some(&key) if key <= finished_time => {
138 let mut current_time = self.current_time.write().expect("RwLock was poisoned");
139 *current_time = key.max(*current_time);
140 timers_by_time
141 .remove(&key)
142 .unwrap_or_else(|| unreachable!("We just checked that it exists"))
143 }
144 _ => break,
145 }
146 };
147
148 await_all(timers_to_run.into_iter().map(|timer| timer.trigger().wait())).await;
149 }
150
151 *self.current_time.write().expect("RwLock was poisoned") = finished_time;
152 }
153
154 pub fn now(&self) -> Instant {
156 Instant::new(*self.current_time.read().expect("RwLock was poisoned"), self.id)
157 }
158
159 pub fn system_time(&self) -> SystemTime {
162 SystemTime::UNIX_EPOCH + *self.current_time.read().expect("RwLock was poisoned")
163 }
164
165 fn next_id() -> u64 {
166 static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
167
168 ID_COUNTER.fetch_add(1, Ordering::Relaxed)
169 }
170}
171
172impl Debug for TimerRegistry {
173 fn fmt(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
174 let Self {
175 id,
176 current_time,
177 timers_by_time: _,
178 any_timer_scheduled_signal: _,
179 advance_time_lock: _,
180 } = self;
181 formatter
182 .debug_struct("TimerRegistry")
183 .field("id", id)
184 .field("current_time", current_time)
185 .finish_non_exhaustive()
186 }
187}