use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, FnArg, ItemFn, Pat, Type};
struct FerroTestArgs {
migrator: Option<syn::Path>,
}
impl syn::parse::Parse for FerroTestArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut migrator = None;
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
if ident == "migrator" {
input.parse::<syn::Token![=]>()?;
migrator = Some(input.parse()?);
}
if input.peek(syn::Token![,]) {
input.parse::<syn::Token![,]>()?;
}
}
Ok(Self { migrator })
}
}
fn is_test_database_type(ty: &Type) -> bool {
if let Type::Path(type_path) = ty {
if let Some(segment) = type_path.path.segments.last() {
return segment.ident == "TestDatabase";
}
}
false
}
fn find_db_param_name(func: &ItemFn) -> Option<syn::Ident> {
for arg in &func.sig.inputs {
if let FnArg::Typed(pat_type) = arg {
if is_test_database_type(&pat_type.ty) {
if let Pat::Ident(pat_ident) = &*pat_type.pat {
return Some(pat_ident.ident.clone());
}
}
}
}
None
}
pub fn ferro_test_impl(attr: TokenStream, input: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as FerroTestArgs);
let input_fn = parse_macro_input!(input as ItemFn);
let ferro = quote!(::ferro);
let fn_name = &input_fn.sig.ident;
let fn_block = &input_fn.block;
let fn_attrs: Vec<_> = input_fn
.attrs
.iter()
.filter(|attr| !attr.path().is_ident("ferro_test"))
.collect();
let fn_vis = &input_fn.vis;
let migrator_type = args
.migrator
.unwrap_or_else(|| syn::parse_quote!(crate::migrations::Migrator));
let db_param_name = find_db_param_name(&input_fn);
let setup_and_body = if let Some(param_name) = db_param_name {
quote! {
#ferro::App::init();
#ferro::App::boot_services();
let #param_name = #ferro::testing::TestDatabase::fresh::<#migrator_type>()
.await
.expect("Failed to set up test database");
#fn_block
}
} else {
quote! {
#ferro::App::init();
#ferro::App::boot_services();
let _db = #ferro::testing::TestDatabase::fresh::<#migrator_type>()
.await
.expect("Failed to set up test database");
#fn_block
}
};
let output = quote! {
#(#fn_attrs)*
#[::tokio::test]
#fn_vis async fn #fn_name() {
#setup_and_body
}
};
output.into()
}