use async_trait::async_trait;
use chrono::{DateTime, Utc};
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::Jwk};
use parking_lot::RwLock;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
use zeroize::Zeroize;
use crate::default_http_client;
use auths_oidc_port::{JwksClient, JwtValidator, OidcError, OidcValidationConfig};
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct OidcTokenClaims {
pub sub: String,
pub iss: String,
pub aud: String,
pub exp: i64,
#[serde(default)]
pub iat: i64,
#[serde(default)]
pub nbf: Option<i64>,
#[serde(default)]
pub jti: Option<String>,
#[serde(flatten)]
pub extra: serde_json::Map<String, serde_json::Value>,
}
pub struct HttpJwtValidator {
jwks_client: Arc<dyn JwksClient>,
}
impl HttpJwtValidator {
pub fn new(jwks_client: Arc<dyn JwksClient>) -> Self {
Self { jwks_client }
}
}
#[async_trait]
impl JwtValidator for HttpJwtValidator {
async fn validate(
&self,
token: &str,
config: &OidcValidationConfig,
now: DateTime<Utc>,
) -> Result<serde_json::Value, OidcError> {
let mut token_mut = token.to_string();
let header = decode_header(&token_mut).map_err(|e| {
let error_msg = format!("{}", e);
if error_msg.contains("unknown variant") && error_msg.contains("expected one of") {
OidcError::AlgorithmMismatch {
expected: "RS256, RS384, RS512, ES256, ES384, PS256, PS384, PS512, or EdDSA"
.to_string(),
got: "unsupported algorithm".to_string(),
}
} else {
OidcError::JwtDecode(format!("failed to decode JWT header: {}", e))
}
})?;
let kid = header
.kid
.ok_or_else(|| OidcError::JwtDecode("JWT header missing 'kid' field".to_string()))?;
let alg_str = format!("{:?}", header.alg);
if alg_str.to_uppercase() == "NONE" {
return Err(OidcError::AlgorithmMismatch {
expected: "RS256 or ES256".to_string(),
got: "none".to_string(),
});
}
if !config
.allowed_algorithms
.iter()
.any(|allowed| allowed.to_uppercase() == alg_str.to_uppercase())
{
return Err(OidcError::AlgorithmMismatch {
expected: config.allowed_algorithms.join(", "),
got: alg_str.clone(),
});
}
let jwks = self.jwks_client.fetch_jwks(&config.issuer).await?;
let keys = jwks.get("keys").and_then(|k| k.as_array()).ok_or_else(|| {
OidcError::JwksResolutionFailed("JWKS response missing 'keys' array".to_string())
})?;
let key_obj = keys
.iter()
.find(|key| {
key.get("kid")
.and_then(|k| k.as_str())
.map(|k| k == kid)
.unwrap_or(false)
})
.ok_or_else(|| OidcError::UnknownKeyId(kid.clone()))?;
let jwk: Jwk = serde_json::from_value(key_obj.clone()).map_err(|e| {
OidcError::JwksResolutionFailed(format!(
"failed to parse JWKS key for kid {}: {}",
kid, e
))
})?;
let decoding_key = DecodingKey::from_jwk(&jwk).map_err(|e| {
OidcError::JwksResolutionFailed(format!(
"failed to create decoding key for kid {}: {}",
kid, e
))
})?;
let now_secs = now.timestamp();
let leeway = config.max_clock_skew_secs as u64;
let algorithm = match alg_str.to_uppercase().as_str() {
"RS256" => Algorithm::RS256,
"ES256" => Algorithm::ES256,
_ => {
return Err(OidcError::AlgorithmMismatch {
expected: "RS256 or ES256".to_string(),
got: alg_str,
});
}
};
let mut validation = Validation::new(algorithm);
validation.set_issuer(&[&config.issuer]);
validation.set_audience(&[&config.audience]);
validation.leeway = leeway;
validation.validate_exp = true;
validation.set_required_spec_claims(&["exp", "iss", "aud", "sub"]);
let token_data = decode::<OidcTokenClaims>(&token_mut, &decoding_key, &validation)
.map_err(|e| {
let error_msg = format!("{}", e);
if error_msg.contains("ExpiredSignature") || error_msg.contains("InvalidIssuedAt") {
OidcError::ClockSkewExceeded {
token_exp: 0,
current_time: now_secs,
leeway: leeway as i64,
}
} else if error_msg.contains("InvalidSignature") {
OidcError::SignatureVerificationFailed
} else if error_msg.contains("InvalidIssuer") {
OidcError::ClaimsValidationFailed {
claim: "iss".to_string(),
reason: "issuer mismatch".to_string(),
}
} else if error_msg.contains("InvalidAudience") {
OidcError::ClaimsValidationFailed {
claim: "aud".to_string(),
reason: "audience mismatch".to_string(),
}
} else {
OidcError::JwtDecode(format!("JWT validation failed: {}", e))
}
})?;
token_mut.zeroize();
let mut json = serde_json::json!(token_data.claims);
if let Some(obj) = json.as_object_mut() {
for (k, v) in token_data.claims.extra.iter() {
obj.insert(k.clone(), v.clone());
}
}
Ok(json)
}
}
pub struct HttpJwksClient {
cache: Arc<RwLock<JwksCache>>,
}
struct JwksCache {
data: Option<serde_json::Value>,
expires_at: Option<DateTime<Utc>>,
ttl: Duration,
}
impl HttpJwksClient {
pub fn new(ttl: Duration) -> Self {
Self {
cache: Arc::new(RwLock::new(JwksCache {
data: None,
expires_at: None,
ttl,
})),
}
}
pub fn with_default_ttl() -> Self {
Self::new(Duration::from_secs(3600))
}
}
#[async_trait]
impl JwksClient for HttpJwksClient {
async fn fetch_jwks(&self, issuer_url: &str) -> Result<serde_json::Value, OidcError> {
#[allow(clippy::disallowed_methods)] let now = Utc::now();
{
let cache = self.cache.read();
if let Some(data) = &cache.data
&& let Some(expires_at) = cache.expires_at
&& now < expires_at
{
return Ok(data.clone());
}
}
let jwks_url = format!(
"{}{}",
issuer_url.trim_end_matches('/'),
"/.well-known/jwks.json"
);
let client = default_http_client();
let response = client.get(&jwks_url).send().await.map_err(|e| {
OidcError::JwksResolutionFailed(format!(
"failed to fetch JWKS from {}: {}",
jwks_url, e
))
})?;
let jwks: serde_json::Value = response.json().await.map_err(|e| {
OidcError::JwksResolutionFailed(format!(
"failed to parse JWKS response from {}: {}",
jwks_url, e
))
})?;
let mut cache = self.cache.write();
cache.data = Some(jwks.clone());
#[allow(clippy::expect_used)]
let duration_offset = chrono::Duration::from_std(cache.ttl).expect("cache TTL overflow");
cache.expires_at = Some(now + duration_offset);
Ok(jwks)
}
}
#[cfg(test)]
mod tests {
use super::*;
use auths_oidc_port::OidcValidationConfig;
#[tokio::test]
async fn test_http_jwt_validator_missing_kid() {
let mock_client = MockJwksClient::new();
let validator = HttpJwtValidator::new(Arc::new(mock_client));
let invalid_token = "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
let config = OidcValidationConfig::builder()
.issuer("https://example.com")
.audience("test")
.build()
.unwrap();
#[allow(clippy::disallowed_methods)] let result = validator.validate(invalid_token, &config, Utc::now()).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_algorithm_none_rejected() {
let mock_client = MockJwksClient::new();
let validator = HttpJwtValidator::new(Arc::new(mock_client));
let token_none = "eyJhbGciOiJub25lIiwia2lkIjoiYWJjIn0.eyJzdWIiOiIxMjM0NTY3ODkwIn0.";
let config = OidcValidationConfig::builder()
.issuer("https://example.com")
.audience("test")
.build()
.unwrap();
#[allow(clippy::disallowed_methods)] let result = validator.validate(token_none, &config, Utc::now()).await;
assert!(matches!(result, Err(OidcError::AlgorithmMismatch { .. })));
}
struct MockJwksClient;
impl MockJwksClient {
fn new() -> Self {
Self
}
}
#[async_trait]
impl JwksClient for MockJwksClient {
async fn fetch_jwks(&self, _issuer_url: &str) -> Result<serde_json::Value, OidcError> {
Ok(serde_json::json!({
"keys": [
{
"kty": "RSA",
"kid": "test-key-1",
"use": "sig",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e": "AQAB"
}
]
}))
}
}
}