netdb_auth_macro_derive/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, ItemFn, LitInt, LitStr, Token};
4use syn::parse::{Parse, ParseStream};
5
6struct ScopeArgs {
7 category: LitStr,
8 scope_code: Option<LitInt>,
9}
10
11impl Parse for ScopeArgs {
12 fn parse(input: ParseStream) -> syn::Result<Self> {
13 let category = input.parse()?;
14 let scope_code = if input.peek(Token![,]) {
15 input.parse::<Token![,]>()?;
16 Some(input.parse()?)
17 } else {
18 None
19 };
20
21 Ok(ScopeArgs { category, scope_code })
22 }
23}
24
25#[proc_macro_attribute]
26pub fn has_scope(attr: TokenStream, item: TokenStream) -> TokenStream {
27 let ScopeArgs { category, scope_code } = parse_macro_input!(attr as ScopeArgs);
28 let scope_formatted: Option<u32> = scope_code.map(|lit| lit.base10_parse().unwrap());
29
30 let category = category.value();
31 let input = parse_macro_input!(item as ItemFn);
32 let fn_name = &input.sig.ident;
33 let block = &input.block;
34 let inputs = &input.sig.inputs;
35 let output = &input.sig.output;
36 let attrs = &input.attrs;
37 let is_async = input.sig.asyncness.is_some();
38
39 let fn_string = fn_name.to_string();
40 let mut guard_str = fn_string.clone();
41 if let Some(first) = guard_str.get_mut(0..1) {
42 first.make_ascii_uppercase();
43 }
44 guard_str.push_str("ScopeGuard");
45 let guard_name = syn::Ident::new(&guard_str, fn_name.span());
46
47 let scope_expr = if let Some(code) = scope_formatted {
48 quote! { Some(#code) }
49 } else {
50 quote! { None }
51 };
52
53 let guard_def = quote! {
54 struct #guard_name;
55
56 #[rocket::async_trait]
57 impl<'r> ::rocket::request::FromRequest<'r> for #guard_name {
58 type Error = ::rocket::http::Status;
59 async fn from_request(req: &'r ::rocket::Request<'_>)
60 -> ::rocket::request::Outcome<Self, Self::Error>
61 {
62 let user = match req.guard::<User>().await {
64 ::rocket::request::Outcome::Success(u) => u,
65 ::rocket::request::Outcome::Error((status, _)) => {
66 return ::rocket::request::Outcome::Error((
69 status,
70 ::rocket::http::Status::Unauthorized,
71 ));
72 }
73 ::rocket::request::Outcome::Forward(status) => {
74 return ::rocket::request::Outcome::Forward(status);
75 }
76 };
77
78 if !user.has_scope(#category, #scope_expr) {
79 return ::rocket::request::Outcome::Error((
80 ::rocket::http::Status::Forbidden,
81 ::rocket::http::Status::Forbidden,
82 ));
83 }
84
85 ::rocket::request::Outcome::Success(#guard_name)
86 }
87 }
88 };
89
90 let fn_inputs = quote! { _guard: #guard_name, #inputs };
92
93 let expanded = if is_async {
94 quote! {
95 #guard_def
96 #(#attrs)*
97 async fn #fn_name(#fn_inputs) #output {
98 #block
99 }
100 }
101 } else {
102 quote! {
103 #guard_def
104 #(#attrs)*
105 fn #fn_name(#fn_inputs) #output {
106 #block
107 }
108 }
109 };
110
111 TokenStream::from(expanded)
112}