use std::str::FromStr;
use async_trait::async_trait;
use time::OffsetDateTime;
use ppoppo_token::{AuthError, Claims, VerifyConfig};
use super::keyset::JwksCache;
use super::port::{AuthSession, BearerVerifier, Expectations, VerifyError};
use crate::types::{Ppnum, PpnumId, SessionId};
#[derive(Clone, Debug)]
pub struct PasJwtVerifier {
jwks: JwksCache,
expectations: Expectations,
}
impl PasJwtVerifier {
pub async fn from_jwks_url(
jwks_url: impl Into<String>,
expectations: Expectations,
) -> Result<Self, VerifyError> {
let jwks = JwksCache::fetch(jwks_url).await?;
Ok(Self { jwks, expectations })
}
}
#[async_trait]
impl BearerVerifier for PasJwtVerifier {
async fn verify(&self, bearer_token: &str) -> Result<AuthSession, VerifyError> {
if bearer_token.is_empty() || !looks_like_jws_compact(bearer_token) {
return Err(VerifyError::InvalidFormat);
}
let cfg = VerifyConfig::access_token(
self.expectations.issuer.clone(),
self.expectations.audience.clone(),
);
let keyset = self.jwks.snapshot().await;
let claims = ppoppo_token::verify(bearer_token, &cfg, &keyset)
.await
.map_err(|e| map_auth_error(e, &self.expectations))?;
claims_to_auth_session(claims)
}
}
fn looks_like_jws_compact(token: &str) -> bool {
token.split('.').count() == 3
}
fn map_auth_error(err: AuthError, _expectations: &Expectations) -> VerifyError {
use AuthError as E;
match err {
E::AlgNone
| E::AlgNotWhitelisted
| E::AlgHmacRejected
| E::AlgRsaRejected
| E::AlgEcdsaRejected
| E::HeaderJku
| E::HeaderX5u
| E::HeaderJwk
| E::HeaderX5c
| E::HeaderCrit
| E::HeaderExtraParam
| E::KidUnknown
| E::TypMismatch
| E::NestedJws => VerifyError::SignatureInvalid,
E::OversizedToken | E::JwsJsonRejected | E::JwePayload | E::LaxBase64 => {
VerifyError::InvalidFormat
}
E::Expired | E::ExpUpperBound | E::IatFuture | E::NotYetValid => VerifyError::Expired,
E::IssMismatch => VerifyError::IssuerInvalid,
E::AudMismatch => VerifyError::AudienceInvalid,
E::ExpMissing => VerifyError::MissingClaim("exp"),
E::AudMissing => VerifyError::MissingClaim("aud"),
E::IatMissing => VerifyError::MissingClaim("iat"),
E::JtiMissing => VerifyError::MissingClaim("jti"),
E::SubMissing => VerifyError::MissingClaim("sub"),
E::ClientIdMissing => VerifyError::MissingClaim("client_id"),
other => VerifyError::Other(other.to_string()),
}
}
fn claims_to_auth_session(claims: Claims) -> Result<AuthSession, VerifyError> {
let ppnum_id = ulid::Ulid::from_str(&claims.sub)
.map(PpnumId)
.map_err(|_| VerifyError::MissingClaim("sub"))?;
let ppnum = match claims.active_ppnum {
Some(p) => Some(Ppnum::try_from(p).map_err(|_| VerifyError::MissingClaim("active_ppnum"))?),
None => None,
};
let session_id = claims.sid.map(SessionId::from);
let expires_at = OffsetDateTime::from_unix_timestamp(claims.exp)
.map_err(|_| VerifyError::MissingClaim("exp"))?;
Ok(AuthSession::new(ppnum_id, ppnum, session_id, expires_at))
}