use std::{collections::HashMap, marker::PhantomData, sync::Arc, time::Duration};
use bon::Builder;
use http::HeaderName;
use huskarl_core::{
BoxedError, EndpointUrl,
crypto::verifier::{JwsVerifierFactory, JwsVerifierPlatform},
jwt::{
BoxedJtiUniquenessChecker,
validator::{ClaimCheck, JwtValidator},
},
platform::MaybeSendSync,
server_metadata::AuthorizationServerMetadata,
};
use serde::Deserialize;
use crate::{
AccessTokenValidator,
validator::{
ValidationResult,
binding::DPoPBindingChecker,
common::ValidatorInner,
custom::custom_validator_builder::{SetAuthorizationServer, SetJwksUri},
dpop_nonce::DpopNonceChecker,
error::ValidateHeadersError,
metadata::{ProvideValidatorMetadata, ValidatorMetadata},
observe::{OnValidate, ValidationOutcome},
},
};
pub struct CustomValidator<N: DpopNonceChecker, Claims = HashMap<String, serde_json::Value>> {
inner: ValidatorInner<N>,
authorization_server: Option<String>,
audience: Option<String>,
on_validate: Option<Arc<dyn OnValidate>>,
_phantom: PhantomData<Claims>,
}
#[bon::bon]
impl<N: DpopNonceChecker, Claims: for<'de> Deserialize<'de> + Clone + 'static>
CustomValidator<N, Claims>
{
#[builder(
start_fn(vis = "", name = "builder_internal"),
generics(setters(vis = "", name = "with_{}_internal"))
)]
pub async fn new(
rules: AccessTokenValidationRules,
#[builder(into)]
audience: Option<String>,
#[builder(into)]
allowed_signing_algorithms: Option<Vec<String>>,
#[builder(into)]
allowed_dpop_signing_algorithms: Option<Vec<String>>,
#[builder(default = Duration::from_secs(60))]
max_dpop_proof_age: Duration,
#[builder(default = false)]
require_dpop: bool,
#[builder(default = false)]
require_mtls: bool,
#[builder(into)]
authorization_server: Option<String>,
jwks_uri: Option<EndpointUrl>,
jws_verifier_factory: Arc<dyn JwsVerifierFactory>,
token_jti_checker: Option<BoxedJtiUniquenessChecker>,
dpop_nonce_checker: N,
dpop_jti_checker: Option<BoxedJtiUniquenessChecker>,
#[cfg_attr(feature = "default-jws-verifier-platform", builder(default = crate::DefaultJwsVerifierPlatform::default().into()))]
jws_verifier_platform: Arc<dyn JwsVerifierPlatform>,
#[builder(default = http::header::AUTHORIZATION)]
token_header: HeaderName,
on_validate: Option<Arc<dyn OnValidate>>,
) -> Result<Self, BoxedError> {
let jws_verifier = jws_verifier_factory
.build(jwks_uri.as_ref(), jws_verifier_platform.clone())
.await?;
let token_validator = JwtValidator::builder()
.verifier(jws_verifier)
.aud(
audience
.as_deref()
.map_or(ClaimCheck::NoCheck, ClaimCheck::if_present),
)
.maybe_allowed_algorithms(allowed_signing_algorithms)
.typ(rules.typ)
.iss(rules.iss)
.require_exp(rules.require_exp)
.require_iat(rules.require_iat)
.sub(rules.sub)
.require_jti(rules.require_jti)
.maybe_jti_checker(token_jti_checker)
.build();
Ok(Self {
inner: ValidatorInner {
token_validator,
dpop_binding_checker: DPoPBindingChecker {
dpop_nonce_checker,
dpop_jti_checker,
max_proof_age: max_dpop_proof_age,
jws_verifier_platform,
allowed_signing_algorithms: allowed_dpop_signing_algorithms,
required: require_dpop,
},
token_header,
require_mtls,
},
authorization_server,
audience,
on_validate,
_phantom: PhantomData,
})
}
}
impl<N: DpopNonceChecker> CustomValidator<N, ()> {
pub fn builder() -> CustomValidatorBuilder<N, ()> {
CustomValidator::builder_internal()
}
}
impl<
N: DpopNonceChecker,
Claims: for<'de> Deserialize<'de> + Clone + 'static,
S: custom_validator_builder::State,
> CustomValidatorBuilder<N, Claims, S>
{
pub fn with_claims<Claims1: for<'de> Deserialize<'de> + Clone + 'static>(
self,
) -> CustomValidatorBuilder<N, Claims1, S> {
self.with_claims_internal()
}
}
impl<N: DpopNonceChecker> CustomValidator<N, ()> {
pub fn builder_from_metadata(
metadata: &AuthorizationServerMetadata,
) -> CustomValidatorBuilder<N, (), SetJwksUri<SetAuthorizationServer>> {
Self::builder()
.authorization_server(metadata.issuer.clone())
.maybe_jwks_uri(metadata.jwks_uri.clone())
}
}
impl<N: DpopNonceChecker, Claims: for<'de> Deserialize<'de> + Clone + MaybeSendSync + 'static>
AccessTokenValidator for CustomValidator<N, Claims>
{
type Claims = Claims;
type Error = ValidateHeadersError;
async fn validate_request(
&self,
headers: &http::HeaderMap,
method: &http::Method,
uri: &http::Uri,
client_cert_der: Option<&[u8]>,
) -> ValidationResult<Self::Claims, Self::Error> {
self.validate_request(headers, method, uri, client_cert_der)
.await
}
}
impl<N: DpopNonceChecker, Claims: for<'de> Deserialize<'de> + Clone + 'static>
CustomValidator<N, Claims>
{
pub fn validator_metadata(&self) -> ValidatorMetadata {
ValidatorMetadata {
realm: None,
authorization_servers: self.authorization_server.as_ref().map(|s| vec![s.clone()]),
dpop_signing_alg_values_supported: self
.inner
.dpop_binding_checker
.allowed_signing_algorithms
.clone(),
dpop_bound_access_tokens_required: Some(self.inner.dpop_binding_checker.required),
resource: self.audience.clone(),
bearer_methods_supported: Some(vec!["header"]),
}
}
pub async fn validate_request(
&self,
headers: &http::HeaderMap,
http_method: &http::Method,
http_uri: &http::Uri,
client_cert_der: Option<&[u8]>,
) -> ValidationResult<Claims, ValidateHeadersError> {
let result = self
.inner
.validate_request(headers, http_method, http_uri, client_cert_der)
.await;
if let Some(cb) = &self.on_validate {
let validation_outcome = match &result.outcome {
Ok(Some(_)) => ValidationOutcome::Success,
Ok(None) => ValidationOutcome::NoToken,
Err(ValidateHeadersError::Extract { .. }) => ValidationOutcome::ExtractError,
Err(ValidateHeadersError::InvalidJwt { .. }) => ValidationOutcome::InvalidToken,
Err(ValidateHeadersError::Binding { .. }) => ValidationOutcome::BindingError,
};
cb.on_validate(validation_outcome, self.audience.as_deref().unwrap_or(""));
}
result
}
}
impl<N: DpopNonceChecker, Claims: for<'de> Deserialize<'de> + Clone + 'static>
ProvideValidatorMetadata for CustomValidator<N, Claims>
{
fn validator_metadata(&self) -> ValidatorMetadata {
self.validator_metadata()
}
}
#[derive(Debug, Clone, Builder)]
#[allow(clippy::should_implement_trait)]
pub struct AccessTokenValidationRules {
#[builder(default)]
pub(super) typ: ClaimCheck,
#[builder(default = ClaimCheck::Present)]
pub(super) iss: ClaimCheck,
#[builder(default = true)]
pub(super) require_exp: bool,
#[builder(default = true)]
pub(super) require_iat: bool,
#[builder(default = ClaimCheck::Present)]
pub(super) sub: ClaimCheck,
#[builder(default = true)]
pub(super) require_jti: bool,
}