#![warn(rust_2018_idioms, missing_docs)]
use proc_macro::TokenStream;
use quote::{format_ident, quote};
use syn::{
parse_macro_input, punctuated::Punctuated, token::Comma, AttributeArgs, Error, FnArg, ItemFn,
NestedMeta, Pat, Type,
};
struct Arguments {
ids: Punctuated<Pat, Comma>,
tys: Punctuated<Type, Comma>,
}
fn parse_args(fn_item: &ItemFn) -> Result<Arguments, TokenStream> {
let mut args = Arguments {
ids: Punctuated::new(),
tys: Punctuated::new(),
};
for pt in fn_item.sig.inputs.iter() {
match pt {
FnArg::Receiver(_) => {
return Err(
Error::new_spanned(&fn_item, "test fn cannot take a receiver")
.to_compile_error()
.into(),
)
}
FnArg::Typed(pt) => {
args.ids.push(*pt.pat.clone());
args.tys.push(*pt.ty.clone());
}
}
}
Ok(args)
}
#[proc_macro_attribute]
pub fn tokio(args: TokenStream, item: TokenStream) -> TokenStream {
let fn_item = parse_macro_input!(item as ItemFn);
for attr in &fn_item.attrs {
if attr.path.is_ident("test") {
return Error::new_spanned(&fn_item, "multiple #[test] attributes were supplied")
.to_compile_error()
.into();
}
}
if fn_item.sig.asyncness.is_none() {
return Error::new_spanned(&fn_item, "test fn must be async")
.to_compile_error()
.into();
}
let p_args = parse_macro_input!(args as AttributeArgs);
let attrib: Punctuated<NestedMeta, Comma> = p_args.into_iter().collect();
let call_by = format_ident!("{}", fn_item.sig.ident);
let Arguments { ids, tys } = match parse_args(&fn_item) {
Err(e) => return e,
Ok(ts) => ts,
};
let ret = &fn_item.sig.output;
quote! (
#[::tokio::test(#attrib)]
async fn #call_by() {
#fn_item
let test_fn: fn(#tys) #ret = |#ids| {
::futures::executor::block_on(#call_by(#ids))
};
::tokio::task::spawn_blocking(move || {
::quickcheck::quickcheck(test_fn)
})
.await
.unwrap()
}
)
.into()
}
#[proc_macro_attribute]
pub fn async_std(args: TokenStream, item: TokenStream) -> TokenStream {
let fn_item = parse_macro_input!(item as ItemFn);
for attr in &fn_item.attrs {
if attr.path.is_ident("test") {
return Error::new_spanned(&fn_item, "multiple #[test] attributes were supplied")
.to_compile_error()
.into();
}
}
if fn_item.sig.asyncness.is_none() {
return Error::new_spanned(&fn_item, "test fn must be async")
.to_compile_error()
.into();
}
let p_args = parse_macro_input!(args as AttributeArgs);
let attrib: Punctuated<NestedMeta, Comma> = p_args.into_iter().collect();
let call_by = format_ident!("{}", fn_item.sig.ident);
let Arguments { ids, tys } = match parse_args(&fn_item) {
Err(e) => return e,
Ok(ts) => ts,
};
let ret = &fn_item.sig.output;
quote! (
#[::async_std::test(#attrib)]
async fn #call_by() {
#fn_item
let test_fn: fn(#tys) #ret = |#ids| {
::futures::executor::block_on(#call_by(#ids))
};
::quickcheck::quickcheck(test_fn);
}
)
.into()
}