1use crate::runbeam_api::types::{RunbeamError, TeamInfo, UserInfo};
2use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
3use once_cell::sync::Lazy;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::{Arc, RwLock};
7use std::time::{Duration, Instant};
8
9#[derive(Debug, Clone, Deserialize)]
11pub struct Jwks {
12 pub keys: Vec<JwkKey>,
13}
14
15#[derive(Debug, Clone, Deserialize)]
17pub struct JwkKey {
18 pub kty: String,
20 #[serde(rename = "use")]
22 pub key_use: Option<String>,
23 pub kid: String,
25 pub alg: Option<String>,
27 pub n: String,
29 pub e: String,
31}
32
33impl JwkKey {
34 pub fn to_decoding_key(&self) -> Result<DecodingKey, RunbeamError> {
36 if self.kty != "RSA" {
37 return Err(RunbeamError::JwtValidation(format!(
38 "Unsupported key type: {}. Only RSA is supported.",
39 self.kty
40 )));
41 }
42
43 DecodingKey::from_rsa_components(&self.n, &self.e).map_err(|e| {
46 RunbeamError::JwtValidation(format!(
47 "Failed to create RSA decoding key from JWK components: {}",
48 e
49 ))
50 })
51 }
52}
53
54#[derive(Debug, Clone)]
60pub struct JwtValidationOptions {
61 pub trusted_issuers: Option<Vec<String>>,
67
68 pub jwks_uri: Option<String>,
72
73 pub algorithms: Option<Vec<Algorithm>>,
76
77 pub required_claims: Option<Vec<String>>,
81
82 pub leeway_seconds: Option<u64>,
85
86 pub validate_expiry: bool,
88
89 pub jwks_cache_duration_hours: u64,
91}
92
93impl Default for JwtValidationOptions {
94 fn default() -> Self {
95 Self {
96 trusted_issuers: None,
97 jwks_uri: None,
98 algorithms: None,
99 required_claims: None,
100 leeway_seconds: None,
101 validate_expiry: true,
102 jwks_cache_duration_hours: 24,
103 }
104 }
105}
106
107impl JwtValidationOptions {
108 pub fn new() -> Self {
110 Self::default()
111 }
112
113 pub fn with_trusted_issuers(mut self, issuers: Vec<String>) -> Self {
115 self.trusted_issuers = Some(issuers);
116 self
117 }
118
119 pub fn with_jwks_uri(mut self, uri: String) -> Self {
121 self.jwks_uri = Some(uri);
122 self
123 }
124
125 pub fn with_algorithms(mut self, algorithms: Vec<Algorithm>) -> Self {
127 self.algorithms = Some(algorithms);
128 self
129 }
130
131 pub fn with_required_claims(mut self, claims: Vec<String>) -> Self {
133 self.required_claims = Some(claims);
134 self
135 }
136
137 pub fn with_leeway_seconds(mut self, leeway: u64) -> Self {
139 self.leeway_seconds = Some(leeway.min(300)); self
141 }
142
143 pub fn with_validate_expiry(mut self, validate: bool) -> Self {
145 self.validate_expiry = validate;
146 self
147 }
148
149 pub fn with_jwks_cache_duration_hours(mut self, hours: u64) -> Self {
151 self.jwks_cache_duration_hours = hours;
152 self
153 }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize)]
161pub struct JwtClaims {
162 pub iss: String,
164 pub sub: String,
166 #[serde(default)]
168 pub aud: Option<String>,
169 pub exp: i64,
171 pub iat: i64,
173 #[serde(default)]
175 pub user: Option<UserInfo>,
176 #[serde(default)]
178 pub team: Option<TeamInfo>,
179}
180
181struct JwksCache {
183 keys: HashMap<String, DecodingKey>,
185 last_fetched: Instant,
187}
188
189impl JwksCache {
190 fn is_expired(&self, cache_duration: Duration) -> bool {
192 self.last_fetched.elapsed() > cache_duration
193 }
194}
195
196static JWKS_CACHE: Lazy<Arc<RwLock<HashMap<String, JwksCache>>>> =
198 Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
199
200async fn get_decoding_key(
208 issuer: &str,
209 kid: &str,
210 cache_duration: Duration,
211) -> Result<DecodingKey, RunbeamError> {
212 {
214 let cache = JWKS_CACHE
215 .read()
216 .map_err(|e| RunbeamError::JwtValidation(format!("Cache lock poisoned: {}", e)))?;
217
218 if let Some(cache_entry) = cache.get(issuer) {
219 if !cache_entry.is_expired(cache_duration) {
220 if let Some(key) = cache_entry.keys.get(kid) {
221 tracing::debug!("JWKS cache hit for issuer={}, kid={}", issuer, kid);
222 return Ok(key.clone());
223 } else {
224 tracing::debug!("JWKS cache miss: kid '{}' not found in cached keys", kid);
225 }
226 } else {
227 tracing::debug!("JWKS cache expired for issuer={}", issuer);
228 }
229 } else {
230 tracing::debug!("JWKS cache miss for issuer={}", issuer);
231 }
232 }
233
234 {
236 let cache = JWKS_CACHE
237 .write()
238 .map_err(|e| RunbeamError::JwtValidation(format!("Cache lock poisoned: {}", e)))?;
239
240 if let Some(cache_entry) = cache.get(issuer) {
241 if !cache_entry.is_expired(cache_duration) {
242 if let Some(key) = cache_entry.keys.get(kid) {
243 tracing::debug!(
244 "JWKS cache hit after lock acquisition for issuer={}, kid={}",
245 issuer,
246 kid
247 );
248 return Ok(key.clone());
249 }
250 }
251 }
252 }
254
255 tracing::info!("Fetching fresh JWKS for issuer={}", issuer);
257 let jwks = fetch_jwks(issuer).await?;
258
259 let mut keys_map = HashMap::new();
261 for jwk in &jwks.keys {
262 match jwk.to_decoding_key() {
263 Ok(key) => {
264 keys_map.insert(jwk.kid.clone(), key);
265 }
266 Err(e) => {
267 tracing::warn!(
268 "Failed to convert JWK kid='{}' to decoding key: {}",
269 jwk.kid,
270 e
271 );
272 }
274 }
275 }
276
277 let decoding_key = keys_map
279 .get(kid)
280 .ok_or_else(|| {
281 RunbeamError::JwtValidation(format!(
282 "Key ID '{}' not found in JWKS from issuer {}",
283 kid, issuer
284 ))
285 })?
286 .clone();
287
288 {
290 let mut cache = JWKS_CACHE
291 .write()
292 .map_err(|e| RunbeamError::JwtValidation(format!("Cache lock poisoned: {}", e)))?;
293
294 cache.insert(
295 issuer.to_string(),
296 JwksCache {
297 keys: keys_map,
298 last_fetched: Instant::now(),
299 },
300 );
301 }
302
303 tracing::debug!("JWKS cache updated for issuer={}", issuer);
304 Ok(decoding_key)
305}
306
307fn clear_jwks_cache(issuer: &str) -> Result<(), RunbeamError> {
311 let mut cache = JWKS_CACHE
312 .write()
313 .map_err(|e| RunbeamError::JwtValidation(format!("Cache lock poisoned: {}", e)))?;
314
315 if cache.remove(issuer).is_some() {
316 tracing::debug!("Cleared JWKS cache for issuer={}", issuer);
317 }
318 Ok(())
319}
320
321async fn fetch_jwks(issuer: &str) -> Result<Jwks, RunbeamError> {
326 let jwks_url = format!("{}/api/.well-known/jwks.json", issuer.trim_end_matches('/'));
328
329 tracing::debug!("Fetching JWKS from: {}", jwks_url);
330
331 let client = reqwest::Client::builder()
332 .timeout(Duration::from_secs(10))
333 .build()
334 .map_err(|e| RunbeamError::JwtValidation(format!("Failed to create HTTP client: {}", e)))?;
335
336 let response = client.get(&jwks_url).send().await.map_err(|e| {
337 tracing::error!("Failed to fetch JWKS from {}: {}", jwks_url, e);
338 if e.is_timeout() {
339 RunbeamError::JwtValidation(format!("JWKS endpoint timeout: {}", jwks_url))
340 } else if e.is_connect() {
341 RunbeamError::JwtValidation(format!("Failed to connect to JWKS endpoint: {}", jwks_url))
342 } else {
343 RunbeamError::JwtValidation(format!("Network error fetching JWKS: {}", e))
344 }
345 })?;
346
347 let status = response.status();
348 if !status.is_success() {
349 tracing::error!(
350 "JWKS endpoint returned HTTP {}: {}",
351 status.as_u16(),
352 jwks_url
353 );
354 return Err(RunbeamError::JwtValidation(format!(
355 "JWKS endpoint returned HTTP {}",
356 status.as_u16()
357 )));
358 }
359
360 let jwks = response.json::<Jwks>().await.map_err(|e| {
361 tracing::error!("Failed to parse JWKS response from {}: {}", jwks_url, e);
362 RunbeamError::JwtValidation(format!("Invalid JWKS response: {}", e))
363 })?;
364
365 tracing::info!(
366 "Successfully fetched JWKS with {} keys from {}",
367 jwks.keys.len(),
368 jwks_url
369 );
370 Ok(jwks)
371}
372
373impl JwtClaims {
374 pub fn api_base_url(&self) -> String {
379 if let Ok(url) = url::Url::parse(&self.iss) {
381 let scheme = url.scheme();
383 let host = url.host_str().unwrap_or("");
384 let port = url.port().map(|p| format!(":{}", p)).unwrap_or_default();
385 format!("{}://{}{}", scheme, host, port)
386 } else {
387 self.iss.clone()
389 }
390 }
391
392 pub fn is_expired(&self) -> bool {
394 let now = std::time::SystemTime::now()
395 .duration_since(std::time::UNIX_EPOCH)
396 .unwrap()
397 .as_secs() as i64;
398 self.exp < now
399 }
400}
401
402pub async fn validate_jwt_token(
447 token: &str,
448 options: &JwtValidationOptions,
449) -> Result<JwtClaims, RunbeamError> {
450 tracing::debug!("Validating JWT token (length: {})", token.len());
451
452 let header = decode_header(token)
454 .map_err(|e| RunbeamError::JwtValidation(format!("Invalid JWT header: {}", e)))?;
455
456 let kid = header.kid.ok_or_else(|| {
457 RunbeamError::JwtValidation("Missing 'kid' (key ID) in JWT header".to_string())
458 })?;
459
460 let allowed_algorithms = options.algorithms.as_deref()
462 .unwrap_or(&[Algorithm::RS256]);
463
464 if !allowed_algorithms.contains(&header.alg) {
465 return Err(RunbeamError::JwtValidation(format!(
466 "Algorithm {:?} not in allowed list: {:?}",
467 header.alg, allowed_algorithms
468 )));
469 }
470
471 tracing::debug!("JWT header decoded: alg={:?}, kid={}", header.alg, kid);
472
473 let insecure_token_data = jsonwebtoken::dangerous::insecure_decode::<JwtClaims>(token)
476 .map_err(|e| RunbeamError::JwtValidation(format!("Failed to decode JWT: {}", e)))?;
477
478 let issuer = &insecure_token_data.claims.iss;
479 if issuer.is_empty() {
480 return Err(RunbeamError::JwtValidation(
481 "Missing or empty issuer (iss) claim".to_string(),
482 ));
483 }
484
485 tracing::debug!("JWT issuer extracted: {}", issuer);
486
487 if let Some(trusted_issuers) = &options.trusted_issuers {
489 let issuer_base_url = insecure_token_data.claims.api_base_url();
491 let is_trusted = trusted_issuers.iter().any(|trusted| {
492 issuer == trusted || issuer_base_url == *trusted || issuer.starts_with(trusted)
494 });
495
496 if !is_trusted {
497 return Err(RunbeamError::JwtValidation(format!(
498 "Issuer '{}' is not in the trusted issuers list",
499 issuer
500 )));
501 }
502 tracing::debug!("Issuer validated against trusted list");
503 } else {
504 tracing::warn!(
505 "⚠️ SECURITY WARNING: No trusted_issuers configured! Accepting JWT from ANY issuer: '{}'. \
506 This is a security risk - an attacker can issue their own tokens from a malicious JWKS endpoint.",
507 issuer
508 );
509 }
510
511 let base_url = insecure_token_data.claims.api_base_url();
515 tracing::debug!("JWT issuer base URL: {}", base_url);
516
517 let jwks_url = options.jwks_uri.as_deref()
520 .unwrap_or(&base_url);
521
522 let cache_duration = Duration::from_secs(options.jwks_cache_duration_hours * 3600);
523 let decoding_key = match get_decoding_key(jwks_url, &kid, cache_duration).await {
524 Ok(key) => key,
525 Err(e) => {
526 tracing::warn!("Initial JWKS fetch/cache lookup failed: {}", e);
527 return Err(e);
528 }
529 };
530
531 let primary_algorithm = allowed_algorithms.first()
533 .copied()
534 .unwrap_or(Algorithm::RS256);
535 let mut validation = Validation::new(primary_algorithm);
536
537 validation.validate_exp = options.validate_expiry;
539 validation.validate_nbf = false; if let Some(leeway) = options.leeway_seconds {
543 validation.leeway = leeway;
544 }
545
546 let validation_result = decode::<JwtClaims>(token, &decoding_key, &validation);
548
549 let claims = match validation_result {
550 Ok(token_data) => token_data.claims,
551 Err(e) => {
552 tracing::warn!("JWT validation failed, attempting cache refresh: {}", e);
554
555 if let Err(clear_err) = clear_jwks_cache(jwks_url) {
557 tracing::error!("Failed to clear JWKS cache: {}", clear_err);
558 }
559
560 let fresh_key = get_decoding_key(jwks_url, &kid, cache_duration)
562 .await
563 .map_err(|refresh_err| {
564 tracing::error!("Failed to refresh JWKS: {}", refresh_err);
565 RunbeamError::JwtValidation(format!(
566 "Token validation failed and refresh failed: {}. Original error: {}",
567 refresh_err, e
568 ))
569 })?;
570
571 decode::<JwtClaims>(token, &fresh_key, &validation)
573 .map_err(|retry_err| {
574 tracing::error!("JWT validation failed after refresh: {}", retry_err);
575 RunbeamError::JwtValidation(format!("Token validation failed: {}", retry_err))
576 })?
577 .claims
578 }
579 };
580
581 tracing::debug!(
582 "JWT validation successful: iss={}, sub={}, aud={:?}",
583 claims.iss,
584 claims.sub,
585 claims.aud
586 );
587
588 if claims.iss.is_empty() {
590 return Err(RunbeamError::JwtValidation(
591 "Missing or empty issuer (iss) claim".to_string(),
592 ));
593 }
594
595 if claims.sub.is_empty() {
596 return Err(RunbeamError::JwtValidation(
597 "Missing or empty subject (sub) claim".to_string(),
598 ));
599 }
600
601 if let Some(required_claims) = &options.required_claims {
603 let claims_json = serde_json::to_value(&claims)
605 .map_err(|e| RunbeamError::JwtValidation(format!("Failed to serialize claims: {}", e)))?;
606
607 for required_claim in required_claims {
608 if claims_json.get(required_claim).is_none() {
609 return Err(RunbeamError::JwtValidation(format!(
610 "Required claim '{}' is missing from JWT",
611 required_claim
612 )));
613 }
614 }
615 tracing::debug!("All required claims present: {:?}", required_claims);
616 }
617
618 Ok(claims)
619}
620
621pub fn extract_bearer_token(auth_header: &str) -> Result<&str, RunbeamError> {
633 if !auth_header.starts_with("Bearer ") {
634 return Err(RunbeamError::JwtValidation(
635 "Authorization header must start with 'Bearer '".to_string(),
636 ));
637 }
638
639 let token = auth_header.trim_start_matches("Bearer ").trim();
640 if token.is_empty() {
641 return Err(RunbeamError::JwtValidation(
642 "Missing token in Authorization header".to_string(),
643 ));
644 }
645
646 Ok(token)
647}
648
649#[cfg(test)]
650mod tests {
651 use super::*;
652
653 #[test]
654 fn test_extract_bearer_token_valid() {
655 let header = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test";
656 let token = extract_bearer_token(header).unwrap();
657 assert_eq!(token, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test");
658 }
659
660 #[test]
661 fn test_extract_bearer_token_with_whitespace() {
662 let header = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test ";
663 let token = extract_bearer_token(header).unwrap();
664 assert_eq!(token, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test");
665 }
666
667 #[test]
668 fn test_extract_bearer_token_missing_bearer() {
669 let header = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test";
670 let result = extract_bearer_token(header);
671 assert!(result.is_err());
672 }
673
674 #[test]
675 fn test_extract_bearer_token_empty_token() {
676 let header = "Bearer ";
677 let result = extract_bearer_token(header);
678 assert!(result.is_err());
679 }
680
681 #[test]
682 fn test_jwt_claims_is_expired() {
683 let now = std::time::SystemTime::now()
684 .duration_since(std::time::UNIX_EPOCH)
685 .unwrap()
686 .as_secs() as i64;
687
688 let expired_claims = JwtClaims {
689 iss: "http://example.com".to_string(),
690 sub: "user123".to_string(),
691 aud: Some("runbeam-cli".to_string()),
692 exp: now - 3600, iat: now - 7200,
694 user: None,
695 team: None,
696 };
697
698 assert!(expired_claims.is_expired());
699
700 let valid_claims = JwtClaims {
701 iss: "http://example.com".to_string(),
702 sub: "user123".to_string(),
703 aud: Some("runbeam-cli".to_string()),
704 exp: now + 3600, iat: now,
706 user: None,
707 team: None,
708 };
709
710 assert!(!valid_claims.is_expired());
711 }
712}