1use crate::error::jwt_error::JWTError;
11use crate::parser::jwt_claims_parser::JWTClaimsParser;
12use chrono::{DateTime, Duration, TimeZone, Utc};
13use jwt_compact::{Claims, Token};
14use jwt_simple::prelude::{Audiences, UnixTimeStamp};
15use pdk_core::logger::debug;
16use serde::{Deserialize, Serialize};
17use serde_json::Value;
18use std::collections::HashMap;
19
20pub mod claim_names {
22 pub const AUD_CLAIM: &str = "aud";
24 pub const SUBJECT_CLAIM: &str = "sub";
26 pub const ISSUER_CLAIM: &str = "iss";
28 pub const JTI_CLAIM: &str = "jti";
30 pub const NONCE_CLAIM: &str = "nonce";
32 pub const EXPIRATION_CLAIM: &str = "exp";
34 pub const NOT_BEFORE_CLAIM: &str = "nbf";
36 pub const ISSUED_AT_CLAIM: &str = "iat";
38}
39
40use claim_names::*;
41use pdk_script::IntoValue;
42
43#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
45pub struct CustomClaims {
46 #[serde(flatten)]
48 pub inner: HashMap<String, serde_json::Value>,
49}
50
51#[derive(Clone, Debug, PartialEq, Eq, Serialize)]
54pub struct JWTClaims {
55 #[serde(flatten)]
56 pub(crate) claims: HashMap<String, serde_json::Value>,
57 #[serde(flatten)]
58 pub(crate) headers: HashMap<String, serde_json::Value>,
59}
60
61impl JWTClaims {
62 pub fn new(
65 expiration: Option<DateTime<Utc>>,
66 not_before: Option<DateTime<Utc>>,
67 issued_at: Option<DateTime<Utc>>,
68 claims: HashMap<String, Value>,
69 headers: HashMap<String, Value>,
70 ) -> Result<Self, JWTError> {
71 let mut claims = claims;
72
73 extract_datetime_value(EXPIRATION_CLAIM, &expiration, &mut claims);
74 extract_datetime_value(NOT_BEFORE_CLAIM, ¬_before, &mut claims);
75 extract_datetime_value(ISSUED_AT_CLAIM, &issued_at, &mut claims);
76
77 Ok(JWTClaims { claims, headers })
78 }
79
80 pub fn audience(&self) -> Option<Result<Vec<String>, JWTError>> {
82 self.claims.get(AUD_CLAIM).map(parse_audiences_to_list)
83 }
84
85 pub fn not_before(&self) -> Option<DateTime<Utc>> {
87 self.get_claim(NOT_BEFORE_CLAIM)
88 }
89
90 pub fn expiration(&self) -> Option<DateTime<Utc>> {
92 self.get_claim(EXPIRATION_CLAIM)
93 }
94
95 pub fn issued_at(&self) -> Option<DateTime<Utc>> {
97 self.get_claim(ISSUED_AT_CLAIM)
98 }
99
100 pub fn issuer(&self) -> Option<String> {
102 self.get_claim(ISSUER_CLAIM)
103 }
104
105 pub fn jti(&self) -> Option<String> {
107 self.get_claim(JTI_CLAIM)
108 }
109
110 pub fn nonce(&self) -> Option<String> {
112 self.get_claim(NONCE_CLAIM)
113 }
114
115 pub fn subject(&self) -> Option<String> {
117 self.get_claim(SUBJECT_CLAIM)
118 }
119
120 pub fn has_claim(&self, name: &str) -> bool {
122 self.claims.contains_key(name)
123 }
124
125 pub fn get_claim<T>(&self, name: &str) -> Option<T>
127 where
128 T: ValueRetrieval,
129 {
130 let value = self.claims.get(name);
131
132 match value {
133 Some(val) => T::retrieve(val.clone()),
134 None => None,
135 }
136 }
137
138 pub fn has_header(&self, name: &str) -> bool {
140 self.headers.contains_key(name)
141 }
142
143 pub fn get_header(&self, name: &str) -> Option<String> {
145 let value = self.headers.get(name);
146
147 match value {
148 Some(val) => val.as_str().map(|str| str.into()),
149 None => None,
150 }
151 }
152
153 pub fn to_serde(&self) -> Result<Value, JWTError> {
155 serde_json::to_value(self)
156 .map_err(|e| JWTError::TokenParseFailed(format!("Error serializing claims: {e}")))
157 }
158
159 pub fn get_claims(&self) -> pdk_script::Value {
161 self.claims.clone().into_value()
162 }
163
164 pub fn get_headers(&self) -> pdk_script::Value {
166 self.headers.clone().into_value()
167 }
168}
169
170fn parse_audiences_to_list(aud: &Value) -> Result<Vec<String>, JWTError> {
171 let correct_type = if aud.is_string() || aud.is_array() {
172 Ok(aud)
173 } else {
174 Err(JWTError::ClaimAudValidationFailed)
175 }?;
176
177 if correct_type.is_string() {
178 let aud_str = correct_type
179 .as_str()
180 .ok_or(JWTError::ClaimAudValidationFailed)?;
181 Ok(vec![String::from(aud_str)])
182 } else {
183 let aud_array = correct_type
184 .as_array()
185 .ok_or(JWTError::ClaimAudValidationFailed)?;
186
187 Ok(aud_array
188 .iter()
189 .filter_map(|v| v.as_str().map(String::from))
190 .collect::<Vec<String>>())
191 }
192}
193
194pub trait TryFromUntrustedJWTCompactClaims {
196 fn try_from_untrusted_jwt_compact_claims(
197 claims: Claims<CustomClaims>,
198 token: String,
199 ) -> Result<JWTClaims, JWTError>;
200}
201
202impl TryFromUntrustedJWTCompactClaims for JWTClaims {
203 fn try_from_untrusted_jwt_compact_claims(
204 claims: Claims<CustomClaims>,
205 token: String,
206 ) -> Result<JWTClaims, JWTError> {
207 let headers: HashMap<String, Value> = JWTClaimsParser::parse_headers(token)?;
208
209 JWTClaims::new(
210 claims.expiration,
211 claims.not_before,
212 claims.issued_at,
213 claims.custom.inner,
214 headers,
215 )
216 }
217}
218
219pub trait TryFromTrustedJWTCompactClaims {
221 fn try_from_trusted_jwt_compact_claims(
222 claims: Token<CustomClaims>,
223 token: String,
224 ) -> Result<JWTClaims, JWTError>;
225}
226
227impl TryFromTrustedJWTCompactClaims for JWTClaims {
228 fn try_from_trusted_jwt_compact_claims(
229 trusted_token: Token<CustomClaims>,
230 token_str: String,
231 ) -> Result<Self, JWTError> {
232 let claims = trusted_token.claims().to_owned();
233
234 JWTClaims::try_from_untrusted_jwt_compact_claims(claims, token_str)
235 }
236}
237
238pub trait TryFromTrustedJWTSimpleClaims {
240 fn try_from_trusted_jwt_simple_claims(
241 claims: jwt_simple::claims::JWTClaims<CustomClaims>,
242 token: String,
243 ) -> Result<JWTClaims, JWTError>;
244}
245
246impl TryFromTrustedJWTSimpleClaims for JWTClaims {
247 fn try_from_trusted_jwt_simple_claims(
248 claims: jwt_simple::claims::JWTClaims<CustomClaims>,
249 token: String,
250 ) -> Result<Self, JWTError> {
251 let expiration = utc_from_duration(&claims.expires_at);
252 let not_before = utc_from_duration(&claims.invalid_before);
253 let issued_at = utc_from_duration(&claims.issued_at);
254 let mut custom_claims = claims.custom.inner;
255
256 extract_value(ISSUER_CLAIM, &claims.issuer, &mut custom_claims);
257 extract_value(SUBJECT_CLAIM, &claims.subject, &mut custom_claims);
258 extract_value(JTI_CLAIM, &claims.jwt_id, &mut custom_claims);
259 extract_value(NONCE_CLAIM, &claims.nonce, &mut custom_claims);
260 parse_custom_claim_audiences_to_list(claims.audiences, &mut custom_claims);
261
262 let headers: HashMap<String, Value> = JWTClaimsParser::parse_headers(token)?;
263
264 JWTClaims::new(expiration, not_before, issued_at, custom_claims, headers)
265 }
266}
267
268fn parse_custom_claim_audiences_to_list(
269 aud: Option<Audiences>,
270 custom_claims: &mut HashMap<String, Value>,
271) {
272 if let Some(audiences) = aud {
273 match audiences {
274 Audiences::AsSet(set) => {
275 if let Ok(value_set) = serde_json::to_value(set) {
276 custom_claims.insert(AUD_CLAIM.to_string(), value_set);
277 }
278 }
279 Audiences::AsString(str) => {
280 custom_claims.insert(AUD_CLAIM.to_string(), Value::String(str));
281 }
282 }
283 }
284}
285
286pub trait ValueRetrieval {
289 fn retrieve(claim_value: serde_json::Value) -> Option<Self>
290 where
291 Self: Sized;
292}
293
294impl ValueRetrieval for Value {
295 fn retrieve(claim_value: serde_json::Value) -> Option<Self> {
296 Some(claim_value)
297 }
298}
299
300impl ValueRetrieval for String {
301 fn retrieve(claim_value: serde_json::Value) -> Option<Self> {
302 Some(claim_value.as_str().unwrap().into())
303 }
304}
305
306impl ValueRetrieval for f64 {
307 fn retrieve(claim_value: serde_json::Value) -> Option<Self> {
308 claim_value.as_f64()
309 }
310}
311
312impl ValueRetrieval for Vec<String> {
313 fn retrieve(claim_value: serde_json::Value) -> Option<Self> {
314 claim_value.as_array().map(|val| {
315 val.iter()
316 .map(|item| item.as_str().unwrap().into())
317 .collect()
318 })
319 }
320}
321
322impl ValueRetrieval for DateTime<Utc> {
323 fn retrieve(claim_value: serde_json::Value) -> Option<Self> {
324 claim_value.as_i64().map(|val| Utc.timestamp_nanos(val))
325 }
326}
327
328fn extract_datetime_value(
331 key: &str,
332 datetime: &Option<DateTime<Utc>>,
333 custom_claims: &mut HashMap<String, Value>,
334) {
335 if datetime.is_some() {
336 custom_claims.insert(
337 key.to_string(),
338 Value::from(datetime.unwrap().timestamp_nanos_opt().unwrap()),
341 );
342 }
343}
344
345fn extract_value(
346 key: &str,
347 claim_value: &Option<String>,
348 custom_claims: &mut HashMap<String, Value>,
349) {
350 if let Some(claim) = claim_value {
351 let value = Value::String(claim.to_string());
352 custom_claims.insert(key.to_string(), value);
353 }
354}
355
356fn utc_from_duration(duration: &Option<UnixTimeStamp>) -> Option<DateTime<Utc>> {
357 let start_utc_date = match Utc.with_ymd_and_hms(1970, 1, 1, 0, 0, 0).single() {
358 None => {
359 debug!("Using default start_utc_date");
361 DateTime::default()
362 }
363 Some(date) => date,
364 };
365
366 duration
367 .as_ref()
368 .and_then(|timestamp| Duration::try_milliseconds(timestamp.as_millis() as i64))
369 .map(|some_duration| start_utc_date + some_duration)
370}
371
372#[cfg(test)]
373mod tests {
374 use crate::api::JWTClaimsParser;
375 use pdk_script::IntoValue;
376 use serde_json::json;
377
378 fn valid_token() -> String {
379 "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL3JlZGFjdGVkLmFuei5jb20iLCJhdWQiOiJyZWRhY3RlZC5hbnouY29tL2VtYWlsQWRkcmVzcy1yZWRhY3RlZEBhbnouY29tIiwic3ViIjoicmVkYWN0ZWQuYW56LmNvbS9lbWFpbEFkZHJlc3MtcmVkYWN0ZWRAYW56LmNvbSIsImV4cCI6MTY5Mzk4ODY1Ni41OTcsInNjb3BlcyI6WyJyZWRhY3RlZDEiLCJyZWRhY3RlZCJdLCJhbXIiOlsicmVkYWN0ZWQiXSwiYWNyIjoicmVkYWN0ZWQifQ.EswBSyD9976PVC6o4tWwsT5rGTf2RJcvL7hdpcTwXUo"
380 .to_string()
381 }
382
383 fn no_exp_token() -> String {
384 "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJodHRwczovL3JlZGFjdGVkLmFuei5jb20iLCJhdWQiOiJyZWRhY3RlZC5hbnouY29tL2VtYWlsQWRkcmVzcy1yZWRhY3RlZEBhbnouY29tIiwic3ViIjoicmVkYWN0ZWQuYW56LmNvbS9lbWFpbEFkZHJlc3MtcmVkYWN0ZWRAYW56LmNvbSIsInNjb3BlcyI6WyJyZWRhY3RlZDEiLCJyZWRhY3RlZCJdLCJhbXIiOlsicmVkYWN0ZWQiXSwiYWNyIjoicmVkYWN0ZWQifQ.vxNPrXPIG8VuGb9Bf2MjLFXPRX9EmN97WZtM-jzQniE"
385 .to_string()
386 }
387
388 mod tests_accessors {
389 use crate::{error::jwt_error::JWTError, model::claims::JWTClaims};
390 use chrono::{DateTime, Utc};
391 use serde_json::Value;
392 use std::collections::HashMap;
393
394 impl JWTClaims {
395 pub fn empty() -> Result<Self, JWTError> {
397 JWTClaims::new(None, None, None, HashMap::default(), HashMap::default())
398 }
399
400 pub fn just_expiration(expiration: DateTime<Utc>) -> Result<Self, JWTError> {
402 JWTClaims::new(
403 Some(expiration),
404 None,
405 None,
406 HashMap::default(),
407 HashMap::default(),
408 )
409 }
410
411 pub fn just_not_before(not_before: DateTime<Utc>) -> Result<Self, JWTError> {
413 JWTClaims::new(
414 None,
415 Some(not_before),
416 None,
417 HashMap::default(),
418 HashMap::default(),
419 )
420 }
421
422 pub fn from_custom_claims(
424 custom_claims: HashMap<String, Value>,
425 ) -> Result<Self, JWTError> {
426 JWTClaims::new(None, None, None, custom_claims, HashMap::default())
427 }
428
429 pub fn custom_claims_count(&self) -> usize {
431 self.claims.len()
432 }
433 }
434 }
435
436 #[test]
437 fn basic_methods() {
438 let claims = JWTClaimsParser::parse(no_exp_token()).unwrap();
439
440 assert!(claims.has_header("alg"));
441 assert!(!claims.has_claim("alg"));
442 assert!(claims.has_claim("iss"));
443 assert!(!claims.has_header("iss"));
444
445 assert_eq!(
446 claims.get_headers(),
447 json!({
448 "alg": "HS256",
449 "typ": "JWT"
450 })
451 .into_value()
452 );
453
454 assert_eq!(
455 claims.get_claims(),
456 json!(
457 {
458 "iss": "https://redacted.anz.com",
459 "aud": "redacted.anz.com/emailAddress-redacted@anz.com",
460 "sub": "redacted.anz.com/emailAddress-redacted@anz.com",
461 "scopes": [
462 "redacted1",
463 "redacted"
464 ],
465 "amr": [
466 "redacted"
467 ],
468 "acr": "redacted"
469 }
470 )
471 .into_value()
472 );
473 }
474
475 mod audience_claim_extraction {
476 use crate::model::claims::{claim_names::AUD_CLAIM, JWTClaims};
477 use serde_json::Value::*;
478 use serde_json::{Map, Number, Value};
479 use std::collections::HashMap;
480 use std::string::String;
481
482 #[test]
483 fn succeeds_on_string_type() {
484 let token_claims = JWTClaims::from_custom_claims(HashMap::from([(
485 AUD_CLAIM.to_string(),
486 String(String::from("an_audience")),
487 )]))
488 .unwrap();
489
490 assert!(token_claims.audience().is_some());
491 assert!(token_claims.audience().unwrap().is_ok());
492 assert_eq!(
493 token_claims.audience().unwrap().unwrap(),
494 vec![String::from("an_audience")]
495 );
496 }
497
498 #[test]
499 fn succeeds_on_list_type() {
500 let token_claims = JWTClaims::from_custom_claims(HashMap::from([(
501 AUD_CLAIM.to_string(),
502 Array(vec![
503 String(String::from("an_audience1")),
504 String(String::from("an_audience2")),
505 ]),
506 )]))
507 .unwrap();
508
509 assert!(token_claims.audience().is_some());
510 assert!(token_claims.audience().unwrap().is_ok());
511 assert_eq!(
512 token_claims.audience().unwrap().unwrap(),
513 vec![String::from("an_audience1"), String::from("an_audience2")]
514 );
515 }
516
517 #[test]
518 fn fails_on_bool_type() {
519 let token_claims =
520 JWTClaims::from_custom_claims(HashMap::from([(AUD_CLAIM.to_string(), Bool(true))]))
521 .unwrap();
522
523 assert!(token_claims.audience().is_some());
524 assert!(token_claims.audience().unwrap().is_err());
525 }
526
527 #[test]
528 fn fails_on_number_type() {
529 let token_claims = JWTClaims::from_custom_claims(HashMap::from([(
530 AUD_CLAIM.to_string(),
531 Number(Number::from(1u64)),
532 )]))
533 .unwrap();
534
535 assert!(token_claims.audience().is_some());
536 assert!(token_claims.audience().unwrap().is_err());
537 }
538
539 #[test]
540 fn fails_on_null_type() {
541 let token_claims =
542 JWTClaims::from_custom_claims(HashMap::from([(AUD_CLAIM.to_string(), Null)]))
543 .unwrap();
544
545 assert!(token_claims.audience().is_some());
546 assert!(token_claims.audience().unwrap().is_err());
547 }
548
549 #[test]
550 fn fails_on_object_type() {
551 let token_claims = JWTClaims::from_custom_claims(HashMap::from([(
552 AUD_CLAIM.to_string(),
553 Object(Map::<String, Value>::new()),
554 )]))
555 .unwrap();
556
557 assert!(token_claims.audience().is_some());
558 assert!(token_claims.audience().unwrap().is_err());
559 }
560 }
561
562 mod jwt_simple_custom_claims_parsing {
563 use super::super::*;
564
565 #[test]
566 fn parse_claims_and_exp() {
567 let claims: jwt_simple::claims::JWTClaims<CustomClaims> = serde_json::from_str(r#"{"exp":1617757825, "jti": "key1", "sub": "anSubject", "iss": "anIss", "nonce": "anNonce"}"#).unwrap();
568
569 let expiration = Duration::try_seconds(1617757825).unwrap();
570
571 let jwt_claims: JWTClaims =
572 JWTClaims::try_from_trusted_jwt_simple_claims(claims, super::valid_token())
573 .unwrap();
574
575 assert_eq!(
577 jwt_claims
578 .get_claim::<DateTime<Utc>>("exp")
579 .unwrap()
580 .timestamp_millis(),
581 expiration.num_milliseconds()
582 );
583 assert_eq!(jwt_claims.get_claim::<String>("jti").unwrap(), "key1");
584 assert_eq!(jwt_claims.get_claim::<String>("sub").unwrap(), "anSubject");
585 assert_eq!(jwt_claims.get_claim::<String>("iss").unwrap(), "anIss");
586 assert_eq!(jwt_claims.get_claim::<String>("nonce").unwrap(), "anNonce");
587 assert!(jwt_claims.audience().is_none());
588
589 assert_eq!(
591 jwt_claims.expiration().unwrap().timestamp_millis(),
592 expiration.num_milliseconds()
593 );
594 assert_eq!(jwt_claims.jti().unwrap(), "key1");
595 assert_eq!(jwt_claims.subject().unwrap(), "anSubject");
596 assert_eq!(jwt_claims.issuer().unwrap(), "anIss");
597 assert_eq!(jwt_claims.nonce().unwrap(), "anNonce");
598 }
599
600 #[test]
601 fn parse_custom_and_audiences_with_vec_claims() {
602 let claims: jwt_simple::claims::JWTClaims<CustomClaims> = serde_json::from_str(r#"{"exp":1617757827, "otherClaim": "anCustomclaim", "aud": ["audience1", "audience2"]}"#).unwrap();
603
604 let jwt_claims =
605 JWTClaims::try_from_trusted_jwt_simple_claims(claims, super::valid_token())
606 .unwrap();
607
608 assert_eq!(
609 jwt_claims.get_claim::<String>("otherClaim").unwrap(),
610 "anCustomclaim"
611 );
612
613 assert!(jwt_claims.audience().is_some());
614
615 let audiences = jwt_claims.audience().unwrap();
616 assert!(audiences.is_ok());
617
618 let audiences_vec = audiences.unwrap();
619
620 assert_eq!(audiences_vec.len(), 2);
621 assert!(audiences_vec.contains(&String::from("audience1")));
622 assert!(audiences_vec.contains(&String::from("audience2")));
623 }
624
625 #[test]
626 fn parse_audiences_with_string_claims() {
627 let claims: jwt_simple::claims::JWTClaims<CustomClaims> =
628 serde_json::from_str(r#"{"exp":1617757827, "aud": "audience1"}"#).unwrap();
629
630 let jwt_claims =
631 JWTClaims::try_from_trusted_jwt_simple_claims(claims, super::valid_token())
632 .unwrap();
633
634 assert!(jwt_claims.audience().is_some());
635
636 let audiences = jwt_claims.audience().unwrap();
637 assert!(audiences.is_ok());
638
639 let audiences_vec = audiences.unwrap();
640
641 assert_eq!(audiences_vec.len(), 1);
642 assert!(audiences_vec.contains(&String::from("audience1")));
643 }
644
645 #[test]
646 fn parse_other_claims() {
647 let claims: jwt_simple::claims::JWTClaims<CustomClaims> =
648 serde_json::from_str(r#"{"exp":1617757827, "iat":1617757827, "nbf": 1617757827}"#)
649 .unwrap();
650
651 let duration = Duration::try_seconds(1617757827).unwrap();
652
653 let jwt_claims =
654 JWTClaims::try_from_trusted_jwt_simple_claims(claims, super::valid_token())
655 .unwrap();
656
657 assert_eq!(
659 jwt_claims
660 .get_claim::<DateTime<Utc>>("nbf")
661 .unwrap()
662 .timestamp_millis(),
663 duration.num_milliseconds()
664 );
665 assert_eq!(
666 jwt_claims
667 .get_claim::<DateTime<Utc>>("iat")
668 .unwrap()
669 .timestamp_millis(),
670 duration.num_milliseconds()
671 );
672 assert_eq!(
673 jwt_claims
674 .get_claim::<DateTime<Utc>>("exp")
675 .unwrap()
676 .timestamp_millis(),
677 duration.num_milliseconds()
678 );
679
680 assert_eq!(
682 jwt_claims.not_before().unwrap().timestamp_millis(),
683 duration.num_milliseconds()
684 );
685 assert_eq!(
686 jwt_claims.issued_at().unwrap().timestamp_millis(),
687 duration.num_milliseconds()
688 );
689 assert_eq!(
690 jwt_claims.expiration().unwrap().timestamp_millis(),
691 duration.num_milliseconds()
692 );
693
694 assert_eq!(jwt_claims.get_header("alg").unwrap(), "HS256",);
697
698 assert_eq!(jwt_claims.get_header("typ").unwrap(), "JWT",);
699 }
700 }
701}