#[macro_use]
extern crate syn;
extern crate proc_macro;
use std::collections::{BTreeMap, HashMap};
use std::iter::FromIterator;
use std::sync::atomic::{AtomicUsize, Ordering};
use proc_macro2::TokenStream;
use quote::quote;
use syn::export::fmt::Display;
use syn::parse_macro_input;
use syn::spanned::Spanned;
mod attribute;
#[proc_macro_attribute]
pub fn parameterized(
args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let args = parse_macro_input!(args as attribute::AttributeArgList);
let func = parse_macro_input!(input as syn::ItemFn);
let name = &func.sig.ident;
let vis = &func.vis;
let func_args = &func.sig.inputs;
let body_block = &func.block;
let attributes = &func.attrs;
let mod_name = format!("{}", name);
let mod_ident = syn::Ident::new(mod_name.as_str(), name.span());
let generated_ident_id = AtomicUsize::new(0);
let identifiers_defined = args.args.len();
let exprs_by_id: HashMap<syn::Ident, Vec<syn::Expr>> = args
.args
.iter()
.map(|v| (v.id.clone(), v.param_args.iter().cloned().collect()))
.collect();
if exprs_by_id.len() != identifiers_defined {
panic!("Duplicate identifier(s) found. Please use unique parameter names.")
}
let amount_of_test_cases = check_all_input_lengths(&exprs_by_id);
let test_case_fns = (0..amount_of_test_cases).map(|i| {
let binds: Vec<TokenStream> = func_args
.iter()
.map(|fn_arg| {
if let syn::FnArg::Typed(pat) = fn_arg {
let fn_expected_ty = &pat.ty;
let fn_ident = pat.pat.as_ref();
if let syn::Pat::Ident(pat_ident) = fn_ident {
let fn_arg_ident = &pat_ident.ident;
if let Some(exprs) = exprs_by_id.get(&fn_arg_ident) {
let expr = &exprs[i];
return quote! {
let #fn_arg_ident: #fn_expected_ty = #expr;
};
} else {
panic!("not enough test cases found, [this should never happen] ")
}
} else {
panic!("Unable to find a parameter name...")
}
} else {
panic!("Malformed function input.")
}
})
.collect();
let next_id = generated_ident_id.fetch_add(1, Ordering::SeqCst);
let ident = format!("case_{}", next_id);
let ident = syn::Ident::new(ident.as_str(), func.span());
quote! {
#[test]
#(#attributes)*
#vis fn #ident() {
#(#binds)*
#body_block
}
}
});
let token_stream = quote! {
#[cfg(test)]
#vis mod #mod_ident {
use super::*;
#(#test_case_fns)*
}
};
token_stream.into()
}
fn check_all_input_lengths(map: &HashMap<syn::Ident, Vec<syn::Expr>>) -> usize {
map.values()
.fold(None, |acc: Option<usize>, exprs| match acc {
Some(size) if size == exprs.len() => Some(size),
Some(_) => {
panic_on_inequal_length(map);
unreachable!()
}
None => Some(exprs.len()),
})
.unwrap_or_default()
}
fn panic_on_inequal_length<K: Ord + Display, V>(map: impl IntoIterator<Item = (K, V)>) {
let sorted_by_id: BTreeMap<K, V> = BTreeMap::from_iter(map);
let ids: String = sorted_by_id
.iter()
.map(|(id, _)| format!("{}", id))
.collect::<Vec<String>>()
.join(", ");
panic!("All inputs ({}) should have equal length.", ids)
}