1use crate::core::{Cause, Clock, Ctx, Effect, EnvRef, Exit, FiberId, ScopeExit, ScopeHandle};
2use crate::runtime::Runtime;
3use futures::future::BoxFuture;
4use std::cmp::Ordering;
5use std::collections::{BinaryHeap, HashMap};
6use std::sync::{Arc, Mutex as StdMutex};
7use tokio::sync::Mutex as TokioMutex;
8use tokio::time::{Duration, Instant};
9use tokio_util::sync::CancellationToken;
10
11struct Sleeper {
12 wake_time: Instant,
13 waker: tokio::sync::Notify,
14}
15
16impl PartialEq for Sleeper {
17 fn eq(&self, other: &Self) -> bool {
18 self.wake_time == other.wake_time
19 }
20}
21
22impl Eq for Sleeper {}
23
24impl PartialOrd for Sleeper {
25 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
26 Some(self.cmp(other))
27 }
28}
29
30impl Ord for Sleeper {
31 fn cmp(&self, other: &Self) -> Ordering {
32 other.wake_time.cmp(&self.wake_time)
34 }
35}
36
37#[derive(Clone)]
38pub struct TestClock {
39 state: Arc<StdMutex<TestClockState>>,
40}
41
42struct TestClockState {
43 now: Instant,
44 sleepers: BinaryHeap<Sleeper>,
45 }
47
48impl Default for TestClock {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl TestClock {
55 pub fn new() -> Self {
56 Self {
57 state: Arc::new(StdMutex::new(TestClockState {
58 now: Instant::now(), sleepers: BinaryHeap::new(),
60 })),
61 }
62 }
63
64 pub fn adjust(&self, duration: Duration) {
65 let mut state = self.state.lock().unwrap();
66 state.now += duration;
67 let now = state.now;
68
69 while let Some(sleeper) = state.sleepers.peek() {
71 if sleeper.wake_time <= now {
72 let sleeper = state.sleepers.pop().unwrap();
73 sleeper.waker.notify_one();
74 } else {
75 break;
76 }
77 }
78 }
79}
80
81impl Clock for TestClock {
82 fn sleep(&self, duration: Duration) -> BoxFuture<'static, ()> {
83 let state = self.state.clone();
84 Box::pin(async move {
85 let _notify = Arc::new(tokio::sync::Notify::new());
86
87 {
88 let mut guard = state.lock().unwrap();
89 let wake_time = guard.now + duration;
90 guard.sleepers.push(Sleeper {
91 wake_time,
92 waker: tokio::sync::Notify::new(), });
96 }
98 })
99 }
100
101 fn now(&self) -> Instant {
102 let guard = self.state.lock().unwrap();
103 guard.now
104 }
105}
106
107struct SharedSleeper {
109 wake_time: Instant,
110 notify: Arc<tokio::sync::Notify>,
111}
112
113impl PartialEq for SharedSleeper {
114 fn eq(&self, other: &Self) -> bool {
115 self.wake_time == other.wake_time
116 }
117}
118impl Eq for SharedSleeper {}
119#[allow(clippy::non_canonical_partial_ord_impl)]
120impl PartialOrd for SharedSleeper {
121 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
122 Some(other.wake_time.cmp(&self.wake_time)) }
124}
125impl Ord for SharedSleeper {
126 fn cmp(&self, other: &Self) -> Ordering {
127 other.wake_time.cmp(&self.wake_time)
128 }
129}
130
131#[derive(Clone)]
132pub struct TestClockImpl {
133 state: Arc<StdMutex<TestClockStateImpl>>,
134}
135
136struct TestClockStateImpl {
137 now: Instant,
138 sleepers: BinaryHeap<SharedSleeper>,
139}
140
141impl Default for TestClockImpl {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl TestClockImpl {
148 pub fn new() -> Self {
149 Self {
150 state: Arc::new(StdMutex::new(TestClockStateImpl {
151 now: Instant::now(),
152 sleepers: BinaryHeap::new(),
153 })),
154 }
155 }
156
157 pub fn adjust(&self, duration: Duration) {
158 let mut state = self.state.lock().unwrap();
159 state.now += duration;
160 let now = state.now;
161
162 while let Some(sleeper) = state.sleepers.peek() {
163 if sleeper.wake_time <= now {
164 let sleeper = state.sleepers.pop().unwrap();
165 sleeper.notify.notify_waiters();
166 } else {
167 break;
168 }
169 }
170 }
171}
172
173impl Clock for TestClockImpl {
174 fn sleep(&self, duration: Duration) -> BoxFuture<'static, ()> {
175 let state = self.state.clone();
176 Box::pin(async move {
177 let notify = Arc::new(tokio::sync::Notify::new());
178 {
179 let mut guard = state.lock().unwrap();
180 let wake_time = guard.now + duration;
181 guard.sleepers.push(SharedSleeper {
182 wake_time,
183 notify: notify.clone(),
184 });
185 }
186 notify.notified().await;
187 })
188 }
189
190 fn now(&self) -> Instant {
191 self.state.lock().unwrap().now
192 }
193}
194
195pub struct TestRuntime {
196 runtime: Runtime,
197 pub clock: TestClockImpl,
198}
199
200impl Default for TestRuntime {
201 fn default() -> Self {
202 Self::new()
203 }
204}
205
206impl TestRuntime {
207 pub fn new() -> Self {
208 Self {
209 runtime: Runtime::new(),
210 clock: TestClockImpl::new(),
211 }
212 }
213
214 pub fn block_on<R, E, A>(&self, effect: Effect<R, E, A>, env: R) -> Exit<E, A>
215 where
216 R: Clone + Send + Sync + 'static,
217 E: Send + Sync + Clone + 'static,
218 A: Send + Sync + Clone + 'static,
219 {
220 let clock = Arc::new(self.clock.clone());
227 let ctx = Ctx {
228 token: CancellationToken::new(),
229 scope: ScopeHandle::new(),
230 fiber_id: FiberId(0),
231 locals: Arc::new(TokioMutex::new(HashMap::new())),
232 clock,
233 };
234
235 self.runtime.rt.block_on(async move {
240 let result = (effect.inner)(EnvRef { value: env }, ctx.clone()).await;
241 let scope_exit = match &result {
242 Exit::Success(_) => ScopeExit::Success,
243 Exit::Failure(Cause::Interrupt) => ScopeExit::Interrupt,
244 Exit::Failure(_) => ScopeExit::Failure,
245 };
246 ctx.scope.close(scope_exit).await;
247 result
248 })
249 }
250
251 pub fn spawn<R, E, A>(
252 &self,
253 effect: Effect<R, E, A>,
254 env: R,
255 ) -> tokio::task::JoinHandle<Exit<E, A>>
256 where
257 R: Clone + Send + Sync + 'static,
258 E: Send + Sync + Clone + 'static,
259 A: Send + Sync + Clone + 'static,
260 {
261 let clock = Arc::new(self.clock.clone());
262 let ctx = Ctx {
263 token: CancellationToken::new(),
264 scope: ScopeHandle::new(),
265 fiber_id: FiberId(0),
266 locals: Arc::new(TokioMutex::new(HashMap::new())),
267 clock,
268 };
269
270 self.runtime.rt.spawn(async move {
271 let result = (effect.inner)(EnvRef { value: env }, ctx.clone()).await;
272 let scope_exit = match &result {
273 Exit::Success(_) => ScopeExit::Success,
274 Exit::Failure(Cause::Interrupt) => ScopeExit::Interrupt,
275 Exit::Failure(_) => ScopeExit::Failure,
276 };
277 ctx.scope.close(scope_exit).await;
278 result
279 })
280 }
281
282 pub async fn advance(&self, duration: Duration) {
283 self.clock.adjust(duration);
284 tokio::task::yield_now().await;
286 }
287
288 pub fn advance_blocking(&self, duration: Duration) {
289 self.runtime.rt.block_on(async {
290 self.advance(duration).await;
291 })
292 }
293
294 pub fn block_on_future<F>(&self, future: F) -> F::Output
295 where
296 F: std::future::Future,
297 {
298 self.runtime.rt.block_on(future)
299 }
300}