#![recursion_limit = "128"]
#![deny(unused_must_use)]
extern crate proc_macro;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use std::collections::HashMap;
use syn::parse::{Parse, ParseStream, Result as ParseResult};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::Comma;
use syn::{braced, parse_macro_input, FnArg, Ident, ItemFn, Pat, PatIdent, PatType, Type};
type Error = syn::parse::Error;
struct TemplateArg {
ident: syn::Ident,
is_pattern: bool,
ignore_fn: Option<syn::Path>,
value: syn::LitStr,
}
impl Parse for TemplateArg {
fn parse(input: ParseStream) -> ParseResult<Self> {
let mut ignore_fn = None;
let ident = input.parse::<syn::Ident>()?;
let is_pattern = if input.peek(syn::token::In) {
let _in = input.parse::<syn::token::In>()?;
true
} else {
let _eq = input.parse::<syn::token::Eq>()?;
false
};
let value = input.parse::<syn::LitStr>()?;
if is_pattern && input.peek(syn::token::If) {
let _if = input.parse::<syn::token::If>()?;
let _not = input.parse::<syn::token::Bang>()?;
ignore_fn = Some(input.parse::<syn::Path>()?);
}
Ok(Self {
ident,
is_pattern,
ignore_fn,
value,
})
}
}
struct FilesTestArgs {
root: String,
args: HashMap<Ident, TemplateArg>,
}
impl Parse for FilesTestArgs {
fn parse(input: ParseStream) -> ParseResult<Self> {
let root = input.parse::<syn::LitStr>()?;
let _comma = input.parse::<syn::token::Comma>()?;
let content;
let _brace_token = braced!(content in input);
let args: Punctuated<TemplateArg, Comma> = content.parse_terminated(TemplateArg::parse)?;
let args = args
.into_pairs()
.map(|p| {
let value = p.into_value();
(value.ident.clone(), value)
})
.collect();
Ok(Self {
root: root.value(),
args,
})
}
}
enum Registration {
Ctor,
Nightly,
}
#[proc_macro_attribute]
pub fn files_ctor_registration(
args: proc_macro::TokenStream,
func: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
guarded_test_attribute(
args,
func,
Ident::new("files_ctor_internal", Span::call_site()),
)
}
#[proc_macro_attribute]
pub fn files_test_case_registration(
args: proc_macro::TokenStream,
func: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
guarded_test_attribute(
args,
func,
Ident::new("files_test_case_internal", Span::call_site()),
)
}
#[proc_macro_attribute]
pub fn files_ctor_internal(
args: proc_macro::TokenStream,
func: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
files_internal(args, func, Registration::Ctor)
}
#[proc_macro_attribute]
pub fn files_test_case_internal(
args: proc_macro::TokenStream,
func: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
files_internal(args, func, Registration::Nightly)
}
fn files_internal(
args: proc_macro::TokenStream,
func: proc_macro::TokenStream,
channel: Registration,
) -> proc_macro::TokenStream {
let mut func_item: ItemFn = parse_macro_input!(func as ItemFn);
let args: FilesTestArgs = parse_macro_input!(args as FilesTestArgs);
let info = handle_common_attrs(&mut func_item, false);
let func_ident = &func_item.sig.ident;
let func_name_str = func_ident.to_string();
let desc_ident = Ident::new(&format!("__TEST_{}", func_ident), func_ident.span());
let trampoline_func_ident = Ident::new(
&format!("__TEST_TRAMPOLINE_{}", func_ident),
func_ident.span(),
);
let ignore = info.ignore;
let root = args.root;
let mut pattern_idx = None;
let mut params: Vec<String> = Vec::new();
let mut invoke_args: Vec<TokenStream> = Vec::new();
let mut ignore_fn = None;
for (mut idx, arg) in func_item.sig.inputs.iter().enumerate() {
match match_arg(arg) {
Some((pat_ident, ty)) => {
if info.bench {
if idx == 0 {
invoke_args.push(quote!(#pat_ident));
continue;
} else {
idx -= 1;
}
}
if let Some(arg) = args.args.get(&pat_ident.ident) {
if arg.is_pattern {
if pattern_idx.is_some() {
return Error::new(arg.ident.span(), "two patterns are not allowed!")
.to_compile_error()
.into();
}
pattern_idx = Some(idx);
ignore_fn = arg.ignore_fn.clone();
}
params.push(arg.value.value());
invoke_args.push(quote! {
::datatest::__internal::TakeArg::take(&mut <#ty as ::datatest::__internal::DeriveArg>::derive(&paths_arg[#idx]))
})
} else {
return Error::new(pat_ident.span(), "mapping is not defined for the argument")
.to_compile_error()
.into();
}
}
None => {
return Error::new(
arg.span(),
"unexpected argument; only simple argument types are allowed (`&str`, `String`, `&[u8]`, `Vec<u8>`, `&Path`, etc)",
).to_compile_error().into();
}
}
}
let ignore_func_ref = if let Some(ignore_fn) = ignore_fn {
quote!(Some(#ignore_fn))
} else {
quote!(None)
};
if pattern_idx.is_none() {
return Error::new(
Span::call_site(),
"must have exactly one pattern mapping defined via `pattern in r#\"<regular expression>\"`",
)
.to_compile_error()
.into();
}
let (kind, bencher_param) = if info.bench {
(
quote!(BenchFn),
quote!(bencher: &mut ::datatest::__internal::Bencher,),
)
} else {
(quote!(TestFn), quote!())
};
let registration = test_registration(channel, &desc_ident);
let output = quote! {
#registration
#[automatically_derived]
#[allow(non_upper_case_globals)]
static #desc_ident: ::datatest::__internal::FilesTestDesc = ::datatest::__internal::FilesTestDesc {
name: concat!(module_path!(), "::", #func_name_str),
ignore: #ignore,
root: #root,
params: &[#(#params),*],
pattern: #pattern_idx,
ignorefn: #ignore_func_ref,
testfn: ::datatest::__internal::FilesTestFn::#kind(#trampoline_func_ident),
source_file: file!(),
};
#[automatically_derived]
#[allow(non_snake_case)]
fn #trampoline_func_ident(#bencher_param paths_arg: &[::std::path::PathBuf]) {
let result = #func_ident(#(#invoke_args),*);
::datatest::__internal::assert_test_result(result);
}
#func_item
};
output.into()
}
fn match_arg(arg: &FnArg) -> Option<(&PatIdent, &Type)> {
if let FnArg::Typed(PatType { pat, ty, .. }) = arg {
if let Pat::Ident(pat_ident) = pat.as_ref() {
return Some((pat_ident, ty));
}
}
None
}
enum ShouldPanic {
No,
Yes,
YesWithMessage(String),
}
struct FuncInfo {
ignore: bool,
bench: bool,
should_panic: ShouldPanic,
}
fn handle_common_attrs(func: &mut ItemFn, regular_test: bool) -> FuncInfo {
let test_pos = func
.attrs
.iter()
.position(|attr| attr.path.is_ident("test"));
if let Some(pos) = test_pos {
func.attrs.remove(pos);
}
let bench_pos = func
.attrs
.iter()
.position(|attr| attr.path.is_ident("bench"));
if let Some(pos) = bench_pos {
func.attrs.remove(pos);
}
let ignore_pos = func
.attrs
.iter()
.position(|attr| attr.path.is_ident("ignore"));
if let Some(pos) = ignore_pos {
func.attrs.remove(pos);
}
let mut should_panic = ShouldPanic::No;
if regular_test {
let should_panic_pos = func
.attrs
.iter()
.position(|attr| attr.path.is_ident("should_panic"));
if let Some(pos) = should_panic_pos {
let attr = &func.attrs[pos];
should_panic = parse_should_panic(attr);
func.attrs.remove(pos);
}
}
FuncInfo {
ignore: ignore_pos.is_some(),
bench: bench_pos.is_some(),
should_panic,
}
}
fn parse_should_panic(attr: &syn::Attribute) -> ShouldPanic {
if let Ok(meta) = attr.parse_meta() {
if let syn::Meta::List(list) = meta {
for item in list.nested {
match item {
syn::NestedMeta::Meta(syn::Meta::NameValue(ref nv))
if nv.path.is_ident("expected") =>
{
if let syn::Lit::Str(ref value) = nv.lit {
return ShouldPanic::YesWithMessage(value.value());
}
}
_ => {}
}
}
}
}
ShouldPanic::Yes
}
#[allow(clippy::large_enum_variant)]
enum DataTestArgs {
Literal(syn::LitStr),
Expression(syn::Expr),
}
impl Parse for DataTestArgs {
fn parse(input: ParseStream) -> ParseResult<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(syn::LitStr) {
input.parse::<syn::LitStr>().map(DataTestArgs::Literal)
} else {
input.parse::<syn::Expr>().map(DataTestArgs::Expression)
}
}
}
#[proc_macro_attribute]
pub fn data_ctor_registration(
args: proc_macro::TokenStream,
func: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
guarded_test_attribute(
args,
func,
Ident::new("data_ctor_internal", Span::call_site()),
)
}
#[proc_macro_attribute]
pub fn data_test_case_registration(
args: proc_macro::TokenStream,
func: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
guarded_test_attribute(
args,
func,
Ident::new("data_test_case_internal", Span::call_site()),
)
}
#[proc_macro_attribute]
pub fn data_ctor_internal(
args: proc_macro::TokenStream,
func: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
data_internal(args, func, Registration::Ctor)
}
#[proc_macro_attribute]
pub fn data_test_case_internal(
args: proc_macro::TokenStream,
func: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
data_internal(args, func, Registration::Nightly)
}
fn data_internal(
args: proc_macro::TokenStream,
func: proc_macro::TokenStream,
channel: Registration,
) -> proc_macro::TokenStream {
let mut func_item = parse_macro_input!(func as ItemFn);
let cases: DataTestArgs = parse_macro_input!(args as DataTestArgs);
let info = handle_common_attrs(&mut func_item, false);
let cases = match cases {
DataTestArgs::Literal(path) => quote!(datatest::yaml(#path)),
DataTestArgs::Expression(expr) => quote!(#expr),
};
let func_ident = &func_item.sig.ident;
let func_name_str = func_ident.to_string();
let desc_ident = Ident::new(&format!("__TEST_{}", func_ident), func_ident.span());
let describe_func_ident = Ident::new(
&format!("__TEST_DESCRIBE_{}", func_ident),
func_ident.span(),
);
let trampoline_func_ident = Ident::new(
&format!("__TEST_TRAMPOLINE_{}", func_ident),
func_ident.span(),
);
let ignore = info.ignore;
let mut args = func_item.sig.inputs.iter();
if info.bench {
args.next();
}
let arg = args.next();
let ty = match arg {
Some(FnArg::Typed(PatType { ty, .. })) => Some(ty.as_ref()),
_ => None,
};
let (ref_token, ty) = match ty {
Some(syn::Type::Reference(type_ref)) => (quote!(&), Some(type_ref.elem.as_ref())),
_ => (TokenStream::new(), ty),
};
let (case_ctor, bencher_param, bencher_arg) = if info.bench {
(
quote!(::datatest::__internal::DataTestFn::BenchFn(Box::new(::datatest::__internal::DataBenchFn(#trampoline_func_ident, case)))),
quote!(bencher: &mut ::datatest::__internal::Bencher,),
quote!(bencher,),
)
} else {
(
quote!(::datatest::__internal::DataTestFn::TestFn(Box::new(move || #trampoline_func_ident(case)))),
quote!(),
quote!(),
)
};
let registration = test_registration(channel, &desc_ident);
let output = quote! {
#registration
#[automatically_derived]
#[allow(non_upper_case_globals)]
static #desc_ident: ::datatest::__internal::DataTestDesc = ::datatest::__internal::DataTestDesc {
name: concat!(module_path!(), "::", #func_name_str),
ignore: #ignore,
describefn: #describe_func_ident,
source_file: file!(),
};
#[automatically_derived]
#[allow(non_snake_case)]
fn #trampoline_func_ident(#bencher_param arg: #ty) {
let result = #func_ident(#bencher_arg #ref_token arg);
::datatest::__internal::assert_test_result(result);
}
#[automatically_derived]
#[allow(non_snake_case)]
fn #describe_func_ident() -> Vec<::datatest::DataTestCaseDesc<::datatest::__internal::DataTestFn>> {
let result = #cases
.into_iter()
.map(|input| {
let case = input.case;
::datatest::DataTestCaseDesc {
case: #case_ctor,
name: input.name,
location: input.location,
}
})
.collect::<Vec<_>>();
assert!(!result.is_empty(), "no test cases were found!");
result
}
#func_item
};
output.into()
}
fn test_registration(channel: Registration, desc_ident: &syn::Ident) -> TokenStream {
match channel {
Registration::Nightly => quote!(#[test_case]),
Registration::Ctor => {
let registration_fn =
syn::Ident::new(&format!("{}__REGISTRATION", desc_ident), desc_ident.span());
let check_fn = syn::Ident::new(&format!("{}__CHECK", desc_ident), desc_ident.span());
let tokens = quote! {
#[automatically_derived]
#[allow(non_snake_case)]
#[datatest::__internal::ctor]
fn #registration_fn() {
use ::datatest::__internal::RegistrationNode;
static mut REGISTRATION: RegistrationNode = RegistrationNode {
descriptor: &#desc_ident,
next: None,
};
::datatest::__internal::register(unsafe { &mut REGISTRATION });
}
#[automatically_derived]
#[allow(non_snake_case)]
mod #check_fn {
#[datatest::__internal::dtor]
fn check_fn() {
::datatest::__internal::check_test_runner();
}
}
};
tokens
}
}
}
#[proc_macro_attribute]
pub fn test_ctor_registration(
_args: proc_macro::TokenStream,
func: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let mut func_item = parse_macro_input!(func as ItemFn);
let info = handle_common_attrs(&mut func_item, true);
let func_ident = &func_item.sig.ident;
let func_name_str = func_ident.to_string();
let desc_ident = Ident::new(&format!("__TEST_{}", func_ident), func_ident.span());
let ignore = info.ignore;
let should_panic = match info.should_panic {
ShouldPanic::No => quote!(::datatest::__internal::RegularShouldPanic::No),
ShouldPanic::Yes => quote!(::datatest::__internal::RegularShouldPanic::Yes),
ShouldPanic::YesWithMessage(v) => {
quote!(::datatest::__internal::RegularShouldPanic::YesWithMessage(#v))
}
};
let registration = test_registration(Registration::Ctor, &desc_ident);
let output = quote! {
#registration
#[automatically_derived]
#[allow(non_upper_case_globals)]
static #desc_ident: ::datatest::__internal::RegularTestDesc = ::datatest::__internal::RegularTestDesc {
name: concat!(module_path!(), "::", #func_name_str),
ignore: #ignore,
testfn: || {
let result = #func_ident();
::datatest::__internal::assert_test_result(result);
},
should_panic: #should_panic,
source_file: file!(),
};
#func_item
};
output.into()
}
fn guarded_test_attribute(
args: proc_macro::TokenStream,
item: proc_macro::TokenStream,
implementation: Ident,
) -> proc_macro::TokenStream {
let args: TokenStream = args.into();
let header = quote! {
#[cfg(test)]
#[::datatest::__internal::#implementation(#args)]
};
let mut out: proc_macro::TokenStream = header.into();
out.extend(item);
out
}