1use std::ops::Add;
2use std::time::{Duration, SystemTime};
3
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5use serde_json::{Map, Value};
6
7use crate::error::{err_inv, Error};
8
9macro_rules! impl_segment {
10 () => (
11 pub fn new(json: Value) -> Self {
12 Self {
13 json
14 }
15 }
16
17 pub fn get_str(&self, key: &str) -> Option<&str> {
18 self.json.get(key)?.as_str()
19 }
20
21 pub fn get_i64(&self, key: &str) -> Option<i64> {
22 self.json.get(key)?.as_i64()
23 }
24
25 pub fn get_u64(&self, key: &str) -> Option<u64> {
26 self.json.get(key)?.as_u64()
27 }
28
29 pub fn get_f64(&self, key: &str) -> Option<f64> {
30 self.json.get(key)?.as_f64()
31 }
32
33 pub fn get_bool(&self, key: &str) -> Option<bool> {
34 self.json.get(key)?.as_bool()
35 }
36
37 pub fn get_object(&self, key: &str) -> Option<&Map<String, Value>> {
38 self.json.get(key)?.as_object()
39 }
40
41 pub fn get_array(&self, key: &str) -> Option<&Vec<Value>> {
42 self.json.get(key)?.as_array()
43 }
44
45 pub fn get_null(&self, key: &str) -> Option<()> {
46 self.json.get(key)?.as_null()
47 }
48
49 pub fn into<T: DeserializeOwned>(&self) -> Result<T, Error> {
50 Ok(serde_json::from_value::<T>(self.json.clone()).or(Err(err_inv("Failed to deserialize segment")))?)
51 }
52 )
53}
54
55#[derive(Debug, Serialize, Deserialize)]
56pub struct Header {
57 pub(crate) json: Value,
58}
59
60impl Header {
61 impl_segment!();
62
63 pub fn alg(&self) -> Option<&str> {
64 self.get_str("alg")
65 }
66
67 pub fn enc(&self) -> Option<&str> {
68 self.get_str("enc")
69 }
70
71 pub fn zip(&self) -> Option<&str> {
72 self.get_str("zip")
73 }
74
75 pub fn jku(&self) -> Option<&str> {
76 self.get_str("jku")
77 }
78
79 pub fn jkw(&self) -> Option<&str> {
80 self.get_str("jkw")
81 }
82
83 pub fn kid(&self) -> Option<&str> {
84 self.get_str("kid")
85 }
86
87 pub fn x5u(&self) -> Option<&str> {
88 self.get_str("x5u")
89 }
90
91 pub fn x5c(&self) -> Option<&str> {
92 self.get_str("x5c")
93 }
94
95 pub fn x5t(&self) -> Option<&str> {
96 self.get_str("x5t")
97 }
98
99 pub fn typ(&self) -> Option<&str> {
100 self.get_str("typ")
101 }
102
103 pub fn cty(&self) -> Option<&str> {
104 self.get_str("cty")
105 }
106
107 pub fn crit(&self) -> Option<&str> {
108 self.get_str("crit")
109 }
110}
111
112#[derive(Debug, Serialize, Deserialize)]
113pub struct Payload {
114 pub(crate) json: Value,
115}
116
117impl Payload {
118 impl_segment!();
119
120 pub fn iss(&self) -> Option<&str> {
121 self.get_str("iss")
122 }
123
124 pub fn sub(&self) -> Option<&str> {
125 self.get_str("sub")
126 }
127
128 pub fn aud(&self) -> Option<&str> {
129 self.get_str("aud")
130 }
131
132 pub fn exp(&self) -> Option<u64> {
133 self.get_f64("exp").and_then(|f| Some(f as u64))
134 }
135
136 pub fn nbf(&self) -> Option<u64> {
137 self.get_f64("nbf").and_then(|f| Some(f as u64))
138 }
139
140 pub fn iat(&self) -> Option<u64> {
141 self.get_f64("iat").and_then(|f| Some(f as u64))
142 }
143
144 pub fn jti(&self) -> Option<&str> {
145 self.get_str("jti")
146 }
147
148 pub fn expiry(&self) -> Option<SystemTime> {
149 if let Some(time) = self.exp() {
150 Some(SystemTime::UNIX_EPOCH.add(Duration::new(time, 0)))
151 } else {
152 None
153 }
154 }
155
156 pub fn issued_at(&self) -> Option<SystemTime> {
157 if let Some(time) = self.iat() {
158 Some(SystemTime::UNIX_EPOCH.add(Duration::new(time, 0)))
159 } else {
160 None
161 }
162 }
163
164 pub fn not_before(&self) -> Option<SystemTime> {
165 if let Some(time) = self.nbf() {
166 Some(SystemTime::UNIX_EPOCH.add(Duration::new(time, 0)))
167 } else {
168 None
169 }
170 }
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174pub struct Jwt {
175 header: Header,
176 payload: Payload,
177 signature: String,
178}
179
180impl Jwt {
181 pub fn new(header: Header, payload: Payload, signature: String) -> Self {
182 Jwt {
183 header,
184 payload,
185 signature,
186 }
187 }
188
189 pub fn header(&self) -> &Header {
190 &self.header
191 }
192
193 pub fn payload(&self) -> &Payload {
194 &self.payload
195 }
196
197 pub fn signature(&self) -> &String {
198 &self.signature
199 }
200
201 pub fn expired(&self) -> Option<bool> {
202 self.expired_time(SystemTime::now())
203 }
204
205 pub fn expired_time(&self, time: SystemTime) -> Option<bool> {
206 match self.payload.expiry() {
207 Some(token_time) => Some(time > token_time),
208 None => None,
209 }
210 }
211
212 pub fn early(&self) -> Option<bool> {
213 self.early_time(SystemTime::now())
214 }
215
216 pub fn early_time(&self, time: SystemTime) -> Option<bool> {
217 match self.payload.not_before() {
218 Some(token_time) => Some(time < token_time),
219 None => None,
220 }
221 }
222
223 pub fn issued_by(&self, issuer: &str) -> Option<bool> {
224 match self.payload.iss() {
225 Some(t) => Some(t == issuer),
226 None => None,
227 }
228 }
229
230 pub fn valid(&self) -> Option<bool> {
231 self.valid_time(SystemTime::now())
232 }
233
234 pub fn valid_time(&self, time: SystemTime) -> Option<bool> {
235 Some(!self.expired_time(time)? && !self.early_time(time)?)
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use serde_json::json;
242
243 use crate::jwt::{Header, Payload};
244
245 #[test]
246 fn test_header() {
247 let json = json!({
248 "alg": "test_alg",
249 "enc": "test_enc",
250 "zip": "test_zip",
251 "jku": "test_jku",
252 "jkw": "test_jkw",
253 "kid": "test_kid",
254 "x5u": "test_x5u",
255 "x5c": "test_x5c",
256 "x5t": "test_x5t",
257 "typ": "test_typ",
258 "cty": "test_cty",
259 "crit": "test_crit"
260 });
261
262 let test_header = Header { json };
263
264 assert_eq!("test_alg", test_header.alg().unwrap());
265 assert_eq!("test_enc", test_header.enc().unwrap());
266 assert_eq!("test_zip", test_header.zip().unwrap());
267 assert_eq!("test_jku", test_header.jku().unwrap());
268 assert_eq!("test_jkw", test_header.jkw().unwrap());
269 assert_eq!("test_kid", test_header.kid().unwrap());
270 assert_eq!("test_x5u", test_header.x5u().unwrap());
271 assert_eq!("test_x5c", test_header.x5c().unwrap());
272 assert_eq!("test_x5t", test_header.x5t().unwrap());
273 assert_eq!("test_typ", test_header.typ().unwrap());
274 assert_eq!("test_cty", test_header.cty().unwrap());
275 assert_eq!("test_crit", test_header.crit().unwrap());
276 }
277
278 #[test]
279 fn test_payload() {
280 let json = json!({
281 "iss": "test_iss",
282 "exp": 123456f64, "iat": 123f64, "sub": "test_sub",
285 "aud": "test_aud",
286 "nbf": 456f64, "jti": "test_jti", });
289
290 let payload = Payload { json };
291
292 assert_eq!("test_iss", payload.iss().unwrap());
293 assert_eq!(123456u64, payload.exp().unwrap());
294 assert_eq!(123u64, payload.iat().unwrap());
295 assert_eq!("test_sub", payload.sub().unwrap());
296 assert_eq!("test_aud", payload.aud().unwrap());
297 assert_eq!(456u64, payload.nbf().unwrap());
298 assert_eq!("test_jti", payload.jti().unwrap());
299 }
300}