use crate::body;
use crate::helpers::*;
use pm::{Span, TokenStream};
use quote::{quote, ToTokens};
use std::{iter::FromIterator, mem::replace};
use syn::{
parse::Error, parse_quote, punctuated::Punctuated, spanned::Spanned, visit_mut::VisitMut,
Attribute, Block, FnArg, FnDecl, Ident, ItemMod, Type, Visibility,
};
type FnArgs = Punctuated<syn::FnArg, syn::token::Comma>;
type CallArgs = Punctuated<syn::Expr, syn::token::Comma>;
pub fn generate(
ident: Ident,
decl: &mut FnDecl,
block: &mut Block,
attrs: &mut Vec<Attribute>,
session: Option<&Type>,
) -> Result<ItemMod, Error> {
let case_ident = generate_case_ident();
let cases = extract_cases(attrs)?;
let stringified_cases = generate_stringified_cases(&decl.inputs, &cases, session.is_some())?;
let span = ident.span();
let test_fn_idents: Vec<_> = cases.iter().map(move |s| fn_name(s, span)).collect();
let vis = Visibility::Inherited;
let attrs = {
let mut attrs = replace(attrs, Vec::new());
add_test_attribute(&mut attrs);
TokenStream::from_iter(attrs.into_iter().flat_map(ToTokens::into_token_stream))
};
decl.inputs.push(parse_quote!(#case_ident: &'static str));
body::Transform::new(Some(case_ident)).visit_block_mut(block);
let tests =
zip2(test_fn_idents, cases, stringified_cases).map(|(case_ident, case, stringified)| {
match session {
Some(ty) => quote! {
#attrs
pub fn #case_ident() {
#ty::default().#ident(#case, #stringified)
}
},
None => quote! {
#attrs
pub fn #case_ident() {
#ident(#case, #stringified)
}
},
}
});
let output = parse_quote! {
#vis mod #ident {
use super::*;
#( #tests )*
}
};
Ok(output)
}
fn fn_name(args: &CallArgs, span: Span) -> Ident {
fn remove_dup(mut s: String) -> String {
let mut prev = None;
s.retain(|c| {
let result = match &prev {
Some(prev) if *prev == '_' && c == '_' => false,
_ => true,
};
prev = Some(c);
result
});
s
}
let s = args
.iter()
.cloned()
.fold(String::from("case"), |s, arg| {
format!("{}_{}", s, arg.into_token_stream())
})
.replace(|c| !char::is_alphanumeric(c), "_")
.to_lowercase();
Ident::new(remove_dup(s).trim_end_matches('_'), span)
}
fn generate_stringified_cases(
args: &FnArgs,
cases: &[CallArgs],
skip_self: bool,
) -> Result<Vec<String>, Error> {
let args = {
let mut output = Vec::new();
for arg in args.iter().skip(skip_self as usize) {
match arg {
FnArg::Captured(arg) => {
output.push(format!("{}", arg.pat.clone().into_token_stream()))
}
_ => Err(Error::new(arg.span(), "cannot format this argument"))?,
}
}
output
};
let stringified = cases
.iter()
.cloned()
.map(|case| {
case.iter()
.zip(&args)
.map(|(val, arg)| format!("{} = {}", arg, val.into_token_stream()))
.join()
})
.collect();
Ok(stringified)
}
fn generate_case_ident() -> Ident {
let uuid = uuid::Uuid::new_v4();
let ident = format!("case_{}", uuid).replace('-', "_");
Ident::new(&ident, Span::call_site())
}
fn extract_cases(attrs: &mut Vec<Attribute>) -> Result<Vec<CallArgs>, Error> {
attrs.sort_by_key(|attr| attr.path.is_ident("case"));
let to_skip = attrs
.iter()
.take_while(|attr| !attr.path.is_ident("case"))
.count();
let mut output = Vec::new();
for case in attrs.drain(to_skip..) {
let args = case.tts.stream_args()?;
output.push(parse_quote!(#args))
}
Ok(output)
}