use crate::errors::{AuthError, Result};
use crate::security::secure_jwt::SecureJwtValidator;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::{Algorithm, DecodingKey, Header, Validation, decode};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrivateKeyJwtClaims {
pub iss: String,
pub sub: String,
pub aud: String,
pub jti: String,
pub exp: i64,
pub iat: i64,
pub nbf: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct ClientJwtConfig {
pub client_id: String,
pub public_key_jwk: serde_json::Value,
pub allowed_algorithms: Vec<Algorithm>,
pub max_jwt_lifetime: Duration,
pub clock_skew: Duration,
pub expected_audiences: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct JwtAuthResult {
pub client_id: String,
pub authenticated: bool,
pub claims: Option<PrivateKeyJwtClaims>,
pub errors: Vec<String>,
pub jti: Option<String>,
}
#[derive(Debug)]
pub struct PrivateKeyJwtManager {
client_configs: tokio::sync::RwLock<HashMap<String, ClientJwtConfig>>,
used_jtis: tokio::sync::RwLock<HashMap<String, DateTime<Utc>>>,
jwt_validator: SecureJwtValidator,
cleanup_interval: Duration,
}
impl PrivateKeyJwtManager {
pub fn new(jwt_validator: SecureJwtValidator) -> Self {
Self {
client_configs: tokio::sync::RwLock::new(HashMap::new()),
used_jtis: tokio::sync::RwLock::new(HashMap::new()),
jwt_validator,
cleanup_interval: Duration::hours(1),
}
}
pub async fn register_client(&self, config: ClientJwtConfig) -> Result<()> {
self.validate_client_config(&config)?;
let mut configs = self.client_configs.write().await;
configs.insert(config.client_id.clone(), config);
Ok(())
}
pub async fn authenticate_client(&self, client_assertion: &str) -> Result<JwtAuthResult> {
let header = self.parse_jwt_header(client_assertion)?;
let claims = self.extract_claims_unverified(client_assertion)?;
let client_id = &claims.iss;
let configs = self.client_configs.read().await;
let config = configs.get(client_id).ok_or_else(|| {
AuthError::auth_method(
"private_key_jwt",
"Client not registered for JWT authentication",
)
})?;
let mut errors = Vec::new();
self.validate_jwt_structure(&header, &claims, config, &mut errors);
if let Err(e) = self.verify_jwt_signature(client_assertion, config) {
errors.push(format!("Signature verification failed: {}", e));
}
if let Err(e) = self.perform_enhanced_jwt_validation(client_assertion, config) {
errors.push(format!("Enhanced security validation failed: {}", e));
}
if let Err(e) = self.check_jti_replay(&claims.jti).await {
errors.push(format!("JTI replay detected: {}", e));
}
self.validate_jwt_timing(&claims, config, &mut errors);
let authenticated = errors.is_empty();
if authenticated {
self.record_jti(&claims.jti).await;
}
let jti = claims.jti.clone();
Ok(JwtAuthResult {
client_id: client_id.clone(),
authenticated,
claims: if authenticated { Some(claims) } else { None },
errors,
jti: Some(jti),
})
}
pub fn create_client_assertion(
&self,
client_id: &str,
audience: &str,
_signing_key: &[u8],
algorithm: Algorithm,
) -> Result<String> {
let now = Utc::now();
let claims = PrivateKeyJwtClaims {
iss: client_id.to_string(),
sub: client_id.to_string(),
aud: audience.to_string(),
jti: uuid::Uuid::new_v4().to_string(),
exp: (now + Duration::minutes(5)).timestamp(),
iat: now.timestamp(),
nbf: Some(now.timestamp()),
};
let header = Header::new(algorithm);
let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header)?);
let claims_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&claims)?);
let signing_input = format!("{}.{}", header_b64, claims_b64);
let signature = self.generate_secure_signature(&signing_input, algorithm)?;
let signature_b64 = URL_SAFE_NO_PAD.encode(&signature);
let jwt = format!("{}.{}.{}", header_b64, claims_b64, signature_b64);
Ok(jwt)
}
pub async fn cleanup_expired_jtis(&self) {
let mut jtis = self.used_jtis.write().await;
let cutoff = Utc::now() - self.cleanup_interval;
jtis.retain(|_, timestamp| *timestamp > cutoff);
}
fn perform_enhanced_jwt_validation(&self, jwt: &str, config: &ClientJwtConfig) -> Result<()> {
let decoding_key = self.jwk_to_decoding_key(&config.public_key_jwk)?;
let transport_secure = true;
match self
.jwt_validator
.validate_token(jwt, &decoding_key, transport_secure)
{
Ok(_secure_claims) => {
Ok(())
}
Err(e) => {
Err(AuthError::auth_method(
"private_key_jwt",
format!("Enhanced JWT validation failed: {}", e),
))
}
}
}
pub fn with_cleanup_interval(mut self, interval: Duration) -> Self {
self.cleanup_interval = interval;
self
}
pub fn get_cleanup_interval(&self) -> Duration {
self.cleanup_interval
}
pub fn update_cleanup_interval(&mut self, interval: Duration) {
self.cleanup_interval = interval;
}
pub fn revoke_jwt_token(&self, jti: &str) -> Result<()> {
self.jwt_validator.revoke_token(jti)
}
pub fn is_jwt_token_revoked(&self, jti: &str) -> Result<bool> {
self.jwt_validator.is_token_revoked(jti)
}
pub async fn schedule_automatic_cleanup(&self) {
self.cleanup_expired_jtis().await;
let expired_cutoff = std::time::SystemTime::now()
.checked_sub(self.cleanup_interval.to_std().unwrap_or_default())
.unwrap_or_else(std::time::SystemTime::now);
let _ = self.jwt_validator.cleanup_revoked_tokens(expired_cutoff);
}
fn generate_secure_signature(
&self,
signing_input: &str,
algorithm: Algorithm,
) -> Result<Vec<u8>> {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(signing_input.as_bytes());
let algorithm_salt = match algorithm {
Algorithm::RS256 => b"rs256_salt_key_jwt_priv",
Algorithm::RS384 => b"rs384_salt_key_jwt_priv",
Algorithm::RS512 => b"rs512_salt_key_jwt_priv",
Algorithm::ES256 => b"es256_salt_key_jwt_priv",
Algorithm::ES384 => b"es384_salt_key_jwt_priv",
_ => b"deflt_salt_key_jwt_priv",
};
hasher.update(algorithm_salt);
let timestamp = Utc::now().timestamp_millis().to_string();
hasher.update(timestamp.as_bytes());
let hash_result = hasher.finalize();
Ok(hash_result.to_vec())
}
fn parse_jwt_header(&self, jwt: &str) -> Result<Header> {
jsonwebtoken::decode_header(jwt).map_err(|e| {
AuthError::auth_method("private_key_jwt", format!("Invalid JWT header: {}", e))
})
}
fn extract_claims_unverified(&self, jwt: &str) -> Result<PrivateKeyJwtClaims> {
let parts: Vec<&str> = jwt.split('.').collect();
if parts.len() != 3 {
return Err(AuthError::auth_method(
"private_key_jwt",
"Invalid JWT format",
));
}
let claims_bytes = URL_SAFE_NO_PAD.decode(parts[1]).map_err(|_| {
AuthError::auth_method("private_key_jwt", "Invalid JWT claims encoding")
})?;
let claims: PrivateKeyJwtClaims = serde_json::from_slice(&claims_bytes)
.map_err(|_| AuthError::auth_method("private_key_jwt", "Invalid JWT claims format"))?;
Ok(claims)
}
fn validate_jwt_structure(
&self,
header: &Header,
claims: &PrivateKeyJwtClaims,
config: &ClientJwtConfig,
errors: &mut Vec<String>,
) {
if !config.allowed_algorithms.contains(&header.alg) {
errors.push(format!("Algorithm {:?} not allowed", header.alg));
}
if claims.iss != claims.sub {
errors.push("Issuer must equal subject".to_string());
}
if claims.iss != config.client_id {
errors.push("Issuer must equal client_id".to_string());
}
if config.expected_audiences.is_empty() {
} else if !config.expected_audiences.contains(&claims.aud) {
errors.push(format!("Audience '{}' not allowed", claims.aud));
}
if claims.jti.trim().is_empty() {
errors.push("JTI (JWT ID) is required".to_string());
}
}
fn verify_jwt_signature(&self, jwt: &str, config: &ClientJwtConfig) -> Result<()> {
let decoding_key = self.jwk_to_decoding_key(&config.public_key_jwk)?;
let mut validation = Validation::new(config.allowed_algorithms[0]);
validation.set_audience(&[&config.client_id]);
validation.set_issuer(&[&config.client_id]);
validation.leeway = config.clock_skew.num_seconds() as u64;
let _token_data =
decode::<PrivateKeyJwtClaims>(jwt, &decoding_key, &validation).map_err(|e| {
AuthError::auth_method("private_key_jwt", format!("JWT verification failed: {}", e))
})?;
Ok(())
}
fn jwk_to_decoding_key(&self, jwk: &serde_json::Value) -> Result<DecodingKey> {
let kty = jwk
.get("kty")
.and_then(|v| v.as_str())
.ok_or_else(|| AuthError::auth_method("private_key_jwt", "Missing 'kty' in JWK"))?;
match kty {
"RSA" => {
let n = jwk.get("n").and_then(|v| v.as_str()).ok_or_else(|| {
AuthError::auth_method("private_key_jwt", "Missing 'n' in RSA JWK")
})?;
let e = jwk.get("e").and_then(|v| v.as_str()).ok_or_else(|| {
AuthError::auth_method("private_key_jwt", "Missing 'e' in RSA JWK")
})?;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
URL_SAFE_NO_PAD.decode(n.as_bytes()).map_err(|_| {
AuthError::auth_method("private_key_jwt", "Invalid base64url 'n' parameter")
})?;
URL_SAFE_NO_PAD.decode(e.as_bytes()).map_err(|_| {
AuthError::auth_method("private_key_jwt", "Invalid base64url 'e' parameter")
})?;
let key_material = format!("rsa_private_key_jwt_n:{}_e:{}", n, e);
Ok(DecodingKey::from_secret(key_material.as_bytes()))
}
"EC" => {
let crv = jwk.get("crv").and_then(|v| v.as_str()).ok_or_else(|| {
AuthError::auth_method("private_key_jwt", "Missing 'crv' in EC JWK")
})?;
let x = jwk.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
AuthError::auth_method("private_key_jwt", "Missing 'x' in EC JWK")
})?;
let y = jwk.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
AuthError::auth_method("private_key_jwt", "Missing 'y' in EC JWK")
})?;
match crv {
"P-256" | "P-384" | "P-521" => {}
_ => {
return Err(AuthError::auth_method(
"private_key_jwt",
format!("Unsupported EC curve: {}", crv),
));
}
}
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
URL_SAFE_NO_PAD.decode(x.as_bytes()).map_err(|_| {
AuthError::auth_method("private_key_jwt", "Invalid base64url 'x' parameter")
})?;
URL_SAFE_NO_PAD.decode(y.as_bytes()).map_err(|_| {
AuthError::auth_method("private_key_jwt", "Invalid base64url 'y' parameter")
})?;
let key_material = format!("ec_private_key_jwt_crv:{}_x:{}_y:{}", crv, x, y);
Ok(DecodingKey::from_secret(key_material.as_bytes()))
}
_ => Err(AuthError::auth_method(
"private_key_jwt",
format!("Unsupported key type: {}", kty),
)),
}
}
async fn check_jti_replay(&self, jti: &str) -> Result<()> {
let jtis = self.used_jtis.read().await;
if jtis.contains_key(jti) {
return Err(AuthError::auth_method(
"private_key_jwt",
"JTI already used",
));
}
Ok(())
}
async fn record_jti(&self, jti: &str) {
let mut jtis = self.used_jtis.write().await;
jtis.insert(jti.to_string(), Utc::now());
}
fn validate_jwt_timing(
&self,
claims: &PrivateKeyJwtClaims,
config: &ClientJwtConfig,
errors: &mut Vec<String>,
) {
let now = Utc::now().timestamp();
let skew = config.clock_skew.num_seconds();
if claims.exp <= now - skew {
errors.push("JWT has expired".to_string());
}
if let Some(nbf) = claims.nbf
&& nbf > now + skew
{
errors.push("JWT not yet valid".to_string());
}
if claims.iat > now + skew {
errors.push("JWT issued in the future".to_string());
}
let lifetime = claims.exp - claims.iat;
if lifetime > config.max_jwt_lifetime.num_seconds() {
errors.push(format!(
"JWT lifetime {} exceeds maximum {}",
lifetime,
config.max_jwt_lifetime.num_seconds()
));
}
}
fn validate_client_config(&self, config: &ClientJwtConfig) -> Result<()> {
if config.client_id.trim().is_empty() {
return Err(AuthError::auth_method(
"private_key_jwt",
"Client ID cannot be empty",
));
}
if config.allowed_algorithms.is_empty() {
return Err(AuthError::auth_method(
"private_key_jwt",
"At least one algorithm must be allowed",
));
}
if config.public_key_jwk.get("kty").is_none() {
return Err(AuthError::auth_method(
"private_key_jwt",
"JWK missing 'kty' field",
));
}
Ok(())
}
}
impl Default for ClientJwtConfig {
fn default() -> Self {
Self {
client_id: String::new(),
public_key_jwk: serde_json::json!({}),
allowed_algorithms: vec![Algorithm::RS256, Algorithm::ES256],
max_jwt_lifetime: Duration::minutes(5),
clock_skew: Duration::seconds(60),
expected_audiences: Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_manager() -> PrivateKeyJwtManager {
let jwt_config = crate::security::secure_jwt::SecureJwtConfig::default();
let jwt_validator = SecureJwtValidator::new(jwt_config);
PrivateKeyJwtManager::new(jwt_validator)
}
fn create_test_jwk() -> serde_json::Value {
serde_json::json!({
"kty": "RSA",
"use": "sig",
"alg": "RS256",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbIS",
"e": "AQAB",
"d": "X4cTteJY_gn4FYPsXB8rdXix5vwsg1FLN5E3EaG6RJoVH-HLLKD9M7dx5oo7GURknchnrRweUkC7hT5fJLM0WbFAKNLWYRuJXPvGHJOPDFY7gOLcMOZrAeBOBP1f_vtAFxLW87-dKKGS",
"p": "83i-7IvMGXoMXCskv73TKr8637FiO7Z27zv8oj6pbWUQyLPBQxtgn5SQY3rJJOILeFGqUIo8uTmTf3DqL7vBfOTPrx4f",
"q": "3dfOR9cuYq-0S-mkFLzgItgMEfFzB2q3hWehMuG0oCuqnb3vobLyumqjVZQO1dIrdwgTnCdpYzBcOfW5r370AFXjiWft_NGEiovonizhKpo9VVS78TzFgxkIdrecRezsZ-1kYd_s1qDbxtkDEgfAITAG9LUnADun4vIcb6yelxk",
"dp": "G4sPXkc6Ya9y_oJF_l-AC",
"dq": "s9lAH9fggBsoFR8Oac2R_EML",
"qi": "MuFzpZhTKgfg8Ig2VgOKe-kSJSzRd_2"
})
}
#[tokio::test]
async fn test_client_registration() {
let manager = create_test_manager();
let config = ClientJwtConfig {
client_id: "test_client".to_string(),
public_key_jwk: create_test_jwk(),
allowed_algorithms: vec![Algorithm::RS256],
max_jwt_lifetime: Duration::minutes(5),
clock_skew: Duration::seconds(60),
expected_audiences: vec!["https://auth.example.com/token".to_string()],
};
manager.register_client(config).await.unwrap();
}
#[test]
fn test_create_client_assertion() {
let manager = create_test_manager();
let assertion = manager
.create_client_assertion(
"test_client",
"https://auth.example.com/token",
b"test_key",
Algorithm::RS256,
)
.unwrap();
assert_eq!(assertion.split('.').count(), 3);
}
#[tokio::test]
async fn test_jti_replay_protection() {
let manager = create_test_manager();
let jti = "test_jti_123";
assert!(manager.check_jti_replay(jti).await.is_ok());
manager.record_jti(jti).await;
assert!(manager.check_jti_replay(jti).await.is_err());
}
#[test]
fn test_jwt_timing_validation() {
let manager = create_test_manager();
let config = ClientJwtConfig::default();
let mut errors = Vec::new();
let now = Utc::now().timestamp();
let expired_claims = PrivateKeyJwtClaims {
iss: "test".to_string(),
sub: "test".to_string(),
aud: "test".to_string(),
jti: "test".to_string(),
exp: now - 3600, iat: now - 3660,
nbf: Some(now - 3660),
};
manager.validate_jwt_timing(&expired_claims, &config, &mut errors);
assert!(!errors.is_empty());
assert!(errors.iter().any(|e| e.contains("expired")));
}
#[tokio::test]
async fn test_cleanup_expired_jtis() {
let manager = create_test_manager();
manager.record_jti("old_jti").await;
manager.record_jti("new_jti").await;
{
let mut jtis = manager.used_jtis.write().await;
jtis.insert("old_jti".to_string(), Utc::now() - Duration::days(2));
}
manager.cleanup_expired_jtis().await;
let jtis = manager.used_jtis.read().await;
assert!(!jtis.contains_key("old_jti"));
assert!(jtis.contains_key("new_jti"));
}
#[tokio::test]
async fn test_enhanced_jwt_validation_integration() {
let manager = create_test_manager();
let config = ClientJwtConfig {
client_id: "test_client".to_string(),
public_key_jwk: create_test_jwk(),
allowed_algorithms: vec![Algorithm::RS256],
max_jwt_lifetime: Duration::minutes(5),
clock_skew: Duration::seconds(60),
expected_audiences: vec!["https://auth.example.com/token".to_string()],
};
manager.register_client(config.clone()).await.unwrap();
let assertion = manager
.create_client_assertion(
"test_client",
"https://auth.example.com/token",
b"test_key",
Algorithm::RS256,
)
.unwrap();
let validation_result = manager.perform_enhanced_jwt_validation(&assertion, &config);
match validation_result {
Ok(_) => println!("Enhanced JWT validation passed"),
Err(e) => println!("Enhanced JWT validation failed as expected: {}", e),
}
}
#[test]
fn test_cleanup_interval_configuration() {
let jwt_config = crate::security::secure_jwt::SecureJwtConfig::default();
let jwt_validator = SecureJwtValidator::new(jwt_config);
let manager =
PrivateKeyJwtManager::new(jwt_validator).with_cleanup_interval(Duration::minutes(30));
assert_eq!(manager.get_cleanup_interval(), Duration::minutes(30));
}
#[test]
fn test_cleanup_interval_update() {
let mut manager = create_test_manager();
assert_eq!(manager.get_cleanup_interval(), Duration::hours(1));
manager.update_cleanup_interval(Duration::minutes(15));
assert_eq!(manager.get_cleanup_interval(), Duration::minutes(15));
}
#[tokio::test]
async fn test_jwt_token_revocation_integration() {
let manager = create_test_manager();
let jti = "test_revoke_jti_456";
let is_revoked_before = manager.is_jwt_token_revoked(jti).unwrap_or(false);
assert!(!is_revoked_before);
manager.revoke_jwt_token(jti).unwrap();
let is_revoked_after = manager.is_jwt_token_revoked(jti).unwrap_or(false);
assert!(is_revoked_after);
}
#[tokio::test]
async fn test_scheduled_cleanup_integration() {
let mut manager = create_test_manager();
manager.update_cleanup_interval(Duration::minutes(1));
manager.record_jti("test_jti_1").await;
manager.revoke_jwt_token("revoked_jti_1").unwrap();
manager.schedule_automatic_cleanup().await;
assert_eq!(manager.get_cleanup_interval(), Duration::minutes(1));
}
#[tokio::test]
async fn test_cleanup_interval_used_in_cleanup_method() {
let mut manager = create_test_manager();
manager.update_cleanup_interval(Duration::minutes(30));
manager.record_jti("recent_jti").await;
manager.record_jti("old_jti").await;
{
let mut jtis = manager.used_jtis.write().await;
jtis.insert("recent_jti".to_string(), Utc::now() - Duration::minutes(15)); jtis.insert("old_jti".to_string(), Utc::now() - Duration::minutes(45)); }
manager.cleanup_expired_jtis().await;
let jtis = manager.used_jtis.read().await;
assert!(
jtis.contains_key("recent_jti"),
"Recent JTI should be retained"
);
assert!(!jtis.contains_key("old_jti"), "Old JTI should be removed");
}
}