1use proc_macro::{Span, TokenStream};
2use quote::{format_ident, quote};
3use std::collections::HashMap;
4use syn::{
5 parse_macro_input, AttributeArgs, DeriveInput, Expr, ExprLit, ExprPath, Lit, LitStr, Meta,
6 NestedMeta,
7};
8
9const ONE_MONTH_IN_SECONDS: u64 = 2_592_000;
10const ONE_MINUTE_IN_SECONDS: u64 = 60;
11
12fn get_lit_int(lit: Option<&Lit>, default_value: u64) -> u64 {
13 match lit {
14 Some(exp_lit) => {
15 if let Lit::Int(exp_lit_int) = exp_lit {
16 exp_lit_int.base10_digits().parse::<u64>().unwrap()
17 } else {
18 default_value
19 }
20 }
21 None => default_value,
22 }
23}
24
25fn get_lit_str(lit: Option<&Lit>, default_value: String) -> String {
26 match lit {
27 Some(exp_lit) => {
28 if let Lit::Str(exp_lit_str) = exp_lit {
29 exp_lit_str.value()
30 } else {
31 default_value
32 }
33 }
34 None => default_value,
35 }
36}
37
38fn parse_invocation(attr: Vec<NestedMeta>, input: DeriveInput) -> TokenStream {
39 let mut attr_into_iter = attr.into_iter();
40
41 let secret = attr_into_iter.next();
43 let mut secrete_value: Expr = Expr::Lit(ExprLit {
44 attrs: Vec::new(),
45 lit: Lit::Str(LitStr::new("", Span::call_site().into())),
46 });
47
48 if let Some(secret) = secret {
49 match secret {
50 NestedMeta::Lit(lit) => {
51 if let Lit::Str(lit_str) = lit {
52 secrete_value = Expr::Lit(ExprLit {
53 attrs: Vec::new(),
54 lit: Lit::Str(lit_str),
55 });
56 }
57 }
58 NestedMeta::Meta(meta) => {
59 if let Meta::Path(secret_path) = meta {
60 secrete_value = Expr::Path(ExprPath {
61 attrs: Vec::new(),
62 qself: None,
63 path: secret_path,
64 })
65 }
66 }
67 }
68 }
69
70 let mut hashmap: HashMap<String, Lit> = HashMap::new();
71 for attr_iter in attr_into_iter {
72 if let NestedMeta::Meta(Meta::NameValue(namevalue)) = attr_iter {
73 let name = namevalue.path;
74 let value = namevalue.lit;
75 let name = name.segments[0].ident.to_string();
76 hashmap.insert(name, value);
77 }
78 }
79
80 let exp = get_lit_int(hashmap.get("exp"), ONE_MONTH_IN_SECONDS);
81 let leeway = get_lit_int(hashmap.get("leeway"), ONE_MINUTE_IN_SECONDS);
82 let cookie_key = get_lit_str(hashmap.get("cookie"), "".to_string());
83 let query_key = get_lit_str(hashmap.get("query"), "".to_string());
84
85 let guard_type = &input.ident;
87 let vis = &input.vis;
88 let fairing_name = format!("'{}' JwtFairing", &guard_type.to_string());
89 let guard_claim = format_ident!("{}JwtClaim", &guard_type);
90
91 let jwt = quote!(::jsonwebtoken);
92 #[allow(non_snake_case)]
93 let Result = quote!(::jsonwebtoken::errors::Result);
94 #[allow(non_snake_case)]
95 let Status = quote!(::rocket::http::Status);
96 #[allow(non_snake_case)]
97 let Outcome = quote!(::rocket::outcome::Outcome);
98 let request = quote!(::rocket::request);
99 let response = quote!(::rocket::response);
100 let std_time = quote!(::std::time);
101 let serder = quote!(::serde);
102
103 let async_trait = quote!(#[::rocket::async_trait]);
104
105 let guard_types = quote! {
106 #[derive(Debug, #serder::Deserialize, #serder::Serialize)]
107 #input
108
109 #[derive(Debug, #serder::Deserialize,#serder::Serialize)]
110 #vis struct #guard_claim {
111 exp: u64,
112 iat: u64,
113 user: #guard_type
114 }
115 };
116
117 quote! {
118 #guard_types
119
120 impl #guard_type {
121 pub fn fairing() -> impl ::rocket::fairing::Fairing {
122 ::rocket::fairing::AdHoc::on_ignite(#fairing_name, |rocket| async {
123 rocket
124 })
125 }
126
127 pub fn sign(user: #guard_type) -> String {
128 let now = #std_time::SystemTime::now().duration_since(#std_time::UNIX_EPOCH).unwrap().as_secs();
129 let payload = #guard_claim {
130 exp: #exp + now,
131 iat: now,
132 user,
133 };
134
135 #jwt::encode(&#jwt::Header::default(), &payload, &#jwt::EncodingKey::from_secret((#secrete_value).as_bytes())).unwrap()
136 }
137
138 pub fn decode(token: String) -> #Result<#guard_claim> {
139 let mut validation = #jwt::Validation::default();
140 validation.leeway = #leeway;
141
142 let result = #jwt::decode::<#guard_claim>(&token, &#jwt::DecodingKey::from_secret((#secrete_value).as_bytes()), &validation);
143 match result {
144 Ok(token_claim) => Ok(token_claim.claims),
145 Err(err) => Err(err),
146 }
147 }
148 }
149
150 #async_trait
151 impl<'r> #request::FromRequest<'r> for #guard_type {
152 type Error = #response::status::Custom<String>;
153 async fn from_request(request: &'r #request::Request<'_>,) -> #request::Outcome<Self, #response::status::Custom<String>> {
156 let mut auth_str: Option<String> = None;
157 if (#cookie_key) != "" {
158 auth_str = match request.cookies().get(#cookie_key) {
159 None => None,
160 Some(t) => Some(t.value().to_string()),
161 };
162 } else if (#query_key) != "" {
163 auth_str = match request.query_value::<String>(#query_key) {
164 None => None,
165 Some(t) => match t {
166 Ok(r) => Some(r),
167 Err(_) => None,
168 }
169 }
170 } else {
171 auth_str = match auth_str {
172 Some(auth_str) => Some(auth_str),
173 None => match request.headers().get_one("Authorization") {
174 Some(s) => Some(s.to_string()),
175 None => None,
176 }
177 };
178 };
179
180 if let Some(auth_str) = auth_str {
181 if auth_str.starts_with("Bearer") {
182 let token = auth_str[6..auth_str.len()].trim();
183 match #guard_type::decode(token.to_string()) {
184 Ok(token_data) => {
185 return #Outcome::Success(token_data.user);
186 },
187 Err(err) => {
188 return #Outcome::Error((
189 #Status::Unauthorized,
190 #response::status::Custom(
191 #Status::Unauthorized,
192 err.to_string(),
193 ),
194 ));
195 },
196 }
200 }
201 }
202
203 #Outcome::Error((
205 #Status::Unauthorized,
206 #response::status::Custom(
207 #Status::Unauthorized,
208 String::from("EmptySignature"),
209 ),
210 ))
211 }
212 }
213 }.into()
214}
215
216#[proc_macro_attribute]
287pub fn jwt(attr: TokenStream, input: TokenStream) -> TokenStream {
288 let input = parse_macro_input!(input as DeriveInput);
289 let attr = parse_macro_input!(attr as AttributeArgs);
290
291 parse_invocation(attr, input)
292}