rocket_jwt/
lib.rs

1use proc_macro::{Span, TokenStream};
2use quote::{format_ident, quote};
3use std::collections::HashMap;
4use syn::{
5    parse_macro_input, AttributeArgs, DeriveInput, Expr, ExprLit, ExprPath, Lit, LitStr, Meta,
6    NestedMeta,
7};
8
9const ONE_MONTH_IN_SECONDS: u64 = 2_592_000;
10const ONE_MINUTE_IN_SECONDS: u64 = 60;
11
12fn get_lit_int(lit: Option<&Lit>, default_value: u64) -> u64 {
13    match lit {
14        Some(exp_lit) => {
15            if let Lit::Int(exp_lit_int) = exp_lit {
16                exp_lit_int.base10_digits().parse::<u64>().unwrap()
17            } else {
18                default_value
19            }
20        }
21        None => default_value,
22    }
23}
24
25fn get_lit_str(lit: Option<&Lit>, default_value: String) -> String {
26    match lit {
27        Some(exp_lit) => {
28            if let Lit::Str(exp_lit_str) = exp_lit {
29                exp_lit_str.value()
30            } else {
31                default_value
32            }
33        }
34        None => default_value,
35    }
36}
37
38fn parse_invocation(attr: Vec<NestedMeta>, input: DeriveInput) -> TokenStream {
39    let mut attr_into_iter = attr.into_iter();
40
41    // get secret
42    let secret = attr_into_iter.next();
43    let mut secrete_value: Expr = Expr::Lit(ExprLit {
44        attrs: Vec::new(),
45        lit: Lit::Str(LitStr::new("", Span::call_site().into())),
46    });
47
48    if let Some(secret) = secret {
49        match secret {
50            NestedMeta::Lit(lit) => {
51                if let Lit::Str(lit_str) = lit {
52                    secrete_value = Expr::Lit(ExprLit {
53                        attrs: Vec::new(),
54                        lit: Lit::Str(lit_str),
55                    });
56                }
57            }
58            NestedMeta::Meta(meta) => {
59                if let Meta::Path(secret_path) = meta {
60                    secrete_value = Expr::Path(ExprPath {
61                        attrs: Vec::new(),
62                        qself: None,
63                        path: secret_path,
64                    })
65                }
66            }
67        }
68    }
69
70    let mut hashmap: HashMap<String, Lit> = HashMap::new();
71    for attr_iter in attr_into_iter {
72        if let NestedMeta::Meta(Meta::NameValue(namevalue)) = attr_iter {
73            let name = namevalue.path;
74            let value = namevalue.lit;
75            let name = name.segments[0].ident.to_string();
76            hashmap.insert(name, value);
77        }
78    }
79
80    let exp = get_lit_int(hashmap.get("exp"), ONE_MONTH_IN_SECONDS);
81    let leeway = get_lit_int(hashmap.get("leeway"), ONE_MINUTE_IN_SECONDS);
82    let cookie_key = get_lit_str(hashmap.get("cookie"), "".to_string());
83    let query_key = get_lit_str(hashmap.get("query"), "".to_string());
84
85    // handle input
86    let guard_type = &input.ident;
87    let vis = &input.vis;
88    let fairing_name = format!("'{}' JwtFairing", &guard_type.to_string());
89    let guard_claim = format_ident!("{}JwtClaim", &guard_type);
90
91    let jwt = quote!(::jsonwebtoken);
92    #[allow(non_snake_case)]
93    let Result = quote!(::jsonwebtoken::errors::Result);
94    #[allow(non_snake_case)]
95    let Status = quote!(::rocket::http::Status);
96    #[allow(non_snake_case)]
97    let Outcome = quote!(::rocket::outcome::Outcome);
98    let request = quote!(::rocket::request);
99    let response = quote!(::rocket::response);
100    let std_time = quote!(::std::time);
101    let serder = quote!(::serde);
102
103    let async_trait = quote!(#[::rocket::async_trait]);
104
105    let guard_types = quote! {
106        #[derive(Debug, #serder::Deserialize, #serder::Serialize)]
107        #input
108
109        #[derive(Debug, #serder::Deserialize,#serder::Serialize)]
110        #vis struct #guard_claim {
111            exp: u64,
112            iat: u64,
113            user: #guard_type
114        }
115    };
116
117    quote! {
118        #guard_types
119
120        impl #guard_type {
121            pub fn fairing() -> impl ::rocket::fairing::Fairing {
122                ::rocket::fairing::AdHoc::on_ignite(#fairing_name, |rocket| async {
123                    rocket
124                })
125            }
126
127            pub fn sign(user: #guard_type) -> String {
128                let now = #std_time::SystemTime::now().duration_since(#std_time::UNIX_EPOCH).unwrap().as_secs();
129                let payload = #guard_claim {
130                    exp: #exp + now,
131                    iat: now,
132                    user,
133                };
134
135                #jwt::encode(&#jwt::Header::default(), &payload, &#jwt::EncodingKey::from_secret((#secrete_value).as_bytes())).unwrap()
136            }
137
138            pub fn decode(token: String) -> #Result<#guard_claim> {
139                let mut validation = #jwt::Validation::default();
140                validation.leeway = #leeway;
141
142                let result = #jwt::decode::<#guard_claim>(&token, &#jwt::DecodingKey::from_secret((#secrete_value).as_bytes()), &validation);
143                match result {
144                    Ok(token_claim) => Ok(token_claim.claims),
145                    Err(err) => Err(err),
146                }
147            }
148        }
149
150        #async_trait
151        impl<'r> #request::FromRequest<'r> for #guard_type {
152            type Error = #response::status::Custom<String>;
153            // type Error = ();
154
155            async fn from_request(request: &'r #request::Request<'_>,) -> #request::Outcome<Self, #response::status::Custom<String>> {
156                let mut auth_str: Option<String> = None;
157                if (#cookie_key) != "" {
158                    auth_str = match request.cookies().get(#cookie_key) {
159                        None => None,
160                        Some(t) => Some(t.value().to_string()),
161                    };
162                } else if (#query_key) != "" {
163                    auth_str = match request.query_value::<String>(#query_key) {
164                        None => None,
165                        Some(t) => match t {
166                            Ok(r) => Some(r),
167                            Err(_) => None,
168                        }
169                    }
170                } else {
171                    auth_str = match auth_str {
172                        Some(auth_str) => Some(auth_str),
173                        None => match request.headers().get_one("Authorization") {
174                            Some(s) => Some(s.to_string()),
175                            None => None,
176                        }
177                    };
178                };
179
180                if let Some(auth_str) = auth_str {
181                    if auth_str.starts_with("Bearer") {
182                        let token = auth_str[6..auth_str.len()].trim();
183                        match #guard_type::decode(token.to_string()) {
184                            Ok(token_data) => {
185                                return #Outcome::Success(token_data.user);
186                            },
187                            Err(err) => {
188                                return #Outcome::Error((
189                                    #Status::Unauthorized,
190                                    #response::status::Custom(
191                                        #Status::Unauthorized,
192                                        err.to_string(),
193                                    ),
194                                ));
195                            },
196                            // Err(_) => {
197                            //     return #Outcome::Forward(());
198                            // },
199                        }
200                    }
201                }
202
203                // #Outcome::Forward(())
204                #Outcome::Error((
205                    #Status::Unauthorized,
206                    #response::status::Custom(
207                        #Status::Unauthorized,
208                        String::from("EmptySignature"),
209                    ),
210                ))
211            }
212        }
213    }.into()
214}
215
216///
217/// Attribute to generate a [`jsonwebtoken claim`] and associated metadata.
218///
219/// ```rust
220/// // expire default in 2592_000s
221/// #[rocket_jwt::jwt("secret")]
222/// struct User { id: String }
223/// ```
224///
225/// or
226///
227/// ```rust
228/// // expire in 10s
229/// #[rocket_jwt::jwt("secret", exp = 10)]
230/// struct User { id: String }
231/// ```
232///
233/// ## Example
234/// ---
235/// ```rust
236
237/// #[macro_use]
238/// extern crate rocket;
239///
240/// use rocket_jwt::jwt;
241///
242/// static SECRET_KEY: &str = "secret_key";
243///
244/// #[jwt(SECRET_KEY)]
245/// pub struct UserClaim {
246///     id: String,
247/// }
248///
249/// #[get("/")]
250/// fn index() -> String {
251///     let user_claim = UserClaim {
252///         id: format!("hello_rocket_jwt"),
253///     };
254///     let token = UserClaim::sign(user_claim);
255///     println!("{:#?}", UserClaim::decode(token.clone()));
256///     token
257/// }
258///
259/// #[get("/user_id")]
260/// fn get_uer_id_from_jwt(user: UserClaim) -> String {
261///     format!("user id is {}", user.id)
262/// }
263///
264/// fn main() {
265///     rocket::build()
266///         .attach(UserClaim::fairing())
267///         .mount("/", routes![index, get_uer_id_from_jwt])
268///         .launch();
269/// }
270/// ```
271/// token default comes from request.header, if want get from cookie or query, user
272///
273/// ```rust
274/// #[rocket_jwt::jwt("secret", cookie = "token")]
275/// pub struct UserClaim {
276///     id: String,
277/// }
278/// ```
279///
280/// /// ```rust
281/// #[jwt("secret", query = "token")]
282/// pub struct UserClaim {
283///     id: String,
284/// }
285/// ```
286#[proc_macro_attribute]
287pub fn jwt(attr: TokenStream, input: TokenStream) -> TokenStream {
288    let input = parse_macro_input!(input as DeriveInput);
289    let attr = parse_macro_input!(attr as AttributeArgs);
290
291    parse_invocation(attr, input)
292}