nonparallel_async/
lib.rs

1//! A procedural macro for Rust that allows you to ensure that functions (e.g.
2//! unit tests) are not running at the same time.
3//!
4//! This is achieved by acquiring a mutex at the beginning of the annotated
5//! function.
6//!
7//! Different functions can synchronize on different mutexes. That's why a
8//! static mutex reference must be passed to the `nonparallel` annotation.
9//!
10//! ## Usage
11//!
12//! ```rust
13//! use tokio::sync::Mutex;
14//! use nonparallel_async::nonparallel_async;
15//!
16//! // Create two locks
17//! static MUT_A: Mutex<()> = Mutex::const_new(());
18//! static MUT_B: Mutex<()> = Mutex::const_new(());
19//!
20//! // Mutually exclude parallel runs of functions using those two locks
21//!
22//! #[nonparallel_async(MUT_A)]
23//! async fn function_a1() {
24//!     // This will not run in parallel to function_a2
25//! }
26//!
27//! #[nonparallel_async(MUT_A)]
28//! async fn function_a2() {
29//!     // This will not run in parallel to function_a1
30//! }
31//!
32//! #[nonparallel_async(MUT_B)]
33//! async fn function_b() {
34//!     // This may run in parallel to function_a*
35//! }
36//! ```
37
38extern crate proc_macro;
39
40use proc_macro::TokenStream;
41use quote::{quote, ToTokens};
42use syn::parse::{Parse, ParseStream};
43use syn::{parse, parse_macro_input, Ident, ItemFn, Stmt};
44
45#[derive(Debug)]
46struct Nonparallel {
47    ident: Ident,
48}
49
50impl Parse for Nonparallel {
51    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
52        let ident = input.parse::<Ident>()?;
53        Ok(Nonparallel { ident })
54    }
55}
56
57#[proc_macro_attribute]
58pub fn nonparallel_async(attr: TokenStream, item: TokenStream) -> TokenStream {
59    // Parse macro attributes
60    let Nonparallel { ident } = parse_macro_input!(attr);
61
62    // Parse function
63    let mut function: ItemFn = parse(item).expect("Could not parse ItemFn");
64
65    // Insert locking code
66    let quoted = quote! { let guard = #ident.lock().await; };
67    let stmt: Stmt = parse(quoted.into()).expect("Could not parse quoted statement");
68    function.block.stmts.insert(0, stmt);
69
70    // Generate token stream
71    TokenStream::from(function.to_token_stream())
72}