Skip to main content

auth_framework/server/core/
common_jwt.rs

1//! Common JWT Operations
2//!
3//! This module provides shared JWT functionality to eliminate
4//! duplication across server modules.
5
6use crate::errors::{AuthError, Result};
7use crate::server::core::common_validation;
8use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::time::{SystemTime, UNIX_EPOCH};
12
13/// Common JWT configuration
14#[derive(Clone)]
15pub struct JwtConfig {
16    /// Signing algorithm
17    pub algorithm: Algorithm,
18    /// Signing key
19    pub signing_key: EncodingKey,
20    /// Verification key
21    pub verification_key: DecodingKey,
22    /// Default expiration time in seconds
23    pub default_expiration: u64,
24    /// Issuer
25    pub issuer: String,
26    /// Audiences
27    pub audiences: Vec<String>,
28}
29
30impl JwtConfig {
31    /// Create new JWT config with symmetric key
32    pub fn with_symmetric_key(secret: &[u8], issuer: String) -> Self {
33        Self {
34            algorithm: Algorithm::HS256,
35            signing_key: EncodingKey::from_secret(secret),
36            verification_key: DecodingKey::from_secret(secret),
37            default_expiration: 3600, // 1 hour
38            issuer,
39            audiences: vec![],
40        }
41    }
42
43    /// Create new JWT config with RSA keys
44    pub fn with_rsa_keys(private_key: &[u8], public_key: &[u8], issuer: String) -> Result<Self> {
45        let signing_key = EncodingKey::from_rsa_pem(private_key)
46            .map_err(|e| AuthError::validation(format!("Invalid private key: {}", e)))?;
47
48        let verification_key = DecodingKey::from_rsa_pem(public_key)
49            .map_err(|e| AuthError::validation(format!("Invalid public key: {}", e)))?;
50
51        Ok(Self {
52            algorithm: Algorithm::RS256,
53            signing_key,
54            verification_key,
55            default_expiration: 3600, // 1 hour
56            issuer,
57            audiences: vec![],
58        })
59    }
60
61    /// Add audience
62    pub fn with_audience(mut self, audience: String) -> Self {
63        self.audiences.push(audience);
64        self
65    }
66
67    /// Set expiration time
68    pub fn with_expiration(mut self, expiration: u64) -> Self {
69        self.default_expiration = expiration;
70        self
71    }
72}
73
74/// Common JWT claims structure
75#[derive(Debug, Clone, Serialize, Deserialize)]
76pub struct CommonJwtClaims {
77    /// Issuer
78    pub iss: String,
79    /// Subject
80    pub sub: String,
81    /// Audience
82    pub aud: Vec<String>,
83    /// Expiration time
84    pub exp: i64,
85    /// Issued at
86    pub iat: i64,
87    /// Not before
88    pub nbf: Option<i64>,
89    /// JWT ID
90    pub jti: Option<String>,
91    /// Custom claims
92    #[serde(flatten)]
93    pub custom: HashMap<String, serde_json::Value>,
94}
95
96impl CommonJwtClaims {
97    /// Create new claims with required fields
98    pub fn new(issuer: String, subject: String, audiences: Vec<String>, expiration: i64) -> Self {
99        let now = SystemTime::now()
100            .duration_since(UNIX_EPOCH)
101            .unwrap_or_default()
102            .as_secs() as i64;
103
104        Self {
105            iss: issuer,
106            sub: subject,
107            aud: audiences,
108            exp: expiration,
109            iat: now,
110            nbf: None,
111            jti: None,
112            custom: HashMap::new(),
113        }
114    }
115
116    /// Add custom claim
117    pub fn with_custom_claim(mut self, key: String, value: serde_json::Value) -> Self {
118        self.custom.insert(key, value);
119        self
120    }
121
122    /// Set JWT ID
123    pub fn with_jti(mut self, jti: String) -> Self {
124        self.jti = Some(jti);
125        self
126    }
127
128    /// Set not before
129    pub fn with_nbf(mut self, nbf: i64) -> Self {
130        self.nbf = Some(nbf);
131        self
132    }
133}
134
135/// Common JWT token management for OAuth 2.0 and OpenID Connect operations.
136///
137/// `JwtManager` provides comprehensive JWT token creation, verification, and
138/// management capabilities specifically designed for OAuth 2.0 authorization
139/// servers and OpenID Connect providers. It supports both symmetric and
140/// asymmetric signing algorithms with security best practices.
141///
142/// # Supported Algorithms
143///
144/// - **HMAC**: HS256, HS384, HS512 (symmetric)
145/// - **RSA**: RS256, RS384, RS512 (asymmetric)
146/// - **ECDSA**: ES256, ES384, ES512 (asymmetric)
147/// - **EdDSA**: EdDSA (asymmetric, Ed25519)
148///
149/// # Security Features
150///
151/// - **Algorithm Validation**: Prevents algorithm confusion attacks
152/// - **Time Validation**: Automatic `exp`, `nbf`, and `iat` claim validation
153/// - **Audience Validation**: Ensures tokens are used by intended recipients
154/// - **Issuer Validation**: Verifies token origin
155/// - **Secure Defaults**: Uses secure algorithm choices and expiration times
156///
157/// # Token Types Supported
158///
159/// - **Access Tokens**: OAuth 2.0 access tokens with scopes
160/// - **ID Tokens**: OpenID Connect identity tokens
161/// - **Refresh Tokens**: Long-lived tokens for access token renewal
162/// - **Custom Tokens**: Application-specific token types
163///
164/// # Key Management
165///
166/// - **Symmetric Keys**: HMAC-based signing with shared secrets
167/// - **RSA Keys**: Support for PKCS#1 and PKCS#8 key formats
168/// - **Key Rotation**: Support for multiple signing keys
169/// - **Key Security**: Secure key storage and access patterns
170///
171/// # Example
172///
173/// ```rust,no_run
174/// use auth_framework::server::core::common_jwt::{JwtManager, JwtConfig, CommonJwtClaims};
175///
176/// # fn main() -> Result<(), Box<dyn std::error::Error>> {
177/// # let private_key_bytes: &[u8] = unimplemented!();
178/// # let public_key_bytes: &[u8] = unimplemented!();
179/// # let expiration_time: i64 = unimplemented!();
180/// // Create JWT manager with RSA keys
181/// let config = JwtConfig::with_rsa_keys(
182///     private_key_bytes,
183///     public_key_bytes,
184///     "https://auth.example.com".to_string()
185/// )?;
186/// let jwt_manager = JwtManager::new(config);
187///
188/// // Create access token
189/// let claims = CommonJwtClaims::new(
190///     "https://auth.example.com".to_string(),
191///     "user123".to_string(),
192///     vec!["api".to_string()],
193///     expiration_time
194/// ).with_custom_claim("scope".to_string(), serde_json::json!("read write"));
195///
196/// let token = jwt_manager.create_token(&claims)?;
197///
198/// // Verify token
199/// let verified_claims = jwt_manager.verify_token(&token)?;
200/// # Ok(())
201/// # }
202/// ```
203///
204/// # Performance Considerations
205///
206/// - Asymmetric algorithms are more computationally expensive
207/// - Token verification is optimized for high-throughput scenarios
208/// - Key caching reduces cryptographic operation overhead
209///
210/// # RFC Compliance
211///
212/// - **RFC 7519**: JSON Web Token (JWT)
213/// - **RFC 7515**: JSON Web Signature (JWS)
214/// - **RFC 8725**: JWT Best Current Practices
215/// - **RFC 9068**: JWT Profile for OAuth 2.0 Access Tokens
216pub struct JwtManager {
217    config: JwtConfig,
218}
219
220impl JwtManager {
221    /// Create new JWT manager
222    pub fn new(config: JwtConfig) -> Self {
223        Self { config }
224    }
225
226    /// Create signed JWT token
227    pub fn create_token(&self, claims: &CommonJwtClaims) -> Result<String> {
228        let header = Header {
229            alg: self.config.algorithm,
230            ..Default::default()
231        };
232
233        encode(&header, claims, &self.config.signing_key)
234            .map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
235    }
236
237    /// Create signed token with custom claims
238    pub fn create_token_with_custom_claims<T>(&self, claims: &T) -> Result<String>
239    where
240        T: Serialize,
241    {
242        let header = Header {
243            alg: self.config.algorithm,
244            ..Default::default()
245        };
246
247        encode(&header, claims, &self.config.signing_key)
248            .map_err(|e| AuthError::validation(format!("Failed to encode JWT: {}", e)))
249    }
250
251    /// Verify and decode JWT token
252    pub fn verify_token(&self, token: &str) -> Result<CommonJwtClaims> {
253        // Basic format validation
254        common_validation::jwt::validate_jwt_format(token)?;
255
256        let mut validation = Validation::new(self.config.algorithm);
257        validation.set_issuer(&[&self.config.issuer]);
258
259        if !self.config.audiences.is_empty() {
260            validation.set_audience(
261                &self
262                    .config
263                    .audiences
264                    .iter()
265                    .map(String::as_str)
266                    .collect::<Vec<_>>(),
267            );
268        }
269
270        let token_data =
271            decode::<CommonJwtClaims>(token, &self.config.verification_key, &validation)
272                .map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
273
274        // Additional validation using common validation utilities
275        let claims_value = serde_json::to_value(&token_data.claims)
276            .map_err(|e| AuthError::validation(format!("Failed to serialize claims: {}", e)))?;
277
278        common_validation::jwt::validate_time_claims(&claims_value)?;
279
280        Ok(token_data.claims)
281    }
282
283    /// Verify token and extract custom claims
284    pub fn verify_token_with_custom_claims<T>(&self, token: &str) -> Result<T>
285    where
286        T: for<'de> Deserialize<'de>,
287    {
288        common_validation::jwt::validate_jwt_format(token)?;
289
290        let mut validation = Validation::new(self.config.algorithm);
291        validation.set_issuer(&[&self.config.issuer]);
292
293        if !self.config.audiences.is_empty() {
294            validation.set_audience(
295                &self
296                    .config
297                    .audiences
298                    .iter()
299                    .map(String::as_str)
300                    .collect::<Vec<_>>(),
301            );
302        }
303
304        let token_data = decode::<T>(token, &self.config.verification_key, &validation)
305            .map_err(|e| AuthError::validation(format!("Invalid JWT: {}", e)))?;
306
307        Ok(token_data.claims)
308    }
309
310    /// Create access token with standard claims
311    pub fn create_access_token(
312        &self,
313        subject: String,
314        scope: Vec<String>,
315        client_id: Option<String>,
316    ) -> Result<String> {
317        let exp = SystemTime::now()
318            .duration_since(UNIX_EPOCH)
319            .unwrap_or_default()
320            .as_secs() as i64
321            + self.config.default_expiration as i64;
322
323        let mut claims = CommonJwtClaims::new(
324            self.config.issuer.clone(),
325            subject,
326            self.config.audiences.clone(),
327            exp,
328        );
329
330        claims
331            .custom
332            .insert("scope".to_string(), serde_json::json!(scope.join(" ")));
333
334        if let Some(client_id) = client_id {
335            claims.custom.insert(
336                "client_id".to_string(),
337                serde_json::Value::String(client_id),
338            );
339        }
340
341        claims.custom.insert(
342            "token_type".to_string(),
343            serde_json::Value::String("access_token".to_string()),
344        );
345
346        self.create_token(&claims)
347    }
348
349    /// Create refresh token
350    pub fn create_refresh_token(&self, subject: String, client_id: String) -> Result<String> {
351        // Refresh tokens typically have longer expiration
352        let exp = SystemTime::now()
353            .duration_since(UNIX_EPOCH)
354            .unwrap_or_default()
355            .as_secs() as i64
356            + (self.config.default_expiration * 24) as i64; // 24x longer
357
358        let mut claims = CommonJwtClaims::new(
359            self.config.issuer.clone(),
360            subject,
361            self.config.audiences.clone(),
362            exp,
363        );
364
365        claims.custom.insert(
366            "client_id".to_string(),
367            serde_json::Value::String(client_id),
368        );
369        claims.custom.insert(
370            "token_type".to_string(),
371            serde_json::Value::String("refresh_token".to_string()),
372        );
373
374        self.create_token(&claims)
375    }
376
377    /// Create ID token for OpenID Connect
378    pub fn create_id_token(
379        &self,
380        subject: String,
381        nonce: Option<String>,
382        auth_time: Option<i64>,
383        user_info: HashMap<String, serde_json::Value>,
384    ) -> Result<String> {
385        let exp = SystemTime::now()
386            .duration_since(UNIX_EPOCH)
387            .unwrap_or_default()
388            .as_secs() as i64
389            + 300; // 5 minutes for ID token
390
391        let mut claims = CommonJwtClaims::new(
392            self.config.issuer.clone(),
393            subject,
394            self.config.audiences.clone(),
395            exp,
396        );
397
398        claims.custom.insert(
399            "token_type".to_string(),
400            serde_json::Value::String("id_token".to_string()),
401        );
402
403        if let Some(nonce) = nonce {
404            claims
405                .custom
406                .insert("nonce".to_string(), serde_json::Value::String(nonce));
407        }
408
409        if let Some(auth_time) = auth_time {
410            claims.custom.insert(
411                "auth_time".to_string(),
412                serde_json::Value::Number(auth_time.into()),
413            );
414        }
415
416        // Add user info claims
417        for (key, value) in user_info {
418            claims.custom.insert(key, value);
419        }
420
421        self.create_token(&claims)
422    }
423}
424
425/// JWT utilities for token introspection and manipulation
426pub(crate) mod utils {
427    use super::*;
428
429    /// Extract claims from JWT without verification (for inspection only)
430    ///
431    /// # Security Warning
432    /// This function bypasses JWT signature verification! Only use for:
433    /// - Token inspection and debugging
434    /// - Extracting metadata before full validation
435    /// - Non-security-critical token analysis
436    ///
437    /// Never use for authentication or authorization decisions!
438    #[allow(dead_code)]
439    pub(crate) fn extract_claims_unsafe(token: &str) -> Result<serde_json::Value> {
440        common_validation::jwt::extract_claims_unsafe(token)
441    }
442
443    /// Check if token is expired without full verification
444    ///
445    /// # Security Warning
446    /// This function checks expiration without validating the JWT signature.
447    /// Only use for preliminary checks - always validate the token fully
448    /// before making security decisions!
449    #[allow(dead_code)]
450    pub(crate) fn is_token_expired(token: &str) -> Result<bool> {
451        let claims = extract_claims_unsafe(token)?;
452
453        let now = SystemTime::now()
454            .duration_since(UNIX_EPOCH)
455            .unwrap_or_default()
456            .as_secs() as i64;
457
458        if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) {
459            Ok(now >= exp)
460        } else {
461            Ok(false) // No expiration claim means not expired
462        }
463    }
464
465    /// Get token expiration time without signature validation
466    ///
467    /// # Security Warning
468    /// This function extracts expiration time without validating the JWT signature.
469    /// Only use for inspection - validate the token before trusting the data!
470    #[allow(dead_code)]
471    pub(crate) fn get_token_expiration(token: &str) -> Result<Option<i64>> {
472        let claims = extract_claims_unsafe(token)?;
473        Ok(claims.get("exp").and_then(|v| v.as_i64()))
474    }
475
476    /// Get token subject without signature validation
477    ///
478    /// # Security Warning
479    /// This function extracts the subject without validating the JWT signature.
480    /// Only use for inspection - validate the token before trusting the data!
481    #[allow(dead_code)]
482    pub(crate) fn get_token_subject(token: &str) -> Result<Option<String>> {
483        let claims = extract_claims_unsafe(token)?;
484        Ok(claims.get("sub").and_then(|v| v.as_str()).map(String::from))
485    }
486
487    /// Get token scopes without signature validation
488    ///
489    /// # Security Warning
490    /// This function extracts scopes without validating the JWT signature.
491    /// Only use for inspection - validate the token before trusting the data!
492    #[allow(dead_code)]
493    pub(crate) fn get_token_scopes(token: &str) -> Result<Vec<String>> {
494        let claims = extract_claims_unsafe(token)?;
495
496        if let Some(scope_str) = claims.get("scope").and_then(|v| v.as_str()) {
497            Ok(scope_str.split_whitespace().map(String::from).collect())
498        } else if let Some(scopes_array) = claims.get("scopes").and_then(|v| v.as_array()) {
499            Ok(scopes_array
500                .iter()
501                .filter_map(|v| v.as_str())
502                .map(String::from)
503                .collect())
504        } else {
505            Ok(vec![])
506        }
507    }
508}
509
510#[cfg(test)]
511mod tests {
512    use super::*;
513
514    fn make_manager() -> JwtManager {
515        let config = JwtConfig::with_symmetric_key(
516            b"a-test-secret-key-with-enough-bytes-for-hmac",
517            "https://test-issuer.example.com".into(),
518        );
519        JwtManager::new(config)
520    }
521
522    // ── JwtConfig ───────────────────────────────────────────────────────
523
524    #[test]
525    fn test_jwt_config_symmetric() {
526        let config = JwtConfig::with_symmetric_key(b"secret", "iss".into());
527        assert_eq!(config.issuer, "iss");
528        assert_eq!(config.default_expiration, 3600);
529    }
530
531    #[test]
532    fn test_jwt_config_with_audience() {
533        let config =
534            JwtConfig::with_symmetric_key(b"secret", "iss".into()).with_audience("aud1".into());
535        assert_eq!(config.audiences, vec!["aud1"]);
536    }
537
538    #[test]
539    fn test_jwt_config_with_expiration() {
540        let config = JwtConfig::with_symmetric_key(b"secret", "iss".into()).with_expiration(7200);
541        assert_eq!(config.default_expiration, 7200);
542    }
543
544    // ── CommonJwtClaims ─────────────────────────────────────────────────
545
546    #[test]
547    fn test_claims_new() {
548        let claims = CommonJwtClaims::new(
549            "issuer".into(),
550            "subject".into(),
551            vec!["aud".into()],
552            9999999999,
553        );
554        assert_eq!(claims.iss, "issuer");
555        assert_eq!(claims.sub, "subject");
556        assert!(claims.iat > 0);
557    }
558
559    #[test]
560    fn test_claims_with_custom_claim() {
561        let claims = CommonJwtClaims::new("iss".into(), "sub".into(), vec![], 9999999999)
562            .with_custom_claim("role".to_string(), serde_json::json!("admin"));
563        assert_eq!(claims.custom.get("role").unwrap(), "admin");
564    }
565
566    #[test]
567    fn test_claims_with_jti() {
568        let claims = CommonJwtClaims::new("iss".into(), "sub".into(), vec![], 9999999999)
569            .with_jti("test-jti-value".into());
570        assert!(claims.jti.is_some());
571    }
572
573    // ── JwtManager create/verify ────────────────────────────────────────
574
575    #[test]
576    fn test_create_and_verify_token() {
577        let mgr = make_manager();
578        let claims = CommonJwtClaims::new(
579            "https://test-issuer.example.com".into(),
580            "user_123".into(),
581            vec![],
582            (SystemTime::now()
583                .duration_since(UNIX_EPOCH)
584                .unwrap()
585                .as_secs()
586                + 3600) as i64,
587        );
588        let token = mgr.create_token(&claims).unwrap();
589        let verified = mgr.verify_token(&token).unwrap();
590        assert_eq!(verified.sub, "user_123");
591    }
592
593    #[test]
594    fn test_verify_invalid_token() {
595        let mgr = make_manager();
596        assert!(mgr.verify_token("not.a.valid.jwt").is_err());
597    }
598
599    #[test]
600    fn test_verify_wrong_key() {
601        let mgr1 = make_manager();
602        let mgr2 = JwtManager::new(JwtConfig::with_symmetric_key(
603            b"different-key-entirely-for-testing",
604            "https://test-issuer.example.com".into(),
605        ));
606        let claims = CommonJwtClaims::new(
607            "https://test-issuer.example.com".into(),
608            "user".into(),
609            vec![],
610            (SystemTime::now()
611                .duration_since(UNIX_EPOCH)
612                .unwrap()
613                .as_secs()
614                + 3600) as i64,
615        );
616        let token = mgr1.create_token(&claims).unwrap();
617        assert!(mgr2.verify_token(&token).is_err());
618    }
619
620    // ── Specialized token creation ──────────────────────────────────────
621
622    #[test]
623    fn test_create_access_token() {
624        let mgr = make_manager();
625        let token = mgr
626            .create_access_token(
627                "user_1".into(),
628                vec!["read".into()],
629                Some("client_1".into()),
630            )
631            .unwrap();
632        let claims = mgr.verify_token(&token).unwrap();
633        assert_eq!(claims.sub, "user_1");
634        assert!(claims.custom.contains_key("scope"));
635    }
636
637    #[test]
638    fn test_create_refresh_token() {
639        let mgr = make_manager();
640        let token = mgr
641            .create_refresh_token("user_2".into(), "client_2".into())
642            .unwrap();
643        let claims = mgr.verify_token(&token).unwrap();
644        assert_eq!(claims.sub, "user_2");
645        assert_eq!(
646            claims.custom.get("token_type").unwrap(),
647            &serde_json::json!("refresh_token")
648        );
649    }
650
651    #[test]
652    fn test_create_id_token() {
653        let mgr = make_manager();
654        let user_info = HashMap::from([
655            ("name".into(), serde_json::json!("Test User")),
656            ("email".into(), serde_json::json!("test@example.com")),
657        ]);
658        let token = mgr
659            .create_id_token("user_3".into(), Some("nonce_123".into()), None, user_info)
660            .unwrap();
661        let claims = mgr.verify_token(&token).unwrap();
662        assert_eq!(claims.sub, "user_3");
663        assert_eq!(claims.custom.get("nonce").unwrap(), "nonce_123");
664        assert_eq!(
665            claims.custom.get("token_type").unwrap(),
666            &serde_json::json!("id_token")
667        );
668    }
669
670    // ── Utils ───────────────────────────────────────────────────────────
671
672    #[test]
673    fn test_extract_claims_unsafe_works() {
674        let mgr = make_manager();
675        let claims = CommonJwtClaims::new(
676            "https://test-issuer.example.com".into(),
677            "peek_user".into(),
678            vec![],
679            (SystemTime::now()
680                .duration_since(UNIX_EPOCH)
681                .unwrap()
682                .as_secs()
683                + 3600) as i64,
684        );
685        let token = mgr.create_token(&claims).unwrap();
686        let extracted = utils::extract_claims_unsafe(&token).unwrap();
687        assert_eq!(extracted["sub"], "peek_user");
688    }
689
690    #[test]
691    fn test_is_token_expired_not_expired() {
692        let mgr = make_manager();
693        let token = mgr
694            .create_access_token("user".into(), vec![], None)
695            .unwrap();
696        assert!(!utils::is_token_expired(&token).unwrap());
697    }
698}