jwks_client/
jwt.rs

1use std::ops::Add;
2use std::time::{Duration, SystemTime};
3
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5use serde_json::{Map, Value};
6
7use crate::error::{err_inv, Error};
8
9macro_rules! impl_segment {
10    () => (
11        pub fn new(json: Value) -> Self {
12            Self {
13                json
14            }
15        }
16
17        pub fn get_str(&self, key: &str) -> Option<&str> {
18            self.json.get(key)?.as_str()
19        }
20
21        pub fn get_i64(&self, key: &str) -> Option<i64> {
22            self.json.get(key)?.as_i64()
23        }
24
25        pub fn get_u64(&self, key: &str) -> Option<u64> {
26            self.json.get(key)?.as_u64()
27        }
28
29        pub fn get_f64(&self, key: &str) -> Option<f64> {
30            self.json.get(key)?.as_f64()
31        }
32
33        pub fn get_bool(&self, key: &str) -> Option<bool> {
34            self.json.get(key)?.as_bool()
35        }
36
37        pub fn get_object(&self, key: &str) -> Option<&Map<String, Value>> {
38            self.json.get(key)?.as_object()
39        }
40
41        pub fn get_array(&self, key: &str) -> Option<&Vec<Value>> {
42            self.json.get(key)?.as_array()
43        }
44
45        pub fn get_null(&self, key: &str) -> Option<()> {
46            self.json.get(key)?.as_null()
47        }
48
49        pub fn into<T: DeserializeOwned>(&self) -> Result<T, Error> {
50            Ok(serde_json::from_value::<T>(self.json.clone()).or(Err(err_inv("Failed to deserialize segment")))?)
51        }
52    )
53}
54
55#[derive(Debug, Serialize, Deserialize)]
56pub struct Header {
57    pub(crate) json: Value,
58}
59
60impl Header {
61    impl_segment!();
62
63    pub fn alg(&self) -> Option<&str> {
64        self.get_str("alg")
65    }
66
67    pub fn enc(&self) -> Option<&str> {
68        self.get_str("enc")
69    }
70
71    pub fn zip(&self) -> Option<&str> {
72        self.get_str("zip")
73    }
74
75    pub fn jku(&self) -> Option<&str> {
76        self.get_str("jku")
77    }
78
79    pub fn jkw(&self) -> Option<&str> {
80        self.get_str("jkw")
81    }
82
83    pub fn kid(&self) -> Option<&str> {
84        self.get_str("kid")
85    }
86
87    pub fn x5u(&self) -> Option<&str> {
88        self.get_str("x5u")
89    }
90
91    pub fn x5c(&self) -> Option<&str> {
92        self.get_str("x5c")
93    }
94
95    pub fn x5t(&self) -> Option<&str> {
96        self.get_str("x5t")
97    }
98
99    pub fn typ(&self) -> Option<&str> {
100        self.get_str("typ")
101    }
102
103    pub fn cty(&self) -> Option<&str> {
104        self.get_str("cty")
105    }
106
107    pub fn crit(&self) -> Option<&str> {
108        self.get_str("crit")
109    }
110}
111
112#[derive(Debug, Serialize, Deserialize)]
113pub struct Payload {
114    pub(crate) json: Value,
115}
116
117impl Payload {
118    impl_segment!();
119
120    pub fn iss(&self) -> Option<&str> {
121        self.get_str("iss")
122    }
123
124    pub fn sub(&self) -> Option<&str> {
125        self.get_str("sub")
126    }
127
128    pub fn aud(&self) -> Option<&str> {
129        self.get_str("aud")
130    }
131
132    pub fn exp(&self) -> Option<u64> {
133        self.get_f64("exp").and_then(|f| Some(f as u64))
134    }
135
136    pub fn nbf(&self) -> Option<u64> {
137        self.get_f64("nbf").and_then(|f| Some(f as u64))
138    }
139
140    pub fn iat(&self) -> Option<u64> {
141        self.get_f64("iat").and_then(|f| Some(f as u64))
142    }
143
144    pub fn jti(&self) -> Option<&str> {
145        self.get_str("jti")
146    }
147
148    pub fn expiry(&self) -> Option<SystemTime> {
149        if let Some(time) = self.exp() {
150            Some(SystemTime::UNIX_EPOCH.add(Duration::new(time, 0)))
151        } else {
152            None
153        }
154    }
155
156    pub fn issued_at(&self) -> Option<SystemTime> {
157        if let Some(time) = self.iat() {
158            Some(SystemTime::UNIX_EPOCH.add(Duration::new(time, 0)))
159        } else {
160            None
161        }
162    }
163
164    pub fn not_before(&self) -> Option<SystemTime> {
165        if let Some(time) = self.nbf() {
166            Some(SystemTime::UNIX_EPOCH.add(Duration::new(time, 0)))
167        } else {
168            None
169        }
170    }
171}
172
173#[derive(Debug, Serialize, Deserialize)]
174pub struct Jwt {
175    header: Header,
176    payload: Payload,
177    signature: String,
178}
179
180impl Jwt {
181    pub fn new(header: Header, payload: Payload, signature: String) -> Self {
182        Jwt {
183            header,
184            payload,
185            signature,
186        }
187    }
188
189    pub fn header(&self) -> &Header {
190        &self.header
191    }
192
193    pub fn payload(&self) -> &Payload {
194        &self.payload
195    }
196
197    pub fn signature(&self) -> &String {
198        &self.signature
199    }
200
201    pub fn expired(&self) -> Option<bool> {
202        self.expired_time(SystemTime::now())
203    }
204
205    pub fn expired_time(&self, time: SystemTime) -> Option<bool> {
206        match self.payload.expiry() {
207            Some(token_time) => Some(time > token_time),
208            None => None,
209        }
210    }
211
212    pub fn early(&self) -> Option<bool> {
213        self.early_time(SystemTime::now())
214    }
215
216    pub fn early_time(&self, time: SystemTime) -> Option<bool> {
217        match self.payload.not_before() {
218            Some(token_time) => Some(time < token_time),
219            None => None,
220        }
221    }
222
223    pub fn issued_by(&self, issuer: &str) -> Option<bool> {
224        match self.payload.iss() {
225            Some(t) => Some(t == issuer),
226            None => None,
227        }
228    }
229
230    pub fn valid(&self) -> Option<bool> {
231        self.valid_time(SystemTime::now())
232    }
233
234    pub fn valid_time(&self, time: SystemTime) -> Option<bool> {
235        Some(!self.expired_time(time)? && !self.early_time(time)?)
236    }
237}
238
239#[cfg(test)]
240mod tests {
241    use serde_json::json;
242
243    use crate::jwt::{Header, Payload};
244
245    #[test]
246    fn test_header() {
247        let json = json!({
248            "alg": "test_alg",
249            "enc": "test_enc",
250            "zip": "test_zip",
251            "jku": "test_jku",
252            "jkw": "test_jkw",
253            "kid": "test_kid",
254            "x5u": "test_x5u",
255            "x5c": "test_x5c",
256            "x5t": "test_x5t",
257            "typ": "test_typ",
258            "cty": "test_cty",
259            "crit": "test_crit"
260        });
261
262        let test_header = Header { json };
263
264        assert_eq!("test_alg", test_header.alg().unwrap());
265        assert_eq!("test_enc", test_header.enc().unwrap());
266        assert_eq!("test_zip", test_header.zip().unwrap());
267        assert_eq!("test_jku", test_header.jku().unwrap());
268        assert_eq!("test_jkw", test_header.jkw().unwrap());
269        assert_eq!("test_kid", test_header.kid().unwrap());
270        assert_eq!("test_x5u", test_header.x5u().unwrap());
271        assert_eq!("test_x5c", test_header.x5c().unwrap());
272        assert_eq!("test_x5t", test_header.x5t().unwrap());
273        assert_eq!("test_typ", test_header.typ().unwrap());
274        assert_eq!("test_cty", test_header.cty().unwrap());
275        assert_eq!("test_crit", test_header.crit().unwrap());
276    }
277
278    #[test]
279    fn test_payload() {
280        let json = json!({
281            "iss": "test_iss",
282            "exp": 123456f64,  // f64--not u64 since JSON uses f64
283            "iat": 123f64,  // f64--not u64 since JSON uses f64
284            "sub": "test_sub",
285            "aud": "test_aud",
286            "nbf": 456f64,  // f64--not u64 since JSON uses f64
287            "jti": "test_jti",  // f64--not u64 since JSON uses f64
288        });
289
290        let payload = Payload { json };
291
292        assert_eq!("test_iss", payload.iss().unwrap());
293        assert_eq!(123456u64, payload.exp().unwrap());
294        assert_eq!(123u64, payload.iat().unwrap());
295        assert_eq!("test_sub", payload.sub().unwrap());
296        assert_eq!("test_aud", payload.aud().unwrap());
297        assert_eq!(456u64, payload.nbf().unwrap());
298        assert_eq!("test_jti", payload.jti().unwrap());
299    }
300}