1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
use crate::await_all::await_all;
use crate::time_handler_guard::TimeHandlerGuard;
use crate::timeout::Timeout;
use crate::timer::{Timer, TimerListener};
use crate::{Instant, Interval};
use event_listener::Event;
use std::collections::{BTreeMap, VecDeque};
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock, RwLockWriteGuard};
use std::time::Duration;

pub struct TimerRegistry {
	id: u64,
	current_time: RwLock<Duration>,
	timers_by_time: RwLock<TimersByTime>,
	any_timer_scheduled_signal: Event,
	advance_time_lock: async_lock::Mutex<()>,
}

impl Default for TimerRegistry {
	fn default() -> Self {
		Self {
			id: Self::next_id(),
			current_time: Default::default(),
			timers_by_time: Default::default(),
			any_timer_scheduled_signal: Default::default(),
			advance_time_lock: Default::default(),
		}
	}
}

type TimersByTime = BTreeMap<Duration, VecDeque<Timer>>;

impl TimerRegistry {
	/// Schedules a timer to expire in "Duration", once expired, returns
	/// a TimeHandlerGuard that must be dropped only once the timer event has been fully processed
	/// (all sideeffects finished).
	///
	/// Roughly eqivalent to `async pub fn sleep(&self, duration: Duration) -> TimeHandlerGuard`.
	pub fn sleep(&self, duration: Duration) -> impl Future<Output = TimeHandlerGuard> + Send + Sync + 'static {
		assert!(!duration.is_zero(), "Sleeping for zero time is not allowed");

		let timer = {
			let timers_by_time = self.timers_by_time.write().expect("RwLock was poisoned");
			let wakeup_time = *self.current_time.read().expect("RwLock was poisoned") + duration;
			Self::schedule_timer(timers_by_time, wakeup_time)
		};
		self.any_timer_scheduled_signal.notify(1);

		timer.wait_until_triggered()
	}

	/// Schedules a timer to expire at "Instant", once expired, returns
	/// a TimeHandlerGuard that must be dropped only once the timer event has been fully processed
	/// (all sideeffects finished).
	///
	/// Roughly eqivalent to `async pub fn sleep_until(&self, until: Instant) -> TimeHandlerGuard`.
	///
	/// # Panics
	/// When `until` was created by a different instance of `TimerRegistry`.
	pub fn sleep_until(&self, until: Instant) -> impl Future<Output = TimeHandlerGuard> + Send + Sync + 'static {
		let timer = {
			let timers_by_time = self.timers_by_time.write().expect("RwLock was poisoned");
			let wakeup_time = until.into_duration(self.id);
			Self::schedule_timer(timers_by_time, wakeup_time)
		};
		self.any_timer_scheduled_signal.notify(1);

		timer.wait_until_triggered()
	}

	/// Combines a future with a `sleep` timer. If the future finishes before
	/// the timer has expired, returns the futures output. Otherwise returns
	/// `Elapsed` which contains a `TimeHandlerGuard` that must be dropped once the timeout has been fully processed
	/// (all sideeffects finished).
	///
	/// Roughly equivalent to `async pub fn timeout<F: Future>(&self, timeout: Duration, future: F) -> Result<F::Output, Elapsed>`
	pub fn timeout<F>(&self, timeout: Duration, future: F) -> Timeout<F>
	where
		F: Future,
	{
		Timeout::new(future, self.sleep(timeout))
	}

	/// Combines a future with a `sleep_until` timer. If the future finishes before
	/// the timer has expired, returns the futures output. Otherwise returns
	/// `Elapsed` which contains a `TimeHandlerGuard` that must be dropped once the timeout has been fully processed
	/// (all sideeffects finished).
	///
	/// Roughly equivalent to `async pub fn timeout_at<F: Future>(&self, at: Instant, future: F) -> Result<F::Output, Elapsed>`
	///
	/// # Panics
	/// When `at` was created by a different instance of `TimerRegistry`.
	pub fn timeout_at<F>(&self, at: Instant, future: F) -> Timeout<F>
	where
		F: Future,
	{
		Timeout::new(future, self.sleep_until(at))
	}

	pub fn interval(self: &Arc<Self>, period: Duration) -> Interval {
		Interval::new(self.clone(), self.now(), period)
	}

	pub fn interval_at(self: &Arc<Self>, start: Instant, period: Duration) -> Interval {
		Interval::new(self.clone(), start, period)
	}

	fn schedule_timer(mut timers_by_time: RwLockWriteGuard<'_, TimersByTime>, at: Duration) -> TimerListener {
		let (timer, listener) = Timer::new();
		timers_by_time.entry(at).or_insert_with(VecDeque::new).push_back(timer);
		listener
	}

	/// Advances test time by the given duration. Starts all scheduled timers that have expired
	/// at the new (advanced) point in time in the following order:
	/// 1. By time they are scheduled to run at
	/// 2. By the order they were scheduled
	///
	/// If no timer has been scheduled yet, waits until one is.
	/// Returns only once all started timers have finished processing.
	pub async fn advance_time(&self, by_duration: Duration) {
		let _guard = self.advance_time_lock.lock().await;

		let finished_time = *self.current_time.read().expect("RwLock was poisoned") + by_duration;

		if self.timers_by_time.read().expect("RwLock was poisoned").is_empty() {
			// If no timer has been scheduled yet, wait for one to be scheduled
			self.any_timer_scheduled_signal.listen().await;
		}

		loop {
			let timers_to_run = {
				let mut timers_by_time = self.timers_by_time.write().expect("RwLock was poisoned");
				match timers_by_time.keys().next() {
					Some(&key) if key <= finished_time => {
						let mut current_time = self.current_time.write().expect("RwLock was poisoned");
						*current_time = key.max(*current_time);
						timers_by_time
							.remove(&key)
							.unwrap_or_else(|| unreachable!("We just checked that it exists"))
					}
					_ => break,
				}
			};

			await_all(timers_to_run.into_iter().map(|timer| timer.trigger().wait())).await;
		}

		*self.current_time.write().expect("RwLock was poisoned") = finished_time;
	}

	/// Current test time, increases on every call to [`advance_time`].
	pub fn now(&self) -> Instant {
		Instant::new(*self.current_time.read().expect("RwLock was poisoned"), self.id)
	}

	fn next_id() -> u64 {
		static ID_COUNTER: AtomicU64 = AtomicU64::new(0);

		ID_COUNTER.fetch_add(1, Ordering::Relaxed)
	}
}

impl Debug for TimerRegistry {
	fn fmt(&self, formatter: &mut Formatter<'_>) -> std::fmt::Result {
		let Self {
			id,
			current_time,
			timers_by_time: _,
			any_timer_scheduled_signal: _,
			advance_time_lock: _,
		} = self;
		formatter
			.debug_struct("TimerRegistry")
			.field("id", id)
			.field("current_time", current_time)
			.finish_non_exhaustive()
	}
}