use crate::crate_paths::get_reinhardt_di_crate;
use crate::injectable_common::{
detect_inject_params, generate_di_context_extraction, generate_injection_calls,
strip_inject_attrs,
};
use proc_macro2::TokenStream;
use quote::quote;
use syn::{
Expr, ExprLit, FnArg, ItemFn, Lit, Meta, Result, Token, parse::Parser, punctuated::Punctuated,
};
#[cfg_attr(test, derive(Debug))]
pub(crate) struct ActionMeta {
#[allow(dead_code)]
pub methods: Vec<String>,
pub detail: bool,
pub url_path: String,
pub url_name: String,
}
pub(crate) fn parse_action_args_with_defaults(
args: TokenStream,
fn_ident: &syn::Ident,
) -> Result<ActionMeta> {
let mut methods = Vec::<String>::new();
let mut detail = false;
let mut url_path: Option<String> = None;
let mut url_name: Option<String> = None;
let mut has_methods = false;
let mut has_detail = false;
let meta_list = Punctuated::<Meta, Token![,]>::parse_terminated.parse2(args)?;
for meta in meta_list {
if let Meta::NameValue(nv) = meta {
if nv.path.is_ident("methods") {
let Expr::Lit(ExprLit {
lit: Lit::Str(lit), ..
}) = &nv.value
else {
return Err(syn::Error::new_spanned(
&nv.value,
"methods parameter must be a string literal",
));
};
has_methods = true;
let s = lit.value();
let s = s.trim_matches(|c| c == '[' || c == ']');
for m in s.split(',') {
let m = m.trim().trim_matches('"');
if !m.is_empty() {
methods.push(m.to_string());
}
}
} else if nv.path.is_ident("detail") {
let Expr::Lit(ExprLit {
lit: Lit::Bool(lit),
..
}) = &nv.value
else {
return Err(syn::Error::new_spanned(
&nv.value,
"detail parameter must be a boolean literal (true or false)",
));
};
has_detail = true;
detail = lit.value;
} else if nv.path.is_ident("url_path") {
let Expr::Lit(ExprLit {
lit: Lit::Str(lit), ..
}) = &nv.value
else {
return Err(syn::Error::new_spanned(
&nv.value,
"url_path parameter must be a string literal",
));
};
let p = lit.value();
if p.contains(' ') {
return Err(syn::Error::new_spanned(
lit,
"url_path cannot contain spaces",
));
}
if !p.starts_with('/') {
return Err(syn::Error::new_spanned(lit, "url_path must start with '/'"));
}
url_path = Some(p);
} else if nv.path.is_ident("url_name") {
let Expr::Lit(ExprLit {
lit: Lit::Str(lit), ..
}) = &nv.value
else {
return Err(syn::Error::new_spanned(
&nv.value,
"url_name parameter must be a string literal",
));
};
let value = lit.value();
if syn::parse_str::<syn::Ident>(&value).is_err() {
return Err(syn::Error::new_spanned(
lit,
format!(
"url_name `{value}` is not a valid Rust identifier (use snake_case ASCII letters, digits, and underscores; cannot start with a digit)"
),
));
}
url_name = Some(value);
}
}
}
if !has_methods {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"action macro requires 'methods' parameter",
));
}
if !has_detail {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"action macro requires 'detail' parameter",
));
}
if methods.is_empty() {
methods.push("GET".to_string());
}
Ok(ActionMeta {
methods,
detail,
url_path: url_path.unwrap_or_default(),
url_name: url_name.unwrap_or_else(|| fn_ident.to_string()),
})
}
pub(crate) fn action_impl(args: TokenStream, input: ItemFn) -> Result<TokenStream> {
let mut methods = Vec::new();
let mut detail = false;
let mut _url_path: Option<String> = None;
let mut _url_name: Option<String> = None;
let mut use_inject = false;
let mut methods_lit = None;
let mut has_methods = false;
let mut has_detail = false;
let meta_list = Punctuated::<Meta, Token![,]>::parse_terminated.parse2(args)?;
for meta in meta_list {
match meta {
Meta::NameValue(nv) => {
if nv.path.is_ident("methods") {
has_methods = true;
if let Expr::Lit(ExprLit {
lit: Lit::Str(lit), ..
}) = &nv.value
{
methods_lit = Some(lit.clone());
let methods_str = lit.value();
let methods_str = methods_str.trim_matches(|c| c == '[' || c == ']');
for method in methods_str.split(',') {
let method = method.trim().trim_matches('"');
if !method.is_empty() {
methods.push(method.to_string());
}
}
} else {
return Err(syn::Error::new_spanned(
&nv.value,
"methods parameter must be a string literal",
));
}
} else if nv.path.is_ident("detail") {
has_detail = true;
if let Expr::Lit(ExprLit {
lit: Lit::Bool(lit),
..
}) = &nv.value
{
detail = lit.value;
} else {
return Err(syn::Error::new_spanned(
&nv.value,
"detail parameter must be a boolean literal (true or false)",
));
}
} else if nv.path.is_ident("url_path") {
if let Expr::Lit(ExprLit {
lit: Lit::Str(lit), ..
}) = &nv.value
{
let path = lit.value();
if path.contains(' ') {
return Err(syn::Error::new_spanned(
lit,
"url_path cannot contain spaces",
));
}
if !path.starts_with('/') {
return Err(syn::Error::new_spanned(
lit,
"url_path must start with '/'",
));
}
_url_path = Some(path);
} else {
return Err(syn::Error::new_spanned(
&nv.value,
"url_path parameter must be a string literal",
));
}
} else if nv.path.is_ident("url_name") {
if let Expr::Lit(ExprLit {
lit: Lit::Str(lit), ..
}) = &nv.value
{
_url_name = Some(lit.value());
} else {
return Err(syn::Error::new_spanned(
&nv.value,
"url_name parameter must be a string literal",
));
}
} else if nv.path.is_ident("use_inject") {
if let Expr::Lit(ExprLit {
lit: Lit::Bool(lit),
..
}) = &nv.value
{
use_inject = lit.value;
} else {
return Err(syn::Error::new_spanned(
&nv.value,
"use_inject parameter must be a boolean literal (true or false)",
));
}
}
}
Meta::Path(path) => {
if path.is_ident("methods") {
return Err(syn::Error::new_spanned(
path,
"methods parameter requires a value: methods = \"GET,POST\"",
));
} else if path.is_ident("detail") {
return Err(syn::Error::new_spanned(
path,
"detail parameter requires a value: detail = true or detail = false",
));
}
}
_ => {}
}
}
if !has_methods {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"action macro requires 'methods' parameter",
));
}
if !has_detail {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"action macro requires 'detail' parameter",
));
}
if methods.is_empty() {
methods.push("GET".to_string());
}
const VALID_METHODS: &[&str] = &["GET", "POST", "PUT", "PATCH", "DELETE", "HEAD", "OPTIONS"];
for method in &methods {
let method_upper = method.to_uppercase();
if !VALID_METHODS.contains(&method_upper.as_str()) {
let error_msg = format!(
"Invalid HTTP method '{}'. Valid methods are: GET, POST, PUT, PATCH, DELETE, HEAD, OPTIONS",
method
);
return Err(syn::Error::new_spanned(
methods_lit.as_ref().unwrap(),
error_msg,
));
}
}
let fn_name = &input.sig.ident;
let fn_block = &input.block;
let fn_inputs = &input.sig.inputs;
let fn_output = &input.sig.output;
let fn_vis = &input.vis;
let fn_attrs = &input.attrs;
let asyncness = &input.sig.asyncness;
let generics = &input.sig.generics;
let where_clause = &input.sig.generics.where_clause;
let detail_flag = detail;
let method_list = methods.join(", ");
let inject_params = detect_inject_params(fn_inputs);
if !use_inject && !inject_params.is_empty() {
return Err(syn::Error::new_spanned(
&inject_params[0].pat,
"#[inject] attribute requires use_inject = true option",
));
}
if use_inject && !inject_params.is_empty() {
let original_fn_name = quote::format_ident!("{}_original", fn_name);
let stripped_inputs = strip_inject_attrs(fn_inputs);
let stripped_inputs = Punctuated::<FnArg, Token![,]>::from_iter(stripped_inputs);
let request_ident = syn::Ident::new("request", proc_macro2::Span::call_site());
let di_extraction = generate_di_context_extraction(&request_ident);
let injection_calls = generate_injection_calls(&inject_params);
let inject_args: Vec<_> = inject_params.iter().map(|p| &p.pat).collect();
let regular_args: Vec<_> = stripped_inputs
.iter()
.filter_map(|arg| {
if let FnArg::Typed(pat_type) = arg {
Some(&pat_type.pat)
} else {
None
}
})
.collect();
let di_crate = get_reinhardt_di_crate();
Ok(quote! {
#asyncness fn #original_fn_name #generics (#stripped_inputs) #fn_output #where_clause {
#fn_block
}
#(#fn_attrs)*
#[doc = "Custom action with DI support"]
#[doc = concat!("Methods: ", #method_list)]
#[doc = concat!("Detail: ", stringify!(#detail_flag))]
#fn_vis #asyncness fn #fn_name(request: ::Request) #fn_output {
#di_extraction
#di_crate::resolve_context::RESOLVE_CTX.scope(__resolve_ctx, async {
#(#injection_calls)*
#original_fn_name(#(#regular_args,)* #(#inject_args),*).await
}).await
}
})
} else {
Ok(quote! {
#(#fn_attrs)*
#[doc = "Custom action"]
#[doc = concat!("Methods: ", #method_list)]
#[doc = concat!("Detail: ", stringify!(#detail_flag))]
#fn_vis #asyncness fn #fn_name #generics (#fn_inputs) #fn_output #where_clause {
#fn_block
}
})
}
}
#[cfg(test)]
mod meta_extractor_tests {
use super::*;
use quote::quote;
#[test]
fn url_name_defaults_to_fn_name_when_absent() {
let args = quote! { methods = "POST", detail = true };
let fn_ident: syn::Ident = syn::parse_quote! { highlight };
let meta = parse_action_args_with_defaults(args, &fn_ident).unwrap();
assert_eq!(meta.url_name, "highlight");
assert!(meta.detail);
assert!(meta.url_path.is_empty());
}
#[test]
fn explicit_url_name_wins_over_fn_name() {
let args = quote! { methods = "POST", detail = true, url_name = "highlight_code" };
let fn_ident: syn::Ident = syn::parse_quote! { highlight };
let meta = parse_action_args_with_defaults(args, &fn_ident).unwrap();
assert_eq!(meta.url_name, "highlight_code");
}
#[test]
fn url_path_with_placeholders_is_preserved() {
let args = quote! {
methods = "GET",
detail = true,
url_name = "child",
url_path = "/children/{child_id}"
};
let fn_ident: syn::Ident = syn::parse_quote! { child };
let meta = parse_action_args_with_defaults(args, &fn_ident).unwrap();
assert_eq!(meta.url_path, "/children/{child_id}");
}
#[test]
fn missing_methods_errors() {
let args = quote! { detail = true };
let fn_ident: syn::Ident = syn::parse_quote! { x };
assert!(parse_action_args_with_defaults(args, &fn_ident).is_err());
}
#[test]
fn non_string_methods_value_errors() {
let args = quote! { methods = true, detail = false };
let fn_ident: syn::Ident = syn::parse_quote! { x };
let err = parse_action_args_with_defaults(args, &fn_ident)
.expect_err("non-string methods value must error");
assert!(
err.to_string()
.contains("methods parameter must be a string literal"),
"unexpected error message: {err}"
);
}
#[test]
fn non_bool_detail_value_errors() {
let args = quote! { methods = "GET", detail = "false" };
let fn_ident: syn::Ident = syn::parse_quote! { x };
let err = parse_action_args_with_defaults(args, &fn_ident)
.expect_err("non-boolean detail value must error");
assert!(
err.to_string()
.contains("detail parameter must be a boolean literal"),
"unexpected error message: {err}"
);
}
#[test]
fn non_string_url_path_value_errors() {
let args = quote! { methods = "GET", detail = true, url_path = 42 };
let fn_ident: syn::Ident = syn::parse_quote! { x };
let err = parse_action_args_with_defaults(args, &fn_ident)
.expect_err("non-string url_path value must error");
assert!(
err.to_string()
.contains("url_path parameter must be a string literal"),
"unexpected error message: {err}"
);
}
#[test]
fn non_string_url_name_value_errors() {
let args = quote! { methods = "GET", detail = true, url_name = false };
let fn_ident: syn::Ident = syn::parse_quote! { x };
let err = parse_action_args_with_defaults(args, &fn_ident)
.expect_err("non-string url_name value must error");
assert!(
err.to_string()
.contains("url_name parameter must be a string literal"),
"unexpected error message: {err}"
);
}
}