async_time_mock_core/
timer_registry.rs

1use crate::await_all::await_all;
2use crate::timeout::Timeout;
3use crate::timer::{Timer, TimerListener};
4use crate::{Instant, Interval};
5use event_listener::Event;
6use std::collections::{BTreeMap, VecDeque};
7use std::fmt::{Debug, Formatter};
8use std::future::Future;
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::{Arc, RwLock, RwLockWriteGuard};
11use std::time::{Duration, SystemTime};
12
13pub struct TimerRegistry {
14	id: u64,
15	current_time: RwLock<Duration>,
16	timers_by_time: RwLock<TimersByTime>,
17	any_timer_scheduled_signal: Event,
18	advance_time_lock: async_lock::Mutex<()>,
19}
20
21impl Default for TimerRegistry {
22	fn default() -> Self {
23		Self {
24			id: Self::next_id(),
25			current_time: Default::default(),
26			timers_by_time: Default::default(),
27			any_timer_scheduled_signal: Default::default(),
28			advance_time_lock: Default::default(),
29		}
30	}
31}
32
33type TimersByTime = BTreeMap<Duration, VecDeque<Timer>>;
34
35impl TimerRegistry {
36	/// Schedules a timer to expire in "Duration", once expired, returns
37	/// a TimeHandlerGuard that must be dropped only once the timer event has been fully processed
38	/// (all sideeffects finished).
39	///
40	/// Roughly eqivalent to `async pub fn sleep(&self, duration: Duration) -> TimeHandlerGuard`.
41	pub fn sleep(&self, duration: Duration) -> TimerListener {
42		assert!(!duration.is_zero(), "Sleeping for zero time is not allowed");
43
44		let listener = {
45			let timers_by_time = self.timers_by_time.write().expect("RwLock was poisoned");
46			let wakeup_time = *self.current_time.read().expect("RwLock was poisoned") + duration;
47			Self::schedule_timer(timers_by_time, wakeup_time)
48		};
49		self.any_timer_scheduled_signal.notify(1);
50
51		listener
52	}
53
54	/// Schedules a timer to expire at "Instant", once expired, returns
55	/// a TimeHandlerGuard that must be dropped only once the timer event has been fully processed
56	/// (all sideeffects finished).
57	///
58	/// Roughly eqivalent to `async pub fn sleep_until(&self, until: Instant) -> TimeHandlerGuard`.
59	///
60	/// # Panics
61	/// When `until` was created by a different instance of `TimerRegistry`.
62	pub fn sleep_until(&self, until: Instant) -> TimerListener {
63		let listener = {
64			let timers_by_time = self.timers_by_time.write().expect("RwLock was poisoned");
65			let wakeup_time = until.into_duration(self.id);
66			Self::schedule_timer(timers_by_time, wakeup_time)
67		};
68		self.any_timer_scheduled_signal.notify(1);
69
70		listener
71	}
72
73	/// Combines a future with a `sleep` timer. If the future finishes before
74	/// the timer has expired, returns the futures output. Otherwise returns
75	/// `Elapsed` which contains a `TimeHandlerGuard` that must be dropped once the timeout has been fully processed
76	/// (all sideeffects finished).
77	///
78	/// Roughly equivalent to `async pub fn timeout<F: Future>(&self, timeout: Duration, future: F) -> Result<F::Output, Elapsed>`
79	pub fn timeout<F>(&self, timeout: Duration, future: F) -> Timeout<F>
80	where
81		F: Future,
82	{
83		Timeout::new(future, self.sleep(timeout))
84	}
85
86	/// Combines a future with a `sleep_until` timer. If the future finishes before
87	/// the timer has expired, returns the futures output. Otherwise returns
88	/// `Elapsed` which contains a `TimeHandlerGuard` that must be dropped once the timeout has been fully processed
89	/// (all sideeffects finished).
90	///
91	/// Roughly equivalent to `async pub fn timeout_at<F: Future>(&self, at: Instant, future: F) -> Result<F::Output, Elapsed>`
92	///
93	/// # Panics
94	/// When `at` was created by a different instance of `TimerRegistry`.
95	pub fn timeout_at<F>(&self, at: Instant, future: F) -> Timeout<F>
96	where
97		F: Future,
98	{
99		Timeout::new(future, self.sleep_until(at))
100	}
101
102	pub fn interval(self: &Arc<Self>, period: Duration) -> Interval {
103		Interval::new(self.clone(), self.now(), period)
104	}
105
106	pub fn interval_at(self: &Arc<Self>, start: Instant, period: Duration) -> Interval {
107		Interval::new(self.clone(), start, period)
108	}
109
110	fn schedule_timer(mut timers_by_time: RwLockWriteGuard<'_, TimersByTime>, at: Duration) -> TimerListener {
111		let (timer, listener) = Timer::new();
112		timers_by_time.entry(at).or_default().push_back(timer);
113		listener
114	}
115
116	/// Advances test time by the given duration. Starts all scheduled timers that have expired
117	/// at the new (advanced) point in time in the following order:
118	/// 1. By time they are scheduled to run at
119	/// 2. By the order they were scheduled
120	///
121	/// If no timer has been scheduled yet, waits until one is.
122	/// Returns only once all started timers have finished processing.
123	pub async fn advance_time(&self, by_duration: Duration) {
124		let _guard = self.advance_time_lock.lock().await;
125
126		let finished_time = *self.current_time.read().expect("RwLock was poisoned") + by_duration;
127
128		if self.timers_by_time.read().expect("RwLock was poisoned").is_empty() {
129			// If no timer has been scheduled yet, wait for one to be scheduled
130			self.any_timer_scheduled_signal.listen().await;
131		}
132
133		loop {
134			let timers_to_run = {
135				let mut timers_by_time = self.timers_by_time.write().expect("RwLock was poisoned");
136				match timers_by_time.keys().next() {
137					Some(&key) if key <= finished_time => {
138						let mut current_time = self.current_time.write().expect("RwLock was poisoned");
139						*current_time = key.max(*current_time);
140						timers_by_time
141							.remove(&key)
142							.unwrap_or_else(|| unreachable!("We just checked that it exists"))
143					}
144					_ => break,
145				}
146			};
147
148			await_all(timers_to_run.into_iter().map(|timer| timer.trigger().wait())).await;
149		}
150
151		*self.current_time.write().expect("RwLock was poisoned") = finished_time;
152	}
153
154	/// Current test time, increases on every call to [`advance_time`].
155	pub fn now(&self) -> Instant {
156		Instant::new(*self.current_time.read().expect("RwLock was poisoned"), self.id)
157	}
158
159	/// Current test time. Similar to [`now`] but simulating system time, not monotonic time.
160	/// Increases on every call to [`advance_time`].
161	pub fn system_time(&self) -> SystemTime {
162		SystemTime::UNIX_EPOCH + *self.current_time.read().expect("RwLock was poisoned")
163	}
164
165	fn next_id() -> u64 {
166		static ID_COUNTER: AtomicU64 = AtomicU64::new(0);
167
168		ID_COUNTER.fetch_add(1, Ordering::Relaxed)
169	}
170}
171
172impl Debug for TimerRegistry {
173	fn fmt(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
174		let Self {
175			id,
176			current_time,
177			timers_by_time: _,
178			any_timer_scheduled_signal: _,
179			advance_time_lock: _,
180		} = self;
181		formatter
182			.debug_struct("TimerRegistry")
183			.field("id", id)
184			.field("current_time", current_time)
185			.finish_non_exhaustive()
186	}
187}