test_context_macros/
lib.rs1mod macro_args;
2mod test_args;
3
4use crate::test_args::{ContextArg, ContextArgMode, TestArg};
5use macro_args::MacroArgs;
6use proc_macro::TokenStream;
7use quote::{format_ident, quote};
8use syn::ItemFn;
9
10#[proc_macro_attribute]
31pub fn test_context(attr: TokenStream, item: TokenStream) -> TokenStream {
32 let args = syn::parse_macro_input!(attr as MacroArgs);
33 let input = syn::parse_macro_input!(item as syn::ItemFn);
34
35 let (input, context_args) = remove_context_args(input, args.context_type.clone());
36
37 if context_args.len() != 1 {
38 panic!("Exactly one Context argument must be defined");
39 }
40
41 let context_arg = context_args.into_iter().next().unwrap();
42
43 if !args.skip_teardown && context_arg.mode.is_owned() {
44 panic!(
45 "It is not possible to take ownership of the context if the teardown has to be ran."
46 );
47 }
48
49 let input = refactor_input_body(input, &args, context_arg);
50
51 quote! { #input }.into()
52}
53
54fn remove_context_args(
55 mut input: syn::ItemFn,
56 expected_context_type: syn::Type,
57) -> (syn::ItemFn, Vec<ContextArg>) {
58 let test_args: Vec<TestArg> = input
59 .sig
60 .inputs
61 .into_iter()
62 .map(|arg| TestArg::parse_arg_with_expected_context(arg, &expected_context_type))
63 .collect();
64
65 let context_args: Vec<ContextArg> = test_args
66 .iter()
67 .cloned()
68 .filter_map(|arg| match arg {
69 TestArg::Any(_) => None,
70 TestArg::Context(context_arg_info) => Some(context_arg_info),
71 })
72 .collect();
73
74 let new_args: syn::punctuated::Punctuated<_, _> = test_args
75 .into_iter()
76 .filter_map(|arg| match arg {
77 TestArg::Any(fn_arg) => Some(fn_arg),
78 TestArg::Context(_) => None,
79 })
80 .collect();
81
82 input.sig.inputs = new_args;
83
84 (input, context_args)
85}
86
87fn refactor_input_body(
88 input: syn::ItemFn,
89 args: &MacroArgs,
90 context_arg: ContextArg,
91) -> syn::ItemFn {
92 let context_type = &args.context_type;
93 let result_name = format_ident!("wrapped_result");
94 let body = &input.block;
95 let is_async = input.sig.asyncness.is_some();
96 let context_arg_name = context_arg.name;
97
98 let context_binding = match context_arg.mode {
99 ContextArgMode::Owned => quote! { let #context_arg_name = __context; },
100 ContextArgMode::OwnedMut => quote! { let mut #context_arg_name = __context; },
101 ContextArgMode::Reference => quote! { let #context_arg_name = &__context; },
102 ContextArgMode::MutableReference => quote! { let #context_arg_name = &mut __context; },
103 };
104
105 let body = if args.skip_teardown && is_async {
106 quote! {
107 use test_context::futures::FutureExt;
108 let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
109 #context_binding
110 let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
111 }
112 } else if args.skip_teardown && !is_async {
113 quote! {
114 let mut __context = <#context_type as test_context::TestContext>::setup();
115 let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
116 #context_binding
117 #body
118 }));
119 }
120 } else if !args.skip_teardown && is_async {
121 quote! {
122 use test_context::futures::FutureExt;
123 let mut __context = <#context_type as test_context::AsyncTestContext>::setup().await;
124 #context_binding
125 let #result_name = std::panic::AssertUnwindSafe( async { #body } ).catch_unwind().await;
126 <#context_type as test_context::AsyncTestContext>::teardown(__context).await;
127 }
128 }
129 else {
131 quote! {
132 let mut __context = <#context_type as test_context::TestContext>::setup();
133 let #result_name = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
134 #context_binding
135 #body
136 }));
137 <#context_type as test_context::TestContext>::teardown(__context);
138 }
139 };
140
141 let body = quote! {
142 {
143 #body
144 match #result_name {
145 Ok(value) => value,
146 Err(err) => {
147 std::panic::resume_unwind(err);
148 }
149 }
150 }
151 };
152
153 ItemFn {
154 block: Box::new(syn::parse2(body).unwrap()),
155 ..input
156 }
157}