1use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10use thiserror::Error;
11
12use super::config::{Identity, JwtClaims, JwtConfig};
13
14#[derive(Debug, Error)]
16pub enum JwtError {
17 #[error("Invalid token format")]
18 InvalidFormat,
19
20 #[error("Token has expired")]
21 Expired,
22
23 #[error("Token not yet valid")]
24 NotYetValid,
25
26 #[error("Invalid issuer")]
27 InvalidIssuer,
28
29 #[error("Invalid audience")]
30 InvalidAudience,
31
32 #[error("Invalid signature")]
33 InvalidSignature,
34
35 #[error("Key not found: {0}")]
36 KeyNotFound(String),
37
38 #[error("Unsupported algorithm: {0}")]
39 UnsupportedAlgorithm(String),
40
41 #[error("Failed to decode: {0}")]
42 DecodeFailed(String),
43
44 #[error("JWKS fetch failed: {0}")]
45 JwksFetchFailed(String),
46}
47
48pub struct JwtValidator {
50 config: JwtConfig,
52
53 jwks: Arc<RwLock<Jwks>>,
55
56 last_refresh: Arc<RwLock<Option<Instant>>>,
58}
59
60impl JwtValidator {
61 pub fn new(config: JwtConfig) -> Self {
63 Self {
64 config,
65 jwks: Arc::new(RwLock::new(Jwks::empty())),
66 last_refresh: Arc::new(RwLock::new(None)),
67 }
68 }
69
70 pub fn validate(&self, token: &str) -> Result<JwtClaims, JwtError> {
72 let parts: Vec<&str> = token.split('.').collect();
74 if parts.len() != 3 {
75 return Err(JwtError::InvalidFormat);
76 }
77
78 let header = self.decode_header(parts[0])?;
80
81 if !self.config.allowed_algorithms.contains(&header.alg) {
83 return Err(JwtError::UnsupportedAlgorithm(header.alg));
84 }
85
86 let key = self.get_key(&header.kid)?;
88
89 self.verify_signature(token, &key)?;
91
92 let claims = self.decode_claims(parts[1])?;
94
95 self.validate_expiration(&claims)?;
97 self.validate_not_before(&claims)?;
98 self.validate_issuer(&claims)?;
99 self.validate_audience(&claims)?;
100
101 Ok(claims)
102 }
103
104 pub fn validate_to_identity(&self, token: &str) -> Result<Identity, JwtError> {
106 let claims = self.validate(token)?;
107 Ok(Identity::from_jwt_claims(&claims))
108 }
109
110 fn decode_header(&self, header_b64: &str) -> Result<JwtHeader, JwtError> {
112 let decoded = base64_decode_url_safe(header_b64)
113 .map_err(|e| JwtError::DecodeFailed(e.to_string()))?;
114
115 serde_json::from_slice(&decoded)
116 .map_err(|e| JwtError::DecodeFailed(e.to_string()))
117 }
118
119 fn decode_claims(&self, claims_b64: &str) -> Result<JwtClaims, JwtError> {
121 let decoded = base64_decode_url_safe(claims_b64)
122 .map_err(|e| JwtError::DecodeFailed(e.to_string()))?;
123
124 serde_json::from_slice(&decoded)
125 .map_err(|e| JwtError::DecodeFailed(e.to_string()))
126 }
127
128 fn get_key(&self, kid: &Option<String>) -> Result<Jwk, JwtError> {
130 let jwks = self.jwks.read();
131
132 match kid {
133 Some(kid) => jwks
134 .get_key(kid)
135 .cloned()
136 .ok_or_else(|| JwtError::KeyNotFound(kid.clone())),
137 None => jwks
138 .keys
139 .first()
140 .cloned()
141 .ok_or_else(|| JwtError::KeyNotFound("(default)".to_string())),
142 }
143 }
144
145 fn verify_signature(&self, _token: &str, _key: &Jwk) -> Result<(), JwtError> {
147 Ok(())
158 }
159
160 fn validate_expiration(&self, claims: &JwtClaims) -> Result<(), JwtError> {
162 let now = chrono::Utc::now().timestamp();
163 let exp_with_skew = claims.exp + self.config.clock_skew.as_secs() as i64;
164
165 if now > exp_with_skew {
166 return Err(JwtError::Expired);
167 }
168
169 Ok(())
170 }
171
172 fn validate_not_before(&self, claims: &JwtClaims) -> Result<(), JwtError> {
174 if let Some(nbf) = claims.nbf {
175 let now = chrono::Utc::now().timestamp();
176 let nbf_with_skew = nbf - self.config.clock_skew.as_secs() as i64;
177
178 if now < nbf_with_skew {
179 return Err(JwtError::NotYetValid);
180 }
181 }
182
183 Ok(())
184 }
185
186 fn validate_issuer(&self, claims: &JwtClaims) -> Result<(), JwtError> {
188 if !self.config.allowed_issuers.is_empty() {
189 if !self.config.allowed_issuers.contains(&claims.iss) {
190 return Err(JwtError::InvalidIssuer);
191 }
192 }
193
194 Ok(())
195 }
196
197 fn validate_audience(&self, claims: &JwtClaims) -> Result<(), JwtError> {
199 if let Some(required_aud) = &self.config.required_audience {
200 match &claims.aud {
201 Some(aud) if aud.contains(required_aud) => Ok(()),
202 Some(_) => Err(JwtError::InvalidAudience),
203 None => Err(JwtError::InvalidAudience),
204 }
205 } else {
206 Ok(())
207 }
208 }
209
210 pub async fn refresh_jwks(&self) -> Result<(), JwtError> {
212 let jwks = Jwks {
218 keys: vec![Jwk {
219 kty: "RSA".to_string(),
220 kid: Some("default".to_string()),
221 alg: Some("RS256".to_string()),
222 use_: Some("sig".to_string()),
223 n: Some("dummy_modulus".to_string()),
224 e: Some("AQAB".to_string()),
225 x: None,
226 y: None,
227 crv: None,
228 }],
229 };
230
231 *self.jwks.write() = jwks;
232 *self.last_refresh.write() = Some(Instant::now());
233
234 Ok(())
235 }
236
237 pub fn needs_refresh(&self) -> bool {
239 match *self.last_refresh.read() {
240 Some(last) => last.elapsed() > self.config.jwks_refresh_interval,
241 None => true,
242 }
243 }
244
245 pub fn jwks_url(&self) -> &str {
247 &self.config.jwks_url
248 }
249
250 pub fn last_refresh_time(&self) -> Option<Instant> {
252 *self.last_refresh.read()
253 }
254}
255
256#[derive(Debug, serde::Deserialize)]
258pub struct JwtHeader {
259 pub alg: String,
261
262 #[serde(default)]
264 pub typ: Option<String>,
265
266 pub kid: Option<String>,
268}
269
270#[derive(Debug, Clone)]
272pub struct Jwks {
273 pub keys: Vec<Jwk>,
275}
276
277impl Jwks {
278 pub fn empty() -> Self {
280 Self { keys: Vec::new() }
281 }
282
283 pub fn get_key(&self, kid: &str) -> Option<&Jwk> {
285 self.keys.iter().find(|k| k.kid.as_deref() == Some(kid))
286 }
287
288 pub fn is_empty(&self) -> bool {
290 self.keys.is_empty()
291 }
292}
293
294#[derive(Debug, Clone, serde::Deserialize)]
296pub struct Jwk {
297 pub kty: String,
299
300 pub kid: Option<String>,
302
303 pub alg: Option<String>,
305
306 #[serde(rename = "use")]
308 pub use_: Option<String>,
309
310 pub n: Option<String>,
312
313 pub e: Option<String>,
315
316 pub x: Option<String>,
318
319 pub y: Option<String>,
321
322 pub crv: Option<String>,
324}
325
326fn base64_decode_url_safe(input: &str) -> Result<Vec<u8>, base64::DecodeError> {
328 use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
329 URL_SAFE_NO_PAD.decode(input)
330}
331
332pub struct TokenCache {
334 cache: HashMap<String, CachedToken>,
336
337 max_size: usize,
339
340 ttl: Duration,
342}
343
344struct CachedToken {
345 claims: JwtClaims,
346 cached_at: Instant,
347}
348
349impl TokenCache {
350 pub fn new(max_size: usize, ttl: Duration) -> Self {
352 Self {
353 cache: HashMap::new(),
354 max_size,
355 ttl,
356 }
357 }
358
359 pub fn get(&self, token: &str) -> Option<&JwtClaims> {
361 self.cache.get(token).and_then(|cached| {
362 if cached.cached_at.elapsed() < self.ttl {
363 Some(&cached.claims)
364 } else {
365 None
366 }
367 })
368 }
369
370 pub fn insert(&mut self, token: String, claims: JwtClaims) {
372 if self.cache.len() >= self.max_size {
374 self.evict_expired();
375 }
376
377 self.cache.insert(
378 token,
379 CachedToken {
380 claims,
381 cached_at: Instant::now(),
382 },
383 );
384 }
385
386 pub fn evict_expired(&mut self) {
388 self.cache
389 .retain(|_, cached| cached.cached_at.elapsed() < self.ttl);
390 }
391
392 pub fn clear(&mut self) {
394 self.cache.clear();
395 }
396
397 pub fn len(&self) -> usize {
399 self.cache.len()
400 }
401
402 pub fn is_empty(&self) -> bool {
404 self.cache.is_empty()
405 }
406}
407
408impl Default for TokenCache {
409 fn default() -> Self {
410 Self::new(1000, Duration::from_secs(60))
411 }
412}
413
414#[cfg(test)]
415mod tests {
416 use super::*;
417
418 fn test_config() -> JwtConfig {
419 JwtConfig::new("https://example.com/.well-known/jwks.json")
420 .with_issuer("https://example.com")
421 .with_audience("test-api")
422 }
423
424 #[test]
425 fn test_jwt_validator_creation() {
426 let validator = JwtValidator::new(test_config());
427 assert!(validator.needs_refresh());
428 }
429
430 #[test]
431 fn test_jwks_empty() {
432 let jwks = Jwks::empty();
433 assert!(jwks.is_empty());
434 assert!(jwks.get_key("test").is_none());
435 }
436
437 #[test]
438 fn test_token_cache() {
439 let mut cache = TokenCache::new(10, Duration::from_secs(60));
440
441 let claims = JwtClaims {
442 sub: "user123".to_string(),
443 iss: "test".to_string(),
444 aud: None,
445 exp: chrono::Utc::now().timestamp() + 3600,
446 iat: chrono::Utc::now().timestamp(),
447 nbf: None,
448 jti: None,
449 name: Some("Test User".to_string()),
450 email: Some("test@example.com".to_string()),
451 roles: vec!["user".to_string()],
452 tenant_id: None,
453 custom: HashMap::new(),
454 };
455
456 cache.insert("token123".to_string(), claims);
457
458 assert_eq!(cache.len(), 1);
459 assert!(cache.get("token123").is_some());
460 assert!(cache.get("nonexistent").is_none());
461 }
462
463 #[test]
464 fn test_token_cache_eviction() {
465 let mut cache = TokenCache::new(2, Duration::from_millis(1));
466
467 let claims = JwtClaims {
468 sub: "user".to_string(),
469 iss: "test".to_string(),
470 aud: None,
471 exp: chrono::Utc::now().timestamp() + 3600,
472 iat: chrono::Utc::now().timestamp(),
473 nbf: None,
474 jti: None,
475 name: None,
476 email: None,
477 roles: Vec::new(),
478 tenant_id: None,
479 custom: HashMap::new(),
480 };
481
482 cache.insert("token1".to_string(), claims.clone());
483 cache.insert("token2".to_string(), claims);
484
485 std::thread::sleep(Duration::from_millis(5));
487
488 cache.evict_expired();
489 assert!(cache.is_empty());
490 }
491
492 #[test]
493 fn test_invalid_token_format() {
494 let validator = JwtValidator::new(test_config());
495
496 assert!(matches!(
497 validator.validate("invalid"),
498 Err(JwtError::InvalidFormat)
499 ));
500
501 assert!(matches!(
502 validator.validate("only.two"),
503 Err(JwtError::InvalidFormat)
504 ));
505 }
506}