use super::{JwksCache, TokenClaims, TokenError};
use async_trait::async_trait;
use std::sync::Arc;
#[async_trait]
pub trait TokenValidator: Send + Sync {
async fn validate(&self, token: &str) -> Result<TokenClaims, TokenError>;
fn issuer(&self) -> &str;
}
#[cfg(feature = "sso")]
pub struct JwtValidator {
issuer: String,
audience: Option<String>,
jwks_cache: Arc<JwksCache>,
algorithms: Vec<jsonwebtoken::Algorithm>,
}
#[cfg(feature = "sso")]
impl JwtValidator {
pub fn builder() -> JwtValidatorBuilder {
JwtValidatorBuilder::default()
}
fn validation(&self, _kid: Option<&str>) -> jsonwebtoken::Validation {
let mut validation = jsonwebtoken::Validation::new(
self.algorithms.first().copied().unwrap_or(jsonwebtoken::Algorithm::RS256),
);
validation.set_issuer(&[&self.issuer]);
if let Some(aud) = &self.audience {
validation.set_audience(&[aud]);
}
validation.validate_exp = true;
validation.validate_nbf = true;
validation
}
}
#[cfg(feature = "sso")]
#[async_trait]
impl TokenValidator for JwtValidator {
async fn validate(&self, token: &str) -> Result<TokenClaims, TokenError> {
let header = jsonwebtoken::decode_header(token)?;
let kid = header.kid.ok_or_else(|| TokenError::MissingClaim("kid".into()))?;
let key = self.jwks_cache.get_key(&kid).await?;
let validation = self.validation(Some(&kid));
let token_data = jsonwebtoken::decode::<TokenClaims>(token, &key, &validation)?;
Ok(token_data.claims)
}
fn issuer(&self) -> &str {
&self.issuer
}
}
#[cfg(feature = "sso")]
#[derive(Default)]
pub struct JwtValidatorBuilder {
issuer: Option<String>,
audience: Option<String>,
jwks_uri: Option<String>,
algorithms: Vec<jsonwebtoken::Algorithm>,
}
#[cfg(feature = "sso")]
impl JwtValidatorBuilder {
pub fn issuer(mut self, issuer: impl Into<String>) -> Self {
self.issuer = Some(issuer.into());
self
}
pub fn audience(mut self, audience: impl Into<String>) -> Self {
self.audience = Some(audience.into());
self
}
pub fn jwks_uri(mut self, uri: impl Into<String>) -> Self {
self.jwks_uri = Some(uri.into());
self
}
pub fn algorithm(mut self, alg: jsonwebtoken::Algorithm) -> Self {
self.algorithms.push(alg);
self
}
pub fn build(self) -> Result<JwtValidator, TokenError> {
let issuer =
self.issuer.ok_or_else(|| TokenError::ValidationError("issuer is required".into()))?;
let jwks_uri = self
.jwks_uri
.ok_or_else(|| TokenError::ValidationError("jwks_uri is required".into()))?;
let algorithms = if self.algorithms.is_empty() {
vec![jsonwebtoken::Algorithm::RS256]
} else {
self.algorithms
};
for algorithm in &algorithms {
match algorithm {
jsonwebtoken::Algorithm::RS256
| jsonwebtoken::Algorithm::RS384
| jsonwebtoken::Algorithm::RS512
| jsonwebtoken::Algorithm::PS256
| jsonwebtoken::Algorithm::PS384
| jsonwebtoken::Algorithm::PS512
| jsonwebtoken::Algorithm::ES256
| jsonwebtoken::Algorithm::ES384 => {}
jsonwebtoken::Algorithm::HS256
| jsonwebtoken::Algorithm::HS384
| jsonwebtoken::Algorithm::HS512
| jsonwebtoken::Algorithm::EdDSA => {
return Err(TokenError::ValidationError(format!(
"algorithm '{algorithm:?}' is not supported with JWKS-based validation. Use an RSA or EC algorithm instead."
)));
}
}
}
Ok(JwtValidator {
issuer,
audience: self.audience,
jwks_cache: Arc::new(JwksCache::new(jwks_uri)),
algorithms,
})
}
}
#[cfg(not(feature = "sso"))]
pub struct JwtValidator;
#[cfg(not(feature = "sso"))]
impl JwtValidator {
pub fn builder() -> JwtValidatorBuilder {
JwtValidatorBuilder
}
}
#[cfg(not(feature = "sso"))]
pub struct JwtValidatorBuilder;
#[cfg(not(feature = "sso"))]
impl JwtValidatorBuilder {
pub fn issuer(self, _: impl Into<String>) -> Self {
self
}
pub fn audience(self, _: impl Into<String>) -> Self {
self
}
pub fn jwks_uri(self, _: impl Into<String>) -> Self {
self
}
pub fn build(self) -> Result<JwtValidator, TokenError> {
Err(TokenError::ValidationError("SSO feature not enabled".into()))
}
}