auth_framework/server/core/
common_jwt.rs

1//! Common JWT Operations
2//!
3//! This module provides shared JWT functionality to eliminate
4//! duplication across server modules.
5
6use crate::errors::{AuthError, Result};
7use crate::server::core::common_validation;
8use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13/// Common JWT configuration
14#[derive(Clone)]
15pub struct JwtConfig {
16    /// Signing algorithm
17    pub algorithm: Algorithm,
18    /// Signing key
19    pub signing_key: EncodingKey,
20    /// Verification key
21    pub verification_key: DecodingKey,
22    /// Default expiration time in seconds
23    pub default_expiration: u64,
24    /// Issuer
25    pub issuer: String,
26    /// Audiences
27    pub audiences: Vec<String>,
28}
29
30impl JwtConfig {
31    /// Create new JWT config with symmetric key
32    pub fn with_symmetric_key(secret: &[u8], issuer: String) -> Self {
33        Self {
34            algorithm: Algorithm::HS256,
35            signing_key: EncodingKey::from_secret(secret),
36            verification_key: DecodingKey::from_secret(secret),
37            default_expiration: 3600, // 1 hour
38            issuer,
39            audiences: vec![],
40        }
41    }
42
43    /// Create new JWT config with RSA keys
44    pub fn with_rsa_keys(private_key: &[u8], public_key: &[u8], issuer: String) -> Result<Self> {
45        let signing_key = EncodingKey::from_rsa_pem(private_key)
46            .map_err(|e| AuthError::validation(format!("Invalid private key: {}", e)))?;
47
48        let verification_key = DecodingKey::from_rsa_pem(public_key)
49            .map_err(|e| AuthError::validation(format!("Invalid public key: {}", e)))?;
50
51        Ok(Self {
52            algorithm: Algorithm::RS256,
53            signing_key,
54            verification_key,
55            default_expiration: 3600, // 1 hour
56            issuer,
57            audiences: vec![],
58        })
59    }
60
61    /// Add audience
62    pub fn with_audience(mut self, audience: String) -> Self {
63        self.audiences.push(audience);
64        self
65    }
66
67    /// Set expiration time
68    pub fn with_expiration(mut self, expiration: u64) -> Self {
69        self.default_expiration = expiration;
70        self
71    }
72}
73
74/// Common JWT claims structure
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct CommonJwtClaims {
77    /// Issuer
78    pub iss: String,
79    /// Subject
80    pub sub: String,
81    /// Audience
82    pub aud: Vec<String>,
83    /// Expiration time
84    pub exp: i64,
85    /// Issued at
86    pub iat: i64,
87    /// Not before
88    pub nbf: Option<i64>,
89    /// JWT ID
90    pub jti: Option<String>,
91    /// Custom claims
92    #[serde(flatten)]
93    pub custom: HashMap<String, serde_json::Value>,
94}
95
96impl CommonJwtClaims {
97    /// Create new claims with required fields
98    pub fn new(issuer: String, subject: String, audiences: Vec<String>, expiration: i64) -> Self {
99        let now = SystemTime::now()
100            .duration_since(UNIX_EPOCH)
101            .unwrap()
102            .as_secs() as i64;
103
104        Self {
105            iss: issuer,
106            sub: subject,
107            aud: audiences,
108            exp: expiration,
109            iat: now,
110            nbf: None,
111            jti: None,
112            custom: HashMap::new(),
113        }
114    }
115
116    /// Add custom claim
117    pub fn with_custom_claim(mut self, key: String, value: serde_json::Value) -> Self {
118        self.custom.insert(key, value);
119        self
120    }
121
122    /// Set JWT ID
123    pub fn with_jti(mut self, jti: String) -> Self {
124        self.jti = Some(jti);
125        self
126    }
127
128    /// Set not before
129    pub fn with_nbf(mut self, nbf: i64) -> Self {
130        self.nbf = Some(nbf);
131        self
132    }
133}
134
135/// Common JWT token management for OAuth 2.0 and OpenID Connect operations.
136///
137/// `JwtManager` provides comprehensive JWT token creation, verification, and
138/// management capabilities specifically designed for OAuth 2.0 authorization
139/// servers and OpenID Connect providers. It supports both symmetric and
140/// asymmetric signing algorithms with security best practices.
141///
142/// # Supported Algorithms
143///
144/// - **HMAC**: HS256, HS384, HS512 (symmetric)
145/// - **RSA**: RS256, RS384, RS512 (asymmetric)
146/// - **ECDSA**: ES256, ES384, ES512 (asymmetric)
147/// - **EdDSA**: EdDSA (asymmetric, Ed25519)
148///
149/// # Security Features
150///
151/// - **Algorithm Validation**: Prevents algorithm confusion attacks
152/// - **Time Validation**: Automatic `exp`, `nbf`, and `iat` claim validation
153/// - **Audience Validation**: Ensures tokens are used by intended recipients
154/// - **Issuer Validation**: Verifies token origin
155/// - **Secure Defaults**: Uses secure algorithm choices and expiration times
156///
157/// # Token Types Supported
158///
159/// - **Access Tokens**: OAuth 2.0 access tokens with scopes
160/// - **ID Tokens**: OpenID Connect identity tokens
161/// - **Refresh Tokens**: Long-lived tokens for access token renewal
162/// - **Custom Tokens**: Application-specific token types
163///
164/// # Key Management
165///
166/// - **Symmetric Keys**: HMAC-based signing with shared secrets
167/// - **RSA Keys**: Support for PKCS#1 and PKCS#8 key formats
168/// - **Key Rotation**: Support for multiple signing keys
169/// - **Key Security**: Secure key storage and access patterns
170///
171/// # Example
172///
173/// ```rust
174/// use auth_framework::server::core::common_jwt::{JwtManager, JwtConfig, CommonJwtClaims};
175///
176/// // Create JWT manager with RSA keys
177/// let config = JwtConfig::with_rsa_keys(
178///     private_key_bytes,
179///     public_key_bytes,
180///     "https://auth.example.com".to_string()
181/// )?;
182/// let jwt_manager = JwtManager::new(config);
183///
184/// // Create access token
185/// let claims = CommonJwtClaims::new(
186///     "https://auth.example.com".to_string(),
187///     "user123".to_string(),
188///     vec!["api".to_string()],
189///     expiration_time
190/// ).with_custom_claim("scope".to_string(), json!("read write"));
191///
192/// let token = jwt_manager.create_token(&claims)?;
193///
194/// // Verify token
195/// let verified_claims = jwt_manager.verify_token(&token)?;
196/// ```
197///
198/// # Performance Considerations
199///
200/// - Asymmetric algorithms are more computationally expensive
201/// - Token verification is optimized for high-throughput scenarios
202/// - Key caching reduces cryptographic operation overhead
203///
204/// # RFC Compliance
205///
206/// - **RFC 7519**: JSON Web Token (JWT)
207/// - **RFC 7515**: JSON Web Signature (JWS)
208/// - **RFC 8725**: JWT Best Current Practices
209/// - **RFC 9068**: JWT Profile for OAuth 2.0 Access Tokens
210pub struct JwtManager {
211    config: JwtConfig,
212}
213
214impl JwtManager {
215    /// Create new JWT manager
216    pub fn new(config: JwtConfig) -> Self {
217        Self { config }
218    }
219
220    /// Create signed JWT token
221    pub fn create_token(&self, claims: &CommonJwtClaims) -> Result<String> {
222        let header = Header {
223            alg: self.config.algorithm,
224            ..Default::default()
225        };
226
227        encode(&header, claims, &self.config.signing_key)
228            .map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
229    }
230
231    /// Create signed token with custom claims
232    pub fn create_token_with_custom_claims<T>(&self, claims: &T) -> Result<String>
233    where
234        T: Serialize,
235    {
236        let header = Header {
237            alg: self.config.algorithm,
238            ..Default::default()
239        };
240
241        encode(&header, claims, &self.config.signing_key)
242            .map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
243    }
244
245    /// Verify and decode JWT token
246    pub fn verify_token(&self, token: &str) -> Result<CommonJwtClaims> {
247        // Basic format validation
248        common_validation::jwt::validate_jwt_format(token)?;
249
250        let mut validation = Validation::new(self.config.algorithm);
251        validation.set_issuer(&[&self.config.issuer]);
252
253        if !self.config.audiences.is_empty() {
254            validation.set_audience(
255                &self
256                    .config
257                    .audiences
258                    .iter()
259                    .map(String::as_str)
260                    .collect::<Vec<_>>(),
261            );
262        }
263
264        let token_data =
265            decode::<CommonJwtClaims>(token, &self.config.verification_key, &validation)
266                .map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
267
268        // Additional validation using common validation utilities
269        let claims_value = serde_json::to_value(&token_data.claims)
270            .map_err(|e| AuthError::validation(format!("Failed to serialize claims: {}", e)))?;
271
272        common_validation::jwt::validate_time_claims(&claims_value)?;
273
274        Ok(token_data.claims)
275    }
276
277    /// Verify token and extract custom claims
278    pub fn verify_token_with_custom_claims<T>(&self, token: &str) -> Result<T>
279    where
280        T: for<'de> Deserialize<'de>,
281    {
282        common_validation::jwt::validate_jwt_format(token)?;
283
284        let mut validation = Validation::new(self.config.algorithm);
285        validation.set_issuer(&[&self.config.issuer]);
286
287        if !self.config.audiences.is_empty() {
288            validation.set_audience(
289                &self
290                    .config
291                    .audiences
292                    .iter()
293                    .map(String::as_str)
294                    .collect::<Vec<_>>(),
295            );
296        }
297
298        let token_data = decode::<T>(token, &self.config.verification_key, &validation)
299            .map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
300
301        Ok(token_data.claims)
302    }
303
304    /// Create access token with standard claims
305    pub fn create_access_token(
306        &self,
307        subject: String,
308        scope: Vec<String>,
309        client_id: Option<String>,
310    ) -> Result<String> {
311        let exp = SystemTime::now()
312            .duration_since(UNIX_EPOCH)
313            .unwrap()
314            .as_secs() as i64
315            + self.config.default_expiration as i64;
316
317        let mut claims = CommonJwtClaims::new(
318            self.config.issuer.clone(),
319            subject,
320            self.config.audiences.clone(),
321            exp,
322        );
323
324        claims
325            .custom
326            .insert("scope".to_string(), serde_json::json!(scope.join(" ")));
327
328        if let Some(client_id) = client_id {
329            claims.custom.insert(
330                "client_id".to_string(),
331                serde_json::Value::String(client_id),
332            );
333        }
334
335        claims.custom.insert(
336            "token_type".to_string(),
337            serde_json::Value::String("access_token".to_string()),
338        );
339
340        self.create_token(&claims)
341    }
342
343    /// Create refresh token
344    pub fn create_refresh_token(&self, subject: String, client_id: String) -> Result<String> {
345        // Refresh tokens typically have longer expiration
346        let exp = SystemTime::now()
347            .duration_since(UNIX_EPOCH)
348            .unwrap()
349            .as_secs() as i64
350            + (self.config.default_expiration * 24) as i64; // 24x longer
351
352        let mut claims = CommonJwtClaims::new(
353            self.config.issuer.clone(),
354            subject,
355            self.config.audiences.clone(),
356            exp,
357        );
358
359        claims.custom.insert(
360            "client_id".to_string(),
361            serde_json::Value::String(client_id),
362        );
363        claims.custom.insert(
364            "token_type".to_string(),
365            serde_json::Value::String("refresh_token".to_string()),
366        );
367
368        self.create_token(&claims)
369    }
370
371    /// Create ID token for OpenID Connect
372    pub fn create_id_token(
373        &self,
374        subject: String,
375        nonce: Option<String>,
376        auth_time: Option<i64>,
377        user_info: HashMap<String, serde_json::Value>,
378    ) -> Result<String> {
379        let exp = SystemTime::now()
380            .duration_since(UNIX_EPOCH)
381            .unwrap()
382            .as_secs() as i64
383            + 300; // 5 minutes for ID token
384
385        let mut claims = CommonJwtClaims::new(
386            self.config.issuer.clone(),
387            subject,
388            self.config.audiences.clone(),
389            exp,
390        );
391
392        claims.custom.insert(
393            "token_type".to_string(),
394            serde_json::Value::String("id_token".to_string()),
395        );
396
397        if let Some(nonce) = nonce {
398            claims
399                .custom
400                .insert("nonce".to_string(), serde_json::Value::String(nonce));
401        }
402
403        if let Some(auth_time) = auth_time {
404            claims.custom.insert(
405                "auth_time".to_string(),
406                serde_json::Value::Number(auth_time.into()),
407            );
408        }
409
410        // Add user info claims
411        for (key, value) in user_info {
412            claims.custom.insert(key, value);
413        }
414
415        self.create_token(&claims)
416    }
417}
418
419/// JWT utilities for token introspection and manipulation
420pub mod utils {
421    use super::*;
422
423    /// Extract claims from JWT without verification (for inspection only)
424    ///
425    /// # Security Warning
426    /// This function bypasses JWT signature verification! Only use for:
427    /// - Token inspection and debugging
428    /// - Extracting metadata before full validation
429    /// - Non-security-critical token analysis
430    ///
431    /// Never use for authentication or authorization decisions!
432    pub fn extract_claims_unsafe(token: &str) -> Result<serde_json::Value> {
433        common_validation::jwt::extract_claims_unsafe(token)
434    }
435
436    /// Check if token is expired without full verification
437    ///
438    /// # Security Warning
439    /// This function checks expiration without validating the JWT signature.
440    /// Only use for preliminary checks - always validate the token fully
441    /// before making security decisions!
442    pub fn is_token_expired(token: &str) -> Result<bool> {
443        let claims = extract_claims_unsafe(token)?;
444
445        let now = SystemTime::now()
446            .duration_since(UNIX_EPOCH)
447            .unwrap()
448            .as_secs() as i64;
449
450        if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) {
451            Ok(now >= exp)
452        } else {
453            Ok(false) // No expiration claim means not expired
454        }
455    }
456
457    /// Get token expiration time without signature validation
458    ///
459    /// # Security Warning
460    /// This function extracts expiration time without validating the JWT signature.
461    /// Only use for inspection - validate the token before trusting the data!
462    pub fn get_token_expiration(token: &str) -> Result<Option<i64>> {
463        let claims = extract_claims_unsafe(token)?;
464        Ok(claims.get("exp").and_then(|v| v.as_i64()))
465    }
466
467    /// Get token subject without signature validation
468    ///
469    /// # Security Warning
470    /// This function extracts the subject without validating the JWT signature.
471    /// Only use for inspection - validate the token before trusting the data!
472    pub fn get_token_subject(token: &str) -> Result<Option<String>> {
473        let claims = extract_claims_unsafe(token)?;
474        Ok(claims.get("sub").and_then(|v| v.as_str()).map(String::from))
475    }
476
477    /// Get token scopes without signature validation
478    ///
479    /// # Security Warning
480    /// This function extracts scopes without validating the JWT signature.
481    /// Only use for inspection - validate the token before trusting the data!
482    pub fn get_token_scopes(token: &str) -> Result<Vec<String>> {
483        let claims = extract_claims_unsafe(token)?;
484
485        if let Some(scope_str) = claims.get("scope").and_then(|v| v.as_str()) {
486            Ok(scope_str.split_whitespace().map(String::from).collect())
487        } else if let Some(scopes_array) = claims.get("scopes").and_then(|v| v.as_array()) {
488            Ok(scopes_array
489                .iter()
490                .filter_map(|v| v.as_str())
491                .map(String::from)
492                .collect())
493        } else {
494            Ok(vec![])
495        }
496    }
497}
498
499