allframe_core/auth/
jwt.rs

1//! JWT (JSON Web Token) validation.
2//!
3//! Provides JWT validation using the `jsonwebtoken` crate with support for
4//! HS256 (HMAC) and RS256 (RSA) algorithms.
5//!
6//! # Example
7//!
8//! ```rust,ignore
9//! use allframe_core::auth::{JwtValidator, JwtConfig, Authenticator};
10//! use serde::Deserialize;
11//!
12//! #[derive(Debug, Clone, Deserialize)]
13//! struct MyClaims {
14//!     sub: String,
15//!     email: Option<String>,
16//!     role: Option<String>,
17//! }
18//!
19//! // Create validator with HS256
20//! let config = JwtConfig::hs256("my-secret-key")
21//!     .with_issuer("my-app")
22//!     .with_leeway(60);
23//!
24//! let validator = JwtValidator::<MyClaims>::new(config);
25//!
26//! // Validate a token
27//! let claims = validator.authenticate("eyJ...").await?;
28//! println!("User: {}", claims.sub);
29//! ```
30
31use std::marker::PhantomData;
32
33use jsonwebtoken::{decode, DecodingKey, TokenData, Validation};
34use serde::de::DeserializeOwned;
35
36use super::{AuthError, Authenticator};
37
38/// JWT algorithm to use for validation.
39#[derive(Debug, Clone)]
40pub enum JwtAlgorithm {
41    /// HMAC with SHA-256 (symmetric key).
42    HS256(String),
43    /// HMAC with SHA-384 (symmetric key).
44    HS384(String),
45    /// HMAC with SHA-512 (symmetric key).
46    HS512(String),
47    /// RSA with SHA-256 (public key in PEM format).
48    RS256(String),
49    /// RSA with SHA-384 (public key in PEM format).
50    RS384(String),
51    /// RSA with SHA-512 (public key in PEM format).
52    RS512(String),
53    /// EdDSA (public key in PEM format).
54    EdDSA(String),
55}
56
57impl JwtAlgorithm {
58    fn jsonwebtoken_algorithm(&self) -> jsonwebtoken::Algorithm {
59        match self {
60            JwtAlgorithm::HS256(_) => jsonwebtoken::Algorithm::HS256,
61            JwtAlgorithm::HS384(_) => jsonwebtoken::Algorithm::HS384,
62            JwtAlgorithm::HS512(_) => jsonwebtoken::Algorithm::HS512,
63            JwtAlgorithm::RS256(_) => jsonwebtoken::Algorithm::RS256,
64            JwtAlgorithm::RS384(_) => jsonwebtoken::Algorithm::RS384,
65            JwtAlgorithm::RS512(_) => jsonwebtoken::Algorithm::RS512,
66            JwtAlgorithm::EdDSA(_) => jsonwebtoken::Algorithm::EdDSA,
67        }
68    }
69
70    fn decoding_key(&self) -> Result<DecodingKey, AuthError> {
71        match self {
72            JwtAlgorithm::HS256(secret)
73            | JwtAlgorithm::HS384(secret)
74            | JwtAlgorithm::HS512(secret) => Ok(DecodingKey::from_secret(secret.as_bytes())),
75            JwtAlgorithm::RS256(pem)
76            | JwtAlgorithm::RS384(pem)
77            | JwtAlgorithm::RS512(pem) => DecodingKey::from_rsa_pem(pem.as_bytes())
78                .map_err(|e| AuthError::Internal(format!("Invalid RSA key: {}", e))),
79            JwtAlgorithm::EdDSA(pem) => DecodingKey::from_ed_pem(pem.as_bytes())
80                .map_err(|e| AuthError::Internal(format!("Invalid EdDSA key: {}", e))),
81        }
82    }
83}
84
85/// Configuration for JWT validation.
86#[derive(Debug, Clone)]
87pub struct JwtConfig {
88    /// Algorithm and key for validation.
89    pub algorithm: JwtAlgorithm,
90    /// Expected issuer (iss claim).
91    pub issuer: Option<String>,
92    /// Expected audience (aud claim).
93    pub audience: Option<String>,
94    /// Leeway in seconds for expiration checks.
95    pub leeway_seconds: u64,
96    /// Whether to validate expiration.
97    pub validate_exp: bool,
98    /// Whether to validate not-before.
99    pub validate_nbf: bool,
100}
101
102impl JwtConfig {
103    /// Create a new config with HS256 algorithm.
104    ///
105    /// # Example
106    ///
107    /// ```rust
108    /// use allframe_core::auth::JwtConfig;
109    ///
110    /// let config = JwtConfig::hs256("my-secret-key");
111    /// ```
112    pub fn hs256(secret: impl Into<String>) -> Self {
113        Self {
114            algorithm: JwtAlgorithm::HS256(secret.into()),
115            issuer: None,
116            audience: None,
117            leeway_seconds: 0,
118            validate_exp: true,
119            validate_nbf: true,
120        }
121    }
122
123    /// Create a new config with HS384 algorithm.
124    pub fn hs384(secret: impl Into<String>) -> Self {
125        Self {
126            algorithm: JwtAlgorithm::HS384(secret.into()),
127            issuer: None,
128            audience: None,
129            leeway_seconds: 0,
130            validate_exp: true,
131            validate_nbf: true,
132        }
133    }
134
135    /// Create a new config with HS512 algorithm.
136    pub fn hs512(secret: impl Into<String>) -> Self {
137        Self {
138            algorithm: JwtAlgorithm::HS512(secret.into()),
139            issuer: None,
140            audience: None,
141            leeway_seconds: 0,
142            validate_exp: true,
143            validate_nbf: true,
144        }
145    }
146
147    /// Create a new config with RS256 algorithm.
148    ///
149    /// # Arguments
150    /// * `public_key_pem` - RSA public key in PEM format
151    pub fn rs256(public_key_pem: impl Into<String>) -> Self {
152        Self {
153            algorithm: JwtAlgorithm::RS256(public_key_pem.into()),
154            issuer: None,
155            audience: None,
156            leeway_seconds: 0,
157            validate_exp: true,
158            validate_nbf: true,
159        }
160    }
161
162    /// Create a new config with RS384 algorithm.
163    pub fn rs384(public_key_pem: impl Into<String>) -> Self {
164        Self {
165            algorithm: JwtAlgorithm::RS384(public_key_pem.into()),
166            issuer: None,
167            audience: None,
168            leeway_seconds: 0,
169            validate_exp: true,
170            validate_nbf: true,
171        }
172    }
173
174    /// Create a new config with RS512 algorithm.
175    pub fn rs512(public_key_pem: impl Into<String>) -> Self {
176        Self {
177            algorithm: JwtAlgorithm::RS512(public_key_pem.into()),
178            issuer: None,
179            audience: None,
180            leeway_seconds: 0,
181            validate_exp: true,
182            validate_nbf: true,
183        }
184    }
185
186    /// Create a new config with EdDSA algorithm.
187    pub fn eddsa(public_key_pem: impl Into<String>) -> Self {
188        Self {
189            algorithm: JwtAlgorithm::EdDSA(public_key_pem.into()),
190            issuer: None,
191            audience: None,
192            leeway_seconds: 0,
193            validate_exp: true,
194            validate_nbf: true,
195        }
196    }
197
198    /// Create config from environment variables.
199    ///
200    /// Reads:
201    /// - `JWT_SECRET` - HMAC secret (for HS256)
202    /// - `JWT_PUBLIC_KEY` - RSA public key (for RS256, overrides JWT_SECRET)
203    /// - `JWT_ISSUER` - Expected issuer
204    /// - `JWT_AUDIENCE` - Expected audience
205    /// - `JWT_LEEWAY` - Leeway in seconds (default: 60)
206    pub fn from_env() -> Option<Self> {
207        let config = if let Ok(public_key) = std::env::var("JWT_PUBLIC_KEY") {
208            Self::rs256(public_key)
209        } else if let Ok(secret) = std::env::var("JWT_SECRET") {
210            Self::hs256(secret)
211        } else {
212            return None;
213        };
214
215        let config = if let Ok(issuer) = std::env::var("JWT_ISSUER") {
216            config.with_issuer(issuer)
217        } else {
218            config
219        };
220
221        let config = if let Ok(audience) = std::env::var("JWT_AUDIENCE") {
222            config.with_audience(audience)
223        } else {
224            config
225        };
226
227        let config = if let Ok(leeway) = std::env::var("JWT_LEEWAY") {
228            if let Ok(seconds) = leeway.parse() {
229                config.with_leeway(seconds)
230            } else {
231                config
232            }
233        } else {
234            config.with_leeway(60) // Default 60 second leeway
235        };
236
237        Some(config)
238    }
239
240    /// Set the expected issuer.
241    pub fn with_issuer(mut self, issuer: impl Into<String>) -> Self {
242        self.issuer = Some(issuer.into());
243        self
244    }
245
246    /// Set the expected audience.
247    pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
248        self.audience = Some(audience.into());
249        self
250    }
251
252    /// Set the leeway for expiration checks (in seconds).
253    pub fn with_leeway(mut self, seconds: u64) -> Self {
254        self.leeway_seconds = seconds;
255        self
256    }
257
258    /// Disable expiration validation.
259    pub fn without_exp_validation(mut self) -> Self {
260        self.validate_exp = false;
261        self
262    }
263
264    /// Disable not-before validation.
265    pub fn without_nbf_validation(mut self) -> Self {
266        self.validate_nbf = false;
267        self
268    }
269}
270
271/// JWT token validator.
272///
273/// Validates JWT tokens and extracts claims into a typed struct.
274///
275/// # Type Parameters
276///
277/// * `C` - The claims type. Must implement `DeserializeOwned`.
278///
279/// # Example
280///
281/// ```rust,ignore
282/// use allframe_core::auth::{JwtValidator, JwtConfig, Authenticator};
283/// use serde::Deserialize;
284///
285/// #[derive(Debug, Clone, Deserialize)]
286/// struct Claims {
287///     sub: String,
288///     exp: i64,
289/// }
290///
291/// let validator = JwtValidator::<Claims>::new(JwtConfig::hs256("secret"));
292/// let claims = validator.authenticate("eyJ...").await?;
293/// ```
294pub struct JwtValidator<C> {
295    config: JwtConfig,
296    decoding_key: DecodingKey,
297    validation: Validation,
298    _phantom: PhantomData<C>,
299}
300
301impl<C> JwtValidator<C> {
302    /// Create a new JWT validator.
303    ///
304    /// # Panics
305    ///
306    /// Panics if the key in the config is invalid (e.g., malformed PEM).
307    /// Use `try_new` for fallible construction.
308    pub fn new(config: JwtConfig) -> Self {
309        Self::try_new(config).expect("Invalid JWT configuration")
310    }
311
312    /// Try to create a new JWT validator.
313    ///
314    /// Returns an error if the key is invalid.
315    pub fn try_new(config: JwtConfig) -> Result<Self, AuthError> {
316        let decoding_key = config.algorithm.decoding_key()?;
317
318        let mut validation = Validation::new(config.algorithm.jsonwebtoken_algorithm());
319        validation.leeway = config.leeway_seconds;
320        validation.validate_exp = config.validate_exp;
321        validation.validate_nbf = config.validate_nbf;
322
323        if let Some(ref iss) = config.issuer {
324            validation.set_issuer(&[iss]);
325        }
326
327        if let Some(ref aud) = config.audience {
328            validation.set_audience(&[aud]);
329        } else {
330            // By default jsonwebtoken requires audience validation
331            // Disable if not configured
332            validation.validate_aud = false;
333        }
334
335        Ok(Self {
336            config,
337            decoding_key,
338            validation,
339            _phantom: PhantomData,
340        })
341    }
342
343    /// Get the configuration.
344    pub fn config(&self) -> &JwtConfig {
345        &self.config
346    }
347}
348
349impl<C: Clone + Send + Sync + DeserializeOwned + 'static> JwtValidator<C> {
350    /// Validate a token and return the claims.
351    pub fn validate(&self, token: &str) -> Result<C, AuthError> {
352        let token_data: TokenData<C> =
353            decode(token, &self.decoding_key, &self.validation).map_err(|e| {
354                use jsonwebtoken::errors::ErrorKind;
355                match e.kind() {
356                    ErrorKind::ExpiredSignature => AuthError::TokenExpired,
357                    ErrorKind::InvalidSignature => AuthError::InvalidSignature,
358                    ErrorKind::InvalidIssuer => AuthError::InvalidIssuer,
359                    ErrorKind::InvalidAudience => AuthError::InvalidAudience,
360                    ErrorKind::InvalidToken => AuthError::InvalidToken("malformed token".into()),
361                    ErrorKind::InvalidAlgorithm => {
362                        AuthError::InvalidToken("wrong algorithm".into())
363                    }
364                    _ => AuthError::InvalidToken(e.to_string()),
365                }
366            })?;
367
368        Ok(token_data.claims)
369    }
370}
371
372#[async_trait::async_trait]
373impl<C: Clone + Send + Sync + DeserializeOwned + 'static> Authenticator for JwtValidator<C> {
374    type Claims = C;
375
376    async fn authenticate(&self, token: &str) -> Result<Self::Claims, AuthError> {
377        // JWT validation is synchronous, but we implement async trait for consistency
378        self.validate(token)
379    }
380}
381
382/// Standard JWT claims structure.
383///
384/// Use this if you don't need custom claims, or as a base for your own type.
385#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
386pub struct StandardClaims {
387    /// Subject (usually user ID).
388    pub sub: String,
389    /// Expiration time (Unix timestamp).
390    #[serde(default)]
391    pub exp: Option<i64>,
392    /// Issued at (Unix timestamp).
393    #[serde(default)]
394    pub iat: Option<i64>,
395    /// Not before (Unix timestamp).
396    #[serde(default)]
397    pub nbf: Option<i64>,
398    /// Issuer.
399    #[serde(default)]
400    pub iss: Option<String>,
401    /// Audience.
402    #[serde(default)]
403    pub aud: Option<String>,
404}
405
406impl super::HasSubject for StandardClaims {
407    fn subject(&self) -> &str {
408        &self.sub
409    }
410}
411
412impl super::HasExpiration for StandardClaims {
413    fn expiration(&self) -> Option<i64> {
414        self.exp
415    }
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421
422    // Test token created with secret "test-secret", sub "user123", exp far in future
423    const TEST_SECRET: &str = "test-secret-that-is-long-enough-for-hs256";
424
425    fn create_test_token(claims: &impl serde::Serialize) -> String {
426        use jsonwebtoken::{encode, EncodingKey, Header};
427        encode(
428            &Header::default(),
429            claims,
430            &EncodingKey::from_secret(TEST_SECRET.as_bytes()),
431        )
432        .unwrap()
433    }
434
435    #[derive(Debug, Clone, serde::Serialize, serde::Deserialize, PartialEq)]
436    struct TestClaims {
437        sub: String,
438        exp: i64,
439    }
440
441    #[test]
442    fn test_jwt_config_hs256() {
443        let config = JwtConfig::hs256("secret");
444        assert!(matches!(config.algorithm, JwtAlgorithm::HS256(_)));
445        assert!(config.issuer.is_none());
446        assert!(config.audience.is_none());
447    }
448
449    #[test]
450    fn test_jwt_config_builder() {
451        let config = JwtConfig::hs256("secret")
452            .with_issuer("my-app")
453            .with_audience("my-audience")
454            .with_leeway(120);
455
456        assert_eq!(config.issuer, Some("my-app".to_string()));
457        assert_eq!(config.audience, Some("my-audience".to_string()));
458        assert_eq!(config.leeway_seconds, 120);
459    }
460
461    #[test]
462    fn test_jwt_validator_valid_token() {
463        let claims = TestClaims {
464            sub: "user123".to_string(),
465            exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
466        };
467        let token = create_test_token(&claims);
468
469        let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256(TEST_SECRET));
470
471        let result = validator.validate(&token);
472        assert!(result.is_ok());
473
474        let decoded = result.unwrap();
475        assert_eq!(decoded.sub, "user123");
476    }
477
478    #[test]
479    fn test_jwt_validator_expired_token() {
480        let claims = TestClaims {
481            sub: "user123".to_string(),
482            exp: 0, // Expired long ago
483        };
484        let token = create_test_token(&claims);
485
486        let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256(TEST_SECRET));
487
488        let result = validator.validate(&token);
489        assert!(matches!(result, Err(AuthError::TokenExpired)));
490    }
491
492    #[test]
493    fn test_jwt_validator_wrong_secret() {
494        let claims = TestClaims {
495            sub: "user123".to_string(),
496            exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
497        };
498        let token = create_test_token(&claims);
499
500        let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256("wrong-secret"));
501
502        let result = validator.validate(&token);
503        assert!(matches!(result, Err(AuthError::InvalidSignature)));
504    }
505
506    #[test]
507    fn test_jwt_validator_invalid_token() {
508        let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256(TEST_SECRET));
509
510        let result = validator.validate("not-a-jwt");
511        assert!(matches!(result, Err(AuthError::InvalidToken(_))));
512    }
513
514    #[tokio::test]
515    async fn test_jwt_validator_authenticator_trait() {
516        let claims = TestClaims {
517            sub: "user123".to_string(),
518            exp: (chrono::Utc::now() + chrono::Duration::hours(1)).timestamp(),
519        };
520        let token = create_test_token(&claims);
521
522        let validator = JwtValidator::<TestClaims>::new(JwtConfig::hs256(TEST_SECRET));
523
524        // Use via Authenticator trait
525        let result = validator.authenticate(&token).await;
526        assert!(result.is_ok());
527        assert_eq!(result.unwrap().sub, "user123");
528    }
529
530    #[test]
531    fn test_standard_claims() {
532        let claims = StandardClaims {
533            sub: "user123".to_string(),
534            exp: Some(1234567890),
535            iat: Some(1234567800),
536            nbf: None,
537            iss: Some("my-app".to_string()),
538            aud: None,
539        };
540
541        assert_eq!(claims.sub, "user123");
542        assert_eq!(claims.exp, Some(1234567890));
543
544        // Test HasSubject trait
545        use super::super::HasSubject;
546        assert_eq!(claims.subject(), "user123");
547
548        // Test HasExpiration trait
549        use super::super::HasExpiration;
550        assert_eq!(claims.expiration(), Some(1234567890));
551    }
552}