use std::ops::Deref;
use gdp_rs::{proven::ProvenError, Proven};
pub use self::predicate::{InvalidDPoPProof, IsValidDPoPProof};
use super::{context::DPoPProofContext, raw::RawDPoPProof};
pub type ContextualDPoPProof = (RawDPoPProof<'static>, DPoPProofContext);
pub struct ValidatedDPoPProof(Proven<ContextualDPoPProof, IsValidDPoPProof>);
impl Deref for ValidatedDPoPProof {
type Target = RawDPoPProof<'static>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.0 .0
}
}
impl ValidatedDPoPProof {
#[inline]
#[allow(clippy::result_large_err)]
pub fn try_new(
raw_proof: RawDPoPProof<'static>,
context: DPoPProofContext,
) -> Result<Self, ProvenError<ContextualDPoPProof, InvalidDPoPProof>> {
Ok(Self(Proven::try_new((raw_proof, context))?))
}
#[inline]
pub fn context(&self) -> &DPoPProofContext {
&self.0 .1
}
}
mod predicate {
use gdp_rs::predicate::{Predicate, PurePredicate, SyncEvaluablePredicate};
use picky::jose::{
jwk::JwkError,
jws::{verify_signature, JwsError},
};
use super::ContextualDPoPProof;
use crate::proof::{
context::KeyBoundAccessToken,
payload::{ath::Ath, jkt::Jkt},
};
#[derive(Debug, Clone)]
pub struct IsValidDPoPProof;
impl Predicate<ContextualDPoPProof> for IsValidDPoPProof {
fn label() -> std::borrow::Cow<'static, str> {
"IsValidDPoPProof".into()
}
}
impl SyncEvaluablePredicate<ContextualDPoPProof> for IsValidDPoPProof {
type EvalError = InvalidDPoPProof;
fn evaluate_for((raw_proof, context): &ContextualDPoPProof) -> Result<(), Self::EvalError> {
let decoded_header = &raw_proof.decoded_essence().header;
let decoded_claims = &raw_proof.decoded_essence().claims;
let public_key = decoded_header.jwk().to_public_key()?;
verify_signature(raw_proof.compact_repr(), &public_key, decoded_header.alg)
.map_err(InvalidDPoPProof::InvalidSignature)?;
if decoded_claims.htm != context.req_method {
return Err(InvalidDPoPProof::HtmClaimMismatch);
}
if !decoded_claims.htu.matches(&context.req_uri) {
return Err(InvalidDPoPProof::HtuClaimMismatch);
}
if context.active_nonce.is_some() && context.active_nonce != decoded_claims.nonce {
return Err(InvalidDPoPProof::NonceClaimMismatch);
}
if (context.req_time - decoded_claims.iat).abs() >= context.time_leeway.into() {
return Err(InvalidDPoPProof::TimestampOutOfWindow);
}
if let Some(nonce_timestamp) = context.nonce_timestamp {
if (context.req_time - nonce_timestamp).abs() >= context.time_leeway.into() {
return Err(InvalidDPoPProof::TimestampOutOfWindow);
}
}
if let Some(KeyBoundAccessToken {
access_token,
bound_key_jkt,
}) = context.key_bound_access_token.as_ref()
{
let ath = Ath::new(access_token);
let decoded_ath = decoded_claims
.ath
.as_ref()
.or_else(|| cfg!(feature = "unsafe-optional-ath-claim").then_some(&ath))
.ok_or(InvalidDPoPProof::AthClaimMismatch)?;
if decoded_ath != &ath {
return Err(InvalidDPoPProof::AthClaimMismatch);
}
if bound_key_jkt != &Jkt::new(&decoded_header.jwk()) {
return Err(InvalidDPoPProof::BindingKeyMisMatch);
}
}
Ok(())
}
}
impl PurePredicate<ContextualDPoPProof> for IsValidDPoPProof {}
#[derive(Debug, thiserror::Error)]
pub enum InvalidDPoPProof {
#[error("Invalid public key jwk.\n{0}")]
InvalidPublicKeyJwk(#[from] JwkError),
#[error("Invalid signature.\n{0}")]
InvalidSignature(JwsError),
#[error("Htm claim mismatch.")]
HtmClaimMismatch,
#[error("Htu claim mismatch.")]
HtuClaimMismatch,
#[error("Nonce claim mismatch.")]
NonceClaimMismatch,
#[error("Ath claim mismatch.")]
AthClaimMismatch,
#[error("Binding key mismatch.")]
BindingKeyMisMatch,
#[error("Proof timestamp out of window.")]
TimestampOutOfWindow,
}
}