use graph_error::{AuthorizationFailure, GraphFailure, AF};
use serde::{Deserialize, Deserializer};
use serde_aux::prelude::*;
use serde_json::Value;
use std::collections::HashMap;
use std::fmt;
use std::fmt::Display;
use std::ops::{Add, Sub};
use crate::identity::{AuthorizationResponse, IdToken};
use graph_core::{cache::AsBearer, identity::Claims};
use jsonwebtoken::{Algorithm, DecodingKey, TokenData, Validation};
use time::OffsetDateTime;
fn deserialize_scope<'de, D>(scope: D) -> Result<Vec<String>, D::Error>
where
D: Deserializer<'de>,
{
let scope_string: Result<String, D::Error> = serde::Deserialize::deserialize(scope);
if let Ok(scope) = scope_string {
Ok(scope.split(' ').map(|scope| scope.to_owned()).collect())
} else {
Ok(vec![])
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PhantomToken {
access_token: String,
token_type: String,
#[serde(deserialize_with = "deserialize_number_from_string")]
expires_in: i64,
ext_expires_in: Option<i64>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_scope")]
scope: Vec<String>,
refresh_token: Option<String>,
user_id: Option<String>,
id_token: Option<String>,
state: Option<String>,
session_state: Option<String>,
nonce: Option<String>,
correlation_id: Option<String>,
client_info: Option<String>,
#[serde(flatten)]
additional_fields: HashMap<String, Value>,
}
#[derive(Clone, Eq, PartialEq, Serialize)]
pub struct Token {
pub access_token: String,
pub token_type: String,
#[serde(deserialize_with = "deserialize_number_from_string")]
pub expires_in: i64,
pub ext_expires_in: Option<i64>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_scope")]
pub scope: Vec<String>,
pub refresh_token: Option<String>,
pub user_id: Option<String>,
pub id_token: Option<IdToken>,
pub state: Option<String>,
pub session_state: Option<String>,
pub nonce: Option<String>,
pub correlation_id: Option<String>,
pub client_info: Option<String>,
pub timestamp: Option<time::OffsetDateTime>,
pub expires_on: Option<time::OffsetDateTime>,
#[serde(flatten)]
pub additional_fields: HashMap<String, Value>,
#[serde(skip)]
pub log_pii: bool,
}
impl Token {
pub fn new<T: ToString, I: IntoIterator<Item = T>>(
token_type: &str,
expires_in: i64,
access_token: &str,
scope: I,
) -> Token {
let timestamp = time::OffsetDateTime::now_utc();
let expires_on = timestamp.add(time::Duration::seconds(expires_in));
Token {
token_type: token_type.into(),
ext_expires_in: None,
expires_in,
scope: scope.into_iter().map(|s| s.to_string()).collect(),
access_token: access_token.into(),
refresh_token: None,
user_id: None,
id_token: None,
state: None,
session_state: None,
nonce: None,
correlation_id: None,
client_info: None,
timestamp: Some(timestamp),
expires_on: Some(expires_on),
additional_fields: Default::default(),
log_pii: false,
}
}
pub fn with_token_type(&mut self, s: &str) -> &mut Self {
self.token_type = s.into();
self
}
pub fn with_expires_in(&mut self, expires_in: i64) -> &mut Self {
self.expires_in = expires_in;
let timestamp = time::OffsetDateTime::now_utc();
self.expires_on = Some(timestamp.add(time::Duration::seconds(self.expires_in)));
self.timestamp = Some(timestamp);
self
}
pub fn with_scope<T: ToString, I: IntoIterator<Item = T>>(&mut self, scope: I) -> &mut Self {
self.scope = scope.into_iter().map(|s| s.to_string()).collect();
self
}
pub fn with_access_token(&mut self, s: &str) -> &mut Self {
self.access_token = s.into();
self
}
pub fn with_refresh_token(&mut self, s: &str) -> &mut Self {
self.refresh_token = Some(s.to_string());
self
}
pub fn with_user_id(&mut self, s: &str) -> &mut Self {
self.user_id = Some(s.to_string());
self
}
pub fn set_id_token(&mut self, s: &str) -> &mut Self {
self.id_token = Some(IdToken::new(s, None, None, None));
self
}
pub fn with_id_token(&mut self, id_token: IdToken) {
self.id_token = Some(id_token);
}
pub fn with_state(&mut self, s: &str) -> &mut Self {
self.state = Some(s.to_string());
self
}
pub fn enable_pii_logging(&mut self, log_pii: bool) {
self.log_pii = log_pii;
}
pub fn gen_timestamp(&mut self) {
let timestamp = time::OffsetDateTime::now_utc();
let expires_on = timestamp.add(time::Duration::seconds(self.expires_in));
self.timestamp = Some(timestamp);
self.expires_on = Some(expires_on);
}
pub fn is_expired(&self) -> bool {
if let Some(expires_on) = self.expires_on.as_ref() {
expires_on.lt(&OffsetDateTime::now_utc())
} else {
false
}
}
pub fn is_expired_sub(&self, duration: time::Duration) -> bool {
if let Some(expires_on) = self.expires_on.as_ref() {
expires_on.sub(duration).lt(&OffsetDateTime::now_utc())
} else {
false
}
}
pub fn elapsed(&self) -> Option<time::Duration> {
Some(self.expires_on? - self.timestamp?)
}
pub fn decode_header(&self) -> jsonwebtoken::errors::Result<jsonwebtoken::Header> {
let id_token = self
.id_token
.as_ref()
.ok_or(jsonwebtoken::errors::Error::from(
jsonwebtoken::errors::ErrorKind::InvalidToken,
))?;
jsonwebtoken::decode_header(id_token.as_ref())
}
pub fn decode(
&self,
n: &str,
e: &str,
client_id: &str,
issuer: &str,
) -> jsonwebtoken::errors::Result<TokenData<Claims>> {
let id_token = self
.id_token
.as_ref()
.ok_or(jsonwebtoken::errors::Error::from(
jsonwebtoken::errors::ErrorKind::InvalidToken,
))?;
let mut validation = Validation::new(Algorithm::RS256);
validation.set_audience(&[client_id]);
validation.set_issuer(&[issuer]);
jsonwebtoken::decode::<Claims>(
id_token.as_ref(),
&DecodingKey::from_rsa_components(n, e).unwrap(),
&validation,
)
}
}
impl Default for Token {
fn default() -> Self {
Token {
token_type: String::new(),
expires_in: 0,
ext_expires_in: None,
scope: vec![],
access_token: String::new(),
refresh_token: None,
user_id: None,
id_token: None,
state: None,
session_state: None,
nonce: None,
correlation_id: None,
client_info: None,
timestamp: Some(time::OffsetDateTime::now_utc()),
expires_on: Some(
OffsetDateTime::from_unix_timestamp(0).unwrap_or(time::OffsetDateTime::UNIX_EPOCH),
),
additional_fields: Default::default(),
log_pii: false,
}
}
}
impl TryFrom<AuthorizationResponse> for Token {
type Error = AuthorizationFailure;
fn try_from(value: AuthorizationResponse) -> Result<Self, Self::Error> {
let id_token = IdToken::try_from(value.clone()).ok();
Ok(Token {
access_token: value
.access_token
.ok_or_else(|| AF::msg_err("access_token", "access_token is None"))?,
token_type: "Bearer".to_string(),
expires_in: value.expires_in.unwrap_or_default(),
ext_expires_in: None,
scope: vec![],
refresh_token: None,
user_id: None,
id_token,
state: value.state,
session_state: value.session_state,
nonce: value.nonce,
correlation_id: None,
client_info: None,
timestamp: None,
expires_on: None,
additional_fields: Default::default(),
log_pii: false,
})
}
}
impl Display for Token {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.access_token)
}
}
impl AsBearer for Token {
fn as_bearer(&self) -> String {
self.access_token.to_string()
}
}
impl TryFrom<&str> for Token {
type Error = GraphFailure;
fn try_from(value: &str) -> Result<Self, Self::Error> {
Ok(serde_json::from_str(value)?)
}
}
impl TryFrom<reqwest::blocking::RequestBuilder> for Token {
type Error = GraphFailure;
fn try_from(value: reqwest::blocking::RequestBuilder) -> Result<Self, Self::Error> {
let response = value.send()?;
Token::try_from(response)
}
}
impl TryFrom<Result<reqwest::blocking::Response, reqwest::Error>> for Token {
type Error = GraphFailure;
fn try_from(
value: Result<reqwest::blocking::Response, reqwest::Error>,
) -> Result<Self, Self::Error> {
let response = value?;
Token::try_from(response)
}
}
impl TryFrom<reqwest::blocking::Response> for Token {
type Error = GraphFailure;
fn try_from(value: reqwest::blocking::Response) -> Result<Self, Self::Error> {
Ok(value.json::<Token>()?)
}
}
impl fmt::Debug for Token {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if self.log_pii {
f.debug_struct("MsalAccessToken")
.field("bearer_token", &self.access_token)
.field("refresh_token", &self.refresh_token)
.field("token_type", &self.token_type)
.field("expires_in", &self.expires_in)
.field("scope", &self.scope)
.field("user_id", &self.user_id)
.field("id_token", &self.id_token)
.field("state", &self.state)
.field("timestamp", &self.timestamp)
.field("expires_on", &self.expires_on)
.field("additional_fields", &self.additional_fields)
.finish()
} else {
f.debug_struct("MsalAccessToken")
.field(
"bearer_token",
&"[REDACTED] - call enable_pii_logging(true) to log value",
)
.field(
"refresh_token",
&"[REDACTED] - call enable_pii_logging(true) to log value",
)
.field("token_type", &self.token_type)
.field("expires_in", &self.expires_in)
.field("scope", &self.scope)
.field("user_id", &self.user_id)
.field(
"id_token",
&"[REDACTED] - call enable_pii_logging(true) to log value",
)
.field("state", &self.state)
.field("timestamp", &self.timestamp)
.field("expires_on", &self.expires_on)
.field("additional_fields", &self.additional_fields)
.finish()
}
}
}
impl AsRef<str> for Token {
fn as_ref(&self) -> &str {
self.access_token.as_str()
}
}
impl<'de> Deserialize<'de> for Token {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let phantom_access_token: PhantomToken = Deserialize::deserialize(deserializer)?;
let timestamp = OffsetDateTime::now_utc();
let expires_on = timestamp.add(time::Duration::seconds(phantom_access_token.expires_in));
let id_token = phantom_access_token
.id_token
.map(|id_token_string| IdToken::new(id_token_string.as_ref(), None, None, None));
let token = Token {
access_token: phantom_access_token.access_token,
token_type: phantom_access_token.token_type,
expires_in: phantom_access_token.expires_in,
ext_expires_in: phantom_access_token.ext_expires_in,
scope: phantom_access_token.scope,
refresh_token: phantom_access_token.refresh_token,
user_id: phantom_access_token.user_id,
id_token,
state: phantom_access_token.state,
session_state: phantom_access_token.session_state,
nonce: phantom_access_token.nonce,
correlation_id: phantom_access_token.correlation_id,
client_info: phantom_access_token.client_info,
timestamp: Some(timestamp),
expires_on: Some(expires_on),
additional_fields: phantom_access_token.additional_fields,
log_pii: false,
};
Ok(token)
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn is_expired_test() {
let mut access_token = Token::default();
access_token.with_expires_in(5);
std::thread::sleep(std::time::Duration::from_secs(6));
assert!(access_token.is_expired());
let mut access_token = Token::default();
access_token.with_expires_in(8);
std::thread::sleep(std::time::Duration::from_secs(4));
assert!(!access_token.is_expired());
}
pub const ACCESS_TOKEN_INT: &str = r#"{
"access_token": "fasdfasdfasfdasdfasfsdf",
"token_type": "Bearer",
"expires_in": 65874,
"scope": null,
"refresh_token": null,
"user_id": "santa@north.pole.com",
"id_token": "789aasdf-asdf",
"state": null,
"timestamp": "2020-10-27T16:31:38.788098400Z"
}"#;
pub const ACCESS_TOKEN_STRING: &str = r#"{
"access_token": "fasdfasdfasfdasdfasfsdf",
"token_type": "Bearer",
"expires_in": "65874",
"scope": null,
"refresh_token": null,
"user_id": "helpers@north.pole.com",
"id_token": "789aasdf-asdf",
"state": null,
"timestamp": "2020-10-27T16:31:38.788098400Z"
}"#;
#[test]
pub fn test_deserialize() {
let _token: Token = serde_json::from_str(ACCESS_TOKEN_INT).unwrap();
let _token: Token = serde_json::from_str(ACCESS_TOKEN_STRING).unwrap();
}
#[test]
pub fn try_from_url_authorization_response() {
let authorization_response = AuthorizationResponse {
code: Some("code".into()),
id_token: Some("id_token".into()),
expires_in: Some(3600),
access_token: Some("token".into()),
state: Some("state".into()),
session_state: Some("session_state".into()),
nonce: None,
error: None,
error_description: None,
error_uri: None,
additional_fields: Default::default(),
log_pii: false,
};
let token = Token::try_from(authorization_response).unwrap();
assert_eq!(
token.id_token,
Some(IdToken::new(
"id_token",
Some("code"),
Some("state"),
Some("session_state")
))
);
assert_eq!(token.access_token, "token".to_string());
assert_eq!(token.state, Some("state".to_string()));
assert_eq!(token.session_state, Some("session_state".to_string()));
assert_eq!(token.expires_in, 3600);
}
}