use std::collections::BTreeMap;
use std::str::FromStr;
use std::sync::Arc;
use async_trait::async_trait;
use ppoppo_clock::ArcClock;
use ppoppo_clock::native::WallClock;
use time::OffsetDateTime;
use ppoppo_token::access_token::{
AuthError, Claims, EpochRevocation, VerifyConfig as EngineVerifyConfig,
};
use ppoppo_token::SharedAuthError;
use super::jwks_cache::JwksCache;
use super::{BearerVerifier, VerifiedClaims, VerifyConfig, VerifyError};
use crate::audit::{AuditEvent, AuditSink, VerifyErrorKind};
use crate::session_liveness::{SessionLiveness, SessionLivenessError};
use crate::types::{Ppnum, PpnumId, SessionId};
#[derive(Clone)]
pub struct JwtVerifier {
jwks: JwksCache,
expectations: VerifyConfig,
clock: ArcClock,
audit_sink: Option<Arc<dyn AuditSink>>,
epoch: Option<Arc<dyn EpochRevocation>>,
session_liveness: Option<Arc<dyn SessionLiveness>>,
}
impl std::fmt::Debug for JwtVerifier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwtVerifier")
.field("expectations", &self.expectations)
.finish_non_exhaustive()
}
}
impl JwtVerifier {
pub async fn from_jwks_url(
jwks_url: impl Into<String>,
expectations: VerifyConfig,
) -> Result<Self, VerifyError> {
let jwks = JwksCache::fetch(jwks_url).await?;
Ok(Self {
jwks,
expectations,
clock: Arc::new(WallClock),
audit_sink: None,
epoch: None,
session_liveness: None,
})
}
#[must_use]
pub fn with_clock(mut self, clock: ArcClock) -> Self {
self.jwks = self.jwks.with_clock(clock.clone());
self.clock = clock;
self
}
#[must_use]
pub fn with_audit(mut self, sink: Arc<dyn AuditSink>) -> Self {
self.audit_sink = Some(sink);
self
}
#[must_use]
pub fn with_epoch_revocation(mut self, port: Arc<dyn EpochRevocation>) -> Self {
self.epoch = Some(port);
self
}
#[must_use]
pub fn with_session_liveness(mut self, port: Arc<dyn SessionLiveness>) -> Self {
self.session_liveness = Some(port);
self
}
#[cfg(any(test, feature = "test-support"))]
#[must_use]
pub fn for_test_skip_fetch(expectations: VerifyConfig) -> Self {
Self {
jwks: JwksCache::for_test_empty(),
expectations,
clock: Arc::new(WallClock),
audit_sink: None,
epoch: None,
session_liveness: None,
}
}
async fn emit_failure(&self, bearer_token: &str, err: VerifyError) -> VerifyError {
let Some(sink) = self.audit_sink.as_ref() else {
return err;
};
let kind = VerifyErrorKind::from(&err);
let (client_id_hint, kid_hint) = peek_token_hints(bearer_token);
let mut metadata = BTreeMap::new();
if let VerifyError::Other(msg) = &err {
metadata.insert(
"engine_msg".to_owned(),
serde_json::Value::String(msg.clone()),
);
}
let event = AuditEvent::from_hints(
kind,
self.clock.now_utc(),
client_id_hint,
kid_hint,
metadata,
);
sink.record_failure(event).await;
err
}
}
impl From<&VerifyError> for VerifyErrorKind {
fn from(err: &VerifyError) -> Self {
match err {
VerifyError::InvalidFormat => Self::InvalidFormat,
VerifyError::SignatureInvalid => Self::SignatureInvalid,
VerifyError::Expired => Self::Expired,
VerifyError::IssuerInvalid => Self::IssuerInvalid,
VerifyError::AudienceInvalid => Self::AudienceInvalid,
VerifyError::MissingClaim(claim) => Self::MissingClaim((*claim).to_owned()),
VerifyError::KeysetUnavailable => Self::KeysetUnavailable,
VerifyError::IdTokenAsBearer => Self::IdTokenAsBearer,
VerifyError::SessionVersionStale => Self::SessionVersionStale,
VerifyError::SessionVersionLookupUnavailable => {
Self::SessionVersionLookupUnavailable
}
VerifyError::SessionRevoked => Self::SessionRevoked,
VerifyError::SessionLivenessLookupUnavailable => {
Self::SessionLivenessLookupUnavailable
}
VerifyError::Other(_) => Self::Other,
}
}
}
fn peek_token_hints(token: &str) -> (Option<String>, Option<String>) {
use base64::Engine as _;
let mut parts = token.split('.');
let header_b64 = parts.next();
let payload_b64 = parts.next();
let kid_hint = header_b64.and_then(|h| {
let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(h)
.ok()?;
let value: serde_json::Value = serde_json::from_slice(&bytes).ok()?;
value
.get("kid")
.and_then(|k| k.as_str())
.map(|s| s.to_owned())
});
let client_id_hint = payload_b64.and_then(|p| {
let bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(p)
.ok()?;
let value: serde_json::Value = serde_json::from_slice(&bytes).ok()?;
value
.get("client_id")
.and_then(|c| c.as_str())
.map(|s| s.to_owned())
});
(client_id_hint, kid_hint)
}
#[async_trait]
impl BearerVerifier for JwtVerifier {
async fn verify(&self, bearer_token: &str) -> Result<VerifiedClaims, VerifyError> {
if bearer_token.is_empty() || !looks_like_jws_compact(bearer_token) {
return Err(self.emit_failure(bearer_token, VerifyError::InvalidFormat).await);
}
if peek_id_token_shape(bearer_token) {
return Err(self
.emit_failure(bearer_token, VerifyError::IdTokenAsBearer)
.await);
}
let mut cfg = EngineVerifyConfig::access_token(
self.expectations.issuer.clone(),
self.expectations.audience.clone(),
);
if let Some(port) = self.epoch.as_ref() {
cfg = cfg.with_epoch_revocation(Arc::clone(port));
}
let keyset = self.jwks.snapshot().await;
let now = self.clock.now_utc().unix_timestamp();
let claims = match ppoppo_token::access_token::verify(bearer_token, &cfg, &keyset, now).await {
Ok(c) => c,
Err(e) => {
let mapped = map_auth_error(e, &self.expectations);
return Err(self.emit_failure(bearer_token, mapped).await);
}
};
if let Some(port) = self.session_liveness.as_ref() {
if let Some(sid_str) = claims.sid.as_deref() {
let sid = SessionId::from(sid_str.to_owned());
match port.check(&sid).await {
Ok(()) => {}
Err(SessionLivenessError::Revoked) => {
return Err(self
.emit_failure(bearer_token, VerifyError::SessionRevoked)
.await);
}
Err(SessionLivenessError::Transient(_)) => {
return Err(self
.emit_failure(
bearer_token,
VerifyError::SessionLivenessLookupUnavailable,
)
.await);
}
}
}
}
match claims_to_verified(claims) {
Ok(session) => Ok(session),
Err(err) => Err(self.emit_failure(bearer_token, err).await),
}
}
}
fn looks_like_jws_compact(token: &str) -> bool {
token.split('.').count() == 3
}
pub(crate) fn peek_id_token_shape(token: &str) -> bool {
use base64::Engine as _;
let mut parts = token.split('.');
let Some(header_b64) = parts.next() else { return false; };
let Some(payload_b64) = parts.next() else { return false; };
let header_bytes = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(header_b64) {
Ok(b) => b,
Err(_) => return false,
};
let payload_bytes = match base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload_b64) {
Ok(b) => b,
Err(_) => return false,
};
let header_json: serde_json::Value = match serde_json::from_slice(&header_bytes) {
Ok(v) => v,
Err(_) => return false,
};
let payload_json: serde_json::Value = match serde_json::from_slice(&payload_bytes) {
Ok(v) => v,
Err(_) => return false,
};
let typ = header_json.get("typ").and_then(|v| v.as_str());
if matches!(typ, Some(t) if t != "at+jwt") {
return true;
}
let cat = payload_json.get("cat").and_then(|v| v.as_str());
if cat == Some("id") {
return true;
}
false
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod m73_tests {
use super::*;
use base64::Engine as _;
fn forge(header: serde_json::Value, payload: serde_json::Value) -> String {
let h = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(serde_json::to_vec(&header).unwrap());
let p = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(serde_json::to_vec(&payload).unwrap());
format!("{h}.{p}.<sig>")
}
#[test]
fn typ_jwt_alone_is_id_token_shape() {
let token = forge(
serde_json::json!({"alg": "EdDSA", "typ": "JWT", "kid": "k"}),
serde_json::json!({"cat": "access", "sub": "01HSAB00000000000000000000"}),
);
assert!(peek_id_token_shape(&token));
}
#[test]
fn cat_id_alone_is_id_token_shape() {
let token = forge(
serde_json::json!({"alg": "EdDSA", "typ": "at+jwt", "kid": "k"}),
serde_json::json!({"cat": "id", "sub": "01HSAB00000000000000000000"}),
);
assert!(peek_id_token_shape(&token));
}
#[test]
fn typ_jwt_and_cat_id_both_signal() {
let token = forge(
serde_json::json!({"alg": "EdDSA", "typ": "JWT", "kid": "k"}),
serde_json::json!({"cat": "id", "sub": "01HSAB00000000000000000000"}),
);
assert!(peek_id_token_shape(&token));
}
#[test]
fn proper_access_token_shape_returns_false() {
let token = forge(
serde_json::json!({"alg": "EdDSA", "typ": "at+jwt", "kid": "k"}),
serde_json::json!({"cat": "access", "sub": "01HSAB00000000000000000000"}),
);
assert!(!peek_id_token_shape(&token));
}
#[test]
fn missing_typ_and_cat_admits_to_engine() {
let token = forge(
serde_json::json!({"alg": "EdDSA", "kid": "k"}),
serde_json::json!({"sub": "01HSAB00000000000000000000"}),
);
assert!(!peek_id_token_shape(&token));
}
#[test]
fn malformed_base64_returns_false_not_panic() {
assert!(!peek_id_token_shape("not.valid.token"));
assert!(!peek_id_token_shape("!!!.@@@.###"));
assert!(!peek_id_token_shape(""));
}
#[test]
fn unrecognized_typ_value_is_id_token_shape() {
let token = forge(
serde_json::json!({"alg": "EdDSA", "typ": "id+jwt", "kid": "k"}),
serde_json::json!({"cat": "access", "sub": "01HSAB00000000000000000000"}),
);
assert!(peek_id_token_shape(&token));
}
}
fn map_auth_error(err: AuthError, _expectations: &VerifyConfig) -> VerifyError {
use AuthError as E;
use SharedAuthError as S;
match err {
E::Jose(
S::AlgNone
| S::AlgNotWhitelisted
| S::AlgHmacRejected
| S::AlgRsaRejected
| S::AlgEcdsaRejected
| S::HeaderJku
| S::HeaderX5u
| S::HeaderJwk
| S::HeaderX5c
| S::HeaderCrit
| S::HeaderExtraParam
| S::HeaderB64False
| S::KidUnknown
| S::TypMismatch
| S::NestedJws
| S::DuplicateJsonKeys
| S::HeaderUnparseable
| S::PayloadUnparseable
| S::NotJwsCompact,
) => VerifyError::SignatureInvalid,
E::Jose(
S::OversizedToken | S::JwsJsonRejected | S::JwePayload | S::LaxBase64,
) => VerifyError::InvalidFormat,
E::Expired | E::ExpUpperBound | E::IatFuture | E::NotYetValid => VerifyError::Expired,
E::IssMismatch => VerifyError::IssuerInvalid,
E::AudMismatch => VerifyError::AudienceInvalid,
E::ExpMissing => VerifyError::MissingClaim("exp"),
E::AudMissing => VerifyError::MissingClaim("aud"),
E::IatMissing => VerifyError::MissingClaim("iat"),
E::JtiMissing => VerifyError::MissingClaim("jti"),
E::SubMissing => VerifyError::MissingClaim("sub"),
E::ClientIdMissing => VerifyError::MissingClaim("client_id"),
E::SessionVersionStale => VerifyError::SessionVersionStale,
E::SessionVersionLookupUnavailable => VerifyError::SessionVersionLookupUnavailable,
other => VerifyError::Other(other.to_string()),
}
}
fn claims_to_verified(claims: Claims) -> Result<VerifiedClaims, VerifyError> {
let ppnum_id = ulid::Ulid::from_str(&claims.sub)
.map(PpnumId)
.map_err(|_| VerifyError::MissingClaim("sub"))?;
let ppnum = match claims.active_ppnum {
Some(p) => Some(Ppnum::try_from(p).map_err(|_| VerifyError::MissingClaim("active_ppnum"))?),
None => None,
};
let session_id = claims.sid.map(SessionId::from);
let expires_at = OffsetDateTime::from_unix_timestamp(claims.exp)
.map_err(|_| VerifyError::MissingClaim("exp"))?;
Ok(VerifiedClaims::new(ppnum_id, ppnum, session_id, expires_at))
}