async_metronome/
lib.rs

1//! ## Unit testing framework for async Rust
2//!
3//! This crate implements a async unit testing framework, which is based on
4//! the [MutithreadedTC](https://www.cs.umd.edu/projects/PL/multithreadedtc/overview.html)
5//! developed by William Pugh and Nathaniel Ayewah at the University of Maryland.
6//!
7//! ## Example
8//!
9//! ```
10//! # use futures::{channel::mpsc, SinkExt, StreamExt};
11//! use async_metronome::{assert_tick, await_tick};
12//!
13//! #[async_metronome::test]
14//! async fn test_send_receive() {
15//!         let (mut sender, mut receiver) = mpsc::channel::<usize>(1);
16//!
17//!         let sender = async move {
18//!             assert_tick!(0);
19//!             sender.send(42).await.unwrap();
20//!             sender.send(17).await.unwrap();
21//!             assert_tick!(1);
22//!         };
23//!
24//!         let receiver = async move {
25//!             assert_tick!(0);
26//!             await_tick!(1);
27//!             receiver.next().await;
28//!             receiver.next().await;
29//!         };
30//!         let sender = async_metronome::spawn(sender);
31//!         let receiver = async_metronome::spawn(receiver);
32//!
33//!         sender.await;
34//!         receiver.await;
35//! }
36//! ```
37//!
38//! ## Explanation
39//! async-metronome has an internal clock. The clock only advances to the next
40//! tick when all tasks are in a pending state.
41//!
42//! The clock starts at `tick 0`. In this example, the macro `await_tick!(1)` makes
43//! the receiver block until the clock reaches `tick 1` before resuming.
44//! Thread 1 is allowed to run freely in `tick 0`, until it blocks on
45//! the call to `sender.send(17)`. At this point, all threads are blocked, and
46//! the clock can advance to the next tick.
47//!
48//! In `tick 1`, the statement `receiver.next(42)` in the receiver is executed,
49//! and this frees up sender. The final statement in sender asserts that the
50//! clock is in `tick 1`, in effect asserting that the task blocked on the
51//! call to `sender.send(17)`.
52
53use futures::{
54    future::poll_fn,
55    task::{self, Poll, Waker},
56    Future,
57};
58
59use async_task::{self, Runnable};
60
61use std::cell::RefCell;
62use std::panic;
63use std::pin::Pin;
64use std::sync::{
65    atomic::{AtomicBool, Ordering},
66    Arc, Mutex,
67};
68use std::time::Duration;
69use std::vec::Vec;
70
71use derive_builder::Builder;
72use flume;
73use pin_project::pin_project;
74use threadpool::ThreadPool;
75
76/// Marks async test function.
77pub use async_metronome_attributes::test;
78
79const DEADLOCK: &str = "deadlock";
80const HASCONTEXT: &str = "hascontext";
81
82/// Options.
83#[derive(Clone, Default, Builder, Debug)]
84pub struct Options {
85    #[builder(setter(into, strip_option), default)]
86    _timeout: Option<Duration>,
87
88    #[builder(setter(into), default)]
89    debug: bool,
90}
91
92struct RunQueueEntry(usize, Runnable);
93
94struct TestContext {
95    tick: usize,
96    task_id: usize,
97    task_active: usize,
98    sender: flume::Sender<RunQueueEntry>,
99    wakers: Vec<Waker>,
100    options: Arc<Options>,
101}
102
103/// A future that awaits the result of a task.
104///
105/// Dropping a [`JoinHandle`] will detach the task, meaning that there is no longer
106/// a handle to the task and no way to `join` on it.
107pub struct JoinHandle<O> {
108    task: Option<async_task::Task<O>>,
109}
110
111impl<O> Future for JoinHandle<O> {
112    type Output = O;
113
114    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
115        Pin::new(&mut self.task.as_mut().unwrap()).poll(cx)
116    }
117}
118
119impl<T> Drop for JoinHandle<T> {
120    fn drop(&mut self) {
121        if let Some(task) = self.task.take() {
122            task.detach();
123        }
124    }
125}
126
127impl TestContext {
128    fn new(sender: flume::Sender<RunQueueEntry>, options: Arc<Options>) -> Self {
129        TestContext {
130            tick: 0,
131            task_id: 0,
132            task_active: 0,
133            sender,
134            wakers: Vec::new(),
135            options,
136        }
137    }
138
139    fn register_wait(&mut self, waker: Waker) {
140        self.wakers.push(waker);
141    }
142
143    fn next_tick(&mut self) -> usize {
144        let wakers = self.wakers.len();
145
146        if wakers > 0 {
147            self.tick += 1;
148
149            for waker in &self.wakers {
150                waker.wake_by_ref();
151            }
152
153            self.wakers.clear();
154
155            wakers
156        } else {
157            wakers
158        }
159    }
160
161    fn spawn<F, O>(&mut self, future: F) -> JoinHandle<O>
162    where
163        F: Future<Output = O> + Send + 'static,
164        O: Send + 'static,
165    {
166        let sender = self.sender.clone();
167
168        let task_id = self.task_id;
169        self.task_id += 1;
170        self.task_active += 1;
171
172        let schedule = move |runnable| {
173            sender.send(RunQueueEntry(task_id, runnable)).unwrap();
174        };
175
176        let options = self.options.clone();
177        if options.debug {
178            println!("{:?} ** spawn", task_id);
179        }
180        let (runnable, task) = async_task::spawn(
181            TaskWrapper {
182                future,
183                task_id,
184                options,
185            },
186            schedule,
187        );
188        runnable.schedule();
189
190        JoinHandle { task: Some(task) }
191    }
192}
193
194type WrappedTestContext = Arc<Mutex<TestContext>>;
195
196thread_local! {
197    static CONTEXT: RefCell<Option<WrappedTestContext>> = RefCell::new(None);
198}
199
200fn get_context() -> WrappedTestContext {
201    CONTEXT.with(|cell| cell.borrow().as_ref().expect(HASCONTEXT).clone())
202}
203
204#[doc(hidden)]
205pub fn __private_wait_tick(tick: usize) -> impl Future<Output = usize> {
206    poll_fn(move |cx| {
207        let test_context = get_context();
208        let mut test_context = test_context.lock().unwrap();
209
210        if test_context.tick >= tick {
211            Poll::Ready(tick)
212        } else {
213            test_context.register_wait(cx.waker().clone());
214            Poll::Pending
215        }
216    })
217}
218
219/// Awaits for the tick counter reach the specified value.
220///
221/// Tick counter increments when all tasks in the test case
222/// are in 'Pending' state, and at least one of them awaits for a
223/// tick.
224#[macro_export]
225macro_rules! await_tick {
226    ($tick:expr) => {
227        $crate::__private_wait_tick($tick as usize).await
228    };
229}
230
231#[doc(hidden)]
232pub fn __private_get_tick() -> usize {
233    get_context().lock().unwrap().tick
234}
235
236/// Asserts current tick counter value.
237#[macro_export]
238macro_rules! assert_tick {
239    ($expected:expr) => {
240        let actual = $crate::__private_get_tick();
241        assert!(
242            actual == $expected,
243            "tick mismatch: expected={}, actual={}",
244            $expected,
245            actual
246        )
247    };
248}
249
250#[pin_project]
251struct TaskWrapper<T> {
252    task_id: usize,
253    #[pin]
254    future: T,
255    options: Arc<Options>,
256}
257
258impl<T: Future> Future for TaskWrapper<T> {
259    type Output = T::Output;
260
261    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<Self::Output> {
262        let debug = self.options.debug;
263
264        let this = self.project();
265        let task_id = *this.task_id;
266
267        if debug {
268            println!("{:?} ** poll", task_id);
269        }
270
271        let context = get_context();
272        match panic::catch_unwind(panic::AssertUnwindSafe(|| this.future.poll(cx))) {
273            Ok(poll) => {
274                if poll.is_ready() {
275                    if debug {
276                        println!("{:?} ** ready", task_id);
277                    }
278
279                    context.lock().unwrap().task_active -= 1;
280                } else {
281                    if debug {
282                        println!("{:?} ** pending", task_id);
283                    }
284                }
285
286                poll
287            }
288            Err(error) => {
289                context.lock().unwrap().task_active -= 1;
290                panic::resume_unwind(error);
291            }
292        }
293    }
294}
295
296/// Checks if context is set.
297///
298pub fn is_context() -> bool {
299    CONTEXT.with(|cell| cell.borrow().as_ref().is_some())
300}
301
302/// Spawns a task
303///
304/// Panics if called outside of the test case - either a root task started by `run` / `run_opt` or
305/// one of child tasks.
306pub fn spawn<F, O>(future: F) -> JoinHandle<O>
307where
308    F: Future<Output = O> + Send + 'static,
309    O: Send + 'static,
310{
311    get_context().lock().expect(HASCONTEXT).spawn(future)
312}
313
314/// Runs the test case and blocks until it complets or panics.
315///
316/// Internally, it creates a `test context` that is
317/// propagated to subsequestly spawned futures.
318///
319/// # Panics
320/// Will panic if used from already running test case.
321/// Will panic if the future it runs panics.
322///
323/// Will panic if deadlock is detected. That means, all tasks are in 'pending' state
324/// and none of them is waiting for next tick (`await_tick`).
325pub fn run_opt<O, F>(future: F, options: Options)
326where
327    F: Future<Output = O> + Send + 'static,
328    O: Send + 'static,
329{
330    CONTEXT.with(|cell| {
331        if cell.borrow().is_some() {
332            panic!("{}", HASCONTEXT);
333        }
334    });
335
336    let options = Arc::new(options);
337    let pool = ThreadPool::new(8);
338    let (sender, receiver) = flume::unbounded::<RunQueueEntry>();
339    let mut context = TestContext::new(sender, options.clone());
340
341    context.spawn(future);
342
343    let context = Arc::new(Mutex::new(context));
344    let panic_flag = Arc::new(AtomicBool::new(false));
345    loop {
346        if let Ok(RunQueueEntry(task_id, runnable)) = receiver.try_recv() {
347            let panic_flag1 = panic_flag.clone();
348            let context = context.clone();
349
350            pool.execute(move || {
351                CONTEXT.with(|cell| cell.replace(Some(context)));
352
353                let result = panic::catch_unwind(panic::AssertUnwindSafe(|| runnable.run()));
354
355                if let Err(_) = result {
356                    if task_id == 0 {
357                        panic_flag1.store(true, Ordering::Relaxed);
358                    }
359                }
360
361                CONTEXT.with(|cell| cell.replace(None));
362            });
363
364            if !panic_flag.load(Ordering::Relaxed) {
365                continue;
366            }
367        }
368
369        // println!("queue empty, waiting joining runnables");
370        // wait for all to continue
371        pool.join();
372
373        if panic_flag.load(Ordering::Relaxed) {
374            panic!("root task panic");
375        }
376
377        // if there are new schedules, meanwhile
378        if !receiver.is_empty() {
379            continue;
380        }
381
382        let mut context = context.lock().unwrap();
383
384        if options.debug {
385            println!("queue exhaused: tc: {:?}", context.task_active);
386        }
387
388        if context.task_active > 0 {
389            let wakers = context.next_tick();
390
391            if wakers > 0 {
392                if options.debug {
393                    println!("tick -> {:?}, waking up {:?} wakers", context.tick, wakers);
394                }
395                continue;
396            } else {
397                panic!("{}", DEADLOCK);
398            }
399        } else {
400            break;
401        }
402    }
403}
404
405/// Same as [run_opt](fn.run_opt.html) but with default options.
406pub fn run<O, F>(future: F)
407where
408    F: Future<Output = O> + Send + 'static,
409    O: Send + 'static,
410{
411    run_opt(future, Options::default());
412}
413
414#[cfg(test)]
415mod tests {
416    #[test]
417    #[should_panic]
418    fn test_panic_no_context() {
419        super::spawn(async {});
420    }
421
422    #[test]
423    #[should_panic]
424    fn test_root_task_exception() {
425        super::run(async {
426            panic!();
427        });
428    }
429
430    #[test]
431    fn test_inner_task_exception() {
432        super::run(async {
433            super::spawn(async {
434                panic!();
435            });
436        });
437    }
438
439    #[test]
440    #[should_panic]
441    fn test_inner_task_exception_propagates() {
442        super::run(async {
443            let jh = super::spawn(async {
444                panic!();
445            });
446
447            jh.await;
448        });
449    }
450
451    #[test]
452    fn test_has_context() {
453        super::run(async {
454            super::CONTEXT.with(|cell| assert!(cell.borrow().is_some()));
455
456            super::spawn(async {
457                super::CONTEXT.with(|cell| {
458                    assert!(cell.borrow().is_some());
459                });
460            })
461            .await;
462        });
463    }
464
465    #[test]
466    #[should_panic]
467    fn test_panic_nested() {
468        super::run(async {
469            super::run(async {});
470        });
471    }
472
473    #[test]
474    fn test_initial_ticks_0() {
475        super::run(async {
476            assert_tick!(0);
477        });
478    }
479
480    #[test]
481    fn test_task_count() {
482        use futures::future::FutureExt;
483
484        super::run(async {
485            // just top level
486            assert_eq!(super::get_context().lock().unwrap().task_active, 1);
487
488            super::spawn(async {
489                // top level + nested
490                assert_eq!(super::get_context().lock().unwrap().task_active, 2);
491            })
492            .await;
493
494            // nested is gone
495            assert_eq!(super::get_context().lock().unwrap().task_active, 1);
496
497            // nested starts and panics
498            let handle = super::spawn(async {
499                panic!();
500            });
501
502            // still top level only
503            let _ = handle.catch_unwind().await;
504
505            assert_eq!(super::get_context().lock().unwrap().task_active, 1);
506        });
507    }
508
509    #[test]
510    fn test_ticks_increment_on_wait() {
511        super::run(async {
512            super::await_tick!(1);
513            super::assert_tick!(1);
514        });
515    }
516
517    #[test]
518    fn test_ticks_increment_on_wait_inner() {
519        super::run(async {
520            super::spawn(async {
521                super::await_tick!(1);
522            })
523            .await;
524            super::assert_tick!(1);
525        });
526    }
527
528    #[test]
529    #[should_panic]
530    fn test_deadlock() {
531        use async_std::task;
532        use std::time::Duration;
533
534        super::run(async {
535            task::sleep(Duration::from_secs(1)).await;
536        });
537    }
538
539    #[test]
540    #[should_panic]
541    fn test_deadlock_inner() {
542        use async_std::task;
543        use std::time::Duration;
544
545        super::run(async {
546            super::spawn(async {
547                task::sleep(Duration::from_secs(1)).await;
548            })
549            .await;
550        });
551    }
552}