1use proc_macro::TokenStream;
7use proc_macro2::TokenStream as TokenStream2;
8use quote::quote;
9use syn::{parse_macro_input, ItemFn, Meta};
10
11#[derive(Default)]
13struct TestConfig {
14 tokio_thread_count: Option<usize>,
15 rayon_thread_count: Option<usize>,
16}
17
18impl TestConfig {
19 fn parse(attrs: &[Meta]) -> syn::Result<Self> {
20 let mut config = Self::default();
21
22 for meta in attrs {
23 if let Meta::NameValue(nv) = meta {
24 let ident = nv
25 .path
26 .get_ident()
27 .ok_or_else(|| syn::Error::new_spanned(&nv.path, "expected identifier"))?;
28
29 let value = match &nv.value {
30 syn::Expr::Lit(syn::ExprLit {
31 lit: syn::Lit::Int(lit),
32 ..
33 }) => lit.base10_parse::<usize>()?,
34 _ => {
35 return Err(syn::Error::new_spanned(
36 &nv.value,
37 "expected integer literal",
38 ))
39 }
40 };
41
42 match ident.to_string().as_str() {
43 "tokio_thread_count" => config.tokio_thread_count = Some(value),
44 "rayon_thread_count" => config.rayon_thread_count = Some(value),
45 _ => {
46 return Err(syn::Error::new_spanned(
47 ident,
48 format!(
49 "unknown attribute `{}`, expected `tokio_thread_count` or `rayon_thread_count`",
50 ident
51 ),
52 ))
53 }
54 }
55 } else {
56 return Err(syn::Error::new_spanned(
57 meta,
58 "expected `key = value` format",
59 ));
60 }
61 }
62
63 Ok(config)
64 }
65}
66
67#[proc_macro_attribute]
112pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
113 let input = parse_macro_input!(item as ItemFn);
114
115 let attr_parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
117 let attrs = match syn::parse::Parser::parse(attr_parser, attr) {
118 Ok(attrs) => attrs,
119 Err(e) => return e.to_compile_error().into(),
120 };
121
122 let config = match TestConfig::parse(&attrs.into_iter().collect::<Vec<_>>()) {
123 Ok(c) => c,
124 Err(e) => return e.to_compile_error().into(),
125 };
126
127 match generate_test(input, config) {
128 Ok(tokens) => tokens.into(),
129 Err(e) => e.to_compile_error().into(),
130 }
131}
132
133fn generate_test(input: ItemFn, config: TestConfig) -> syn::Result<TokenStream2> {
134 let ItemFn {
135 attrs,
136 vis,
137 sig,
138 block,
139 } = input;
140
141 if sig.asyncness.is_none() {
143 return Err(syn::Error::new_spanned(
144 sig.fn_token,
145 "test function must be async",
146 ));
147 }
148
149 let fn_name = &sig.ident;
150
151 let tokio_threads = config.tokio_thread_count.unwrap_or(1);
153 let rayon_threads = config.rayon_thread_count.unwrap_or(2);
154
155 let mut new_sig = sig.clone();
157 new_sig.asyncness = None;
158
159 let has_return_type = !matches!(&sig.output, syn::ReturnType::Default);
161
162 let output = if has_return_type {
164 quote! {
166 #[::core::prelude::v1::test]
167 #(#attrs)*
168 #vis #new_sig {
169 let __loom_runtime = ::loom_rs::LoomBuilder::new()
170 .prefix(concat!("test-", stringify!(#fn_name)))
171 .tokio_threads(#tokio_threads)
172 .rayon_threads(#rayon_threads)
173 .pin_threads(false)
174 .build()
175 .expect("failed to create test runtime");
176
177 let __result = __loom_runtime.block_on(async #block);
178 __loom_runtime.block_until_idle();
179 __result
180 }
181 }
182 } else {
183 quote! {
185 #[::core::prelude::v1::test]
186 #(#attrs)*
187 #vis #new_sig {
188 let __loom_runtime = ::loom_rs::LoomBuilder::new()
189 .prefix(concat!("test-", stringify!(#fn_name)))
190 .tokio_threads(#tokio_threads)
191 .rayon_threads(#rayon_threads)
192 .pin_threads(false)
193 .build()
194 .expect("failed to create test runtime");
195
196 __loom_runtime.block_on(async #block);
197 __loom_runtime.block_until_idle();
198 }
199 }
200 };
201
202 Ok(output)
203}