1use graph_error::{AuthorizationFailure, GraphFailure, AF};
2use serde::{Deserialize, Deserializer};
3use serde_aux::prelude::*;
4use serde_json::Value;
5use std::collections::HashMap;
6use std::fmt;
7use std::fmt::Display;
8use std::ops::{Add, Sub};
9
10use crate::identity::{AuthorizationResponse, IdToken};
11use graph_core::{cache::AsBearer, identity::Claims};
12use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation};
13use time::OffsetDateTime;
14
15fn deserialize_scope<'de, D>(scope: D) -> Result<Vec<String>, D::Error>
16where
17 D: Deserializer<'de>,
18{
19 let scope_string: Result<String, D::Error> = serde::Deserialize::deserialize(scope);
20 if let Ok(scope) = scope_string {
21 Ok(scope.split(' ').map(|scope| scope.to_owned()).collect())
22 } else {
23 Ok(vec![])
24 }
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
30struct PhantomToken {
31 access_token: String,
32 token_type: String,
33 #[serde(deserialize_with = "deserialize_number_from_string")]
34 expires_in: i64,
35 ext_expires_in: Option<i64>,
37 #[serde(default)]
38 #[serde(deserialize_with = "deserialize_scope")]
39 scope: Vec<String>,
40 refresh_token: Option<String>,
41 user_id: Option<String>,
42 id_token: Option<String>,
43 state: Option<String>,
44 session_state: Option<String>,
45 nonce: Option<String>,
46 correlation_id: Option<String>,
47 client_info: Option<String>,
48 #[serde(flatten)]
49 additional_fields: HashMap<String, Value>,
50}
51
52#[derive(Clone, Eq, PartialEq, Serialize)]
72pub struct Token {
73 pub access_token: String,
83 pub token_type: String,
84 #[serde(deserialize_with = "deserialize_number_from_string")]
85 pub expires_in: i64,
86 pub ext_expires_in: Option<i64>,
88 #[serde(default)]
89 #[serde(deserialize_with = "deserialize_scope")]
90 pub scope: Vec<String>,
91
92 pub refresh_token: Option<String>,
110 pub user_id: Option<String>,
111 pub id_token: Option<IdToken>,
112 pub state: Option<String>,
113 pub session_state: Option<String>,
114 pub nonce: Option<String>,
115 pub correlation_id: Option<String>,
116 pub client_info: Option<String>,
117 pub timestamp: Option<time::OffsetDateTime>,
118 pub expires_on: Option<time::OffsetDateTime>,
119 #[serde(flatten)]
121 pub additional_fields: HashMap<String, Value>,
122 #[serde(skip)]
123 pub log_pii: bool,
124}
125
126impl Token {
127 pub fn new<T: ToString, I: IntoIterator<Item = T>>(
128 token_type: &str,
129 expires_in: i64,
130 access_token: &str,
131 scope: I,
132 ) -> Token {
133 let timestamp = time::OffsetDateTime::now_utc();
134 let expires_on = timestamp.add(time::Duration::seconds(expires_in));
135
136 Token {
137 token_type: token_type.into(),
138 ext_expires_in: None,
139 expires_in,
140 scope: scope.into_iter().map(|s| s.to_string()).collect(),
141 access_token: access_token.into(),
142 refresh_token: None,
143 user_id: None,
144 id_token: None,
145 state: None,
146 session_state: None,
147 nonce: None,
148 correlation_id: None,
149 client_info: None,
150 timestamp: Some(timestamp),
151 expires_on: Some(expires_on),
152 additional_fields: Default::default(),
153 log_pii: false,
154 }
155 }
156
157 pub fn with_token_type(&mut self, s: &str) -> &mut Self {
167 self.token_type = s.into();
168 self
169 }
170
171 pub fn with_expires_in(&mut self, expires_in: i64) -> &mut Self {
181 self.expires_in = expires_in;
182 let timestamp = time::OffsetDateTime::now_utc();
183 self.expires_on = Some(timestamp.add(time::Duration::seconds(self.expires_in)));
184 self.timestamp = Some(timestamp);
185 self
186 }
187
188 pub fn with_scope<T: ToString, I: IntoIterator<Item = T>>(&mut self, scope: I) -> &mut Self {
198 self.scope = scope.into_iter().map(|s| s.to_string()).collect();
199 self
200 }
201
202 pub fn with_access_token(&mut self, s: &str) -> &mut Self {
212 self.access_token = s.into();
213 self
214 }
215
216 pub fn with_refresh_token(&mut self, s: &str) -> &mut Self {
226 self.refresh_token = Some(s.to_string());
227 self
228 }
229
230 pub fn with_user_id(&mut self, s: &str) -> &mut Self {
240 self.user_id = Some(s.to_string());
241 self
242 }
243
244 pub fn set_id_token(&mut self, s: &str) -> &mut Self {
254 self.id_token = Some(IdToken::new(s, None, None, None));
255 self
256 }
257
258 pub fn with_id_token(&mut self, id_token: IdToken) {
268 self.id_token = Some(id_token);
269 }
270
271 pub fn with_state(&mut self, s: &str) -> &mut Self {
282 self.state = Some(s.to_string());
283 self
284 }
285
286 pub fn enable_pii_logging(&mut self, log_pii: bool) {
292 self.log_pii = log_pii;
293 }
294
295 pub fn gen_timestamp(&mut self) {
323 let timestamp = time::OffsetDateTime::now_utc();
324 let expires_on = timestamp.add(time::Duration::seconds(self.expires_in));
325 self.timestamp = Some(timestamp);
326 self.expires_on = Some(expires_on);
327 }
328
329 pub fn is_expired(&self) -> bool {
340 if let Some(expires_on) = self.expires_on.as_ref() {
341 expires_on.lt(&OffsetDateTime::now_utc())
342 } else {
343 false
344 }
345 }
346
347 pub fn is_expired_sub(&self, duration: time::Duration) -> bool {
359 if let Some(expires_on) = self.expires_on.as_ref() {
360 expires_on.sub(duration).lt(&OffsetDateTime::now_utc())
361 } else {
362 false
363 }
364 }
365
366 pub fn elapsed(&self) -> Option<time::Duration> {
379 Some(self.expires_on? - self.timestamp?)
380 }
381
382 pub fn decode_header(&self) -> jsonwebtoken::errors::Result<jsonwebtoken::Header> {
383 let id_token = self
384 .id_token
385 .as_ref()
386 .ok_or(jsonwebtoken::errors::Error::from(
387 jsonwebtoken::errors::ErrorKind::InvalidToken,
388 ))?;
389 jsonwebtoken::decode_header(id_token.as_ref())
390 }
391
392 pub fn decode(
394 &self,
395 n: &str,
396 e: &str,
397 client_id: &str,
398 issuer: &str,
399 ) -> jsonwebtoken::errors::Result<TokenData<Claims>> {
400 let id_token = self
401 .id_token
402 .as_ref()
403 .ok_or(jsonwebtoken::errors::Error::from(
404 jsonwebtoken::errors::ErrorKind::InvalidToken,
405 ))?;
406 let mut validation = Validation::new(Algorithm::RS256);
407 validation.set_audience(&[client_id]);
408 validation.set_issuer(&[issuer]);
409
410 jsonwebtoken::decode::<Claims>(
411 id_token.as_ref(),
412 &DecodingKey::from_rsa_components(n, e).unwrap(),
413 &validation,
414 )
415 }
416}
417
418impl Default for Token {
419 fn default() -> Self {
420 Token {
421 token_type: String::new(),
422 expires_in: 0,
423 ext_expires_in: None,
424 scope: vec![],
425 access_token: String::new(),
426 refresh_token: None,
427 user_id: None,
428 id_token: None,
429 state: None,
430 session_state: None,
431 nonce: None,
432 correlation_id: None,
433 client_info: None,
434 timestamp: Some(time::OffsetDateTime::now_utc()),
435 expires_on: Some(
436 OffsetDateTime::from_unix_timestamp(0).unwrap_or(time::OffsetDateTime::UNIX_EPOCH),
437 ),
438 additional_fields: Default::default(),
439 log_pii: false,
440 }
441 }
442}
443
444impl TryFrom<AuthorizationResponse> for Token {
445 type Error = AuthorizationFailure;
446
447 fn try_from(value: AuthorizationResponse) -> Result<Self, Self::Error> {
448 let id_token = IdToken::try_from(value.clone()).ok();
449
450 Ok(Token {
451 access_token: value
452 .access_token
453 .ok_or_else(|| AF::msg_err("access_token", "access_token is None"))?,
454 token_type: "Bearer".to_string(),
455 expires_in: value.expires_in.unwrap_or_default(),
456 ext_expires_in: None,
457 scope: vec![],
458 refresh_token: None,
459 user_id: None,
460 id_token,
461 state: value.state,
462 session_state: value.session_state,
463 nonce: value.nonce,
464 correlation_id: None,
465 client_info: None,
466 timestamp: None,
467 expires_on: None,
468 additional_fields: Default::default(),
469 log_pii: false,
470 })
471 }
472}
473
474impl Display for Token {
475 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
476 write!(f, "{}", self.access_token)
477 }
478}
479
480impl AsBearer for Token {
481 fn as_bearer(&self) -> String {
482 self.access_token.to_string()
483 }
484}
485
486impl TryFrom<&str> for Token {
487 type Error = GraphFailure;
488
489 fn try_from(value: &str) -> Result<Self, Self::Error> {
490 Ok(serde_json::from_str(value)?)
491 }
492}
493
494impl TryFrom<reqwest::blocking::RequestBuilder> for Token {
495 type Error = GraphFailure;
496
497 fn try_from(value: reqwest::blocking::RequestBuilder) -> Result<Self, Self::Error> {
498 let response = value.send()?;
499 Token::try_from(response)
500 }
501}
502
503impl TryFrom<Result<reqwest::blocking::Response, reqwest::Error>> for Token {
504 type Error = GraphFailure;
505
506 fn try_from(
507 value: Result<reqwest::blocking::Response, reqwest::Error>,
508 ) -> Result<Self, Self::Error> {
509 let response = value?;
510 Token::try_from(response)
511 }
512}
513
514impl TryFrom<reqwest::blocking::Response> for Token {
515 type Error = GraphFailure;
516
517 fn try_from(value: reqwest::blocking::Response) -> Result<Self, Self::Error> {
518 Ok(value.json::<Token>()?)
519 }
520}
521
522impl fmt::Debug for Token {
523 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
524 if self.log_pii {
525 f.debug_struct("MsalAccessToken")
526 .field("bearer_token", &self.access_token)
527 .field("refresh_token", &self.refresh_token)
528 .field("token_type", &self.token_type)
529 .field("expires_in", &self.expires_in)
530 .field("scope", &self.scope)
531 .field("user_id", &self.user_id)
532 .field("id_token", &self.id_token)
533 .field("state", &self.state)
534 .field("timestamp", &self.timestamp)
535 .field("expires_on", &self.expires_on)
536 .field("additional_fields", &self.additional_fields)
537 .finish()
538 } else {
539 f.debug_struct("MsalAccessToken")
540 .field(
541 "bearer_token",
542 &"[REDACTED] - call enable_pii_logging(true) to log value",
543 )
544 .field(
545 "refresh_token",
546 &"[REDACTED] - call enable_pii_logging(true) to log value",
547 )
548 .field("token_type", &self.token_type)
549 .field("expires_in", &self.expires_in)
550 .field("scope", &self.scope)
551 .field("user_id", &self.user_id)
552 .field(
553 "id_token",
554 &"[REDACTED] - call enable_pii_logging(true) to log value",
555 )
556 .field("state", &self.state)
557 .field("timestamp", &self.timestamp)
558 .field("expires_on", &self.expires_on)
559 .field("additional_fields", &self.additional_fields)
560 .finish()
561 }
562 }
563}
564
565impl AsRef<str> for Token {
566 fn as_ref(&self) -> &str {
567 self.access_token.as_str()
568 }
569}
570
571impl<'de> Deserialize<'de> for Token {
572 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
573 where
574 D: Deserializer<'de>,
575 {
576 let phantom_access_token: PhantomToken = Deserialize::deserialize(deserializer)?;
577 let timestamp = OffsetDateTime::now_utc();
578 let expires_on = timestamp.add(time::Duration::seconds(phantom_access_token.expires_in));
579 let id_token = phantom_access_token
580 .id_token
581 .map(|id_token_string| IdToken::new(id_token_string.as_ref(), None, None, None));
582
583 let token = Token {
584 access_token: phantom_access_token.access_token,
585 token_type: phantom_access_token.token_type,
586 expires_in: phantom_access_token.expires_in,
587 ext_expires_in: phantom_access_token.ext_expires_in,
588 scope: phantom_access_token.scope,
589 refresh_token: phantom_access_token.refresh_token,
590 user_id: phantom_access_token.user_id,
591 id_token,
592 state: phantom_access_token.state,
593 session_state: phantom_access_token.session_state,
594 nonce: phantom_access_token.nonce,
595 correlation_id: phantom_access_token.correlation_id,
596 client_info: phantom_access_token.client_info,
597 timestamp: Some(timestamp),
598 expires_on: Some(expires_on),
599 additional_fields: phantom_access_token.additional_fields,
600 log_pii: false,
601 };
602
603 Ok(token)
606 }
607}
608
609#[cfg(test)]
610mod test {
611 use super::*;
612
613 #[test]
614 fn is_expired_test() {
615 let mut access_token = Token::default();
616 access_token.with_expires_in(5);
617 std::thread::sleep(std::time::Duration::from_secs(6));
618 assert!(access_token.is_expired());
619
620 let mut access_token = Token::default();
621 access_token.with_expires_in(8);
622 std::thread::sleep(std::time::Duration::from_secs(4));
623 assert!(!access_token.is_expired());
624 }
625
626 pub const ACCESS_TOKEN_INT: &str = r#"{
627 "access_token": "fasdfasdfasfdasdfasfsdf",
628 "token_type": "Bearer",
629 "expires_in": 65874,
630 "scope": null,
631 "refresh_token": null,
632 "user_id": "santa@north.pole.com",
633 "id_token": "789aasdf-asdf",
634 "state": null,
635 "timestamp": "2020-10-27T16:31:38.788098400Z"
636 }"#;
637
638 pub const ACCESS_TOKEN_STRING: &str = r#"{
639 "access_token": "fasdfasdfasfdasdfasfsdf",
640 "token_type": "Bearer",
641 "expires_in": "65874",
642 "scope": null,
643 "refresh_token": null,
644 "user_id": "helpers@north.pole.com",
645 "id_token": "789aasdf-asdf",
646 "state": null,
647 "timestamp": "2020-10-27T16:31:38.788098400Z"
648 }"#;
649
650 #[test]
651 pub fn test_deserialize() {
652 let _token: Token = serde_json::from_str(ACCESS_TOKEN_INT).unwrap();
653 let _token: Token = serde_json::from_str(ACCESS_TOKEN_STRING).unwrap();
654 }
655
656 #[test]
657 pub fn try_from_url_authorization_response() {
658 let authorization_response = AuthorizationResponse {
659 code: Some("code".into()),
660 id_token: Some("id_token".into()),
661 expires_in: Some(3600),
662 access_token: Some("token".into()),
663 state: Some("state".into()),
664 session_state: Some("session_state".into()),
665 nonce: None,
666 error: None,
667 error_description: None,
668 error_uri: None,
669 additional_fields: Default::default(),
670 log_pii: false,
671 };
672
673 let token = Token::try_from(authorization_response).unwrap();
674 assert_eq!(
675 token.id_token,
676 Some(IdToken::new(
677 "id_token",
678 Some("code"),
679 Some("state"),
680 Some("session_state")
681 ))
682 );
683 assert_eq!(token.access_token, "token".to_string());
684 assert_eq!(token.state, Some("state".to_string()));
685 assert_eq!(token.session_state, Some("session_state".to_string()));
686 assert_eq!(token.expires_in, 3600);
687 }
688}