e2e_macro/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{ImplItem, ItemImpl, Type, parse_macro_input};
6
7enum DetectedItem {
8    Constructor,
9    BeforeAll,
10    BeforeEach,
11    AfterEach,
12    AfterAll,
13    TestCase(String),
14}
15
16#[proc_macro_attribute]
17pub fn test_suite(attr: TokenStream, item: TokenStream) -> TokenStream {
18    test_suite_impl(attr, item)
19}
20
21fn test_suite_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
22    let suite_name = parse_macro_input!(attr as syn::Lit);
23
24    let input = parse_macro_input!(item as ItemImpl);
25
26    let struct_ty = input.self_ty.clone();
27    let Type::Path(struct_ty) = *struct_ty else {
28        panic!("The test suite must be implemented for a struct type");
29    };
30    let struct_ty_name = struct_ty.path.get_ident().expect("Expected a struct name");
31
32    let mut constructor = None;
33    let mut before_all = None;
34    let mut before_each = None;
35    let mut after_each = None;
36    let mut after_all = None;
37    let mut test_cases = vec![];
38
39    let mut cleaned_items = vec![];
40
41    for item in &input.items {
42        if let ImplItem::Fn(mut method) = item.clone() {
43            let mut detected_item = None;
44            method.attrs.retain(|attr| {
45                let ident = attr.meta.path().get_ident();
46                if let Some(ident) = ident {
47                    match ident.to_string().as_str() {
48                        "constructor" => {
49                            detected_item = Some(DetectedItem::Constructor);
50                            return false;
51                        }
52                        "before_all" => {
53                            detected_item = Some(DetectedItem::BeforeAll);
54                            return false;
55                        }
56                        "before_each" => {
57                            detected_item = Some(DetectedItem::BeforeEach);
58                            return false;
59                        }
60                        "after_each" => {
61                            detected_item = Some(DetectedItem::AfterEach);
62                            return false;
63                        }
64                        "after_all" => {
65                            detected_item = Some(DetectedItem::AfterAll);
66                            return false;
67                        }
68                        "test_case" => {
69                            if let Ok(syn::Lit::Str(lit_str)) = attr
70                                .meta
71                                .require_list()
72                                .expect("`test_case` attribute must contain test name")
73                                .parse_args()
74                            {
75                                detected_item = Some(DetectedItem::TestCase(lit_str.value()));
76                            }
77                            return false;
78                        }
79                        _ => {}
80                    }
81                }
82                true
83            });
84            cleaned_items.push(ImplItem::Fn(method.clone()));
85
86            match detected_item {
87                Some(DetectedItem::Constructor) => {
88                    if constructor.is_some() {
89                        panic!("Only one constructor is allowed in a test suite");
90                    }
91                    constructor = Some(method);
92                }
93                Some(DetectedItem::BeforeAll) => {
94                    if before_all.is_some() {
95                        panic!("Only one 'before_all' method is allowed in a test suite");
96                    }
97                    before_all = Some(method);
98                }
99                Some(DetectedItem::BeforeEach) => {
100                    if before_each.is_some() {
101                        panic!("Only one 'before_each' method is allowed in a test suite");
102                    }
103                    before_each = Some(method);
104                }
105                Some(DetectedItem::AfterEach) => {
106                    if after_each.is_some() {
107                        panic!("Only one 'after_each' method is allowed in a test suite");
108                    }
109                    after_each = Some(method);
110                }
111                Some(DetectedItem::AfterAll) => {
112                    if after_all.is_some() {
113                        panic!("Only one 'after_all' method is allowed in a test suite");
114                    }
115                    after_all = Some(method);
116                }
117                Some(DetectedItem::TestCase(name)) => {
118                    test_cases.push((name, method));
119                }
120                None => {}
121            }
122        } else {
123            cleaned_items.push(item.clone());
124        }
125    }
126
127    let constructor = constructor.unwrap_or_else(|| {
128        panic!("A test suite must have a constructor method annotated with #[constructor]");
129    });
130    // The only argument to the constructor should be config type.
131    let config_ty = constructor.sig.inputs.first().cloned().unwrap_or_else(|| {
132        panic!("Constructor method must have a single argument for the config type");
133    });
134    let config_ty = if let syn::FnArg::Typed(pat_type) = config_ty {
135        pat_type.ty
136    } else {
137        panic!("Constructor method must have a single argument for the config type");
138    };
139    let Type::Reference(config_ty) = *config_ty else {
140        panic!("Constructor method must take reference to the config type as an argument");
141    };
142    let Type::Path(config_ty) = *config_ty.elem else {
143        panic!("Constructor method must take reference to the config type as an argument");
144    };
145    let config_ty_name = config_ty
146        .path
147        .get_ident()
148        .expect("Expected a config type name");
149    let constructor_fn_name = &constructor.sig.ident;
150
151    let crate_name = quote::format_ident!("e2e");
152
153    let before_all_code = if let Some(before_all) = before_all {
154        let fn_name = &before_all.sig.ident;
155        quote! {
156            #struct_ty_name::#fn_name(&self).await
157        }
158    } else {
159        quote! {
160            Ok(())
161        }
162    };
163    let before_each_code = if let Some(before_each) = before_each {
164        let fn_name = &before_each.sig.ident;
165        quote! {
166            #struct_ty_name::#fn_name(&self).await
167        }
168    } else {
169        quote! {
170            Ok(())
171        }
172    };
173    let after_each_code = if let Some(after_each) = after_each {
174        let fn_name = &after_each.sig.ident;
175        quote! {
176            #struct_ty_name::#fn_name(&self).await
177        }
178    } else {
179        quote! {
180            Ok(())
181        }
182    };
183    let after_all_code = if let Some(after_all) = after_all {
184        let fn_name = &after_all.sig.ident;
185        quote! {
186            #struct_ty_name::#fn_name(&self).await
187        }
188    } else {
189        quote! {
190            Ok(())
191        }
192    };
193
194    let factory_name = quote::format_ident!("{}Factory", struct_ty_name);
195
196    let mut test_case_code = Vec::new();
197    let mut test_case_objects = Vec::new();
198    for (id, (name, method)) in test_cases.into_iter().enumerate() {
199        let test_fn_name = &method.sig.ident;
200
201        let test_ty_name = quote::format_ident!("{}Test{}", struct_ty_name, id);
202        let test_case = quote! {
203            struct #test_ty_name(#struct_ty_name);
204
205            #[async_trait::async_trait]
206            impl #crate_name::Test for #test_ty_name {
207                fn name(&self) -> String {
208                    #name.to_string()
209                }
210
211                async fn run(&self) -> anyhow::Result<()> {
212                    self.0.#test_fn_name().await
213                }
214            }
215        };
216        test_case_code.push(test_case);
217        test_case_objects.push(quote! {
218            Box::new(#test_ty_name(self.clone()))
219        });
220    }
221
222    let factory_fn: syn::ImplItem = syn::parse_quote! {
223        pub fn factory() -> Box<dyn #crate_name::TestSuiteFactory<#config_ty_name>> {
224            Box::new(#factory_name)
225        }
226    };
227    cleaned_items.push(factory_fn);
228
229    let cleaned_impl = ItemImpl {
230        items: cleaned_items,
231        ..input
232    };
233
234    // Placeholder for generating the test suite logic
235    let output = quote! {
236        #cleaned_impl
237
238        #[async_trait::async_trait]
239        impl #crate_name::TestSuite for #struct_ty_name {
240            fn name(&self) -> String {
241                #suite_name.to_string()
242            }
243
244            fn tests(&self) -> Vec<Box<dyn #crate_name::Test>> {
245                vec![
246                    #(#test_case_objects),*
247                ]
248            }
249
250            async fn before_all(&self) -> anyhow::Result<()> {
251                #before_all_code
252            }
253
254            async fn before_each(&self) -> anyhow::Result<()> {
255                #before_each_code
256            }
257
258            async fn after_each(&self) -> anyhow::Result<()> {
259                #after_each_code
260            }
261
262            async fn after_all(&self) -> anyhow::Result<()> {
263                #after_all_code
264            }
265        }
266
267        struct #factory_name;
268
269        #[async_trait::async_trait]
270        impl #crate_name::TestSuiteFactory<#config_ty_name> for #factory_name {
271            fn name(&self) -> String {
272                #suite_name.to_string()
273            }
274
275            /// Creates a new test suite instance.
276            async fn create_suite(&self, config: &#config_ty_name) -> anyhow::Result<Box<dyn #crate_name::TestSuite>> {
277                let self_ = #struct_ty_name::#constructor_fn_name(config).await?;
278                Ok(Box::new(self_))
279            }
280        }
281
282        impl std::fmt::Debug for #factory_name {
283            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284                write!(f, "{}", <Self as #crate_name::TestSuiteFactory<#config_ty_name>>::name(self))
285            }
286        }
287
288        #(#test_case_code)*
289    };
290
291    TokenStream::from(output)
292}