1use proc_macro::TokenStream;
22use quote::{format_ident, quote};
23use syn::parse::{Parse, ParseStream};
24use syn::punctuated::Punctuated;
25use syn::{Expr, FnArg, ItemFn, Lit, Meta, Pat, Token, Type, parse_macro_input};
26
27struct FerritestArgs {
29 retries: Option<u32>,
30 timeout_ms: Option<u64>,
31 tags: Vec<String>,
32 skip: Option<Option<String>>,
34 slow: Option<Option<String>>,
36 fixme: Option<Option<String>>,
38 fail: Option<Option<String>>,
40 only: bool,
41 infos: Vec<(String, String)>,
43 use_options: Option<String>,
45}
46
47impl Parse for FerritestArgs {
48 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
49 let mut args = Self {
50 retries: None,
51 timeout_ms: None,
52 tags: Vec::new(),
53 skip: None,
54 slow: None,
55 fixme: None,
56 fail: None,
57 only: false,
58 infos: Vec::new(),
59 use_options: None,
60 };
61
62 let metas = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
63 for meta in metas {
64 match &meta {
65 Meta::NameValue(nv) => {
66 let ident = nv.path.get_ident().map(ToString::to_string).unwrap_or_default();
67 match ident.as_str() {
68 "retries" => {
69 if let syn::Expr::Lit(lit) = &nv.value {
70 if let Lit::Int(i) = &lit.lit {
71 args.retries = Some(i.base10_parse()?);
72 }
73 }
74 },
75 "timeout" => {
76 if let syn::Expr::Lit(lit) = &nv.value {
77 if let Lit::Str(s) = &lit.lit {
78 args.timeout_ms = Some(parse_duration_str(&s.value())?);
79 }
80 }
81 },
82 "tag" => {
83 if let syn::Expr::Lit(lit) = &nv.value {
84 if let Lit::Str(s) = &lit.lit {
85 args.tags.push(s.value());
86 }
87 }
88 },
89 "skip" => {
90 if let syn::Expr::Lit(lit) = &nv.value {
91 if let Lit::Str(s) = &lit.lit {
92 args.skip = Some(Some(s.value()));
93 }
94 }
95 },
96 "slow" => {
97 if let syn::Expr::Lit(lit) = &nv.value {
98 if let Lit::Str(s) = &lit.lit {
99 args.slow = Some(Some(s.value()));
100 }
101 }
102 },
103 "fixme" => {
104 if let syn::Expr::Lit(lit) = &nv.value {
105 if let Lit::Str(s) = &lit.lit {
106 args.fixme = Some(Some(s.value()));
107 }
108 }
109 },
110 "fail" => {
111 if let syn::Expr::Lit(lit) = &nv.value {
112 if let Lit::Str(s) = &lit.lit {
113 args.fail = Some(Some(s.value()));
114 }
115 }
116 },
117 "use_options" => {
118 if let syn::Expr::Lit(lit) = &nv.value {
119 if let Lit::Str(s) = &lit.lit {
120 args.use_options = Some(s.value());
121 }
122 }
123 },
124 "info" => {
125 if let syn::Expr::Lit(lit) = &nv.value {
126 if let Lit::Str(s) = &lit.lit {
127 let val = s.value();
128 if let Some((type_name, desc)) = val.split_once(':') {
129 args.infos.push((type_name.trim().to_string(), desc.trim().to_string()));
130 } else {
131 args.infos.push((val, String::new()));
132 }
133 }
134 }
135 },
136 _ => {
137 return Err(syn::Error::new_spanned(
138 &nv.path,
139 format!("unknown ferritest attribute: {ident}"),
140 ));
141 },
142 }
143 },
144 Meta::Path(p) => {
145 let ident = p.get_ident().map(ToString::to_string).unwrap_or_default();
146 match ident.as_str() {
147 "skip" => args.skip = Some(None),
148 "slow" => args.slow = Some(None),
149 "fixme" => args.fixme = Some(None),
150 "fail" => args.fail = Some(None),
151 "only" => args.only = true,
152 _ => return Err(syn::Error::new_spanned(p, format!("unknown ferritest flag: {ident}"))),
153 }
154 },
155 Meta::List(_) => {
156 return Err(syn::Error::new_spanned(&meta, "unexpected nested attribute"));
157 },
158 }
159 }
160 Ok(args)
161 }
162}
163
164fn parse_duration_str(s: &str) -> syn::Result<u64> {
165 let s = s.trim();
166 if let Some(secs) = s.strip_suffix('s') {
167 secs
168 .trim()
169 .parse::<u64>()
170 .map(|v| v * 1000)
171 .map_err(|e| syn::Error::new(proc_macro2::Span::call_site(), format!("invalid timeout: {e}")))
172 } else if let Some(ms) = s.strip_suffix("ms") {
173 ms.trim()
174 .parse::<u64>()
175 .map_err(|e| syn::Error::new(proc_macro2::Span::call_site(), format!("invalid timeout: {e}")))
176 } else {
177 s.parse::<u64>().map_err(|e| {
178 syn::Error::new(
179 proc_macro2::Span::call_site(),
180 format!("invalid timeout (use '30s' or '5000ms'): {e}"),
181 )
182 })
183 }
184}
185
186#[proc_macro_attribute]
191pub fn ferritest(attr: TokenStream, item: TokenStream) -> TokenStream {
192 let args = parse_macro_input!(attr as FerritestArgs);
193 let input = parse_macro_input!(item as ItemFn);
194
195 let fn_name = &input.sig.ident;
196 let fn_name_str = fn_name.to_string();
197 let vis = &input.vis;
198 let block = &input.block;
199 let attrs = &input.attrs;
200
201 let ctx_param_name = if let Some(FnArg::Typed(pt)) = input.sig.inputs.first() {
204 if let Pat::Ident(pi) = pt.pat.as_ref() {
205 pi.ident.clone()
206 } else {
207 format_ident!("ctx")
208 }
209 } else {
210 format_ident!("ctx")
211 };
212
213 let fixture_names: Vec<String> = Vec::new();
215 let fixture_array = fixture_names.iter().map(|f| quote! { #f });
216
217 fn annotation_tokens(variant: &str, arg: &Option<Option<String>>, annotations: &mut Vec<proc_macro2::TokenStream>) {
220 let variant_ident = quote::format_ident!("{}", variant);
221 match arg {
222 Some(None) => {
223 annotations
224 .push(quote! { ferridriver_test::model::TestAnnotation::#variant_ident { reason: None, condition: None } });
225 },
226 Some(Some(val)) => {
227 if let Some((cond, reason)) = val.split_once('|') {
229 let cond = cond.trim();
230 let reason = reason.trim();
231 annotations.push(quote! { ferridriver_test::model::TestAnnotation::#variant_ident {
232 reason: Some(#reason.to_string()),
233 condition: Some(#cond.to_string()),
234 } });
235 } else {
236 annotations.push(quote! { ferridriver_test::model::TestAnnotation::#variant_ident {
237 reason: None,
238 condition: Some(#val.to_string()),
239 } });
240 }
241 },
242 None => {},
243 }
244 }
245
246 let mut annotations = Vec::new();
247 annotation_tokens("Skip", &args.skip, &mut annotations);
248 annotation_tokens("Slow", &args.slow, &mut annotations);
249 annotation_tokens("Fixme", &args.fixme, &mut annotations);
250 annotation_tokens("Fail", &args.fail, &mut annotations);
251 if args.only {
252 annotations.push(quote! { ferridriver_test::model::TestAnnotation::Only });
253 }
254 for tag in &args.tags {
255 annotations.push(quote! { ferridriver_test::model::TestAnnotation::Tag(#tag.to_string()) });
256 }
257 for (type_name, desc) in &args.infos {
258 annotations.push(
259 quote! { ferridriver_test::model::TestAnnotation::Info { type_name: #type_name.to_string(), description: #desc.to_string() } },
260 );
261 }
262
263 let retries_expr = match args.retries {
264 Some(r) => quote! { Some(#r) },
265 None => quote! { None },
266 };
267 let timeout_ms_expr = match args.timeout_ms {
268 Some(ms) => quote! { Some(#ms) },
269 None => quote! { None },
270 };
271 let use_options_expr = match &args.use_options {
272 Some(json) => quote! { Some(#json) },
273 None => quote! { None },
274 };
275
276 let expanded = quote! {
277 #(#attrs)*
278 #vis async fn #fn_name(__pool: ferridriver_test::fixture::FixturePool) -> Result<(), ferridriver_test::model::TestFailure> {
279 let #ctx_param_name = ferridriver_test::TestContext::new(__pool);
280 #block
281 Ok(())
282 }
283
284 inventory::submit! {
285 ferridriver_test::discovery::TestRegistration {
286 file: file!(),
287 module_path: module_path!(),
288 name: #fn_name_str,
289 fixture_requests: &[#(#fixture_array),*],
290 annotations: &[#(#annotations),*],
291 timeout_ms: #timeout_ms_expr,
292 retries: #retries_expr,
293 use_options: #use_options_expr,
294 test_fn: |pool| Box::pin(#fn_name(pool)),
295 }
296 }
297 };
298
299 expanded.into()
300}
301
302struct FerritestEachArgs {
304 data: Vec<Vec<Expr>>,
305}
306
307impl Parse for FerritestEachArgs {
308 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
309 let ident: syn::Ident = input.parse()?;
311 if ident != "data" {
312 return Err(syn::Error::new_spanned(&ident, "expected `data = [...]`"));
313 }
314 let _: Token![=] = input.parse()?;
315
316 let content;
317 syn::bracketed!(content in input);
318
319 let mut data = Vec::new();
320 while !content.is_empty() {
321 let inner;
322 syn::parenthesized!(inner in content);
323 let exprs: Punctuated<Expr, Token![,]> = Punctuated::parse_terminated(&inner)?;
324 data.push(exprs.into_iter().collect());
325
326 if content.peek(Token![,]) {
327 let _: Token![,] = content.parse()?;
328 }
329 }
330
331 Ok(Self { data })
332 }
333}
334
335#[proc_macro_attribute]
349pub fn ferritest_each(attr: TokenStream, item: TokenStream) -> TokenStream {
350 let args = parse_macro_input!(attr as FerritestEachArgs);
351 let input = parse_macro_input!(item as ItemFn);
352
353 let fn_name = &input.sig.ident;
354 let fn_name_str = fn_name.to_string();
355 let block = &input.block;
356 let attrs = &input.attrs;
357
358 let all_params: Vec<_> = input.sig.inputs.iter().collect();
360 let ctx_param_name = if let Some(FnArg::Typed(pt)) = all_params.first() {
361 if let Pat::Ident(pi) = pt.pat.as_ref() {
362 pi.ident.clone()
363 } else {
364 format_ident!("ctx")
365 }
366 } else {
367 format_ident!("ctx")
368 };
369
370 let data_params: Vec<(&syn::Ident, &Type)> = all_params
371 .iter()
372 .skip(1) .filter_map(|arg| {
374 if let FnArg::Typed(pat_type) = arg {
375 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
376 return Some((&pat_ident.ident, &*pat_type.ty));
377 }
378 }
379 None
380 })
381 .collect();
382
383 let fixture_names: Vec<String> = Vec::new();
384
385 let mut submissions = Vec::new();
387 for (row_idx, row) in args.data.iter().enumerate() {
388 if row.len() != data_params.len() {
389 return syn::Error::new_spanned(
390 &input.sig.ident,
391 format!(
392 "data row {} has {} values but function expects {} data parameters",
393 row_idx,
394 row.len(),
395 data_params.len()
396 ),
397 )
398 .to_compile_error()
399 .into();
400 }
401
402 let row_values_str: Vec<String> = row.iter().map(|e| quote!(#e).to_string().replace('"', "")).collect();
404 let suffix = row_values_str.join(", ");
405 let test_name = format!("{fn_name_str} ({suffix})");
406
407 let data_bindings: Vec<_> = data_params
409 .iter()
410 .zip(row.iter())
411 .map(|((param_name, param_type), value)| {
412 quote! { let #param_name: #param_type = #value; }
413 })
414 .collect();
415
416 let inner_fn_name = format_ident!("__ferritest_each_{}_{}", fn_name, row_idx);
417 let fixture_array = fixture_names.iter().map(|f| quote! { #f });
418 let ctx_param = ctx_param_name.clone();
419
420 submissions.push(quote! {
421 async fn #inner_fn_name(__pool: ferridriver_test::fixture::FixturePool) -> Result<(), ferridriver_test::model::TestFailure> {
422 let #ctx_param = ferridriver_test::TestContext::new(__pool);
423 #(#data_bindings)*
424 #block
425 Ok(())
426 }
427
428 inventory::submit! {
429 ferridriver_test::discovery::TestRegistration {
430 file: file!(),
431 module_path: module_path!(),
432 name: #test_name,
433 fixture_requests: &[#(#fixture_array),*],
434 annotations: &[],
435 timeout_ms: None,
436 retries: None,
437 test_fn: |pool| Box::pin(#inner_fn_name(pool)),
438 }
439 }
440 });
441 }
442
443 let expanded = quote! {
444 #(#attrs)*
445 #(#submissions)*
446 };
447
448 expanded.into()
449}
450
451fn hook_impl(kind_tag: &str, is_suite_hook: bool, item: TokenStream) -> TokenStream {
455 let input = parse_macro_input!(item as ItemFn);
456 let fn_name = &input.sig.ident;
457 let vis = &input.vis;
458 let block = &input.block;
459 let attrs = &input.attrs;
460
461 let kind_ident = format_ident!("{}", kind_tag);
462
463 let ctx_param_name = if let Some(FnArg::Typed(pt)) = input.sig.inputs.first() {
465 if let Pat::Ident(pi) = pt.pat.as_ref() {
466 pi.ident.clone()
467 } else {
468 format_ident!("ctx")
469 }
470 } else {
471 format_ident!("ctx")
472 };
473
474 if is_suite_hook {
475 let expanded = quote! {
477 #(#attrs)*
478 #vis fn #fn_name(__pool: ferridriver_test::fixture::FixturePool)
479 -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = Result<(), ferridriver_test::model::TestFailure>> + Send>>
480 {
481 Box::pin(async move {
482 let #ctx_param_name = ferridriver_test::TestContext::new(__pool);
483 #block
484 Ok(())
485 })
486 }
487
488 inventory::submit! {
489 ferridriver_test::discovery::HookRegistration {
490 module_path: module_path!(),
491 suite_hook_fn: Some(#fn_name),
492 each_hook_fn: None,
493 kind: ferridriver_test::discovery::HookKindTag::#kind_ident,
494 }
495 }
496 };
497 expanded.into()
498 } else {
499 let expanded = quote! {
501 #(#attrs)*
502 #vis fn #fn_name(
503 __pool: ferridriver_test::fixture::FixturePool,
504 __info: ::std::sync::Arc<ferridriver_test::model::TestInfo>,
505 ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = Result<(), ferridriver_test::model::TestFailure>> + Send>>
506 {
507 Box::pin(async move {
508 let #ctx_param_name = ferridriver_test::TestContext::new(__pool);
509 #block
510 Ok(())
511 })
512 }
513
514 inventory::submit! {
515 ferridriver_test::discovery::HookRegistration {
516 module_path: module_path!(),
517 suite_hook_fn: None,
518 each_hook_fn: Some(#fn_name),
519 kind: ferridriver_test::discovery::HookKindTag::#kind_ident,
520 }
521 }
522 };
523 expanded.into()
524 }
525}
526
527#[proc_macro_attribute]
543pub fn before_all(_attr: TokenStream, item: TokenStream) -> TokenStream {
544 hook_impl("BeforeAll", true, item)
545}
546
547#[proc_macro_attribute]
549pub fn after_all(_attr: TokenStream, item: TokenStream) -> TokenStream {
550 hook_impl("AfterAll", true, item)
551}
552
553#[proc_macro_attribute]
570pub fn before_each(_attr: TokenStream, item: TokenStream) -> TokenStream {
571 hook_impl("BeforeEach", false, item)
572}
573
574#[proc_macro_attribute]
576pub fn after_each(_attr: TokenStream, item: TokenStream) -> TokenStream {
577 hook_impl("AfterEach", false, item)
578}