use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::parse::{Parse, ParseStream};
use syn::{braced, parenthesized, Ident, LitStr, Token, Type};
fn to_snake_case(name: &str) -> String {
let mut result = String::new();
let mut prev_is_uppercase = false;
for (i, c) in name.chars().enumerate() {
if c.is_alphanumeric() {
if c.is_uppercase() {
if i > 0 && !prev_is_uppercase && !result.ends_with('_') {
result.push('_');
}
result.push(c.to_ascii_lowercase());
prev_is_uppercase = true;
} else {
result.push(c);
prev_is_uppercase = false;
}
} else if (c.is_whitespace() || c == '-' || c == '_')
&& !result.ends_with('_')
&& !result.is_empty()
{
result.push('_');
}
}
while result.ends_with('_') {
result.pop();
}
result
}
struct FnParam {
name: Ident,
ty: Type,
}
struct TestArgs {
name: LitStr,
is_async: bool,
params: Vec<FnParam>,
body: TokenStream2,
}
impl Parse for TestArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let name: LitStr = input.parse()?;
input.parse::<Token![,]>()?;
let is_async = if input.peek(Token![async]) {
input.parse::<Token![async]>()?;
true
} else {
false
};
input.parse::<Token![fn]>()?;
let content;
parenthesized!(content in input);
let mut params = Vec::new();
while !content.is_empty() {
let param_name: Ident = content.parse()?;
content.parse::<Token![:]>()?;
let param_type: Type = content.parse()?;
params.push(FnParam {
name: param_name,
ty: param_type,
});
if content.peek(Token![,]) {
content.parse::<Token![,]>()?;
}
}
let body_content;
braced!(body_content in input);
let body: TokenStream2 = body_content.parse()?;
Ok(Self {
name,
is_async,
params,
body,
})
}
}
fn is_test_database(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
return segment.ident == "TestDatabase";
}
}
false
}
pub fn test_impl(input: TokenStream) -> TokenStream {
let args = match syn::parse::<TestArgs>(input) {
Ok(args) => args,
Err(e) => return e.to_compile_error().into(),
};
let ferro = quote!(::ferro);
let name_str = args.name.value();
let fn_name = format_ident!("{}", to_snake_case(&name_str));
let body = args.body;
let has_db_param = args.params.iter().any(|p| is_test_database(&p.ty));
let db_param = args.params.iter().find(|p| is_test_database(&p.ty));
if args.is_async {
if has_db_param {
let db_param_name = &db_param.unwrap().name;
let output = quote! {
#[#ferro::ferro_test]
async fn #fn_name(#db_param_name: #ferro::testing::TestDatabase) {
#ferro::testing::set_current_test_name(Some(#name_str.to_string()));
let __test_result = async {
#body
}.await;
#ferro::testing::set_current_test_name(None);
__test_result
}
};
output.into()
} else {
let output = quote! {
#[#ferro::ferro_test]
async fn #fn_name() {
#ferro::testing::set_current_test_name(Some(#name_str.to_string()));
let __test_result = async {
#body
}.await;
#ferro::testing::set_current_test_name(None);
__test_result
}
};
output.into()
}
} else {
let output = quote! {
#[test]
fn #fn_name() {
#ferro::testing::set_current_test_name(Some(#name_str.to_string()));
let __test_result = {
#body
};
#ferro::testing::set_current_test_name(None);
__test_result
}
};
output.into()
}
}