1use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10use thiserror::Error;
11
12use super::config::{Identity, JwtClaims, JwtConfig};
13
14#[derive(Debug, Error)]
16pub enum JwtError {
17 #[error("Invalid token format")]
18 InvalidFormat,
19
20 #[error("Token has expired")]
21 Expired,
22
23 #[error("Token not yet valid")]
24 NotYetValid,
25
26 #[error("Invalid issuer")]
27 InvalidIssuer,
28
29 #[error("Invalid audience")]
30 InvalidAudience,
31
32 #[error("Invalid signature")]
33 InvalidSignature,
34
35 #[error("Key not found: {0}")]
36 KeyNotFound(String),
37
38 #[error("Unsupported algorithm: {0}")]
39 UnsupportedAlgorithm(String),
40
41 #[error("Failed to decode: {0}")]
42 DecodeFailed(String),
43
44 #[error("JWKS fetch failed: {0}")]
45 JwksFetchFailed(String),
46}
47
48pub struct JwtValidator {
50 config: JwtConfig,
52
53 jwks: Arc<RwLock<Jwks>>,
55
56 last_refresh: Arc<RwLock<Option<Instant>>>,
58}
59
60impl JwtValidator {
61 pub fn new(config: JwtConfig) -> Self {
63 Self {
64 config,
65 jwks: Arc::new(RwLock::new(Jwks::empty())),
66 last_refresh: Arc::new(RwLock::new(None)),
67 }
68 }
69
70 pub fn validate(&self, token: &str) -> Result<JwtClaims, JwtError> {
72 let parts: Vec<&str> = token.split('.').collect();
74 if parts.len() != 3 {
75 return Err(JwtError::InvalidFormat);
76 }
77
78 let header = self.decode_header(parts[0])?;
80
81 if !self.config.allowed_algorithms.contains(&header.alg) {
83 return Err(JwtError::UnsupportedAlgorithm(header.alg));
84 }
85
86 let key = self.get_key(&header.kid)?;
88
89 self.verify_signature(token, &key)?;
91
92 let claims = self.decode_claims(parts[1])?;
94
95 self.validate_expiration(&claims)?;
97 self.validate_not_before(&claims)?;
98 self.validate_issuer(&claims)?;
99 self.validate_audience(&claims)?;
100
101 Ok(claims)
102 }
103
104 pub fn validate_to_identity(&self, token: &str) -> Result<Identity, JwtError> {
106 let claims = self.validate(token)?;
107 Ok(Identity::from_jwt_claims(&claims))
108 }
109
110 fn decode_header(&self, header_b64: &str) -> Result<JwtHeader, JwtError> {
112 let decoded = base64_decode_url_safe(header_b64)
113 .map_err(|e| JwtError::DecodeFailed(e.to_string()))?;
114
115 serde_json::from_slice(&decoded).map_err(|e| JwtError::DecodeFailed(e.to_string()))
116 }
117
118 fn decode_claims(&self, claims_b64: &str) -> Result<JwtClaims, JwtError> {
120 let decoded = base64_decode_url_safe(claims_b64)
121 .map_err(|e| JwtError::DecodeFailed(e.to_string()))?;
122
123 serde_json::from_slice(&decoded).map_err(|e| JwtError::DecodeFailed(e.to_string()))
124 }
125
126 fn get_key(&self, kid: &Option<String>) -> Result<Jwk, JwtError> {
128 let jwks = self.jwks.read();
129
130 match kid {
131 Some(kid) => jwks
132 .get_key(kid)
133 .cloned()
134 .ok_or_else(|| JwtError::KeyNotFound(kid.clone())),
135 None => jwks
136 .keys
137 .first()
138 .cloned()
139 .ok_or_else(|| JwtError::KeyNotFound("(default)".to_string())),
140 }
141 }
142
143 fn verify_signature(&self, token: &str, key: &Jwk) -> Result<(), JwtError> {
151 let parts: Vec<&str> = token.split('.').collect();
152 if parts.len() != 3 {
153 return Err(JwtError::InvalidFormat);
154 }
155 let header = self.decode_header(parts[0])?;
156 let provided_sig =
157 base64_decode_url_safe(parts[2]).map_err(|e| JwtError::DecodeFailed(e.to_string()))?;
158 let signing_input = format!("{}.{}", parts[0], parts[1]);
160
161 match header.alg.as_str() {
162 "HS256" => {
163 use hmac::{Hmac, Mac};
164 use sha2::Sha256;
165 let secret = key
166 .k
167 .as_deref()
168 .map(base64_decode_url_safe)
169 .transpose()
170 .map_err(|e| JwtError::DecodeFailed(e.to_string()))?
171 .ok_or_else(|| JwtError::KeyNotFound("HS256 symmetric key".to_string()))?;
172 let mut mac = <Hmac<Sha256>>::new_from_slice(&secret)
173 .map_err(|_| JwtError::InvalidSignature)?;
174 mac.update(signing_input.as_bytes());
175 mac.verify_slice(&provided_sig)
177 .map_err(|_| JwtError::InvalidSignature)
178 }
179 other => Err(JwtError::UnsupportedAlgorithm(other.to_string())),
182 }
183 }
184
185 pub fn set_hs256_secret(&self, kid: Option<String>, secret: &[u8]) {
188 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
189 let jwk = Jwk {
190 kty: "oct".to_string(),
191 kid,
192 alg: Some("HS256".to_string()),
193 use_: Some("sig".to_string()),
194 n: None,
195 e: None,
196 x: None,
197 y: None,
198 crv: None,
199 k: Some(URL_SAFE_NO_PAD.encode(secret)),
200 };
201 *self.jwks.write() = Jwks { keys: vec![jwk] };
202 }
203
204 fn validate_expiration(&self, claims: &JwtClaims) -> Result<(), JwtError> {
206 let now = chrono::Utc::now().timestamp();
207 let exp_with_skew = claims.exp + self.config.clock_skew.as_secs() as i64;
208
209 if now > exp_with_skew {
210 return Err(JwtError::Expired);
211 }
212
213 Ok(())
214 }
215
216 fn validate_not_before(&self, claims: &JwtClaims) -> Result<(), JwtError> {
218 if let Some(nbf) = claims.nbf {
219 let now = chrono::Utc::now().timestamp();
220 let nbf_with_skew = nbf - self.config.clock_skew.as_secs() as i64;
221
222 if now < nbf_with_skew {
223 return Err(JwtError::NotYetValid);
224 }
225 }
226
227 Ok(())
228 }
229
230 fn validate_issuer(&self, claims: &JwtClaims) -> Result<(), JwtError> {
232 if !self.config.allowed_issuers.is_empty()
233 && !self.config.allowed_issuers.contains(&claims.iss)
234 {
235 return Err(JwtError::InvalidIssuer);
236 }
237
238 Ok(())
239 }
240
241 fn validate_audience(&self, claims: &JwtClaims) -> Result<(), JwtError> {
243 if let Some(required_aud) = &self.config.required_audience {
244 match &claims.aud {
245 Some(aud) if aud.contains(required_aud) => Ok(()),
246 Some(_) => Err(JwtError::InvalidAudience),
247 None => Err(JwtError::InvalidAudience),
248 }
249 } else {
250 Ok(())
251 }
252 }
253
254 pub async fn refresh_jwks(&self) -> Result<(), JwtError> {
256 let jwks = Jwks {
262 keys: vec![Jwk {
263 kty: "RSA".to_string(),
264 kid: Some("default".to_string()),
265 alg: Some("RS256".to_string()),
266 use_: Some("sig".to_string()),
267 n: Some("dummy_modulus".to_string()),
268 e: Some("AQAB".to_string()),
269 x: None,
270 y: None,
271 crv: None,
272 k: None,
273 }],
274 };
275
276 *self.jwks.write() = jwks;
277 *self.last_refresh.write() = Some(Instant::now());
278
279 Ok(())
280 }
281
282 pub fn needs_refresh(&self) -> bool {
284 match *self.last_refresh.read() {
285 Some(last) => last.elapsed() > self.config.jwks_refresh_interval,
286 None => true,
287 }
288 }
289
290 pub fn jwks_url(&self) -> &str {
292 &self.config.jwks_url
293 }
294
295 pub fn last_refresh_time(&self) -> Option<Instant> {
297 *self.last_refresh.read()
298 }
299}
300
301#[derive(Debug, serde::Deserialize)]
303pub struct JwtHeader {
304 pub alg: String,
306
307 #[serde(default)]
309 pub typ: Option<String>,
310
311 pub kid: Option<String>,
313}
314
315#[derive(Debug, Clone)]
317pub struct Jwks {
318 pub keys: Vec<Jwk>,
320}
321
322impl Jwks {
323 pub fn empty() -> Self {
325 Self { keys: Vec::new() }
326 }
327
328 pub fn get_key(&self, kid: &str) -> Option<&Jwk> {
330 self.keys.iter().find(|k| k.kid.as_deref() == Some(kid))
331 }
332
333 pub fn is_empty(&self) -> bool {
335 self.keys.is_empty()
336 }
337}
338
339#[derive(Debug, Clone, serde::Deserialize)]
341pub struct Jwk {
342 pub kty: String,
344
345 pub kid: Option<String>,
347
348 pub alg: Option<String>,
350
351 #[serde(rename = "use")]
353 pub use_: Option<String>,
354
355 pub n: Option<String>,
357
358 pub e: Option<String>,
360
361 pub x: Option<String>,
363
364 pub y: Option<String>,
366
367 pub crv: Option<String>,
369
370 #[serde(default)]
372 pub k: Option<String>,
373}
374
375fn base64_decode_url_safe(input: &str) -> Result<Vec<u8>, base64::DecodeError> {
377 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
378 URL_SAFE_NO_PAD.decode(input)
379}
380
381pub struct TokenCache {
383 cache: HashMap<String, CachedToken>,
385
386 max_size: usize,
388
389 ttl: Duration,
391}
392
393struct CachedToken {
394 claims: JwtClaims,
395 cached_at: Instant,
396}
397
398impl TokenCache {
399 pub fn new(max_size: usize, ttl: Duration) -> Self {
401 Self {
402 cache: HashMap::new(),
403 max_size,
404 ttl,
405 }
406 }
407
408 pub fn get(&self, token: &str) -> Option<&JwtClaims> {
410 self.cache.get(token).and_then(|cached| {
411 if cached.cached_at.elapsed() < self.ttl {
412 Some(&cached.claims)
413 } else {
414 None
415 }
416 })
417 }
418
419 pub fn insert(&mut self, token: String, claims: JwtClaims) {
421 if self.cache.len() >= self.max_size {
423 self.evict_expired();
424 }
425
426 self.cache.insert(
427 token,
428 CachedToken {
429 claims,
430 cached_at: Instant::now(),
431 },
432 );
433 }
434
435 pub fn evict_expired(&mut self) {
437 self.cache
438 .retain(|_, cached| cached.cached_at.elapsed() < self.ttl);
439 }
440
441 pub fn clear(&mut self) {
443 self.cache.clear();
444 }
445
446 pub fn len(&self) -> usize {
448 self.cache.len()
449 }
450
451 pub fn is_empty(&self) -> bool {
453 self.cache.is_empty()
454 }
455}
456
457impl Default for TokenCache {
458 fn default() -> Self {
459 Self::new(1000, Duration::from_secs(60))
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466
467 fn test_config() -> JwtConfig {
468 JwtConfig::new("https://example.com/.well-known/jwks.json")
469 .with_issuer("https://example.com")
470 .with_audience("test-api")
471 }
472
473 #[test]
474 fn test_jwt_validator_creation() {
475 let validator = JwtValidator::new(test_config());
476 assert!(validator.needs_refresh());
477 }
478
479 #[test]
480 fn test_jwks_empty() {
481 let jwks = Jwks::empty();
482 assert!(jwks.is_empty());
483 assert!(jwks.get_key("test").is_none());
484 }
485
486 #[test]
487 fn test_token_cache() {
488 let mut cache = TokenCache::new(10, Duration::from_secs(60));
489
490 let claims = JwtClaims {
491 sub: "user123".to_string(),
492 iss: "test".to_string(),
493 aud: None,
494 exp: chrono::Utc::now().timestamp() + 3600,
495 iat: chrono::Utc::now().timestamp(),
496 nbf: None,
497 jti: None,
498 name: Some("Test User".to_string()),
499 email: Some("test@example.com".to_string()),
500 roles: vec!["user".to_string()],
501 tenant_id: None,
502 custom: HashMap::new(),
503 };
504
505 cache.insert("token123".to_string(), claims);
506
507 assert_eq!(cache.len(), 1);
508 assert!(cache.get("token123").is_some());
509 assert!(cache.get("nonexistent").is_none());
510 }
511
512 #[test]
513 fn test_token_cache_eviction() {
514 let mut cache = TokenCache::new(2, Duration::from_millis(1));
515
516 let claims = JwtClaims {
517 sub: "user".to_string(),
518 iss: "test".to_string(),
519 aud: None,
520 exp: chrono::Utc::now().timestamp() + 3600,
521 iat: chrono::Utc::now().timestamp(),
522 nbf: None,
523 jti: None,
524 name: None,
525 email: None,
526 roles: Vec::new(),
527 tenant_id: None,
528 custom: HashMap::new(),
529 };
530
531 cache.insert("token1".to_string(), claims.clone());
532 cache.insert("token2".to_string(), claims);
533
534 std::thread::sleep(Duration::from_millis(5));
536
537 cache.evict_expired();
538 assert!(cache.is_empty());
539 }
540
541 #[test]
542 fn test_invalid_token_format() {
543 let validator = JwtValidator::new(test_config());
544
545 assert!(matches!(
546 validator.validate("invalid"),
547 Err(JwtError::InvalidFormat)
548 ));
549
550 assert!(matches!(
551 validator.validate("only.two"),
552 Err(JwtError::InvalidFormat)
553 ));
554 }
555
556 fn b64(d: &[u8]) -> String {
559 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
560 URL_SAFE_NO_PAD.encode(d)
561 }
562
563 fn hs256_token(secret: &[u8], claims_json: &str) -> String {
565 use hmac::{Hmac, Mac};
566 use sha2::Sha256;
567 let header = b64(br#"{"alg":"HS256","typ":"JWT"}"#);
568 let payload = b64(claims_json.as_bytes());
569 let signing_input = format!("{header}.{payload}");
570 let mut mac = <Hmac<Sha256>>::new_from_slice(secret).unwrap();
571 mac.update(signing_input.as_bytes());
572 let sig = b64(&mac.finalize().into_bytes());
573 format!("{signing_input}.{sig}")
574 }
575
576 fn hs256_validator(secret: &[u8]) -> JwtValidator {
577 let config = JwtConfig {
578 allowed_algorithms: vec!["HS256".to_string()],
579 ..Default::default()
580 };
581 let v = JwtValidator::new(config);
582 v.set_hs256_secret(None, secret);
583 v
584 }
585
586 fn future_claims() -> String {
587 let exp = chrono::Utc::now().timestamp() + 3600;
588 format!(r#"{{"sub":"alice","iss":"test","exp":{exp},"iat":0}}"#)
589 }
590
591 #[test]
592 fn hs256_valid_token_accepted() {
593 let v = hs256_validator(b"top-secret");
594 let token = hs256_token(b"top-secret", &future_claims());
595 let claims = v.validate(&token).expect("valid HS256 token");
596 assert_eq!(claims.sub, "alice");
597 }
598
599 #[test]
600 fn hs256_wrong_secret_rejected() {
601 let v = hs256_validator(b"top-secret");
602 let token = hs256_token(b"WRONG-secret", &future_claims());
603 assert!(matches!(
604 v.validate(&token),
605 Err(JwtError::InvalidSignature)
606 ));
607 }
608
609 #[test]
610 fn hs256_tampered_payload_rejected() {
611 let v = hs256_validator(b"top-secret");
612 let token = hs256_token(b"top-secret", &future_claims());
613 let parts: Vec<&str> = token.split('.').collect();
614 let evil = b64(br#"{"sub":"attacker","iss":"test","exp":9999999999,"iat":0}"#);
616 let forged = format!("{}.{}.{}", parts[0], evil, parts[2]);
617 assert!(matches!(
618 v.validate(&forged),
619 Err(JwtError::InvalidSignature)
620 ));
621 }
622
623 #[test]
624 fn hs256_expired_token_rejected() {
625 let v = hs256_validator(b"top-secret");
626 let token = hs256_token(
627 b"top-secret",
628 r#"{"sub":"alice","iss":"test","exp":1000,"iat":0}"#,
629 );
630 assert!(matches!(v.validate(&token), Err(JwtError::Expired)));
631 }
632
633 #[test]
634 fn unsupported_alg_is_rejected_not_trusted() {
635 let config = JwtConfig {
638 allowed_algorithms: vec!["RS256".to_string()],
639 ..Default::default()
640 };
641 let v = JwtValidator::new(config);
642 v.set_hs256_secret(None, b"x"); let header = b64(br#"{"alg":"RS256","typ":"JWT"}"#);
644 let payload = b64(future_claims().as_bytes());
645 let token = format!("{header}.{payload}.{}", b64(b"whatever"));
646 assert!(matches!(
647 v.validate(&token),
648 Err(JwtError::UnsupportedAlgorithm(_))
649 ));
650 }
651}