use std::collections::HashSet;
use crate::{
BoxedError,
crypto::verifier::{JwsVerifier, KeyMatch, VerifyError},
jwt::{ConfirmationClaim, JwsParseError, ParsedJws, parse_compact_jws},
platform::{Duration, SystemTime},
};
use bon::Builder;
use serde::Deserialize;
use snafu::{ResultExt as _, Snafu, ensure};
use crate::crypto::verifier::BoxedJwsVerifier;
#[derive(Debug, Clone, Default)]
pub enum ClaimCheck {
IfPresent(String),
RequireAny(Vec<String>),
RequiredValue(String),
Present,
#[default]
NoCheck,
}
impl ClaimCheck {
pub fn if_present(value: impl Into<String>) -> Self {
Self::IfPresent(value.into())
}
pub fn require_any(values: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self::RequireAny(values.into_iter().map(Into::into).collect())
}
pub fn required_value(value: impl Into<String>) -> Self {
Self::RequiredValue(value.into())
}
#[must_use]
pub fn present() -> Self {
Self::Present
}
}
#[allow(clippy::struct_excessive_bools)]
#[allow(clippy::should_implement_trait)] #[derive(Debug, Builder)]
pub struct JwtValidator {
verifier: BoxedJwsVerifier,
#[builder(default)]
iss: ClaimCheck,
#[builder(default)]
sub: ClaimCheck,
#[builder(default)]
aud: ClaimCheck,
#[builder(default)]
typ: ClaimCheck,
#[builder(default)]
require_exp: bool,
#[builder(default)]
require_iat: bool,
#[builder(default)]
require_jti: bool,
max_token_age: Option<Duration>,
#[builder(default)]
clock_leeway: Duration,
#[builder(default, with = FromIterator::from_iter)]
allowed_crit: HashSet<String>,
#[builder(with = FromIterator::from_iter)]
allowed_algorithms: Option<HashSet<String>>,
}
fn normalize_typ(typ: &str) -> &str {
if typ.len() > 12 && typ[..12].eq_ignore_ascii_case("application/") {
&typ[12..]
} else {
typ
}
}
fn check_str_claim(
check: &ClaimCheck,
value: Option<&str>,
missing: impl FnOnce() -> JwtValidationError,
mismatch: impl FnOnce(String, &str) -> JwtValidationError,
) -> Result<(), JwtValidationError> {
match check {
ClaimCheck::Present => {
if value.is_none() {
return Err(missing());
}
}
ClaimCheck::RequiredValue(v) => match value {
Some(val) if val == v.as_str() => {}
Some(val) => return Err(mismatch(v.clone(), val)),
None => return Err(missing()),
},
ClaimCheck::RequireAny(vs) => match value {
Some(val) if vs.iter().any(|x| val == x.as_str()) => {}
Some(val) => return Err(mismatch(vs.join(", "), val)),
None => return Err(missing()),
},
ClaimCheck::IfPresent(v) => {
if let Some(val) = value
&& val != v.as_str()
{
return Err(mismatch(v.clone(), val));
}
}
ClaimCheck::NoCheck => {}
}
Ok(())
}
impl JwtValidator {
#[allow(clippy::too_many_lines)]
pub async fn validate_parsed_jws<C: for<'de> Deserialize<'de> + Clone + 'static>(
&self,
parsed_jwt: ParsedJws<(), C>,
) -> Result<ValidatedJwt<C>, JwtValidationError> {
let now = SystemTime::now();
ensure!(parsed_jwt.header.alg != "none", UnsignedTokenSnafu);
if let Some(allowed) = &self.allowed_algorithms {
ensure!(
allowed.contains(&*parsed_jwt.header.alg),
DisallowedAlgorithmSnafu {
alg: parsed_jwt.header.alg.to_string()
}
);
}
ensure!(
parsed_jwt
.header
.crit
.iter()
.all(|v| self.allowed_crit.contains(v)),
UnrecognizedCriticalHeaderSnafu {
params: parsed_jwt.header.crit
}
);
let key_match = KeyMatch {
alg: &parsed_jwt.header.alg,
kid: parsed_jwt.header.kid.as_deref(),
};
self.verifier
.verify(&parsed_jwt.signing_input, &parsed_jwt.signature, &key_match)
.await
.context(SignatureSnafu)?;
match &self.aud {
ClaimCheck::Present => ensure!(
!parsed_jwt.claims.aud.is_empty(),
RequiredClaimMissingSnafu { claim: "aud" }
),
ClaimCheck::RequiredValue(v) => ensure!(
parsed_jwt.claims.aud.contains(v),
AudienceMismatchSnafu { expected: v }
),
ClaimCheck::RequireAny(vs) => ensure!(
vs.iter().any(|v| parsed_jwt.claims.aud.contains(v)),
AudienceMismatchSnafu {
expected: vs.join(", ")
}
),
ClaimCheck::IfPresent(v) => {
if !parsed_jwt.claims.aud.is_empty() {
ensure!(
parsed_jwt.claims.aud.contains(v),
AudienceMismatchSnafu { expected: v }
);
}
}
ClaimCheck::NoCheck => {}
}
if self.require_exp {
ensure!(
parsed_jwt.claims.exp.is_some(),
RequiredClaimMissingSnafu { claim: "exp" }
);
}
if self.require_iat {
ensure!(
parsed_jwt.claims.iat.is_some(),
RequiredClaimMissingSnafu { claim: "iat" }
);
}
if self.require_jti {
ensure!(
parsed_jwt.claims.jti.is_some(),
RequiredClaimMissingSnafu { claim: "jti" }
);
}
if let Some(max_token_age) = self.max_token_age {
let iat = parsed_jwt
.claims
.iat
.ok_or_else(|| RequiredClaimMissingSnafu { claim: "iat" }.build())?;
let issued_at = SystemTime::UNIX_EPOCH + Duration::from_secs(iat);
ensure!(
now.duration_since(issued_at)
.is_ok_and(|d| d <= max_token_age + self.clock_leeway),
TokenTooOldSnafu {
issued_at,
max_token_age
}
);
}
match &self.typ {
ClaimCheck::IfPresent(t) => ensure!(
parsed_jwt.header.typ.as_ref().is_none_or(|typ| {
normalize_typ(typ).eq_ignore_ascii_case(normalize_typ(t))
}),
InvalidTokenTypeSnafu {
typ: parsed_jwt.header.typ.map(Into::into)
}
),
ClaimCheck::RequireAny(allowed) => match parsed_jwt.header.typ.as_deref() {
None => return RequiredClaimMissingSnafu { claim: "typ" }.fail(),
Some(typ)
if allowed
.iter()
.any(|t| normalize_typ(typ).eq_ignore_ascii_case(normalize_typ(t))) => {}
Some(typ) => {
return InvalidTokenTypeSnafu {
typ: Some(typ.into()),
}
.fail();
}
},
ClaimCheck::RequiredValue(t) => match parsed_jwt.header.typ.as_deref() {
None => return RequiredClaimMissingSnafu { claim: "typ" }.fail(),
Some(typ) if normalize_typ(typ).eq_ignore_ascii_case(normalize_typ(t)) => {}
Some(typ) => {
return InvalidTokenTypeSnafu {
typ: Some(typ.into()),
}
.fail();
}
},
ClaimCheck::Present => {
ensure!(
parsed_jwt.header.typ.is_some(),
RequiredClaimMissingSnafu { claim: "typ" }
);
}
ClaimCheck::NoCheck => {}
}
check_str_claim(
&self.iss,
parsed_jwt.claims.iss.as_deref(),
|| RequiredClaimMissingSnafu { claim: "iss" }.build(),
|expected, actual| IssuerMismatchSnafu { expected, actual }.build(),
)?;
check_str_claim(
&self.sub,
parsed_jwt.claims.sub.as_deref(),
|| RequiredClaimMissingSnafu { claim: "sub" }.build(),
|expected, actual| SubjectMismatchSnafu { expected, actual }.build(),
)?;
if let Some(exp) = parsed_jwt.claims.exp {
let expiration = SystemTime::UNIX_EPOCH + Duration::from_secs(exp);
ensure!(
expiration + self.clock_leeway >= now,
ExpiredSnafu { expiration, now }
);
}
if let Some(nbf) = parsed_jwt.claims.nbf {
let not_before = SystemTime::UNIX_EPOCH + Duration::from_secs(nbf);
ensure!(
not_before <= now + self.clock_leeway,
NotYetValidSnafu { not_before, now }
);
}
if let Some(iat) = parsed_jwt.claims.iat {
let issued_at = SystemTime::UNIX_EPOCH + Duration::from_secs(iat);
ensure!(
issued_at <= now + self.clock_leeway,
IssuedInFutureSnafu { issued_at, now }
);
}
Ok(ValidatedJwt {
issuer: parsed_jwt.claims.iss.map(Into::into),
subject: parsed_jwt.claims.sub.map(Into::into),
audience: parsed_jwt.claims.aud.iter().map(Into::into).collect(),
issued_at: parsed_jwt
.claims
.iat
.map(|iat| SystemTime::UNIX_EPOCH + Duration::from_secs(iat)),
expiration: parsed_jwt
.claims
.exp
.map(|exp| SystemTime::UNIX_EPOCH + Duration::from_secs(exp)),
jti: parsed_jwt.claims.jti.map(Into::into),
cnf: parsed_jwt.claims.cnf,
claims: parsed_jwt.claims.extra_claims.map(|c| match c {
std::borrow::Cow::Borrowed(c) => c.clone(),
std::borrow::Cow::Owned(c) => c,
}),
})
}
pub async fn validate<C: Clone + for<'de> Deserialize<'de> + 'static>(
&self,
token: &str,
) -> Result<ValidatedJwt<C>, JwtValidationError> {
let parsed_jwt = parse_compact_jws::<(), _>(token).context(ParseSnafu)?;
self.validate_parsed_jws(parsed_jwt).await
}
}
#[derive(Debug)]
pub struct ValidatedJwt<C> {
pub issuer: Option<String>,
pub subject: Option<String>,
pub audience: Vec<String>,
pub jti: Option<String>,
pub issued_at: Option<SystemTime>,
pub expiration: Option<SystemTime>,
pub cnf: Option<ConfirmationClaim>,
pub claims: Option<C>,
}
impl<C> ValidatedJwt<C> {
pub fn map_claims<C1, F>(self, f: F) -> ValidatedJwt<C1>
where
F: FnOnce(C) -> C1,
{
ValidatedJwt {
issuer: self.issuer,
subject: self.subject,
audience: self.audience,
jti: self.jti,
issued_at: self.issued_at,
expiration: self.expiration,
cnf: self.cnf,
claims: self.claims.map(f),
}
}
}
#[derive(Debug, Snafu)]
pub enum JwtValidationError {
Parse {
source: JwsParseError,
},
Signature {
source: VerifyError<BoxedError>,
},
UnsignedToken,
DisallowedAlgorithm {
alg: String,
},
UnrecognizedCriticalHeader {
params: Vec<String>,
},
Expired {
expiration: SystemTime,
now: SystemTime,
},
NotYetValid {
not_before: SystemTime,
now: SystemTime,
},
IssuedInFuture {
issued_at: SystemTime,
now: SystemTime,
},
TokenTooOld {
issued_at: SystemTime,
max_token_age: Duration,
},
InvalidTokenType {
typ: Option<String>,
},
IssuerMismatch {
expected: String,
actual: String,
},
SubjectMismatch {
expected: String,
actual: String,
},
AudienceMismatch {
expected: String,
},
RequiredClaimMissing {
claim: &'static str,
},
}