use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use crate::claims::CognitoJwtClaims;
use crate::cognito::config::{TokenUse, VerifierConfig};
use crate::common::error::{ErrorVerbosity, JwtError};
pub struct ClaimsValidator {
config: Arc<VerifierConfig>,
}
impl ClaimsValidator {
pub fn new(config: Arc<VerifierConfig>) -> Self {
Self { config }
}
pub fn validate_claims(&self, claims: &CognitoJwtClaims) -> Result<(), JwtError> {
for claim in &self.config.required_claims {
match claim.as_str() {
"sub" if claims.sub.is_empty() => {
return Err(JwtError::InvalidClaim {
claim: "sub".to_string(),
reason: "Subject is empty".to_string(),
value: None,
});
}
"iss" if claims.iss.is_empty() => {
return Err(JwtError::InvalidClaim {
claim: "iss".to_string(),
reason: "Issuer is empty".to_string(),
value: None,
});
}
"client_id" if claims.client_id.is_empty() => {
return Err(JwtError::InvalidClaim {
claim: "client_id".to_string(),
reason: "Client ID is empty".to_string(),
value: None,
});
}
_ => {}
}
}
self.validate_token_use(claims.token_use.as_str())?;
self.validate_expiration(claims.exp)?;
self.validate_issued_at(claims.iat)?;
self.validate_issuer(&claims.iss)?;
self.validate_client_id(&claims.client_id)?;
for validator in &self.config.custom_validators {
if let Err(reason) = validator.validate(claims) {
return Err(JwtError::InvalidClaim {
claim: "custom".to_string(),
reason,
value: None,
});
}
}
Ok(())
}
pub fn validate_expiration(&self, exp: u64) -> Result<(), JwtError> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if exp <= now - self.config.clock_skew.as_secs() {
return Err(JwtError::ExpiredToken {
exp: Some(exp),
current_time: Some(now),
});
}
Ok(())
}
pub fn validate_not_before(&self, nbf: Option<u64>) -> Result<(), JwtError> {
if let Some(nbf) = nbf {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if nbf > now + self.config.clock_skew.as_secs() {
return Err(JwtError::TokenNotYetValid {
nbf: Some(nbf),
current_time: Some(now),
});
}
}
Ok(())
}
pub fn validate_issued_at(&self, iat: u64) -> Result<(), JwtError> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if iat > now + self.config.clock_skew.as_secs() {
return Err(JwtError::InvalidClaim {
claim: "iat".to_string(),
reason: "Token issued in the future".to_string(),
value: Some(iat.to_string()),
});
}
Ok(())
}
pub fn validate_issuer(&self, issuer: &str) -> Result<(), JwtError> {
let expected_issuer = self.config.get_issuer_url();
if issuer != expected_issuer {
return Err(JwtError::InvalidIssuer {
expected: expected_issuer,
actual: issuer.to_string(),
});
}
Ok(())
}
pub fn validate_client_id(&self, client_id: &str) -> Result<(), JwtError> {
if self.config.client_ids.is_empty() {
return Ok(());
}
if !self.config.client_ids.contains(&client_id.to_string()) {
return Err(JwtError::InvalidClientId {
expected: self.config.client_ids.clone(),
actual: client_id.to_string(),
});
}
Ok(())
}
pub fn validate_token_use(&self, claims_token_use: &str) -> Result<TokenUse, JwtError> {
if claims_token_use.is_empty() {
return Err(JwtError::InvalidClaim {
claim: "token_use".to_string(),
reason: "Token use is empty".to_string(),
value: None,
});
}
let actual_token_use = match TokenUse::from_str(&claims_token_use) {
Some(tu) => tu,
None => {
return Err(JwtError::InvalidTokenUse {
expected: "id or access".to_string(),
actual: claims_token_use.to_string(),
});
}
};
let allowed_token_uses = &self.config.allowed_token_uses;
if !allowed_token_uses.contains(&actual_token_use) {
let expected = allowed_token_uses
.iter()
.map(|t| t.as_str())
.collect::<Vec<_>>()
.join(" or ");
return Err(JwtError::InvalidTokenUse {
expected,
actual: claims_token_use.to_string(),
});
}
Ok(actual_token_use)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashSet;
use std::time::Duration;
#[test]
fn test_validate_expiration() {
let config = Arc::new(VerifierConfig {
region: "us-east-1".to_string(),
user_pool_id: "us-east-1_example".to_string(),
client_ids: vec!["client1".to_string()],
allowed_token_uses: vec![TokenUse::Id, TokenUse::Access],
clock_skew: Duration::from_secs(60),
jwk_cache_duration: Duration::from_secs(3600),
required_claims: HashSet::new(),
custom_validators: Vec::new(),
error_verbosity: ErrorVerbosity::Standard,
});
let validator = ClaimsValidator::new(config);
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
assert!(validator.validate_expiration(now + 3600).is_ok());
assert!(validator.validate_expiration(now - 30).is_ok());
assert!(validator.validate_expiration(now - 120).is_err());
}
#[test]
fn test_validate_issuer() {
let config = Arc::new(VerifierConfig {
region: "us-east-1".to_string(),
user_pool_id: "us-east-1_example".to_string(),
client_ids: vec!["client1".to_string()],
allowed_token_uses: vec![TokenUse::Id, TokenUse::Access],
clock_skew: Duration::from_secs(60),
jwk_cache_duration: Duration::from_secs(3600),
required_claims: HashSet::new(),
custom_validators: Vec::new(),
error_verbosity: ErrorVerbosity::Standard,
});
let validator = ClaimsValidator::new(config);
assert!(validator
.validate_issuer("https://cognito-idp.us-east-1.amazonaws.com/us-east-1_example")
.is_ok());
assert!(validator
.validate_issuer("https://cognito-idp.us-west-2.amazonaws.com/us-west-2_example")
.is_err());
}
#[test]
fn test_validate_client_id() {
let config = Arc::new(VerifierConfig {
region: "us-east-1".to_string(),
user_pool_id: "us-east-1_example".to_string(),
client_ids: vec!["client1".to_string(), "client2".to_string()],
allowed_token_uses: vec![TokenUse::Id, TokenUse::Access],
clock_skew: Duration::from_secs(60),
jwk_cache_duration: Duration::from_secs(3600),
required_claims: HashSet::new(),
custom_validators: Vec::new(),
error_verbosity: ErrorVerbosity::Standard,
});
let validator = ClaimsValidator::new(config);
assert!(validator.validate_client_id("client1").is_ok());
assert!(validator.validate_client_id("client2").is_ok());
assert!(validator.validate_client_id("client3").is_err());
let config = Arc::new(VerifierConfig {
region: "us-east-1".to_string(),
user_pool_id: "us-east-1_example".to_string(),
client_ids: vec![],
allowed_token_uses: vec![TokenUse::Id, TokenUse::Access],
clock_skew: Duration::from_secs(60),
jwk_cache_duration: Duration::from_secs(3600),
required_claims: HashSet::new(),
custom_validators: Vec::new(),
error_verbosity: ErrorVerbosity::Standard,
});
let validator = ClaimsValidator::new(config);
assert!(validator.validate_client_id("any_client").is_ok());
}
}