e2e_macro/
lib.rs

1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{ImplItem, ItemImpl, Type, parse_macro_input};
6
7enum DetectedItem {
8    Constructor,
9    BeforeAll,
10    BeforeEach,
11    AfterEach,
12    AfterAll,
13    TestCase(String),
14}
15
16#[proc_macro_attribute]
17pub fn test_suite(attr: TokenStream, item: TokenStream) -> TokenStream {
18    test_suite_impl(attr, item)
19}
20
21fn test_suite_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
22    let suite_name = parse_macro_input!(attr as syn::Lit);
23
24    let input = parse_macro_input!(item as ItemImpl);
25
26    let struct_ty = input.self_ty.clone();
27    let Type::Path(struct_ty) = *struct_ty else {
28        panic!("The test suite must be implemented for a struct type");
29    };
30    let struct_ty_name = struct_ty.path.get_ident().expect("Expected a struct name");
31
32    let mut constructor = None;
33    let mut before_all = None;
34    let mut before_each = None;
35    let mut after_each = None;
36    let mut after_all = None;
37    let mut test_cases = vec![];
38
39    let mut cleaned_items = vec![];
40
41    for item in &input.items {
42        if let ImplItem::Fn(mut method) = item.clone() {
43            let mut detected_item = None;
44            let mut ignore = false;
45            method.attrs.retain(|attr| {
46                let ident = attr.meta.path().get_ident();
47                if let Some(ident) = ident {
48                    match ident.to_string().as_str() {
49                        "constructor" => {
50                            detected_item = Some(DetectedItem::Constructor);
51                            return false;
52                        }
53                        "before_all" => {
54                            detected_item = Some(DetectedItem::BeforeAll);
55                            return false;
56                        }
57                        "before_each" => {
58                            detected_item = Some(DetectedItem::BeforeEach);
59                            return false;
60                        }
61                        "after_each" => {
62                            detected_item = Some(DetectedItem::AfterEach);
63                            return false;
64                        }
65                        "after_all" => {
66                            detected_item = Some(DetectedItem::AfterAll);
67                            return false;
68                        }
69                        "test_case" => {
70                            if let Ok(syn::Lit::Str(lit_str)) = attr
71                                .meta
72                                .require_list()
73                                .expect("`test_case` attribute must contain test name")
74                                .parse_args()
75                            {
76                                detected_item = Some(DetectedItem::TestCase(lit_str.value()));
77                            }
78                            return false;
79                        }
80                        "ignore" => {
81                            ignore = true;
82                            return false;
83                        }
84                        _ => {}
85                    }
86                }
87                true
88            });
89            cleaned_items.push(ImplItem::Fn(method.clone()));
90
91            if ignore {
92                assert!(
93                    matches!(detected_item, Some(DetectedItem::TestCase(_))),
94                    "The `ignore` attribute can only be used with `test_case` methods"
95                );
96            }
97
98            match detected_item {
99                Some(DetectedItem::Constructor) => {
100                    if constructor.is_some() {
101                        panic!("Only one constructor is allowed in a test suite");
102                    }
103                    constructor = Some(method);
104                }
105                Some(DetectedItem::BeforeAll) => {
106                    if before_all.is_some() {
107                        panic!("Only one 'before_all' method is allowed in a test suite");
108                    }
109                    before_all = Some(method);
110                }
111                Some(DetectedItem::BeforeEach) => {
112                    if before_each.is_some() {
113                        panic!("Only one 'before_each' method is allowed in a test suite");
114                    }
115                    before_each = Some(method);
116                }
117                Some(DetectedItem::AfterEach) => {
118                    if after_each.is_some() {
119                        panic!("Only one 'after_each' method is allowed in a test suite");
120                    }
121                    after_each = Some(method);
122                }
123                Some(DetectedItem::AfterAll) => {
124                    if after_all.is_some() {
125                        panic!("Only one 'after_all' method is allowed in a test suite");
126                    }
127                    after_all = Some(method);
128                }
129                Some(DetectedItem::TestCase(name)) => {
130                    test_cases.push((name, method, ignore));
131                }
132                None => {}
133            }
134        } else {
135            cleaned_items.push(item.clone());
136        }
137    }
138
139    let constructor = constructor.unwrap_or_else(|| {
140        panic!("A test suite must have a constructor method annotated with #[constructor]");
141    });
142    // The only argument to the constructor should be config type.
143    let config_ty = constructor.sig.inputs.first().cloned().unwrap_or_else(|| {
144        panic!("Constructor method must have a single argument for the config type");
145    });
146    let config_ty = if let syn::FnArg::Typed(pat_type) = config_ty {
147        pat_type.ty
148    } else {
149        panic!("Constructor method must have a single argument for the config type");
150    };
151    let Type::Reference(config_ty) = *config_ty else {
152        panic!("Constructor method must take reference to the config type as an argument");
153    };
154    let Type::Path(config_ty) = *config_ty.elem else {
155        panic!("Constructor method must take reference to the config type as an argument");
156    };
157    let config_ty_name = config_ty
158        .path
159        .get_ident()
160        .expect("Expected a config type name");
161    let constructor_fn_name = &constructor.sig.ident;
162
163    let crate_name = quote::format_ident!("e2e");
164
165    let before_all_code = if let Some(before_all) = before_all {
166        let fn_name = &before_all.sig.ident;
167        quote! {
168            #struct_ty_name::#fn_name(&self).await
169        }
170    } else {
171        quote! {
172            Ok(())
173        }
174    };
175    let before_each_code = if let Some(before_each) = before_each {
176        let fn_name = &before_each.sig.ident;
177        quote! {
178            #struct_ty_name::#fn_name(&self).await
179        }
180    } else {
181        quote! {
182            Ok(())
183        }
184    };
185    let after_each_code = if let Some(after_each) = after_each {
186        let fn_name = &after_each.sig.ident;
187        quote! {
188            #struct_ty_name::#fn_name(&self).await
189        }
190    } else {
191        quote! {
192            Ok(())
193        }
194    };
195    let after_all_code = if let Some(after_all) = after_all {
196        let fn_name = &after_all.sig.ident;
197        quote! {
198            #struct_ty_name::#fn_name(&self).await
199        }
200    } else {
201        quote! {
202            Ok(())
203        }
204    };
205
206    let factory_name = quote::format_ident!("{}Factory", struct_ty_name);
207
208    let mut test_case_code = Vec::new();
209    let mut test_case_objects = Vec::new();
210    for (id, (name, method, ignore)) in test_cases.into_iter().enumerate() {
211        let test_fn_name = &method.sig.ident;
212
213        let test_ty_name = quote::format_ident!("{}Test{}", struct_ty_name, id);
214        let test_case = quote! {
215            struct #test_ty_name(#struct_ty_name);
216
217            #[async_trait::async_trait]
218            impl #crate_name::Test for #test_ty_name {
219                fn name(&self) -> String {
220                    #name.to_string()
221                }
222
223                async fn run(&self) -> anyhow::Result<()> {
224                    self.0.#test_fn_name().await
225                }
226
227                fn ignore(&self) -> bool {
228                    #ignore
229                }
230            }
231        };
232        test_case_code.push(test_case);
233        test_case_objects.push(quote! {
234            Box::new(#test_ty_name(self.clone()))
235        });
236    }
237
238    let factory_fn: syn::ImplItem = syn::parse_quote! {
239        pub fn factory() -> Box<dyn #crate_name::TestSuiteFactory<#config_ty_name>> {
240            Box::new(#factory_name)
241        }
242    };
243    cleaned_items.push(factory_fn);
244
245    let cleaned_impl = ItemImpl {
246        items: cleaned_items,
247        ..input
248    };
249
250    // Placeholder for generating the test suite logic
251    let output = quote! {
252        #cleaned_impl
253
254        #[async_trait::async_trait]
255        impl #crate_name::TestSuite for #struct_ty_name {
256            fn name(&self) -> String {
257                #suite_name.to_string()
258            }
259
260            fn tests(&self) -> Vec<Box<dyn #crate_name::Test>> {
261                vec![
262                    #(#test_case_objects),*
263                ]
264            }
265
266            async fn before_all(&self) -> anyhow::Result<()> {
267                #before_all_code
268            }
269
270            async fn before_each(&self) -> anyhow::Result<()> {
271                #before_each_code
272            }
273
274            async fn after_each(&self) -> anyhow::Result<()> {
275                #after_each_code
276            }
277
278            async fn after_all(&self) -> anyhow::Result<()> {
279                #after_all_code
280            }
281        }
282
283        struct #factory_name;
284
285        #[async_trait::async_trait]
286        impl #crate_name::TestSuiteFactory<#config_ty_name> for #factory_name {
287            fn name(&self) -> String {
288                #suite_name.to_string()
289            }
290
291            /// Creates a new test suite instance.
292            async fn create_suite(&self, config: &#config_ty_name) -> anyhow::Result<Box<dyn #crate_name::TestSuite>> {
293                let self_ = #struct_ty_name::#constructor_fn_name(config).await?;
294                Ok(Box::new(self_))
295            }
296        }
297
298        impl std::fmt::Debug for #factory_name {
299            fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300                write!(f, "{}", <Self as #crate_name::TestSuiteFactory<#config_ty_name>>::name(self))
301            }
302        }
303
304        #(#test_case_code)*
305    };
306
307    TokenStream::from(output)
308}