1use super::jwe::Jwe;
2use crate::jose::jwe::{JweAlg, JweEnc, JweError, JweHeader};
3use crate::jose::jws::{Jws, JwsAlg, JwsError, JwsHeader};
4use crate::key::{PrivateKey, PublicKey};
5use core::fmt;
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8use thiserror::Error;
9
10#[derive(Debug, Error)]
13#[non_exhaustive]
14pub enum JwtError {
15 #[error("JWS error: {source}")]
17 Jws { source: JwsError },
18
19 #[error("JWE error: {source}")]
21 Jwe { source: JweError },
22
23 #[error("JSON error: {source}")]
25 Json { source: serde_json::Error },
26
27 #[error("registered claim `{claim}` has invalid type")]
29 InvalidRegisteredClaimType { claim: &'static str },
30
31 #[error("required claim `{claim}` is missing")]
33 RequiredClaimMissing { claim: &'static str },
34
35 #[error("token not yet valid (not before: {}, now: {} [leeway: {}])", not_before, now.numeric_date, now.leeway)]
37 NotYetValid { not_before: i64, now: JwtDate },
38
39 #[error("token expired (not after: {}, now: {} [leeway: {}])", not_after, now.numeric_date, now.leeway)]
41 Expired { not_after: i64, now: JwtDate },
42
43 #[error("invalid validator: {description}")]
45 InvalidValidator { description: &'static str },
46}
47
48impl From<JwsError> for JwtError {
49 fn from(s: JwsError) -> Self {
50 Self::Jws { source: s }
51 }
52}
53
54impl From<serde_json::Error> for JwtError {
55 fn from(e: serde_json::Error) -> Self {
56 Self::Json { source: e }
57 }
58}
59
60impl From<JweError> for JwtError {
61 fn from(s: JweError) -> Self {
62 Self::Jwe { source: s }
63 }
64}
65
66#[derive(Clone, Debug)]
73pub struct JwtDate {
74 pub numeric_date: i64,
75 pub leeway: u16,
76}
77
78impl JwtDate {
79 pub const fn new(numeric_date: i64) -> Self {
80 Self {
81 numeric_date,
82 leeway: 0,
83 }
84 }
85
86 pub const fn new_with_leeway(numeric_date: i64, leeway: u16) -> Self {
87 Self { numeric_date, leeway }
88 }
89
90 pub const fn is_before(&self, other_numeric_date: i64) -> bool {
91 self.numeric_date <= other_numeric_date + self.leeway as i64
92 }
93
94 pub const fn is_before_strict(&self, other_numeric_date: i64) -> bool {
95 self.numeric_date < other_numeric_date + self.leeway as i64
96 }
97
98 pub const fn is_after(&self, other_numeric_date: i64) -> bool {
99 self.numeric_date >= other_numeric_date - self.leeway as i64
100 }
101
102 pub const fn is_after_strict(&self, other_numeric_date: i64) -> bool {
103 self.numeric_date > other_numeric_date - self.leeway as i64
104 }
105}
106
107#[derive(Debug, Clone, Copy)]
110enum CheckStrictness {
111 Ignored,
112 Optional,
113 Required,
114}
115
116#[derive(Debug, Clone)]
117pub struct JwtValidator<'a> {
118 current_date: Option<&'a JwtDate>,
119 expiration_claim: CheckStrictness,
120 not_before_claim: CheckStrictness,
121}
122
123pub const NO_CHECK_VALIDATOR: JwtValidator<'static> = JwtValidator::no_check();
124
125impl<'a> JwtValidator<'a> {
126 pub const fn strict(current_date: &'a JwtDate) -> Self {
128 Self {
129 current_date: Some(current_date),
130 expiration_claim: CheckStrictness::Required,
131 not_before_claim: CheckStrictness::Required,
132 }
133 }
134
135 pub const fn lenient(current_date: &'a JwtDate) -> Self {
137 Self {
138 current_date: Some(current_date),
139 expiration_claim: CheckStrictness::Optional,
140 not_before_claim: CheckStrictness::Optional,
141 }
142 }
143
144 pub const fn no_check() -> Self {
146 Self {
147 current_date: None,
148 expiration_claim: CheckStrictness::Ignored,
149 not_before_claim: CheckStrictness::Ignored,
150 }
151 }
152
153 pub fn current_date(self, current_date: &'a JwtDate) -> Self {
154 Self {
155 current_date: Some(current_date),
156 expiration_claim: CheckStrictness::Required,
157 not_before_claim: CheckStrictness::Required,
158 }
159 }
160
161 pub fn expiration_check_required(self) -> Self {
162 Self {
163 expiration_claim: CheckStrictness::Required,
164 ..self
165 }
166 }
167
168 pub fn expiration_check_optional(self) -> Self {
169 Self {
170 expiration_claim: CheckStrictness::Optional,
171 ..self
172 }
173 }
174
175 pub fn expiration_check_ignored(self) -> Self {
176 Self {
177 expiration_claim: CheckStrictness::Ignored,
178 ..self
179 }
180 }
181
182 pub fn not_before_check_required(self) -> Self {
183 Self {
184 not_before_claim: CheckStrictness::Required,
185 ..self
186 }
187 }
188
189 pub fn not_before_check_optional(self) -> Self {
190 Self {
191 not_before_claim: CheckStrictness::Optional,
192 ..self
193 }
194 }
195
196 pub fn not_before_check_ignored(self) -> Self {
197 Self {
198 not_before_claim: CheckStrictness::Ignored,
199 ..self
200 }
201 }
202}
203
204const JWT_TYPE: &str = "JWT";
207const EXPIRATION_TIME_CLAIM: &str = "exp";
208const NOT_BEFORE_CLAIM: &str = "nbf";
209
210pub struct Jwt<H, C> {
211 pub header: H,
212 pub claims: C,
213}
214
215pub type JwtSig<C> = Jwt<JwsHeader, C>;
216pub type JwtEnc<C> = Jwt<JweHeader, C>;
217
218impl<H, C> Clone for Jwt<H, C>
219where
220 H: Clone,
221 C: Clone,
222{
223 fn clone(&self) -> Self {
224 Self {
225 header: self.header.clone(),
226 claims: self.claims.clone(),
227 }
228 }
229}
230
231impl<H, C> fmt::Debug for Jwt<H, C>
232where
233 H: fmt::Debug,
234 C: fmt::Debug,
235{
236 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
237 fmt.debug_struct("Jwt")
238 .field("header", &self.header)
239 .field("claims", &self.claims)
240 .finish()
241 }
242}
243
244impl<C> Jwt<JwsHeader, C> {
245 pub fn new(alg: JwsAlg, claims: C) -> Self {
246 Jwt {
247 header: JwsHeader {
248 typ: Some(JWT_TYPE.to_owned()),
249 ..JwsHeader::new(alg)
250 },
251 claims,
252 }
253 }
254
255 pub fn new_signed(alg: JwsAlg, claims: C) -> Self {
256 Self::new(alg, claims)
257 }
258}
259
260impl<C> Jwt<JweHeader, C> {
261 pub fn new(alg: JweAlg, enc: JweEnc, claims: C) -> Self {
262 Jwt {
263 header: JweHeader {
264 typ: Some(JWT_TYPE.to_owned()),
265 ..JweHeader::new(alg, enc)
266 },
267 claims,
268 }
269 }
270
271 pub fn new_encrypted(alg: JweAlg, enc: JweEnc, claims: C) -> Self {
272 Self::new(alg, enc, claims)
273 }
274}
275
276impl<H, C> Jwt<H, C> {
277 pub fn new_with_header(header: H, claims: C) -> Self {
278 Jwt { header, claims }
279 }
280}
281
282impl<C> Jwt<JwsHeader, C>
283where
284 C: Serialize,
285{
286 pub fn encode(self, private_key: &PrivateKey) -> Result<String, JwtError> {
287 let jws = Jws {
288 header: self.header,
289 payload: serde_json::to_vec(&self.claims)?,
290 };
291 let encoded = jws.encode(private_key)?;
292 Ok(encoded)
293 }
294}
295
296impl<C> Jwt<JwsHeader, C>
297where
298 C: DeserializeOwned,
299{
300 pub fn decode(encoded_token: &str, public_key: &PublicKey, validator: &JwtValidator) -> Result<Self, JwtError> {
302 let jws = Jws::decode(encoded_token, public_key)?;
303 Ok(Jwt {
304 header: jws.header,
305 claims: h_decode_and_validate_claims(&jws.payload, validator)?,
306 })
307 }
308
309 pub fn decode_dangerous(encoded_token: &str, validator: &JwtValidator) -> Result<Self, JwtError> {
311 let jws = Jws::decode_without_validation(encoded_token)?;
312 Ok(Jwt {
313 header: jws.header,
314 claims: h_decode_and_validate_claims(&jws.payload, validator)?,
315 })
316 }
317}
318
319impl<C> Jwt<JweHeader, C>
320where
321 C: Serialize,
322{
323 pub fn encode(self, asymmetric_key: &PublicKey) -> Result<String, JwtError> {
325 let jwe = Jwe {
326 header: self.header,
327 payload: serde_json::to_vec(&self.claims)?,
328 };
329 let encoded = jwe.encode(asymmetric_key)?;
330 Ok(encoded)
331 }
332
333 pub fn encode_direct(self, cek: &[u8]) -> Result<String, JweError> {
335 let jwe = Jwe {
336 header: self.header,
337 payload: serde_json::to_vec(&self.claims)?,
338 };
339 let encoded = jwe.encode_direct(cek)?;
340 Ok(encoded)
341 }
342}
343
344impl<C> Jwt<JweHeader, C>
345where
346 C: DeserializeOwned,
347{
348 pub fn decode(encoded_token: &str, key: &PrivateKey, validator: &JwtValidator) -> Result<Self, JwtError> {
350 let jwe = Jwe::decode(encoded_token, key)?;
351 Ok(Jwt {
352 header: jwe.header,
353 claims: h_decode_and_validate_claims(&jwe.payload, validator)?,
354 })
355 }
356
357 pub fn decode_direct(encoded_token: &str, cek: &[u8], validator: &JwtValidator) -> Result<Self, JwtError> {
359 let jwe = Jwe::decode_direct(encoded_token, cek)?;
360 Ok(Jwt {
361 header: jwe.header,
362 claims: h_decode_and_validate_claims(&jwe.payload, validator)?,
363 })
364 }
365}
366
367fn h_decode_and_validate_claims<C: DeserializeOwned>(
368 claims_json: &[u8],
369 validator: &JwtValidator,
370) -> Result<C, JwtError> {
371 let claims = match (
372 validator.current_date,
373 validator.not_before_claim,
374 validator.expiration_claim,
375 ) {
376 (None, CheckStrictness::Required, _) | (None, _, CheckStrictness::Required) => {
377 return Err(JwtError::InvalidValidator {
378 description: "current date is missing",
379 })
380 }
381 (Some(current_date), nbf_strictness, exp_strictness) => {
382 let claims = serde_json::from_slice::<serde_json::Value>(claims_json)?;
383
384 let nbf_opt = claims.get(NOT_BEFORE_CLAIM);
385 match (nbf_strictness, nbf_opt) {
386 (CheckStrictness::Ignored, _) | (CheckStrictness::Optional, None) => {}
387 (CheckStrictness::Required, None) => {
388 return Err(JwtError::RequiredClaimMissing {
389 claim: NOT_BEFORE_CLAIM,
390 })
391 }
392 (_, Some(nbf)) => {
393 let nbf_i64 = nbf.as_i64().ok_or(JwtError::InvalidRegisteredClaimType {
394 claim: NOT_BEFORE_CLAIM,
395 })?;
396 if !current_date.is_after(nbf_i64) {
397 return Err(JwtError::NotYetValid {
398 not_before: nbf_i64,
399 now: current_date.clone(),
400 });
401 }
402 }
403 }
404
405 let exp_opt = claims.get(EXPIRATION_TIME_CLAIM);
406 match (exp_strictness, exp_opt) {
407 (CheckStrictness::Ignored, _) | (CheckStrictness::Optional, None) => {}
408 (CheckStrictness::Required, None) => {
409 return Err(JwtError::RequiredClaimMissing {
410 claim: EXPIRATION_TIME_CLAIM,
411 })
412 }
413 (_, Some(exp)) => {
414 let exp_i64 = exp.as_i64().ok_or(JwtError::InvalidRegisteredClaimType {
415 claim: EXPIRATION_TIME_CLAIM,
416 })?;
417 if !current_date.is_before_strict(exp_i64) {
418 return Err(JwtError::Expired {
419 not_after: exp_i64,
420 now: current_date.clone(),
421 });
422 }
423 }
424 }
425
426 serde_json::value::from_value(claims)?
427 }
428 (None, _, _) => serde_json::from_slice(claims_json)?,
429 };
430
431 Ok(claims)
432}
433
434#[cfg(test)]
435mod tests {
436 use super::*;
437 use crate::pem::Pem;
438 use serde::Deserialize;
439 use std::borrow::Cow;
440
441 #[derive(Serialize, Deserialize, Debug, PartialEq)]
442 struct MyClaims {
443 sub: Cow<'static, str>,
444 name: Cow<'static, str>,
445 admin: bool,
446 iat: i32,
447 }
448
449 const fn get_strongly_typed_claims() -> MyClaims {
450 MyClaims {
451 sub: Cow::Borrowed("1234567890"),
452 name: Cow::Borrowed("John Doe"),
453 admin: true,
454 iat: 1516239022,
455 }
456 }
457
458 fn get_private_key_1() -> PrivateKey {
459 let pk_pem = crate::test_files::RSA_2048_PK_1.parse::<Pem>().unwrap();
460 PrivateKey::from_pkcs8(pk_pem.data()).unwrap()
461 }
462
463 #[test]
464 fn encode_jws_rsa_sha256() {
465 let claims = get_strongly_typed_claims();
466 let jwt = JwtSig::new(JwsAlg::RS256, claims);
467 let encoded = jwt.encode(&get_private_key_1()).unwrap();
468 assert_eq!(encoded, crate::test_files::JOSE_JWT_SIG_EXAMPLE);
469 }
470
471 #[test]
472 fn decode_jws_rsa_sha256() {
473 let public_key = get_private_key_1().to_public_key();
474 let jwt = JwtSig::<MyClaims>::decode(
475 crate::test_files::JOSE_JWT_SIG_EXAMPLE,
476 &public_key,
477 &JwtValidator::no_check(),
478 )
479 .unwrap();
480 assert_eq!(jwt.claims, get_strongly_typed_claims());
481
482 let now = JwtDate::new(0);
484 JwtSig::<MyClaims>::decode(
485 crate::test_files::JOSE_JWT_SIG_EXAMPLE,
486 &public_key,
487 &JwtValidator::lenient(&now),
488 )
489 .unwrap();
490 }
491
492 #[test]
493 fn decode_jws_invalid_validator_err() {
494 let public_key = get_private_key_1().to_public_key();
495 let validator = JwtValidator::no_check()
496 .expiration_check_required()
497 .not_before_check_optional();
498 let err = JwtSig::<MyClaims>::decode(crate::test_files::JOSE_JWT_SIG_EXAMPLE, &public_key, &validator)
499 .err()
500 .unwrap();
501 assert_eq!(err.to_string(), "invalid validator: current date is missing");
502 }
503
504 #[test]
505 fn decode_jws_required_claim_missing_err() {
506 let public_key = get_private_key_1().to_public_key();
507 let now = JwtDate::new(0);
508 let validator = JwtValidator::strict(&now);
509 let err = JwtSig::<MyClaims>::decode(crate::test_files::JOSE_JWT_SIG_EXAMPLE, &public_key, &validator)
510 .err()
511 .unwrap();
512 assert_eq!(err.to_string(), "required claim `nbf` is missing");
513 }
514
515 #[test]
516 fn decode_jws_rsa_sha256_using_json_value_claims() {
517 let public_key = get_private_key_1().to_public_key();
518 let validator = JwtValidator::no_check();
519 let jwt = JwtSig::<serde_json::Value>::decode(crate::test_files::JOSE_JWT_SIG_EXAMPLE, &public_key, &validator)
520 .unwrap();
521 assert_eq!(jwt.claims["sub"].as_str().expect("sub"), "1234567890");
522 assert_eq!(jwt.claims["name"].as_str().expect("name"), "John Doe");
523 assert_eq!(jwt.claims["admin"].as_bool().expect("sub"), true);
524 assert_eq!(jwt.claims["iat"].as_i64().expect("iat"), 1516239022);
525 }
526
527 #[test]
528 fn jwe_direct_aes_256_gcm() {
529 let claims = get_strongly_typed_claims();
530 let key = crate::hash::HashAlgorithm::SHA2_256.digest(b"magic_password");
531 let jwt = Jwt::new_encrypted(JweAlg::Direct, JweEnc::Aes256Gcm, claims);
532 let encoded = jwt.encode_direct(&key).unwrap();
533 let decoded = Jwt::<_, MyClaims>::decode_direct(&encoded, &key, &NO_CHECK_VALIDATOR).unwrap();
534 assert_eq!(decoded.claims, get_strongly_typed_claims());
535 }
536
537 #[derive(Serialize, Deserialize)]
538 struct MyExpirableClaims {
539 exp: i64,
540 nbf: i64,
541 msg: String,
542 }
543
544 #[test]
545 fn decode_jws_not_expired() {
546 let public_key = get_private_key_1().to_public_key();
547
548 let jwt = JwtSig::<MyExpirableClaims>::decode(
549 crate::test_files::JOSE_JWT_SIG_WITH_EXP,
550 &public_key,
551 &JwtValidator::strict(&JwtDate::new(1545263999)),
552 )
553 .expect("couldn't decode jwt without leeway");
554
555 assert_eq!(jwt.claims.exp, 1545264000);
556 assert_eq!(jwt.claims.nbf, 1545263000);
557 assert_eq!(jwt.claims.msg, "THIS IS TIME SENSITIVE DATA");
558
559 JwtSig::<MyExpirableClaims>::decode(
561 crate::test_files::JOSE_JWT_SIG_WITH_EXP,
562 &public_key,
563 &JwtValidator::strict(&JwtDate::new_with_leeway(1545264001, 10)),
564 )
565 .expect("couldn't decode jwt with leeway for exp");
566
567 JwtSig::<MyExpirableClaims>::decode(
568 crate::test_files::JOSE_JWT_SIG_WITH_EXP,
569 &public_key,
570 &JwtValidator::strict(&JwtDate::new_with_leeway(1545262999, 10)),
571 )
572 .expect("couldn't decode jwt with leeway for nbf");
573 }
574
575 #[test]
576 fn decode_jws_invalid_date_err() {
577 let public_key = get_private_key_1().to_public_key();
578
579 let err = JwtSig::<MyExpirableClaims>::decode(
580 crate::test_files::JOSE_JWT_SIG_WITH_EXP,
581 &public_key,
582 &JwtValidator::strict(&JwtDate::new(1545264001)),
583 )
584 .err()
585 .unwrap();
586
587 assert_eq!(
588 err.to_string(),
589 "token expired (not after: 1545264000, now: 1545264001 [leeway: 0])"
590 );
591
592 let err = JwtSig::<MyExpirableClaims>::decode(
593 crate::test_files::JOSE_JWT_SIG_WITH_EXP,
594 &public_key,
595 &JwtValidator::strict(&JwtDate::new_with_leeway(1545262998, 1)),
596 )
597 .err()
598 .unwrap();
599
600 assert_eq!(
601 err.to_string(),
602 "token not yet valid (not before: 1545263000, now: 1545262998 [leeway: 1])"
603 );
604 }
605
606 #[test]
607 fn decode_step_cli_generated_token() {
608 let pk_pem = crate::test_files::RSA_2048_PK_7.parse::<Pem>().unwrap();
609 let pk = PrivateKey::from_pem(&pk_pem).expect("private_key 7");
610
611 let token: &str = "eyJhbGciOiJSU0EtT0FFUC0yNTYiLCJlbmMiOiJBMjU2R0NNIiwia2lkIjoiekZIVnNNOWRjNE9sSEl3dkVuVnFmS3pSajF1anFZR3NablhBY2duX0NxSSJ9.nYfNbetIs-ot-lc2_kWdDATEduiY-uEjF6FWUwYgsKHCrrqwbgKnx0qu7gdhghqJ3-WwgOywlwyQL8EUSxhYFJqkBuISTpyUEdBmcEAjKgdG9wDiajzsHF32awTmVQCVKbS45knI4rnNQj6o37h7JX1IU9p0ZLl5s8SQ4HhwwD6yRRdFgrCk811LIfSWhBOadQNX2AqODGAKU9Dz30BqZMlmrrh0yoGltandsYsNcsQTCgP6a6kFW9tSIx2PN7ox7PpPL2DIos6VS-7qpGgxHwvhxGmsBYLWqE9D3Q0oqx-oiAdxEgknU355Ld6PiJm_Q2K8SnS0Fk83laRDU2FRuQ.OP91ilVez0ErRRXt._Q2yVghXSubt44LbhS4iIF5A8vohaVasnsa_Xnx3cH6LU5kPr_gtSVNT5ZXV67mBz9xY2QNTlArtmR7z1yJrx2yftePxxDOBqz9Bdo189h1iQ_QrzLaGQogkuCFf2BuOAv4wYh6kJ4S835MXM6afNmItQcLV45LX_Bu02GuUa7syx9n4UU0KMKKpyEt79Fx-WN9BDrrQ-P-6eJTuiGi4x7d5O1Dp7ocV4CxgIA4faZznMi05fZsY4ebEP2O9VZ2zfMyv_KT7WeGyB2pcBfpupMGmKybqXweT8QdoFSMnfmE_vqnIxQzFHHVOYrrrUKu9T-294TicUdmgohqIAWzBq0_dm9HFrdD_BnHfOtfnFJn_uHdtsTPPA7L54Mb_81ijLrooZvbrIPXZsJc_YLq65vkYWtdbfA5JDZIK8jDZr-79YyBIrqsYgn3w2LwgNHuKU0Ro-zheV208xCsKYbOooX4E86YAgeltwt_W-VyD-06fKpADUyN2p8ck3AG5k2FV2LUJ7ZSkesixprcOmzDDIjmMrKFyqsbEj0Fwm4kk-RNO3M8T2b00IEdcrP3EUoDm6CG-Ur7NNOosR-7xuK4wnH9KN8x9ePRJeil3G0zWNpIsV9dQhDQaP42HdYfyJ28LLn1tBn9aG_L8Erp7_Yv0Y21VrpoNLLnsptms4N82le3iYXN6Rlk-R6Mv04SupNEOoOFG3NpPa7NF-phcoR65BIFgjonTLPabEanAxu3vBhqGiX9N9A57N1av10cjVhqiOY-FxUTlubIDaw00F1974AuDGhx5bWllVr-68qXEpmatyee8j7tJd1XlEvHy6CpDrOFh-fEFKuwy_e0iMKPEF_Jj2vdX5sb8DAriAkoUY_m9zL29RNZzhdsbWamUMlIFlObkj09f_Db1P-FdolZga3xfdteUOB5Ig9vecEm7B9iE4hBJ4HcJY8yMGT1XS1_b9MSmWkUz4wf_3DHbgK6EUlDjqrLfHWBWZr7--stcFlVmxChu5wL5zAqIDoIcoAT_yIU8DgeRknZImbRAhXBtGoGvcCoRktLTNwoul4I.5SToM5GtHWm-beaADd1uhg";
612 let jwe = Jwe::decode(token, &pk).unwrap();
613
614 #[derive(Deserialize)]
615 struct SomeJetClaims {
616 jet_ap: String,
617 prx_usr: String,
618 nbf: i64,
619 }
620
621 let payload = core::str::from_utf8(&jwe.payload).unwrap();
622 let jwk = JwtSig::<SomeJetClaims>::decode_dangerous(payload, &JwtValidator::no_check()).unwrap();
623
624 assert_eq!(jwk.claims.jet_ap, "rdp");
625 assert_eq!(jwk.claims.prx_usr, "username");
626 assert_eq!(jwk.claims.nbf, 1600373587);
627 }
628}