Skip to main content

netdb_auth_macro_derive/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, ItemFn, LitInt, LitStr, Token};
4use syn::parse::{Parse, ParseStream};
5
6struct ScopeArgs {
7    category: LitStr,
8    scope_code: Option<LitInt>,
9}
10
11impl Parse for ScopeArgs {
12    fn parse(input: ParseStream) -> syn::Result<Self> {
13        let category = input.parse()?;
14        let scope_code = if input.peek(Token![,]) {
15            input.parse::<Token![,]>()?;
16            Some(input.parse()?)
17        } else {
18            None
19        };
20
21        Ok(ScopeArgs { category, scope_code })
22    }
23}
24
25#[proc_macro_attribute]
26pub fn has_scope(attr: TokenStream, item: TokenStream) -> TokenStream {
27    let ScopeArgs { category, scope_code } = parse_macro_input!(attr as ScopeArgs);
28    let scope_formatted: Option<u32> = scope_code.map(|lit| lit.base10_parse().unwrap());
29
30    let category = category.value();
31    let input = parse_macro_input!(item as ItemFn);
32    let fn_name = &input.sig.ident;
33    let block = &input.block;
34    let inputs = &input.sig.inputs;
35    let output = &input.sig.output;
36    let attrs = &input.attrs;
37    let is_async = input.sig.asyncness.is_some();
38
39    let fn_string = fn_name.to_string();
40    let mut guard_str = fn_string.clone();
41    if let Some(first) = guard_str.get_mut(0..1) {
42        first.make_ascii_uppercase();
43    }
44    guard_str.push_str("ScopeGuard");
45    let guard_name = syn::Ident::new(&guard_str, fn_name.span());
46
47    let scope_expr = if let Some(code) = scope_formatted {
48        quote! { Some(#code) }
49    } else {
50        quote! { None }
51    };
52
53    let guard_def = quote! {
54        struct #guard_name;
55
56        #[rocket::async_trait]
57        impl<'r> ::rocket::request::FromRequest<'r> for #guard_name {
58            type Error = ::rocket::http::Status;
59            async fn from_request(req: &'r ::rocket::Request<'_>)
60                -> ::rocket::request::Outcome<Self, Self::Error>
61            {
62                // delegate to the `User` guard and propagate any failures/forwards
63                let user = match req.guard::<User>().await {
64                    ::rocket::request::Outcome::Success(u) => u,
65                    ::rocket::request::Outcome::Error((status, _)) => {
66                        // maintain the same status returned by the User guard, but
67                        // surface a generic unauthorized error for this guard.
68                        return ::rocket::request::Outcome::Error((
69                            status,
70                            ::rocket::http::Status::Unauthorized,
71                        ));
72                    }
73                    ::rocket::request::Outcome::Forward(status) => {
74                        return ::rocket::request::Outcome::Forward(status);
75                    }
76                };
77
78                if !user.has_scope(#category, #scope_expr) {
79                    return ::rocket::request::Outcome::Error((
80                        ::rocket::http::Status::Forbidden,
81                        ::rocket::http::Status::Forbidden,
82                    ));
83                }
84
85                ::rocket::request::Outcome::Success(#guard_name)
86            }
87        }
88    };
89
90    // prepend the guard parameter to the original inputs
91    let fn_inputs = quote! { _guard: #guard_name, #inputs };
92
93    let expanded = if is_async {
94        quote! {
95            #guard_def
96            #(#attrs)*
97            async fn #fn_name(#fn_inputs) #output {
98                #block
99            }
100        }
101    } else {
102        quote! {
103            #guard_def
104            #(#attrs)*
105            fn #fn_name(#fn_inputs) #output {
106                #block
107            }
108        }
109    };
110
111    TokenStream::from(expanded)
112}