1use proc_macro::TokenStream;
27use quote::quote;
28use syn::parse::{Parse, ParseStream};
29use syn::punctuated::Punctuated;
30use syn::{FnArg, ItemFn, Lit, Meta, Pat, Token, parse_macro_input};
31
32struct StepArgs {
35 expression: String,
36 is_regex: bool,
37}
38
39impl Parse for StepArgs {
40 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
41 if input.peek(syn::Ident) {
43 let ident: syn::Ident = input.fork().parse()?;
44 if ident == "regex" {
45 let _: syn::Ident = input.parse()?;
46 let _: Token![=] = input.parse()?;
47 let lit: Lit = input.parse()?;
48 return match lit {
49 Lit::Str(s) => Ok(Self {
50 expression: s.value(),
51 is_regex: true,
52 }),
53 _ => Err(syn::Error::new_spanned(lit, "expected a string literal regex pattern")),
54 };
55 }
56 }
57 let lit: Lit = input.parse()?;
59 match lit {
60 Lit::Str(s) => Ok(Self {
61 expression: s.value(),
62 is_regex: false,
63 }),
64 _ => Err(syn::Error::new_spanned(
65 lit,
66 "expected a string literal cucumber expression",
67 )),
68 }
69 }
70}
71
72struct HookArgs {
75 point: String,
76 tags: Option<String>,
77 order: i32,
78}
79
80impl Parse for HookArgs {
81 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
82 let point_ident: syn::Ident = input.parse()?;
83 let point = point_ident.to_string();
84
85 let valid = ["all", "feature", "scenario", "step"];
86 if !valid.contains(&point.as_str()) {
87 return Err(syn::Error::new_spanned(
88 &point_ident,
89 format!("expected one of: {}", valid.join(", ")),
90 ));
91 }
92
93 let mut tags = None;
94 let mut order = 0i32;
95
96 if input.peek(Token![,]) {
97 let metas = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
98 for meta in metas {
99 if let Meta::NameValue(nv) = &meta {
100 let ident = nv.path.get_ident().map(ToString::to_string).unwrap_or_default();
101 match ident.as_str() {
102 "tags" => {
103 if let syn::Expr::Lit(lit) = &nv.value {
104 if let Lit::Str(s) = &lit.lit {
105 tags = Some(s.value());
106 }
107 }
108 },
109 "order" => {
110 if let syn::Expr::Lit(lit) = &nv.value {
111 if let Lit::Int(i) = &lit.lit {
112 order = i.base10_parse()?;
113 }
114 }
115 },
116 _ => {
117 return Err(syn::Error::new_spanned(
118 &nv.path,
119 format!("unknown hook attribute: {ident}"),
120 ));
121 },
122 }
123 }
124 }
125 }
126
127 Ok(Self { point, tags, order })
128 }
129}
130
131fn generate_step(kind: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
134 let args = parse_macro_input!(attr as StepArgs);
135 let input = parse_macro_input!(item as ItemFn);
136
137 let fn_name = &input.sig.ident;
138 let fn_name_str = fn_name.to_string();
139 let vis = &input.vis;
140 let block = &input.block;
141 let attrs = &input.attrs;
142 let expression = &args.expression;
143 let is_regex = args.is_regex;
144
145 let kind_ident = syn::Ident::new(kind, proc_macro2::Span::call_site());
146
147 let mut param_extractions = Vec::new();
150 let mut param_names = Vec::new();
151 let mut param_idx = 0usize;
152
153 let inputs: Vec<_> = input.sig.inputs.iter().collect();
154 let special_params = ["table", "data_table", "docstring", "doc_string"];
155 for arg in inputs.iter().skip(1) {
156 if let FnArg::Typed(pat_type) = arg {
158 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
159 if special_params.contains(&pat_ident.ident.to_string().as_str()) {
161 continue;
162 }
163 let param_name = &pat_ident.ident;
164 let param_type = &pat_type.ty;
165 let idx = param_idx;
166
167 let extraction = type_to_extraction(param_type, idx);
168 param_extractions.push(quote! {
169 let #param_name: #param_type = #extraction;
170 });
171 param_names.push(quote! { #param_name });
172 param_idx += 1;
173 }
174 }
175 }
176
177 let has_table = inputs.iter().any(|arg| {
179 if let FnArg::Typed(pat_type) = arg {
180 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
181 return pat_ident.ident == "data_table" || pat_ident.ident == "table";
182 }
183 }
184 false
185 });
186
187 let has_docstring = inputs.iter().any(|arg| {
188 if let FnArg::Typed(pat_type) = arg {
189 if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
190 return pat_ident.ident == "docstring" || pat_ident.ident == "doc_string";
191 }
192 }
193 false
194 });
195
196 let table_binding = if has_table {
197 quote! { let table = __table; let data_table = __table; }
198 } else {
199 quote! { let _ = __table; }
200 };
201
202 let docstring_binding = if has_docstring {
203 quote! { let docstring = __docstring; let doc_string = __docstring; }
204 } else {
205 quote! { let _ = __docstring; }
206 };
207
208 let handler_name = syn::Ident::new(
209 &format!("__bdd_step_handler_{fn_name_str}"),
210 proc_macro2::Span::call_site(),
211 );
212 let reg_name = syn::Ident::new(&format!("__bdd_step_reg_{fn_name_str}"), proc_macro2::Span::call_site());
213
214 let expanded = quote! {
215 #(#attrs)*
216 #vis async fn #fn_name(
217 __world: &mut ferridriver_bdd::world::BrowserWorld,
218 __params: Vec<ferridriver_bdd::step::StepParam>,
219 __table: Option<&ferridriver_bdd::step::DataTable>,
220 __docstring: Option<&str>,
221 ) -> Result<(), ferridriver_bdd::step::StepError> {
222 #(#param_extractions)*
223 #table_binding
224 #docstring_binding
225 let world = __world;
226 #block
227 Ok(())
228 }
229
230 fn #handler_name() -> ferridriver_bdd::step::StepHandler {
231 std::sync::Arc::new(
232 |world, params, table, docstring| {
233 Box::pin(#fn_name(world, params, table, docstring))
234 },
235 )
236 }
237
238 ferridriver_bdd::submit_step! {
239 #reg_name,
240 ferridriver_bdd::step::StepKind::#kind_ident,
241 #expression,
242 #handler_name,
243 regex = #is_regex,
244 }
245 };
246
247 expanded.into()
248}
249
250fn type_to_extraction(ty: &syn::Type, idx: usize) -> proc_macro2::TokenStream {
251 let type_str = quote!(#ty).to_string();
252 match type_str.trim() {
253 "String" => quote! {
254 __params.get(#idx)
255 .and_then(|p| p.as_string())
256 .unwrap_or_default()
257 },
258 "i64" => quote! {
259 __params.get(#idx)
260 .and_then(|p| p.as_int())
261 .unwrap_or(0)
262 },
263 "f64" => quote! {
264 __params.get(#idx)
265 .and_then(|p| p.as_float())
266 .unwrap_or(0.0)
267 },
268 _ => quote! {
269 __params.get(#idx)
270 .and_then(|p| p.as_string())
271 .unwrap_or_default()
272 },
273 }
274}
275
276fn generate_hook(prefix: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
279 let args = parse_macro_input!(attr as HookArgs);
280 let input = parse_macro_input!(item as ItemFn);
281
282 let fn_name = &input.sig.ident;
283 let fn_name_str = fn_name.to_string();
284 let vis = &input.vis;
285 let block = &input.block;
286 let attrs = &input.attrs;
287
288 let point = &args.point;
289 let order = args.order;
290
291 let hook_point = match point.as_str() {
292 "all" => {
293 if prefix == "Before" {
294 quote! { ferridriver_bdd::hook::HookPoint::BeforeAll }
295 } else {
296 quote! { ferridriver_bdd::hook::HookPoint::AfterAll }
297 }
298 },
299 "feature" => {
300 if prefix == "Before" {
301 quote! { ferridriver_bdd::hook::HookPoint::BeforeFeature }
302 } else {
303 quote! { ferridriver_bdd::hook::HookPoint::AfterFeature }
304 }
305 },
306 "scenario" => {
307 if prefix == "Before" {
308 quote! { ferridriver_bdd::hook::HookPoint::BeforeScenario }
309 } else {
310 quote! { ferridriver_bdd::hook::HookPoint::AfterScenario }
311 }
312 },
313 "step" => {
314 if prefix == "Before" {
315 quote! { ferridriver_bdd::hook::HookPoint::BeforeStep }
316 } else {
317 quote! { ferridriver_bdd::hook::HookPoint::AfterStep }
318 }
319 },
320 _ => unreachable!(),
321 };
322
323 let tag_filter_expr = match &args.tags {
324 Some(tags) => quote! { Some(#tags.to_string()) },
325 None => quote! { None },
326 };
327
328 let is_global = point == "all";
330 let has_world_param = input.sig.inputs.iter().any(|arg| {
331 if let FnArg::Typed(_) = arg {
332 return true;
333 }
334 false
335 });
336
337 let handler_name = syn::Ident::new(
338 &format!("__bdd_hook_handler_{fn_name_str}"),
339 proc_macro2::Span::call_site(),
340 );
341 let reg_name = syn::Ident::new(&format!("__bdd_hook_reg_{fn_name_str}"), proc_macro2::Span::call_site());
342
343 let (fn_sig, handler_factory) = if is_global {
344 (
345 quote! {
346 #vis async fn #fn_name() -> ::std::result::Result<(), ::ferridriver::FerriError> {
347 #block
348 Ok(())
349 }
350 },
351 quote! {
352 fn #handler_name() -> ferridriver_bdd::hook::HookHandler {
353 ferridriver_bdd::hook::HookHandler::Global(std::sync::Arc::new(|| {
354 Box::pin(async { #fn_name().await })
355 }))
356 }
357 },
358 )
359 } else if has_world_param {
360 (
361 quote! {
362 #(#attrs)*
363 #vis async fn #fn_name(
364 world: &mut ferridriver_bdd::world::BrowserWorld,
365 ) -> ::std::result::Result<(), ::ferridriver::FerriError> {
366 #block
367 Ok(())
368 }
369 },
370 quote! {
371 fn #handler_name() -> ferridriver_bdd::hook::HookHandler {
372 ferridriver_bdd::hook::HookHandler::Scenario(std::sync::Arc::new(|world| {
373 Box::pin(async move { #fn_name(world).await })
374 }))
375 }
376 },
377 )
378 } else {
379 (
380 quote! {
381 #(#attrs)*
382 #vis async fn #fn_name() -> ::std::result::Result<(), ::ferridriver::FerriError> {
383 #block
384 Ok(())
385 }
386 },
387 quote! {
388 fn #handler_name() -> ferridriver_bdd::hook::HookHandler {
389 ferridriver_bdd::hook::HookHandler::Global(std::sync::Arc::new(|| {
390 Box::pin(async { #fn_name().await })
391 }))
392 }
393 },
394 )
395 };
396
397 let expanded = quote! {
398 #fn_sig
399 #handler_factory
400
401 ferridriver_bdd::submit_hook! {
402 #reg_name,
403 #hook_point,
404 #tag_filter_expr,
405 #order,
406 #handler_name,
407 }
408 };
409
410 expanded.into()
411}
412
413struct ParamTypeArgs {
416 name: String,
417 regex: String,
418}
419
420impl Parse for ParamTypeArgs {
421 fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
422 let metas = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
423 let mut name = None;
424 let mut regex = None;
425
426 for meta in metas {
427 if let Meta::NameValue(nv) = &meta {
428 let ident = nv.path.get_ident().map(ToString::to_string).unwrap_or_default();
429 if let syn::Expr::Lit(lit) = &nv.value {
430 if let Lit::Str(s) = &lit.lit {
431 match ident.as_str() {
432 "name" => name = Some(s.value()),
433 "regex" => regex = Some(s.value()),
434 _ => {
435 return Err(syn::Error::new_spanned(
436 &nv.path,
437 format!("unknown param_type attribute: {ident} (expected name, regex)"),
438 ));
439 },
440 }
441 }
442 }
443 }
444 }
445
446 Ok(Self {
447 name: name.ok_or_else(|| syn::Error::new(input.span(), "missing `name` attribute"))?,
448 regex: regex.ok_or_else(|| syn::Error::new(input.span(), "missing `regex` attribute"))?,
449 })
450 }
451}
452
453#[proc_macro_attribute]
464pub fn given(attr: TokenStream, item: TokenStream) -> TokenStream {
465 generate_step("Given", attr, item)
466}
467
468#[proc_macro_attribute]
477pub fn when(attr: TokenStream, item: TokenStream) -> TokenStream {
478 generate_step("When", attr, item)
479}
480
481#[proc_macro_attribute]
491pub fn then(attr: TokenStream, item: TokenStream) -> TokenStream {
492 generate_step("Then", attr, item)
493}
494
495#[proc_macro_attribute]
504pub fn step(attr: TokenStream, item: TokenStream) -> TokenStream {
505 generate_step("Step", attr, item)
506}
507
508#[proc_macro_attribute]
521pub fn before(attr: TokenStream, item: TokenStream) -> TokenStream {
522 generate_hook("Before", attr, item)
523}
524
525#[proc_macro_attribute]
538pub fn after(attr: TokenStream, item: TokenStream) -> TokenStream {
539 generate_hook("After", attr, item)
540}
541
542#[proc_macro_attribute]
560pub fn param_type(attr: TokenStream, item: TokenStream) -> TokenStream {
561 let args = parse_macro_input!(attr as ParamTypeArgs);
562 let input = parse_macro_input!(item as ItemFn);
563
564 let fn_name = &input.sig.ident;
565 let fn_name_str = fn_name.to_string();
566 let name = &args.name;
567 let regex = &args.regex;
568
569 let _reg_name = syn::Ident::new(
570 &format!("__bdd_param_type_reg_{fn_name_str}"),
571 proc_macro2::Span::call_site(),
572 );
573
574 let expanded = quote! {
575 ferridriver_bdd::inventory::submit! {
576 ferridriver_bdd::param_type::ParameterTypeRegistration {
577 name: #name,
578 regex: #regex,
579 transformer_factory: None,
580 }
581 }
582
583 #[allow(dead_code)]
585 fn #fn_name() {}
586 };
587
588 expanded.into()
589}