1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{ImplItem, ItemImpl, Type, parse_macro_input};
6
7enum DetectedItem {
8 Constructor,
9 BeforeAll,
10 BeforeEach,
11 AfterEach,
12 AfterAll,
13 TestCase(String),
14}
15
16#[proc_macro_attribute]
17pub fn test_suite(attr: TokenStream, item: TokenStream) -> TokenStream {
18 test_suite_impl(attr, item)
19}
20
21fn test_suite_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
22 let suite_name = parse_macro_input!(attr as syn::Lit);
23
24 let input = parse_macro_input!(item as ItemImpl);
25
26 let struct_ty = input.self_ty.clone();
27 let Type::Path(struct_ty) = *struct_ty else {
28 panic!("The test suite must be implemented for a struct type");
29 };
30 let struct_ty_name = struct_ty.path.get_ident().expect("Expected a struct name");
31
32 let mut constructor = None;
33 let mut before_all = None;
34 let mut before_each = None;
35 let mut after_each = None;
36 let mut after_all = None;
37 let mut test_cases = vec![];
38
39 let mut cleaned_items = vec![];
40
41 for item in &input.items {
42 if let ImplItem::Fn(mut method) = item.clone() {
43 let mut detected_item = None;
44 let mut ignore = false;
45 method.attrs.retain(|attr| {
46 let ident = attr.meta.path().get_ident();
47 if let Some(ident) = ident {
48 match ident.to_string().as_str() {
49 "constructor" => {
50 detected_item = Some(DetectedItem::Constructor);
51 return false;
52 }
53 "before_all" => {
54 detected_item = Some(DetectedItem::BeforeAll);
55 return false;
56 }
57 "before_each" => {
58 detected_item = Some(DetectedItem::BeforeEach);
59 return false;
60 }
61 "after_each" => {
62 detected_item = Some(DetectedItem::AfterEach);
63 return false;
64 }
65 "after_all" => {
66 detected_item = Some(DetectedItem::AfterAll);
67 return false;
68 }
69 "test_case" => {
70 if let Ok(syn::Lit::Str(lit_str)) = attr
71 .meta
72 .require_list()
73 .expect("`test_case` attribute must contain test name")
74 .parse_args()
75 {
76 detected_item = Some(DetectedItem::TestCase(lit_str.value()));
77 }
78 return false;
79 }
80 "ignore" => {
81 ignore = true;
82 return false;
83 }
84 _ => {}
85 }
86 }
87 true
88 });
89 cleaned_items.push(ImplItem::Fn(method.clone()));
90
91 if ignore {
92 assert!(
93 matches!(detected_item, Some(DetectedItem::TestCase(_))),
94 "The `ignore` attribute can only be used with `test_case` methods"
95 );
96 }
97
98 match detected_item {
99 Some(DetectedItem::Constructor) => {
100 if constructor.is_some() {
101 panic!("Only one constructor is allowed in a test suite");
102 }
103 constructor = Some(method);
104 }
105 Some(DetectedItem::BeforeAll) => {
106 if before_all.is_some() {
107 panic!("Only one 'before_all' method is allowed in a test suite");
108 }
109 before_all = Some(method);
110 }
111 Some(DetectedItem::BeforeEach) => {
112 if before_each.is_some() {
113 panic!("Only one 'before_each' method is allowed in a test suite");
114 }
115 before_each = Some(method);
116 }
117 Some(DetectedItem::AfterEach) => {
118 if after_each.is_some() {
119 panic!("Only one 'after_each' method is allowed in a test suite");
120 }
121 after_each = Some(method);
122 }
123 Some(DetectedItem::AfterAll) => {
124 if after_all.is_some() {
125 panic!("Only one 'after_all' method is allowed in a test suite");
126 }
127 after_all = Some(method);
128 }
129 Some(DetectedItem::TestCase(name)) => {
130 test_cases.push((name, method, ignore));
131 }
132 None => {}
133 }
134 } else {
135 cleaned_items.push(item.clone());
136 }
137 }
138
139 let constructor = constructor.unwrap_or_else(|| {
140 panic!("A test suite must have a constructor method annotated with #[constructor]");
141 });
142 let config_ty = constructor.sig.inputs.first().cloned().unwrap_or_else(|| {
144 panic!("Constructor method must have a single argument for the config type");
145 });
146 let config_ty = if let syn::FnArg::Typed(pat_type) = config_ty {
147 pat_type.ty
148 } else {
149 panic!("Constructor method must have a single argument for the config type");
150 };
151 let Type::Reference(config_ty) = *config_ty else {
152 panic!("Constructor method must take reference to the config type as an argument");
153 };
154 let Type::Path(config_ty) = *config_ty.elem else {
155 panic!("Constructor method must take reference to the config type as an argument");
156 };
157 let config_ty_name = config_ty
158 .path
159 .get_ident()
160 .expect("Expected a config type name");
161 let constructor_fn_name = &constructor.sig.ident;
162
163 let crate_name = quote::format_ident!("e2e");
164
165 let before_all_code = if let Some(before_all) = before_all {
166 let fn_name = &before_all.sig.ident;
167 quote! {
168 #struct_ty_name::#fn_name(&self).await
169 }
170 } else {
171 quote! {
172 Ok(())
173 }
174 };
175 let before_each_code = if let Some(before_each) = before_each {
176 let fn_name = &before_each.sig.ident;
177 quote! {
178 #struct_ty_name::#fn_name(&self).await
179 }
180 } else {
181 quote! {
182 Ok(())
183 }
184 };
185 let after_each_code = if let Some(after_each) = after_each {
186 let fn_name = &after_each.sig.ident;
187 quote! {
188 #struct_ty_name::#fn_name(&self).await
189 }
190 } else {
191 quote! {
192 Ok(())
193 }
194 };
195 let after_all_code = if let Some(after_all) = after_all {
196 let fn_name = &after_all.sig.ident;
197 quote! {
198 #struct_ty_name::#fn_name(&self).await
199 }
200 } else {
201 quote! {
202 Ok(())
203 }
204 };
205
206 let factory_name = quote::format_ident!("{}Factory", struct_ty_name);
207
208 let mut test_case_code = Vec::new();
209 let mut test_case_objects = Vec::new();
210 for (id, (name, method, ignore)) in test_cases.into_iter().enumerate() {
211 let test_fn_name = &method.sig.ident;
212
213 let test_ty_name = quote::format_ident!("{}Test{}", struct_ty_name, id);
214 let test_case = quote! {
215 struct #test_ty_name(#struct_ty_name);
216
217 #[async_trait::async_trait]
218 impl #crate_name::Test for #test_ty_name {
219 fn name(&self) -> String {
220 #name.to_string()
221 }
222
223 async fn run(&self) -> anyhow::Result<()> {
224 self.0.#test_fn_name().await
225 }
226
227 fn ignore(&self) -> bool {
228 #ignore
229 }
230 }
231 };
232 test_case_code.push(test_case);
233 test_case_objects.push(quote! {
234 Box::new(#test_ty_name(self.clone()))
235 });
236 }
237
238 let factory_fn: syn::ImplItem = syn::parse_quote! {
239 pub fn factory() -> Box<dyn #crate_name::TestSuiteFactory<#config_ty_name>> {
240 Box::new(#factory_name)
241 }
242 };
243 cleaned_items.push(factory_fn);
244
245 let cleaned_impl = ItemImpl {
246 items: cleaned_items,
247 ..input
248 };
249
250 let output = quote! {
252 #cleaned_impl
253
254 #[async_trait::async_trait]
255 impl #crate_name::TestSuite for #struct_ty_name {
256 fn name(&self) -> String {
257 #suite_name.to_string()
258 }
259
260 fn tests(&self) -> Vec<Box<dyn #crate_name::Test>> {
261 vec![
262 #(#test_case_objects),*
263 ]
264 }
265
266 async fn before_all(&self) -> anyhow::Result<()> {
267 #before_all_code
268 }
269
270 async fn before_each(&self) -> anyhow::Result<()> {
271 #before_each_code
272 }
273
274 async fn after_each(&self) -> anyhow::Result<()> {
275 #after_each_code
276 }
277
278 async fn after_all(&self) -> anyhow::Result<()> {
279 #after_all_code
280 }
281 }
282
283 struct #factory_name;
284
285 #[async_trait::async_trait]
286 impl #crate_name::TestSuiteFactory<#config_ty_name> for #factory_name {
287 fn name(&self) -> String {
288 #suite_name.to_string()
289 }
290
291 async fn create_suite(&self, config: &#config_ty_name) -> anyhow::Result<Box<dyn #crate_name::TestSuite>> {
293 let self_ = #struct_ty_name::#constructor_fn_name(config).await?;
294 Ok(Box::new(self_))
295 }
296 }
297
298 impl std::fmt::Debug for #factory_name {
299 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
300 write!(f, "{}", <Self as #crate_name::TestSuiteFactory<#config_ty_name>>::name(self))
301 }
302 }
303
304 #(#test_case_code)*
305 };
306
307 TokenStream::from(output)
308}