test_context_macros/
lib.rs

1mod macro_args;
2mod test_args;
3
4use crate::test_args::{ContextArg, ContextArgMode, TestArg};
5use macro_args::MacroArgs;
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::ItemFn;
9
10/// Macro to use on tests to add the setup/teardown functionality of your context.
11///
12/// Ordering of this attribute is important, and typically `test_context` should come
13/// before other test attributes. For example, the following is valid:
14///
15/// ```ignore
16/// #[test_context(MyContext)]
17/// #[test]
18/// fn my_test() {
19/// }
20/// ```
21///
22/// The following is NOT valid...
23///
24/// ```ignore
25/// #[test]
26/// #[test_context(MyContext)]
27/// fn my_test() {
28/// }
29/// ```
30#[proc_macro_attribute]
31pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
32    let args = syn::parse_macro_input!(attr as MacroArgs);
33    let input = syn::parse_macro_input!(item as syn::ItemFn);
34
35    let (input, context_args) = remove_context_args(input, args.context_type.clone());
36
37    if context_args.len() != 1 {
38        panic!("Exactly one Context argument must be defined");
39    }
40
41    let context_arg = context_args.into_iter().next().unwrap();
42
43    if !args.skip_teardown && context_arg.mode.is_owned() {
44        panic!(
45            "It is not possible to take ownership of the context if the teardown has to be ran."
46        );
47    }
48
49    let input = refactor_input_body(input, &args, context_arg);
50
51    quote! { #input }.into()
52}
53
54fn remove_context_args(
55    mut input: syn::ItemFn,
56    expected_context_type: syn::Type,
57) -> (syn::ItemFn, Vec<ContextArg>) {
58    let test_args: Vec<TestArg> = input
59        .sig
60        .inputs
61        .into_iter()
62        .map(|arg| TestArg::parse_arg_with_expected_context(arg, &expected_context_type))
63        .collect();
64
65    let context_args: Vec<ContextArg> = test_args
66        .iter()
67        .cloned()
68        .filter_map(|arg| match arg {
69            TestArg::Any(_) => None,
70            TestArg::Context(context_arg_info) => Some(context_arg_info),
71        })
72        .collect();
73
74    let new_args: syn::punctuated::Punctuated<_, _> = test_args
75        .into_iter()
76        .filter_map(|arg| match arg {
77            TestArg::Any(fn_arg) => Some(fn_arg),
78            TestArg::Context(_) => None,
79        })
80        .collect();
81
82    input.sig.inputs = new_args;
83
84    (input, context_args)
85}
86
87fn refactor_input_body(
88    input: syn::ItemFn,
89    args: &MacroArgs,
90    context_arg: ContextArg,
91) -> syn::ItemFn {
92    let context_type = &args.context_type;
93    let result_name = format_ident!("wrapped_result");
94    let body = &input.block;
95    let is_async = input.sig.asyncness.is_some();
96    let context_arg_name = context_arg.name;
97
98    let context_binding = match context_arg.mode {
99        ContextArgMode::Owned => quote! { let #context_arg_name = __context; },
100        ContextArgMode::OwnedMut => quote! { let mut #context_arg_name = __context; },
101        ContextArgMode::Reference => quote! { let #context_arg_name = &__context; },
102        ContextArgMode::MutableReference => quote! { let #context_arg_name = &mut __context; },
103    };
104
105    let body = if args.skip_teardown && is_async {
106        quote! {
107            use test_context::futures::FutureExt;
108            let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
109            #context_binding
110            let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
111        }
112    } else if args.skip_teardown && !is_async {
113        quote! {
114            let mut __context = <#context_type as test_context::TestContext>::setup();
115            let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
116                #context_binding
117                #body
118            }));
119        }
120    } else if !args.skip_teardown && is_async {
121        quote! {
122            use test_context::futures::FutureExt;
123            let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
124            #context_binding
125            let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
126            <#context_type as test_context::AsyncTestContext>::teardown(__context).await;
127        }
128    }
129    // !args.skip_teardown && !is_async
130    else {
131        quote! {
132            let mut __context = <#context_type as test_context::TestContext>::setup();
133            let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
134                #context_binding
135                #body
136            }));
137            <#context_type as test_context::TestContext>::teardown(__context);
138        }
139    };
140
141    let body = quote! {
142        {
143            #body
144            match #result_name {
145                Ok(value) => value,
146                Err(err) => {
147                    std::panic::resume_unwind(err);
148                }
149            }
150        }
151    };
152
153    ItemFn {
154        block: Box::new(syn::parse2(body).unwrap()),
155        ..input
156    }
157}