1extern crate proc_macro;
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{ImplItem, ItemImpl, Type, parse_macro_input};
6
7enum DetectedItem {
8 Constructor,
9 BeforeAll,
10 BeforeEach,
11 AfterEach,
12 AfterAll,
13 TestCase(String),
14}
15
16#[proc_macro_attribute]
17pub fn test_suite(attr: TokenStream, item: TokenStream) -> TokenStream {
18 test_suite_impl(attr, item)
19}
20
21fn test_suite_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
22 let suite_name = parse_macro_input!(attr as syn::Lit);
23
24 let input = parse_macro_input!(item as ItemImpl);
25
26 let struct_ty = input.self_ty.clone();
27 let Type::Path(struct_ty) = *struct_ty else {
28 panic!("The test suite must be implemented for a struct type");
29 };
30 let struct_ty_name = struct_ty.path.get_ident().expect("Expected a struct name");
31
32 let mut constructor = None;
33 let mut before_all = None;
34 let mut before_each = None;
35 let mut after_each = None;
36 let mut after_all = None;
37 let mut test_cases = vec![];
38
39 let mut cleaned_items = vec![];
40
41 for item in &input.items {
42 if let ImplItem::Fn(mut method) = item.clone() {
43 let mut detected_item = None;
44 method.attrs.retain(|attr| {
45 let ident = attr.meta.path().get_ident();
46 if let Some(ident) = ident {
47 match ident.to_string().as_str() {
48 "constructor" => {
49 detected_item = Some(DetectedItem::Constructor);
50 return false;
51 }
52 "before_all" => {
53 detected_item = Some(DetectedItem::BeforeAll);
54 return false;
55 }
56 "before_each" => {
57 detected_item = Some(DetectedItem::BeforeEach);
58 return false;
59 }
60 "after_each" => {
61 detected_item = Some(DetectedItem::AfterEach);
62 return false;
63 }
64 "after_all" => {
65 detected_item = Some(DetectedItem::AfterAll);
66 return false;
67 }
68 "test_case" => {
69 if let Ok(syn::Lit::Str(lit_str)) = attr
70 .meta
71 .require_list()
72 .expect("`test_case` attribute must contain test name")
73 .parse_args()
74 {
75 detected_item = Some(DetectedItem::TestCase(lit_str.value()));
76 }
77 return false;
78 }
79 _ => {}
80 }
81 }
82 true
83 });
84 cleaned_items.push(ImplItem::Fn(method.clone()));
85
86 match detected_item {
87 Some(DetectedItem::Constructor) => {
88 if constructor.is_some() {
89 panic!("Only one constructor is allowed in a test suite");
90 }
91 constructor = Some(method);
92 }
93 Some(DetectedItem::BeforeAll) => {
94 if before_all.is_some() {
95 panic!("Only one 'before_all' method is allowed in a test suite");
96 }
97 before_all = Some(method);
98 }
99 Some(DetectedItem::BeforeEach) => {
100 if before_each.is_some() {
101 panic!("Only one 'before_each' method is allowed in a test suite");
102 }
103 before_each = Some(method);
104 }
105 Some(DetectedItem::AfterEach) => {
106 if after_each.is_some() {
107 panic!("Only one 'after_each' method is allowed in a test suite");
108 }
109 after_each = Some(method);
110 }
111 Some(DetectedItem::AfterAll) => {
112 if after_all.is_some() {
113 panic!("Only one 'after_all' method is allowed in a test suite");
114 }
115 after_all = Some(method);
116 }
117 Some(DetectedItem::TestCase(name)) => {
118 test_cases.push((name, method));
119 }
120 None => {}
121 }
122 } else {
123 cleaned_items.push(item.clone());
124 }
125 }
126
127 let constructor = constructor.unwrap_or_else(|| {
128 panic!("A test suite must have a constructor method annotated with #[constructor]");
129 });
130 let config_ty = constructor.sig.inputs.first().cloned().unwrap_or_else(|| {
132 panic!("Constructor method must have a single argument for the config type");
133 });
134 let config_ty = if let syn::FnArg::Typed(pat_type) = config_ty {
135 pat_type.ty
136 } else {
137 panic!("Constructor method must have a single argument for the config type");
138 };
139 let Type::Reference(config_ty) = *config_ty else {
140 panic!("Constructor method must take reference to the config type as an argument");
141 };
142 let Type::Path(config_ty) = *config_ty.elem else {
143 panic!("Constructor method must take reference to the config type as an argument");
144 };
145 let config_ty_name = config_ty
146 .path
147 .get_ident()
148 .expect("Expected a config type name");
149 let constructor_fn_name = &constructor.sig.ident;
150
151 let crate_name = quote::format_ident!("e2e");
152
153 let before_all_code = if let Some(before_all) = before_all {
154 let fn_name = &before_all.sig.ident;
155 quote! {
156 #struct_ty_name::#fn_name(&self).await
157 }
158 } else {
159 quote! {
160 Ok(())
161 }
162 };
163 let before_each_code = if let Some(before_each) = before_each {
164 let fn_name = &before_each.sig.ident;
165 quote! {
166 #struct_ty_name::#fn_name(&self).await
167 }
168 } else {
169 quote! {
170 Ok(())
171 }
172 };
173 let after_each_code = if let Some(after_each) = after_each {
174 let fn_name = &after_each.sig.ident;
175 quote! {
176 #struct_ty_name::#fn_name(&self).await
177 }
178 } else {
179 quote! {
180 Ok(())
181 }
182 };
183 let after_all_code = if let Some(after_all) = after_all {
184 let fn_name = &after_all.sig.ident;
185 quote! {
186 #struct_ty_name::#fn_name(&self).await
187 }
188 } else {
189 quote! {
190 Ok(())
191 }
192 };
193
194 let factory_name = quote::format_ident!("{}Factory", struct_ty_name);
195
196 let mut test_case_code = Vec::new();
197 let mut test_case_objects = Vec::new();
198 for (id, (name, method)) in test_cases.into_iter().enumerate() {
199 let test_fn_name = &method.sig.ident;
200
201 let test_ty_name = quote::format_ident!("{}Test{}", struct_ty_name, id);
202 let test_case = quote! {
203 struct #test_ty_name(#struct_ty_name);
204
205 #[async_trait::async_trait]
206 impl #crate_name::Test for #test_ty_name {
207 fn name(&self) -> String {
208 #name.to_string()
209 }
210
211 async fn run(&self) -> anyhow::Result<()> {
212 self.0.#test_fn_name().await
213 }
214 }
215 };
216 test_case_code.push(test_case);
217 test_case_objects.push(quote! {
218 Box::new(#test_ty_name(self.clone()))
219 });
220 }
221
222 let factory_fn: syn::ImplItem = syn::parse_quote! {
223 pub fn factory() -> Box<dyn #crate_name::TestSuiteFactory<#config_ty_name>> {
224 Box::new(#factory_name)
225 }
226 };
227 cleaned_items.push(factory_fn);
228
229 let cleaned_impl = ItemImpl {
230 items: cleaned_items,
231 ..input
232 };
233
234 let output = quote! {
236 #cleaned_impl
237
238 #[async_trait::async_trait]
239 impl #crate_name::TestSuite for #struct_ty_name {
240 fn name(&self) -> String {
241 #suite_name.to_string()
242 }
243
244 fn tests(&self) -> Vec<Box<dyn #crate_name::Test>> {
245 vec![
246 #(#test_case_objects),*
247 ]
248 }
249
250 async fn before_all(&self) -> anyhow::Result<()> {
251 #before_all_code
252 }
253
254 async fn before_each(&self) -> anyhow::Result<()> {
255 #before_each_code
256 }
257
258 async fn after_each(&self) -> anyhow::Result<()> {
259 #after_each_code
260 }
261
262 async fn after_all(&self) -> anyhow::Result<()> {
263 #after_all_code
264 }
265 }
266
267 struct #factory_name;
268
269 #[async_trait::async_trait]
270 impl #crate_name::TestSuiteFactory<#config_ty_name> for #factory_name {
271 fn name(&self) -> String {
272 #suite_name.to_string()
273 }
274
275 async fn create_suite(&self, config: &#config_ty_name) -> anyhow::Result<Box<dyn #crate_name::TestSuite>> {
277 let self_ = #struct_ty_name::#constructor_fn_name(config).await?;
278 Ok(Box::new(self_))
279 }
280 }
281
282 impl std::fmt::Debug for #factory_name {
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 write!(f, "{}", <Self as #crate_name::TestSuiteFactory<#config_ty_name>>::name(self))
285 }
286 }
287
288 #(#test_case_code)*
289 };
290
291 TokenStream::from(output)
292}