use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::parse::Parser as _;
use syn::{Expr, ExprLit, Ident, ItemFn, Lit, LitStr, Meta, Token, parse_quote};
#[derive(Default)]
struct AuthorizeArgs {
action: Option<String>,
resource: Option<Ident>,
from: Option<Ident>,
}
fn parse_authorize_args(attr: TokenStream) -> syn::Result<AuthorizeArgs> {
if attr.is_empty() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"#[authorize] requires an action argument: #[authorize(\"update\", resource = Type)]",
));
}
let metas = syn::punctuated::Punctuated::<Meta, Token![,]>::parse_terminated.parse2(attr)?;
let mut args = AuthorizeArgs::default();
for meta in metas {
match meta {
Meta::Path(p) => {
if let Some(ident) = p.get_ident()
&& args.action.is_none()
{
args.action = Some(ident.to_string());
continue;
}
return Err(syn::Error::new_spanned(
p,
"expected `action` literal or `key = value`",
));
}
Meta::NameValue(nv) => {
let key = nv
.path
.get_ident()
.ok_or_else(|| syn::Error::new_spanned(&nv.path, "expected identifier"))?
.to_string();
match key.as_str() {
"resource" => {
let ident = expect_ident(&nv.value, "resource = TypeName")?;
args.resource = Some(ident);
}
"from" => {
let ident = expect_ident(&nv.value, "from = param_name")?;
args.from = Some(ident);
}
other => {
return Err(syn::Error::new_spanned(
&nv.path,
format!("unknown #[authorize] key: {other}"),
));
}
}
}
Meta::List(l) => {
if l.path.is_ident("action") {
let lit: LitStr = syn::parse2(l.tokens.clone())?;
args.action = Some(lit.value());
} else {
return Err(syn::Error::new_spanned(
&l.path,
"unexpected list-style argument",
));
}
}
}
}
if let Some(action) = first_string_literal(args.action.as_ref()) {
args.action = Some(action);
}
Ok(args)
}
fn first_string_literal(action: Option<&String>) -> Option<String> {
action.and_then(|s| {
let trimmed = s.trim();
if (trimmed.starts_with('"') && trimmed.ends_with('"'))
|| (trimmed.starts_with('\'') && trimmed.ends_with('\''))
{
Some(trimmed[1..trimmed.len() - 1].to_owned())
} else {
None
}
})
}
fn expect_ident(expr: &Expr, hint: &str) -> syn::Result<Ident> {
match expr {
Expr::Path(p) if p.path.get_ident().is_some() => Ok(p.path.get_ident().unwrap().clone()),
Expr::Lit(ExprLit {
lit: Lit::Str(s), ..
}) => Ok(format_ident!("{}", s.value())),
_ => Err(syn::Error::new_spanned(expr, format!("expected `{hint}`"))),
}
}
use crate::param_helpers::has_input_named;
fn snake_case(name: &str) -> String {
let mut out = String::new();
for (i, ch) in name.chars().enumerate() {
if ch.is_uppercase() {
if i > 0 {
out.push('_');
}
out.push(ch.to_ascii_lowercase());
} else {
out.push(ch);
}
}
out
}
pub fn authorize_macro(attr: TokenStream, item: TokenStream) -> TokenStream {
let mut args = match parse_with_leading_literal(attr) {
Ok(a) => a,
Err(err) => return err.to_compile_error(),
};
let Some(action_str) = args.action.take() else {
return syn::Error::new(
proc_macro2::Span::call_site(),
"#[authorize] requires an action: #[authorize(\"update\", resource = Type)]",
)
.to_compile_error();
};
let Some(resource_ident) = args.resource else {
return syn::Error::new(
proc_macro2::Span::call_site(),
"#[authorize] requires `resource = TypeName`",
)
.to_compile_error();
};
let from_ident = args.from.unwrap_or_else(|| {
let name = snake_case(&resource_ident.to_string());
format_ident!("{}", name)
});
let mut input_fn: ItemFn = match syn::parse2(item) {
Ok(f) => f,
Err(err) => return err.to_compile_error(),
};
if input_fn.sig.asyncness.is_none() {
return syn::Error::new_spanned(
input_fn.sig.fn_token,
"#[authorize] can only be applied to async functions",
)
.to_compile_error();
}
if !has_input_named(&input_fn, "__autumn_state") {
let state_param: syn::FnArg = parse_quote! {
::autumn_web::reexports::axum::extract::State(__autumn_state):
::autumn_web::reexports::axum::extract::State<::autumn_web::AppState>
};
input_fn.sig.inputs.insert(0, state_param);
}
if !has_input_named(&input_fn, "__autumn_session") {
let session_param: syn::FnArg = parse_quote! {
__autumn_session: ::autumn_web::session::Session
};
input_fn.sig.inputs.insert(0, session_param);
}
let action_lit = syn::LitStr::new(&action_str, proc_macro2::Span::call_site());
let original_body = &input_fn.block;
input_fn.block = parse_quote! {
{
::autumn_web::authorization::__check_policy::<#resource_ident>(
&__autumn_state,
&__autumn_session,
#action_lit,
&#from_ident,
).await?;
#original_body
}
};
quote! { #input_fn }
}
fn parse_with_leading_literal(attr: TokenStream) -> syn::Result<AuthorizeArgs> {
use proc_macro2::TokenTree;
let mut iter = attr.into_iter().peekable();
let mut leading_action: Option<String> = None;
if let Some(TokenTree::Literal(lit)) = iter.peek() {
let lit_str = lit.to_string();
if (lit_str.starts_with('"') && lit_str.ends_with('"'))
|| (lit_str.starts_with('\'') && lit_str.ends_with('\''))
{
let s: LitStr = syn::parse2(quote! { #lit })?;
leading_action = Some(s.value());
iter.next();
if let Some(TokenTree::Punct(p)) = iter.peek()
&& p.as_char() == ','
{
iter.next();
}
}
}
let rest: TokenStream = iter.collect();
let mut parsed = if rest.is_empty() {
AuthorizeArgs::default()
} else {
parse_authorize_args(rest)?
};
if let Some(action) = leading_action {
parsed.action = Some(action);
}
Ok(parsed)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_action_and_resource() {
let tokens: TokenStream = r#""update", resource = Post"#.parse().unwrap();
let args = parse_with_leading_literal(tokens).unwrap();
assert_eq!(args.action.as_deref(), Some("update"));
assert_eq!(args.resource.unwrap().to_string(), "Post");
}
#[test]
fn parses_with_explicit_from() {
let tokens: TokenStream = r#""delete", resource = Post, from = the_post"#.parse().unwrap();
let args = parse_with_leading_literal(tokens).unwrap();
assert_eq!(args.action.as_deref(), Some("delete"));
assert_eq!(args.from.unwrap().to_string(), "the_post");
}
#[test]
fn rejects_missing_action() {
let tokens: TokenStream = "resource = Post".parse().unwrap();
let args = parse_with_leading_literal(tokens).unwrap();
assert!(args.action.is_none());
}
#[test]
fn snake_case_handles_pascal_case() {
assert_eq!(snake_case("Post"), "post");
assert_eq!(snake_case("BlogPost"), "blog_post");
assert_eq!(snake_case("HTTPRequest"), "h_t_t_p_request");
}
}