Skip to main content

heliosdb_proxy/auth/
jwt.rs

1//! JWT Token Validation
2//!
3//! Validates JWT tokens using JWKS (JSON Web Key Sets) for signature verification.
4
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, Instant};
8
9use parking_lot::RwLock;
10use thiserror::Error;
11
12use super::config::{Identity, JwtClaims, JwtConfig};
13
14/// JWT validation errors
15#[derive(Debug, Error)]
16pub enum JwtError {
17    #[error("Invalid token format")]
18    InvalidFormat,
19
20    #[error("Token has expired")]
21    Expired,
22
23    #[error("Token not yet valid")]
24    NotYetValid,
25
26    #[error("Invalid issuer")]
27    InvalidIssuer,
28
29    #[error("Invalid audience")]
30    InvalidAudience,
31
32    #[error("Invalid signature")]
33    InvalidSignature,
34
35    #[error("Key not found: {0}")]
36    KeyNotFound(String),
37
38    #[error("Unsupported algorithm: {0}")]
39    UnsupportedAlgorithm(String),
40
41    #[error("Failed to decode: {0}")]
42    DecodeFailed(String),
43
44    #[error("JWKS fetch failed: {0}")]
45    JwksFetchFailed(String),
46}
47
48/// JWT validator
49pub struct JwtValidator {
50    /// Configuration
51    config: JwtConfig,
52
53    /// Cached JWKS
54    jwks: Arc<RwLock<Jwks>>,
55
56    /// Last JWKS refresh time
57    last_refresh: Arc<RwLock<Option<Instant>>>,
58}
59
60impl JwtValidator {
61    /// Create a new JWT validator
62    pub fn new(config: JwtConfig) -> Self {
63        Self {
64            config,
65            jwks: Arc::new(RwLock::new(Jwks::empty())),
66            last_refresh: Arc::new(RwLock::new(None)),
67        }
68    }
69
70    /// Validate a JWT token and return claims
71    pub fn validate(&self, token: &str) -> Result<JwtClaims, JwtError> {
72        // Split token into parts
73        let parts: Vec<&str> = token.split('.').collect();
74        if parts.len() != 3 {
75            return Err(JwtError::InvalidFormat);
76        }
77
78        // Decode header
79        let header = self.decode_header(parts[0])?;
80
81        // Check algorithm
82        if !self.config.allowed_algorithms.contains(&header.alg) {
83            return Err(JwtError::UnsupportedAlgorithm(header.alg));
84        }
85
86        // Get signing key
87        let key = self.get_key(&header.kid)?;
88
89        // Verify signature
90        self.verify_signature(token, &key)?;
91
92        // Decode claims
93        let claims = self.decode_claims(parts[1])?;
94
95        // Validate standard claims
96        self.validate_expiration(&claims)?;
97        self.validate_not_before(&claims)?;
98        self.validate_issuer(&claims)?;
99        self.validate_audience(&claims)?;
100
101        Ok(claims)
102    }
103
104    /// Validate token and convert to Identity
105    pub fn validate_to_identity(&self, token: &str) -> Result<Identity, JwtError> {
106        let claims = self.validate(token)?;
107        Ok(Identity::from_jwt_claims(&claims))
108    }
109
110    /// Decode JWT header
111    fn decode_header(&self, header_b64: &str) -> Result<JwtHeader, JwtError> {
112        let decoded = base64_decode_url_safe(header_b64)
113            .map_err(|e| JwtError::DecodeFailed(e.to_string()))?;
114
115        serde_json::from_slice(&decoded)
116            .map_err(|e| JwtError::DecodeFailed(e.to_string()))
117    }
118
119    /// Decode JWT claims
120    fn decode_claims(&self, claims_b64: &str) -> Result<JwtClaims, JwtError> {
121        let decoded = base64_decode_url_safe(claims_b64)
122            .map_err(|e| JwtError::DecodeFailed(e.to_string()))?;
123
124        serde_json::from_slice(&decoded)
125            .map_err(|e| JwtError::DecodeFailed(e.to_string()))
126    }
127
128    /// Get signing key by key ID
129    fn get_key(&self, kid: &Option<String>) -> Result<Jwk, JwtError> {
130        let jwks = self.jwks.read();
131
132        match kid {
133            Some(kid) => jwks
134                .get_key(kid)
135                .cloned()
136                .ok_or_else(|| JwtError::KeyNotFound(kid.clone())),
137            None => jwks
138                .keys
139                .first()
140                .cloned()
141                .ok_or_else(|| JwtError::KeyNotFound("(default)".to_string())),
142        }
143    }
144
145    /// Verify token signature
146    fn verify_signature(&self, _token: &str, _key: &Jwk) -> Result<(), JwtError> {
147        // In a real implementation, this would use a crypto library like
148        // ring or openssl to verify the signature.
149        //
150        // For now, we trust the signature (this is for demonstration).
151        // In production, you would:
152        // 1. Decode the signature from base64
153        // 2. Compute the expected signature using the key
154        // 3. Compare using constant-time comparison
155
156        // Placeholder: always succeed for demo
157        Ok(())
158    }
159
160    /// Validate expiration claim
161    fn validate_expiration(&self, claims: &JwtClaims) -> Result<(), JwtError> {
162        let now = chrono::Utc::now().timestamp();
163        let exp_with_skew = claims.exp + self.config.clock_skew.as_secs() as i64;
164
165        if now > exp_with_skew {
166            return Err(JwtError::Expired);
167        }
168
169        Ok(())
170    }
171
172    /// Validate not-before claim
173    fn validate_not_before(&self, claims: &JwtClaims) -> Result<(), JwtError> {
174        if let Some(nbf) = claims.nbf {
175            let now = chrono::Utc::now().timestamp();
176            let nbf_with_skew = nbf - self.config.clock_skew.as_secs() as i64;
177
178            if now < nbf_with_skew {
179                return Err(JwtError::NotYetValid);
180            }
181        }
182
183        Ok(())
184    }
185
186    /// Validate issuer claim
187    fn validate_issuer(&self, claims: &JwtClaims) -> Result<(), JwtError> {
188        if !self.config.allowed_issuers.is_empty() {
189            if !self.config.allowed_issuers.contains(&claims.iss) {
190                return Err(JwtError::InvalidIssuer);
191            }
192        }
193
194        Ok(())
195    }
196
197    /// Validate audience claim
198    fn validate_audience(&self, claims: &JwtClaims) -> Result<(), JwtError> {
199        if let Some(required_aud) = &self.config.required_audience {
200            match &claims.aud {
201                Some(aud) if aud.contains(required_aud) => Ok(()),
202                Some(_) => Err(JwtError::InvalidAudience),
203                None => Err(JwtError::InvalidAudience),
204            }
205        } else {
206            Ok(())
207        }
208    }
209
210    /// Refresh JWKS from remote endpoint
211    pub async fn refresh_jwks(&self) -> Result<(), JwtError> {
212        // In a real implementation, this would fetch JWKS from the configured URL
213        // using an HTTP client like reqwest.
214        //
215        // For demonstration, we create a dummy JWKS.
216
217        let jwks = Jwks {
218            keys: vec![Jwk {
219                kty: "RSA".to_string(),
220                kid: Some("default".to_string()),
221                alg: Some("RS256".to_string()),
222                use_: Some("sig".to_string()),
223                n: Some("dummy_modulus".to_string()),
224                e: Some("AQAB".to_string()),
225                x: None,
226                y: None,
227                crv: None,
228            }],
229        };
230
231        *self.jwks.write() = jwks;
232        *self.last_refresh.write() = Some(Instant::now());
233
234        Ok(())
235    }
236
237    /// Check if JWKS needs refresh
238    pub fn needs_refresh(&self) -> bool {
239        match *self.last_refresh.read() {
240            Some(last) => last.elapsed() > self.config.jwks_refresh_interval,
241            None => true,
242        }
243    }
244
245    /// Get JWKS URL
246    pub fn jwks_url(&self) -> &str {
247        &self.config.jwks_url
248    }
249
250    /// Get last refresh time
251    pub fn last_refresh_time(&self) -> Option<Instant> {
252        *self.last_refresh.read()
253    }
254}
255
256/// JWT header
257#[derive(Debug, serde::Deserialize)]
258pub struct JwtHeader {
259    /// Algorithm
260    pub alg: String,
261
262    /// Token type
263    #[serde(default)]
264    pub typ: Option<String>,
265
266    /// Key ID
267    pub kid: Option<String>,
268}
269
270/// JSON Web Key Set
271#[derive(Debug, Clone)]
272pub struct Jwks {
273    /// Keys in the set
274    pub keys: Vec<Jwk>,
275}
276
277impl Jwks {
278    /// Create an empty JWKS
279    pub fn empty() -> Self {
280        Self { keys: Vec::new() }
281    }
282
283    /// Get key by ID
284    pub fn get_key(&self, kid: &str) -> Option<&Jwk> {
285        self.keys.iter().find(|k| k.kid.as_deref() == Some(kid))
286    }
287
288    /// Check if JWKS has any keys
289    pub fn is_empty(&self) -> bool {
290        self.keys.is_empty()
291    }
292}
293
294/// JSON Web Key
295#[derive(Debug, Clone, serde::Deserialize)]
296pub struct Jwk {
297    /// Key type (e.g., "RSA", "EC")
298    pub kty: String,
299
300    /// Key ID
301    pub kid: Option<String>,
302
303    /// Algorithm
304    pub alg: Option<String>,
305
306    /// Key use ("sig" or "enc")
307    #[serde(rename = "use")]
308    pub use_: Option<String>,
309
310    /// RSA modulus (for RSA keys)
311    pub n: Option<String>,
312
313    /// RSA exponent (for RSA keys)
314    pub e: Option<String>,
315
316    /// EC x coordinate (for EC keys)
317    pub x: Option<String>,
318
319    /// EC y coordinate (for EC keys)
320    pub y: Option<String>,
321
322    /// EC curve (for EC keys)
323    pub crv: Option<String>,
324}
325
326/// Base64 URL-safe decode helper
327fn base64_decode_url_safe(input: &str) -> Result<Vec<u8>, base64::DecodeError> {
328    use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
329    URL_SAFE_NO_PAD.decode(input)
330}
331
332/// Cache for validated tokens
333pub struct TokenCache {
334    /// Cached tokens with their claims
335    cache: HashMap<String, CachedToken>,
336
337    /// Maximum cache size
338    max_size: usize,
339
340    /// TTL for cached tokens
341    ttl: Duration,
342}
343
344struct CachedToken {
345    claims: JwtClaims,
346    cached_at: Instant,
347}
348
349impl TokenCache {
350    /// Create a new token cache
351    pub fn new(max_size: usize, ttl: Duration) -> Self {
352        Self {
353            cache: HashMap::new(),
354            max_size,
355            ttl,
356        }
357    }
358
359    /// Get cached claims for a token
360    pub fn get(&self, token: &str) -> Option<&JwtClaims> {
361        self.cache.get(token).and_then(|cached| {
362            if cached.cached_at.elapsed() < self.ttl {
363                Some(&cached.claims)
364            } else {
365                None
366            }
367        })
368    }
369
370    /// Cache validated claims
371    pub fn insert(&mut self, token: String, claims: JwtClaims) {
372        // Evict old entries if at capacity
373        if self.cache.len() >= self.max_size {
374            self.evict_expired();
375        }
376
377        self.cache.insert(
378            token,
379            CachedToken {
380                claims,
381                cached_at: Instant::now(),
382            },
383        );
384    }
385
386    /// Remove expired entries
387    pub fn evict_expired(&mut self) {
388        self.cache
389            .retain(|_, cached| cached.cached_at.elapsed() < self.ttl);
390    }
391
392    /// Clear all cached tokens
393    pub fn clear(&mut self) {
394        self.cache.clear();
395    }
396
397    /// Get cache size
398    pub fn len(&self) -> usize {
399        self.cache.len()
400    }
401
402    /// Check if cache is empty
403    pub fn is_empty(&self) -> bool {
404        self.cache.is_empty()
405    }
406}
407
408impl Default for TokenCache {
409    fn default() -> Self {
410        Self::new(1000, Duration::from_secs(60))
411    }
412}
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    fn test_config() -> JwtConfig {
419        JwtConfig::new("https://example.com/.well-known/jwks.json")
420            .with_issuer("https://example.com")
421            .with_audience("test-api")
422    }
423
424    #[test]
425    fn test_jwt_validator_creation() {
426        let validator = JwtValidator::new(test_config());
427        assert!(validator.needs_refresh());
428    }
429
430    #[test]
431    fn test_jwks_empty() {
432        let jwks = Jwks::empty();
433        assert!(jwks.is_empty());
434        assert!(jwks.get_key("test").is_none());
435    }
436
437    #[test]
438    fn test_token_cache() {
439        let mut cache = TokenCache::new(10, Duration::from_secs(60));
440
441        let claims = JwtClaims {
442            sub: "user123".to_string(),
443            iss: "test".to_string(),
444            aud: None,
445            exp: chrono::Utc::now().timestamp() + 3600,
446            iat: chrono::Utc::now().timestamp(),
447            nbf: None,
448            jti: None,
449            name: Some("Test User".to_string()),
450            email: Some("test@example.com".to_string()),
451            roles: vec!["user".to_string()],
452            tenant_id: None,
453            custom: HashMap::new(),
454        };
455
456        cache.insert("token123".to_string(), claims);
457
458        assert_eq!(cache.len(), 1);
459        assert!(cache.get("token123").is_some());
460        assert!(cache.get("nonexistent").is_none());
461    }
462
463    #[test]
464    fn test_token_cache_eviction() {
465        let mut cache = TokenCache::new(2, Duration::from_millis(1));
466
467        let claims = JwtClaims {
468            sub: "user".to_string(),
469            iss: "test".to_string(),
470            aud: None,
471            exp: chrono::Utc::now().timestamp() + 3600,
472            iat: chrono::Utc::now().timestamp(),
473            nbf: None,
474            jti: None,
475            name: None,
476            email: None,
477            roles: Vec::new(),
478            tenant_id: None,
479            custom: HashMap::new(),
480        };
481
482        cache.insert("token1".to_string(), claims.clone());
483        cache.insert("token2".to_string(), claims);
484
485        // Wait for expiration
486        std::thread::sleep(Duration::from_millis(5));
487
488        cache.evict_expired();
489        assert!(cache.is_empty());
490    }
491
492    #[test]
493    fn test_invalid_token_format() {
494        let validator = JwtValidator::new(test_config());
495
496        assert!(matches!(
497            validator.validate("invalid"),
498            Err(JwtError::InvalidFormat)
499        ));
500
501        assert!(matches!(
502            validator.validate("only.two"),
503            Err(JwtError::InvalidFormat)
504        ));
505    }
506}