Skip to main content

permkit_auth_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{
3    format_ident,
4    quote,
5};
6use syn::punctuated::Punctuated;
7use syn::{
8    Expr,
9    ExprAssign,
10    ExprPath,
11    FnArg,
12    Ident,
13    ItemFn,
14    Pat,
15    PatType,
16    Path,
17    Token,
18    Type,
19    TypePath,
20    parse_macro_input,
21    parse_quote,
22};
23
24#[proc_macro_attribute]
25pub fn permissions(args: TokenStream, input: TokenStream) -> TokenStream {
26    let exprs = parse_macro_input!(args with Punctuated::<Expr, Token![,]>::parse_terminated);
27    let func = parse_macro_input!(input as ItemFn);
28
29    match expand_permissions_impl(func, &exprs) {
30        Ok(expanded) => quote!(#expanded).into(),
31        Err(error) => error.to_compile_error().into(),
32    }
33}
34
35#[derive(Default)]
36struct PermissionArgs {
37    context: Option<Expr>,
38    error: Option<Expr>,
39    permissions: Vec<Expr>,
40}
41
42fn expand_permissions_impl(
43    mut func: ItemFn,
44    exprs: &Punctuated<Expr, Token![,]>,
45) -> syn::Result<ItemFn> {
46    let args = parse_args(exprs)?;
47    let context = if let Some(context) = args.context {
48        context
49    } else {
50        let db_ident = ensure_typed_arg(
51            &mut func,
52            &syn::parse_quote!(crate::database::Database),
53            "db",
54        )?;
55        parse_quote!(#db_ident)
56    };
57    let denied_error = args.error.map_or_else(
58        || quote!(::permkit::PermissionDenied::permission_denied()),
59        |error| quote!(::core::convert::Into::into(#error)),
60    );
61
62    let checks = args.permissions.iter().map(|permission| {
63        quote! {
64            if !::permkit::HasPermission::has_permission(&(#permission), &(#context)).await? {
65                return Err(#denied_error);
66            }
67        }
68    });
69
70    let body = func.block;
71    func.block = Box::new(syn::parse_quote!({
72        #(#checks)*
73        #body
74    }));
75
76    Ok(func)
77}
78
79fn ensure_typed_arg(func: &mut ItemFn, ty: &Type, fallback: &str) -> syn::Result<Ident> {
80    let Type::Path(TypePath { path: expected, .. }) = ty else {
81        return Err(syn::Error::new_spanned(ty, "expected a path type"));
82    };
83
84    let expected_last = expected.segments.last().map(|segment| &segment.ident);
85    let matches_expected = |path: &Path| {
86        path == expected || path.segments.last().map(|segment| &segment.ident) == expected_last
87    };
88
89    if let Some(ident) = func.sig.inputs.iter().find_map(|arg| match arg {
90        FnArg::Typed(PatType { pat, ty, .. }) => {
91            let (Pat::Ident(pat), Type::Path(TypePath { path, .. })) = (pat.as_ref(), ty.as_ref())
92            else {
93                return None;
94            };
95
96            matches_expected(path).then_some(pat.ident.clone())
97        }
98        _ => None,
99    }) {
100        return Ok(ident);
101    }
102
103    let ident = format_ident!("{fallback}");
104    func.sig.inputs.insert(0, parse_quote! { #ident: #ty });
105    Ok(ident)
106}
107
108fn parse_args(exprs: &Punctuated<Expr, Token![,]>) -> syn::Result<PermissionArgs> {
109    let mut args = PermissionArgs::default();
110
111    for expr in exprs {
112        match expr {
113            Expr::Assign(assign) if is_assignment_to(assign, "context") => {
114                if args.context.replace((*assign.right).clone()).is_some() {
115                    return Err(syn::Error::new_spanned(expr, "duplicate `context = ...`"));
116                }
117            }
118            Expr::Assign(assign) if is_assignment_to(assign, "error") => {
119                if args.error.replace((*assign.right).clone()).is_some() {
120                    return Err(syn::Error::new_spanned(expr, "duplicate `error = ...`"));
121                }
122            }
123            Expr::Assign(assign) => {
124                return Err(syn::Error::new_spanned(
125                    assign,
126                    "unsupported assignment in `#[permissions(...)]`",
127                ));
128            }
129            _ => args.permissions.push(expr.clone()),
130        }
131    }
132
133    Ok(args)
134}
135
136fn is_assignment_to(assign: &ExprAssign, ident: &str) -> bool {
137    let Expr::Path(ExprPath { path, .. }) = assign.left.as_ref() else {
138        return false;
139    };
140
141    path.is_ident(ident)
142}
143
144#[cfg(test)]
145mod tests {
146    use quote::quote;
147    use syn::parse::Parser as _;
148    use syn::punctuated::Punctuated;
149    use syn::{
150        Expr,
151        ItemFn,
152        Token,
153    };
154
155    fn expand_permissions(
156        args: proc_macro2::TokenStream,
157        input: proc_macro2::TokenStream,
158    ) -> ItemFn {
159        let exprs = Punctuated::<Expr, Token![,]>::parse_terminated
160            .parse2(args)
161            .expect("attribute args should parse");
162        let func = syn::parse2::<ItemFn>(input).expect("function should parse");
163        super::expand_permissions_impl(func, &exprs).expect("failed to expand")
164    }
165
166    #[test]
167    fn inserts_permission_checks_before_handler_body() {
168        let input = quote! {
169            async fn sample(context: Context) -> Result<(), Error> {
170                Ok(())
171            }
172        };
173
174        let expanded = expand_permissions(
175            quote! {
176                Permission::Read,
177                context = context,
178                error = Error::Forbidden
179            },
180            input,
181        );
182        let block_tokens = {
183            let block = &expanded.block;
184            quote! { #block }.to_string()
185        };
186
187        assert!(block_tokens.contains("HasPermission :: has_permission"));
188        assert!(block_tokens.contains("Permission :: Read"));
189        assert!(block_tokens.contains("Error :: Forbidden"));
190        assert!(block_tokens.contains("Ok (())"));
191    }
192
193    #[test]
194    fn infers_database_context_and_backend_error() {
195        let input = quote! {
196            async fn sample() -> Result<(), Error> {
197                Ok(())
198            }
199        };
200
201        let expanded = expand_permissions(quote! { Permission::Read }, input);
202        let inputs = quote! { #expanded }.to_string();
203        let block_tokens = {
204            let block = &expanded.block;
205            quote! { #block }.to_string()
206        };
207
208        assert!(inputs.contains("db : crate :: database :: Database"));
209        assert!(block_tokens.contains("Permission :: Read"));
210        assert!(block_tokens.contains("PermissionDenied :: permission_denied"));
211    }
212}