use proc_macro::TokenStream;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{FnArg, ItemFn, Lit, Meta, Pat, Token, parse_macro_input};
struct StepArgs {
expression: String,
is_regex: bool,
}
impl Parse for StepArgs {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
if input.peek(syn::Ident) {
let ident: syn::Ident = input.fork().parse()?;
if ident == "regex" {
let _: syn::Ident = input.parse()?;
let _: Token![=] = input.parse()?;
let lit: Lit = input.parse()?;
return match lit {
Lit::Str(s) => Ok(Self {
expression: s.value(),
is_regex: true,
}),
_ => Err(syn::Error::new_spanned(lit, "expected a string literal regex pattern")),
};
}
}
let lit: Lit = input.parse()?;
match lit {
Lit::Str(s) => Ok(Self {
expression: s.value(),
is_regex: false,
}),
_ => Err(syn::Error::new_spanned(
lit,
"expected a string literal cucumber expression",
)),
}
}
}
struct HookArgs {
point: String,
tags: Option<String>,
order: i32,
}
impl Parse for HookArgs {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let point_ident: syn::Ident = input.parse()?;
let point = point_ident.to_string();
let valid = ["all", "feature", "scenario", "step"];
if !valid.contains(&point.as_str()) {
return Err(syn::Error::new_spanned(
&point_ident,
format!("expected one of: {}", valid.join(", ")),
));
}
let mut tags = None;
let mut order = 0i32;
if input.peek(Token![,]) {
let metas = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
for meta in metas {
if let Meta::NameValue(nv) = &meta {
let ident = nv.path.get_ident().map(ToString::to_string).unwrap_or_default();
match ident.as_str() {
"tags" => {
if let syn::Expr::Lit(lit) = &nv.value {
if let Lit::Str(s) = &lit.lit {
tags = Some(s.value());
}
}
},
"order" => {
if let syn::Expr::Lit(lit) = &nv.value {
if let Lit::Int(i) = &lit.lit {
order = i.base10_parse()?;
}
}
},
_ => {
return Err(syn::Error::new_spanned(
&nv.path,
format!("unknown hook attribute: {ident}"),
));
},
}
}
}
}
Ok(Self { point, tags, order })
}
}
fn generate_step(kind: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as StepArgs);
let input = parse_macro_input!(item as ItemFn);
let fn_name = &input.sig.ident;
let fn_name_str = fn_name.to_string();
let vis = &input.vis;
let block = &input.block;
let attrs = &input.attrs;
let expression = &args.expression;
let is_regex = args.is_regex;
let kind_ident = syn::Ident::new(kind, proc_macro2::Span::call_site());
let mut param_extractions = Vec::new();
let mut param_names = Vec::new();
let mut param_idx = 0usize;
let inputs: Vec<_> = input.sig.inputs.iter().collect();
let special_params = ["table", "data_table", "docstring", "doc_string"];
for arg in inputs.iter().skip(1) {
if let FnArg::Typed(pat_type) = arg {
if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
if special_params.contains(&pat_ident.ident.to_string().as_str()) {
continue;
}
let param_name = &pat_ident.ident;
let param_type = &pat_type.ty;
let idx = param_idx;
let extraction = type_to_extraction(param_type, idx);
param_extractions.push(quote! {
let #param_name: #param_type = #extraction;
});
param_names.push(quote! { #param_name });
param_idx += 1;
}
}
}
let has_table = inputs.iter().any(|arg| {
if let FnArg::Typed(pat_type) = arg {
if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
return pat_ident.ident == "data_table" || pat_ident.ident == "table";
}
}
false
});
let has_docstring = inputs.iter().any(|arg| {
if let FnArg::Typed(pat_type) = arg {
if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
return pat_ident.ident == "docstring" || pat_ident.ident == "doc_string";
}
}
false
});
let table_binding = if has_table {
quote! { let table = __table; let data_table = __table; }
} else {
quote! { let _ = __table; }
};
let docstring_binding = if has_docstring {
quote! { let docstring = __docstring; let doc_string = __docstring; }
} else {
quote! { let _ = __docstring; }
};
let handler_name = syn::Ident::new(
&format!("__bdd_step_handler_{fn_name_str}"),
proc_macro2::Span::call_site(),
);
let reg_name = syn::Ident::new(&format!("__bdd_step_reg_{fn_name_str}"), proc_macro2::Span::call_site());
let expanded = quote! {
#(#attrs)*
#vis async fn #fn_name(
__world: &mut ferridriver_bdd::world::BrowserWorld,
__params: Vec<ferridriver_bdd::step::StepParam>,
__table: Option<&ferridriver_bdd::step::DataTable>,
__docstring: Option<&str>,
) -> Result<(), ferridriver_bdd::step::StepError> {
#(#param_extractions)*
#table_binding
#docstring_binding
let world = __world;
#block
Ok(())
}
fn #handler_name() -> ferridriver_bdd::step::StepHandler {
std::sync::Arc::new(
|world, params, table, docstring| {
Box::pin(#fn_name(world, params, table, docstring))
},
)
}
ferridriver_bdd::submit_step! {
#reg_name,
ferridriver_bdd::step::StepKind::#kind_ident,
#expression,
#handler_name,
regex = #is_regex,
}
};
expanded.into()
}
fn type_to_extraction(ty: &syn::Type, idx: usize) -> proc_macro2::TokenStream {
let type_str = quote!(#ty).to_string();
match type_str.trim() {
"String" => quote! {
__params.get(#idx)
.and_then(|p| p.as_string())
.unwrap_or_default()
},
"i64" => quote! {
__params.get(#idx)
.and_then(|p| p.as_int())
.unwrap_or(0)
},
"f64" => quote! {
__params.get(#idx)
.and_then(|p| p.as_float())
.unwrap_or(0.0)
},
_ => quote! {
__params.get(#idx)
.and_then(|p| p.as_string())
.unwrap_or_default()
},
}
}
fn generate_hook(prefix: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as HookArgs);
let input = parse_macro_input!(item as ItemFn);
let fn_name = &input.sig.ident;
let fn_name_str = fn_name.to_string();
let vis = &input.vis;
let block = &input.block;
let attrs = &input.attrs;
let point = &args.point;
let order = args.order;
let hook_point = match point.as_str() {
"all" => {
if prefix == "Before" {
quote! { ferridriver_bdd::hook::HookPoint::BeforeAll }
} else {
quote! { ferridriver_bdd::hook::HookPoint::AfterAll }
}
},
"feature" => {
if prefix == "Before" {
quote! { ferridriver_bdd::hook::HookPoint::BeforeFeature }
} else {
quote! { ferridriver_bdd::hook::HookPoint::AfterFeature }
}
},
"scenario" => {
if prefix == "Before" {
quote! { ferridriver_bdd::hook::HookPoint::BeforeScenario }
} else {
quote! { ferridriver_bdd::hook::HookPoint::AfterScenario }
}
},
"step" => {
if prefix == "Before" {
quote! { ferridriver_bdd::hook::HookPoint::BeforeStep }
} else {
quote! { ferridriver_bdd::hook::HookPoint::AfterStep }
}
},
_ => unreachable!(),
};
let tag_filter_expr = match &args.tags {
Some(tags) => quote! { Some(#tags.to_string()) },
None => quote! { None },
};
let is_global = point == "all";
let has_world_param = input.sig.inputs.iter().any(|arg| {
if let FnArg::Typed(_) = arg {
return true;
}
false
});
let handler_name = syn::Ident::new(
&format!("__bdd_hook_handler_{fn_name_str}"),
proc_macro2::Span::call_site(),
);
let reg_name = syn::Ident::new(&format!("__bdd_hook_reg_{fn_name_str}"), proc_macro2::Span::call_site());
let (fn_sig, handler_factory) = if is_global {
(
quote! {
#vis async fn #fn_name() -> ::std::result::Result<(), ::ferridriver::FerriError> {
#block
Ok(())
}
},
quote! {
fn #handler_name() -> ferridriver_bdd::hook::HookHandler {
ferridriver_bdd::hook::HookHandler::Global(std::sync::Arc::new(|| {
Box::pin(async { #fn_name().await })
}))
}
},
)
} else if has_world_param {
(
quote! {
#(#attrs)*
#vis async fn #fn_name(
world: &mut ferridriver_bdd::world::BrowserWorld,
) -> ::std::result::Result<(), ::ferridriver::FerriError> {
#block
Ok(())
}
},
quote! {
fn #handler_name() -> ferridriver_bdd::hook::HookHandler {
ferridriver_bdd::hook::HookHandler::Scenario(std::sync::Arc::new(|world| {
Box::pin(async move { #fn_name(world).await })
}))
}
},
)
} else {
(
quote! {
#(#attrs)*
#vis async fn #fn_name() -> ::std::result::Result<(), ::ferridriver::FerriError> {
#block
Ok(())
}
},
quote! {
fn #handler_name() -> ferridriver_bdd::hook::HookHandler {
ferridriver_bdd::hook::HookHandler::Global(std::sync::Arc::new(|| {
Box::pin(async { #fn_name().await })
}))
}
},
)
};
let expanded = quote! {
#fn_sig
#handler_factory
ferridriver_bdd::submit_hook! {
#reg_name,
#hook_point,
#tag_filter_expr,
#order,
#handler_name,
}
};
expanded.into()
}
struct ParamTypeArgs {
name: String,
regex: String,
}
impl Parse for ParamTypeArgs {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let metas = Punctuated::<Meta, Token![,]>::parse_terminated(input)?;
let mut name = None;
let mut regex = None;
for meta in metas {
if let Meta::NameValue(nv) = &meta {
let ident = nv.path.get_ident().map(ToString::to_string).unwrap_or_default();
if let syn::Expr::Lit(lit) = &nv.value {
if let Lit::Str(s) = &lit.lit {
match ident.as_str() {
"name" => name = Some(s.value()),
"regex" => regex = Some(s.value()),
_ => {
return Err(syn::Error::new_spanned(
&nv.path,
format!("unknown param_type attribute: {ident} (expected name, regex)"),
));
},
}
}
}
}
}
Ok(Self {
name: name.ok_or_else(|| syn::Error::new(input.span(), "missing `name` attribute"))?,
regex: regex.ok_or_else(|| syn::Error::new(input.span(), "missing `regex` attribute"))?,
})
}
}
#[proc_macro_attribute]
pub fn given(attr: TokenStream, item: TokenStream) -> TokenStream {
generate_step("Given", attr, item)
}
#[proc_macro_attribute]
pub fn when(attr: TokenStream, item: TokenStream) -> TokenStream {
generate_step("When", attr, item)
}
#[proc_macro_attribute]
pub fn then(attr: TokenStream, item: TokenStream) -> TokenStream {
generate_step("Then", attr, item)
}
#[proc_macro_attribute]
pub fn step(attr: TokenStream, item: TokenStream) -> TokenStream {
generate_step("Step", attr, item)
}
#[proc_macro_attribute]
pub fn before(attr: TokenStream, item: TokenStream) -> TokenStream {
generate_hook("Before", attr, item)
}
#[proc_macro_attribute]
pub fn after(attr: TokenStream, item: TokenStream) -> TokenStream {
generate_hook("After", attr, item)
}
#[proc_macro_attribute]
pub fn param_type(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as ParamTypeArgs);
let input = parse_macro_input!(item as ItemFn);
let fn_name = &input.sig.ident;
let fn_name_str = fn_name.to_string();
let name = &args.name;
let regex = &args.regex;
let _reg_name = syn::Ident::new(
&format!("__bdd_param_type_reg_{fn_name_str}"),
proc_macro2::Span::call_site(),
);
let expanded = quote! {
ferridriver_bdd::inventory::submit! {
ferridriver_bdd::param_type::ParameterTypeRegistration {
name: #name,
regex: #regex,
transformer_factory: None,
}
}
#[allow(dead_code)]
fn #fn_name() {}
};
expanded.into()
}