1use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use std::time::{Duration, SystemTime, UNIX_EPOCH};
8use parking_lot::RwLock;
9use tracing::{debug, warn};
10
11use crate::{AuthError, Result};
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct TokenSet {
16 pub access_token: String,
18 pub refresh_token: Option<String>,
20 pub id_token: Option<String>,
22 pub expires_at: DateTime<Utc>,
24 pub token_type: String,
26 pub scopes: Vec<String>,
28}
29
30impl TokenSet {
31 pub fn is_expired(&self) -> bool {
33 Utc::now() > self.expires_at
34 }
35
36 pub fn expires_within(&self, duration: chrono::Duration) -> bool {
38 Utc::now() + duration > self.expires_at
39 }
40
41 pub fn remaining_lifetime(&self) -> chrono::Duration {
43 self.expires_at - Utc::now()
44 }
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct IdTokenClaims {
50 pub iss: String,
52 pub sub: String,
54 pub aud: StringOrArray,
56 pub exp: i64,
58 pub iat: i64,
60 #[serde(default)]
62 pub nonce: Option<String>,
63 #[serde(default)]
65 pub email: Option<String>,
66 #[serde(default)]
68 pub email_verified: Option<bool>,
69 #[serde(default)]
71 pub name: Option<String>,
72 #[serde(default)]
74 pub given_name: Option<String>,
75 #[serde(default)]
77 pub family_name: Option<String>,
78 #[serde(default)]
80 pub picture: Option<String>,
81 #[serde(default)]
83 pub groups: Vec<String>,
84 #[serde(flatten)]
86 pub additional: HashMap<String, serde_json::Value>,
87}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91#[serde(untagged)]
92pub enum StringOrArray {
93 String(String),
95 Array(Vec<String>),
97}
98
99impl StringOrArray {
100 pub fn contains(&self, value: &str) -> bool {
102 match self {
103 StringOrArray::String(s) => s == value,
104 StringOrArray::Array(arr) => arr.iter().any(|s| s == value),
105 }
106 }
107}
108
109#[derive(Debug, Clone, Deserialize)]
111struct JwkSet {
112 keys: Vec<Jwk>,
113}
114
115#[derive(Debug, Clone, Deserialize)]
117struct Jwk {
118 kid: Option<String>,
120 kty: String,
122 alg: Option<String>,
124 n: Option<String>,
126 e: Option<String>,
128 #[serde(rename = "use")]
130 key_use: Option<String>,
131}
132
133impl Jwk {
134 fn to_decoding_key(&self) -> std::result::Result<jsonwebtoken::DecodingKey, String> {
136 match self.kty.as_str() {
137 "RSA" => {
138 let n = self.n.as_ref().ok_or("Missing 'n' in RSA key")?;
139 let e = self.e.as_ref().ok_or("Missing 'e' in RSA key")?;
140 jsonwebtoken::DecodingKey::from_rsa_components(n, e)
141 .map_err(|e| format!("Failed to create RSA key: {}", e))
142 }
143 _ => Err(format!("Unsupported key type: {}", self.kty)),
144 }
145 }
146}
147
148#[derive(Clone)]
150struct JwksCacheEntry {
151 jwks: JwkSet,
152 expires_at: SystemTime,
153}
154
155struct JwksCache {
157 entries: HashMap<String, JwksCacheEntry>,
158 ttl: Duration,
159}
160
161impl JwksCache {
162 fn new(ttl: Duration) -> Self {
163 Self {
164 entries: HashMap::new(),
165 ttl,
166 }
167 }
168
169 fn get(&self, jwks_uri: &str) -> Option<JwkSet> {
170 let entry = self.entries.get(jwks_uri)?;
171 if SystemTime::now() < entry.expires_at {
172 Some(entry.jwks.clone())
173 } else {
174 None
175 }
176 }
177
178 fn insert(&mut self, jwks_uri: String, jwks: JwkSet) {
179 let expires_at = SystemTime::now() + self.ttl;
180 self.entries.insert(jwks_uri, JwksCacheEntry { jwks, expires_at });
181 }
182
183 fn clear_expired(&mut self) {
184 let now = SystemTime::now();
185 self.entries.retain(|_, entry| entry.expires_at > now);
186 }
187}
188
189pub struct TokenValidator {
191 issuer: String,
193 audience: String,
195 clock_skew: i64,
197 jwks_uri: Option<String>,
199 http_client: reqwest::Client,
201 jwks_cache: Arc<RwLock<JwksCache>>,
203}
204
205impl TokenValidator {
206 pub fn new(issuer: &str, audience: &str) -> Self {
208 Self {
209 issuer: issuer.to_string(),
210 audience: audience.to_string(),
211 clock_skew: 60, jwks_uri: None,
213 http_client: reqwest::Client::new(),
214 jwks_cache: Arc::new(RwLock::new(JwksCache::new(Duration::from_secs(3600)))),
215 }
216 }
217
218 pub fn with_jwks_uri(issuer: &str, audience: &str, jwks_uri: &str) -> Self {
220 Self {
221 issuer: issuer.to_string(),
222 audience: audience.to_string(),
223 clock_skew: 60,
224 jwks_uri: Some(jwks_uri.to_string()),
225 http_client: reqwest::Client::new(),
226 jwks_cache: Arc::new(RwLock::new(JwksCache::new(Duration::from_secs(3600)))),
227 }
228 }
229
230 pub fn with_clock_skew(mut self, seconds: i64) -> Self {
232 self.clock_skew = seconds;
233 self
234 }
235
236 pub fn validate_claims(&self, claims: &IdTokenClaims, expected_nonce: Option<&str>) -> Result<()> {
241 if claims.iss != self.issuer {
243 return Err(AuthError::TokenValidationFailed(format!(
244 "invalid issuer: expected {}, got {}",
245 self.issuer, claims.iss
246 )));
247 }
248
249 if !claims.aud.contains(&self.audience) {
251 return Err(AuthError::TokenValidationFailed(
252 "token audience mismatch".into(),
253 ));
254 }
255
256 let now = Utc::now().timestamp();
258 if claims.exp < now - self.clock_skew {
259 return Err(AuthError::TokenExpired);
260 }
261
262 if claims.iat > now + self.clock_skew {
264 return Err(AuthError::TokenValidationFailed(
265 "token issued in the future".into(),
266 ));
267 }
268
269 if let Some(expected) = expected_nonce {
271 if claims.nonce.as_deref() != Some(expected) {
272 return Err(AuthError::InvalidNonce);
273 }
274 }
275
276 Ok(())
277 }
278
279 async fn fetch_jwks(&self, jwks_uri: &str) -> Result<JwkSet> {
281 {
283 let cache = self.jwks_cache.read();
284 if let Some(jwks) = cache.get(jwks_uri) {
285 debug!("JWKS cache hit for {}", jwks_uri);
286 return Ok(jwks);
287 }
288 }
289
290 debug!("Fetching JWKS from {}", jwks_uri);
292 let response = self.http_client
293 .get(jwks_uri)
294 .send()
295 .await
296 .map_err(|e| AuthError::HttpError(format!("Failed to fetch JWKS: {}", e)))?;
297
298 if !response.status().is_success() {
299 return Err(AuthError::HttpError(format!(
300 "JWKS fetch failed with status: {}",
301 response.status()
302 )));
303 }
304
305 let jwks: JwkSet = response
306 .json()
307 .await
308 .map_err(|e| AuthError::HttpError(format!("Failed to parse JWKS: {}", e)))?;
309
310 {
312 let mut cache = self.jwks_cache.write();
313 cache.insert(jwks_uri.to_string(), jwks.clone());
314 cache.clear_expired();
315 }
316
317 Ok(jwks)
318 }
319
320 async fn verify_signature(&self, token: &str) -> Result<()> {
322 let jwks_uri = self.jwks_uri.as_ref()
323 .ok_or_else(|| AuthError::TokenValidationFailed("JWKS URI not configured".into()))?;
324
325 let header = jsonwebtoken::decode_header(token)
327 .map_err(|e| AuthError::TokenValidationFailed(format!("Invalid JWT header: {}", e)))?;
328
329 let kid = header.kid.ok_or_else(|| {
330 AuthError::TokenValidationFailed("JWT missing key ID (kid)".into())
331 })?;
332
333 let jwks = self.fetch_jwks(jwks_uri).await?;
335
336 let jwk = jwks.keys.iter()
338 .find(|k| k.kid.as_deref() == Some(&kid))
339 .ok_or_else(|| AuthError::TokenValidationFailed(format!("Key {} not found in JWKS", kid)))?;
340
341 let decoding_key = jwk.to_decoding_key()
343 .map_err(|e| AuthError::TokenValidationFailed(e))?;
344
345 let mut validation = jsonwebtoken::Validation::new(header.alg);
347 validation.set_issuer(&[&self.issuer]);
348 validation.set_audience(&[&self.audience]);
349 validation.leeway = self.clock_skew as u64;
350
351 let _decoded = jsonwebtoken::decode::<serde_json::Value>(token, &decoding_key, &validation)
352 .map_err(|e| AuthError::TokenValidationFailed(format!("JWT signature verification failed: {}", e)))?;
353
354 Ok(())
355 }
356
357 pub async fn decode_and_verify_jwt(&self, token: &str) -> Result<IdTokenClaims> {
361 if self.jwks_uri.is_some() {
363 self.verify_signature(token).await?;
364 } else {
365 warn!("JWT signature verification skipped - JWKS URI not configured");
366 }
367
368 self.decode_jwt_claims(token)
370 }
371
372 pub fn decode_jwt_claims(&self, token: &str) -> Result<IdTokenClaims> {
376 use base64::Engine;
377
378 let parts: Vec<&str> = token.split('.').collect();
379 if parts.len() != 3 {
380 return Err(AuthError::TokenValidationFailed("invalid JWT format".into()));
381 }
382
383 let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
384 .decode(parts[1])
385 .map_err(|e| AuthError::TokenValidationFailed(format!("base64 decode error: {}", e)))?;
386
387 let claims: IdTokenClaims = serde_json::from_slice(&payload)?;
388
389 Ok(claims)
390 }
391}
392
393#[derive(Debug, Clone, Serialize, Deserialize)]
395pub struct UserInfo {
396 pub sub: String,
398 pub email: Option<String>,
400 pub email_verified: bool,
402 pub name: Option<String>,
404 pub given_name: Option<String>,
406 pub family_name: Option<String>,
408 pub picture: Option<String>,
410 pub groups: Vec<String>,
412 pub provider: String,
414}
415
416impl UserInfo {
417 pub fn from_claims(claims: &IdTokenClaims, provider: &str) -> Self {
419 Self {
420 sub: claims.sub.clone(),
421 email: claims.email.clone(),
422 email_verified: claims.email_verified.unwrap_or(false),
423 name: claims.name.clone(),
424 given_name: claims.given_name.clone(),
425 family_name: claims.family_name.clone(),
426 picture: claims.picture.clone(),
427 groups: claims.groups.clone(),
428 provider: provider.to_string(),
429 }
430 }
431
432 pub fn email_domain(&self) -> Option<&str> {
434 self.email.as_ref().and_then(|e| e.split('@').nth(1))
435 }
436}
437
438#[cfg(test)]
439mod tests {
440 use super::*;
441
442 #[test]
443 fn test_token_expiration() {
444 let token = TokenSet {
445 access_token: "test".to_string(),
446 refresh_token: None,
447 id_token: None,
448 expires_at: Utc::now() + chrono::Duration::hours(1),
449 token_type: "Bearer".to_string(),
450 scopes: vec![],
451 };
452
453 assert!(!token.is_expired());
454 assert!(!token.expires_within(chrono::Duration::minutes(30)));
455 assert!(token.expires_within(chrono::Duration::hours(2)));
456 }
457
458 #[test]
459 fn test_string_or_array() {
460 let single = StringOrArray::String("test".to_string());
461 assert!(single.contains("test"));
462 assert!(!single.contains("other"));
463
464 let array = StringOrArray::Array(vec!["one".to_string(), "two".to_string()]);
465 assert!(array.contains("one"));
466 assert!(array.contains("two"));
467 assert!(!array.contains("three"));
468 }
469
470 #[test]
471 fn test_claim_validation() {
472 let validator = TokenValidator::new("https://accounts.google.com", "client-id");
473
474 let claims = IdTokenClaims {
475 iss: "https://accounts.google.com".to_string(),
476 sub: "user123".to_string(),
477 aud: StringOrArray::String("client-id".to_string()),
478 exp: Utc::now().timestamp() + 3600,
479 iat: Utc::now().timestamp(),
480 nonce: Some("test-nonce".to_string()),
481 email: Some("user@example.com".to_string()),
482 email_verified: Some(true),
483 name: Some("Test User".to_string()),
484 given_name: None,
485 family_name: None,
486 picture: None,
487 groups: vec![],
488 additional: HashMap::new(),
489 };
490
491 assert!(validator.validate_claims(&claims, Some("test-nonce")).is_ok());
492 assert!(validator.validate_claims(&claims, Some("wrong-nonce")).is_err());
493 }
494}