use std::collections::HashSet;
use std::time::{SystemTime, UNIX_EPOCH};
use serde_json::map::Map;
use serde_json::Value;
use crate::Algorithm;
use crate::Error;
use crate::Header;
#[derive(Debug, Clone, PartialEq)]
pub struct ValidationOptions {
pub leeway: u64,
pub validate_exp: bool,
pub validate_nbf: bool,
pub audiences: Option<HashSet<String>>,
pub issuer: Option<String>,
pub subject: Option<String>,
pub algorithms: HashSet<Algorithm>,
pub required_claims: Option<HashSet<String>>,
}
impl ValidationOptions {
pub fn new(alg: Algorithm) -> Self {
Self {
algorithms: HashSet::from([alg]),
..Self::default()
}
}
pub fn without_expiry(self) -> Self {
Self {
validate_exp: false,
..Self::default()
}
}
pub fn with_audiences<T: ToString>(self, audiences: &[T]) -> Self {
Self {
audiences: Some(audiences.iter().map(ToString::to_string).collect()),
..self
}
}
pub fn with_audience<T: ToString>(self, audience: T) -> Self {
Self {
audiences: Some(HashSet::from([audience.to_string()])),
..self
}
}
pub fn with_issuer<T: ToString>(self, issuer: T) -> Self {
Self {
issuer: Some(issuer.to_string()),
..self
}
}
pub fn with_subject<T: ToString>(self, subject: T) -> Self {
Self {
subject: Some(subject.to_string()),
..self
}
}
pub fn with_leeway(self, leeway: u64) -> Self {
Self { leeway, ..self }
}
pub fn with_algorithm(mut self, alg: Algorithm) -> Self {
self.algorithms.insert(alg);
self
}
pub fn with_required_claim<T: ToString>(mut self, claim: T) -> Self {
if let Some(ref mut required_claims) = self.required_claims {
required_claims.insert(claim.to_string());
} else {
self.required_claims = Some(HashSet::from([claim.to_string()]));
}
self
}
}
impl Default for ValidationOptions {
fn default() -> Self {
Self {
leeway: 0,
validate_exp: true,
validate_nbf: false,
audiences: None,
issuer: None,
subject: None,
algorithms: HashSet::new(),
required_claims: None,
}
}
}
pub(crate) fn validate_header(
header: &Header,
validation_options: &ValidationOptions,
) -> Result<(), Error> {
if !validation_options.algorithms.is_empty()
&& !validation_options.algorithms.contains(&header.alg)
{
return Err(Error::InvalidAlgorithm);
}
Ok(())
}
pub(crate) fn validate(
claims: &Map<String, Value>,
options: &ValidationOptions,
) -> Result<(), Error> {
let now = current_timestamp();
let validate_time_claim = |claim_value: Option<&Value>,
validate: bool,
validation_predicate: &dyn Fn(u64) -> bool,
validation_error: Error,
missing_claim_error: Error|
-> Result<(), Error> {
if validate {
if let Some(value) = claim_value.and_then(|v| v.as_u64()) {
if !validation_predicate(value) {
return Err(validation_error);
}
} else {
return Err(missing_claim_error);
}
}
Ok(())
};
validate_time_claim(
claims.get("exp"),
options.validate_exp,
&|timestamp| now <= timestamp + options.leeway,
Error::ExpiredSignature,
Error::InvalidClaim("Missing exp claim".to_string()),
)?;
validate_time_claim(
claims.get("nbf"),
options.validate_nbf,
&|timestamp| now >= timestamp - options.leeway,
Error::ImmatureSignature,
Error::InvalidClaim("Missing nbf claim".to_string()),
)?;
let validate_str_claim = |claim_value: Option<&Value>,
expected_value: &Option<String>,
validation_error: Error|
-> Result<(), Error> {
if let Some(expected) = expected_value {
if let Some(actual) = claim_value.and_then(|v| v.as_str()) {
if actual != expected {
return Err(validation_error);
}
} else {
return Err(validation_error);
}
}
Ok(())
};
validate_str_claim(claims.get("iss"), &options.issuer, Error::InvalidIssuer)?;
validate_str_claim(claims.get("sub"), &options.subject, Error::InvalidSubject)?;
let validate_audiences = |aud_claim: Option<&Value>,
expected_audiences: &Option<HashSet<String>>|
-> Result<(), Error> {
if let Some(expected) = expected_audiences {
match aud_claim {
Some(Value::String(aud)) => {
if !expected.contains(aud) {
return Err(Error::InvalidAudience);
}
}
Some(Value::Array(aud_array)) => {
let provided: HashSet<String> = aud_array
.iter()
.filter_map(|val| val.as_str().map(String::from))
.collect();
if provided.is_disjoint(expected) {
return Err(Error::InvalidAudience);
}
}
_ => return Err(Error::InvalidAudience),
}
}
Ok(())
};
validate_audiences(claims.get("aud"), &options.audiences)?;
if let Some(ref required_claims) = options.required_claims {
for claim in required_claims {
if !claims.contains_key(claim) {
return Err(Error::InvalidClaim(format!(
"Missing required claim: {}",
claim
)));
}
}
}
Ok(())
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime before UNIX EPOCH")
.as_secs()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::{json, to_value};
#[test]
fn test_expiration_validation() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
let result = validate(&claims, &ValidationOptions::default());
if result.is_err() {
println!("{:?}", result);
}
assert!(result.is_ok());
}
#[test]
fn test_expiration_validation_fail() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() - 60).unwrap(),
);
let result = validate(&claims, &ValidationOptions::default());
assert!(matches!(result, Err(Error::ExpiredSignature)));
}
#[test]
fn test_not_before_validation() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
claims.insert("nbf".to_string(), to_value(current_timestamp()).unwrap());
let options = ValidationOptions::default();
let result = validate(&claims, &options);
assert!(result.is_ok());
}
#[test]
fn test_issuer_validation() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
claims.insert("iss".to_string(), json!("valid_issuer"));
let options = ValidationOptions::default().with_issuer("valid_issuer");
let result = validate(&claims, &options);
assert!(result.is_ok());
}
#[test]
fn test_issuer_validation_fail() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
claims.insert("iss".to_string(), json!("invalid_issuer"));
let options = ValidationOptions::default().with_issuer("valid_issuer");
let result = validate(&claims, &options);
assert!(matches!(result, Err(Error::InvalidIssuer)));
}
#[test]
fn test_subject_validation() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
claims.insert("sub".to_string(), json!("valid_subject"));
let options = ValidationOptions::default().with_subject("valid_subject");
let result = validate(&claims, &options);
assert!(result.is_ok());
}
#[test]
fn test_subject_validation_fail() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
claims.insert("sub".to_string(), json!("invalid_subject"));
let options = ValidationOptions::default().with_subject("valid_subject");
let result = validate(&claims, &options);
assert!(matches!(result, Err(Error::InvalidSubject)));
}
#[test]
fn test_audience_validation() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
claims.insert("aud".to_string(), json!("valid_audience"));
let options = ValidationOptions::default().with_audience("valid_audience");
let result = validate(&claims, &options);
assert!(result.is_ok());
}
#[test]
fn test_audience_validation_fail() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
claims.insert("aud".to_string(), json!("invalid_audience"));
let options = ValidationOptions::default().with_audience("valid_audience");
let result = validate(&claims, &options);
assert!(matches!(result, Err(Error::InvalidAudience)));
}
#[test]
fn test_audience_validation_array() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
claims.insert(
"aud".to_string(),
json!(["valid_audience", "another_audience"]),
);
let options = ValidationOptions::default().with_audience("valid_audience");
let result = validate(&claims, &options);
assert!(result.is_ok());
}
#[test]
fn test_audience_validation_array_fail() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
claims.insert(
"aud".to_string(),
json!(["invalid_audience", "another_audience"]),
);
let options = ValidationOptions::default().with_audience("valid_audience");
let result = validate(&claims, &options);
assert!(matches!(result, Err(Error::InvalidAudience)));
}
#[test]
fn test_algorithm_validation() {
let header = Header {
alg: Algorithm::HS256,
..Header::default()
};
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
let options = ValidationOptions::default().with_algorithm(Algorithm::HS256);
let result = validate_header(&header, &options);
assert!(result.is_ok());
}
#[test]
fn test_algorithm_validation_fail_in_header() {
let header = Header {
alg: Algorithm::HS256,
..Header::default()
};
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
let options = ValidationOptions::default().with_algorithm(Algorithm::HS384);
let result = validate_header(&header, &options);
assert!(matches!(result, Err(Error::InvalidAlgorithm)));
}
#[test]
fn test_required_claims() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
claims.insert("sub".to_string(), json!("required_subject"));
let options = ValidationOptions::default().with_required_claim("sub");
let result = validate(&claims, &options);
assert!(result.is_ok());
}
#[test]
fn test_required_claims_fail() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
let options = ValidationOptions::default().with_required_claim("sub");
let result = validate(&claims, &options);
assert!(matches!(result, Err(Error::InvalidClaim(_))));
}
#[test]
fn test_required_claims_multiple() {
let mut claims = Map::new();
claims.insert(
"exp".to_string(),
to_value(current_timestamp() + 3600).unwrap(),
);
claims.insert("sub".to_string(), json!("required_subject"));
claims.insert("aud".to_string(), json!("required_audience"));
let options = ValidationOptions::default()
.with_required_claim("sub")
.with_required_claim("aud");
let result = validate(&claims, &options);
assert!(result.is_ok());
}
}