Skip to main content

fraiseql_auth/
jwt.rs

1//! JWT validation, claims parsing, and token generation.
2use std::collections::HashMap;
3
4use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
5use serde::{Deserialize, Serialize};
6
7use crate::{
8    audit::logger::{AuditEventType, SecretType, get_audit_logger},
9    error::{AuthError, Result},
10};
11
12/// Standard JWT claims with support for custom claims
13#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14pub struct Claims {
15    /// Subject (typically user ID)
16    pub sub:   String,
17    /// Issued at (Unix timestamp)
18    pub iat:   u64,
19    /// Expiration time (Unix timestamp)
20    pub exp:   u64,
21    /// Issuer
22    pub iss:   String,
23    /// Audience
24    pub aud:   Vec<String>,
25    /// Additional custom claims
26    #[serde(flatten)]
27    pub extra: HashMap<String, serde_json::Value>,
28}
29
30impl Claims {
31    /// Get a custom claim by name
32    pub fn get_custom(&self, key: &str) -> Option<&serde_json::Value> {
33        self.extra.get(key)
34    }
35
36    /// Check if token is expired
37    ///
38    /// SECURITY: If system time cannot be determined, returns true (treats token as expired)
39    /// This is a fail-safe approach to prevent accepting tokens when we can't verify expiry
40    pub fn is_expired(&self) -> bool {
41        let now = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
42            Ok(duration) => duration.as_secs(),
43            Err(e) => {
44                // CRITICAL: System time failure - treat token as expired (fail-safe)
45                // Log this critical error for operators to investigate
46                tracing::error!(
47                    error = %e,
48                    "CRITICAL: System time error in token expiry check — \
49                     this indicates a system clock issue. Token rejected as safety measure."
50                );
51                // Return current time as far in the future to ensure token is expired
52                u64::MAX
53            },
54        };
55        self.exp <= now
56    }
57}
58
59/// JWT validator configuration and validation logic
60pub struct JwtValidator {
61    validation: Validation,
62    issuer:     String,
63}
64
65impl JwtValidator {
66    /// Create a new JWT validator for a specific issuer
67    ///
68    /// # Arguments
69    /// * `issuer` - The expected issuer URL
70    /// * `algorithm` - The signing algorithm (e.g., RS256, HS256)
71    ///
72    /// # Errors
73    /// Returns error if configuration is invalid
74    pub fn new(issuer: &str, algorithm: Algorithm) -> Result<Self> {
75        if issuer.is_empty() {
76            return Err(AuthError::ConfigError {
77                message: "Issuer cannot be empty".to_string(),
78            });
79        }
80
81        let mut validation = Validation::new(algorithm);
82        validation.set_issuer(&[issuer]);
83        // Require the `aud` claim to be present in every token.
84        // `validate_aud = true` without a configured expected audience means any non-empty
85        // `aud` value is accepted; callers should further restrict this by calling
86        // `with_audiences()` to pin the validator to specific service audiences.
87        // Setting `validate_aud = false` (the previous default) silently accepts tokens
88        // issued for any service — a cross-service token replay vulnerability.
89        validation.validate_aud = true;
90
91        Ok(Self {
92            validation,
93            issuer: issuer.to_string(),
94        })
95    }
96
97    /// Set the audiences that this validator will accept.
98    ///
99    /// Recommended for production to restrict JWT usage to specific services.
100    ///
101    /// # Errors
102    ///
103    /// Returns [`AuthError::ConfigError`] if `audiences` is empty.
104    pub fn with_audiences(mut self, audiences: &[&str]) -> Result<Self> {
105        if audiences.is_empty() {
106            return Err(AuthError::ConfigError {
107                message: "At least one audience must be configured".to_string(),
108            });
109        }
110
111        self.validation
112            .set_audience(&audiences.iter().map(|s| (*s).to_string()).collect::<Vec<_>>());
113        self.validation.validate_aud = true;
114
115        Ok(self)
116    }
117
118    /// Validate a JWT token and extract claims
119    ///
120    /// # Arguments
121    /// * `token` - The JWT token string
122    /// * `key` - The public key bytes for signature verification
123    ///
124    /// # Errors
125    /// Returns various errors: invalid token, expired token, invalid signature, etc.
126    pub fn validate(&self, token: &str, key: &[u8]) -> Result<Claims> {
127        let decoding_key = DecodingKey::from_rsa_pem(key).map_err(|e| AuthError::InvalidToken {
128            reason: format!("Failed to parse public key: {}", e),
129        })?;
130
131        let token_data = decode::<Claims>(token, &decoding_key, &self.validation).map_err(|e| {
132            use jsonwebtoken::errors::ErrorKind;
133            let error = match e.kind() {
134                ErrorKind::ExpiredSignature => AuthError::TokenExpired,
135                ErrorKind::InvalidSignature => AuthError::InvalidSignature,
136                ErrorKind::InvalidIssuer => AuthError::InvalidToken {
137                    reason: format!("Invalid issuer, expected: {}", self.issuer),
138                },
139                ErrorKind::MissingRequiredClaim(claim) => AuthError::MissingClaim {
140                    claim: claim.clone(),
141                },
142                _ => AuthError::InvalidToken {
143                    reason: e.to_string(),
144                },
145            };
146
147            // Audit log: JWT validation failure
148            let audit_logger = get_audit_logger();
149            audit_logger.log_failure(
150                AuditEventType::JwtValidation,
151                SecretType::JwtToken,
152                None, // Subject not yet known at this point
153                "validate",
154                &e.to_string(),
155            );
156
157            error
158        })?;
159
160        let claims = token_data.claims;
161
162        // Additional validation: check if token is expired (redundant but explicit)
163        if claims.is_expired() {
164            let audit_logger = get_audit_logger();
165            audit_logger.log_failure(
166                AuditEventType::JwtValidation,
167                SecretType::JwtToken,
168                Some(claims.sub),
169                "validate",
170                "Token expired",
171            );
172            return Err(AuthError::TokenExpired);
173        }
174
175        // Audit log: JWT validation success
176        let audit_logger = get_audit_logger();
177        audit_logger.log_success(
178            AuditEventType::JwtValidation,
179            SecretType::JwtToken,
180            Some(claims.sub.clone()),
181            "validate",
182        );
183
184        Ok(claims)
185    }
186
187    /// Validate with HMAC secret (symmetric key)
188    ///
189    /// # Arguments
190    /// * `token` - The JWT token string
191    /// * `secret` - The shared secret for HMAC algorithms
192    ///
193    /// # Errors
194    /// Returns various errors similar to `validate`
195    pub fn validate_hmac(&self, token: &str, secret: &[u8]) -> Result<Claims> {
196        let decoding_key = DecodingKey::from_secret(secret);
197
198        let token_data = decode::<Claims>(token, &decoding_key, &self.validation).map_err(|e| {
199            use jsonwebtoken::errors::ErrorKind;
200            match e.kind() {
201                ErrorKind::ExpiredSignature => AuthError::TokenExpired,
202                ErrorKind::InvalidSignature => AuthError::InvalidSignature,
203                ErrorKind::InvalidIssuer => AuthError::InvalidToken {
204                    reason: format!("Invalid issuer, expected: {}", self.issuer),
205                },
206                ErrorKind::MissingRequiredClaim(claim) => AuthError::MissingClaim {
207                    claim: claim.clone(),
208                },
209                _ => AuthError::InvalidToken {
210                    reason: e.to_string(),
211                },
212            }
213        })?;
214
215        let claims = token_data.claims;
216
217        if claims.is_expired() {
218            return Err(AuthError::TokenExpired);
219        }
220
221        Ok(claims)
222    }
223}
224
225/// Generate a JWT token with RS256 signature
226///
227/// # Arguments
228/// * `claims` - The JWT claims to sign
229/// * `private_key_pem` - RSA private key in PEM format
230///
231/// # Errors
232/// Returns error if token generation or signing fails
233pub fn generate_rs256_token(claims: &Claims, private_key_pem: &[u8]) -> Result<String> {
234    let encoding_key =
235        EncodingKey::from_rsa_pem(private_key_pem).map_err(|e| AuthError::Internal {
236            message: format!("Failed to parse private key: {}", e),
237        })?;
238
239    let header = Header::new(Algorithm::RS256);
240    encode(&header, claims, &encoding_key).map_err(|e| AuthError::Internal {
241        message: format!("Failed to generate RS256 token: {}", e),
242    })
243}
244
245/// Generate a JWT token with HMAC secret (HS256)
246///
247/// # Arguments
248/// * `claims` - The JWT claims to sign
249/// * `secret` - The shared secret for HMAC
250///
251/// # Errors
252/// Returns error if token generation or signing fails
253pub fn generate_hs256_token(claims: &Claims, secret: &[u8]) -> Result<String> {
254    let encoding_key = EncodingKey::from_secret(secret);
255    encode(&Header::default(), claims, &encoding_key).map_err(|e| AuthError::Internal {
256        message: format!("Failed to generate HS256 token: {}", e),
257    })
258}
259
260/// Generate a JWT token (for testing and token creation)
261///
262/// # Errors
263///
264/// Returns `AuthError::Internal` if token encoding fails.
265#[cfg(test)]
266pub fn generate_test_token(claims: &Claims, secret: &[u8]) -> Result<String> {
267    generate_hs256_token(claims, secret)
268}
269
270#[cfg(test)]
271mod tests {
272    #[allow(clippy::wildcard_imports)]
273    // Reason: test module — wildcard keeps test boilerplate minimal
274    use super::*;
275
276    fn create_test_claims() -> Claims {
277        let now = std::time::SystemTime::now()
278            .duration_since(std::time::UNIX_EPOCH)
279            .unwrap_or_default()
280            .as_secs();
281
282        Claims {
283            sub:   "user123".to_string(),
284            iat:   now,
285            exp:   now + 3600, // 1 hour expiry
286            iss:   "https://example.com".to_string(),
287            aud:   vec!["api".to_string()],
288            extra: HashMap::new(),
289        }
290    }
291
292    #[test]
293    fn test_jwt_validator_creation() {
294        JwtValidator::new("https://example.com", Algorithm::HS256)
295            .unwrap_or_else(|e| panic!("expected Ok for valid issuer: {e}"));
296    }
297
298    #[test]
299    fn test_jwt_validator_invalid_issuer() {
300        let validator = JwtValidator::new("", Algorithm::HS256);
301        assert!(matches!(validator, Err(AuthError::ConfigError { .. })));
302    }
303
304    #[test]
305    fn test_claims_is_expired() {
306        let now = std::time::SystemTime::now()
307            .duration_since(std::time::UNIX_EPOCH)
308            .unwrap_or_default()
309            .as_secs();
310
311        let mut claims = create_test_claims();
312        claims.exp = now - 100; // Already expired
313
314        assert!(claims.is_expired());
315    }
316
317    #[test]
318    fn test_claims_not_expired() {
319        let now = std::time::SystemTime::now()
320            .duration_since(std::time::UNIX_EPOCH)
321            .unwrap_or_default()
322            .as_secs();
323
324        let mut claims = create_test_claims();
325        claims.exp = now + 3600; // Expires in 1 hour
326
327        assert!(!claims.is_expired());
328    }
329
330    /// Helper: create a validator configured for the test audience "api".
331    fn make_test_validator() -> JwtValidator {
332        JwtValidator::new("https://example.com", Algorithm::HS256)
333            .expect("Failed to create validator")
334            .with_audiences(&["api"])
335            .expect("Failed to set audiences")
336    }
337
338    #[test]
339    fn test_generate_and_validate_token() {
340        let secret = b"test_secret_key_at_least_32_bytes_long";
341        let validator = make_test_validator();
342
343        let claims = create_test_claims();
344        let token = generate_test_token(&claims, secret).expect("Failed to generate token");
345
346        let validated_claims =
347            validator.validate_hmac(&token, secret).expect("Failed to validate token");
348
349        assert_eq!(validated_claims.sub, claims.sub);
350        assert_eq!(validated_claims.iss, claims.iss);
351    }
352
353    #[test]
354    fn test_validate_without_audience_rejects_token() {
355        // A validator created without `with_audiences()` must reject tokens
356        // (audience claim required but no expected audience configured).
357        let secret = b"test_secret_key_at_least_32_bytes_long";
358        let validator = JwtValidator::new("https://example.com", Algorithm::HS256)
359            .expect("Failed to create validator");
360
361        let claims = create_test_claims();
362        let token = generate_test_token(&claims, secret).expect("Failed to generate token");
363
364        let result = validator.validate_hmac(&token, secret);
365        assert!(result.is_err(), "validator without configured audience must reject tokens");
366    }
367
368    #[test]
369    fn test_validate_expired_token() {
370        let secret = b"test_secret_key_at_least_32_bytes_long";
371        let validator = make_test_validator();
372
373        let now = std::time::SystemTime::now()
374            .duration_since(std::time::UNIX_EPOCH)
375            .unwrap_or_default()
376            .as_secs();
377
378        let mut claims = create_test_claims();
379        claims.exp = now - 100; // Already expired
380
381        let token = generate_test_token(&claims, secret).expect("Failed to generate token");
382
383        let result = validator.validate_hmac(&token, secret);
384        assert!(matches!(result, Err(AuthError::TokenExpired)));
385    }
386
387    #[test]
388    fn test_validate_invalid_signature() {
389        let secret = b"test_secret_key_at_least_32_bytes_long";
390        let validator = make_test_validator();
391
392        let claims = create_test_claims();
393        let token = generate_test_token(&claims, secret).expect("Failed to generate token");
394
395        let wrong_secret = b"wrong_secret_key_at_least_32_bytes_";
396        let result = validator.validate_hmac(&token, wrong_secret);
397        assert!(matches!(result, Err(AuthError::InvalidSignature)));
398    }
399
400    #[test]
401    fn test_get_custom_claim() {
402        let mut claims = create_test_claims();
403        claims.extra.insert("email".to_string(), serde_json::json!("user@example.com"));
404        claims.extra.insert("role".to_string(), serde_json::json!("admin"));
405
406        assert_eq!(claims.get_custom("email"), Some(&serde_json::json!("user@example.com")));
407        assert_eq!(claims.get_custom("role"), Some(&serde_json::json!("admin")));
408        assert_eq!(claims.get_custom("nonexistent"), None);
409    }
410}