1use std::collections::HashMap;
3
4use jsonwebtoken::{Algorithm, DecodingKey, EncodingKey, Header, Validation, decode, encode};
5use serde::{Deserialize, Serialize};
6
7use crate::{
8 audit::logger::{AuditEventType, SecretType, get_audit_logger},
9 error::{AuthError, Result},
10};
11
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
14pub struct Claims {
15 pub sub: String,
17 pub iat: u64,
19 pub exp: u64,
21 pub iss: String,
23 pub aud: Vec<String>,
25 #[serde(flatten)]
27 pub extra: HashMap<String, serde_json::Value>,
28}
29
30impl Claims {
31 pub fn get_custom(&self, key: &str) -> Option<&serde_json::Value> {
33 self.extra.get(key)
34 }
35
36 pub fn is_expired(&self) -> bool {
41 let now = match std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH) {
42 Ok(duration) => duration.as_secs(),
43 Err(e) => {
44 tracing::error!(
47 error = %e,
48 "CRITICAL: System time error in token expiry check — \
49 this indicates a system clock issue. Token rejected as safety measure."
50 );
51 u64::MAX
53 },
54 };
55 self.exp <= now
56 }
57}
58
59pub struct JwtValidator {
61 validation: Validation,
62 issuer: String,
63}
64
65impl JwtValidator {
66 pub fn new(issuer: &str, algorithm: Algorithm) -> Result<Self> {
75 if issuer.is_empty() {
76 return Err(AuthError::ConfigError {
77 message: "Issuer cannot be empty".to_string(),
78 });
79 }
80
81 let mut validation = Validation::new(algorithm);
82 validation.set_issuer(&[issuer]);
83 validation.validate_aud = true;
90
91 Ok(Self {
92 validation,
93 issuer: issuer.to_string(),
94 })
95 }
96
97 pub fn with_audiences(mut self, audiences: &[&str]) -> Result<Self> {
105 if audiences.is_empty() {
106 return Err(AuthError::ConfigError {
107 message: "At least one audience must be configured".to_string(),
108 });
109 }
110
111 self.validation
112 .set_audience(&audiences.iter().map(|s| (*s).to_string()).collect::<Vec<_>>());
113 self.validation.validate_aud = true;
114
115 Ok(self)
116 }
117
118 pub fn validate(&self, token: &str, key: &[u8]) -> Result<Claims> {
127 let decoding_key = DecodingKey::from_rsa_pem(key).map_err(|e| AuthError::InvalidToken {
128 reason: format!("Failed to parse public key: {}", e),
129 })?;
130
131 let token_data = decode::<Claims>(token, &decoding_key, &self.validation).map_err(|e| {
132 use jsonwebtoken::errors::ErrorKind;
133 let error = match e.kind() {
134 ErrorKind::ExpiredSignature => AuthError::TokenExpired,
135 ErrorKind::InvalidSignature => AuthError::InvalidSignature,
136 ErrorKind::InvalidIssuer => AuthError::InvalidToken {
137 reason: format!("Invalid issuer, expected: {}", self.issuer),
138 },
139 ErrorKind::MissingRequiredClaim(claim) => AuthError::MissingClaim {
140 claim: claim.clone(),
141 },
142 _ => AuthError::InvalidToken {
143 reason: e.to_string(),
144 },
145 };
146
147 let audit_logger = get_audit_logger();
149 audit_logger.log_failure(
150 AuditEventType::JwtValidation,
151 SecretType::JwtToken,
152 None, "validate",
154 &e.to_string(),
155 );
156
157 error
158 })?;
159
160 let claims = token_data.claims;
161
162 if claims.is_expired() {
164 let audit_logger = get_audit_logger();
165 audit_logger.log_failure(
166 AuditEventType::JwtValidation,
167 SecretType::JwtToken,
168 Some(claims.sub),
169 "validate",
170 "Token expired",
171 );
172 return Err(AuthError::TokenExpired);
173 }
174
175 let audit_logger = get_audit_logger();
177 audit_logger.log_success(
178 AuditEventType::JwtValidation,
179 SecretType::JwtToken,
180 Some(claims.sub.clone()),
181 "validate",
182 );
183
184 Ok(claims)
185 }
186
187 pub fn validate_hmac(&self, token: &str, secret: &[u8]) -> Result<Claims> {
196 let decoding_key = DecodingKey::from_secret(secret);
197
198 let token_data = decode::<Claims>(token, &decoding_key, &self.validation).map_err(|e| {
199 use jsonwebtoken::errors::ErrorKind;
200 match e.kind() {
201 ErrorKind::ExpiredSignature => AuthError::TokenExpired,
202 ErrorKind::InvalidSignature => AuthError::InvalidSignature,
203 ErrorKind::InvalidIssuer => AuthError::InvalidToken {
204 reason: format!("Invalid issuer, expected: {}", self.issuer),
205 },
206 ErrorKind::MissingRequiredClaim(claim) => AuthError::MissingClaim {
207 claim: claim.clone(),
208 },
209 _ => AuthError::InvalidToken {
210 reason: e.to_string(),
211 },
212 }
213 })?;
214
215 let claims = token_data.claims;
216
217 if claims.is_expired() {
218 return Err(AuthError::TokenExpired);
219 }
220
221 Ok(claims)
222 }
223}
224
225pub fn generate_rs256_token(claims: &Claims, private_key_pem: &[u8]) -> Result<String> {
234 let encoding_key =
235 EncodingKey::from_rsa_pem(private_key_pem).map_err(|e| AuthError::Internal {
236 message: format!("Failed to parse private key: {}", e),
237 })?;
238
239 let header = Header::new(Algorithm::RS256);
240 encode(&header, claims, &encoding_key).map_err(|e| AuthError::Internal {
241 message: format!("Failed to generate RS256 token: {}", e),
242 })
243}
244
245pub fn generate_hs256_token(claims: &Claims, secret: &[u8]) -> Result<String> {
254 let encoding_key = EncodingKey::from_secret(secret);
255 encode(&Header::default(), claims, &encoding_key).map_err(|e| AuthError::Internal {
256 message: format!("Failed to generate HS256 token: {}", e),
257 })
258}
259
260#[cfg(test)]
266pub fn generate_test_token(claims: &Claims, secret: &[u8]) -> Result<String> {
267 generate_hs256_token(claims, secret)
268}
269
270#[cfg(test)]
271mod tests {
272 #[allow(clippy::wildcard_imports)]
273 use super::*;
275
276 fn create_test_claims() -> Claims {
277 let now = std::time::SystemTime::now()
278 .duration_since(std::time::UNIX_EPOCH)
279 .unwrap_or_default()
280 .as_secs();
281
282 Claims {
283 sub: "user123".to_string(),
284 iat: now,
285 exp: now + 3600, iss: "https://example.com".to_string(),
287 aud: vec!["api".to_string()],
288 extra: HashMap::new(),
289 }
290 }
291
292 #[test]
293 fn test_jwt_validator_creation() {
294 JwtValidator::new("https://example.com", Algorithm::HS256)
295 .unwrap_or_else(|e| panic!("expected Ok for valid issuer: {e}"));
296 }
297
298 #[test]
299 fn test_jwt_validator_invalid_issuer() {
300 let validator = JwtValidator::new("", Algorithm::HS256);
301 assert!(matches!(validator, Err(AuthError::ConfigError { .. })));
302 }
303
304 #[test]
305 fn test_claims_is_expired() {
306 let now = std::time::SystemTime::now()
307 .duration_since(std::time::UNIX_EPOCH)
308 .unwrap_or_default()
309 .as_secs();
310
311 let mut claims = create_test_claims();
312 claims.exp = now - 100; assert!(claims.is_expired());
315 }
316
317 #[test]
318 fn test_claims_not_expired() {
319 let now = std::time::SystemTime::now()
320 .duration_since(std::time::UNIX_EPOCH)
321 .unwrap_or_default()
322 .as_secs();
323
324 let mut claims = create_test_claims();
325 claims.exp = now + 3600; assert!(!claims.is_expired());
328 }
329
330 fn make_test_validator() -> JwtValidator {
332 JwtValidator::new("https://example.com", Algorithm::HS256)
333 .expect("Failed to create validator")
334 .with_audiences(&["api"])
335 .expect("Failed to set audiences")
336 }
337
338 #[test]
339 fn test_generate_and_validate_token() {
340 let secret = b"test_secret_key_at_least_32_bytes_long";
341 let validator = make_test_validator();
342
343 let claims = create_test_claims();
344 let token = generate_test_token(&claims, secret).expect("Failed to generate token");
345
346 let validated_claims =
347 validator.validate_hmac(&token, secret).expect("Failed to validate token");
348
349 assert_eq!(validated_claims.sub, claims.sub);
350 assert_eq!(validated_claims.iss, claims.iss);
351 }
352
353 #[test]
354 fn test_validate_without_audience_rejects_token() {
355 let secret = b"test_secret_key_at_least_32_bytes_long";
358 let validator = JwtValidator::new("https://example.com", Algorithm::HS256)
359 .expect("Failed to create validator");
360
361 let claims = create_test_claims();
362 let token = generate_test_token(&claims, secret).expect("Failed to generate token");
363
364 let result = validator.validate_hmac(&token, secret);
365 assert!(result.is_err(), "validator without configured audience must reject tokens");
366 }
367
368 #[test]
369 fn test_validate_expired_token() {
370 let secret = b"test_secret_key_at_least_32_bytes_long";
371 let validator = make_test_validator();
372
373 let now = std::time::SystemTime::now()
374 .duration_since(std::time::UNIX_EPOCH)
375 .unwrap_or_default()
376 .as_secs();
377
378 let mut claims = create_test_claims();
379 claims.exp = now - 100; let token = generate_test_token(&claims, secret).expect("Failed to generate token");
382
383 let result = validator.validate_hmac(&token, secret);
384 assert!(matches!(result, Err(AuthError::TokenExpired)));
385 }
386
387 #[test]
388 fn test_validate_invalid_signature() {
389 let secret = b"test_secret_key_at_least_32_bytes_long";
390 let validator = make_test_validator();
391
392 let claims = create_test_claims();
393 let token = generate_test_token(&claims, secret).expect("Failed to generate token");
394
395 let wrong_secret = b"wrong_secret_key_at_least_32_bytes_";
396 let result = validator.validate_hmac(&token, wrong_secret);
397 assert!(matches!(result, Err(AuthError::InvalidSignature)));
398 }
399
400 #[test]
401 fn test_get_custom_claim() {
402 let mut claims = create_test_claims();
403 claims.extra.insert("email".to_string(), serde_json::json!("user@example.com"));
404 claims.extra.insert("role".to_string(), serde_json::json!("admin"));
405
406 assert_eq!(claims.get_custom("email"), Some(&serde_json::json!("user@example.com")));
407 assert_eq!(claims.get_custom("role"), Some(&serde_json::json!("admin")));
408 assert_eq!(claims.get_custom("nonexistent"), None);
409 }
410}