use crate::common::error::JwtError;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OidcJwtClaims {
pub sub: String,
pub iss: String,
#[serde(alias = "client_id")]
pub aud: String,
pub exp: u64,
pub iat: u64,
pub auth_time: Option<u64>,
pub nonce: Option<String>,
pub acr: Option<String>,
pub amr: Option<Vec<String>>,
pub azp: Option<String>,
#[serde(flatten)]
pub custom_claims: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OidcIdTokenClaims {
#[serde(flatten)]
pub base: OidcJwtClaims,
pub email: Option<String>,
pub email_verified: Option<bool>,
pub name: Option<String>,
pub preferred_username: Option<String>,
pub given_name: Option<String>,
pub family_name: Option<String>,
pub locale: Option<String>,
pub picture: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OidcAccessTokenClaims {
#[serde(flatten)]
pub base: OidcJwtClaims,
pub scope: Option<String>,
pub client_id: Option<String>,
}
impl OidcJwtClaims {
pub fn validate_issuer(&self, expected_issuer: &str) -> bool {
self.iss == expected_issuer
}
pub fn validate_client_id(&self, expected_client_ids: &[String]) -> bool {
if expected_client_ids.is_empty() {
return true;
}
if expected_client_ids.contains(&self.aud) {
return true;
}
if let Some(azp) = &self.azp {
if expected_client_ids.contains(azp) {
return true;
}
}
false
}
pub fn validate_token_use(&self, expected_token_use: &str) -> bool {
match self.get_custom_claim_string("token_use") {
Some(token_use) => token_use == expected_token_use,
None => false,
}
}
pub fn get_custom_claim_string(&self, claim_name: &str) -> Option<String> {
self.custom_claims
.get(claim_name)
.and_then(|v| v.as_str().map(|s| s.to_string()))
}
pub fn get_custom_claim_number(&self, claim_name: &str) -> Option<f64> {
self.custom_claims.get(claim_name).and_then(|v| v.as_f64())
}
pub fn get_custom_claim_bool(&self, claim_name: &str) -> Option<bool> {
self.custom_claims.get(claim_name).and_then(|v| v.as_bool())
}
pub fn get_custom_claim_string_array(&self, claim_name: &str) -> Option<Vec<String>> {
self.custom_claims.get(claim_name).and_then(|v| {
v.as_array().map(|arr| {
arr.iter()
.filter_map(|item| item.as_str().map(|s| s.to_string()))
.collect()
})
})
}
}
impl fmt::Display for OidcJwtClaims {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"OIDC JWT Claims (sub={}, iss={}, exp={})",
self.sub, self.iss, self.exp
)
}
}
impl OidcIdTokenClaims {
pub fn get_email(&self) -> Option<&str> {
self.email.as_deref()
}
pub fn is_email_verified(&self) -> bool {
self.email_verified.unwrap_or(false)
}
pub fn get_name(&self) -> Option<&str> {
self.name.as_deref()
}
pub fn get_preferred_username(&self) -> Option<&str> {
self.preferred_username.as_deref()
}
pub fn get_given_name(&self) -> Option<&str> {
self.given_name.as_deref()
}
pub fn get_family_name(&self) -> Option<&str> {
self.family_name.as_deref()
}
pub fn get_locale(&self) -> Option<&str> {
self.locale.as_deref()
}
pub fn get_picture(&self) -> Option<&str> {
self.picture.as_deref()
}
}
impl fmt::Display for OidcIdTokenClaims {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"OIDC ID Token Claims (sub={}, iss={}, exp={})",
self.base.sub, self.base.iss, self.base.exp
)
}
}
impl TryFrom<OidcJwtClaims> for OidcIdTokenClaims {
type Error = JwtError;
fn try_from(claims: OidcJwtClaims) -> Result<Self, Self::Error> {
match claims.get_custom_claim_string("token_use") {
Some(token_use) => {
if token_use != "id" {
return Err(JwtError::InvalidTokenUse {
expected: "id".to_string(),
actual: token_use.clone(),
});
}
}
None => {
return Err(JwtError::InvalidTokenUse {
expected: "id".to_string(),
actual: "None".to_string(),
});
}
}
let email = claims.get_custom_claim_string("email");
let email_verified = claims.get_custom_claim_bool("email_verified");
let name = claims.get_custom_claim_string("name");
let preferred_username = claims.get_custom_claim_string("preferred_username");
let given_name = claims.get_custom_claim_string("given_name");
let family_name = claims.get_custom_claim_string("family_name");
let locale = claims.get_custom_claim_string("locale");
let picture = claims.get_custom_claim_string("picture");
Ok(Self {
base: claims,
email,
email_verified,
name,
preferred_username,
given_name,
family_name,
locale,
picture,
})
}
}
impl OidcAccessTokenClaims {
pub fn get_scopes(&self) -> Vec<String> {
match &self.scope {
Some(scope) => scope.split_whitespace().map(|s| s.to_string()).collect(),
None => Vec::new(),
}
}
pub fn has_scope(&self, scope: &str) -> bool {
self.get_scopes().contains(&scope.to_string())
}
pub fn get_client_id(&self) -> Option<&str> {
self.client_id.as_deref()
}
}
impl fmt::Display for OidcAccessTokenClaims {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let scope_display = match &self.scope {
Some(s) => s,
None => "none",
};
write!(
f,
"OIDC Access Token Claims (sub={}, iss={}, exp={}, scope={})",
self.base.sub, self.base.iss, self.base.exp, scope_display
)
}
}
impl TryFrom<OidcJwtClaims> for OidcAccessTokenClaims {
type Error = JwtError;
fn try_from(claims: OidcJwtClaims) -> Result<Self, Self::Error> {
match claims.get_custom_claim_string("token_use") {
Some(token_use) => {
if token_use != "access" {
return Err(JwtError::InvalidTokenUse {
expected: "access".to_string(),
actual: token_use.clone(),
});
}
}
None => {
return Err(JwtError::InvalidTokenUse {
expected: "access".to_string(),
actual: "None".to_string(),
});
}
}
let scope = claims.get_custom_claim_string("scope");
let client_id = claims.get_custom_claim_string("client_id");
Ok(Self {
base: claims,
scope,
client_id,
})
}
}