graph_oauth/identity/
token.rs

1use graph_error::{AuthorizationFailure, GraphFailure, AF};
2use serde::{Deserialize, Deserializer};
3use serde_aux::prelude::*;
4use serde_json::Value;
5use std::collections::HashMap;
6use std::fmt;
7use std::fmt::Display;
8use std::ops::{Add, Sub};
9
10use crate::identity::{AuthorizationResponse, IdToken};
11use graph_core::{cache::AsBearer, identity::Claims};
12use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation};
13use time::OffsetDateTime;
14
15fn deserialize_scope<'de, D>(scope: D) -> Result<Vec<String>, D::Error>
16where
17    D: Deserializer<'de>,
18{
19    let scope_string: Result<String, D::Error> = serde::Deserialize::deserialize(scope);
20    if let Ok(scope) = scope_string {
21        Ok(scope.split(' ').map(|scope| scope.to_owned()).collect())
22    } else {
23        Ok(vec![])
24    }
25}
26
27// Used to set timestamp based on expires in
28// which can only be done after deserialization.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30struct PhantomToken {
31    access_token: String,
32    token_type: String,
33    #[serde(deserialize_with = "deserialize_number_from_string")]
34    expires_in: i64,
35    /// Legacy version of expires_in
36    ext_expires_in: Option<i64>,
37    #[serde(default)]
38    #[serde(deserialize_with = "deserialize_scope")]
39    scope: Vec<String>,
40    refresh_token: Option<String>,
41    user_id: Option<String>,
42    id_token: Option<String>,
43    state: Option<String>,
44    session_state: Option<String>,
45    nonce: Option<String>,
46    correlation_id: Option<String>,
47    client_info: Option<String>,
48    #[serde(flatten)]
49    additional_fields: HashMap<String, Value>,
50}
51
52/// An access token is a security token issued by an authorization server as part of an OAuth 2.0 flow.
53/// It contains information about the user and the resource for which the token is intended.
54/// The information can be used to access web APIs and other protected resources.
55/// Resources validate access tokens to grant access to a client application.
56/// For more information, see [Access tokens in the Microsoft Identity Platform](https://learn.microsoft.com/en-us/azure/active-directory/develop/access-tokens)
57///
58/// For more info from the specification see [Successful Response](https://www.rfc-editor.org/rfc/rfc6749.html#section-5.1)
59///
60/// Create a new AccessToken.
61/// # Example
62/// ```
63/// # use graph_oauth::Token;
64/// let token_response = Token::new("Bearer", 3600, "ASODFIUJ34KJ;LADSK", vec!["User.Read"]);
65/// ```
66/// The [Token::decode] method parses the id token into a JWT and returns it. Calling
67/// [Token::decode] when the [Token]'s `id_token` field is None returns an error result.
68/// For more info see:
69/// [Microsoft identity platform access tokens](https://docs.microsoft.com/en-us/azure/active-directory/develop/access-tokens)
70/// ```
71#[derive(Clone, Eq, PartialEq, Serialize)]
72pub struct Token {
73    /// Access tokens are credentials used to access protected resources.  An
74    /// access token is a string representing an authorization issued to the
75    /// client.  The string is usually opaque to the client.  Tokens
76    /// represent specific scopes and durations of access, granted by the
77    /// resource owner, and enforced by the resource server and authorization
78    /// server.
79    ///
80    /// See [Access Token](https://www.rfc-editor.org/rfc/rfc6749.html#section-1.4) in
81    /// the specification
82    pub access_token: String,
83    pub token_type: String,
84    #[serde(deserialize_with = "deserialize_number_from_string")]
85    pub expires_in: i64,
86    /// Legacy version of expires_in
87    pub ext_expires_in: Option<i64>,
88    #[serde(default)]
89    #[serde(deserialize_with = "deserialize_scope")]
90    pub scope: Vec<String>,
91
92    /// Refresh tokens are credentials used to obtain access tokens. Refresh tokens are issued
93    /// to the client by the authorization server and are used to obtain a new access token
94    /// when the current access token becomes invalid or expires, or to obtain additional
95    /// access tokens with identical or narrower scope (access tokens may have a shorter
96    /// lifetime and fewer permissions than authorized by the resource owner).
97    /// Issuing a refresh token is optional at the discretion of the authorization server.
98    /// If the authorization server issues a refresh token, it is included when issuing an
99    /// access token
100    ///
101    /// See [Refresh Token](https://www.rfc-editor.org/rfc/rfc6749.html#section-1.5) in the specification
102    ///
103    /// Because access tokens are valid for only a short period of time,
104    /// authorization servers sometimes issue a refresh token at the same
105    /// time the access token is issued. The client application can then
106    /// exchange this refresh token for a new access token when needed.
107    /// For more information, see
108    /// [Refresh tokens in the Microsoft identity platform.](https://learn.microsoft.com/en-us/azure/active-directory/develop/refresh-tokens)
109    pub refresh_token: Option<String>,
110    pub user_id: Option<String>,
111    pub id_token: Option<IdToken>,
112    pub state: Option<String>,
113    pub session_state: Option<String>,
114    pub nonce: Option<String>,
115    pub correlation_id: Option<String>,
116    pub client_info: Option<String>,
117    pub timestamp: Option<time::OffsetDateTime>,
118    pub expires_on: Option<time::OffsetDateTime>,
119    /// Any extra returned fields for AccessToken.
120    #[serde(flatten)]
121    pub additional_fields: HashMap<String, Value>,
122    #[serde(skip)]
123    pub log_pii: bool,
124}
125
126impl Token {
127    pub fn new<T: ToString, I: IntoIterator<Item = T>>(
128        token_type: &str,
129        expires_in: i64,
130        access_token: &str,
131        scope: I,
132    ) -> Token {
133        let timestamp = time::OffsetDateTime::now_utc();
134        let expires_on = timestamp.add(time::Duration::seconds(expires_in));
135
136        Token {
137            token_type: token_type.into(),
138            ext_expires_in: None,
139            expires_in,
140            scope: scope.into_iter().map(|s| s.to_string()).collect(),
141            access_token: access_token.into(),
142            refresh_token: None,
143            user_id: None,
144            id_token: None,
145            state: None,
146            session_state: None,
147            nonce: None,
148            correlation_id: None,
149            client_info: None,
150            timestamp: Some(timestamp),
151            expires_on: Some(expires_on),
152            additional_fields: Default::default(),
153            log_pii: false,
154        }
155    }
156
157    /// Set the token type.
158    ///
159    /// # Example
160    /// ```
161    /// # use graph_oauth::Token;
162    ///
163    /// let mut access_token = Token::default();
164    /// access_token.with_token_type("Bearer");
165    /// ```
166    pub fn with_token_type(&mut self, s: &str) -> &mut Self {
167        self.token_type = s.into();
168        self
169    }
170
171    /// Set the expies in time. This should usually be done in seconds.
172    ///
173    /// # Example
174    /// ```
175    /// # use graph_oauth::Token;
176    ///
177    /// let mut access_token = Token::default();
178    /// access_token.with_expires_in(3600);
179    /// ```
180    pub fn with_expires_in(&mut self, expires_in: i64) -> &mut Self {
181        self.expires_in = expires_in;
182        let timestamp = time::OffsetDateTime::now_utc();
183        self.expires_on = Some(timestamp.add(time::Duration::seconds(self.expires_in)));
184        self.timestamp = Some(timestamp);
185        self
186    }
187
188    /// Set the scope.
189    ///
190    /// # Example
191    /// ```
192    /// # use graph_oauth::Token;
193    ///
194    /// let mut access_token = Token::default();
195    /// access_token.with_scope(vec!["User.Read"]);
196    /// ```
197    pub fn with_scope<T: ToString, I: IntoIterator<Item = T>>(&mut self, scope: I) -> &mut Self {
198        self.scope = scope.into_iter().map(|s| s.to_string()).collect();
199        self
200    }
201
202    /// Set the access token.
203    ///
204    /// # Example
205    /// ```
206    /// # use graph_oauth::Token;
207    ///
208    /// let mut access_token = Token::default();
209    /// access_token.with_access_token("ASODFIUJ34KJ;LADSK");
210    /// ```
211    pub fn with_access_token(&mut self, s: &str) -> &mut Self {
212        self.access_token = s.into();
213        self
214    }
215
216    /// Set the refresh token.
217    ///
218    /// # Example
219    /// ```
220    /// # use graph_oauth::Token;
221    ///
222    /// let mut access_token = Token::default();
223    /// access_token.with_refresh_token("#ASOD323U5342");
224    /// ```
225    pub fn with_refresh_token(&mut self, s: &str) -> &mut Self {
226        self.refresh_token = Some(s.to_string());
227        self
228    }
229
230    /// Set the user id.
231    ///
232    /// # Example
233    /// ```
234    /// # use graph_oauth::Token;
235    ///
236    /// let mut access_token = Token::default();
237    /// access_token.with_user_id("user_id");
238    /// ```
239    pub fn with_user_id(&mut self, s: &str) -> &mut Self {
240        self.user_id = Some(s.to_string());
241        self
242    }
243
244    /// Set the id token.
245    ///
246    /// # Example
247    /// ```
248    /// # use graph_oauth::{Token, IdToken};
249    ///
250    /// let mut access_token = Token::default();
251    /// access_token.set_id_token("id_token");
252    /// ```
253    pub fn set_id_token(&mut self, s: &str) -> &mut Self {
254        self.id_token = Some(IdToken::new(s, None, None, None));
255        self
256    }
257
258    /// Set the id token.
259    ///
260    /// # Example
261    /// ```
262    /// # use graph_oauth::{Token, IdToken};
263    ///
264    /// let mut access_token = Token::default();
265    /// access_token.with_id_token(IdToken::new("id_token", Some("code"), Some("state"), Some("session_state")));
266    /// ```
267    pub fn with_id_token(&mut self, id_token: IdToken) {
268        self.id_token = Some(id_token);
269    }
270
271    /// Set the state.
272    ///
273    /// # Example
274    /// ```
275    /// # use graph_oauth::Token;
276    /// # use graph_oauth::IdToken;
277    ///
278    /// let mut access_token = Token::default();
279    /// access_token.with_state("state");
280    /// ```
281    pub fn with_state(&mut self, s: &str) -> &mut Self {
282        self.state = Some(s.to_string());
283        self
284    }
285
286    /// Enable or disable logging of personally identifiable information such
287    /// as logging the id_token. This is disabled by default. When log_pii is enabled
288    /// passing [Token] to logging or print functions will log both the bearer
289    /// access token value, the refresh token value if any, and the id token value.
290    /// By default these do not get logged.
291    pub fn enable_pii_logging(&mut self, log_pii: bool) {
292        self.log_pii = log_pii;
293    }
294
295    /// Timestamp field is used to tell whether the access token is expired.
296    /// This method is mainly used internally as soon as the access token
297    /// is deserialized from the api response for an accurate reading
298    /// on when the access token expires.
299    ///
300    /// You most likely do not want to use this method unless you are deserializing
301    /// the access token using custom deserialization or creating your own access tokens
302    /// manually.
303    ///
304    /// This method resets the access token timestamp based on the expires_in field
305    /// which is the total seconds that the access token is valid for starting
306    /// from when the token was first retrieved.
307    ///
308    /// This will reset the the timestamp from Utc Now + expires_in. This means
309    /// that if calling [Token::gen_timestamp] will only be reliable if done
310    /// when the access token is first retrieved.
311    ///
312    ///
313    /// # Example
314    /// ```
315    /// # use graph_oauth::Token;
316    ///
317    /// let mut access_token = Token::default();
318    /// access_token.expires_in = 86999;
319    /// access_token.gen_timestamp();
320    /// println!("{:#?}", access_token.timestamp);
321    /// ```
322    pub fn gen_timestamp(&mut self) {
323        let timestamp = time::OffsetDateTime::now_utc();
324        let expires_on = timestamp.add(time::Duration::seconds(self.expires_in));
325        self.timestamp = Some(timestamp);
326        self.expires_on = Some(expires_on);
327    }
328
329    /// Check whether the access token is expired. Checks if expires_on timestamp
330    /// is less than UTC now timestamp.
331    ///
332    /// # Example
333    /// ```
334    /// # use graph_oauth::Token;
335    ///
336    /// let mut access_token = Token::default();
337    /// println!("{:#?}", access_token.is_expired());
338    /// ```
339    pub fn is_expired(&self) -> bool {
340        if let Some(expires_on) = self.expires_on.as_ref() {
341            expires_on.lt(&OffsetDateTime::now_utc())
342        } else {
343            false
344        }
345    }
346
347    /// Check whether the access token is expired sub duration.
348    /// This is useful in scenarios where you want to eagerly refresh
349    /// the access token before it expires to prevent a failed request.
350    ///
351    /// # Example
352    /// ```
353    /// # use graph_oauth::Token;
354    ///
355    /// let mut access_token = Token::default();
356    /// println!("{:#?}", access_token.is_expired_sub(time::Duration::minutes(5)));
357    /// ```
358    pub fn is_expired_sub(&self, duration: time::Duration) -> bool {
359        if let Some(expires_on) = self.expires_on.as_ref() {
360            expires_on.sub(duration).lt(&OffsetDateTime::now_utc())
361        } else {
362            false
363        }
364    }
365
366    /// Get the time left in seconds until the access token expires.
367    /// See the HumanTime crate. If you just need to know if the access token
368    /// is expired then use the is_expired() message which returns a boolean
369    /// true for the token has expired and false otherwise.
370    ///
371    /// # Example
372    /// ```
373    /// # use graph_oauth::Token;
374    ///
375    /// let mut access_token = Token::default();
376    /// println!("{:#?}", access_token.elapsed());
377    /// ```
378    pub fn elapsed(&self) -> Option<time::Duration> {
379        Some(self.expires_on? - self.timestamp?)
380    }
381
382    pub fn decode_header(&self) -> jsonwebtoken::errors::Result<jsonwebtoken::Header> {
383        let id_token = self
384            .id_token
385            .as_ref()
386            .ok_or(jsonwebtoken::errors::Error::from(
387                jsonwebtoken::errors::ErrorKind::InvalidToken,
388            ))?;
389        jsonwebtoken::decode_header(id_token.as_ref())
390    }
391
392    /// Decode and validate the id token.
393    pub fn decode(
394        &self,
395        n: &str,
396        e: &str,
397        client_id: &str,
398        issuer: &str,
399    ) -> jsonwebtoken::errors::Result<TokenData<Claims>> {
400        let id_token = self
401            .id_token
402            .as_ref()
403            .ok_or(jsonwebtoken::errors::Error::from(
404                jsonwebtoken::errors::ErrorKind::InvalidToken,
405            ))?;
406        let mut validation = Validation::new(Algorithm::RS256);
407        validation.set_audience(&[client_id]);
408        validation.set_issuer(&[issuer]);
409
410        jsonwebtoken::decode::<Claims>(
411            id_token.as_ref(),
412            &DecodingKey::from_rsa_components(n, e).unwrap(),
413            &validation,
414        )
415    }
416}
417
418impl Default for Token {
419    fn default() -> Self {
420        Token {
421            token_type: String::new(),
422            expires_in: 0,
423            ext_expires_in: None,
424            scope: vec![],
425            access_token: String::new(),
426            refresh_token: None,
427            user_id: None,
428            id_token: None,
429            state: None,
430            session_state: None,
431            nonce: None,
432            correlation_id: None,
433            client_info: None,
434            timestamp: Some(time::OffsetDateTime::now_utc()),
435            expires_on: Some(
436                OffsetDateTime::from_unix_timestamp(0).unwrap_or(time::OffsetDateTime::UNIX_EPOCH),
437            ),
438            additional_fields: Default::default(),
439            log_pii: false,
440        }
441    }
442}
443
444impl TryFrom<AuthorizationResponse> for Token {
445    type Error = AuthorizationFailure;
446
447    fn try_from(value: AuthorizationResponse) -> Result<Self, Self::Error> {
448        let id_token = IdToken::try_from(value.clone()).ok();
449
450        Ok(Token {
451            access_token: value
452                .access_token
453                .ok_or_else(|| AF::msg_err("access_token", "access_token is None"))?,
454            token_type: "Bearer".to_string(),
455            expires_in: value.expires_in.unwrap_or_default(),
456            ext_expires_in: None,
457            scope: vec![],
458            refresh_token: None,
459            user_id: None,
460            id_token,
461            state: value.state,
462            session_state: value.session_state,
463            nonce: value.nonce,
464            correlation_id: None,
465            client_info: None,
466            timestamp: None,
467            expires_on: None,
468            additional_fields: Default::default(),
469            log_pii: false,
470        })
471    }
472}
473
474impl Display for Token {
475    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
476        write!(f, "{}", self.access_token)
477    }
478}
479
480impl AsBearer for Token {
481    fn as_bearer(&self) -> String {
482        self.access_token.to_string()
483    }
484}
485
486impl TryFrom<&str> for Token {
487    type Error = GraphFailure;
488
489    fn try_from(value: &str) -> Result<Self, Self::Error> {
490        Ok(serde_json::from_str(value)?)
491    }
492}
493
494impl TryFrom<reqwest::blocking::RequestBuilder> for Token {
495    type Error = GraphFailure;
496
497    fn try_from(value: reqwest::blocking::RequestBuilder) -> Result<Self, Self::Error> {
498        let response = value.send()?;
499        Token::try_from(response)
500    }
501}
502
503impl TryFrom<Result<reqwest::blocking::Response, reqwest::Error>> for Token {
504    type Error = GraphFailure;
505
506    fn try_from(
507        value: Result<reqwest::blocking::Response, reqwest::Error>,
508    ) -> Result<Self, Self::Error> {
509        let response = value?;
510        Token::try_from(response)
511    }
512}
513
514impl TryFrom<reqwest::blocking::Response> for Token {
515    type Error = GraphFailure;
516
517    fn try_from(value: reqwest::blocking::Response) -> Result<Self, Self::Error> {
518        Ok(value.json::<Token>()?)
519    }
520}
521
522impl fmt::Debug for Token {
523    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
524        if self.log_pii {
525            f.debug_struct("MsalAccessToken")
526                .field("bearer_token", &self.access_token)
527                .field("refresh_token", &self.refresh_token)
528                .field("token_type", &self.token_type)
529                .field("expires_in", &self.expires_in)
530                .field("scope", &self.scope)
531                .field("user_id", &self.user_id)
532                .field("id_token", &self.id_token)
533                .field("state", &self.state)
534                .field("timestamp", &self.timestamp)
535                .field("expires_on", &self.expires_on)
536                .field("additional_fields", &self.additional_fields)
537                .finish()
538        } else {
539            f.debug_struct("MsalAccessToken")
540                .field(
541                    "bearer_token",
542                    &"[REDACTED]  - call enable_pii_logging(true) to log value",
543                )
544                .field(
545                    "refresh_token",
546                    &"[REDACTED] - call enable_pii_logging(true) to log value",
547                )
548                .field("token_type", &self.token_type)
549                .field("expires_in", &self.expires_in)
550                .field("scope", &self.scope)
551                .field("user_id", &self.user_id)
552                .field(
553                    "id_token",
554                    &"[REDACTED] - call enable_pii_logging(true) to log value",
555                )
556                .field("state", &self.state)
557                .field("timestamp", &self.timestamp)
558                .field("expires_on", &self.expires_on)
559                .field("additional_fields", &self.additional_fields)
560                .finish()
561        }
562    }
563}
564
565impl AsRef<str> for Token {
566    fn as_ref(&self) -> &str {
567        self.access_token.as_str()
568    }
569}
570
571impl<'de> Deserialize<'de> for Token {
572    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
573    where
574        D: Deserializer<'de>,
575    {
576        let phantom_access_token: PhantomToken = Deserialize::deserialize(deserializer)?;
577        let timestamp = OffsetDateTime::now_utc();
578        let expires_on = timestamp.add(time::Duration::seconds(phantom_access_token.expires_in));
579        let id_token = phantom_access_token
580            .id_token
581            .map(|id_token_string| IdToken::new(id_token_string.as_ref(), None, None, None));
582
583        let token = Token {
584            access_token: phantom_access_token.access_token,
585            token_type: phantom_access_token.token_type,
586            expires_in: phantom_access_token.expires_in,
587            ext_expires_in: phantom_access_token.ext_expires_in,
588            scope: phantom_access_token.scope,
589            refresh_token: phantom_access_token.refresh_token,
590            user_id: phantom_access_token.user_id,
591            id_token,
592            state: phantom_access_token.state,
593            session_state: phantom_access_token.session_state,
594            nonce: phantom_access_token.nonce,
595            correlation_id: phantom_access_token.correlation_id,
596            client_info: phantom_access_token.client_info,
597            timestamp: Some(timestamp),
598            expires_on: Some(expires_on),
599            additional_fields: phantom_access_token.additional_fields,
600            log_pii: false,
601        };
602
603        // tracing::debug!(target: "phantom", token.as_value());
604
605        Ok(token)
606    }
607}
608
609#[cfg(test)]
610mod test {
611    use super::*;
612
613    #[test]
614    fn is_expired_test() {
615        let mut access_token = Token::default();
616        access_token.with_expires_in(5);
617        std::thread::sleep(std::time::Duration::from_secs(6));
618        assert!(access_token.is_expired());
619
620        let mut access_token = Token::default();
621        access_token.with_expires_in(8);
622        std::thread::sleep(std::time::Duration::from_secs(4));
623        assert!(!access_token.is_expired());
624    }
625
626    pub const ACCESS_TOKEN_INT: &str = r#"{
627        "access_token": "fasdfasdfasfdasdfasfsdf",
628        "token_type": "Bearer",
629        "expires_in": 65874,
630        "scope": null,
631        "refresh_token": null,
632        "user_id": "santa@north.pole.com",
633        "id_token": "789aasdf-asdf",
634        "state": null,
635        "timestamp": "2020-10-27T16:31:38.788098400Z"
636    }"#;
637
638    pub const ACCESS_TOKEN_STRING: &str = r#"{
639        "access_token": "fasdfasdfasfdasdfasfsdf",
640        "token_type": "Bearer",
641        "expires_in": "65874",
642        "scope": null,
643        "refresh_token": null,
644        "user_id": "helpers@north.pole.com",
645        "id_token": "789aasdf-asdf",
646        "state": null,
647        "timestamp": "2020-10-27T16:31:38.788098400Z"
648    }"#;
649
650    #[test]
651    pub fn test_deserialize() {
652        let _token: Token = serde_json::from_str(ACCESS_TOKEN_INT).unwrap();
653        let _token: Token = serde_json::from_str(ACCESS_TOKEN_STRING).unwrap();
654    }
655
656    #[test]
657    pub fn try_from_url_authorization_response() {
658        let authorization_response = AuthorizationResponse {
659            code: Some("code".into()),
660            id_token: Some("id_token".into()),
661            expires_in: Some(3600),
662            access_token: Some("token".into()),
663            state: Some("state".into()),
664            session_state: Some("session_state".into()),
665            nonce: None,
666            error: None,
667            error_description: None,
668            error_uri: None,
669            additional_fields: Default::default(),
670            log_pii: false,
671        };
672
673        let token = Token::try_from(authorization_response).unwrap();
674        assert_eq!(
675            token.id_token,
676            Some(IdToken::new(
677                "id_token",
678                Some("code"),
679                Some("state"),
680                Some("session_state")
681            ))
682        );
683        assert_eq!(token.access_token, "token".to_string());
684        assert_eq!(token.state, Some("state".to_string()));
685        assert_eq!(token.session_state, Some("session_state".to_string()));
686        assert_eq!(token.expires_in, 3600);
687    }
688}