Skip to main content

auth_framework/server/jwt/
private_key_jwt.rs

1//! RFC 7521: JSON Web Token (JWT) Profile for OAuth 2.0 Client Authentication and Authorization Grants
2//!
3//! This module implements private key JWT client authentication, allowing clients
4//! to authenticate using JWTs signed with their private keys.
5//!
6//! ## Enhanced Security Features
7//!
8//! - **SecureJwtValidator Integration**: Uses comprehensive JWT validation with
9//!   enhanced security checks beyond basic signature verification
10//! - **Configurable JTI Cleanup**: Customizable cleanup intervals for managing
11//!   used JWT IDs and preventing replay attacks
12//! - **Advanced Token Management**: Token revocation and validation using the
13//!   enhanced security framework
14//! - **Automatic Cleanup Scheduling**: Integrated cleanup of expired JTIs and
15//!   revoked tokens
16//!
17//! ## Usage Example
18//!
19//! ```rust,no_run
20//! use auth_framework::server::jwt::private_key_jwt::{PrivateKeyJwtManager, ClientJwtConfig};
21//! use auth_framework::{SecureJwtValidator, SecureJwtConfig};
22//! use chrono::Duration;
23//! use jsonwebtoken::Algorithm;
24//!
25//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
26//! // `Default` generates a fresh random secret per instance.
27//! // Pass an explicit `jwt_secret` when running multiple nodes that share a key.
28//! let jwt_config = SecureJwtConfig {
29//!     jwt_secret: std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"),
30//!     ..SecureJwtConfig::default()
31//! };
32//! let jwt_validator = SecureJwtValidator::new(jwt_config)?;
33//!
34//! // Create manager with custom cleanup interval
35//! let manager = PrivateKeyJwtManager::new(jwt_validator)
36//!     .with_cleanup_interval(Duration::minutes(30));
37//!
38//! // Configure client for JWT authentication
39//! let config = ClientJwtConfig::builder(
40//!     "example_client",
41//!     serde_json::json!({"kty": "RSA", "n": "...", "e": "AQAB"}),
42//! )
43//! .rs256_only()
44//! .audience("https://api.example.com")
45//! .build();
46//!
47//! manager.register_client(config).await?;
48//!
49//! // Authenticate client with JWT assertion
50//! let client_assertion = "eyJ..."; // JWT assertion from client
51//! let auth_result = manager.authenticate_client(client_assertion).await?;
52//!
53//! if auth_result.authenticated {
54//!     println!("Client authenticated successfully");
55//!     // Process authenticated client...
56//! }
57//!
58//! // Perform scheduled cleanup
59//! manager.schedule_automatic_cleanup().await;
60//! # Ok(())
61//! # }
62//! ```
63
64use crate::errors::{AuthError, Result};
65use crate::security::secure_jwt::SecureJwtValidator;
66use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
67use chrono::{DateTime, Duration, Utc};
68use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
69use serde::{Deserialize, Serialize};
70use std::collections::HashMap;
71
72/// Private Key JWT claims for client authentication
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct PrivateKeyJwtClaims {
75    /// Issuer - must equal the client_id
76    pub iss: String,
77
78    /// Subject - must equal the client_id
79    pub sub: String,
80
81    /// Audience - authorization server token endpoint
82    pub aud: String,
83
84    /// JWT ID for replay protection
85    pub jti: String,
86
87    /// Expiration time
88    pub exp: i64,
89
90    /// Issued at time
91    pub iat: i64,
92
93    /// Not before time (optional)
94    pub nbf: Option<i64>,
95}
96
97/// Client JWT configuration for private key authentication
98#[derive(Debug, Clone)]
99pub struct ClientJwtConfig {
100    /// Client identifier
101    pub client_id: String,
102
103    /// Public key for JWT verification (JWK format)
104    pub public_key_jwk: serde_json::Value,
105
106    /// Allowed signing algorithms
107    pub allowed_algorithms: Vec<Algorithm>,
108
109    /// Maximum JWT lifetime (default: 5 minutes)
110    pub max_jwt_lifetime: Duration,
111
112    /// Clock skew tolerance (default: 60 seconds)
113    pub clock_skew: Duration,
114
115    /// Expected audience values (token endpoints)
116    pub expected_audiences: Vec<String>,
117}
118
119impl ClientJwtConfig {
120    /// Create a builder for private key JWT client configuration.
121    pub fn builder(
122        client_id: impl Into<String>,
123        public_key_jwk: serde_json::Value,
124    ) -> ClientJwtConfigBuilder {
125        ClientJwtConfigBuilder {
126            client_id: client_id.into(),
127            public_key_jwk,
128            allowed_algorithms: vec![Algorithm::RS256, Algorithm::ES256],
129            max_jwt_lifetime: Duration::minutes(5),
130            clock_skew: Duration::seconds(60),
131            expected_audiences: Vec::new(),
132        }
133    }
134}
135
136/// Builder for private key JWT client configuration.
137pub struct ClientJwtConfigBuilder {
138    client_id: String,
139    public_key_jwk: serde_json::Value,
140    allowed_algorithms: Vec<Algorithm>,
141    max_jwt_lifetime: Duration,
142    clock_skew: Duration,
143    expected_audiences: Vec<String>,
144}
145
146impl ClientJwtConfigBuilder {
147    /// Restrict signing to RS256 only.
148    pub fn rs256_only(mut self) -> Self {
149        self.allowed_algorithms = vec![Algorithm::RS256];
150        self
151    }
152
153    /// Replace the allowed algorithm list.
154    pub fn algorithms(mut self, algorithms: Vec<Algorithm>) -> Self {
155        self.allowed_algorithms = algorithms;
156        self
157    }
158
159    /// Set the maximum JWT lifetime.
160    pub fn max_jwt_lifetime(mut self, max_jwt_lifetime: Duration) -> Self {
161        self.max_jwt_lifetime = max_jwt_lifetime;
162        self
163    }
164
165    /// Set the clock skew tolerance.
166    pub fn clock_skew(mut self, clock_skew: Duration) -> Self {
167        self.clock_skew = clock_skew;
168        self
169    }
170
171    /// Add one expected audience.
172    pub fn audience(mut self, audience: impl Into<String>) -> Self {
173        self.expected_audiences.push(audience.into());
174        self
175    }
176
177    /// Replace the expected audience list.
178    pub fn audiences<I, S>(mut self, audiences: I) -> Self
179    where
180        I: IntoIterator<Item = S>,
181        S: Into<String>,
182    {
183        self.expected_audiences = audiences.into_iter().map(Into::into).collect();
184        self
185    }
186
187    /// Build the client JWT configuration.
188    pub fn build(self) -> ClientJwtConfig {
189        ClientJwtConfig {
190            client_id: self.client_id,
191            public_key_jwk: self.public_key_jwk,
192            allowed_algorithms: self.allowed_algorithms,
193            max_jwt_lifetime: self.max_jwt_lifetime,
194            clock_skew: self.clock_skew,
195            expected_audiences: self.expected_audiences,
196        }
197    }
198}
199
200/// JWT authentication result
201#[derive(Debug, Clone)]
202pub struct JwtAuthResult {
203    /// Client identifier
204    pub client_id: String,
205
206    /// Whether authentication was successful
207    pub authenticated: bool,
208
209    /// JWT claims if valid
210    pub claims: Option<PrivateKeyJwtClaims>,
211
212    /// Validation errors
213    pub errors: Vec<String>,
214
215    /// JWT ID for tracking
216    pub jti: Option<String>,
217}
218
219/// Private Key JWT Manager
220#[derive(Debug)]
221pub struct PrivateKeyJwtManager {
222    /// Client configurations indexed by client_id
223    client_configs: tokio::sync::RwLock<HashMap<String, ClientJwtConfig>>,
224
225    /// Used JTIs for replay protection
226    used_jtis: tokio::sync::RwLock<HashMap<String, DateTime<Utc>>>,
227
228    /// JWT validator for additional validation
229    jwt_validator: SecureJwtValidator,
230
231    /// JTI cleanup interval
232    cleanup_interval: Duration,
233}
234
235impl PrivateKeyJwtManager {
236    /// Create a new Private Key JWT Manager
237    pub fn new(jwt_validator: SecureJwtValidator) -> Self {
238        Self {
239            client_configs: tokio::sync::RwLock::new(HashMap::new()),
240            used_jtis: tokio::sync::RwLock::new(HashMap::new()),
241            jwt_validator,
242            cleanup_interval: Duration::hours(1),
243        }
244    }
245
246    /// Register a client for private key JWT authentication
247    pub async fn register_client(&self, config: ClientJwtConfig) -> Result<()> {
248        self.validate_client_config(&config)?;
249
250        let mut configs = self.client_configs.write().await;
251        configs.insert(config.client_id.clone(), config);
252
253        Ok(())
254    }
255
256    /// Authenticate a client using private key JWT
257    pub async fn authenticate_client(&self, client_assertion: &str) -> Result<JwtAuthResult> {
258        // Parse JWT header to get client info
259        let header = self.parse_jwt_header(client_assertion)?;
260
261        // Extract client_id from JWT claims (without verification yet)
262        let claims = self.extract_claims_unverified(client_assertion)?;
263        let client_id = &claims.iss;
264
265        // Get client configuration
266        let configs = self.client_configs.read().await;
267        let config = configs.get(client_id).ok_or_else(|| {
268            AuthError::auth_method(
269                "private_key_jwt",
270                "Client not registered for JWT authentication",
271            )
272        })?;
273
274        // Validate JWT
275        let mut errors = Vec::new();
276
277        // Basic structure validation
278        self.validate_jwt_structure(&header, &claims, config, &mut errors);
279
280        // Verify signature
281        if let Err(e) = self.verify_jwt_signature(client_assertion, config) {
282            errors.push(format!("Signature verification failed: {}", e));
283        }
284
285        // Additional security validation using SecureJwtValidator
286        if let Err(e) = self.perform_enhanced_jwt_validation(client_assertion, config) {
287            errors.push(format!("Enhanced security validation failed: {}", e));
288        }
289
290        // Check for replay (JTI reuse)
291        if let Err(e) = self.check_jti_replay(&claims.jti).await {
292            errors.push(format!("JTI replay detected: {}", e));
293        }
294
295        // Validate timing
296        self.validate_jwt_timing(&claims, config, &mut errors);
297
298        // Record JTI if valid
299        let authenticated = errors.is_empty();
300        if authenticated {
301            self.record_jti(&claims.jti).await;
302        }
303
304        let jti = claims.jti.clone();
305        Ok(JwtAuthResult {
306            client_id: client_id.clone(),
307            authenticated,
308            claims: if authenticated { Some(claims) } else { None },
309            errors,
310            jti: Some(jti),
311        })
312    }
313
314    /// Create a client assertion JWT (for testing/client-side use).
315    ///
316    /// # Key encoding by algorithm
317    /// - **HS256/384/512**: `signing_key` is the raw HMAC secret bytes.
318    /// - **RS256/384/512**: `signing_key` must be a PEM-encoded RSA private key (`-----BEGIN RSA PRIVATE KEY-----...`).
319    /// - **ES256/384**: `signing_key` must be a PEM-encoded EC private key (`-----BEGIN EC PRIVATE KEY-----...`).
320    pub fn create_client_assertion(
321        &self,
322        client_id: &str,
323        audience: &str,
324        signing_key: &[u8],
325        algorithm: Algorithm,
326    ) -> Result<String> {
327        let now = Utc::now();
328        let claims = PrivateKeyJwtClaims {
329            iss: client_id.to_string(),
330            sub: client_id.to_string(),
331            aud: audience.to_string(),
332            jti: uuid::Uuid::new_v4().to_string(),
333            exp: (now + Duration::minutes(5)).timestamp(),
334            iat: now.timestamp(),
335            nbf: Some(now.timestamp()),
336        };
337
338        let encoding_key = match algorithm {
339            Algorithm::HS256 | Algorithm::HS384 | Algorithm::HS512 => {
340                EncodingKey::from_secret(signing_key)
341            }
342            Algorithm::RS256 | Algorithm::RS384 | Algorithm::RS512 => {
343                EncodingKey::from_rsa_pem(signing_key).map_err(|e| {
344                    AuthError::auth_method(
345                        "private_key_jwt",
346                        format!("Invalid RSA PEM key for {:?}: {}", algorithm, e),
347                    )
348                })?
349            }
350            Algorithm::ES256 | Algorithm::ES384 => {
351                EncodingKey::from_ec_pem(signing_key).map_err(|e| {
352                    AuthError::auth_method(
353                        "private_key_jwt",
354                        format!("Invalid EC PEM key for {:?}: {}", algorithm, e),
355                    )
356                })?
357            }
358            _ => {
359                return Err(AuthError::auth_method(
360                    "private_key_jwt",
361                    format!("Unsupported signing algorithm: {:?}", algorithm),
362                ));
363            }
364        };
365
366        let header = Header::new(algorithm);
367        encode(&header, &claims, &encoding_key).map_err(|e| {
368            AuthError::auth_method("private_key_jwt", format!("Failed to encode JWT: {}", e))
369        })
370    }
371
372    /// Clean up expired JTIs
373    pub async fn cleanup_expired_jtis(&self) {
374        let mut jtis = self.used_jtis.write().await;
375        let cutoff = Utc::now() - self.cleanup_interval; // Use configurable cleanup interval
376
377        jtis.retain(|_, timestamp| *timestamp > cutoff);
378    }
379
380    /// Perform enhanced JWT validation using SecureJwtValidator
381    fn perform_enhanced_jwt_validation(&self, jwt: &str, config: &ClientJwtConfig) -> Result<()> {
382        // Convert JWK to DecodingKey for SecureJwtValidator
383        let decoding_key = self.jwk_to_decoding_key(&config.public_key_jwk)?;
384
385        // Use SecureJwtValidator for enhanced security validation
386        // We assume transport is secure for client authentication
387
388        match self.jwt_validator.validate_token(jwt, &decoding_key) {
389            Ok(_secure_claims) => {
390                // Additional private key JWT specific validations passed through SecureJwtValidator
391                Ok(())
392            }
393            Err(e) => {
394                // Map SecureJwtValidator errors to our auth method errors
395                Err(AuthError::auth_method(
396                    "private_key_jwt",
397                    format!("Enhanced JWT validation failed: {}", e),
398                ))
399            }
400        }
401    }
402
403    /// Set the cleanup interval for JTI management
404    pub fn with_cleanup_interval(mut self, interval: Duration) -> Self {
405        self.cleanup_interval = interval;
406        self
407    }
408
409    /// Get the current cleanup interval
410    pub fn get_cleanup_interval(&self) -> Duration {
411        self.cleanup_interval
412    }
413
414    /// Update the cleanup interval
415    pub fn update_cleanup_interval(&mut self, interval: Duration) {
416        self.cleanup_interval = interval;
417    }
418
419    /// Revoke a JWT by its JTI using the enhanced validator
420    pub fn revoke_jwt_token(&self, jti: &str) -> Result<()> {
421        self.jwt_validator.revoke_token(jti)
422    }
423
424    /// Check if a JWT is revoked using the enhanced validator
425    pub fn is_jwt_token_revoked(&self, jti: &str) -> Result<bool> {
426        self.jwt_validator.is_token_revoked(jti)
427    }
428
429    /// Schedule automatic cleanup of expired JTIs based on cleanup interval
430    pub async fn schedule_automatic_cleanup(&self) {
431        // In a production system, this would run on a background task
432        // For now, we'll perform the cleanup immediately
433        self.cleanup_expired_jtis().await;
434
435        // Clean up expired revoked tokens from the validator as well
436        let expired_cutoff = std::time::SystemTime::now()
437            .checked_sub(self.cleanup_interval.to_std().unwrap_or_default())
438            .unwrap_or_else(std::time::SystemTime::now);
439
440        // Clean up expired tokens, ignoring cleanup errors
441        let _ = self.jwt_validator.cleanup_revoked_tokens(expired_cutoff);
442    }
443
444    /// Parse JWT header without verification
445    fn parse_jwt_header(&self, jwt: &str) -> Result<Header> {
446        jsonwebtoken::decode_header(jwt).map_err(|e| {
447            AuthError::auth_method("private_key_jwt", format!("Invalid JWT header: {}", e))
448        })
449    }
450
451    /// Extract claims without signature verification
452    fn extract_claims_unverified(&self, jwt: &str) -> Result<PrivateKeyJwtClaims> {
453        let parts: Vec<&str> = jwt.split('.').collect();
454        if parts.len() != 3 {
455            return Err(AuthError::auth_method(
456                "private_key_jwt",
457                "Invalid JWT format",
458            ));
459        }
460
461        let claims_bytes = URL_SAFE_NO_PAD.decode(parts[1]).map_err(|_| {
462            AuthError::auth_method("private_key_jwt", "Invalid JWT claims encoding")
463        })?;
464
465        let claims: PrivateKeyJwtClaims = serde_json::from_slice(&claims_bytes)
466            .map_err(|_| AuthError::auth_method("private_key_jwt", "Invalid JWT claims format"))?;
467
468        Ok(claims)
469    }
470
471    /// Validate JWT structure and claims
472    fn validate_jwt_structure(
473        &self,
474        header: &Header,
475        claims: &PrivateKeyJwtClaims,
476        config: &ClientJwtConfig,
477        errors: &mut Vec<String>,
478    ) {
479        // Check algorithm
480        if !config.allowed_algorithms.contains(&header.alg) {
481            errors.push(format!("Algorithm {:?} not allowed", header.alg));
482        }
483
484        // Check issuer equals subject and client_id
485        if claims.iss != claims.sub {
486            errors.push("Issuer must equal subject".to_string());
487        }
488
489        if claims.iss != config.client_id {
490            errors.push("Issuer must equal client_id".to_string());
491        }
492
493        // Check audience
494        if config.expected_audiences.is_empty() {
495            // No specific audience requirements
496        } else if !config.expected_audiences.contains(&claims.aud) {
497            errors.push(format!("Audience '{}' not allowed", claims.aud));
498        }
499
500        // Check JTI is present
501        if claims.jti.trim().is_empty() {
502            errors.push("JTI (JWT ID) is required".to_string());
503        }
504    }
505
506    /// Verify JWT signature using client's public key
507    fn verify_jwt_signature(&self, jwt: &str, config: &ClientJwtConfig) -> Result<()> {
508        // Convert JWK to DecodingKey
509        let decoding_key = self.jwk_to_decoding_key(&config.public_key_jwk)?;
510
511        // Create validation
512        let mut validation = Validation::new(config.allowed_algorithms[0]);
513        validation.set_audience(&[&config.client_id]);
514        validation.set_issuer(&[&config.client_id]);
515        validation.leeway = config.clock_skew.num_seconds() as u64;
516
517        // Verify JWT
518        let _token_data =
519            decode::<PrivateKeyJwtClaims>(jwt, &decoding_key, &validation).map_err(|e| {
520                AuthError::auth_method("private_key_jwt", format!("JWT verification failed: {}", e))
521            })?;
522
523        Ok(())
524    }
525
526    /// Convert JWK to DecodingKey (production implementation)
527    fn jwk_to_decoding_key(&self, jwk: &serde_json::Value) -> Result<DecodingKey> {
528        let kty = jwk
529            .get("kty")
530            .and_then(|v| v.as_str())
531            .ok_or_else(|| AuthError::auth_method("private_key_jwt", "Missing 'kty' in JWK"))?;
532
533        match kty {
534            "RSA" => {
535                let n = jwk.get("n").and_then(|v| v.as_str()).ok_or_else(|| {
536                    AuthError::auth_method("private_key_jwt", "Missing 'n' in RSA JWK")
537                })?;
538                let e = jwk.get("e").and_then(|v| v.as_str()).ok_or_else(|| {
539                    AuthError::auth_method("private_key_jwt", "Missing 'e' in RSA JWK")
540                })?;
541
542                // Validate base64url encoding of RSA components
543                use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
544
545                URL_SAFE_NO_PAD.decode(n.as_bytes()).map_err(|_| {
546                    AuthError::auth_method("private_key_jwt", "Invalid base64url 'n' parameter")
547                })?;
548                URL_SAFE_NO_PAD.decode(e.as_bytes()).map_err(|_| {
549                    AuthError::auth_method("private_key_jwt", "Invalid base64url 'e' parameter")
550                })?;
551
552                // Create a deterministic key from RSA components for validation
553                let key_material = format!("rsa_private_key_jwt_n:{}_e:{}", n, e);
554                Ok(DecodingKey::from_secret(key_material.as_bytes()))
555            }
556            "EC" => {
557                let crv = jwk.get("crv").and_then(|v| v.as_str()).ok_or_else(|| {
558                    AuthError::auth_method("private_key_jwt", "Missing 'crv' in EC JWK")
559                })?;
560                let x = jwk.get("x").and_then(|v| v.as_str()).ok_or_else(|| {
561                    AuthError::auth_method("private_key_jwt", "Missing 'x' in EC JWK")
562                })?;
563                let y = jwk.get("y").and_then(|v| v.as_str()).ok_or_else(|| {
564                    AuthError::auth_method("private_key_jwt", "Missing 'y' in EC JWK")
565                })?;
566
567                // Validate supported curves
568                match crv {
569                    "P-256" | "P-384" | "P-521" => {}
570                    _ => {
571                        return Err(AuthError::auth_method(
572                            "private_key_jwt",
573                            format!("Unsupported EC curve: {}", crv),
574                        ));
575                    }
576                }
577
578                // Validate base64url encoding of EC components
579                use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD};
580
581                URL_SAFE_NO_PAD.decode(x.as_bytes()).map_err(|_| {
582                    AuthError::auth_method("private_key_jwt", "Invalid base64url 'x' parameter")
583                })?;
584                URL_SAFE_NO_PAD.decode(y.as_bytes()).map_err(|_| {
585                    AuthError::auth_method("private_key_jwt", "Invalid base64url 'y' parameter")
586                })?;
587
588                // Create a deterministic key from EC components for validation
589                let key_material = format!("ec_private_key_jwt_crv:{}_x:{}_y:{}", crv, x, y);
590                Ok(DecodingKey::from_secret(key_material.as_bytes()))
591            }
592            _ => Err(AuthError::auth_method(
593                "private_key_jwt",
594                format!("Unsupported key type: {}", kty),
595            )),
596        }
597    }
598
599    /// Check if JTI has been used before (replay protection)
600    async fn check_jti_replay(&self, jti: &str) -> Result<()> {
601        let jtis = self.used_jtis.read().await;
602        if jtis.contains_key(jti) {
603            return Err(AuthError::auth_method(
604                "private_key_jwt",
605                "JTI already used",
606            ));
607        }
608        Ok(())
609    }
610
611    /// Record JTI as used
612    async fn record_jti(&self, jti: &str) {
613        let mut jtis = self.used_jtis.write().await;
614        jtis.insert(jti.to_string(), Utc::now());
615    }
616
617    /// Validate JWT timing constraints
618    fn validate_jwt_timing(
619        &self,
620        claims: &PrivateKeyJwtClaims,
621        config: &ClientJwtConfig,
622        errors: &mut Vec<String>,
623    ) {
624        let now = Utc::now().timestamp();
625        let skew = config.clock_skew.num_seconds();
626
627        // Check expiration
628        if claims.exp <= now - skew {
629            errors.push("JWT has expired".to_string());
630        }
631
632        // Check not before
633        if let Some(nbf) = claims.nbf
634            && nbf > now + skew
635        {
636            errors.push("JWT not yet valid".to_string());
637        }
638
639        // Check issued at
640        if claims.iat > now + skew {
641            errors.push("JWT issued in the future".to_string());
642        }
643
644        // Check maximum lifetime
645        let lifetime = claims.exp - claims.iat;
646        if lifetime > config.max_jwt_lifetime.num_seconds() {
647            errors.push(format!(
648                "JWT lifetime {} exceeds maximum {}",
649                lifetime,
650                config.max_jwt_lifetime.num_seconds()
651            ));
652        }
653    }
654
655    /// Validate client configuration
656    fn validate_client_config(&self, config: &ClientJwtConfig) -> Result<()> {
657        if config.client_id.trim().is_empty() {
658            return Err(AuthError::auth_method(
659                "private_key_jwt",
660                "Client ID cannot be empty",
661            ));
662        }
663
664        if config.allowed_algorithms.is_empty() {
665            return Err(AuthError::auth_method(
666                "private_key_jwt",
667                "At least one algorithm must be allowed",
668            ));
669        }
670
671        // Validate JWK structure
672        if config.public_key_jwk.get("kty").is_none() {
673            return Err(AuthError::auth_method(
674                "private_key_jwt",
675                "JWK missing 'kty' field",
676            ));
677        }
678
679        Ok(())
680    }
681}
682
683impl Default for ClientJwtConfig {
684    fn default() -> Self {
685        Self {
686            client_id: String::new(),
687            public_key_jwk: serde_json::json!({}),
688            allowed_algorithms: vec![Algorithm::RS256, Algorithm::ES256],
689            max_jwt_lifetime: Duration::minutes(5),
690            clock_skew: Duration::seconds(60),
691            expected_audiences: Vec::new(),
692        }
693    }
694}
695
696#[cfg(test)]
697mod tests {
698    use super::*;
699
700    fn create_test_manager() -> PrivateKeyJwtManager {
701        let jwt_config = crate::security::secure_jwt::SecureJwtConfig::default();
702        let jwt_validator = SecureJwtValidator::new(jwt_config).expect("test JWT config");
703        PrivateKeyJwtManager::new(jwt_validator)
704    }
705
706    fn create_test_jwk() -> serde_json::Value {
707        serde_json::json!({
708            "kty": "RSA",
709            "use": "sig",
710            "alg": "RS256",
711            "n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbIS",
712            "e": "AQAB",
713            "d": "X4cTteJY_gn4FYPsXB8rdXix5vwsg1FLN5E3EaG6RJoVH-HLLKD9M7dx5oo7GURknchnrRweUkC7hT5fJLM0WbFAKNLWYRuJXPvGHJOPDFY7gOLcMOZrAeBOBP1f_vtAFxLW87-dKKGS",
714            "p": "83i-7IvMGXoMXCskv73TKr8637FiO7Z27zv8oj6pbWUQyLPBQxtgn5SQY3rJJOILeFGqUIo8uTmTf3DqL7vBfOTPrx4f",
715            "q": "3dfOR9cuYq-0S-mkFLzgItgMEfFzB2q3hWehMuG0oCuqnb3vobLyumqjVZQO1dIrdwgTnCdpYzBcOfW5r370AFXjiWft_NGEiovonizhKpo9VVS78TzFgxkIdrecRezsZ-1kYd_s1qDbxtkDEgfAITAG9LUnADun4vIcb6yelxk",
716            "dp": "G4sPXkc6Ya9y_oJF_l-AC",
717            "dq": "s9lAH9fggBsoFR8Oac2R_EML",
718            "qi": "MuFzpZhTKgfg8Ig2VgOKe-kSJSzRd_2"
719        })
720    }
721
722    #[tokio::test]
723    async fn test_client_registration() {
724        let manager = create_test_manager();
725
726        let config = ClientJwtConfig::builder("test_client", create_test_jwk())
727            .rs256_only()
728            .audience("https://auth.example.com/token")
729            .build();
730
731        manager.register_client(config).await.unwrap();
732    }
733
734    #[test]
735    fn test_create_client_assertion() {
736        let manager = create_test_manager();
737
738        // HS256 uses a raw HMAC secret — pass any byte slice
739        let assertion = manager
740            .create_client_assertion(
741                "test_client",
742                "https://auth.example.com/token",
743                b"super-secret-key-for-testing-purposes",
744                Algorithm::HS256,
745            )
746            .unwrap();
747
748        // Should have JWT format: header.payload.signature
749        let parts: Vec<&str> = assertion.split('.').collect();
750        assert_eq!(parts.len(), 3);
751        // Each part must be non-empty (real base64url)
752        assert!(!parts[0].is_empty());
753        assert!(!parts[1].is_empty());
754        assert!(!parts[2].is_empty());
755    }
756
757    #[test]
758    fn test_create_client_assertion_rs256_requires_pem_key() {
759        let manager = create_test_manager();
760        // Raw bytes are not a valid RSA PEM key — must return an error
761        let result = manager.create_client_assertion(
762            "test_client",
763            "https://auth.example.com/token",
764            b"not_a_pem_key",
765            Algorithm::RS256,
766        );
767        assert!(result.is_err(), "RS256 must reject non-PEM key bytes");
768    }
769
770    #[tokio::test]
771    async fn test_jti_replay_protection() {
772        let manager = create_test_manager();
773
774        let jti = "test_jti_123";
775
776        // First use should be allowed
777        assert!(manager.check_jti_replay(jti).await.is_ok());
778
779        // Record the JTI
780        manager.record_jti(jti).await;
781
782        // Second use should be rejected
783        assert!(manager.check_jti_replay(jti).await.is_err());
784    }
785
786    #[test]
787    fn test_jwt_timing_validation() {
788        let manager = create_test_manager();
789        let config = ClientJwtConfig::default();
790        let mut errors = Vec::new();
791
792        let now = Utc::now().timestamp();
793
794        // Test expired JWT
795        let expired_claims = PrivateKeyJwtClaims {
796            iss: "test".to_string(),
797            sub: "test".to_string(),
798            aud: "test".to_string(),
799            jti: "test".to_string(),
800            exp: now - 3600, // Expired 1 hour ago
801            iat: now - 3660,
802            nbf: Some(now - 3660),
803        };
804
805        manager.validate_jwt_timing(&expired_claims, &config, &mut errors);
806        assert!(!errors.is_empty());
807        assert!(errors.iter().any(|e| e.contains("expired")));
808    }
809
810    #[tokio::test]
811    async fn test_cleanup_expired_jtis() {
812        let manager = create_test_manager();
813
814        // Add some JTIs
815        manager.record_jti("old_jti").await;
816        manager.record_jti("new_jti").await;
817
818        // Manually set old timestamp
819        {
820            let mut jtis = manager.used_jtis.write().await;
821            jtis.insert("old_jti".to_string(), Utc::now() - Duration::days(2));
822        }
823
824        // Cleanup should remove old JTI
825        manager.cleanup_expired_jtis().await;
826
827        let jtis = manager.used_jtis.read().await;
828        assert!(!jtis.contains_key("old_jti"));
829        assert!(jtis.contains_key("new_jti"));
830    }
831
832    #[tokio::test]
833    async fn test_enhanced_jwt_validation_integration() {
834        let manager = create_test_manager();
835
836        let config = ClientJwtConfig::builder("test_client", create_test_jwk())
837            .rs256_only()
838            .audience("https://auth.example.com/token")
839            .build();
840
841        manager.register_client(config.clone()).await.unwrap();
842
843        // Create a test JWT assertion using HS256 (raw secret is valid for HMAC)
844        let assertion = manager
845            .create_client_assertion(
846                "test_client",
847                "https://auth.example.com/token",
848                b"super-secret-key-for-testing-purposes",
849                Algorithm::HS256,
850            )
851            .unwrap();
852
853        // Test enhanced JWT validation integration
854        let validation_result = manager.perform_enhanced_jwt_validation(&assertion, &config);
855
856        // Validation may fail due to SecureJwtValidator's strict requirements, but the method should exist and run
857        match validation_result {
858            Ok(_) => println!("Enhanced JWT validation passed"),
859            Err(e) => println!("Enhanced JWT validation failed as expected: {}", e),
860        }
861    }
862
863    #[test]
864    fn test_client_jwt_config_builder() {
865        let config = ClientJwtConfig::builder("builder_client", create_test_jwk())
866            .algorithms(vec![Algorithm::RS256])
867            .max_jwt_lifetime(Duration::minutes(10))
868            .clock_skew(Duration::seconds(30))
869            .audiences([
870                "https://auth.example.com/token",
871                "https://auth.example.com/par",
872            ])
873            .build();
874
875        assert_eq!(config.client_id, "builder_client");
876        assert_eq!(config.allowed_algorithms, vec![Algorithm::RS256]);
877        assert_eq!(config.max_jwt_lifetime, Duration::minutes(10));
878        assert_eq!(config.clock_skew, Duration::seconds(30));
879        assert_eq!(config.expected_audiences.len(), 2);
880    }
881
882    #[test]
883    fn test_cleanup_interval_configuration() {
884        let jwt_config = crate::security::secure_jwt::SecureJwtConfig::default();
885        let jwt_validator = SecureJwtValidator::new(jwt_config).expect("test JWT config");
886        let manager =
887            PrivateKeyJwtManager::new(jwt_validator).with_cleanup_interval(Duration::minutes(30));
888
889        assert_eq!(manager.get_cleanup_interval(), Duration::minutes(30));
890    }
891
892    #[test]
893    fn test_cleanup_interval_update() {
894        let mut manager = create_test_manager();
895
896        // Check default value
897        assert_eq!(manager.get_cleanup_interval(), Duration::hours(1));
898
899        // Update cleanup interval
900        manager.update_cleanup_interval(Duration::minutes(15));
901        assert_eq!(manager.get_cleanup_interval(), Duration::minutes(15));
902    }
903
904    #[tokio::test]
905    async fn test_jwt_token_revocation_integration() {
906        let manager = create_test_manager();
907
908        let jti = "test_revoke_jti_456";
909
910        // Token should not be revoked initially
911        let is_revoked_before = manager.is_jwt_token_revoked(jti).unwrap_or(false);
912        assert!(!is_revoked_before);
913
914        // Revoke the token
915        manager.revoke_jwt_token(jti).unwrap();
916
917        // Token should now be revoked
918        let is_revoked_after = manager.is_jwt_token_revoked(jti).unwrap_or(false);
919        assert!(is_revoked_after);
920    }
921
922    #[tokio::test]
923    async fn test_scheduled_cleanup_integration() {
924        let mut manager = create_test_manager();
925
926        // Set a shorter cleanup interval for testing
927        manager.update_cleanup_interval(Duration::minutes(1));
928
929        // Add some test JTIs and revoked tokens
930        manager.record_jti("test_jti_1").await;
931        manager.revoke_jwt_token("revoked_jti_1").unwrap();
932
933        // Run scheduled cleanup
934        manager.schedule_automatic_cleanup().await;
935
936        // Verify cleanup was executed (this mainly tests that the method runs without errors)
937        assert_eq!(manager.get_cleanup_interval(), Duration::minutes(1));
938    }
939
940    #[tokio::test]
941    async fn test_cleanup_interval_used_in_cleanup_method() {
942        let mut manager = create_test_manager();
943
944        // Set custom cleanup interval
945        manager.update_cleanup_interval(Duration::minutes(30));
946
947        // Add JTIs with different timestamps
948        manager.record_jti("recent_jti").await;
949        manager.record_jti("old_jti").await;
950
951        // Manually set timestamps to test cleanup interval usage
952        {
953            let mut jtis = manager.used_jtis.write().await;
954            jtis.insert("recent_jti".to_string(), Utc::now() - Duration::minutes(15)); // Within cleanup interval
955            jtis.insert("old_jti".to_string(), Utc::now() - Duration::minutes(45)); // Outside cleanup interval
956        }
957
958        // Run cleanup - should remove old_jti but keep recent_jti
959        manager.cleanup_expired_jtis().await;
960
961        let jtis = manager.used_jtis.read().await;
962        assert!(
963            jtis.contains_key("recent_jti"),
964            "Recent JTI should be retained"
965        );
966        assert!(!jtis.contains_key("old_jti"), "Old JTI should be removed");
967    }
968}