#![recursion_limit = "128"]
#![deny(unused_must_use)]
extern crate proc_macro;
#[macro_use]
extern crate syn;
#[macro_use]
extern crate quote;
extern crate proc_macro2;
use proc_macro2::{Span, TokenStream};
use std::collections::HashMap;
use syn::parse::{Parse, ParseStream, Result as ParseResult};
use syn::punctuated::Punctuated;
use syn::spanned::Spanned;
use syn::token::Comma;
use syn::{ArgCaptured, FnArg, Ident, ItemFn, Pat};
type Error = syn::parse::Error;
struct TemplateArg {
ident: syn::Ident,
is_pattern: bool,
ignore_fn: Option<syn::Path>,
value: syn::LitStr,
}
impl Parse for TemplateArg {
fn parse(input: ParseStream) -> ParseResult<Self> {
let mut ignore_fn = None;
let ident = input.parse::<syn::Ident>()?;
let is_pattern = if input.peek(syn::token::In) {
let _in = input.parse::<syn::token::In>()?;
true
} else {
let _eq = input.parse::<syn::token::Eq>()?;
false
};
let value = input.parse::<syn::LitStr>()?;
if is_pattern && input.peek(syn::token::If) {
let _if = input.parse::<syn::token::If>()?;
let _not = input.parse::<syn::token::Bang>()?;
ignore_fn = Some(input.parse::<syn::Path>()?);
}
Ok(Self {
ident,
is_pattern,
ignore_fn,
value,
})
}
}
struct FilesTestArgs {
root: String,
args: HashMap<Ident, TemplateArg>,
}
impl Parse for FilesTestArgs {
fn parse(input: ParseStream) -> ParseResult<Self> {
let root = input.parse::<syn::LitStr>()?;
let _comma = input.parse::<syn::token::Comma>()?;
let content;
let _brace_token = braced!(content in input);
let args: Punctuated<TemplateArg, Comma> = content.parse_terminated(TemplateArg::parse)?;
let args = args
.into_pairs()
.map(|p| {
let value = p.into_value();
(value.ident.clone(), value)
})
.collect();
Ok(Self {
root: root.value(),
args,
})
}
}
#[proc_macro_attribute]
#[allow(clippy::needless_pass_by_value)]
pub fn files(
args: proc_macro::TokenStream,
func: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let mut func_item = parse_macro_input!(func as ItemFn);
let args: FilesTestArgs = parse_macro_input!(args as FilesTestArgs);
let func_name_str = func_item.ident.to_string();
let desc_ident = Ident::new(
&format!("__TEST_{}", func_item.ident),
func_item.ident.span(),
);
let trampoline_func_ident = Ident::new(
&format!("__TEST_TRAMPOLINE_{}", func_item.ident),
func_item.ident.span(),
);
let info = handle_common_attrs(&mut func_item);
let ignore = info.ignore;
let root = args.root;
let mut pattern_idx = None;
let mut params: Vec<String> = Vec::new();
let mut invoke_args: Vec<TokenStream> = Vec::new();
let mut ignore_fn = None;
for (mut idx, arg) in func_item.decl.inputs.iter().enumerate() {
match arg {
FnArg::Captured(ArgCaptured {
pat: Pat::Ident(pat_ident),
ty,
..
}) => {
if info.bench {
if idx == 0 {
invoke_args.push(quote!(#pat_ident));
continue;
} else {
idx -= 1;
}
}
if let Some(arg) = args.args.get(&pat_ident.ident) {
if arg.is_pattern {
if pattern_idx.is_some() {
return Error::new(arg.ident.span(), "two patterns are not allowed!")
.to_compile_error()
.into();
}
pattern_idx = Some(idx);
ignore_fn = arg.ignore_fn.clone();
}
params.push(arg.value.value());
invoke_args.push(quote! {
::datatest::TakeArg::take(&mut <#ty as ::datatest::DeriveArg>::derive(&paths_arg[#idx]))
})
} else {
return Error::new(pat_ident.span(), "mapping is not defined for the argument")
.to_compile_error()
.into();
}
}
_ => {
return Error::new(
arg.span(),
"unexpected argument; only simple argument types are allowed (`&str`, `String`, `&[u8]`, `Vec<u8>`, `&Path`, etc)",
).to_compile_error().into();
}
}
}
let ignore_func_ref = if let Some(ignore_fn) = ignore_fn {
quote!(Some(#ignore_fn))
} else {
quote!(None)
};
if pattern_idx.is_none() {
return Error::new(
Span::call_site(),
"must have exactly one pattern mapping defined via `pattern in r#\"<regular expression>\"`",
)
.to_compile_error()
.into();
}
let orig_func_name = &func_item.ident;
let (kind, bencher_param) = if info.bench {
(quote!(BenchFn), quote!(bencher: &mut ::datatest::Bencher,))
} else {
(quote!(TestFn), quote!())
};
let output = quote! {
#[test_case]
#[automatically_derived]
#[allow(non_upper_case_globals)]
static #desc_ident: ::datatest::FilesTestDesc = ::datatest::FilesTestDesc {
name: concat!(module_path!(), "::", #func_name_str),
ignore: #ignore,
root: #root,
params: &[#(#params),*],
pattern: #pattern_idx,
ignorefn: #ignore_func_ref,
testfn: ::datatest::FilesTestFn::#kind(#trampoline_func_ident),
};
#[automatically_derived]
#[allow(non_snake_case)]
fn #trampoline_func_ident(#bencher_param paths_arg: &[::std::path::PathBuf]) {
let result = #orig_func_name(#(#invoke_args),*);
datatest::assert_test_result(result);
}
#func_item
};
output.into()
}
struct FuncInfo {
ignore: bool,
bench: bool,
}
fn handle_common_attrs(func: &mut ItemFn) -> FuncInfo {
let test_pos = func
.attrs
.iter()
.position(|attr| attr.path.is_ident("test"));
if let Some(pos) = test_pos {
func.attrs.remove(pos);
}
let bench_pos = func
.attrs
.iter()
.position(|attr| attr.path.is_ident("bench"));
if let Some(pos) = bench_pos {
func.attrs.remove(pos);
}
let ignore_pos = func
.attrs
.iter()
.position(|attr| attr.path.is_ident("ignore"));
if let Some(pos) = ignore_pos {
func.attrs.remove(pos);
}
FuncInfo {
ignore: ignore_pos.is_some(),
bench: bench_pos.is_some(),
}
}
enum DataTestArgs {
Literal(syn::LitStr),
Expression(syn::Expr),
}
impl Parse for DataTestArgs {
fn parse(input: ParseStream) -> ParseResult<Self> {
let lookahead = input.lookahead1();
if lookahead.peek(syn::LitStr) {
input.parse::<syn::LitStr>().map(DataTestArgs::Literal)
} else {
input.parse::<syn::Expr>().map(DataTestArgs::Expression)
}
}
}
#[proc_macro_attribute]
#[allow(clippy::needless_pass_by_value)]
pub fn data(
args: proc_macro::TokenStream,
func: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let mut func_item = parse_macro_input!(func as ItemFn);
let cases: DataTestArgs = parse_macro_input!(args as DataTestArgs);
let cases = match cases {
DataTestArgs::Literal(path) => quote!(datatest::yaml(#path)),
DataTestArgs::Expression(expr) => quote!(#expr),
};
let func_name_str = func_item.ident.to_string();
let desc_ident = Ident::new(
&format!("__TEST_{}", func_item.ident),
func_item.ident.span(),
);
let describe_func_ident = Ident::new(
&format!("__TEST_DESCRIBE_{}", func_item.ident),
func_item.ident.span(),
);
let trampoline_func_ident = Ident::new(
&format!("__TEST_TRAMPOLINE_{}", func_item.ident),
func_item.ident.span(),
);
let info = handle_common_attrs(&mut func_item);
let ignore = info.ignore;
let orig_func_ident = &func_item.ident;
let mut args = func_item.decl.inputs.iter();
if info.bench {
args.next();
}
let arg = args.next();
let ty = match arg {
Some(FnArg::Captured(ArgCaptured { ty, .. })) => Some(ty),
_ => None,
};
let (ref_token, ty) = match ty {
Some(syn::Type::Reference(type_ref)) => (quote!(&), Some(type_ref.elem.as_ref())),
_ => (TokenStream::new(), ty),
};
let (case_ctor, bencher_param, bencher_arg) = if info.bench {
(
quote!(::datatest::DataTestFn::BenchFn(Box::new(::datatest::DataBenchFn(#trampoline_func_ident, case)))),
quote!(bencher: &mut ::datatest::Bencher,),
quote!(bencher,),
)
} else {
(
quote!(::datatest::DataTestFn::TestFn(Box::new(move || #trampoline_func_ident(case)))),
quote!(),
quote!(),
)
};
let output = quote! {
#[test_case]
#[automatically_derived]
#[allow(non_upper_case_globals)]
static #desc_ident: ::datatest::DataTestDesc = ::datatest::DataTestDesc {
name: concat!(module_path!(), "::", #func_name_str),
ignore: #ignore,
describefn: #describe_func_ident,
};
#[automatically_derived]
#[allow(non_snake_case)]
fn #trampoline_func_ident(#bencher_param arg: #ty) {
let result = #orig_func_ident(#bencher_arg #ref_token arg);
datatest::assert_test_result(result);
}
#[automatically_derived]
#[allow(non_snake_case)]
fn #describe_func_ident() -> Vec<::datatest::DataTestCaseDesc<::datatest::DataTestFn>> {
let result = #cases
.into_iter()
.map(|input| {
let case = input.case;
::datatest::DataTestCaseDesc {
case: #case_ctor,
name: input.name,
location: input.location,
}
})
.collect::<Vec<_>>();
assert!(!result.is_empty(), "no test cases were found!");
result
}
#func_item
};
output.into()
}