use anyhow::Result;
use serde::de::DeserializeOwned;
use std::collections::HashMap;
use std::sync::Arc;
use crate::claims::{
ClaimsValidator, CognitoAccessTokenClaims, CognitoIdTokenClaims, CognitoJwtClaims,
};
use crate::cognito::config::{self, TokenUse, VerifierConfig};
use crate::cognito::token::CognitoTokenParser;
use crate::common::error::{ErrorLogger, JwtError};
use crate::common::token::TokenParser;
use crate::jwk::provider::JwkProvider;
use crate::jwk::registry::{JwkProviderRegistry, RegistryError};
use crate::verifier::{AccessTokenClaims, IdTokenClaims, JwtVerifier};
#[derive(Debug)]
pub struct CognitoJwtVerifier {
jwk_registry: JwkProviderRegistry,
configs: HashMap<String, VerifierConfig>,
error_logger: ErrorLogger,
}
impl CognitoJwtVerifier {
pub fn new(configs: Vec<VerifierConfig>) -> Result<Self, JwtError> {
let mut verifier = Self {
jwk_registry: JwkProviderRegistry::new(),
configs: HashMap::new(),
error_logger: ErrorLogger::new(crate::common::error::ErrorVerbosity::Standard),
};
for config in configs {
let id = format!("{}_{}", config.region, config.user_pool_id);
verifier.add_user_pool(&id, config)?;
}
Ok(verifier)
}
pub fn new_single_pool(
region: &str,
user_pool_id: &str,
client_ids: &[String],
) -> Result<Self, JwtError> {
let config = VerifierConfig::new(region, user_pool_id, client_ids, None)?;
Self::new(vec![config])
}
pub fn add_user_pool(&mut self, id: &str, config: VerifierConfig) -> Result<(), JwtError> {
let jwk_provider = JwkProvider::new(
&config.region,
&config.user_pool_id,
config.jwk_cache_duration,
)?;
self.jwk_registry
.register(id, jwk_provider)
.map_err(|e| match e {
RegistryError::JwtError(err) => err,
other => JwtError::ConfigurationError {
parameter: Some("provider_registration".to_string()),
error: other.to_string(),
},
})?;
self.configs.insert(id.to_string(), config);
Ok(())
}
pub fn add_user_pool_with_params(
&mut self,
id: &str,
region: &str,
user_pool_id: &str,
client_ids: &[String],
) -> Result<(), JwtError> {
let config = VerifierConfig::new(region, user_pool_id, client_ids, None)?;
self.add_user_pool(id, config)
}
pub fn get_user_pool_ids(&self) -> Vec<String> {
self.jwk_registry.list_ids()
}
pub fn remove_user_pool(&mut self, id: &str) -> Result<(), JwtError> {
self.jwk_registry.remove(id).map_err(|e| match e {
RegistryError::ProviderNotFound(_) => JwtError::ConfigurationError {
parameter: Some("pool_id".to_string()),
error: format!("User pool '{}' not found", id),
},
RegistryError::JwtError(err) => err,
other => JwtError::ConfigurationError {
parameter: Some("pool_id".to_string()),
error: other.to_string(),
},
})?;
self.configs.remove(id);
Ok(())
}
pub fn set_error_verbosity(&mut self, verbosity: crate::common::error::ErrorVerbosity) {
self.error_logger = ErrorLogger::new(verbosity);
}
pub async fn hydrate(&self) -> Vec<(String, Result<(), JwtError>)> {
self.jwk_registry.hydrate().await
}
pub async fn verify<T>(&self, token: &str) -> Result<T, JwtError>
where
T: DeserializeOwned + TryFrom<CognitoJwtClaims, Error = JwtError>,
{
let header = TokenParser::parse_token_header(token)?;
let issuer = TokenParser::extract_issuer(token)?;
let pool_id = self.find_pool_id_by_issuer(&issuer)?;
self.verify_generic_with_pool::<T>(token, &pool_id).await
}
fn find_pool_id_by_issuer(&self, issuer: &str) -> Result<String, JwtError> {
for (id, config) in &self.configs {
let expected_issuer = format!(
"https://cognito-idp.{}.amazonaws.com/{}",
config.region, config.user_pool_id
);
if expected_issuer == issuer {
return Ok(id.clone());
}
}
Err(JwtError::InvalidIssuer {
expected: "a registered Cognito user pool".to_string(),
actual: issuer.to_string(),
})
}
async fn verify_generic_with_pool<T>(&self, token: &str, pool_id: &str) -> Result<T, JwtError>
where
T: DeserializeOwned + TryFrom<CognitoJwtClaims, Error = JwtError>,
{
let jwk_provider = self.jwk_registry.get(pool_id).map_err(|e| match e {
RegistryError::ProviderNotFound(_) => JwtError::ConfigurationError {
parameter: Some("pool_id".to_string()),
error: format!("User pool '{}' not found", pool_id),
},
RegistryError::JwtError(err) => err,
other => JwtError::ConfigurationError {
parameter: Some("pool_id".to_string()),
error: other.to_string(),
},
})?;
let config = self
.configs
.get(pool_id)
.ok_or_else(|| JwtError::ConfigurationError {
parameter: Some("pool_id".to_string()),
error: format!("Configuration for user pool '{}' not found", pool_id),
})?;
let header = TokenParser::parse_token_header(token)?;
let key = jwk_provider.get_key(&header.kid).await?;
let validation = self
.jwk_registry
.create_validation_for_issuer(
jwk_provider.get_issuer(),
config.clock_skew,
&config.client_ids,
)
.map_err(|e| match e {
RegistryError::JwtError(err) => err,
other => JwtError::ConfigurationError {
parameter: Some("validation".to_string()),
error: other.to_string(),
},
})?;
let claims: CognitoJwtClaims = TokenParser::parse_token_claims(token, &key, &validation)?;
self.validate_claims(&claims, config)?;
T::try_from(claims)
}
fn validate_claims(
&self,
claims: &CognitoJwtClaims,
config: &VerifierConfig,
) -> Result<(), JwtError> {
if claims.token_use.is_empty() {
return Err(JwtError::InvalidClaim {
claim: "token_use".to_string(),
reason: "Token use is empty".to_string(),
value: None,
});
}
let token_use = match TokenUse::from_str(&claims.token_use) {
Some(tu) => tu,
None => {
if claims.token_use == "refresh" {
return Err(JwtError::UnsupportedTokenType {
token_type: "refresh".to_string(),
});
} else {
return Err(JwtError::InvalidTokenUse {
expected: "id or access".to_string(),
actual: claims.token_use.clone(),
});
}
}
};
if !config.allowed_token_uses.contains(&token_use) {
let expected = config
.allowed_token_uses
.iter()
.map(|t| t.as_str())
.collect::<Vec<_>>()
.join(" or ");
return Err(JwtError::InvalidTokenUse {
expected,
actual: claims.token_use.clone(),
});
}
let claims_validator = ClaimsValidator::new(Arc::new(config.clone()));
claims_validator.validate_claims(claims)
}
}
impl JwtVerifier for CognitoJwtVerifier {
async fn verify_id_token(&self, token: &str) -> Result<Box<dyn IdTokenClaims>, JwtError> {
let claims = self.verify::<CognitoIdTokenClaims>(token).await?;
Ok(Box::new(claims))
}
async fn verify_access_token(
&self,
token: &str,
) -> Result<Box<dyn AccessTokenClaims>, JwtError> {
let claims = self.verify::<CognitoAccessTokenClaims>(token).await?;
Ok(Box::new(claims))
}
}
impl IdTokenClaims for CognitoIdTokenClaims {
fn get_sub(&self) -> &str {
&self.base.sub
}
fn get_iss(&self) -> &str {
&self.base.iss
}
fn get_aud(&self) -> &str {
&self.base.client_id
}
fn get_exp(&self) -> u64 {
self.base.exp
}
fn get_iat(&self) -> u64 {
self.base.iat
}
fn get_email(&self) -> Option<&str> {
self.email.as_deref()
}
fn is_email_verified(&self) -> bool {
self.email_verified.unwrap_or(false)
}
fn get_name(&self) -> Option<&str> {
self.name.as_deref()
}
}
impl AccessTokenClaims for CognitoAccessTokenClaims {
fn get_sub(&self) -> &str {
&self.base.sub
}
fn get_iss(&self) -> &str {
&self.base.iss
}
fn get_aud(&self) -> &str {
&self.base.client_id
}
fn get_exp(&self) -> u64 {
self.base.exp
}
fn get_iat(&self) -> u64 {
self.base.iat
}
fn get_scopes(&self) -> Vec<String> {
match &self.scope {
Some(scope) => scope.split_whitespace().map(|s| s.to_string()).collect(),
None => Vec::new(),
}
}
fn has_scope(&self, scope: &str) -> bool {
self.get_scopes().contains(&scope.to_string())
}
fn get_client_id(&self) -> Option<&str> {
Some(&self.base.client_id)
}
}