rocket_jwt_authorization/
lib.rs1mod panic;
10
11use proc_macro::TokenStream;
12use quote::quote;
13use syn::{
14 parse::{Parse, ParseStream},
15 DeriveInput, Expr, Lit, Meta, Path, Token,
16};
17
18const CORRECT_USAGE_FOR_JWT_ATTRIBUTE: &[&str] = &[
19 "#[jwt(\"key\")]",
20 "#[jwt(PATH)]",
21 "#[jwt(\"key\", sha2::Sha512)]",
22 "#[jwt(PATH, sha2::Sha512)]",
23 "#[jwt(PATH, sha2::Sha512, Header)]",
24 "#[jwt(PATH, sha2::Sha512, Cookie(\"access_token\"), Header, Query(PATH))]",
25];
26
27enum Source {
28 Header,
29 Cookie(Expr),
30 Query(Expr),
31 #[allow(dead_code)]
33 Body(Expr),
34}
35
36impl Source {
37 #[inline]
38 fn as_str(&self) -> &'static str {
39 match self {
40 Source::Header => "header",
41 Source::Cookie(_) => "cookie",
42 Source::Query(_) => "query",
43 Source::Body(_) => "body",
44 }
45 }
46
47 #[inline]
48 fn from<S: AsRef<str>>(name: S, expr: Expr) -> Option<Source> {
49 let name = name.as_ref();
50
51 match name {
52 "query" => Some(Source::Query(expr)),
53 "cookie" => Some(Source::Cookie(expr)),
54 "body" => unimplemented!(),
55 _ => None,
56 }
57 }
58
59 #[inline]
60 fn search<S: AsRef<str>>(sources: &[Source], name: S) -> Option<&Source> {
61 let name = name.as_ref();
62
63 sources.iter().find(|source| source.as_str() == name)
64 }
65
66 #[inline]
67 fn search_cookie_get_expr(sources: &[Source]) -> Option<&Expr> {
68 for source in sources.iter() {
69 if let Source::Cookie(expr) = source {
70 return Some(expr);
71 }
72 }
73
74 None
75 }
76}
77
78struct Parser2 {
79 expr: Expr,
80}
81
82impl Parse for Parser2 {
83 #[inline]
84 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
85 match input.parse::<Expr>() {
86 Ok(expr) => {
87 let pass = match &expr {
88 Expr::Path(_) => true,
89 Expr::Lit(lit) => matches!(lit.lit, Lit::Str(_)),
90 _ => false,
91 };
92
93 if !pass {
94 panic::attribute_incorrect_format("jwt", CORRECT_USAGE_FOR_JWT_ATTRIBUTE);
95 }
96
97 Ok(Parser2 {
98 expr,
99 })
100 },
101 _ => panic::attribute_incorrect_format("jwt", CORRECT_USAGE_FOR_JWT_ATTRIBUTE),
102 }
103 }
104}
105
106struct Parser {
107 key: Expr,
108 algorithm: Path,
109 sources: Vec<Source>,
110}
111
112impl Parse for Parser {
113 #[inline]
114 fn parse(input: ParseStream) -> Result<Self, syn::Error> {
115 let key = input.parse::<Parser2>()?.expr;
116
117 let (algorithm, sources): (Path, Vec<Source>) = {
118 if input.is_empty() {
119 (syn::parse2(quote!(::sha2::Sha256))?, vec![Source::Header])
120 } else {
121 input.parse::<Token!(,)>()?;
122
123 match input.parse::<Path>() {
124 Ok(p) => {
125 let mut sources = Vec::new();
126
127 while !input.is_empty() {
128 input.parse::<Token!(,)>()?;
129
130 let m = input.parse::<Meta>()?;
131
132 let attr_name = match m.path().get_ident() {
133 Some(ident) => ident.to_string().to_ascii_lowercase(),
134 None => {
135 panic::attribute_incorrect_format(
136 "jwt",
137 CORRECT_USAGE_FOR_JWT_ATTRIBUTE,
138 );
139 },
140 };
141
142 if Source::search(&sources, attr_name.as_str()).is_some() {
143 panic::duplicated_source(attr_name.as_str());
144 }
145
146 match m {
147 Meta::Path(_) => {
148 if attr_name.eq_ignore_ascii_case("header") {
149 sources.push(Source::Header);
150 } else {
151 panic::attribute_incorrect_format(
152 "jwt",
153 CORRECT_USAGE_FOR_JWT_ATTRIBUTE,
154 );
155 }
156 },
157 Meta::NameValue(v) => {
158 let expr = v.value;
159
160 let pass = match &expr {
161 Expr::Path(_) => true,
162 Expr::Lit(lit) => matches!(lit.lit, Lit::Str(_)),
163 _ => false,
164 };
165
166 if !pass {
167 panic::attribute_incorrect_format(
168 "jwt",
169 CORRECT_USAGE_FOR_JWT_ATTRIBUTE,
170 );
171 }
172
173 match Source::from(attr_name, expr) {
174 Some(source) => sources.push(source),
175 None => panic::attribute_incorrect_format(
176 "jwt",
177 CORRECT_USAGE_FOR_JWT_ATTRIBUTE,
178 ),
179 }
180 },
181 Meta::List(list) => {
182 let parsed: Parser2 = list.parse_args()?;
183
184 let expr = parsed.expr;
185
186 match Source::from(attr_name, expr) {
187 Some(source) => sources.push(source),
188 None => panic::attribute_incorrect_format(
189 "jwt",
190 CORRECT_USAGE_FOR_JWT_ATTRIBUTE,
191 ),
192 }
193 },
194 }
195 }
196
197 if sources.is_empty() {
198 sources.push(Source::Header);
199 }
200
201 (p, sources)
202 },
203 Err(_) => {
204 panic::attribute_incorrect_format("jwt", CORRECT_USAGE_FOR_JWT_ATTRIBUTE)
205 },
206 }
207 }
208 };
209
210 Ok(Parser {
211 key,
212 algorithm,
213 sources,
214 })
215 }
216}
217
218fn derive_input_handler(ast: DeriveInput) -> TokenStream {
219 for attr in ast.attrs {
220 if attr.path().is_ident("jwt") {
221 match attr.meta {
222 Meta::List(list) => {
223 let parsed: Parser = list.parse_args().unwrap();
224
225 let algorithm = parsed.algorithm;
226 let key = parsed.key;
227 let sources = parsed.sources;
228
229 let name = &ast.ident;
230 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
231
232 let get_jwt_hasher = quote! {
233 #[inline]
234 pub fn get_jwt_hasher() -> &'static hmac::Hmac<#algorithm> {
235 static START: ::std::sync::Once = ::std::sync::Once::new();
236 static mut HMAC: Option<hmac::Hmac<#algorithm>> = None;
237
238 unsafe {
239 START.call_once(|| {
240 use ::hmac::Hmac;
241 use ::hmac::digest::KeyInit;
242
243 HMAC = Some(Hmac::new_from_slice(unsafe {#key}.as_ref()).unwrap())
244 });
245
246 HMAC.as_ref().unwrap()
247 }
248 }
249 };
250
251 let get_jwt_token = quote! {
252 #[inline]
253 pub fn get_jwt_token(&self) -> String {
254 use ::jwt::SignWithKey;
255
256 let hasher = Self::get_jwt_hasher();
257
258 self.sign_with_key(hasher).unwrap()
259 }
260 };
261
262 let verify_jwt_token = quote! {
263 #[inline]
264 pub fn verify_jwt_token<S: AsRef<str>>(token: S) -> Result<Self, ::jwt::Error> {
265 use ::jwt::VerifyWithKey;
266
267 let token = token.as_ref();
268
269 let hasher = Self::get_jwt_hasher();
270
271 token.verify_with_key(hasher)
272 }
273 };
274
275 let (set_cookie, set_cookie_insecure, remove_cookie) = if let Some(expr) =
276 Source::search_cookie_get_expr(&sources)
277 {
278 let set_cookie = quote! {
279 #[inline]
280 pub fn set_cookie(&self, cookies: &::rocket::http::CookieJar) {
281 let mut cookie = ::rocket::http::Cookie::new(unsafe {#expr}, self.get_jwt_token());
282
283 cookie.set_secure(true);
284
285 cookies.add(cookie);
286 }
287 };
288
289 let set_cookie_insecure = quote! {
290 #[inline]
291 pub fn set_cookie_insecure(&self, cookies: &::rocket::http::CookieJar) {
292 let mut cookie = ::rocket::http::Cookie::new(unsafe {#expr}, self.get_jwt_token());
293
294 cookie.set_same_site(::rocket::http::SameSite::Strict);
295
296 cookies.add(cookie);
297 }
298 };
299
300 let remove_cookie = quote! {
301 #[inline]
302 pub fn remove_cookie(cookies: &::rocket::http::CookieJar) {
303 cookies.remove(::rocket::http::Cookie::named(unsafe {#expr}));
304 }
305 };
306
307 (set_cookie, set_cookie_insecure, remove_cookie)
308 } else {
309 (quote!(), quote!(), quote!())
310 };
311
312 let (from_request, from_request_cache) = {
313 let mut source_streams = Vec::with_capacity(sources.len());
314
315 for source in sources.iter() {
316 let source_stream = match source {
317 Source::Header => {
318 quote! {
319 else if let Some(authorization) = request.headers().get("authorization").next() {
320 if let Some(token) = authorization.strip_prefix("Bearer ") {
321 match #name::verify_jwt_token(token) {
322 Ok(o) => Some(o),
323 Err(_) => None
324 }
325 } else {
326 None
327 }
328 }
329 }
330 },
331 Source::Cookie(expr) => {
332 quote! {
333 else if let Some(token) = request.cookies().get(unsafe {#expr}) {
334 match #name::verify_jwt_token(token.value()) {
335 Ok(o) => Some(o),
336 Err(_) => {
337 #name::remove_cookie(&request.cookies());
338
339 None
340 }
341 }
342 }
343 }
344 },
345 Source::Query(expr) => {
346 quote! {
347 else if let Some(token) = request.query_value(unsafe {#expr}) {
348 let token: &str = token.unwrap();
349
350 match #name::verify_jwt_token(token) {
351 Ok(o) => Some(o),
352 Err(_) => None
353 }
354 }
355 }
356 },
357 _ => unimplemented!(),
358 };
359
360 source_streams.push(source_stream);
361 }
362
363 let from_request_body = quote! {
364 if false {
365 None
366 }
367 #(
368 #source_streams
369 )*
370 else {
371 None
372 }
373 };
374
375 let from_request = quote! {
376 #[rocket::async_trait]
377 impl<'r> ::rocket::request::FromRequest<'r> for #name {
378 type Error = ();
379
380 async fn from_request(request: &'r ::rocket::request::Request<'_>) -> ::rocket::request::Outcome<Self, Self::Error> {
381 match #from_request_body {
382 Some(o) => ::rocket::outcome::Outcome::Success(o),
383 None => ::rocket::outcome::Outcome::Forward(::rocket::http::Status::Unauthorized),
384 }
385 }
386 }
387 };
388
389 let from_request_cache = quote! {
390 #[rocket::async_trait]
391 impl<'r> ::rocket::request::FromRequest<'r> for &'r #name {
392 type Error = ();
393
394 async fn from_request(request: &'r ::rocket::request::Request<'_>) -> ::rocket::request::Outcome<Self, Self::Error> {
395 let cache = request.local_cache(|| {
396 #from_request_body
397 });
398
399 match cache.as_ref() {
400 Some(o) => ::rocket::outcome::Outcome::Success(o),
401 None => ::rocket::outcome::Outcome::Forward(::rocket::http::Status::Unauthorized),
402 }
403 }
404 }
405 };
406
407 (from_request, from_request_cache)
408 };
409
410 let jwt_impl = quote! {
411 impl #impl_generics #name #ty_generics #where_clause {
412 #get_jwt_hasher
413
414 #get_jwt_token
415
416 #verify_jwt_token
417
418 #set_cookie
419
420 #set_cookie_insecure
421
422 #remove_cookie
423 }
424
425 #from_request
426
427 #from_request_cache
428 };
429
430 return jwt_impl.into();
431 },
432 _ => panic::attribute_incorrect_format("jwt", CORRECT_USAGE_FOR_JWT_ATTRIBUTE),
433 }
434 }
435 }
436
437 panic::jwt_not_found();
438}
439
440#[proc_macro_derive(JWT, attributes(jwt))]
441pub fn jwt_derive(input: TokenStream) -> TokenStream {
442 derive_input_handler(syn::parse(input).unwrap())
443}