1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{parse_macro_input, FnArg, ItemFn};
4
5struct Args {
6 should_fail: bool
7}
8
9fn parse_args(attr_args: syn::AttributeArgs) -> syn::Result<Args> {
10 let mut should_fail = false;
11 for arg in attr_args {
12 match arg {
13 syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue))
14 if namevalue.path.is_ident("should_fail") => {
15 should_fail = match namevalue.lit {
16 syn::Lit::Bool(litbool) => litbool.value,
17 _ => {
18 return Err(syn::Error::new_spanned(
19 namevalue,
20 "expected `true` or `false`",
21 ))
22 }
23 };
24 },
25 other => {
26 return Err(syn::Error::new_spanned(
27 other,
28 "expected `should_fail = true | false`",
29 ))
30 }
31
32 }
33 }
34 Ok(Args { should_fail })
35}
36
37#[proc_macro_attribute]
38pub fn test(args: TokenStream, item: TokenStream) -> TokenStream {
39 let args = syn::parse_macro_input!(args as syn::AttributeArgs);
40 let input = parse_macro_input!(item as ItemFn);
41
42 let function_name = input.sig.ident.clone();
43 let arguments = input.sig.inputs;
44 let return_type = input.sig.output.clone();
45
46 let block = &input.block;
47 let migrate = if let Some(FnArg::Typed(arg)) = arguments.first() {
48 let name = arg.pat.clone().into_token_stream();
49 quote! { flagrant::db::migrate(&#name, semver::Version::parse("0.0.1").unwrap()).await.unwrap(); }
50 } else {
51 quote! {}
52 };
53 let should_fail = if parse_args(args).unwrap().should_fail {
54 quote! {
55 #[should_panic]
56 }
57 } else {
58 quote! { }
59 };
60
61 quote! {
62 #[sqlx::test]
63 #should_fail
64 async fn #function_name(#arguments) #return_type {
65 #migrate
66 #block
67 }
68 }.into()
69}