nonparallel_async/lib.rs
1//! A procedural macro for Rust that allows you to ensure that functions (e.g.
2//! unit tests) are not running at the same time.
3//!
4//! This is achieved by acquiring a mutex at the beginning of the annotated
5//! function.
6//!
7//! Different functions can synchronize on different mutexes. That's why a
8//! static mutex reference must be passed to the `nonparallel` annotation.
9//!
10//! ## Usage
11//!
12//! ```rust
13//! use tokio::sync::Mutex;
14//! use nonparallel_async::nonparallel_async;
15//!
16//! // Create two locks
17//! static MUT_A: Mutex<()> = Mutex::const_new(());
18//! static MUT_B: Mutex<()> = Mutex::const_new(());
19//!
20//! // Mutually exclude parallel runs of functions using those two locks
21//!
22//! #[nonparallel_async(MUT_A)]
23//! async fn function_a1() {
24//! // This will not run in parallel to function_a2
25//! }
26//!
27//! #[nonparallel_async(MUT_A)]
28//! async fn function_a2() {
29//! // This will not run in parallel to function_a1
30//! }
31//!
32//! #[nonparallel_async(MUT_B)]
33//! async fn function_b() {
34//! // This may run in parallel to function_a*
35//! }
36//! ```
37
38extern crate proc_macro;
39
40use proc_macro::TokenStream;
41use quote::{quote, ToTokens};
42use syn::parse::{Parse, ParseStream};
43use syn::{parse, parse_macro_input, Ident, ItemFn, Stmt};
44
45#[derive(Debug)]
46struct Nonparallel {
47 ident: Ident,
48}
49
50impl Parse for Nonparallel {
51 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
52 let ident = input.parse::<Ident>()?;
53 Ok(Nonparallel { ident })
54 }
55}
56
57#[proc_macro_attribute]
58pub fn nonparallel_async(attr: TokenStream, item: TokenStream) -> TokenStream {
59 // Parse macro attributes
60 let Nonparallel { ident } = parse_macro_input!(attr);
61
62 // Parse function
63 let mut function: ItemFn = parse(item).expect("Could not parse ItemFn");
64
65 // Insert locking code
66 let quoted = quote! { let guard = #ident.lock().await; };
67 let stmt: Stmt = parse(quoted.into()).expect("Could not parse quoted statement");
68 function.block.stmts.insert(0, stmt);
69
70 // Generate token stream
71 TokenStream::from(function.to_token_stream())
72}