Skip to main content

corevpn_auth/
token.rs

1//! Token Management and Validation
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, SystemTime, UNIX_EPOCH};
8use parking_lot::RwLock;
9use tracing::{debug, warn};
10
11use crate::{AuthError, Result};
12
13/// OAuth2/OIDC Token Set
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TokenSet {
16    /// Access token
17    pub access_token: String,
18    /// Refresh token (optional)
19    pub refresh_token: Option<String>,
20    /// ID token (for OIDC)
21    pub id_token: Option<String>,
22    /// Token expiration time
23    pub expires_at: DateTime<Utc>,
24    /// Token type (usually "Bearer")
25    pub token_type: String,
26    /// Granted scopes
27    pub scopes: Vec<String>,
28}
29
30impl TokenSet {
31    /// Check if access token is expired
32    pub fn is_expired(&self) -> bool {
33        Utc::now() > self.expires_at
34    }
35
36    /// Check if access token will expire within given duration
37    pub fn expires_within(&self, duration: chrono::Duration) -> bool {
38        Utc::now() + duration > self.expires_at
39    }
40
41    /// Get remaining lifetime
42    pub fn remaining_lifetime(&self) -> chrono::Duration {
43        self.expires_at - Utc::now()
44    }
45}
46
47/// Claims extracted from ID token
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct IdTokenClaims {
50    /// Issuer
51    pub iss: String,
52    /// Subject (user ID)
53    pub sub: String,
54    /// Audience
55    pub aud: StringOrArray,
56    /// Expiration time
57    pub exp: i64,
58    /// Issued at time
59    pub iat: i64,
60    /// Nonce
61    #[serde(default)]
62    pub nonce: Option<String>,
63    /// Email
64    #[serde(default)]
65    pub email: Option<String>,
66    /// Email verified
67    #[serde(default)]
68    pub email_verified: Option<bool>,
69    /// Name
70    #[serde(default)]
71    pub name: Option<String>,
72    /// Given name
73    #[serde(default)]
74    pub given_name: Option<String>,
75    /// Family name
76    #[serde(default)]
77    pub family_name: Option<String>,
78    /// Picture URL
79    #[serde(default)]
80    pub picture: Option<String>,
81    /// Groups
82    #[serde(default)]
83    pub groups: Vec<String>,
84    /// Additional claims
85    #[serde(flatten)]
86    pub additional: HashMap<String, serde_json::Value>,
87}
88
89/// String or array of strings (for audience claim)
90#[derive(Debug, Clone, Serialize, Deserialize)]
91#[serde(untagged)]
92pub enum StringOrArray {
93    /// Single string
94    String(String),
95    /// Array of strings
96    Array(Vec<String>),
97}
98
99impl StringOrArray {
100    /// Check if contains a value
101    pub fn contains(&self, value: &str) -> bool {
102        match self {
103            StringOrArray::String(s) => s == value,
104            StringOrArray::Array(arr) => arr.iter().any(|s| s == value),
105        }
106    }
107}
108
109/// JWKS key set
110#[derive(Debug, Clone, Deserialize)]
111struct JwkSet {
112    keys: Vec<Jwk>,
113}
114
115/// Individual JWK
116#[derive(Debug, Clone, Deserialize)]
117struct Jwk {
118    /// Key ID
119    kid: Option<String>,
120    /// Key type (e.g., "RSA")
121    kty: String,
122    /// Algorithm (e.g., "RS256")
123    alg: Option<String>,
124    /// RSA modulus (base64url)
125    n: Option<String>,
126    /// RSA exponent (base64url)
127    e: Option<String>,
128    /// Key use (e.g., "sig")
129    #[serde(rename = "use")]
130    key_use: Option<String>,
131}
132
133impl Jwk {
134    /// Convert to jsonwebtoken DecodingKey
135    fn to_decoding_key(&self) -> std::result::Result<jsonwebtoken::DecodingKey, String> {
136        match self.kty.as_str() {
137            "RSA" => {
138                let n = self.n.as_ref().ok_or("Missing 'n' in RSA key")?;
139                let e = self.e.as_ref().ok_or("Missing 'e' in RSA key")?;
140                jsonwebtoken::DecodingKey::from_rsa_components(n, e)
141                    .map_err(|e| format!("Failed to create RSA key: {}", e))
142            }
143            _ => Err(format!("Unsupported key type: {}", self.kty)),
144        }
145    }
146}
147
148/// JWKS cache entry
149#[derive(Clone)]
150struct JwksCacheEntry {
151    jwks: JwkSet,
152    expires_at: SystemTime,
153}
154
155/// JWKS cache
156struct JwksCache {
157    entries: HashMap<String, JwksCacheEntry>,
158    ttl: Duration,
159}
160
161impl JwksCache {
162    fn new(ttl: Duration) -> Self {
163        Self {
164            entries: HashMap::new(),
165            ttl,
166        }
167    }
168
169    fn get(&self, jwks_uri: &str) -> Option<JwkSet> {
170        let entry = self.entries.get(jwks_uri)?;
171        if SystemTime::now() < entry.expires_at {
172            Some(entry.jwks.clone())
173        } else {
174            None
175        }
176    }
177
178    fn insert(&mut self, jwks_uri: String, jwks: JwkSet) {
179        let expires_at = SystemTime::now() + self.ttl;
180        self.entries.insert(jwks_uri, JwksCacheEntry { jwks, expires_at });
181    }
182
183    fn clear_expired(&mut self) {
184        let now = SystemTime::now();
185        self.entries.retain(|_, entry| entry.expires_at > now);
186    }
187}
188
189/// Token validator with JWKS support
190pub struct TokenValidator {
191    /// Expected issuer
192    issuer: String,
193    /// Expected audience (client ID)
194    audience: String,
195    /// Clock skew tolerance (in seconds)
196    clock_skew: i64,
197    /// JWKS URI for signature verification
198    jwks_uri: Option<String>,
199    /// HTTP client for fetching JWKS
200    http_client: reqwest::Client,
201    /// JWKS cache
202    jwks_cache: Arc<RwLock<JwksCache>>,
203}
204
205impl TokenValidator {
206    /// Create a new token validator
207    pub fn new(issuer: &str, audience: &str) -> Self {
208        Self {
209            issuer: issuer.to_string(),
210            audience: audience.to_string(),
211            clock_skew: 60, // 1 minute tolerance
212            jwks_uri: None,
213            http_client: reqwest::Client::new(),
214            jwks_cache: Arc::new(RwLock::new(JwksCache::new(Duration::from_secs(3600)))),
215        }
216    }
217
218    /// Create a new token validator with JWKS URI for signature verification
219    pub fn with_jwks_uri(issuer: &str, audience: &str, jwks_uri: &str) -> Self {
220        Self {
221            issuer: issuer.to_string(),
222            audience: audience.to_string(),
223            clock_skew: 60,
224            jwks_uri: Some(jwks_uri.to_string()),
225            http_client: reqwest::Client::new(),
226            jwks_cache: Arc::new(RwLock::new(JwksCache::new(Duration::from_secs(3600)))),
227        }
228    }
229
230    /// Set clock skew tolerance
231    pub fn with_clock_skew(mut self, seconds: i64) -> Self {
232        self.clock_skew = seconds;
233        self
234    }
235
236    /// Validate ID token claims (without cryptographic verification)
237    ///
238    /// Note: For production, you should also verify the JWT signature
239    /// using the provider's JWKS.
240    pub fn validate_claims(&self, claims: &IdTokenClaims, expected_nonce: Option<&str>) -> Result<()> {
241        // Check issuer
242        if claims.iss != self.issuer {
243            return Err(AuthError::TokenValidationFailed(format!(
244                "invalid issuer: expected {}, got {}",
245                self.issuer, claims.iss
246            )));
247        }
248
249        // Check audience
250        if !claims.aud.contains(&self.audience) {
251            return Err(AuthError::TokenValidationFailed(
252                "token audience mismatch".into(),
253            ));
254        }
255
256        // Check expiration
257        let now = Utc::now().timestamp();
258        if claims.exp < now - self.clock_skew {
259            return Err(AuthError::TokenExpired);
260        }
261
262        // Check issued at (not in the future)
263        if claims.iat > now + self.clock_skew {
264            return Err(AuthError::TokenValidationFailed(
265                "token issued in the future".into(),
266            ));
267        }
268
269        // Check nonce if provided
270        if let Some(expected) = expected_nonce {
271            if claims.nonce.as_deref() != Some(expected) {
272                return Err(AuthError::InvalidNonce);
273            }
274        }
275
276        Ok(())
277    }
278
279    /// Fetch and cache JWKS
280    async fn fetch_jwks(&self, jwks_uri: &str) -> Result<JwkSet> {
281        // Check cache first
282        {
283            let cache = self.jwks_cache.read();
284            if let Some(jwks) = cache.get(jwks_uri) {
285                debug!("JWKS cache hit for {}", jwks_uri);
286                return Ok(jwks);
287            }
288        }
289
290        // Fetch from network
291        debug!("Fetching JWKS from {}", jwks_uri);
292        let response = self.http_client
293            .get(jwks_uri)
294            .send()
295            .await
296            .map_err(|e| AuthError::HttpError(format!("Failed to fetch JWKS: {}", e)))?;
297
298        if !response.status().is_success() {
299            return Err(AuthError::HttpError(format!(
300                "JWKS fetch failed with status: {}",
301                response.status()
302            )));
303        }
304
305        let jwks: JwkSet = response
306            .json()
307            .await
308            .map_err(|e| AuthError::HttpError(format!("Failed to parse JWKS: {}", e)))?;
309
310        // Cache the result
311        {
312            let mut cache = self.jwks_cache.write();
313            cache.insert(jwks_uri.to_string(), jwks.clone());
314            cache.clear_expired();
315        }
316
317        Ok(jwks)
318    }
319
320    /// Verify JWT signature using JWKS
321    async fn verify_signature(&self, token: &str) -> Result<()> {
322        let jwks_uri = self.jwks_uri.as_ref()
323            .ok_or_else(|| AuthError::TokenValidationFailed("JWKS URI not configured".into()))?;
324
325        // Decode header to get key ID
326        let header = jsonwebtoken::decode_header(token)
327            .map_err(|e| AuthError::TokenValidationFailed(format!("Invalid JWT header: {}", e)))?;
328
329        let kid = header.kid.ok_or_else(|| {
330            AuthError::TokenValidationFailed("JWT missing key ID (kid)".into())
331        })?;
332
333        // Fetch JWKS
334        let jwks = self.fetch_jwks(jwks_uri).await?;
335
336        // Find the key by kid
337        let jwk = jwks.keys.iter()
338            .find(|k| k.kid.as_deref() == Some(&kid))
339            .ok_or_else(|| AuthError::TokenValidationFailed(format!("Key {} not found in JWKS", kid)))?;
340
341        // Convert to decoding key
342        let decoding_key = jwk.to_decoding_key()
343            .map_err(|e| AuthError::TokenValidationFailed(e))?;
344
345        // Verify signature using jsonwebtoken
346        let mut validation = jsonwebtoken::Validation::new(header.alg);
347        validation.set_issuer(&[&self.issuer]);
348        validation.set_audience(&[&self.audience]);
349        validation.leeway = self.clock_skew as u64;
350
351        let _decoded = jsonwebtoken::decode::<serde_json::Value>(token, &decoding_key, &validation)
352            .map_err(|e| AuthError::TokenValidationFailed(format!("JWT signature verification failed: {}", e)))?;
353
354        Ok(())
355    }
356
357    /// Decode and verify JWT token with signature verification
358    ///
359    /// This method verifies the JWT signature using JWKS before trusting the claims.
360    pub async fn decode_and_verify_jwt(&self, token: &str) -> Result<IdTokenClaims> {
361        // Verify signature if JWKS URI is configured
362        if self.jwks_uri.is_some() {
363            self.verify_signature(token).await?;
364        } else {
365            warn!("JWT signature verification skipped - JWKS URI not configured");
366        }
367
368        // Decode claims
369        self.decode_jwt_claims(token)
370    }
371
372    /// Decode JWT without verification (for extracting claims)
373    ///
374    /// NOTE: This does not verify the signature. Use decode_and_verify_jwt() for production.
375    pub fn decode_jwt_claims(&self, token: &str) -> Result<IdTokenClaims> {
376        use base64::Engine;
377
378        let parts: Vec<&str> = token.split('.').collect();
379        if parts.len() != 3 {
380            return Err(AuthError::TokenValidationFailed("invalid JWT format".into()));
381        }
382
383        let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
384            .decode(parts[1])
385            .map_err(|e| AuthError::TokenValidationFailed(format!("base64 decode error: {}", e)))?;
386
387        let claims: IdTokenClaims = serde_json::from_slice(&payload)?;
388
389        Ok(claims)
390    }
391}
392
393/// User information extracted from tokens
394#[derive(Debug, Clone, Serialize, Deserialize)]
395pub struct UserInfo {
396    /// Subject (user ID from provider)
397    pub sub: String,
398    /// Email
399    pub email: Option<String>,
400    /// Email verified
401    pub email_verified: bool,
402    /// Display name
403    pub name: Option<String>,
404    /// Given name
405    pub given_name: Option<String>,
406    /// Family name
407    pub family_name: Option<String>,
408    /// Picture URL
409    pub picture: Option<String>,
410    /// Groups
411    pub groups: Vec<String>,
412    /// Provider type
413    pub provider: String,
414}
415
416impl UserInfo {
417    /// Extract from ID token claims
418    pub fn from_claims(claims: &IdTokenClaims, provider: &str) -> Self {
419        Self {
420            sub: claims.sub.clone(),
421            email: claims.email.clone(),
422            email_verified: claims.email_verified.unwrap_or(false),
423            name: claims.name.clone(),
424            given_name: claims.given_name.clone(),
425            family_name: claims.family_name.clone(),
426            picture: claims.picture.clone(),
427            groups: claims.groups.clone(),
428            provider: provider.to_string(),
429        }
430    }
431
432    /// Get email domain
433    pub fn email_domain(&self) -> Option<&str> {
434        self.email.as_ref().and_then(|e| e.split('@').nth(1))
435    }
436}
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441
442    #[test]
443    fn test_token_expiration() {
444        let token = TokenSet {
445            access_token: "test".to_string(),
446            refresh_token: None,
447            id_token: None,
448            expires_at: Utc::now() + chrono::Duration::hours(1),
449            token_type: "Bearer".to_string(),
450            scopes: vec![],
451        };
452
453        assert!(!token.is_expired());
454        assert!(!token.expires_within(chrono::Duration::minutes(30)));
455        assert!(token.expires_within(chrono::Duration::hours(2)));
456    }
457
458    #[test]
459    fn test_string_or_array() {
460        let single = StringOrArray::String("test".to_string());
461        assert!(single.contains("test"));
462        assert!(!single.contains("other"));
463
464        let array = StringOrArray::Array(vec!["one".to_string(), "two".to_string()]);
465        assert!(array.contains("one"));
466        assert!(array.contains("two"));
467        assert!(!array.contains("three"));
468    }
469
470    #[test]
471    fn test_claim_validation() {
472        let validator = TokenValidator::new("https://accounts.google.com", "client-id");
473
474        let claims = IdTokenClaims {
475            iss: "https://accounts.google.com".to_string(),
476            sub: "user123".to_string(),
477            aud: StringOrArray::String("client-id".to_string()),
478            exp: Utc::now().timestamp() + 3600,
479            iat: Utc::now().timestamp(),
480            nonce: Some("test-nonce".to_string()),
481            email: Some("user@example.com".to_string()),
482            email_verified: Some(true),
483            name: Some("Test User".to_string()),
484            given_name: None,
485            family_name: None,
486            picture: None,
487            groups: vec![],
488            additional: HashMap::new(),
489        };
490
491        assert!(validator.validate_claims(&claims, Some("test-nonce")).is_ok());
492        assert!(validator.validate_claims(&claims, Some("wrong-nonce")).is_err());
493    }
494}