use regex::Regex;
use serde_json::Value;
use std::collections::HashSet;
use std::fmt;
use std::ops::RangeInclusive;
use crate::claims::CognitoJwtClaims;
pub trait ClaimValidator: Send + Sync {
fn validate(&self, claims: &CognitoJwtClaims) -> Result<(), String>;
}
pub struct ExistenceValidator {
claim_name: String,
}
impl ExistenceValidator {
pub fn new(claim_name: &str) -> Self {
Self {
claim_name: claim_name.to_string(),
}
}
}
impl ClaimValidator for ExistenceValidator {
fn validate(&self, claims: &CognitoJwtClaims) -> Result<(), String> {
if claims.custom_claims.contains_key(&self.claim_name) {
Ok(())
} else {
Err(format!("Claim '{}' is required", self.claim_name))
}
}
}
impl fmt::Debug for ExistenceValidator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ExistenceValidator")
.field("claim_name", &self.claim_name)
.finish()
}
}
pub struct StringValueValidator {
claim_name: String,
expected_value: String,
}
impl StringValueValidator {
pub fn new(claim_name: &str, expected_value: &str) -> Self {
Self {
claim_name: claim_name.to_string(),
expected_value: expected_value.to_string(),
}
}
}
impl ClaimValidator for StringValueValidator {
fn validate(&self, claims: &CognitoJwtClaims) -> Result<(), String> {
if let Some(Value::String(value)) = claims.custom_claims.get(&self.claim_name) {
if value == &self.expected_value {
Ok(())
} else {
Err(format!(
"Claim '{}' has invalid value: expected '{}', got '{}'",
self.claim_name, self.expected_value, value
))
}
} else {
Err(format!(
"Claim '{}' is missing or not a string",
self.claim_name
))
}
}
}
impl fmt::Debug for StringValueValidator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("StringValueValidator")
.field("claim_name", &self.claim_name)
.field("expected_value", &self.expected_value)
.finish()
}
}
pub struct AllowedValuesValidator {
claim_name: String,
allowed_values: HashSet<String>,
}
impl AllowedValuesValidator {
pub fn new(claim_name: &str, allowed_values: &[&str]) -> Self {
Self {
claim_name: claim_name.to_string(),
allowed_values: allowed_values.iter().map(|s| s.to_string()).collect(),
}
}
}
impl ClaimValidator for AllowedValuesValidator {
fn validate(&self, claims: &CognitoJwtClaims) -> Result<(), String> {
if let Some(Value::String(value)) = claims.custom_claims.get(&self.claim_name) {
if self.allowed_values.contains(value) {
Ok(())
} else {
Err(format!(
"Claim '{}' has invalid value: '{}' is not one of the allowed values",
self.claim_name, value
))
}
} else {
Err(format!(
"Claim '{}' is missing or not a string",
self.claim_name
))
}
}
}
impl fmt::Debug for AllowedValuesValidator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AllowedValuesValidator")
.field("claim_name", &self.claim_name)
.field("allowed_values", &self.allowed_values)
.finish()
}
}
pub struct NumericRangeValidator {
claim_name: String,
range: RangeInclusive<f64>,
}
impl NumericRangeValidator {
pub fn new(claim_name: &str, min: f64, max: f64) -> Self {
Self {
claim_name: claim_name.to_string(),
range: min..=max,
}
}
}
impl ClaimValidator for NumericRangeValidator {
fn validate(&self, claims: &CognitoJwtClaims) -> Result<(), String> {
if let Some(value) = claims.custom_claims.get(&self.claim_name) {
let num = match value {
Value::Number(n) => n.as_f64().unwrap_or(0.0),
Value::String(s) => s.parse::<f64>().unwrap_or(0.0),
_ => {
return Err(format!("Claim '{}' is not a number", self.claim_name));
}
};
if self.range.contains(&num) {
Ok(())
} else {
Err(format!(
"Claim '{}' is outside the allowed range: {} not in {:?}",
self.claim_name, num, self.range
))
}
} else {
Err(format!("Claim '{}' is missing", self.claim_name))
}
}
}
impl fmt::Debug for NumericRangeValidator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NumericRangeValidator")
.field("claim_name", &self.claim_name)
.field("range", &format!("{:?}", self.range))
.finish()
}
}
pub struct BooleanValidator {
claim_name: String,
expected_value: bool,
}
impl BooleanValidator {
pub fn new(claim_name: &str, expected_value: bool) -> Self {
Self {
claim_name: claim_name.to_string(),
expected_value,
}
}
}
impl ClaimValidator for BooleanValidator {
fn validate(&self, claims: &CognitoJwtClaims) -> Result<(), String> {
if let Some(Value::Bool(value)) = claims.custom_claims.get(&self.claim_name) {
if *value == self.expected_value {
Ok(())
} else {
Err(format!(
"Claim '{}' has invalid value: expected {}, got {}",
self.claim_name, self.expected_value, value
))
}
} else {
Err(format!(
"Claim '{}' is missing or not a boolean",
self.claim_name
))
}
}
}
impl fmt::Debug for BooleanValidator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BooleanValidator")
.field("claim_name", &self.claim_name)
.field("expected_value", &self.expected_value)
.finish()
}
}
pub struct RegexValidator {
claim_name: String,
pattern: Regex,
pattern_str: String,
}
impl RegexValidator {
pub fn new(claim_name: &str, pattern: &str) -> Result<Self, regex::Error> {
Ok(Self {
claim_name: claim_name.to_string(),
pattern: Regex::new(pattern)?,
pattern_str: pattern.to_string(),
})
}
}
impl ClaimValidator for RegexValidator {
fn validate(&self, claims: &CognitoJwtClaims) -> Result<(), String> {
if let Some(Value::String(value)) = claims.custom_claims.get(&self.claim_name) {
if self.pattern.is_match(value) {
Ok(())
} else {
Err(format!(
"Claim '{}' does not match pattern: '{}' does not match '{}'",
self.claim_name, value, self.pattern_str
))
}
} else {
Err(format!(
"Claim '{}' is missing or not a string",
self.claim_name
))
}
}
}
impl fmt::Debug for RegexValidator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RegexValidator")
.field("claim_name", &self.claim_name)
.field("pattern", &self.pattern_str)
.finish()
}
}
pub struct ArrayContainsValidator {
claim_name: String,
value: Value,
}
impl ArrayContainsValidator {
pub fn new_string(claim_name: &str, value: &str) -> Self {
Self {
claim_name: claim_name.to_string(),
value: Value::String(value.to_string()),
}
}
pub fn new_number(claim_name: &str, value: f64) -> Self {
Self {
claim_name: claim_name.to_string(),
value: Value::Number(serde_json::Number::from_f64(value).unwrap()),
}
}
pub fn new_bool(claim_name: &str, value: bool) -> Self {
Self {
claim_name: claim_name.to_string(),
value: Value::Bool(value),
}
}
}
impl ClaimValidator for ArrayContainsValidator {
fn validate(&self, claims: &CognitoJwtClaims) -> Result<(), String> {
if let Some(Value::Array(array)) = claims.custom_claims.get(&self.claim_name) {
if array.contains(&self.value) {
Ok(())
} else {
Err(format!(
"Claim '{}' does not contain expected value: {:?}",
self.claim_name, self.value
))
}
} else {
Err(format!(
"Claim '{}' is missing or not an array",
self.claim_name
))
}
}
}
impl fmt::Debug for ArrayContainsValidator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ArrayContainsValidator")
.field("claim_name", &self.claim_name)
.field("value", &self.value)
.finish()
}
}
pub struct AndValidator {
validators: Vec<Box<dyn ClaimValidator>>,
}
impl AndValidator {
pub fn new(validators: Vec<Box<dyn ClaimValidator>>) -> Self {
Self { validators }
}
}
impl ClaimValidator for AndValidator {
fn validate(&self, claims: &CognitoJwtClaims) -> Result<(), String> {
for validator in &self.validators {
if let Err(reason) = validator.validate(claims) {
return Err(reason);
}
}
Ok(())
}
}
impl fmt::Debug for AndValidator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AndValidator")
.field("validators_count", &self.validators.len())
.finish()
}
}
pub struct OrValidator {
validators: Vec<Box<dyn ClaimValidator>>,
}
impl OrValidator {
pub fn new(validators: Vec<Box<dyn ClaimValidator>>) -> Self {
Self { validators }
}
}
impl ClaimValidator for OrValidator {
fn validate(&self, claims: &CognitoJwtClaims) -> Result<(), String> {
let mut errors = Vec::new();
for validator in &self.validators {
match validator.validate(claims) {
Ok(()) => return Ok(()),
Err(reason) => errors.push(reason),
}
}
Err(format!(
"None of the validators passed: {}",
errors.join(", ")
))
}
}
impl fmt::Debug for OrValidator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OrValidator")
.field("validators_count", &self.validators.len())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_existence_validator() {
let validator = ExistenceValidator::new("test_claim");
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims
.custom_claims
.insert("test_claim".to_string(), Value::String("value".to_string()));
assert!(validator.validate(&claims).is_ok());
let claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
assert!(validator.validate(&claims).is_err());
}
#[test]
fn test_string_value_validator() {
let validator = StringValueValidator::new("test_claim", "expected_value");
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims.custom_claims.insert(
"test_claim".to_string(),
Value::String("expected_value".to_string()),
);
assert!(validator.validate(&claims).is_ok());
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims.custom_claims.insert(
"test_claim".to_string(),
Value::String("wrong_value".to_string()),
);
assert!(validator.validate(&claims).is_err());
let claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
assert!(validator.validate(&claims).is_err());
}
#[test]
fn test_allowed_values_validator() {
let validator = AllowedValuesValidator::new("test_claim", &["value1", "value2", "value3"]);
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims.custom_claims.insert(
"test_claim".to_string(),
Value::String("value2".to_string()),
);
assert!(validator.validate(&claims).is_ok());
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims.custom_claims.insert(
"test_claim".to_string(),
Value::String("value4".to_string()),
);
assert!(validator.validate(&claims).is_err());
}
#[test]
fn test_numeric_range_validator() {
let validator = NumericRangeValidator::new("test_claim", 10.0, 20.0);
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims.custom_claims.insert(
"test_claim".to_string(),
Value::Number(serde_json::Number::from_f64(15.0).unwrap()),
);
assert!(validator.validate(&claims).is_ok());
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims.custom_claims.insert(
"test_claim".to_string(),
Value::Number(serde_json::Number::from_f64(25.0).unwrap()),
);
assert!(validator.validate(&claims).is_err());
}
#[test]
fn test_boolean_validator() {
let validator = BooleanValidator::new("test_claim", true);
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims
.custom_claims
.insert("test_claim".to_string(), Value::Bool(true));
assert!(validator.validate(&claims).is_ok());
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims
.custom_claims
.insert("test_claim".to_string(), Value::Bool(false));
assert!(validator.validate(&claims).is_err());
}
#[test]
fn test_regex_validator() {
let validator = RegexValidator::new("test_claim", r"^[a-z]{3}\d{2}$").unwrap();
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims
.custom_claims
.insert("test_claim".to_string(), Value::String("abc12".to_string()));
assert!(validator.validate(&claims).is_ok());
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims
.custom_claims
.insert("test_claim".to_string(), Value::String("ABC12".to_string()));
assert!(validator.validate(&claims).is_err());
}
#[test]
fn test_array_contains_validator() {
let validator = ArrayContainsValidator::new_string("test_claim", "value2");
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
let array = vec![
Value::String("value1".to_string()),
Value::String("value2".to_string()),
Value::String("value3".to_string()),
];
claims
.custom_claims
.insert("test_claim".to_string(), Value::Array(array));
assert!(validator.validate(&claims).is_ok());
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
let array = vec![
Value::String("value1".to_string()),
Value::String("value3".to_string()),
Value::String("value4".to_string()),
];
claims
.custom_claims
.insert("test_claim".to_string(), Value::Array(array));
assert!(validator.validate(&claims).is_err());
}
#[test]
fn test_and_validator() {
let validator1 = Box::new(ExistenceValidator::new("claim1"));
let validator2 = Box::new(ExistenceValidator::new("claim2"));
let and_validator = AndValidator::new(vec![validator1, validator2]);
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims
.custom_claims
.insert("claim1".to_string(), Value::String("value1".to_string()));
claims
.custom_claims
.insert("claim2".to_string(), Value::String("value2".to_string()));
assert!(and_validator.validate(&claims).is_ok());
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims
.custom_claims
.insert("claim1".to_string(), Value::String("value1".to_string()));
assert!(and_validator.validate(&claims).is_err());
}
#[test]
fn test_or_validator() {
let validator1 = Box::new(ExistenceValidator::new("claim1"));
let validator2 = Box::new(ExistenceValidator::new("claim2"));
let or_validator = OrValidator::new(vec![validator1, validator2]);
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims
.custom_claims
.insert("claim1".to_string(), Value::String("value1".to_string()));
claims
.custom_claims
.insert("claim2".to_string(), Value::String("value2".to_string()));
assert!(or_validator.validate(&claims).is_ok());
let mut claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
claims
.custom_claims
.insert("claim2".to_string(), Value::String("value2".to_string()));
assert!(or_validator.validate(&claims).is_ok());
let claims = CognitoJwtClaims {
sub: "user123".to_string(),
iss: "https://example.com".to_string(),
client_id: "client123".to_string(),
origin_jti: None,
event_id: None,
token_use: "id".to_string(),
scope: None,
auth_time: 0,
exp: 0,
iat: 0,
jti: "jti123".to_string(),
username: None,
custom_claims: HashMap::new(),
};
assert!(or_validator.validate(&claims).is_err());
}
}