Skip to main content

loom_macros/
lib.rs

1//! Procedural macros for loom-rs runtime.
2//!
3//! This crate provides the `#[loom_rs::test]` attribute macro for writing
4//! tests that run within a LoomRuntime.
5
6use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::quote;
9use syn::{parse_macro_input, ItemFn, Meta};
10
11/// Configuration parsed from the macro attributes.
12#[derive(Default)]
13struct TestConfig {
14    tokio_thread_count: Option<usize>,
15    rayon_thread_count: Option<usize>,
16}
17
18impl TestConfig {
19    fn parse(attrs: &[Meta]) -> syn::Result<Self> {
20        let mut config = Self::default();
21
22        for meta in attrs {
23            if let Meta::NameValue(nv) = meta {
24                let ident = nv
25                    .path
26                    .get_ident()
27                    .ok_or_else(|| syn::Error::new_spanned(&nv.path, "expected identifier"))?;
28
29                let value = match &nv.value {
30                    syn::Expr::Lit(syn::ExprLit {
31                        lit: syn::Lit::Int(lit),
32                        ..
33                    }) => lit.base10_parse::<usize>()?,
34                    _ => {
35                        return Err(syn::Error::new_spanned(
36                            &nv.value,
37                            "expected integer literal",
38                        ))
39                    }
40                };
41
42                match ident.to_string().as_str() {
43                    "tokio_thread_count" => config.tokio_thread_count = Some(value),
44                    "rayon_thread_count" => config.rayon_thread_count = Some(value),
45                    _ => {
46                        return Err(syn::Error::new_spanned(
47                            ident,
48                            format!(
49                                "unknown attribute `{}`, expected `tokio_thread_count` or `rayon_thread_count`",
50                                ident
51                            ),
52                        ))
53                    }
54                }
55            } else {
56                return Err(syn::Error::new_spanned(
57                    meta,
58                    "expected `key = value` format",
59                ));
60            }
61        }
62
63        Ok(config)
64    }
65}
66
67/// A test attribute macro for loom-rs that sets up a LoomRuntime with test-appropriate defaults.
68///
69/// # Default Configuration
70///
71/// - 1 tokio thread
72/// - 2 rayon threads
73/// - Thread pinning disabled
74///
75/// # Attributes
76///
77/// - `tokio_thread_count = N` - Set the number of tokio worker threads
78/// - `rayon_thread_count = N` - Set the number of rayon threads
79///
80/// # Examples
81///
82/// Basic usage with defaults:
83///
84/// ```ignore
85/// #[loom_rs::test]
86/// async fn test_spawn_compute() {
87///     let result = loom_rs::spawn_compute(|| 42).await;
88///     assert_eq!(result, 42);
89/// }
90/// ```
91///
92/// With Result return type (supports anyhow::Result, etc.):
93///
94/// ```ignore
95/// #[loom_rs::test]
96/// async fn test_with_result() -> anyhow::Result<()> {
97///     let result = loom_rs::spawn_compute(|| 42).await;
98///     assert_eq!(result, 42);
99///     Ok(())
100/// }
101/// ```
102///
103/// Custom thread counts:
104///
105/// ```ignore
106/// #[loom_rs::test(tokio_thread_count = 2, rayon_thread_count = 4)]
107/// async fn test_parallel_work() {
108///     // Test code here
109/// }
110/// ```
111#[proc_macro_attribute]
112pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
113    let input = parse_macro_input!(item as ItemFn);
114
115    // Parse attributes
116    let attr_parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
117    let attrs = match syn::parse::Parser::parse(attr_parser, attr) {
118        Ok(attrs) => attrs,
119        Err(e) => return e.to_compile_error().into(),
120    };
121
122    let config = match TestConfig::parse(&attrs.into_iter().collect::<Vec<_>>()) {
123        Ok(c) => c,
124        Err(e) => return e.to_compile_error().into(),
125    };
126
127    match generate_test(input, config) {
128        Ok(tokens) => tokens.into(),
129        Err(e) => e.to_compile_error().into(),
130    }
131}
132
133fn generate_test(input: ItemFn, config: TestConfig) -> syn::Result<TokenStream2> {
134    let ItemFn {
135        attrs,
136        vis,
137        sig,
138        block,
139    } = input;
140
141    // Verify the function is async
142    if sig.asyncness.is_none() {
143        return Err(syn::Error::new_spanned(
144            sig.fn_token,
145            "test function must be async",
146        ));
147    }
148
149    let fn_name = &sig.ident;
150
151    // Get thread counts with defaults
152    let tokio_threads = config.tokio_thread_count.unwrap_or(1);
153    let rayon_threads = config.rayon_thread_count.unwrap_or(2);
154
155    // Create the new synchronous function signature
156    let mut new_sig = sig.clone();
157    new_sig.asyncness = None;
158
159    // Check if the function returns a Result (has a non-unit return type)
160    let has_return_type = !matches!(&sig.output, syn::ReturnType::Default);
161
162    // Generate the test function
163    let output = if has_return_type {
164        // Function returns something (likely Result<()>), capture and return it
165        quote! {
166            #[::core::prelude::v1::test]
167            #(#attrs)*
168            #vis #new_sig {
169                let __loom_runtime = ::loom_rs::LoomBuilder::new()
170                    .prefix(concat!("test-", stringify!(#fn_name)))
171                    .tokio_threads(#tokio_threads)
172                    .rayon_threads(#rayon_threads)
173                    .pin_threads(false)
174                    .build()
175                    .expect("failed to create test runtime");
176
177                let __result = __loom_runtime.block_on(async #block);
178                __loom_runtime.block_until_idle();
179                __result
180            }
181        }
182    } else {
183        // Function returns (), no need to capture
184        quote! {
185            #[::core::prelude::v1::test]
186            #(#attrs)*
187            #vis #new_sig {
188                let __loom_runtime = ::loom_rs::LoomBuilder::new()
189                    .prefix(concat!("test-", stringify!(#fn_name)))
190                    .tokio_threads(#tokio_threads)
191                    .rayon_threads(#rayon_threads)
192                    .pin_threads(false)
193                    .build()
194                    .expect("failed to create test runtime");
195
196                __loom_runtime.block_on(async #block);
197                __loom_runtime.block_until_idle();
198            }
199        }
200    };
201
202    Ok(output)
203}