use std::iter::FromIterator;
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use proc_macro_error::{abort, emit_error, proc_macro_error};
use quote::quote;
use syn::{
parse_macro_input, Block, Expr, FnArg, ItemFn, LitStr, Local, Pat, Result, Stmt, Token,
parse::{Parse, ParseStream},
punctuated::Punctuated,
token::{Eq, Let, Semi},
};
#[proc_macro_attribute]
#[proc_macro_error]
pub fn args(attr: TokenStream, input: TokenStream) -> TokenStream {
apply_test_args(attr, input, false)
}
#[proc_macro_attribute]
#[proc_macro_error]
pub fn test_args(attr: TokenStream, input: TokenStream) -> TokenStream {
apply_test_args(attr, input, true)
}
fn apply_test_args(attr: TokenStream, input: TokenStream, append_test_attr: bool) -> TokenStream {
let cases = parse_macro_input!(attr as Cases);
let input = parse_macro_input!(input as ItemFn);
if cases.0.len() == 0 {
let test = test_attribute(&input, append_test_attr);
return quote!{
#test
#input
}.into();
}
let mut output = quote!{};
for case in cases.0 {
let should_panic = case.panics.clone().map(|e| quote!{ #[should_panic(expected = #e)] });
let func = make_case_function(&input, case);
let test = test_attribute(&func, append_test_attr);
output.extend(quote!{
#test
#should_panic
#func
});
}
output.into()
}
fn test_attribute(func: &ItemFn, add_if_needed: bool) -> Option<proc_macro2::TokenStream> {
if func.sig.inputs.len() > 0 ||
func.attrs.iter().any(|a| a.path.segments.last().map_or(false, |seg|seg.ident=="test"))
{
return None;
}
if add_if_needed {
Some(quote!{ #[test] })
} else {
abort!(func, "Devbox: Function '{}' is missing '#[test]' attribute", func.sig.ident);
}
}
fn make_case_function(input: &ItemFn, case: Case) -> ItemFn {
if case.values.len() > input.sig.inputs.len() {
emit_error!(
input,
"Devbox: Test case '{}' arguments outnumber function '{}' parameters {} to {}",
case.ident, input.sig.ident, case.values.len(), input.sig.inputs.len()
);
}
let mut func = input.clone();
let name = format!("{}__{}", func.sig.ident, case.ident.to_string());
func.sig.ident = Ident::new(name.as_ref(), Span::call_site());
let inputs = func.sig.inputs.clone();
let mut args = inputs.iter().map(|t|t.clone());
for expr in case.values {
if let Some(arg) = args.next() {
insert_param(&mut func.block, arg, expr);
}
}
func.sig.inputs = syn::punctuated::Punctuated::from_iter(args);
func
}
fn insert_param(block: &mut Box<Block>, arg: FnArg, init:Box<Expr>){
match arg {
FnArg::Typed(arg) => block.stmts.insert(0, Stmt::Local(Local {
attrs: vec![],
let_token: Let { span: Span::call_site() },
pat: Pat::Type(arg),
init: Some((Eq{ spans: [Span::call_site()] }, init)),
semi_token: Semi { spans: [Span::call_site()] },
})),
FnArg::Receiver(_) => emit_error!(
arg,
"Devbox: Parametrized test applied to non-associated function"
)
}
}
struct Case {
pub ident: Ident,
pub colon: Token![:],
pub values: Vec<Box<Expr>>,
pub panics: Option<LitStr>,
}
impl Parse for Case {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Case {
ident: input.parse()?,
colon: input.parse()?,
values: {
let mut result = vec![Box::new(input.parse()?)];
let mut more: Option<Token![,]> = input.parse()?;
while more.is_some() {
result.push(Box::new(input.parse()?));
more = input.parse()?;
}
result
},
panics: {
let excl: Option<Token![!]> = input.parse()?;
if excl.is_some() {
input.parse()?
} else {
None
}
}
})
}
}
struct Cases(Punctuated<Case, Token![;]>);
impl Parse for Cases {
fn parse(input: ParseStream) -> Result<Self> {
Ok(Cases(input.parse_terminated(Case::parse)?))
}
}