Skip to main content

ferridriver_test_macros/
lib.rs

1//! Proc macros for the ferridriver test framework.
2//!
3//! Provides `#[ferritest]` to register async browser test functions
4//! with automatic fixture injection based on parameter types.
5//!
6//! ```ignore
7//! use ferridriver_test::prelude::*;
8//!
9//! #[ferritest]
10//! async fn basic_navigation(page: Page) {
11//!     page.goto("https://example.com", None).await.unwrap();
12//!     expect(&page).to_have_title("Example").await.unwrap();
13//! }
14//!
15//! #[ferritest(retries = 2, timeout = "30s", tag = "smoke")]
16//! async fn flaky_test(page: Page, context: BrowserContext) {
17//!     // ...
18//! }
19//! ```
20
21use proc_macro::TokenStream;
22use quote::{format_ident, quote};
23use syn::parse::{Parse, ParseStream};
24use syn::punctuated::Punctuated;
25use syn::{Expr, FnArg, ItemFn, Lit, Meta, Pat, Token, Type, parse_macro_input};
26
27/// Attribute arguments: `#[ferritest(retries = 2, timeout = "30s", tag = "smoke")]`
28struct FerritestArgs {
29  retries: Option<u32>,
30  timeout_ms: Option<u64>,
31  tags: Vec<String>,
32  /// None = not set, Some(None) = unconditional, Some(Some("firefox")) = conditional
33  skip: Option<Option<String>>,
34  /// None = not set, Some(None) = unconditional, Some(Some("ci")) = conditional
35  slow: Option<Option<String>>,
36  /// None = not set, Some(None) = unconditional, Some(Some("linux")) = conditional
37  fixme: Option<Option<String>>,
38  /// None = not set, Some(None) = unconditional, Some(Some("webkit")) = conditional
39  fail: Option<Option<String>>,
40  only: bool,
41  /// Structured metadata annotations: `info = "type:description"`.
42  infos: Vec<(String, String)>,
43  /// Raw JSON string for fixture/context overrides (viewport, locale, etc.)
44  use_options: Option<String>,
45}
46
47impl Parse for FerritestArgs {
48  fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
49    let mut args = Self {
50      retries: None,
51      timeout_ms: None,
52      tags: Vec::new(),
53      skip: None,
54      slow: None,
55      fixme: None,
56      fail: None,
57      only: false,
58      infos: Vec::new(),
59      use_options: None,
60    };
61
62    let metas = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
63    for meta in metas {
64      match &meta {
65        Meta::NameValue(nv) => {
66          let ident = nv.path.get_ident().map(ToString::to_string).unwrap_or_default();
67          match ident.as_str() {
68            "retries" => {
69              if let syn::Expr::Lit(lit) = &nv.value {
70                if let Lit::Int(i) = &lit.lit {
71                  args.retries = Some(i.base10_parse()?);
72                }
73              }
74            },
75            "timeout" => {
76              if let syn::Expr::Lit(lit) = &nv.value {
77                if let Lit::Str(s) = &lit.lit {
78                  args.timeout_ms = Some(parse_duration_str(&s.value())?);
79                }
80              }
81            },
82            "tag" => {
83              if let syn::Expr::Lit(lit) = &nv.value {
84                if let Lit::Str(s) = &lit.lit {
85                  args.tags.push(s.value());
86                }
87              }
88            },
89            "skip" => {
90              if let syn::Expr::Lit(lit) = &nv.value {
91                if let Lit::Str(s) = &lit.lit {
92                  args.skip = Some(Some(s.value()));
93                }
94              }
95            },
96            "slow" => {
97              if let syn::Expr::Lit(lit) = &nv.value {
98                if let Lit::Str(s) = &lit.lit {
99                  args.slow = Some(Some(s.value()));
100                }
101              }
102            },
103            "fixme" => {
104              if let syn::Expr::Lit(lit) = &nv.value {
105                if let Lit::Str(s) = &lit.lit {
106                  args.fixme = Some(Some(s.value()));
107                }
108              }
109            },
110            "fail" => {
111              if let syn::Expr::Lit(lit) = &nv.value {
112                if let Lit::Str(s) = &lit.lit {
113                  args.fail = Some(Some(s.value()));
114                }
115              }
116            },
117            "use_options" => {
118              if let syn::Expr::Lit(lit) = &nv.value {
119                if let Lit::Str(s) = &lit.lit {
120                  args.use_options = Some(s.value());
121                }
122              }
123            },
124            "info" => {
125              if let syn::Expr::Lit(lit) = &nv.value {
126                if let Lit::Str(s) = &lit.lit {
127                  let val = s.value();
128                  if let Some((type_name, desc)) = val.split_once(':') {
129                    args.infos.push((type_name.trim().to_string(), desc.trim().to_string()));
130                  } else {
131                    args.infos.push((val, String::new()));
132                  }
133                }
134              }
135            },
136            _ => {
137              return Err(syn::Error::new_spanned(
138                &nv.path,
139                format!("unknown ferritest attribute: {ident}"),
140              ));
141            },
142          }
143        },
144        Meta::Path(p) => {
145          let ident = p.get_ident().map(ToString::to_string).unwrap_or_default();
146          match ident.as_str() {
147            "skip" => args.skip = Some(None),
148            "slow" => args.slow = Some(None),
149            "fixme" => args.fixme = Some(None),
150            "fail" => args.fail = Some(None),
151            "only" => args.only = true,
152            _ => return Err(syn::Error::new_spanned(p, format!("unknown ferritest flag: {ident}"))),
153          }
154        },
155        Meta::List(_) => {
156          return Err(syn::Error::new_spanned(&meta, "unexpected nested attribute"));
157        },
158      }
159    }
160    Ok(args)
161  }
162}
163
164fn parse_duration_str(s: &str) -> syn::Result<u64> {
165  let s = s.trim();
166  if let Some(secs) = s.strip_suffix('s') {
167    secs
168      .trim()
169      .parse::<u64>()
170      .map(|v| v * 1000)
171      .map_err(|e| syn::Error::new(proc_macro2::Span::call_site(), format!("invalid timeout: {e}")))
172  } else if let Some(ms) = s.strip_suffix("ms") {
173    ms.trim()
174      .parse::<u64>()
175      .map_err(|e| syn::Error::new(proc_macro2::Span::call_site(), format!("invalid timeout: {e}")))
176  } else {
177    s.parse::<u64>().map_err(|e| {
178      syn::Error::new(
179        proc_macro2::Span::call_site(),
180        format!("invalid timeout (use '30s' or '5000ms'): {e}"),
181      )
182    })
183  }
184}
185
186/// `#[ferritest]` attribute macro.
187///
188/// Transforms an async function into a registered test case with automatic
189/// fixture injection based on parameter types.
190#[proc_macro_attribute]
191pub fn ferritest(attr: TokenStream, item: TokenStream) -> TokenStream {
192  let args = parse_macro_input!(attr as FerritestArgs);
193  let input = parse_macro_input!(item as ItemFn);
194
195  let fn_name = &input.sig.ident;
196  let fn_name_str = fn_name.to_string();
197  let vis = &input.vis;
198  let block = &input.block;
199  let attrs = &input.attrs;
200
201  // The function receives a TestContext. Extract the parameter name the user chose
202  // (e.g., `ctx`, `context`, `t`, etc.)
203  let ctx_param_name = if let Some(FnArg::Typed(pt)) = input.sig.inputs.first() {
204    if let Pat::Ident(pi) = pt.pat.as_ref() {
205      pi.ident.clone()
206    } else {
207      format_ident!("ctx")
208    }
209  } else {
210    format_ident!("ctx")
211  };
212
213  // Rust tests resolve built-in fixtures lazily via TestContext getters.
214  let fixture_names: Vec<String> = Vec::new();
215  let fixture_array = fixture_names.iter().map(|f| quote! { #f });
216
217  // Build annotations.
218  // Helper: parse "condition" or "condition | reason" into (condition, reason) tokens.
219  fn annotation_tokens(variant: &str, arg: &Option<Option<String>>, annotations: &mut Vec<proc_macro2::TokenStream>) {
220    let variant_ident = quote::format_ident!("{}", variant);
221    match arg {
222      Some(None) => {
223        annotations
224          .push(quote! { ferridriver_test::model::TestAnnotation::#variant_ident { reason: None, condition: None } });
225      },
226      Some(Some(val)) => {
227        // Support "condition | reason" format.
228        if let Some((cond, reason)) = val.split_once('|') {
229          let cond = cond.trim();
230          let reason = reason.trim();
231          annotations.push(quote! { ferridriver_test::model::TestAnnotation::#variant_ident {
232            reason: Some(#reason.to_string()),
233            condition: Some(#cond.to_string()),
234          } });
235        } else {
236          annotations.push(quote! { ferridriver_test::model::TestAnnotation::#variant_ident {
237            reason: None,
238            condition: Some(#val.to_string()),
239          } });
240        }
241      },
242      None => {},
243    }
244  }
245
246  let mut annotations = Vec::new();
247  annotation_tokens("Skip", &args.skip, &mut annotations);
248  annotation_tokens("Slow", &args.slow, &mut annotations);
249  annotation_tokens("Fixme", &args.fixme, &mut annotations);
250  annotation_tokens("Fail", &args.fail, &mut annotations);
251  if args.only {
252    annotations.push(quote! { ferridriver_test::model::TestAnnotation::Only });
253  }
254  for tag in &args.tags {
255    annotations.push(quote! { ferridriver_test::model::TestAnnotation::Tag(#tag.to_string()) });
256  }
257  for (type_name, desc) in &args.infos {
258    annotations.push(
259      quote! { ferridriver_test::model::TestAnnotation::Info { type_name: #type_name.to_string(), description: #desc.to_string() } },
260    );
261  }
262
263  let retries_expr = match args.retries {
264    Some(r) => quote! { Some(#r) },
265    None => quote! { None },
266  };
267  let timeout_ms_expr = match args.timeout_ms {
268    Some(ms) => quote! { Some(#ms) },
269    None => quote! { None },
270  };
271  let use_options_expr = match &args.use_options {
272    Some(json) => quote! { Some(#json) },
273    None => quote! { None },
274  };
275
276  let expanded = quote! {
277    #(#attrs)*
278    #vis async fn #fn_name(__pool: ferridriver_test::fixture::FixturePool) -> Result<(), ferridriver_test::model::TestFailure> {
279      let #ctx_param_name = ferridriver_test::TestContext::new(__pool);
280      #block
281      Ok(())
282    }
283
284    inventory::submit! {
285      ferridriver_test::discovery::TestRegistration {
286        file: file!(),
287        module_path: module_path!(),
288        name: #fn_name_str,
289        fixture_requests: &[#(#fixture_array),*],
290        annotations: &[#(#annotations),*],
291        timeout_ms: #timeout_ms_expr,
292        retries: #retries_expr,
293        use_options: #use_options_expr,
294        test_fn: |pool| Box::pin(#fn_name(pool)),
295      }
296    }
297  };
298
299  expanded.into()
300}
301
302/// Arguments for `#[ferritest_each]`: `data = [(...), (...)]`.
303struct FerritestEachArgs {
304  data: Vec<Vec<Expr>>,
305}
306
307impl Parse for FerritestEachArgs {
308  fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
309    // Parse: data = [(...), (...)]
310    let ident: syn::Ident = input.parse()?;
311    if ident != "data" {
312      return Err(syn::Error::new_spanned(&ident, "expected `data = [...]`"));
313    }
314    let _: Token![=] = input.parse()?;
315
316    let content;
317    syn::bracketed!(content in input);
318
319    let mut data = Vec::new();
320    while !content.is_empty() {
321      let inner;
322      syn::parenthesized!(inner in content);
323      let exprs: Punctuated<Expr, Token![,]> = Punctuated::parse_terminated(&inner)?;
324      data.push(exprs.into_iter().collect());
325
326      if content.peek(Token![,]) {
327        let _: Token![,] = content.parse()?;
328      }
329    }
330
331    Ok(Self { data })
332  }
333}
334
335/// `#[ferritest_each(data = [("a", 1), ("b", 2)])]` — parameterized test macro.
336///
337/// Expands a single async test function into N registered tests, one per data row.
338/// First parameter is `FixturePool`, remaining parameters receive the data values.
339///
340/// ```ignore
341/// #[ferritest_each(data = [("admin", "admin@example.com"), ("guest", "guest@example.com")])]
342/// async fn login(pool: FixturePool, role: &str, email: &str) {
343///     let page = pool.page().await.unwrap();
344///     page.goto(&format!("/login?role={role}"), None).await.unwrap();
345/// }
346/// ```
347/// Registers: `login (admin, admin@example.com)` and `login (guest, guest@example.com)`.
348#[proc_macro_attribute]
349pub fn ferritest_each(attr: TokenStream, item: TokenStream) -> TokenStream {
350  let args = parse_macro_input!(attr as FerritestEachArgs);
351  let input = parse_macro_input!(item as ItemFn);
352
353  let fn_name = &input.sig.ident;
354  let fn_name_str = fn_name.to_string();
355  let block = &input.block;
356  let attrs = &input.attrs;
357
358  // First param is TestContext, rest are data params.
359  let all_params: Vec<_> = input.sig.inputs.iter().collect();
360  let ctx_param_name = if let Some(FnArg::Typed(pt)) = all_params.first() {
361    if let Pat::Ident(pi) = pt.pat.as_ref() {
362      pi.ident.clone()
363    } else {
364      format_ident!("ctx")
365    }
366  } else {
367    format_ident!("ctx")
368  };
369
370  let data_params: Vec<(&syn::Ident, &Type)> = all_params
371    .iter()
372    .skip(1) // skip FixturePool
373    .filter_map(|arg| {
374      if let FnArg::Typed(pat_type) = arg {
375        if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
376          return Some((&pat_ident.ident, &*pat_type.ty));
377        }
378      }
379      None
380    })
381    .collect();
382
383  let fixture_names: Vec<String> = Vec::new();
384
385  // Generate one inventory::submit! per data row.
386  let mut submissions = Vec::new();
387  for (row_idx, row) in args.data.iter().enumerate() {
388    if row.len() != data_params.len() {
389      return syn::Error::new_spanned(
390        &input.sig.ident,
391        format!(
392          "data row {} has {} values but function expects {} data parameters",
393          row_idx,
394          row.len(),
395          data_params.len()
396        ),
397      )
398      .to_compile_error()
399      .into();
400    }
401
402    // Build name suffix: "(val1, val2)"
403    let row_values_str: Vec<String> = row.iter().map(|e| quote!(#e).to_string().replace('"', "")).collect();
404    let suffix = row_values_str.join(", ");
405    let test_name = format!("{fn_name_str} ({suffix})");
406
407    // Build let bindings for data params.
408    let data_bindings: Vec<_> = data_params
409      .iter()
410      .zip(row.iter())
411      .map(|((param_name, param_type), value)| {
412        quote! { let #param_name: #param_type = #value; }
413      })
414      .collect();
415
416    let inner_fn_name = format_ident!("__ferritest_each_{}_{}", fn_name, row_idx);
417    let fixture_array = fixture_names.iter().map(|f| quote! { #f });
418    let ctx_param = ctx_param_name.clone();
419
420    submissions.push(quote! {
421      async fn #inner_fn_name(__pool: ferridriver_test::fixture::FixturePool) -> Result<(), ferridriver_test::model::TestFailure> {
422        let #ctx_param = ferridriver_test::TestContext::new(__pool);
423        #(#data_bindings)*
424        #block
425        Ok(())
426      }
427
428      inventory::submit! {
429        ferridriver_test::discovery::TestRegistration {
430          file: file!(),
431          module_path: module_path!(),
432          name: #test_name,
433          fixture_requests: &[#(#fixture_array),*],
434          annotations: &[],
435          timeout_ms: None,
436          retries: None,
437          test_fn: |pool| Box::pin(#inner_fn_name(pool)),
438        }
439      }
440    });
441  }
442
443  let expanded = quote! {
444    #(#attrs)*
445    #(#submissions)*
446  };
447
448  expanded.into()
449}
450
451// ── Hook macros ──
452
453/// Shared implementation for all four hook macros.
454fn hook_impl(kind_tag: &str, is_suite_hook: bool, item: TokenStream) -> TokenStream {
455  let input = parse_macro_input!(item as ItemFn);
456  let fn_name = &input.sig.ident;
457  let vis = &input.vis;
458  let block = &input.block;
459  let attrs = &input.attrs;
460
461  let kind_ident = format_ident!("{}", kind_tag);
462
463  // Extract parameter name for TestContext.
464  let ctx_param_name = if let Some(FnArg::Typed(pt)) = input.sig.inputs.first() {
465    if let Pat::Ident(pi) = pt.pat.as_ref() {
466      pi.ident.clone()
467    } else {
468      format_ident!("ctx")
469    }
470  } else {
471    format_ident!("ctx")
472  };
473
474  if is_suite_hook {
475    // before_all / after_all: fn(FixturePool) -> Result
476    let expanded = quote! {
477      #(#attrs)*
478      #vis fn #fn_name(__pool: ferridriver_test::fixture::FixturePool)
479        -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = Result<(), ferridriver_test::model::TestFailure>> + Send>>
480      {
481        Box::pin(async move {
482          let #ctx_param_name = ferridriver_test::TestContext::new(__pool);
483          #block
484          Ok(())
485        })
486      }
487
488      inventory::submit! {
489        ferridriver_test::discovery::HookRegistration {
490          module_path: module_path!(),
491          suite_hook_fn: Some(#fn_name),
492          each_hook_fn: None,
493          kind: ferridriver_test::discovery::HookKindTag::#kind_ident,
494        }
495      }
496    };
497    expanded.into()
498  } else {
499    // before_each / after_each: fn(FixturePool, Arc<TestInfo>) -> Result
500    let expanded = quote! {
501      #(#attrs)*
502      #vis fn #fn_name(
503        __pool: ferridriver_test::fixture::FixturePool,
504        __info: ::std::sync::Arc<ferridriver_test::model::TestInfo>,
505      ) -> ::std::pin::Pin<Box<dyn ::std::future::Future<Output = Result<(), ferridriver_test::model::TestFailure>> + Send>>
506      {
507        Box::pin(async move {
508          let #ctx_param_name = ferridriver_test::TestContext::new(__pool);
509          #block
510          Ok(())
511        })
512      }
513
514      inventory::submit! {
515        ferridriver_test::discovery::HookRegistration {
516          module_path: module_path!(),
517          suite_hook_fn: None,
518          each_hook_fn: Some(#fn_name),
519          kind: ferridriver_test::discovery::HookKindTag::#kind_ident,
520        }
521      }
522    };
523    expanded.into()
524  }
525}
526
527/// Runs once before all tests in the containing module (suite).
528///
529/// ```ignore
530/// mod my_suite {
531///     use ferridriver_test::prelude::*;
532///
533///     #[before_all]
534///     async fn setup(ctx: TestContext) {
535///         // seed database, etc.
536///     }
537///
538///     #[ferritest]
539///     async fn test_one(ctx: TestContext) { ... }
540/// }
541/// ```
542#[proc_macro_attribute]
543pub fn before_all(_attr: TokenStream, item: TokenStream) -> TokenStream {
544  hook_impl("BeforeAll", true, item)
545}
546
547/// Runs once after all tests in the containing module (suite).
548#[proc_macro_attribute]
549pub fn after_all(_attr: TokenStream, item: TokenStream) -> TokenStream {
550  hook_impl("AfterAll", true, item)
551}
552
553/// Runs before each test in the containing module (suite).
554///
555/// ```ignore
556/// mod my_suite {
557///     use ferridriver_test::prelude::*;
558///
559///     #[before_each]
560///     async fn login(ctx: TestContext) {
561///         let page = ctx.page().await?;
562///         page.goto("/login", None).await?;
563///     }
564///
565///     #[ferritest]
566///     async fn dashboard_test(ctx: TestContext) { ... }
567/// }
568/// ```
569#[proc_macro_attribute]
570pub fn before_each(_attr: TokenStream, item: TokenStream) -> TokenStream {
571  hook_impl("BeforeEach", false, item)
572}
573
574/// Runs after each test in the containing module (suite), even on failure.
575#[proc_macro_attribute]
576pub fn after_each(_attr: TokenStream, item: TokenStream) -> TokenStream {
577  hook_impl("AfterEach", false, item)
578}