use std::{cmp::max, collections::HashMap, num::Wrapping};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::{helpers::now, types::OidcClientError};
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct TokenSetParams {
#[serde(skip_serializing_if = "Option::is_none")]
pub access_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_in: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_state: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub other: Option<HashMap<String, Value>>,
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct TokenSet {
#[serde(skip_serializing_if = "Option::is_none")]
access_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
token_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
id_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
refresh_token: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
expires_in: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
expires_at: Option<i64>,
#[serde(skip_serializing_if = "Option::is_none")]
session_state: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
other: Option<HashMap<String, Value>>,
}
impl TokenSet {
pub fn new(params: TokenSetParams) -> Self {
let mut tokenset = Self {
access_token: params.access_token,
token_type: params.token_type,
id_token: params.id_token,
refresh_token: params.refresh_token,
expires_in: params.expires_in,
expires_at: params.expires_at,
session_state: params.session_state,
scope: params.scope,
other: params.other,
};
if params.expires_at.is_none() && params.expires_in.is_some() {
if let Some(e) = params.expires_in {
tokenset.expires_at = Some((Wrapping(now()) + Wrapping(e)).0);
}
}
if let Some(e) = params.expires_in {
if e < 0 {
tokenset.expires_in = Some(0);
}
}
tokenset
}
pub fn expired(&self) -> bool {
let expires_in = self.get_expires_in_internal();
if let Some(e) = expires_in {
return e == 0;
}
false
}
pub fn claims(&self) -> Result<HashMap<String, Value>, OidcClientError> {
if self.id_token.is_none() {
return Err(OidcClientError::new_type_error(
"id_token not present in TokenSet",
None,
));
}
let id_token_components: Vec<&str> = self.id_token.as_ref().unwrap().split('.').collect();
let payload = id_token_components.get(1);
if payload.is_none() {
return Err(OidcClientError::new_type_error(
"id_token invalid. payload component not found",
None,
));
}
match base64_url::decode(payload.unwrap()) {
Ok(decoded) => {
serde_json::from_slice::<HashMap<String, Value>>(&decoded).map_err(|_| {
OidcClientError::new_type_error("id_token payload is not a json object", None)
})
}
Err(_) => Err(OidcClientError::new_type_error(
"id_token payload is not base64url encoded",
None,
)),
}
}
pub fn get_access_token(&self) -> Option<String> {
self.access_token.clone()
}
pub fn get_token_type(&self) -> Option<String> {
self.token_type.clone()
}
pub fn get_id_token(&self) -> Option<String> {
self.id_token.clone()
}
pub fn get_refresh_token(&self) -> Option<String> {
self.refresh_token.clone()
}
pub fn get_expires_in(&self) -> Option<i64> {
self.expires_in
}
pub fn get_expires_at(&self) -> Option<i64> {
self.expires_at
}
pub fn get_session_state(&self) -> Option<String> {
self.session_state.clone()
}
pub fn get_scope(&self) -> Option<String> {
self.scope.clone()
}
pub fn get_other(&self) -> Option<HashMap<String, Value>> {
self.other.clone()
}
pub(self) fn get_expires_in_internal(&self) -> Option<i64> {
if let Some(e) = self.expires_at {
return Some(max((Wrapping(e) - Wrapping(now())).0, 0));
}
None
}
}
#[cfg(test)]
#[path = "./tests/tokenset_tests.rs"]
mod tokenset_tests;