permkit_auth_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::{
3 format_ident,
4 quote,
5};
6use syn::punctuated::Punctuated;
7use syn::{
8 Expr,
9 ExprAssign,
10 ExprPath,
11 FnArg,
12 Ident,
13 ItemFn,
14 Pat,
15 PatType,
16 Path,
17 Token,
18 Type,
19 TypePath,
20 parse_macro_input,
21 parse_quote,
22};
23
24#[proc_macro_attribute]
25pub fn permissions(args: TokenStream, input: TokenStream) -> TokenStream {
26 let exprs = parse_macro_input!(args with Punctuated::<Expr, Token![,]>::parse_terminated);
27 let func = parse_macro_input!(input as ItemFn);
28
29 match expand_permissions_impl(func, &exprs) {
30 Ok(expanded) => quote!(#expanded).into(),
31 Err(error) => error.to_compile_error().into(),
32 }
33}
34
35#[derive(Default)]
36struct PermissionArgs {
37 context: Option<Expr>,
38 error: Option<Expr>,
39 permissions: Vec<Expr>,
40}
41
42fn expand_permissions_impl(
43 mut func: ItemFn,
44 exprs: &Punctuated<Expr, Token![,]>,
45) -> syn::Result<ItemFn> {
46 let args = parse_args(exprs)?;
47 let context = if let Some(context) = args.context {
48 context
49 } else {
50 let db_ident = ensure_typed_arg(
51 &mut func,
52 &syn::parse_quote!(crate::database::Database),
53 "db",
54 )?;
55 parse_quote!(#db_ident)
56 };
57 let denied_error = args.error.map_or_else(
58 || quote!(::permkit::PermissionDenied::permission_denied()),
59 |error| quote!(::core::convert::Into::into(#error)),
60 );
61
62 let checks = args.permissions.iter().map(|permission| {
63 quote! {
64 if !::permkit::HasPermission::has_permission(&(#permission), &(#context)).await? {
65 return Err(#denied_error);
66 }
67 }
68 });
69
70 let body = func.block;
71 func.block = Box::new(syn::parse_quote!({
72 #(#checks)*
73 #body
74 }));
75
76 Ok(func)
77}
78
79fn ensure_typed_arg(func: &mut ItemFn, ty: &Type, fallback: &str) -> syn::Result<Ident> {
80 let Type::Path(TypePath { path: expected, .. }) = ty else {
81 return Err(syn::Error::new_spanned(ty, "expected a path type"));
82 };
83
84 let expected_last = expected.segments.last().map(|segment| &segment.ident);
85 let matches_expected = |path: &Path| {
86 path == expected || path.segments.last().map(|segment| &segment.ident) == expected_last
87 };
88
89 if let Some(ident) = func.sig.inputs.iter().find_map(|arg| match arg {
90 FnArg::Typed(PatType { pat, ty, .. }) => {
91 let (Pat::Ident(pat), Type::Path(TypePath { path, .. })) = (pat.as_ref(), ty.as_ref())
92 else {
93 return None;
94 };
95
96 matches_expected(path).then_some(pat.ident.clone())
97 }
98 _ => None,
99 }) {
100 return Ok(ident);
101 }
102
103 let ident = format_ident!("{fallback}");
104 func.sig.inputs.insert(0, parse_quote! { #ident: #ty });
105 Ok(ident)
106}
107
108fn parse_args(exprs: &Punctuated<Expr, Token![,]>) -> syn::Result<PermissionArgs> {
109 let mut args = PermissionArgs::default();
110
111 for expr in exprs {
112 match expr {
113 Expr::Assign(assign) if is_assignment_to(assign, "context") => {
114 if args.context.replace((*assign.right).clone()).is_some() {
115 return Err(syn::Error::new_spanned(expr, "duplicate `context = ...`"));
116 }
117 }
118 Expr::Assign(assign) if is_assignment_to(assign, "error") => {
119 if args.error.replace((*assign.right).clone()).is_some() {
120 return Err(syn::Error::new_spanned(expr, "duplicate `error = ...`"));
121 }
122 }
123 Expr::Assign(assign) => {
124 return Err(syn::Error::new_spanned(
125 assign,
126 "unsupported assignment in `#[permissions(...)]`",
127 ));
128 }
129 _ => args.permissions.push(expr.clone()),
130 }
131 }
132
133 Ok(args)
134}
135
136fn is_assignment_to(assign: &ExprAssign, ident: &str) -> bool {
137 let Expr::Path(ExprPath { path, .. }) = assign.left.as_ref() else {
138 return false;
139 };
140
141 path.is_ident(ident)
142}
143
144#[cfg(test)]
145mod tests {
146 use quote::quote;
147 use syn::parse::Parser as _;
148 use syn::punctuated::Punctuated;
149 use syn::{
150 Expr,
151 ItemFn,
152 Token,
153 };
154
155 fn expand_permissions(
156 args: proc_macro2::TokenStream,
157 input: proc_macro2::TokenStream,
158 ) -> ItemFn {
159 let exprs = Punctuated::<Expr, Token![,]>::parse_terminated
160 .parse2(args)
161 .expect("attribute args should parse");
162 let func = syn::parse2::<ItemFn>(input).expect("function should parse");
163 super::expand_permissions_impl(func, &exprs).expect("failed to expand")
164 }
165
166 #[test]
167 fn inserts_permission_checks_before_handler_body() {
168 let input = quote! {
169 async fn sample(context: Context) -> Result<(), Error> {
170 Ok(())
171 }
172 };
173
174 let expanded = expand_permissions(
175 quote! {
176 Permission::Read,
177 context = context,
178 error = Error::Forbidden
179 },
180 input,
181 );
182 let block_tokens = {
183 let block = &expanded.block;
184 quote! { #block }.to_string()
185 };
186
187 assert!(block_tokens.contains("HasPermission :: has_permission"));
188 assert!(block_tokens.contains("Permission :: Read"));
189 assert!(block_tokens.contains("Error :: Forbidden"));
190 assert!(block_tokens.contains("Ok (())"));
191 }
192
193 #[test]
194 fn infers_database_context_and_backend_error() {
195 let input = quote! {
196 async fn sample() -> Result<(), Error> {
197 Ok(())
198 }
199 };
200
201 let expanded = expand_permissions(quote! { Permission::Read }, input);
202 let inputs = quote! { #expanded }.to_string();
203 let block_tokens = {
204 let block = &expanded.block;
205 quote! { #block }.to_string()
206 };
207
208 assert!(inputs.contains("db : crate :: database :: Database"));
209 assert!(block_tokens.contains("Permission :: Read"));
210 assert!(block_tokens.contains("PermissionDenied :: permission_denied"));
211 }
212}