auth_framework/server/core/
common_jwt.rs

1//! Common JWT Operations
2//!
3//! This module provides shared JWT functionality to eliminate
4//! duplication across server modules.
5
6use crate::errors::{AuthError, Result};
7use crate::server::core::common_validation;
8use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13/// Common JWT configuration
14#[derive(Clone)]
15pub struct JwtConfig {
16    /// Signing algorithm
17    pub algorithm: Algorithm,
18    /// Signing key
19    pub signing_key: EncodingKey,
20    /// Verification key
21    pub verification_key: DecodingKey,
22    /// Default expiration time in seconds
23    pub default_expiration: u64,
24    /// Issuer
25    pub issuer: String,
26    /// Audiences
27    pub audiences: Vec<String>,
28}
29
30impl JwtConfig {
31    /// Create new JWT config with symmetric key
32    pub fn with_symmetric_key(secret: &[u8], issuer: String) -> Self {
33        Self {
34            algorithm: Algorithm::HS256,
35            signing_key: EncodingKey::from_secret(secret),
36            verification_key: DecodingKey::from_secret(secret),
37            default_expiration: 3600, // 1 hour
38            issuer,
39            audiences: vec![],
40        }
41    }
42
43    /// Create new JWT config with RSA keys
44    pub fn with_rsa_keys(private_key: &[u8], public_key: &[u8], issuer: String) -> Result<Self> {
45        let signing_key = EncodingKey::from_rsa_pem(private_key)
46            .map_err(|e| AuthError::validation(format!("Invalid private key: {}", e)))?;
47
48        let verification_key = DecodingKey::from_rsa_pem(public_key)
49            .map_err(|e| AuthError::validation(format!("Invalid public key: {}", e)))?;
50
51        Ok(Self {
52            algorithm: Algorithm::RS256,
53            signing_key,
54            verification_key,
55            default_expiration: 3600, // 1 hour
56            issuer,
57            audiences: vec![],
58        })
59    }
60
61    /// Add audience
62    pub fn with_audience(mut self, audience: String) -> Self {
63        self.audiences.push(audience);
64        self
65    }
66
67    /// Set expiration time
68    pub fn with_expiration(mut self, expiration: u64) -> Self {
69        self.default_expiration = expiration;
70        self
71    }
72}
73
74/// Common JWT claims structure
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct CommonJwtClaims {
77    /// Issuer
78    pub iss: String,
79    /// Subject
80    pub sub: String,
81    /// Audience
82    pub aud: Vec<String>,
83    /// Expiration time
84    pub exp: i64,
85    /// Issued at
86    pub iat: i64,
87    /// Not before
88    pub nbf: Option<i64>,
89    /// JWT ID
90    pub jti: Option<String>,
91    /// Custom claims
92    #[serde(flatten)]
93    pub custom: HashMap<String, serde_json::Value>,
94}
95
96impl CommonJwtClaims {
97    /// Create new claims with required fields
98    pub fn new(issuer: String, subject: String, audiences: Vec<String>, expiration: i64) -> Self {
99        let now = SystemTime::now()
100            .duration_since(UNIX_EPOCH)
101            .unwrap()
102            .as_secs() as i64;
103
104        Self {
105            iss: issuer,
106            sub: subject,
107            aud: audiences,
108            exp: expiration,
109            iat: now,
110            nbf: None,
111            jti: None,
112            custom: HashMap::new(),
113        }
114    }
115
116    /// Add custom claim
117    pub fn with_custom_claim(mut self, key: String, value: serde_json::Value) -> Self {
118        self.custom.insert(key, value);
119        self
120    }
121
122    /// Set JWT ID
123    pub fn with_jti(mut self, jti: String) -> Self {
124        self.jti = Some(jti);
125        self
126    }
127
128    /// Set not before
129    pub fn with_nbf(mut self, nbf: i64) -> Self {
130        self.nbf = Some(nbf);
131        self
132    }
133}
134
135/// Common JWT token management for OAuth 2.0 and OpenID Connect operations.
136///
137/// `JwtManager` provides comprehensive JWT token creation, verification, and
138/// management capabilities specifically designed for OAuth 2.0 authorization
139/// servers and OpenID Connect providers. It supports both symmetric and
140/// asymmetric signing algorithms with security best practices.
141///
142/// # Supported Algorithms
143///
144/// - **HMAC**: HS256, HS384, HS512 (symmetric)
145/// - **RSA**: RS256, RS384, RS512 (asymmetric)
146/// - **ECDSA**: ES256, ES384, ES512 (asymmetric)
147/// - **EdDSA**: EdDSA (asymmetric, Ed25519)
148///
149/// # Security Features
150///
151/// - **Algorithm Validation**: Prevents algorithm confusion attacks
152/// - **Time Validation**: Automatic `exp`, `nbf`, and `iat` claim validation
153/// - **Audience Validation**: Ensures tokens are used by intended recipients
154/// - **Issuer Validation**: Verifies token origin
155/// - **Secure Defaults**: Uses secure algorithm choices and expiration times
156///
157/// # Token Types Supported
158///
159/// - **Access Tokens**: OAuth 2.0 access tokens with scopes
160/// - **ID Tokens**: OpenID Connect identity tokens
161/// - **Refresh Tokens**: Long-lived tokens for access token renewal
162/// - **Custom Tokens**: Application-specific token types
163///
164/// # Key Management
165///
166/// - **Symmetric Keys**: HMAC-based signing with shared secrets
167/// - **RSA Keys**: Support for PKCS#1 and PKCS#8 key formats
168/// - **Key Rotation**: Support for multiple signing keys
169/// - **Key Security**: Secure key storage and access patterns
170///
171/// # Example
172///
173/// ```rust,no_run
174/// use auth_framework::server::core::common_jwt::{JwtManager, JwtConfig, CommonJwtClaims};
175/// use serde_json::json;
176/// use chrono::{Duration, Utc};
177///
178/// # #[tokio::main]
179/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
180/// # let private_key_bytes = b"dummy_private_key";
181/// # let public_key_bytes = b"dummy_public_key";
182/// # let expiration_time = (Utc::now() + Duration::hours(1)).timestamp();
183/// // Create JWT manager with RSA keys
184/// let config = JwtConfig::with_rsa_keys(
185///     private_key_bytes,
186///     public_key_bytes,
187///     "https://auth.example.com".to_string()
188/// )?;
189/// let jwt_manager = JwtManager::new(config);
190///
191/// // Create access token
192/// let claims = CommonJwtClaims::new(
193///     "https://auth.example.com".to_string(),
194///     "user123".to_string(),
195///     vec!["api".to_string()],
196///     expiration_time
197/// ).with_custom_claim("scope".to_string(), json!("read write"));
198///
199/// let token = jwt_manager.create_token(&claims)?;
200///
201/// // Verify token
202/// let verified_claims = jwt_manager.verify_token(&token)?;
203/// # Ok(())
204/// # }
205/// ```
206///
207/// # Performance Considerations
208///
209/// - Asymmetric algorithms are more computationally expensive
210/// - Token verification is optimized for high-throughput scenarios
211/// - Key caching reduces cryptographic operation overhead
212///
213/// # RFC Compliance
214///
215/// - **RFC 7519**: JSON Web Token (JWT)
216/// - **RFC 7515**: JSON Web Signature (JWS)
217/// - **RFC 8725**: JWT Best Current Practices
218/// - **RFC 9068**: JWT Profile for OAuth 2.0 Access Tokens
219pub struct JwtManager {
220    config: JwtConfig,
221}
222
223impl JwtManager {
224    /// Create new JWT manager
225    pub fn new(config: JwtConfig) -> Self {
226        Self { config }
227    }
228
229    /// Create signed JWT token
230    pub fn create_token(&self, claims: &CommonJwtClaims) -> Result<String> {
231        let header = Header {
232            alg: self.config.algorithm,
233            ..Default::default()
234        };
235
236        encode(&header, claims, &self.config.signing_key)
237            .map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
238    }
239
240    /// Create signed token with custom claims
241    pub fn create_token_with_custom_claims<T>(&self, claims: &T) -> Result<String>
242    where
243        T: Serialize,
244    {
245        let header = Header {
246            alg: self.config.algorithm,
247            ..Default::default()
248        };
249
250        encode(&header, claims, &self.config.signing_key)
251            .map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
252    }
253
254    /// Verify and decode JWT token
255    pub fn verify_token(&self, token: &str) -> Result<CommonJwtClaims> {
256        // Basic format validation
257        common_validation::jwt::validate_jwt_format(token)?;
258
259        let mut validation = Validation::new(self.config.algorithm);
260        validation.set_issuer(&[&self.config.issuer]);
261
262        if !self.config.audiences.is_empty() {
263            validation.set_audience(
264                &self
265                    .config
266                    .audiences
267                    .iter()
268                    .map(String::as_str)
269                    .collect::<Vec<_>>(),
270            );
271        }
272
273        let token_data =
274            decode::<CommonJwtClaims>(token, &self.config.verification_key, &validation)
275                .map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
276
277        // Additional validation using common validation utilities
278        let claims_value = serde_json::to_value(&token_data.claims)
279            .map_err(|e| AuthError::validation(format!("Failed to serialize claims: {}", e)))?;
280
281        common_validation::jwt::validate_time_claims(&claims_value)?;
282
283        Ok(token_data.claims)
284    }
285
286    /// Verify token and extract custom claims
287    pub fn verify_token_with_custom_claims<T>(&self, token: &str) -> Result<T>
288    where
289        T: for<'de> Deserialize<'de>,
290    {
291        common_validation::jwt::validate_jwt_format(token)?;
292
293        let mut validation = Validation::new(self.config.algorithm);
294        validation.set_issuer(&[&self.config.issuer]);
295
296        if !self.config.audiences.is_empty() {
297            validation.set_audience(
298                &self
299                    .config
300                    .audiences
301                    .iter()
302                    .map(String::as_str)
303                    .collect::<Vec<_>>(),
304            );
305        }
306
307        let token_data = decode::<T>(token, &self.config.verification_key, &validation)
308            .map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
309
310        Ok(token_data.claims)
311    }
312
313    /// Create access token with standard claims
314    pub fn create_access_token(
315        &self,
316        subject: String,
317        scope: Vec<String>,
318        client_id: Option<String>,
319    ) -> Result<String> {
320        let exp = SystemTime::now()
321            .duration_since(UNIX_EPOCH)
322            .unwrap()
323            .as_secs() as i64
324            + self.config.default_expiration as i64;
325
326        let mut claims = CommonJwtClaims::new(
327            self.config.issuer.clone(),
328            subject,
329            self.config.audiences.clone(),
330            exp,
331        );
332
333        claims
334            .custom
335            .insert("scope".to_string(), serde_json::json!(scope.join(" ")));
336
337        if let Some(client_id) = client_id {
338            claims.custom.insert(
339                "client_id".to_string(),
340                serde_json::Value::String(client_id),
341            );
342        }
343
344        claims.custom.insert(
345            "token_type".to_string(),
346            serde_json::Value::String("access_token".to_string()),
347        );
348
349        self.create_token(&claims)
350    }
351
352    /// Create refresh token
353    pub fn create_refresh_token(&self, subject: String, client_id: String) -> Result<String> {
354        // Refresh tokens typically have longer expiration
355        let exp = SystemTime::now()
356            .duration_since(UNIX_EPOCH)
357            .unwrap()
358            .as_secs() as i64
359            + (self.config.default_expiration * 24) as i64; // 24x longer
360
361        let mut claims = CommonJwtClaims::new(
362            self.config.issuer.clone(),
363            subject,
364            self.config.audiences.clone(),
365            exp,
366        );
367
368        claims.custom.insert(
369            "client_id".to_string(),
370            serde_json::Value::String(client_id),
371        );
372        claims.custom.insert(
373            "token_type".to_string(),
374            serde_json::Value::String("refresh_token".to_string()),
375        );
376
377        self.create_token(&claims)
378    }
379
380    /// Create ID token for OpenID Connect
381    pub fn create_id_token(
382        &self,
383        subject: String,
384        nonce: Option<String>,
385        auth_time: Option<i64>,
386        user_info: HashMap<String, serde_json::Value>,
387    ) -> Result<String> {
388        let exp = SystemTime::now()
389            .duration_since(UNIX_EPOCH)
390            .unwrap()
391            .as_secs() as i64
392            + 300; // 5 minutes for ID token
393
394        let mut claims = CommonJwtClaims::new(
395            self.config.issuer.clone(),
396            subject,
397            self.config.audiences.clone(),
398            exp,
399        );
400
401        claims.custom.insert(
402            "token_type".to_string(),
403            serde_json::Value::String("id_token".to_string()),
404        );
405
406        if let Some(nonce) = nonce {
407            claims
408                .custom
409                .insert("nonce".to_string(), serde_json::Value::String(nonce));
410        }
411
412        if let Some(auth_time) = auth_time {
413            claims.custom.insert(
414                "auth_time".to_string(),
415                serde_json::Value::Number(auth_time.into()),
416            );
417        }
418
419        // Add user info claims
420        for (key, value) in user_info {
421            claims.custom.insert(key, value);
422        }
423
424        self.create_token(&claims)
425    }
426}
427
428/// JWT utilities for token introspection and manipulation
429pub mod utils {
430    use super::*;
431
432    /// Extract claims from JWT without verification (for inspection only)
433    ///
434    /// # Security Warning
435    /// This function bypasses JWT signature verification! Only use for:
436    /// - Token inspection and debugging
437    /// - Extracting metadata before full validation
438    /// - Non-security-critical token analysis
439    ///
440    /// Never use for authentication or authorization decisions!
441    pub fn extract_claims_unsafe(token: &str) -> Result<serde_json::Value> {
442        common_validation::jwt::extract_claims_unsafe(token)
443    }
444
445    /// Check if token is expired without full verification
446    ///
447    /// # Security Warning
448    /// This function checks expiration without validating the JWT signature.
449    /// Only use for preliminary checks - always validate the token fully
450    /// before making security decisions!
451    pub fn is_token_expired(token: &str) -> Result<bool> {
452        let claims = extract_claims_unsafe(token)?;
453
454        let now = SystemTime::now()
455            .duration_since(UNIX_EPOCH)
456            .unwrap()
457            .as_secs() as i64;
458
459        if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) {
460            Ok(now >= exp)
461        } else {
462            Ok(false) // No expiration claim means not expired
463        }
464    }
465
466    /// Get token expiration time without signature validation
467    ///
468    /// # Security Warning
469    /// This function extracts expiration time without validating the JWT signature.
470    /// Only use for inspection - validate the token before trusting the data!
471    pub fn get_token_expiration(token: &str) -> Result<Option<i64>> {
472        let claims = extract_claims_unsafe(token)?;
473        Ok(claims.get("exp").and_then(|v| v.as_i64()))
474    }
475
476    /// Get token subject without signature validation
477    ///
478    /// # Security Warning
479    /// This function extracts the subject without validating the JWT signature.
480    /// Only use for inspection - validate the token before trusting the data!
481    pub fn get_token_subject(token: &str) -> Result<Option<String>> {
482        let claims = extract_claims_unsafe(token)?;
483        Ok(claims.get("sub").and_then(|v| v.as_str()).map(String::from))
484    }
485
486    /// Get token scopes without signature validation
487    ///
488    /// # Security Warning
489    /// This function extracts scopes without validating the JWT signature.
490    /// Only use for inspection - validate the token before trusting the data!
491    pub fn get_token_scopes(token: &str) -> Result<Vec<String>> {
492        let claims = extract_claims_unsafe(token)?;
493
494        if let Some(scope_str) = claims.get("scope").and_then(|v| v.as_str()) {
495            Ok(scope_str.split_whitespace().map(String::from).collect())
496        } else if let Some(scopes_array) = claims.get("scopes").and_then(|v| v.as_array()) {
497            Ok(scopes_array
498                .iter()
499                .filter_map(|v| v.as_str())
500                .map(String::from)
501                .collect())
502        } else {
503            Ok(vec![])
504        }
505    }
506}