use crate::common::error::JwtError;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub struct CognitoJwtClaims {
pub sub: String,
pub iss: String,
#[serde(rename = "client_id", alias = "aud")]
pub client_id: String,
pub origin_jti: Option<String>,
pub event_id: Option<String>,
#[serde(rename = "token_use")]
pub token_use: String,
pub scope: Option<String>,
pub auth_time: u64,
pub exp: u64,
pub iat: u64,
pub jti: String,
pub username: Option<String>,
#[serde(flatten)]
pub custom_claims: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CognitoIdTokenClaims {
#[serde(flatten)]
pub base: CognitoJwtClaims,
pub email: Option<String>,
pub email_verified: Option<bool>,
pub phone_number: Option<String>,
pub phone_number_verified: Option<bool>,
#[serde(rename = "cognito:groups")]
pub cognito_groups: Option<Vec<String>>,
pub name: Option<String>,
pub aud: String,
pub at_hash: Option<String>,
#[serde(rename = "cognito:username")]
pub cognito_username: Option<String>,
#[serde(rename = "cognito:roles")]
pub cognito_roles: Option<Vec<String>>,
#[serde(rename = "cognito:preferred_role")]
pub cognito_preferred_role: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CognitoAccessTokenClaims {
#[serde(flatten)]
pub base: CognitoJwtClaims,
pub scope: Option<String>,
pub version: Option<u32>,
#[serde(rename = "cognito:username")]
pub cognito_username: Option<String>,
#[serde(rename = "cognito:groups")]
pub cognito_groups: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Identity {
pub user_id: String,
pub provider_name: String,
pub provider_type: String,
pub primary: String,
pub date_created: String,
}
impl CognitoJwtClaims {
pub fn validate_issuer(&self, expected_issuer: &str) -> bool {
self.iss == expected_issuer
}
pub fn validate_client_id(&self, expected_client_ids: &[String]) -> bool {
expected_client_ids.is_empty() || expected_client_ids.contains(&self.client_id)
}
pub fn validate_token_use(&self, expected_token_use: &str) -> bool {
self.token_use == expected_token_use
}
pub fn get_custom_claim_string(&self, claim_name: &str) -> Option<String> {
self.custom_claims
.get(claim_name)
.and_then(|v| v.as_str().map(|s| s.to_string()))
}
pub fn get_custom_claim_number(&self, claim_name: &str) -> Option<f64> {
self.custom_claims.get(claim_name).and_then(|v| v.as_f64())
}
pub fn get_custom_claim_bool(&self, claim_name: &str) -> Option<bool> {
self.custom_claims.get(claim_name).and_then(|v| v.as_bool())
}
pub fn get_custom_claim_string_array(&self, claim_name: &str) -> Option<Vec<String>> {
self.custom_claims.get(claim_name).and_then(|v| {
v.as_array().map(|arr| {
arr.iter()
.filter_map(|item| item.as_str().map(|s| s.to_string()))
.collect()
})
})
}
}
impl fmt::Display for CognitoJwtClaims {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"JWT Claims (sub={}, iss={}, exp={})",
self.sub, self.iss, self.exp
)
}
}
impl CognitoIdTokenClaims {
pub fn get_email(&self) -> Option<&str> {
self.email.as_deref()
}
pub fn is_email_verified(&self) -> bool {
self.email_verified.unwrap_or(false)
}
pub fn get_phone_number(&self) -> Option<&str> {
self.phone_number.as_deref()
}
pub fn is_phone_verified(&self) -> bool {
self.phone_number_verified.unwrap_or(false)
}
pub fn get_groups(&self) -> Vec<String> {
self.cognito_groups.clone().unwrap_or_default()
}
pub fn is_in_group(&self, group: &str) -> bool {
self.cognito_groups
.as_ref()
.map(|groups| groups.contains(&group.to_string()))
.unwrap_or(false)
}
pub fn get_roles(&self) -> Vec<String> {
self.cognito_roles.clone().unwrap_or_default()
}
pub fn has_role(&self, role: &str) -> bool {
self.cognito_roles
.as_ref()
.map(|roles| roles.contains(&role.to_string()))
.unwrap_or(false)
}
pub fn get_preferred_role(&self) -> Option<&str> {
self.cognito_preferred_role.as_deref()
}
}
impl fmt::Display for CognitoIdTokenClaims {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"ID Token Claims (sub={}, iss={}, exp={})",
self.base.sub, self.base.iss, self.base.exp
)
}
}
impl TryFrom<CognitoJwtClaims> for CognitoIdTokenClaims {
type Error = JwtError;
fn try_from(claims: CognitoJwtClaims) -> Result<Self, Self::Error> {
if claims.token_use != "id" {
return Err(JwtError::InvalidTokenUse {
expected: "id".to_string(),
actual: claims.token_use.clone(),
});
}
let email = claims.get_custom_claim_string("email");
let email_verified = claims.get_custom_claim_bool("email_verified");
let phone_number = claims.get_custom_claim_string("phone_number");
let phone_number_verified = claims.get_custom_claim_bool("phone_number_verified");
let cognito_groups = claims.get_custom_claim_string_array("cognito:groups");
let name = claims.get_custom_claim_string("name");
let aud = claims.get_custom_claim_string("aud").unwrap_or_default();
let at_hash = claims.get_custom_claim_string("at_hash");
let cognito_username = claims.get_custom_claim_string("cognito:username");
let cognito_roles = claims.get_custom_claim_string_array("cognito:roles");
let cognito_preferred_role = claims.get_custom_claim_string("cognito:preferred_role");
Ok(Self {
base: claims,
email,
email_verified,
phone_number,
phone_number_verified,
cognito_groups,
name,
aud,
at_hash,
cognito_username,
cognito_roles,
cognito_preferred_role,
})
}
}
impl CognitoAccessTokenClaims {
pub fn get_scopes(&self) -> Vec<String> {
match &self.scope {
Some(scope) => scope.split_whitespace().map(|s| s.to_string()).collect(),
None => Vec::new(),
}
}
pub fn has_scope(&self, scope: &str) -> bool {
self.get_scopes().contains(&scope.to_string())
}
pub fn get_version(&self) -> u32 {
self.version.unwrap_or(1) }
}
impl fmt::Display for CognitoAccessTokenClaims {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let scope_display = match &self.scope {
Some(s) => s,
None => "none",
};
write!(
f,
"Access Token Claims (sub={}, iss={}, exp={}, scope={})",
self.base.sub, self.base.iss, self.base.exp, scope_display
)
}
}
impl TryFrom<CognitoJwtClaims> for CognitoAccessTokenClaims {
type Error = JwtError;
fn try_from(claims: CognitoJwtClaims) -> Result<Self, Self::Error> {
if claims.token_use != "access" {
return Err(JwtError::InvalidTokenUse {
expected: "access".to_string(),
actual: claims.token_use.clone(),
});
}
let scope = claims.scope.clone();
let version = claims.get_custom_claim_number("version").map(|v| v as u32);
let cognito_groups = claims.get_custom_claim_string_array("cognito:groups");
let cognito_username = claims.get_custom_claim_string("cognito:username");
Ok(Self {
base: claims,
scope,
version,
cognito_groups,
cognito_username,
})
}
}