use std::sync::Arc;
use jsonwebtoken::{decode, decode_header, DecodingKey, Validation};
use tracing::{debug, warn};
use crate::config::SecurityConfig;
use crate::error::SecurityError;
use crate::identity::{DefaultIdentityBuilder, IdentityBuilder, RoleExtractor};
use crate::jwks::JwksCache;
enum KeySource {
Jwks(Arc<JwksCache>),
Static(DecodingKey),
}
pub struct JwtValidator<B: IdentityBuilder = DefaultIdentityBuilder> {
key_source: KeySource,
config: SecurityConfig,
identity_builder: B,
}
impl JwtValidator {
pub fn new(jwks: Arc<JwksCache>, config: SecurityConfig) -> Self {
Self::from_jwks(jwks, config, DefaultIdentityBuilder::new())
}
pub fn new_with_static_key(key: DecodingKey, config: SecurityConfig) -> Self {
Self::from_static_key(key, config, DefaultIdentityBuilder::new())
}
pub fn with_role_extractor(mut self, extractor: Box<dyn RoleExtractor>) -> Self {
self.identity_builder = DefaultIdentityBuilder::with_extractor(extractor);
self
}
}
impl<B: IdentityBuilder> JwtValidator<B> {
pub fn from_jwks(jwks: Arc<JwksCache>, config: SecurityConfig, identity_builder: B) -> Self {
Self {
key_source: KeySource::Jwks(jwks),
config,
identity_builder,
}
}
pub fn from_static_key(
key: DecodingKey,
config: SecurityConfig,
identity_builder: B,
) -> Self {
Self {
key_source: KeySource::Static(key),
config,
identity_builder,
}
}
pub async fn validate(&self, token: &str) -> Result<B::Identity, SecurityError> {
let header = decode_header(token)
.map_err(|e| SecurityError::InvalidToken(format!("Failed to decode header: {e}")))?;
let algorithm = header.alg;
debug!(?algorithm, kid = ?header.kid, "Decoded JWT header");
let decoding_key = match &self.key_source {
KeySource::Static(key) => key.clone(),
KeySource::Jwks(jwks) => {
let kid = header.kid.as_deref().ok_or_else(|| {
SecurityError::InvalidToken("JWT header missing 'kid' field".into())
})?;
jwks.get_key(kid).await?
}
};
let mut validation = Validation::new(algorithm);
validation.set_issuer(&[&self.config.issuer]);
validation.set_audience(&[&self.config.audience]);
validation.validate_exp = true;
validation.validate_nbf = true;
let token_data = decode::<serde_json::Value>(token, &decoding_key, &validation)
.map_err(|e| {
let err = match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
SecurityError::TokenExpired
}
jsonwebtoken::errors::ErrorKind::InvalidIssuer => {
SecurityError::ValidationFailed("Invalid issuer".into())
}
jsonwebtoken::errors::ErrorKind::InvalidAudience => {
SecurityError::ValidationFailed("Invalid audience".into())
}
_ => SecurityError::InvalidToken(e.to_string()),
};
warn!(error = %err, "JWT claim validation failed");
err
})?;
let sub = token_data
.claims
.get("sub")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_owned();
let identity = self.identity_builder.build(token_data.claims).await?;
debug!(sub = %sub, "JWT validated");
Ok(identity)
}
}