Skip to main content

heliosdb_proxy/auth/
jwt.rs

1//! JWT Token Validation
2//!
3//! Validates JWT tokens using JWKS (JSON Web Key Sets) for signature verification.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10use thiserror::Error;
11
12use super::config::{Identity, JwtClaims, JwtConfig};
13
14/// JWT validation errors
15#[derive(Debug, Error)]
16pub enum JwtError {
17    #[error("Invalid token format")]
18    InvalidFormat,
19
20    #[error("Token has expired")]
21    Expired,
22
23    #[error("Token not yet valid")]
24    NotYetValid,
25
26    #[error("Invalid issuer")]
27    InvalidIssuer,
28
29    #[error("Invalid audience")]
30    InvalidAudience,
31
32    #[error("Invalid signature")]
33    InvalidSignature,
34
35    #[error("Key not found: {0}")]
36    KeyNotFound(String),
37
38    #[error("Unsupported algorithm: {0}")]
39    UnsupportedAlgorithm(String),
40
41    #[error("Failed to decode: {0}")]
42    DecodeFailed(String),
43
44    #[error("JWKS fetch failed: {0}")]
45    JwksFetchFailed(String),
46}
47
48/// JWT validator
49pub struct JwtValidator {
50    /// Configuration
51    config: JwtConfig,
52
53    /// Cached JWKS
54    jwks: Arc<RwLock<Jwks>>,
55
56    /// Last JWKS refresh time
57    last_refresh: Arc<RwLock<Option<Instant>>>,
58}
59
60impl JwtValidator {
61    /// Create a new JWT validator
62    pub fn new(config: JwtConfig) -> Self {
63        Self {
64            config,
65            jwks: Arc::new(RwLock::new(Jwks::empty())),
66            last_refresh: Arc::new(RwLock::new(None)),
67        }
68    }
69
70    /// Validate a JWT token and return claims
71    pub fn validate(&self, token: &str) -> Result<JwtClaims, JwtError> {
72        // Split token into parts
73        let parts: Vec<&str> = token.split('.').collect();
74        if parts.len() != 3 {
75            return Err(JwtError::InvalidFormat);
76        }
77
78        // Decode header
79        let header = self.decode_header(parts[0])?;
80
81        // Check algorithm
82        if !self.config.allowed_algorithms.contains(&header.alg) {
83            return Err(JwtError::UnsupportedAlgorithm(header.alg));
84        }
85
86        // Get signing key
87        let key = self.get_key(&header.kid)?;
88
89        // Verify signature
90        self.verify_signature(token, &key)?;
91
92        // Decode claims
93        let claims = self.decode_claims(parts[1])?;
94
95        // Validate standard claims
96        self.validate_expiration(&claims)?;
97        self.validate_not_before(&claims)?;
98        self.validate_issuer(&claims)?;
99        self.validate_audience(&claims)?;
100
101        Ok(claims)
102    }
103
104    /// Validate token and convert to Identity
105    pub fn validate_to_identity(&self, token: &str) -> Result<Identity, JwtError> {
106        let claims = self.validate(token)?;
107        Ok(Identity::from_jwt_claims(&claims))
108    }
109
110    /// Decode JWT header
111    fn decode_header(&self, header_b64: &str) -> Result<JwtHeader, JwtError> {
112        let decoded = base64_decode_url_safe(header_b64)
113            .map_err(|e| JwtError::DecodeFailed(e.to_string()))?;
114
115        serde_json::from_slice(&decoded).map_err(|e| JwtError::DecodeFailed(e.to_string()))
116    }
117
118    /// Decode JWT claims
119    fn decode_claims(&self, claims_b64: &str) -> Result<JwtClaims, JwtError> {
120        let decoded = base64_decode_url_safe(claims_b64)
121            .map_err(|e| JwtError::DecodeFailed(e.to_string()))?;
122
123        serde_json::from_slice(&decoded).map_err(|e| JwtError::DecodeFailed(e.to_string()))
124    }
125
126    /// Get signing key by key ID
127    fn get_key(&self, kid: &Option<String>) -> Result<Jwk, JwtError> {
128        let jwks = self.jwks.read();
129
130        match kid {
131            Some(kid) => jwks
132                .get_key(kid)
133                .cloned()
134                .ok_or_else(|| JwtError::KeyNotFound(kid.clone())),
135            None => jwks
136                .keys
137                .first()
138                .cloned()
139                .ok_or_else(|| JwtError::KeyNotFound("(default)".to_string())),
140        }
141    }
142
143    /// Verify the token signature.
144    ///
145    /// HS256 is verified with a constant-time HMAC-SHA256 comparison against the
146    /// symmetric key (`kty = "oct"`, base64url `k`). Any other algorithm is
147    /// REJECTED with `UnsupportedAlgorithm` — a forged/tampered token can no
148    /// longer slip through (the previous implementation accepted every
149    /// signature unconditionally).
150    fn verify_signature(&self, token: &str, key: &Jwk) -> Result<(), JwtError> {
151        let parts: Vec<&str> = token.split('.').collect();
152        if parts.len() != 3 {
153            return Err(JwtError::InvalidFormat);
154        }
155        let header = self.decode_header(parts[0])?;
156        let provided_sig =
157            base64_decode_url_safe(parts[2]).map_err(|e| JwtError::DecodeFailed(e.to_string()))?;
158        // The JWS signing input is the raw `header.claims` base64url segments.
159        let signing_input = format!("{}.{}", parts[0], parts[1]);
160
161        match header.alg.as_str() {
162            "HS256" => {
163                use hmac::{Hmac, Mac};
164                use sha2::Sha256;
165                let secret = key
166                    .k
167                    .as_deref()
168                    .map(base64_decode_url_safe)
169                    .transpose()
170                    .map_err(|e| JwtError::DecodeFailed(e.to_string()))?
171                    .ok_or_else(|| JwtError::KeyNotFound("HS256 symmetric key".to_string()))?;
172                let mut mac = <Hmac<Sha256>>::new_from_slice(&secret)
173                    .map_err(|_| JwtError::InvalidSignature)?;
174                mac.update(signing_input.as_bytes());
175                // Constant-time verification.
176                mac.verify_slice(&provided_sig)
177                    .map_err(|_| JwtError::InvalidSignature)
178            }
179            // RS256 / ES256 (asymmetric JWKS) verification is a follow-on; until
180            // implemented they are rejected rather than blindly trusted.
181            other => Err(JwtError::UnsupportedAlgorithm(other.to_string())),
182        }
183    }
184
185    /// Install a static HS256 (symmetric) signing key so tokens can be verified
186    /// against a configured shared secret without a JWKS endpoint.
187    pub fn set_hs256_secret(&self, kid: Option<String>, secret: &[u8]) {
188        use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
189        let jwk = Jwk {
190            kty: "oct".to_string(),
191            kid,
192            alg: Some("HS256".to_string()),
193            use_: Some("sig".to_string()),
194            n: None,
195            e: None,
196            x: None,
197            y: None,
198            crv: None,
199            k: Some(URL_SAFE_NO_PAD.encode(secret)),
200        };
201        *self.jwks.write() = Jwks { keys: vec![jwk] };
202    }
203
204    /// Validate expiration claim
205    fn validate_expiration(&self, claims: &JwtClaims) -> Result<(), JwtError> {
206        let now = chrono::Utc::now().timestamp();
207        let exp_with_skew = claims.exp + self.config.clock_skew.as_secs() as i64;
208
209        if now > exp_with_skew {
210            return Err(JwtError::Expired);
211        }
212
213        Ok(())
214    }
215
216    /// Validate not-before claim
217    fn validate_not_before(&self, claims: &JwtClaims) -> Result<(), JwtError> {
218        if let Some(nbf) = claims.nbf {
219            let now = chrono::Utc::now().timestamp();
220            let nbf_with_skew = nbf - self.config.clock_skew.as_secs() as i64;
221
222            if now < nbf_with_skew {
223                return Err(JwtError::NotYetValid);
224            }
225        }
226
227        Ok(())
228    }
229
230    /// Validate issuer claim
231    fn validate_issuer(&self, claims: &JwtClaims) -> Result<(), JwtError> {
232        if !self.config.allowed_issuers.is_empty()
233            && !self.config.allowed_issuers.contains(&claims.iss)
234        {
235            return Err(JwtError::InvalidIssuer);
236        }
237
238        Ok(())
239    }
240
241    /// Validate audience claim
242    fn validate_audience(&self, claims: &JwtClaims) -> Result<(), JwtError> {
243        if let Some(required_aud) = &self.config.required_audience {
244            match &claims.aud {
245                Some(aud) if aud.contains(required_aud) => Ok(()),
246                Some(_) => Err(JwtError::InvalidAudience),
247                None => Err(JwtError::InvalidAudience),
248            }
249        } else {
250            Ok(())
251        }
252    }
253
254    /// Refresh JWKS from remote endpoint
255    pub async fn refresh_jwks(&self) -> Result<(), JwtError> {
256        // In a real implementation, this would fetch JWKS from the configured URL
257        // using an HTTP client like reqwest.
258        //
259        // For demonstration, we create a dummy JWKS.
260
261        let jwks = Jwks {
262            keys: vec![Jwk {
263                kty: "RSA".to_string(),
264                kid: Some("default".to_string()),
265                alg: Some("RS256".to_string()),
266                use_: Some("sig".to_string()),
267                n: Some("dummy_modulus".to_string()),
268                e: Some("AQAB".to_string()),
269                x: None,
270                y: None,
271                crv: None,
272                k: None,
273            }],
274        };
275
276        *self.jwks.write() = jwks;
277        *self.last_refresh.write() = Some(Instant::now());
278
279        Ok(())
280    }
281
282    /// Check if JWKS needs refresh
283    pub fn needs_refresh(&self) -> bool {
284        match *self.last_refresh.read() {
285            Some(last) => last.elapsed() > self.config.jwks_refresh_interval,
286            None => true,
287        }
288    }
289
290    /// Get JWKS URL
291    pub fn jwks_url(&self) -> &str {
292        &self.config.jwks_url
293    }
294
295    /// Get last refresh time
296    pub fn last_refresh_time(&self) -> Option<Instant> {
297        *self.last_refresh.read()
298    }
299}
300
301/// JWT header
302#[derive(Debug, serde::Deserialize)]
303pub struct JwtHeader {
304    /// Algorithm
305    pub alg: String,
306
307    /// Token type
308    #[serde(default)]
309    pub typ: Option<String>,
310
311    /// Key ID
312    pub kid: Option<String>,
313}
314
315/// JSON Web Key Set
316#[derive(Debug, Clone)]
317pub struct Jwks {
318    /// Keys in the set
319    pub keys: Vec<Jwk>,
320}
321
322impl Jwks {
323    /// Create an empty JWKS
324    pub fn empty() -> Self {
325        Self { keys: Vec::new() }
326    }
327
328    /// Get key by ID
329    pub fn get_key(&self, kid: &str) -> Option<&Jwk> {
330        self.keys.iter().find(|k| k.kid.as_deref() == Some(kid))
331    }
332
333    /// Check if JWKS has any keys
334    pub fn is_empty(&self) -> bool {
335        self.keys.is_empty()
336    }
337}
338
339/// JSON Web Key
340#[derive(Debug, Clone, serde::Deserialize)]
341pub struct Jwk {
342    /// Key type (e.g., "RSA", "EC")
343    pub kty: String,
344
345    /// Key ID
346    pub kid: Option<String>,
347
348    /// Algorithm
349    pub alg: Option<String>,
350
351    /// Key use ("sig" or "enc")
352    #[serde(rename = "use")]
353    pub use_: Option<String>,
354
355    /// RSA modulus (for RSA keys)
356    pub n: Option<String>,
357
358    /// RSA exponent (for RSA keys)
359    pub e: Option<String>,
360
361    /// EC x coordinate (for EC keys)
362    pub x: Option<String>,
363
364    /// EC y coordinate (for EC keys)
365    pub y: Option<String>,
366
367    /// EC curve (for EC keys)
368    pub crv: Option<String>,
369
370    /// Symmetric key material, base64url-encoded (for `kty = "oct"` / HMAC).
371    #[serde(default)]
372    pub k: Option<String>,
373}
374
375/// Base64 URL-safe decode helper
376fn base64_decode_url_safe(input: &str) -> Result<Vec<u8>, base64::DecodeError> {
377    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
378    URL_SAFE_NO_PAD.decode(input)
379}
380
381/// Cache for validated tokens
382pub struct TokenCache {
383    /// Cached tokens with their claims
384    cache: HashMap<String, CachedToken>,
385
386    /// Maximum cache size
387    max_size: usize,
388
389    /// TTL for cached tokens
390    ttl: Duration,
391}
392
393struct CachedToken {
394    claims: JwtClaims,
395    cached_at: Instant,
396}
397
398impl TokenCache {
399    /// Create a new token cache
400    pub fn new(max_size: usize, ttl: Duration) -> Self {
401        Self {
402            cache: HashMap::new(),
403            max_size,
404            ttl,
405        }
406    }
407
408    /// Get cached claims for a token
409    pub fn get(&self, token: &str) -> Option<&JwtClaims> {
410        self.cache.get(token).and_then(|cached| {
411            if cached.cached_at.elapsed() < self.ttl {
412                Some(&cached.claims)
413            } else {
414                None
415            }
416        })
417    }
418
419    /// Cache validated claims
420    pub fn insert(&mut self, token: String, claims: JwtClaims) {
421        // Evict old entries if at capacity
422        if self.cache.len() >= self.max_size {
423            self.evict_expired();
424        }
425
426        self.cache.insert(
427            token,
428            CachedToken {
429                claims,
430                cached_at: Instant::now(),
431            },
432        );
433    }
434
435    /// Remove expired entries
436    pub fn evict_expired(&mut self) {
437        self.cache
438            .retain(|_, cached| cached.cached_at.elapsed() < self.ttl);
439    }
440
441    /// Clear all cached tokens
442    pub fn clear(&mut self) {
443        self.cache.clear();
444    }
445
446    /// Get cache size
447    pub fn len(&self) -> usize {
448        self.cache.len()
449    }
450
451    /// Check if cache is empty
452    pub fn is_empty(&self) -> bool {
453        self.cache.is_empty()
454    }
455}
456
457impl Default for TokenCache {
458    fn default() -> Self {
459        Self::new(1000, Duration::from_secs(60))
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    fn test_config() -> JwtConfig {
468        JwtConfig::new("https://example.com/.well-known/jwks.json")
469            .with_issuer("https://example.com")
470            .with_audience("test-api")
471    }
472
473    #[test]
474    fn test_jwt_validator_creation() {
475        let validator = JwtValidator::new(test_config());
476        assert!(validator.needs_refresh());
477    }
478
479    #[test]
480    fn test_jwks_empty() {
481        let jwks = Jwks::empty();
482        assert!(jwks.is_empty());
483        assert!(jwks.get_key("test").is_none());
484    }
485
486    #[test]
487    fn test_token_cache() {
488        let mut cache = TokenCache::new(10, Duration::from_secs(60));
489
490        let claims = JwtClaims {
491            sub: "user123".to_string(),
492            iss: "test".to_string(),
493            aud: None,
494            exp: chrono::Utc::now().timestamp() + 3600,
495            iat: chrono::Utc::now().timestamp(),
496            nbf: None,
497            jti: None,
498            name: Some("Test User".to_string()),
499            email: Some("test@example.com".to_string()),
500            roles: vec!["user".to_string()],
501            tenant_id: None,
502            custom: HashMap::new(),
503        };
504
505        cache.insert("token123".to_string(), claims);
506
507        assert_eq!(cache.len(), 1);
508        assert!(cache.get("token123").is_some());
509        assert!(cache.get("nonexistent").is_none());
510    }
511
512    #[test]
513    fn test_token_cache_eviction() {
514        let mut cache = TokenCache::new(2, Duration::from_millis(1));
515
516        let claims = JwtClaims {
517            sub: "user".to_string(),
518            iss: "test".to_string(),
519            aud: None,
520            exp: chrono::Utc::now().timestamp() + 3600,
521            iat: chrono::Utc::now().timestamp(),
522            nbf: None,
523            jti: None,
524            name: None,
525            email: None,
526            roles: Vec::new(),
527            tenant_id: None,
528            custom: HashMap::new(),
529        };
530
531        cache.insert("token1".to_string(), claims.clone());
532        cache.insert("token2".to_string(), claims);
533
534        // Wait for expiration
535        std::thread::sleep(Duration::from_millis(5));
536
537        cache.evict_expired();
538        assert!(cache.is_empty());
539    }
540
541    #[test]
542    fn test_invalid_token_format() {
543        let validator = JwtValidator::new(test_config());
544
545        assert!(matches!(
546            validator.validate("invalid"),
547            Err(JwtError::InvalidFormat)
548        ));
549
550        assert!(matches!(
551            validator.validate("only.two"),
552            Err(JwtError::InvalidFormat)
553        ));
554    }
555
556    // ---- HS256 signature verification (real crypto; closes the accept-all hole) ----
557
558    fn b64(d: &[u8]) -> String {
559        use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
560        URL_SAFE_NO_PAD.encode(d)
561    }
562
563    /// Build a real HS256 JWT over `claims_json`, signed with `secret`.
564    fn hs256_token(secret: &[u8], claims_json: &str) -> String {
565        use hmac::{Hmac, Mac};
566        use sha2::Sha256;
567        let header = b64(br#"{"alg":"HS256","typ":"JWT"}"#);
568        let payload = b64(claims_json.as_bytes());
569        let signing_input = format!("{header}.{payload}");
570        let mut mac = <Hmac<Sha256>>::new_from_slice(secret).unwrap();
571        mac.update(signing_input.as_bytes());
572        let sig = b64(&mac.finalize().into_bytes());
573        format!("{signing_input}.{sig}")
574    }
575
576    fn hs256_validator(secret: &[u8]) -> JwtValidator {
577        let config = JwtConfig {
578            allowed_algorithms: vec!["HS256".to_string()],
579            ..Default::default()
580        };
581        let v = JwtValidator::new(config);
582        v.set_hs256_secret(None, secret);
583        v
584    }
585
586    fn future_claims() -> String {
587        let exp = chrono::Utc::now().timestamp() + 3600;
588        format!(r#"{{"sub":"alice","iss":"test","exp":{exp},"iat":0}}"#)
589    }
590
591    #[test]
592    fn hs256_valid_token_accepted() {
593        let v = hs256_validator(b"top-secret");
594        let token = hs256_token(b"top-secret", &future_claims());
595        let claims = v.validate(&token).expect("valid HS256 token");
596        assert_eq!(claims.sub, "alice");
597    }
598
599    #[test]
600    fn hs256_wrong_secret_rejected() {
601        let v = hs256_validator(b"top-secret");
602        let token = hs256_token(b"WRONG-secret", &future_claims());
603        assert!(matches!(
604            v.validate(&token),
605            Err(JwtError::InvalidSignature)
606        ));
607    }
608
609    #[test]
610    fn hs256_tampered_payload_rejected() {
611        let v = hs256_validator(b"top-secret");
612        let token = hs256_token(b"top-secret", &future_claims());
613        let parts: Vec<&str> = token.split('.').collect();
614        // Forge a privilege-escalated payload, keep the original signature.
615        let evil = b64(br#"{"sub":"attacker","iss":"test","exp":9999999999,"iat":0}"#);
616        let forged = format!("{}.{}.{}", parts[0], evil, parts[2]);
617        assert!(matches!(
618            v.validate(&forged),
619            Err(JwtError::InvalidSignature)
620        ));
621    }
622
623    #[test]
624    fn hs256_expired_token_rejected() {
625        let v = hs256_validator(b"top-secret");
626        let token = hs256_token(
627            b"top-secret",
628            r#"{"sub":"alice","iss":"test","exp":1000,"iat":0}"#,
629        );
630        assert!(matches!(v.validate(&token), Err(JwtError::Expired)));
631    }
632
633    #[test]
634    fn unsupported_alg_is_rejected_not_trusted() {
635        // RS256 is allowed by config but unimplemented -> it must be REJECTED
636        // (the old verify_signature returned Ok for every algorithm).
637        let config = JwtConfig {
638            allowed_algorithms: vec!["RS256".to_string()],
639            ..Default::default()
640        };
641        let v = JwtValidator::new(config);
642        v.set_hs256_secret(None, b"x"); // a key must exist to reach verify_signature
643        let header = b64(br#"{"alg":"RS256","typ":"JWT"}"#);
644        let payload = b64(future_claims().as_bytes());
645        let token = format!("{header}.{payload}.{}", b64(b"whatever"));
646        assert!(matches!(
647            v.validate(&token),
648            Err(JwtError::UnsupportedAlgorithm(_))
649        ));
650    }
651}