1use futures::{
54 future::poll_fn,
55 task::{self, Poll, Waker},
56 Future,
57};
58
59use async_task::{self, Runnable};
60
61use std::cell::RefCell;
62use std::panic;
63use std::pin::Pin;
64use std::sync::{
65 atomic::{AtomicBool, Ordering},
66 Arc, Mutex,
67};
68use std::time::Duration;
69use std::vec::Vec;
70
71use derive_builder::Builder;
72use flume;
73use pin_project::pin_project;
74use threadpool::ThreadPool;
75
76pub use async_metronome_attributes::test;
78
79const DEADLOCK: &str = "deadlock";
80const HASCONTEXT: &str = "hascontext";
81
82#[derive(Clone, Default, Builder, Debug)]
84pub struct Options {
85 #[builder(setter(into, strip_option), default)]
86 _timeout: Option<Duration>,
87
88 #[builder(setter(into), default)]
89 debug: bool,
90}
91
92struct RunQueueEntry(usize, Runnable);
93
94struct TestContext {
95 tick: usize,
96 task_id: usize,
97 task_active: usize,
98 sender: flume::Sender<RunQueueEntry>,
99 wakers: Vec<Waker>,
100 options: Arc<Options>,
101}
102
103pub struct JoinHandle<O> {
108 task: Option<async_task::Task<O>>,
109}
110
111impl<O> Future for JoinHandle<O> {
112 type Output = O;
113
114 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
115 Pin::new(&mut self.task.as_mut().unwrap()).poll(cx)
116 }
117}
118
119impl<T> Drop for JoinHandle<T> {
120 fn drop(&mut self) {
121 if let Some(task) = self.task.take() {
122 task.detach();
123 }
124 }
125}
126
127impl TestContext {
128 fn new(sender: flume::Sender<RunQueueEntry>, options: Arc<Options>) -> Self {
129 TestContext {
130 tick: 0,
131 task_id: 0,
132 task_active: 0,
133 sender,
134 wakers: Vec::new(),
135 options,
136 }
137 }
138
139 fn register_wait(&mut self, waker: Waker) {
140 self.wakers.push(waker);
141 }
142
143 fn next_tick(&mut self) -> usize {
144 let wakers = self.wakers.len();
145
146 if wakers > 0 {
147 self.tick += 1;
148
149 for waker in &self.wakers {
150 waker.wake_by_ref();
151 }
152
153 self.wakers.clear();
154
155 wakers
156 } else {
157 wakers
158 }
159 }
160
161 fn spawn<F, O>(&mut self, future: F) -> JoinHandle<O>
162 where
163 F: Future<Output = O> + Send + 'static,
164 O: Send + 'static,
165 {
166 let sender = self.sender.clone();
167
168 let task_id = self.task_id;
169 self.task_id += 1;
170 self.task_active += 1;
171
172 let schedule = move |runnable| {
173 sender.send(RunQueueEntry(task_id, runnable)).unwrap();
174 };
175
176 let options = self.options.clone();
177 if options.debug {
178 println!("{:?} ** spawn", task_id);
179 }
180 let (runnable, task) = async_task::spawn(
181 TaskWrapper {
182 future,
183 task_id,
184 options,
185 },
186 schedule,
187 );
188 runnable.schedule();
189
190 JoinHandle { task: Some(task) }
191 }
192}
193
194type WrappedTestContext = Arc<Mutex<TestContext>>;
195
196thread_local! {
197 static CONTEXT: RefCell<Option<WrappedTestContext>> = RefCell::new(None);
198}
199
200fn get_context() -> WrappedTestContext {
201 CONTEXT.with(|cell| cell.borrow().as_ref().expect(HASCONTEXT).clone())
202}
203
204#[doc(hidden)]
205pub fn __private_wait_tick(tick: usize) -> impl Future<Output = usize> {
206 poll_fn(move |cx| {
207 let test_context = get_context();
208 let mut test_context = test_context.lock().unwrap();
209
210 if test_context.tick >= tick {
211 Poll::Ready(tick)
212 } else {
213 test_context.register_wait(cx.waker().clone());
214 Poll::Pending
215 }
216 })
217}
218
219#[macro_export]
225macro_rules! await_tick {
226 ($tick:expr) => {
227 $crate::__private_wait_tick($tick as usize).await
228 };
229}
230
231#[doc(hidden)]
232pub fn __private_get_tick() -> usize {
233 get_context().lock().unwrap().tick
234}
235
236#[macro_export]
238macro_rules! assert_tick {
239 ($expected:expr) => {
240 let actual = $crate::__private_get_tick();
241 assert!(
242 actual == $expected,
243 "tick mismatch: expected={}, actual={}",
244 $expected,
245 actual
246 )
247 };
248}
249
250#[pin_project]
251struct TaskWrapper<T> {
252 task_id: usize,
253 #[pin]
254 future: T,
255 options: Arc<Options>,
256}
257
258impl<T: Future> Future for TaskWrapper<T> {
259 type Output = T::Output;
260
261 fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
262 let debug = self.options.debug;
263
264 let this = self.project();
265 let task_id = *this.task_id;
266
267 if debug {
268 println!("{:?} ** poll", task_id);
269 }
270
271 let context = get_context();
272 match panic::catch_unwind(panic::AssertUnwindSafe(|| this.future.poll(cx))) {
273 Ok(poll) => {
274 if poll.is_ready() {
275 if debug {
276 println!("{:?} ** ready", task_id);
277 }
278
279 context.lock().unwrap().task_active -= 1;
280 } else {
281 if debug {
282 println!("{:?} ** pending", task_id);
283 }
284 }
285
286 poll
287 }
288 Err(error) => {
289 context.lock().unwrap().task_active -= 1;
290 panic::resume_unwind(error);
291 }
292 }
293 }
294}
295
296pub fn is_context() -> bool {
299 CONTEXT.with(|cell| cell.borrow().as_ref().is_some())
300}
301
302pub fn spawn<F, O>(future: F) -> JoinHandle<O>
307where
308 F: Future<Output = O> + Send + 'static,
309 O: Send + 'static,
310{
311 get_context().lock().expect(HASCONTEXT).spawn(future)
312}
313
314pub fn run_opt<O, F>(future: F, options: Options)
326where
327 F: Future<Output = O> + Send + 'static,
328 O: Send + 'static,
329{
330 CONTEXT.with(|cell| {
331 if cell.borrow().is_some() {
332 panic!("{}", HASCONTEXT);
333 }
334 });
335
336 let options = Arc::new(options);
337 let pool = ThreadPool::new(8);
338 let (sender, receiver) = flume::unbounded::<RunQueueEntry>();
339 let mut context = TestContext::new(sender, options.clone());
340
341 context.spawn(future);
342
343 let context = Arc::new(Mutex::new(context));
344 let panic_flag = Arc::new(AtomicBool::new(false));
345 loop {
346 if let Ok(RunQueueEntry(task_id, runnable)) = receiver.try_recv() {
347 let panic_flag1 = panic_flag.clone();
348 let context = context.clone();
349
350 pool.execute(move || {
351 CONTEXT.with(|cell| cell.replace(Some(context)));
352
353 let result = panic::catch_unwind(panic::AssertUnwindSafe(|| runnable.run()));
354
355 if let Err(_) = result {
356 if task_id == 0 {
357 panic_flag1.store(true, Ordering::Relaxed);
358 }
359 }
360
361 CONTEXT.with(|cell| cell.replace(None));
362 });
363
364 if !panic_flag.load(Ordering::Relaxed) {
365 continue;
366 }
367 }
368
369 pool.join();
372
373 if panic_flag.load(Ordering::Relaxed) {
374 panic!("root task panic");
375 }
376
377 if !receiver.is_empty() {
379 continue;
380 }
381
382 let mut context = context.lock().unwrap();
383
384 if options.debug {
385 println!("queue exhaused: tc: {:?}", context.task_active);
386 }
387
388 if context.task_active > 0 {
389 let wakers = context.next_tick();
390
391 if wakers > 0 {
392 if options.debug {
393 println!("tick -> {:?}, waking up {:?} wakers", context.tick, wakers);
394 }
395 continue;
396 } else {
397 panic!("{}", DEADLOCK);
398 }
399 } else {
400 break;
401 }
402 }
403}
404
405pub fn run<O, F>(future: F)
407where
408 F: Future<Output = O> + Send + 'static,
409 O: Send + 'static,
410{
411 run_opt(future, Options::default());
412}
413
414#[cfg(test)]
415mod tests {
416 #[test]
417 #[should_panic]
418 fn test_panic_no_context() {
419 super::spawn(async {});
420 }
421
422 #[test]
423 #[should_panic]
424 fn test_root_task_exception() {
425 super::run(async {
426 panic!();
427 });
428 }
429
430 #[test]
431 fn test_inner_task_exception() {
432 super::run(async {
433 super::spawn(async {
434 panic!();
435 });
436 });
437 }
438
439 #[test]
440 #[should_panic]
441 fn test_inner_task_exception_propagates() {
442 super::run(async {
443 let jh = super::spawn(async {
444 panic!();
445 });
446
447 jh.await;
448 });
449 }
450
451 #[test]
452 fn test_has_context() {
453 super::run(async {
454 super::CONTEXT.with(|cell| assert!(cell.borrow().is_some()));
455
456 super::spawn(async {
457 super::CONTEXT.with(|cell| {
458 assert!(cell.borrow().is_some());
459 });
460 })
461 .await;
462 });
463 }
464
465 #[test]
466 #[should_panic]
467 fn test_panic_nested() {
468 super::run(async {
469 super::run(async {});
470 });
471 }
472
473 #[test]
474 fn test_initial_ticks_0() {
475 super::run(async {
476 assert_tick!(0);
477 });
478 }
479
480 #[test]
481 fn test_task_count() {
482 use futures::future::FutureExt;
483
484 super::run(async {
485 assert_eq!(super::get_context().lock().unwrap().task_active, 1);
487
488 super::spawn(async {
489 assert_eq!(super::get_context().lock().unwrap().task_active, 2);
491 })
492 .await;
493
494 assert_eq!(super::get_context().lock().unwrap().task_active, 1);
496
497 let handle = super::spawn(async {
499 panic!();
500 });
501
502 let _ = handle.catch_unwind().await;
504
505 assert_eq!(super::get_context().lock().unwrap().task_active, 1);
506 });
507 }
508
509 #[test]
510 fn test_ticks_increment_on_wait() {
511 super::run(async {
512 super::await_tick!(1);
513 super::assert_tick!(1);
514 });
515 }
516
517 #[test]
518 fn test_ticks_increment_on_wait_inner() {
519 super::run(async {
520 super::spawn(async {
521 super::await_tick!(1);
522 })
523 .await;
524 super::assert_tick!(1);
525 });
526 }
527
528 #[test]
529 #[should_panic]
530 fn test_deadlock() {
531 use async_std::task;
532 use std::time::Duration;
533
534 super::run(async {
535 task::sleep(Duration::from_secs(1)).await;
536 });
537 }
538
539 #[test]
540 #[should_panic]
541 fn test_deadlock_inner() {
542 use async_std::task;
543 use std::time::Duration;
544
545 super::run(async {
546 super::spawn(async {
547 task::sleep(Duration::from_secs(1)).await;
548 })
549 .await;
550 });
551 }
552}