use std::{marker::PhantomData, sync::Arc, time::Duration};
use crate::{
core::{
BoxedError, EndpointUrl,
crypto::verifier::{JwsVerifierFactory, JwsVerifierPlatform},
jwt::{
BoxedJtiUniquenessChecker,
validator::{ClaimCheck, JwtValidator},
},
platform::MaybeSendSync,
server_metadata::AuthorizationServerMetadata,
},
validator::{
dpop_nonce::NoNonceCheck, rfc9068::rfc9068_validator_builder::SetDpopNonceChecker,
},
};
use http::HeaderName;
use serde::{Deserialize, Serialize};
use crate::{
AccessTokenValidator,
validator::{
ValidationResult,
binding::DPoPBindingChecker,
common::ValidatorInner,
dpop_nonce::DpopNonceChecker,
error::ValidateHeadersError,
metadata::{ProvideValidatorMetadata, ValidatorMetadata},
observe::{OnValidate, ValidationOutcome},
},
};
pub struct Rfc9068Validator<N: DpopNonceChecker, Claims = ()> {
inner: ValidatorInner<N>,
issuer: String,
on_validate: Option<Arc<dyn OnValidate>>,
_phantom: PhantomData<Claims>,
}
#[bon::bon]
impl<N: DpopNonceChecker, Claims: for<'de> Deserialize<'de> + Clone + 'static>
Rfc9068Validator<N, Claims>
{
#[builder(
start_fn(vis = "", name = "builder_internal"),
generics(setters(vis = "", name = "with_{}_internal")),
on(String, into)
)]
pub async fn new(
issuer: String,
audience: String,
#[builder(into)]
allowed_signing_algorithms: Option<Vec<String>>,
#[builder(into)]
allowed_dpop_signing_algorithms: Option<Vec<String>>,
#[builder(default = Duration::from_secs(60))]
max_dpop_proof_age: Duration,
#[builder(default)]
require_dpop: bool,
#[builder(default)]
require_mtls: bool,
jwks_uri: Option<EndpointUrl>,
jws_verifier_factory: Arc<dyn JwsVerifierFactory>,
#[cfg_attr(feature = "default-jws-verifier-platform", builder(default = crate::DefaultJwsVerifierPlatform::default().into()))]
jws_verifier_platform: Arc<dyn JwsVerifierPlatform>,
jti_checker: Option<BoxedJtiUniquenessChecker>,
#[builder(setters(vis = "", name = "dpop_nonce_checker_internal"))]
dpop_nonce_checker: Option<N>,
dpop_jti_checker: Option<BoxedJtiUniquenessChecker>,
#[builder(default = http::header::AUTHORIZATION)]
token_header: HeaderName,
on_validate: Option<Arc<dyn OnValidate>>,
) -> Result<Self, BoxedError> {
let jws_verifier = jws_verifier_factory
.build(jwks_uri.as_ref(), jws_verifier_platform.clone())
.await?;
let jwt_validator = JwtValidator::builder()
.verifier(jws_verifier)
.aud(ClaimCheck::required_value(&audience))
.maybe_allowed_algorithms(allowed_signing_algorithms)
.typ(ClaimCheck::required_value("at+jwt"))
.iss(ClaimCheck::required_value(&issuer))
.require_exp(true)
.require_iat(true)
.sub(ClaimCheck::present())
.require_jti(jti_checker.is_some())
.maybe_jti_checker(jti_checker)
.build();
Ok(Self {
inner: ValidatorInner {
jwt_validator,
dpop_binding_checker: DPoPBindingChecker {
dpop_nonce_checker,
dpop_jti_checker,
max_proof_age: max_dpop_proof_age,
jws_verifier_platform,
allowed_signing_algorithms: allowed_dpop_signing_algorithms,
required: require_dpop,
},
token_header,
require_mtls,
},
issuer,
on_validate,
_phantom: PhantomData,
})
}
}
impl Rfc9068Validator<NoNonceCheck, ()> {
pub fn builder() -> Rfc9068ValidatorBuilder<NoNonceCheck, ()> {
Rfc9068Validator::builder_internal()
}
pub fn builder_from_metadata(
metadata: &AuthorizationServerMetadata,
) -> Rfc9068ValidatorBuilder<
NoNonceCheck,
(),
rfc9068_validator_builder::SetJwksUri<rfc9068_validator_builder::SetIssuer>,
> {
Self::builder()
.issuer(metadata.issuer.clone())
.maybe_jwks_uri(metadata.jwks_uri.clone())
}
}
impl<
N: DpopNonceChecker,
Claims: for<'de> Deserialize<'de> + Clone + 'static,
S: rfc9068_validator_builder::State,
> Rfc9068ValidatorBuilder<N, Claims, S>
{
pub fn with_claims<Claims1: for<'de> Deserialize<'de> + Clone + 'static>(
self,
) -> Rfc9068ValidatorBuilder<N, Claims1, S> {
self.with_claims_internal()
}
pub fn dpop_nonce_checker<N1: DpopNonceChecker>(
self,
dpop_nonce_checker: N1,
) -> Rfc9068ValidatorBuilder<N1, Claims, SetDpopNonceChecker<S>>
where
S::DpopNonceChecker: rfc9068_validator_builder::IsUnset,
{
self.with_n_internal()
.dpop_nonce_checker_internal(dpop_nonce_checker)
}
}
impl<N: DpopNonceChecker, Claims: for<'de> Deserialize<'de> + Clone + 'static>
Rfc9068Validator<N, Claims>
{
pub fn validator_metadata(&self, resource: Option<&str>) -> ValidatorMetadata {
ValidatorMetadata {
realm: None,
authorization_servers: Some(vec![self.issuer.clone()]),
dpop_signing_alg_values_supported: self
.inner
.dpop_binding_checker
.allowed_signing_algorithms
.clone(),
dpop_bound_access_tokens_required: Some(self.inner.dpop_binding_checker.required),
resource: resource.map(|r| r.to_owned()),
bearer_methods_supported: Some(vec!["header"]),
}
}
pub async fn validate_request(
&self,
headers: &http::HeaderMap,
http_method: &http::Method,
http_uri: &http::Uri,
client_cert_der: Option<&[u8]>,
) -> ValidationResult<Rfc9068AccessTokenClaims<Claims>, ValidateHeadersError> {
let result = self
.inner
.validate_request(headers, http_method, http_uri, client_cert_der)
.await;
if let Some(cb) = &self.on_validate {
let validation_outcome = match &result.outcome {
Ok(Some(_)) => ValidationOutcome::Success,
Ok(None) => ValidationOutcome::NoToken,
Err(ValidateHeadersError::Extract { .. }) => ValidationOutcome::ExtractError,
Err(ValidateHeadersError::InvalidJwt { .. }) => ValidationOutcome::InvalidToken,
Err(ValidateHeadersError::Binding { .. }) => ValidationOutcome::BindingError,
};
cb.on_validate(validation_outcome);
}
result
}
}
impl<N: DpopNonceChecker, ExtraClaims: for<'de> Deserialize<'de> + Clone + MaybeSendSync + 'static>
AccessTokenValidator for Rfc9068Validator<N, ExtraClaims>
{
type Claims = Rfc9068AccessTokenClaims<ExtraClaims>;
type Error = ValidateHeadersError;
async fn validate_request(
&self,
headers: &http::HeaderMap,
method: &http::Method,
uri: &http::Uri,
client_cert_der: Option<&[u8]>,
) -> ValidationResult<Self::Claims, Self::Error> {
self.validate_request(headers, method, uri, client_cert_der)
.await
}
}
impl<N: DpopNonceChecker, ExtraClaims: for<'de> Deserialize<'de> + Clone + 'static>
ProvideValidatorMetadata for Rfc9068Validator<N, ExtraClaims>
{
fn validator_metadata(&self, resource: Option<&str>) -> ValidatorMetadata {
self.validator_metadata(resource)
}
}
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(bound(deserialize = "ExtraClaims: for<'d> Deserialize<'d>"))]
pub struct Rfc9068AccessTokenClaims<ExtraClaims = ()> {
pub client_id: String,
pub auth_time: Option<u64>,
pub acr: Option<String>,
#[serde(default)]
pub amr: Vec<String>,
pub scope: Option<String>,
#[serde(flatten)]
pub extra_claims: ExtraClaims,
}