use std::{marker::PhantomData, sync::Arc, time::Duration};
use crate::{
core::{
BoxedError, EndpointUrl,
crypto::verifier::{JwsVerifierFactory, JwsVerifierPlatform},
jwt::{
BoxedJtiUniquenessChecker,
validator::{ClaimCheck, JwtValidator},
},
platform::MaybeSendSync,
server_metadata::AuthorizationServerMetadata,
},
validator::{custom::custom_validator_builder::SetDpopNonceChecker, dpop_nonce::NoNonceCheck},
};
use bon::Builder;
use http::HeaderName;
use serde::Deserialize;
use crate::{
AccessTokenValidator,
validator::{
ValidationResult,
binding::DPoPBindingChecker,
common::ValidatorInner,
custom::custom_validator_builder::{SetAuthorizationServer, SetJwksUri},
dpop_nonce::DpopNonceChecker,
error::ValidateHeadersError,
metadata::{ProvideValidatorMetadata, ValidatorMetadata},
observe::{OnValidate, ValidationOutcome},
},
};
pub struct CustomValidator<N: DpopNonceChecker, Claims = ()> {
inner: ValidatorInner<N>,
authorization_server: Option<String>,
on_validate: Option<Arc<dyn OnValidate>>,
_phantom: PhantomData<Claims>,
}
#[bon::bon]
impl<N: DpopNonceChecker, Claims: for<'de> Deserialize<'de> + Clone + 'static>
CustomValidator<N, Claims>
{
#[builder(
start_fn(vis = "", name = "builder_internal"),
generics(setters(vis = "", name = "with_{}_internal")),
on(String, into)
)]
pub async fn new(
#[builder(field)]
rules: AccessTokenValidationRules,
#[builder(into)]
allowed_signing_algorithms: Option<Vec<String>>,
#[builder(into)]
allowed_dpop_signing_algorithms: Option<Vec<String>>,
#[builder(default = Duration::from_secs(60))]
max_dpop_proof_age: Duration,
#[builder(default)]
require_dpop: bool,
#[builder(default)]
require_mtls: bool,
authorization_server: Option<String>,
jwks_uri: Option<EndpointUrl>,
jws_verifier_factory: Arc<dyn JwsVerifierFactory>,
token_jti_checker: Option<BoxedJtiUniquenessChecker>,
#[builder(setters(vis = "", name = "dpop_nonce_checker_internal"))]
dpop_nonce_checker: Option<N>,
dpop_jti_checker: Option<BoxedJtiUniquenessChecker>,
#[cfg_attr(feature = "default-jws-verifier-platform", builder(default = crate::DefaultJwsVerifierPlatform::default().into()))]
jws_verifier_platform: Arc<dyn JwsVerifierPlatform>,
#[builder(default = http::header::AUTHORIZATION)]
token_header: HeaderName,
on_validate: Option<Arc<dyn OnValidate>>,
) -> Result<Self, BoxedError> {
let jws_verifier = jws_verifier_factory
.build(jwks_uri.as_ref(), jws_verifier_platform.clone())
.await?;
let jwt_validator = JwtValidator::builder()
.verifier(jws_verifier)
.aud(rules.aud)
.maybe_allowed_algorithms(allowed_signing_algorithms)
.typ(rules.typ)
.iss(rules.iss)
.require_exp(rules.require_exp)
.require_iat(rules.require_iat)
.sub(rules.sub)
.require_jti(rules.require_jti)
.maybe_jti_checker(token_jti_checker)
.build();
Ok(Self {
inner: ValidatorInner {
jwt_validator,
dpop_binding_checker: DPoPBindingChecker {
dpop_nonce_checker,
dpop_jti_checker,
max_proof_age: max_dpop_proof_age,
jws_verifier_platform,
allowed_signing_algorithms: allowed_dpop_signing_algorithms,
required: require_dpop,
},
token_header,
require_mtls,
},
authorization_server,
on_validate,
_phantom: PhantomData,
})
}
}
impl CustomValidator<NoNonceCheck, ()> {
pub fn builder() -> CustomValidatorBuilder<NoNonceCheck, ()> {
CustomValidator::builder_internal()
}
pub fn builder_from_metadata(
metadata: &AuthorizationServerMetadata,
) -> CustomValidatorBuilder<NoNonceCheck, (), SetJwksUri<SetAuthorizationServer>> {
Self::builder()
.authorization_server(metadata.issuer.clone())
.maybe_jwks_uri(metadata.jwks_uri.clone())
}
}
impl<
N: DpopNonceChecker,
Claims: for<'de> Deserialize<'de> + Clone + 'static,
S: custom_validator_builder::State,
> CustomValidatorBuilder<N, Claims, S>
{
pub fn with_claims<Claims1: for<'de> Deserialize<'de> + Clone + 'static>(
self,
) -> CustomValidatorBuilder<N, Claims1, S> {
self.with_claims_internal()
}
pub fn dpop_nonce_checker<N1: DpopNonceChecker>(
self,
dpop_nonce_checker: N1,
) -> CustomValidatorBuilder<N1, Claims, SetDpopNonceChecker<S>>
where
S::DpopNonceChecker: custom_validator_builder::IsUnset,
{
self.with_n_internal()
.dpop_nonce_checker_internal(dpop_nonce_checker)
}
pub fn rules(mut self, rules: AccessTokenValidationRules) -> Self {
self.rules = rules;
self
}
pub fn token_type(mut self, typ: ClaimCheck) -> Self {
self.rules.typ = typ;
self
}
pub fn issuer(mut self, iss: ClaimCheck) -> Self {
self.rules.iss = iss;
self
}
pub fn audience(mut self, aud: ClaimCheck) -> Self {
self.rules.aud = aud;
self
}
pub fn require_exp(mut self, require_exp: bool) -> Self {
self.rules.require_exp = require_exp;
self
}
pub fn require_iat(mut self, require_iat: bool) -> Self {
self.rules.require_iat = require_iat;
self
}
pub fn subject(mut self, sub: ClaimCheck) -> Self {
self.rules.sub = sub;
self
}
pub fn require_jti(mut self, require_jti: bool) -> Self {
self.rules.require_jti = require_jti;
self
}
}
impl<N: DpopNonceChecker, Claims: for<'de> Deserialize<'de> + Clone + MaybeSendSync + 'static>
AccessTokenValidator for CustomValidator<N, Claims>
{
type Claims = Claims;
type Error = ValidateHeadersError;
async fn validate_request(
&self,
headers: &http::HeaderMap,
method: &http::Method,
uri: &http::Uri,
client_cert_der: Option<&[u8]>,
) -> ValidationResult<Self::Claims, Self::Error> {
self.validate_request(headers, method, uri, client_cert_der)
.await
}
}
impl<N: DpopNonceChecker, Claims: for<'de> Deserialize<'de> + Clone + 'static>
CustomValidator<N, Claims>
{
pub fn validator_metadata(&self, resource: Option<&str>) -> ValidatorMetadata {
ValidatorMetadata {
realm: None,
authorization_servers: self.authorization_server.as_ref().map(|s| vec![s.clone()]),
dpop_signing_alg_values_supported: self
.inner
.dpop_binding_checker
.allowed_signing_algorithms
.clone(),
dpop_bound_access_tokens_required: Some(self.inner.dpop_binding_checker.required),
resource: resource.map(|r| r.to_owned()),
bearer_methods_supported: Some(vec!["header"]),
}
}
pub async fn validate_request(
&self,
headers: &http::HeaderMap,
http_method: &http::Method,
http_uri: &http::Uri,
client_cert_der: Option<&[u8]>,
) -> ValidationResult<Claims, ValidateHeadersError> {
let result = self
.inner
.validate_request(headers, http_method, http_uri, client_cert_der)
.await;
if let Some(cb) = &self.on_validate {
let validation_outcome = match &result.outcome {
Ok(Some(_)) => ValidationOutcome::Success,
Ok(None) => ValidationOutcome::NoToken,
Err(ValidateHeadersError::Extract { .. }) => ValidationOutcome::ExtractError,
Err(ValidateHeadersError::InvalidJwt { .. }) => ValidationOutcome::InvalidToken,
Err(ValidateHeadersError::Binding { .. }) => ValidationOutcome::BindingError,
};
cb.on_validate(validation_outcome);
}
result
}
}
impl<N: DpopNonceChecker, Claims: for<'de> Deserialize<'de> + Clone + 'static>
ProvideValidatorMetadata for CustomValidator<N, Claims>
{
fn validator_metadata(&self, resource: Option<&str>) -> ValidatorMetadata {
self.validator_metadata(resource)
}
}
#[derive(Debug, Clone, Builder)]
#[allow(clippy::should_implement_trait)]
pub struct AccessTokenValidationRules {
#[builder(default)]
pub(super) typ: ClaimCheck,
#[builder(default = ClaimCheck::Present)]
pub(super) iss: ClaimCheck,
#[builder(default)]
pub(super) aud: ClaimCheck,
#[builder(default = true)]
pub(super) require_exp: bool,
#[builder(default = true)]
pub(super) require_iat: bool,
#[builder(default = ClaimCheck::Present)]
pub(super) sub: ClaimCheck,
#[builder(default = true)]
pub(super) require_jti: bool,
}
impl Default for AccessTokenValidationRules {
fn default() -> Self {
Self::builder().build()
}
}