use crate::errors::{AuthError, Result};
use crate::security::secure_jwt::SecureJwtValidator;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
use chrono::{DateTime, Duration, Utc};
use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrivateKeyJwtClaims {
pub iss: String,
pub sub: String,
pub aud: String,
pub jti: String,
pub exp: i64,
pub iat: i64,
pub nbf: Option<i64>,
}
#[derive(Debug, Clone)]
pub struct ClientJwtConfig {
pub client_id: String,
pub public_key_jwk: serde_json::Value,
pub allowed_algorithms: Vec<Algorithm>,
pub max_jwt_lifetime: Duration,
pub clock_skew: Duration,
pub expected_audiences: Vec<String>,
}
impl ClientJwtConfig {
pub fn builder(
client_id: impl Into<String>,
public_key_jwk: serde_json::Value,
) -> ClientJwtConfigBuilder {
ClientJwtConfigBuilder {
client_id: client_id.into(),
public_key_jwk,
allowed_algorithms: vec![Algorithm::RS256, Algorithm::ES256],
max_jwt_lifetime: Duration::minutes(5),
clock_skew: Duration::seconds(60),
expected_audiences: Vec::new(),
}
}
}
pub struct ClientJwtConfigBuilder {
client_id: String,
public_key_jwk: serde_json::Value,
allowed_algorithms: Vec<Algorithm>,
max_jwt_lifetime: Duration,
clock_skew: Duration,
expected_audiences: Vec<String>,
}
impl ClientJwtConfigBuilder {
pub fn rs256_only(mut self) -> Self {
self.allowed_algorithms = vec![Algorithm::RS256];
self
}
pub fn algorithms(mut self, algorithms: Vec<Algorithm>) -> Self {
self.allowed_algorithms = algorithms;
self
}
pub fn max_jwt_lifetime(mut self, max_jwt_lifetime: Duration) -> Self {
self.max_jwt_lifetime = max_jwt_lifetime;
self
}
pub fn clock_skew(mut self, clock_skew: Duration) -> Self {
self.clock_skew = clock_skew;
self
}
pub fn audience(mut self, audience: impl Into<String>) -> Self {
self.expected_audiences.push(audience.into());
self
}
pub fn audiences<I, S>(mut self, audiences: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.expected_audiences = audiences.into_iter().map(Into::into).collect();
self
}
pub fn build(self) -> ClientJwtConfig {
ClientJwtConfig {
client_id: self.client_id,
public_key_jwk: self.public_key_jwk,
allowed_algorithms: self.allowed_algorithms,
max_jwt_lifetime: self.max_jwt_lifetime,
clock_skew: self.clock_skew,
expected_audiences: self.expected_audiences,
}
}
}
#[derive(Debug, Clone)]
pub struct JwtAuthResult {
pub client_id: String,
pub authenticated: bool,
pub claims: Option<PrivateKeyJwtClaims>,
pub errors: Vec<String>,
pub jti: Option<String>,
}
#[derive(Debug)]
pub struct PrivateKeyJwtManager {
client_configs: tokio::sync::RwLock<HashMap<String, ClientJwtConfig>>,
used_jtis: tokio::sync::RwLock<HashMap<String, DateTime<Utc>>>,
jwt_validator: SecureJwtValidator,
cleanup_interval: Duration,
}
impl PrivateKeyJwtManager {
pub fn new(jwt_validator: SecureJwtValidator) -> Self {
Self {
client_configs: tokio::sync::RwLock::new(HashMap::new()),
used_jtis: tokio::sync::RwLock::new(HashMap::new()),
jwt_validator,
cleanup_interval: Duration::hours(1),
}
}
pub async fn register_client(&self, config: ClientJwtConfig) -> Result<()> {
self.validate_client_config(&config)?;
let mut configs = self.client_configs.write().await;
configs.insert(config.client_id.clone(), config);
Ok(())
}
pub async fn authenticate_client(&self, client_assertion: &str) -> Result<JwtAuthResult> {
let header = self.parse_jwt_header(client_assertion)?;
let claims = self.extract_claims_unverified(client_assertion)?;
let client_id = &claims.iss;
let configs = self.client_configs.read().await;
let config = configs.get(client_id).ok_or_else(|| {
AuthError::auth_method(
"private_key_jwt",
"Client not registered for JWT authentication",
)
})?;
let mut errors = Vec::new();
self.validate_jwt_structure(&header, &claims, config, &mut errors);
if let Err(e) = self.verify_jwt_signature(client_assertion, config) {
errors.push(format!("Signature verification failed: {}", e));
}
if let Err(e) = self.perform_enhanced_jwt_validation(client_assertion, config) {
errors.push(format!("Enhanced security validation failed: {}", e));
}
if let Err(e) = self.check_jti_replay(&claims.jti).await {
errors.push(format!("JTI replay detected: {}", e));
}
self.validate_jwt_timing(&claims, config, &mut errors);
let authenticated = errors.is_empty();
if authenticated {
self.record_jti(&claims.jti).await;
}
let jti = claims.jti.clone();
Ok(JwtAuthResult {
client_id: client_id.clone(),
authenticated,
claims: if authenticated { Some(claims) } else { None },
errors,
jti: Some(jti),
})
}
pub fn create_client_assertion(
&self,
client_id: &str,
audience: &str,
signing_key: &[u8],
algorithm: Algorithm,
) -> Result<String> {
let now = Utc::now();
let claims = PrivateKeyJwtClaims {
iss: client_id.to_string(),
sub: client_id.to_string(),
aud: audience.to_string(),
jti: uuid::Uuid::new_v4().to_string(),
exp: (now + Duration::minutes(5)).timestamp(),
iat: now.timestamp(),
nbf: Some(now.timestamp()),
};
let encoding_key = match algorithm {
Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
EncodingKey::from_secret(signing_key)
}
Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
EncodingKey::from_rsa_pem(signing_key).map_err(|e| {
AuthError::auth_method(
"private_key_jwt",
format!("Invalid RSA PEM key for {:?}: {}", algorithm, e),
)
})?
}
Algorithm::ES256 | Algorithm::ES384 => {
EncodingKey::from_ec_pem(signing_key).map_err(|e| {
AuthError::auth_method(
"private_key_jwt",
format!("Invalid EC PEM key for {:?}: {}", algorithm, e),
)
})?
}
_ => {
return Err(AuthError::auth_method(
"private_key_jwt",
format!("Unsupported signing algorithm: {:?}", algorithm),
));
}
};
let header = Header::new(algorithm);
encode(&header, &claims, &encoding_key).map_err(|e| {
AuthError::auth_method("private_key_jwt", format!("Failed to encode JWT: {}", e))
})
}
pub async fn cleanup_expired_jtis(&self) {
let mut jtis = self.used_jtis.write().await;
let cutoff = Utc::now() - self.cleanup_interval;
jtis.retain(|_, timestamp| *timestamp > cutoff);
}
fn perform_enhanced_jwt_validation(&self, jwt: &str, config: &ClientJwtConfig) -> Result<()> {
let decoding_key = self.jwk_to_decoding_key(&config.public_key_jwk)?;
match self.jwt_validator.validate_token(jwt, &decoding_key) {
Ok(_secure_claims) => {
Ok(())
}
Err(e) => {
Err(AuthError::auth_method(
"private_key_jwt",
format!("Enhanced JWT validation failed: {}", e),
))
}
}
}
pub fn with_cleanup_interval(mut self, interval: Duration) -> Self {
self.cleanup_interval = interval;
self
}
pub fn get_cleanup_interval(&self) -> Duration {
self.cleanup_interval
}
pub fn update_cleanup_interval(&mut self, interval: Duration) {
self.cleanup_interval = interval;
}
pub fn revoke_jwt_token(&self, jti: &str) -> Result<()> {
self.jwt_validator.revoke_token(jti)
}
pub fn is_jwt_token_revoked(&self, jti: &str) -> Result<bool> {
self.jwt_validator.is_token_revoked(jti)
}
pub async fn schedule_automatic_cleanup(&self) {
self.cleanup_expired_jtis().await;
let expired_cutoff = std::time::SystemTime::now()
.checked_sub(self.cleanup_interval.to_std().unwrap_or_default())
.unwrap_or_else(std::time::SystemTime::now);
let _ = self.jwt_validator.cleanup_revoked_tokens(expired_cutoff);
}
fn parse_jwt_header(&self, jwt: &str) -> Result<Header> {
jsonwebtoken::decode_header(jwt).map_err(|e| {
AuthError::auth_method("private_key_jwt", format!("Invalid JWT header: {}", e))
})
}
fn extract_claims_unverified(&self, jwt: &str) -> Result<PrivateKeyJwtClaims> {
let parts: Vec<&str> = jwt.split('.').collect();
if parts.len() != 3 {
return Err(AuthError::auth_method(
"private_key_jwt",
"Invalid JWT format",
));
}
let claims_bytes = URL_SAFE_NO_PAD.decode(parts[1]).map_err(|_| {
AuthError::auth_method("private_key_jwt", "Invalid JWT claims encoding")
})?;
let claims: PrivateKeyJwtClaims = serde_json::from_slice(&claims_bytes)
.map_err(|_| AuthError::auth_method("private_key_jwt", "Invalid JWT claims format"))?;
Ok(claims)
}
fn validate_jwt_structure(
&self,
header: &Header,
claims: &PrivateKeyJwtClaims,
config: &ClientJwtConfig,
errors: &mut Vec<String>,
) {
if !config.allowed_algorithms.contains(&header.alg) {
errors.push(format!("Algorithm {:?} not allowed", header.alg));
}
if claims.iss != claims.sub {
errors.push("Issuer must equal subject".to_string());
}
if claims.iss != config.client_id {
errors.push("Issuer must equal client_id".to_string());
}
if config.expected_audiences.is_empty() {
} else if !config.expected_audiences.contains(&claims.aud) {
errors.push(format!("Audience '{}' not allowed", claims.aud));
}
if claims.jti.trim().is_empty() {
errors.push("JTI (JWT ID) is required".to_string());
}
}
fn verify_jwt_signature(&self, jwt: &str, config: &ClientJwtConfig) -> Result<()> {
let decoding_key = self.jwk_to_decoding_key(&config.public_key_jwk)?;
let mut validation = Validation::new(config.allowed_algorithms[0]);
validation.set_audience(&[&config.client_id]);
validation.set_issuer(&[&config.client_id]);
validation.leeway = config.clock_skew.num_seconds() as u64;
let _token_data =
decode::<PrivateKeyJwtClaims>(jwt, &decoding_key, &validation).map_err(|e| {
AuthError::auth_method("private_key_jwt", format!("JWT verification failed: {}", e))
})?;
Ok(())
}
fn jwk_to_decoding_key(&self, jwk: &serde_json::Value) -> Result<DecodingKey> {
let kty = jwk
.get("kty")
.and_then(|v| v.as_str())
.ok_or_else(|| AuthError::auth_method("private_key_jwt", "Missing 'kty' in JWK"))?;
match kty {
"RSA" => {
let n = jwk.get("n").and_then(|v| v.as_str()).ok_or_else(|| {
AuthError::auth_method("private_key_jwt", "Missing 'n' in RSA JWK")
})?;
let e = jwk.get("e").and_then(|v| v.as_str()).ok_or_else(|| {
AuthError::auth_method("private_key_jwt", "Missing 'e' in RSA JWK")
})?;
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
URL_SAFE_NO_PAD.decode(n.as_bytes()).map_err(|_| {
AuthError::auth_method("private_key_jwt", "Invalid base64url 'n' parameter")
})?;
URL_SAFE_NO_PAD.decode(e.as_bytes()).map_err(|_| {
AuthError::auth_method("private_key_jwt", "Invalid base64url 'e' parameter")
})?;
let key_material = format!("rsa_private_key_jwt_n:{}_e:{}", n, e);
Ok(DecodingKey::from_secret(key_material.as_bytes()))
}
"EC" => {
let crv = jwk.get("crv").and_then(|v| v.as_str()).ok_or_else(|| {
AuthError::auth_method("private_key_jwt", "Missing 'crv' in EC JWK")
})?;
let x = jwk.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
AuthError::auth_method("private_key_jwt", "Missing 'x' in EC JWK")
})?;
let y = jwk.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
AuthError::auth_method("private_key_jwt", "Missing 'y' in EC JWK")
})?;
match crv {
"P-256" | "P-384" | "P-521" => {}
_ => {
return Err(AuthError::auth_method(
"private_key_jwt",
format!("Unsupported EC curve: {}", crv),
));
}
}
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
URL_SAFE_NO_PAD.decode(x.as_bytes()).map_err(|_| {
AuthError::auth_method("private_key_jwt", "Invalid base64url 'x' parameter")
})?;
URL_SAFE_NO_PAD.decode(y.as_bytes()).map_err(|_| {
AuthError::auth_method("private_key_jwt", "Invalid base64url 'y' parameter")
})?;
let key_material = format!("ec_private_key_jwt_crv:{}_x:{}_y:{}", crv, x, y);
Ok(DecodingKey::from_secret(key_material.as_bytes()))
}
_ => Err(AuthError::auth_method(
"private_key_jwt",
format!("Unsupported key type: {}", kty),
)),
}
}
async fn check_jti_replay(&self, jti: &str) -> Result<()> {
let jtis = self.used_jtis.read().await;
if jtis.contains_key(jti) {
return Err(AuthError::auth_method(
"private_key_jwt",
"JTI already used",
));
}
Ok(())
}
async fn record_jti(&self, jti: &str) {
let mut jtis = self.used_jtis.write().await;
jtis.insert(jti.to_string(), Utc::now());
}
fn validate_jwt_timing(
&self,
claims: &PrivateKeyJwtClaims,
config: &ClientJwtConfig,
errors: &mut Vec<String>,
) {
let now = Utc::now().timestamp();
let skew = config.clock_skew.num_seconds();
if claims.exp <= now - skew {
errors.push("JWT has expired".to_string());
}
if let Some(nbf) = claims.nbf
&& nbf > now + skew
{
errors.push("JWT not yet valid".to_string());
}
if claims.iat > now + skew {
errors.push("JWT issued in the future".to_string());
}
let lifetime = claims.exp - claims.iat;
if lifetime > config.max_jwt_lifetime.num_seconds() {
errors.push(format!(
"JWT lifetime {} exceeds maximum {}",
lifetime,
config.max_jwt_lifetime.num_seconds()
));
}
}
fn validate_client_config(&self, config: &ClientJwtConfig) -> Result<()> {
if config.client_id.trim().is_empty() {
return Err(AuthError::auth_method(
"private_key_jwt",
"Client ID cannot be empty",
));
}
if config.allowed_algorithms.is_empty() {
return Err(AuthError::auth_method(
"private_key_jwt",
"At least one algorithm must be allowed",
));
}
if config.public_key_jwk.get("kty").is_none() {
return Err(AuthError::auth_method(
"private_key_jwt",
"JWK missing 'kty' field",
));
}
Ok(())
}
}
impl Default for ClientJwtConfig {
fn default() -> Self {
Self {
client_id: String::new(),
public_key_jwk: serde_json::json!({}),
allowed_algorithms: vec![Algorithm::RS256, Algorithm::ES256],
max_jwt_lifetime: Duration::minutes(5),
clock_skew: Duration::seconds(60),
expected_audiences: Vec::new(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_manager() -> PrivateKeyJwtManager {
let jwt_config = crate::security::secure_jwt::SecureJwtConfig::default();
let jwt_validator = SecureJwtValidator::new(jwt_config).expect("test JWT config");
PrivateKeyJwtManager::new(jwt_validator)
}
fn create_test_jwk() -> serde_json::Value {
serde_json::json!({
"kty": "RSA",
"use": "sig",
"alg": "RS256",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbIS",
"e": "AQAB",
"d": "X4cTteJY_gn4FYPsXB8rdXix5vwsg1FLN5E3EaG6RJoVH-HLLKD9M7dx5oo7GURknchnrRweUkC7hT5fJLM0WbFAKNLWYRuJXPvGHJOPDFY7gOLcMOZrAeBOBP1f_vtAFxLW87-dKKGS",
"p": "83i-7IvMGXoMXCskv73TKr8637FiO7Z27zv8oj6pbWUQyLPBQxtgn5SQY3rJJOILeFGqUIo8uTmTf3DqL7vBfOTPrx4f",
"q": "3dfOR9cuYq-0S-mkFLzgItgMEfFzB2q3hWehMuG0oCuqnb3vobLyumqjVZQO1dIrdwgTnCdpYzBcOfW5r370AFXjiWft_NGEiovonizhKpo9VVS78TzFgxkIdrecRezsZ-1kYd_s1qDbxtkDEgfAITAG9LUnADun4vIcb6yelxk",
"dp": "G4sPXkc6Ya9y_oJF_l-AC",
"dq": "s9lAH9fggBsoFR8Oac2R_EML",
"qi": "MuFzpZhTKgfg8Ig2VgOKe-kSJSzRd_2"
})
}
#[tokio::test]
async fn test_client_registration() {
let manager = create_test_manager();
let config = ClientJwtConfig::builder("test_client", create_test_jwk())
.rs256_only()
.audience("https://auth.example.com/token")
.build();
manager.register_client(config).await.unwrap();
}
#[test]
fn test_create_client_assertion() {
let manager = create_test_manager();
let assertion = manager
.create_client_assertion(
"test_client",
"https://auth.example.com/token",
b"super-secret-key-for-testing-purposes",
Algorithm::HS256,
)
.unwrap();
let parts: Vec<&str> = assertion.split('.').collect();
assert_eq!(parts.len(), 3);
assert!(!parts[0].is_empty());
assert!(!parts[1].is_empty());
assert!(!parts[2].is_empty());
}
#[test]
fn test_create_client_assertion_rs256_requires_pem_key() {
let manager = create_test_manager();
let result = manager.create_client_assertion(
"test_client",
"https://auth.example.com/token",
b"not_a_pem_key",
Algorithm::RS256,
);
assert!(result.is_err(), "RS256 must reject non-PEM key bytes");
}
#[tokio::test]
async fn test_jti_replay_protection() {
let manager = create_test_manager();
let jti = "test_jti_123";
assert!(manager.check_jti_replay(jti).await.is_ok());
manager.record_jti(jti).await;
assert!(manager.check_jti_replay(jti).await.is_err());
}
#[test]
fn test_jwt_timing_validation() {
let manager = create_test_manager();
let config = ClientJwtConfig::default();
let mut errors = Vec::new();
let now = Utc::now().timestamp();
let expired_claims = PrivateKeyJwtClaims {
iss: "test".to_string(),
sub: "test".to_string(),
aud: "test".to_string(),
jti: "test".to_string(),
exp: now - 3600, iat: now - 3660,
nbf: Some(now - 3660),
};
manager.validate_jwt_timing(&expired_claims, &config, &mut errors);
assert!(!errors.is_empty());
assert!(errors.iter().any(|e| e.contains("expired")));
}
#[tokio::test]
async fn test_cleanup_expired_jtis() {
let manager = create_test_manager();
manager.record_jti("old_jti").await;
manager.record_jti("new_jti").await;
{
let mut jtis = manager.used_jtis.write().await;
jtis.insert("old_jti".to_string(), Utc::now() - Duration::days(2));
}
manager.cleanup_expired_jtis().await;
let jtis = manager.used_jtis.read().await;
assert!(!jtis.contains_key("old_jti"));
assert!(jtis.contains_key("new_jti"));
}
#[tokio::test]
async fn test_enhanced_jwt_validation_integration() {
let manager = create_test_manager();
let config = ClientJwtConfig::builder("test_client", create_test_jwk())
.rs256_only()
.audience("https://auth.example.com/token")
.build();
manager.register_client(config.clone()).await.unwrap();
let assertion = manager
.create_client_assertion(
"test_client",
"https://auth.example.com/token",
b"super-secret-key-for-testing-purposes",
Algorithm::HS256,
)
.unwrap();
let validation_result = manager.perform_enhanced_jwt_validation(&assertion, &config);
match validation_result {
Ok(_) => println!("Enhanced JWT validation passed"),
Err(e) => println!("Enhanced JWT validation failed as expected: {}", e),
}
}
#[test]
fn test_client_jwt_config_builder() {
let config = ClientJwtConfig::builder("builder_client", create_test_jwk())
.algorithms(vec![Algorithm::RS256])
.max_jwt_lifetime(Duration::minutes(10))
.clock_skew(Duration::seconds(30))
.audiences([
"https://auth.example.com/token",
"https://auth.example.com/par",
])
.build();
assert_eq!(config.client_id, "builder_client");
assert_eq!(config.allowed_algorithms, vec![Algorithm::RS256]);
assert_eq!(config.max_jwt_lifetime, Duration::minutes(10));
assert_eq!(config.clock_skew, Duration::seconds(30));
assert_eq!(config.expected_audiences.len(), 2);
}
#[test]
fn test_cleanup_interval_configuration() {
let jwt_config = crate::security::secure_jwt::SecureJwtConfig::default();
let jwt_validator = SecureJwtValidator::new(jwt_config).expect("test JWT config");
let manager =
PrivateKeyJwtManager::new(jwt_validator).with_cleanup_interval(Duration::minutes(30));
assert_eq!(manager.get_cleanup_interval(), Duration::minutes(30));
}
#[test]
fn test_cleanup_interval_update() {
let mut manager = create_test_manager();
assert_eq!(manager.get_cleanup_interval(), Duration::hours(1));
manager.update_cleanup_interval(Duration::minutes(15));
assert_eq!(manager.get_cleanup_interval(), Duration::minutes(15));
}
#[tokio::test]
async fn test_jwt_token_revocation_integration() {
let manager = create_test_manager();
let jti = "test_revoke_jti_456";
let is_revoked_before = manager.is_jwt_token_revoked(jti).unwrap_or(false);
assert!(!is_revoked_before);
manager.revoke_jwt_token(jti).unwrap();
let is_revoked_after = manager.is_jwt_token_revoked(jti).unwrap_or(false);
assert!(is_revoked_after);
}
#[tokio::test]
async fn test_scheduled_cleanup_integration() {
let mut manager = create_test_manager();
manager.update_cleanup_interval(Duration::minutes(1));
manager.record_jti("test_jti_1").await;
manager.revoke_jwt_token("revoked_jti_1").unwrap();
manager.schedule_automatic_cleanup().await;
assert_eq!(manager.get_cleanup_interval(), Duration::minutes(1));
}
#[tokio::test]
async fn test_cleanup_interval_used_in_cleanup_method() {
let mut manager = create_test_manager();
manager.update_cleanup_interval(Duration::minutes(30));
manager.record_jti("recent_jti").await;
manager.record_jti("old_jti").await;
{
let mut jtis = manager.used_jtis.write().await;
jtis.insert("recent_jti".to_string(), Utc::now() - Duration::minutes(15)); jtis.insert("old_jti".to_string(), Utc::now() - Duration::minutes(45)); }
manager.cleanup_expired_jtis().await;
let jtis = manager.used_jtis.read().await;
assert!(
jtis.contains_key("recent_jti"),
"Recent JTI should be retained"
);
assert!(!jtis.contains_key("old_jti"), "Old JTI should be removed");
}
}