rocket_jwt_authorization/
lib.rs

1/*!
2# `jwt-authorization` Request Guard for Rocket Framework
3
4This crate provides a procedural macro to create request guards used for authorization.
5
6See `examples`.
7*/
8
9mod panic;
10
11use proc_macro::TokenStream;
12use quote::quote;
13use syn::{
14    parse::{Parse, ParseStream},
15    DeriveInput, Expr, Lit, Meta, Path, Token,
16};
17
18const CORRECT_USAGE_FOR_JWT_ATTRIBUTE: &[&str] = &[
19    "#[jwt(\"key\")]",
20    "#[jwt(PATH)]",
21    "#[jwt(\"key\", sha2::Sha512)]",
22    "#[jwt(PATH, sha2::Sha512)]",
23    "#[jwt(PATH, sha2::Sha512, Header)]",
24    "#[jwt(PATH, sha2::Sha512, Cookie(\"access_token\"), Header, Query(PATH))]",
25];
26
27enum Source {
28    Header,
29    Cookie(Expr),
30    Query(Expr),
31    // TODO currently it's hard to be implemented, just ignore it
32    #[allow(dead_code)]
33    Body(Expr),
34}
35
36impl Source {
37    #[inline]
38    fn as_str(&self) -> &'static str {
39        match self {
40            Source::Header => "header",
41            Source::Cookie(_) => "cookie",
42            Source::Query(_) => "query",
43            Source::Body(_) => "body",
44        }
45    }
46
47    #[inline]
48    fn from<S: AsRef<str>>(name: S, expr: Expr) -> Option<Source> {
49        let name = name.as_ref();
50
51        match name {
52            "query" => Some(Source::Query(expr)),
53            "cookie" => Some(Source::Cookie(expr)),
54            "body" => unimplemented!(),
55            _ => None,
56        }
57    }
58
59    #[inline]
60    fn search<S: AsRef<str>>(sources: &[Source], name: S) -> Option<&Source> {
61        let name = name.as_ref();
62
63        sources.iter().find(|source| source.as_str() == name)
64    }
65
66    #[inline]
67    fn search_cookie_get_expr(sources: &[Source]) -> Option<&Expr> {
68        for source in sources.iter() {
69            if let Source::Cookie(expr) = source {
70                return Some(expr);
71            }
72        }
73
74        None
75    }
76}
77
78struct Parser2 {
79    expr: Expr,
80}
81
82impl Parse for Parser2 {
83    #[inline]
84    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
85        match input.parse::<Expr>() {
86            Ok(expr) => {
87                let pass = match &expr {
88                    Expr::Path(_) => true,
89                    Expr::Lit(lit) => matches!(lit.lit, Lit::Str(_)),
90                    _ => false,
91                };
92
93                if !pass {
94                    panic::attribute_incorrect_format("jwt", CORRECT_USAGE_FOR_JWT_ATTRIBUTE);
95                }
96
97                Ok(Parser2 {
98                    expr,
99                })
100            },
101            _ => panic::attribute_incorrect_format("jwt", CORRECT_USAGE_FOR_JWT_ATTRIBUTE),
102        }
103    }
104}
105
106struct Parser {
107    key:       Expr,
108    algorithm: Path,
109    sources:   Vec<Source>,
110}
111
112impl Parse for Parser {
113    #[inline]
114    fn parse(input: ParseStream) -> Result<Self, syn::Error> {
115        let key = input.parse::<Parser2>()?.expr;
116
117        let (algorithm, sources): (Path, Vec<Source>) = {
118            if input.is_empty() {
119                (syn::parse2(quote!(::sha2::Sha256))?, vec![Source::Header])
120            } else {
121                input.parse::<Token!(,)>()?;
122
123                match input.parse::<Path>() {
124                    Ok(p) => {
125                        let mut sources = Vec::new();
126
127                        while !input.is_empty() {
128                            input.parse::<Token!(,)>()?;
129
130                            let m = input.parse::<Meta>()?;
131
132                            let attr_name = match m.path().get_ident() {
133                                Some(ident) => ident.to_string().to_ascii_lowercase(),
134                                None => {
135                                    panic::attribute_incorrect_format(
136                                        "jwt",
137                                        CORRECT_USAGE_FOR_JWT_ATTRIBUTE,
138                                    );
139                                },
140                            };
141
142                            if Source::search(&sources, attr_name.as_str()).is_some() {
143                                panic::duplicated_source(attr_name.as_str());
144                            }
145
146                            match m {
147                                Meta::Path(_) => {
148                                    if attr_name.eq_ignore_ascii_case("header") {
149                                        sources.push(Source::Header);
150                                    } else {
151                                        panic::attribute_incorrect_format(
152                                            "jwt",
153                                            CORRECT_USAGE_FOR_JWT_ATTRIBUTE,
154                                        );
155                                    }
156                                },
157                                Meta::NameValue(v) => {
158                                    let expr = v.value;
159
160                                    let pass = match &expr {
161                                        Expr::Path(_) => true,
162                                        Expr::Lit(lit) => matches!(lit.lit, Lit::Str(_)),
163                                        _ => false,
164                                    };
165
166                                    if !pass {
167                                        panic::attribute_incorrect_format(
168                                            "jwt",
169                                            CORRECT_USAGE_FOR_JWT_ATTRIBUTE,
170                                        );
171                                    }
172
173                                    match Source::from(attr_name, expr) {
174                                        Some(source) => sources.push(source),
175                                        None => panic::attribute_incorrect_format(
176                                            "jwt",
177                                            CORRECT_USAGE_FOR_JWT_ATTRIBUTE,
178                                        ),
179                                    }
180                                },
181                                Meta::List(list) => {
182                                    let parsed: Parser2 = list.parse_args()?;
183
184                                    let expr = parsed.expr;
185
186                                    match Source::from(attr_name, expr) {
187                                        Some(source) => sources.push(source),
188                                        None => panic::attribute_incorrect_format(
189                                            "jwt",
190                                            CORRECT_USAGE_FOR_JWT_ATTRIBUTE,
191                                        ),
192                                    }
193                                },
194                            }
195                        }
196
197                        if sources.is_empty() {
198                            sources.push(Source::Header);
199                        }
200
201                        (p, sources)
202                    },
203                    Err(_) => {
204                        panic::attribute_incorrect_format("jwt", CORRECT_USAGE_FOR_JWT_ATTRIBUTE)
205                    },
206                }
207            }
208        };
209
210        Ok(Parser {
211            key,
212            algorithm,
213            sources,
214        })
215    }
216}
217
218fn derive_input_handler(ast: DeriveInput) -> TokenStream {
219    for attr in ast.attrs {
220        if attr.path().is_ident("jwt") {
221            match attr.meta {
222                Meta::List(list) => {
223                    let parsed: Parser = list.parse_args().unwrap();
224
225                    let algorithm = parsed.algorithm;
226                    let key = parsed.key;
227                    let sources = parsed.sources;
228
229                    let name = &ast.ident;
230                    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
231
232                    let get_jwt_hasher = quote! {
233                        #[inline]
234                        pub fn get_jwt_hasher() -> &'static hmac::Hmac<#algorithm> {
235                            static START: ::std::sync::Once = ::std::sync::Once::new();
236                            static mut HMAC: Option<hmac::Hmac<#algorithm>> = None;
237
238                            unsafe {
239                                START.call_once(|| {
240                                    use ::hmac::Hmac;
241                                    use ::hmac::digest::KeyInit;
242
243                                    HMAC = Some(Hmac::new_from_slice(unsafe {#key}.as_ref()).unwrap())
244                                });
245
246                                HMAC.as_ref().unwrap()
247                            }
248                        }
249                    };
250
251                    let get_jwt_token = quote! {
252                        #[inline]
253                        pub fn get_jwt_token(&self) -> String {
254                            use ::jwt::SignWithKey;
255
256                            let hasher = Self::get_jwt_hasher();
257
258                            self.sign_with_key(hasher).unwrap()
259                        }
260                    };
261
262                    let verify_jwt_token = quote! {
263                        #[inline]
264                        pub fn verify_jwt_token<S: AsRef<str>>(token: S) -> Result<Self, ::jwt::Error> {
265                            use ::jwt::VerifyWithKey;
266
267                            let token = token.as_ref();
268
269                            let hasher = Self::get_jwt_hasher();
270
271                            token.verify_with_key(hasher)
272                        }
273                    };
274
275                    let (set_cookie, set_cookie_insecure, remove_cookie) = if let Some(expr) =
276                        Source::search_cookie_get_expr(&sources)
277                    {
278                        let set_cookie = quote! {
279                            #[inline]
280                            pub fn set_cookie(&self, cookies: &::rocket::http::CookieJar) {
281                                let mut cookie = ::rocket::http::Cookie::new(unsafe {#expr}, self.get_jwt_token());
282
283                                cookie.set_secure(true);
284
285                                cookies.add(cookie);
286                            }
287                        };
288
289                        let set_cookie_insecure = quote! {
290                            #[inline]
291                            pub fn set_cookie_insecure(&self, cookies: &::rocket::http::CookieJar) {
292                                let mut cookie = ::rocket::http::Cookie::new(unsafe {#expr}, self.get_jwt_token());
293
294                                cookie.set_same_site(::rocket::http::SameSite::Strict);
295
296                                cookies.add(cookie);
297                            }
298                        };
299
300                        let remove_cookie = quote! {
301                            #[inline]
302                            pub fn remove_cookie(cookies: &::rocket::http::CookieJar) {
303                                cookies.remove(::rocket::http::Cookie::named(unsafe {#expr}));
304                            }
305                        };
306
307                        (set_cookie, set_cookie_insecure, remove_cookie)
308                    } else {
309                        (quote!(), quote!(), quote!())
310                    };
311
312                    let (from_request, from_request_cache) = {
313                        let mut source_streams = Vec::with_capacity(sources.len());
314
315                        for source in sources.iter() {
316                            let source_stream = match source {
317                                Source::Header => {
318                                    quote! {
319                                        else if let Some(authorization) = request.headers().get("authorization").next() {
320                                            if let Some(token) = authorization.strip_prefix("Bearer ") {
321                                                match #name::verify_jwt_token(token) {
322                                                    Ok(o) => Some(o),
323                                                    Err(_) => None
324                                                }
325                                            } else {
326                                                None
327                                            }
328                                        }
329                                    }
330                                },
331                                Source::Cookie(expr) => {
332                                    quote! {
333                                        else if let Some(token) = request.cookies().get(unsafe {#expr}) {
334                                            match #name::verify_jwt_token(token.value()) {
335                                                Ok(o) => Some(o),
336                                                Err(_) => {
337                                                    #name::remove_cookie(&request.cookies());
338
339                                                    None
340                                                }
341                                            }
342                                        }
343                                    }
344                                },
345                                Source::Query(expr) => {
346                                    quote! {
347                                        else if let Some(token) = request.query_value(unsafe {#expr}) {
348                                            let token: &str = token.unwrap();
349
350                                            match #name::verify_jwt_token(token) {
351                                                Ok(o) => Some(o),
352                                                Err(_) => None
353                                            }
354                                        }
355                                    }
356                                },
357                                _ => unimplemented!(),
358                            };
359
360                            source_streams.push(source_stream);
361                        }
362
363                        let from_request_body = quote! {
364                            if false {
365                                None
366                            }
367                            #(
368                                #source_streams
369                            )*
370                            else {
371                                None
372                            }
373                        };
374
375                        let from_request = quote! {
376                            #[rocket::async_trait]
377                            impl<'r> ::rocket::request::FromRequest<'r> for #name {
378                                type Error = ();
379
380                                async fn from_request(request: &'r ::rocket::request::Request<'_>) -> ::rocket::request::Outcome<Self, Self::Error> {
381                                    match #from_request_body {
382                                        Some(o) => ::rocket::outcome::Outcome::Success(o),
383                                        None => ::rocket::outcome::Outcome::Forward(::rocket::http::Status::Unauthorized),
384                                    }
385                                }
386                            }
387                        };
388
389                        let from_request_cache = quote! {
390                            #[rocket::async_trait]
391                            impl<'r> ::rocket::request::FromRequest<'r> for &'r #name {
392                                type Error = ();
393
394                                async fn from_request(request: &'r ::rocket::request::Request<'_>) -> ::rocket::request::Outcome<Self, Self::Error> {
395                                    let cache = request.local_cache(|| {
396                                        #from_request_body
397                                    });
398
399                                    match cache.as_ref() {
400                                        Some(o) => ::rocket::outcome::Outcome::Success(o),
401                                        None => ::rocket::outcome::Outcome::Forward(::rocket::http::Status::Unauthorized),
402                                    }
403                                }
404                            }
405                        };
406
407                        (from_request, from_request_cache)
408                    };
409
410                    let jwt_impl = quote! {
411                        impl #impl_generics #name #ty_generics #where_clause {
412                            #get_jwt_hasher
413
414                            #get_jwt_token
415
416                            #verify_jwt_token
417
418                            #set_cookie
419
420                            #set_cookie_insecure
421
422                            #remove_cookie
423                        }
424
425                        #from_request
426
427                        #from_request_cache
428                    };
429
430                    return jwt_impl.into();
431                },
432                _ => panic::attribute_incorrect_format("jwt", CORRECT_USAGE_FOR_JWT_ATTRIBUTE),
433            }
434        }
435    }
436
437    panic::jwt_not_found();
438}
439
440#[proc_macro_derive(JWT, attributes(jwt))]
441pub fn jwt_derive(input: TokenStream) -> TokenStream {
442    derive_input_handler(syn::parse(input).unwrap())
443}