use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{parse_macro_input, ItemFn, Meta};
#[derive(Default)]
struct TestConfig {
tokio_thread_count: Option<usize>,
rayon_thread_count: Option<usize>,
}
impl TestConfig {
fn parse(attrs: &[Meta]) -> syn::Result<Self> {
let mut config = Self::default();
for meta in attrs {
if let Meta::NameValue(nv) = meta {
let ident = nv
.path
.get_ident()
.ok_or_else(|| syn::Error::new_spanned(&nv.path, "expected identifier"))?;
let value = match &nv.value {
syn::Expr::Lit(syn::ExprLit {
lit: syn::Lit::Int(lit),
..
}) => lit.base10_parse::<usize>()?,
_ => {
return Err(syn::Error::new_spanned(
&nv.value,
"expected integer literal",
))
}
};
match ident.to_string().as_str() {
"tokio_thread_count" => config.tokio_thread_count = Some(value),
"rayon_thread_count" => config.rayon_thread_count = Some(value),
_ => {
return Err(syn::Error::new_spanned(
ident,
format!(
"unknown attribute `{}`, expected `tokio_thread_count` or `rayon_thread_count`",
ident
),
))
}
}
} else {
return Err(syn::Error::new_spanned(
meta,
"expected `key = value` format",
));
}
}
Ok(config)
}
}
#[proc_macro_attribute]
pub fn test(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemFn);
let attr_parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
let attrs = match syn::parse::Parser::parse(attr_parser, attr) {
Ok(attrs) => attrs,
Err(e) => return e.to_compile_error().into(),
};
let config = match TestConfig::parse(&attrs.into_iter().collect::<Vec<_>>()) {
Ok(c) => c,
Err(e) => return e.to_compile_error().into(),
};
match generate_test(input, config) {
Ok(tokens) => tokens.into(),
Err(e) => e.to_compile_error().into(),
}
}
fn generate_test(input: ItemFn, config: TestConfig) -> syn::Result<TokenStream2> {
let ItemFn {
attrs,
vis,
sig,
block,
} = input;
if sig.asyncness.is_none() {
return Err(syn::Error::new_spanned(
sig.fn_token,
"test function must be async",
));
}
let fn_name = &sig.ident;
let tokio_threads = config.tokio_thread_count.unwrap_or(1);
let rayon_threads = config.rayon_thread_count.unwrap_or(2);
let mut new_sig = sig.clone();
new_sig.asyncness = None;
let has_return_type = !matches!(&sig.output, syn::ReturnType::Default);
let output = if has_return_type {
quote! {
#[::core::prelude::v1::test]
#(#attrs)*
#vis #new_sig {
let __loom_runtime = ::loom_rs::LoomBuilder::new()
.prefix(concat!("test-", stringify!(#fn_name)))
.tokio_threads(#tokio_threads)
.rayon_threads(#rayon_threads)
.pin_threads(false)
.build()
.expect("failed to create test runtime");
let __result = __loom_runtime.block_on(async #block);
__loom_runtime.block_until_idle();
__result
}
}
} else {
quote! {
#[::core::prelude::v1::test]
#(#attrs)*
#vis #new_sig {
let __loom_runtime = ::loom_rs::LoomBuilder::new()
.prefix(concat!("test-", stringify!(#fn_name)))
.tokio_threads(#tokio_threads)
.rayon_threads(#rayon_threads)
.pin_threads(false)
.build()
.expect("failed to create test runtime");
__loom_runtime.block_on(async #block);
__loom_runtime.block_until_idle();
}
}
};
Ok(output)
}