use crate::runbeam_api::types::{RunbeamError, TeamInfo, UserInfo};
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use once_cell::sync::Lazy;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Deserialize)]
pub struct Jwks {
pub keys: Vec<JwkKey>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct JwkKey {
pub kty: String,
#[serde(rename = "use")]
pub key_use: Option<String>,
pub kid: String,
pub alg: Option<String>,
pub n: String,
pub e: String,
}
impl JwkKey {
pub fn to_decoding_key(&self) -> Result<DecodingKey, RunbeamError> {
if self.kty != "RSA" {
return Err(RunbeamError::JwtValidation(format!(
"Unsupported key type: {}. Only RSA is supported.",
self.kty
)));
}
DecodingKey::from_rsa_components(&self.n, &self.e).map_err(|e| {
RunbeamError::JwtValidation(format!(
"Failed to create RSA decoding key from JWK components: {}",
e
))
})
}
}
#[derive(Debug, Clone)]
pub struct JwtValidationOptions {
pub trusted_issuers: Option<Vec<String>>,
pub jwks_uri: Option<String>,
pub algorithms: Option<Vec<Algorithm>>,
pub required_claims: Option<Vec<String>>,
pub leeway_seconds: Option<u64>,
pub validate_expiry: bool,
pub jwks_cache_duration_hours: u64,
}
impl Default for JwtValidationOptions {
fn default() -> Self {
Self {
trusted_issuers: None,
jwks_uri: None,
algorithms: None,
required_claims: None,
leeway_seconds: None,
validate_expiry: true,
jwks_cache_duration_hours: 24,
}
}
}
impl JwtValidationOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_trusted_issuers(mut self, issuers: Vec<String>) -> Self {
self.trusted_issuers = Some(issuers);
self
}
pub fn with_jwks_uri(mut self, uri: String) -> Self {
self.jwks_uri = Some(uri);
self
}
pub fn with_algorithms(mut self, algorithms: Vec<Algorithm>) -> Self {
self.algorithms = Some(algorithms);
self
}
pub fn with_required_claims(mut self, claims: Vec<String>) -> Self {
self.required_claims = Some(claims);
self
}
pub fn with_leeway_seconds(mut self, leeway: u64) -> Self {
self.leeway_seconds = Some(leeway.min(300)); self
}
pub fn with_validate_expiry(mut self, validate: bool) -> Self {
self.validate_expiry = validate;
self
}
pub fn with_jwks_cache_duration_hours(mut self, hours: u64) -> Self {
self.jwks_cache_duration_hours = hours;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwtClaims {
pub iss: String,
pub sub: String,
#[serde(default)]
pub aud: Option<String>,
pub exp: i64,
pub iat: i64,
#[serde(default)]
pub user: Option<UserInfo>,
#[serde(default)]
pub team: Option<TeamInfo>,
}
struct JwksCache {
keys: HashMap<String, DecodingKey>,
last_fetched: Instant,
}
impl JwksCache {
fn is_expired(&self, cache_duration: Duration) -> bool {
self.last_fetched.elapsed() > cache_duration
}
}
static JWKS_CACHE: Lazy<Arc<RwLock<HashMap<String, JwksCache>>>> =
Lazy::new(|| Arc::new(RwLock::new(HashMap::new())));
async fn get_decoding_key(
issuer: &str,
kid: &str,
cache_duration: Duration,
) -> Result<DecodingKey, RunbeamError> {
{
let cache = JWKS_CACHE
.read()
.map_err(|e| RunbeamError::JwtValidation(format!("Cache lock poisoned: {}", e)))?;
if let Some(cache_entry) = cache.get(issuer) {
if !cache_entry.is_expired(cache_duration) {
if let Some(key) = cache_entry.keys.get(kid) {
tracing::debug!("JWKS cache hit for issuer={}, kid={}", issuer, kid);
return Ok(key.clone());
} else {
tracing::debug!("JWKS cache miss: kid '{}' not found in cached keys", kid);
}
} else {
tracing::debug!("JWKS cache expired for issuer={}", issuer);
}
} else {
tracing::debug!("JWKS cache miss for issuer={}", issuer);
}
}
{
let cache = JWKS_CACHE
.write()
.map_err(|e| RunbeamError::JwtValidation(format!("Cache lock poisoned: {}", e)))?;
if let Some(cache_entry) = cache.get(issuer) {
if !cache_entry.is_expired(cache_duration) {
if let Some(key) = cache_entry.keys.get(kid) {
tracing::debug!(
"JWKS cache hit after lock acquisition for issuer={}, kid={}",
issuer,
kid
);
return Ok(key.clone());
}
}
}
}
tracing::info!("Fetching fresh JWKS for issuer={}", issuer);
let jwks = fetch_jwks(issuer).await?;
let mut keys_map = HashMap::new();
for jwk in &jwks.keys {
match jwk.to_decoding_key() {
Ok(key) => {
keys_map.insert(jwk.kid.clone(), key);
}
Err(e) => {
tracing::warn!(
"Failed to convert JWK kid='{}' to decoding key: {}",
jwk.kid,
e
);
}
}
}
let decoding_key = keys_map
.get(kid)
.ok_or_else(|| {
RunbeamError::JwtValidation(format!(
"Key ID '{}' not found in JWKS from issuer {}",
kid, issuer
))
})?
.clone();
{
let mut cache = JWKS_CACHE
.write()
.map_err(|e| RunbeamError::JwtValidation(format!("Cache lock poisoned: {}", e)))?;
cache.insert(
issuer.to_string(),
JwksCache {
keys: keys_map,
last_fetched: Instant::now(),
},
);
}
tracing::debug!("JWKS cache updated for issuer={}", issuer);
Ok(decoding_key)
}
fn clear_jwks_cache(issuer: &str) -> Result<(), RunbeamError> {
let mut cache = JWKS_CACHE
.write()
.map_err(|e| RunbeamError::JwtValidation(format!("Cache lock poisoned: {}", e)))?;
if cache.remove(issuer).is_some() {
tracing::debug!("Cleared JWKS cache for issuer={}", issuer);
}
Ok(())
}
async fn fetch_jwks(issuer: &str) -> Result<Jwks, RunbeamError> {
let jwks_url = format!("{}/api/.well-known/jwks.json", issuer.trim_end_matches('/'));
tracing::debug!("Fetching JWKS from: {}", jwks_url);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.map_err(|e| RunbeamError::JwtValidation(format!("Failed to create HTTP client: {}", e)))?;
let response = client.get(&jwks_url).send().await.map_err(|e| {
tracing::error!("Failed to fetch JWKS from {}: {}", jwks_url, e);
if e.is_timeout() {
RunbeamError::JwtValidation(format!("JWKS endpoint timeout: {}", jwks_url))
} else if e.is_connect() {
RunbeamError::JwtValidation(format!("Failed to connect to JWKS endpoint: {}", jwks_url))
} else {
RunbeamError::JwtValidation(format!("Network error fetching JWKS: {}", e))
}
})?;
let status = response.status();
if !status.is_success() {
tracing::error!(
"JWKS endpoint returned HTTP {}: {}",
status.as_u16(),
jwks_url
);
return Err(RunbeamError::JwtValidation(format!(
"JWKS endpoint returned HTTP {}",
status.as_u16()
)));
}
let jwks = response.json::<Jwks>().await.map_err(|e| {
tracing::error!("Failed to parse JWKS response from {}: {}", jwks_url, e);
RunbeamError::JwtValidation(format!("Invalid JWKS response: {}", e))
})?;
tracing::info!(
"Successfully fetched JWKS with {} keys from {}",
jwks.keys.len(),
jwks_url
);
Ok(jwks)
}
impl JwtClaims {
pub fn api_base_url(&self) -> String {
if let Ok(url) = url::Url::parse(&self.iss) {
let scheme = url.scheme();
let host = url.host_str().unwrap_or("");
let port = url.port().map(|p| format!(":{}", p)).unwrap_or_default();
format!("{}://{}{}", scheme, host, port)
} else {
self.iss.clone()
}
}
pub fn is_expired(&self) -> bool {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
self.exp < now
}
}
pub async fn validate_jwt_token(
token: &str,
options: &JwtValidationOptions,
) -> Result<JwtClaims, RunbeamError> {
tracing::debug!("Validating JWT token (length: {})", token.len());
let header = decode_header(token)
.map_err(|e| RunbeamError::JwtValidation(format!("Invalid JWT header: {}", e)))?;
let kid = header.kid.ok_or_else(|| {
RunbeamError::JwtValidation("Missing 'kid' (key ID) in JWT header".to_string())
})?;
let allowed_algorithms = options.algorithms.as_deref()
.unwrap_or(&[Algorithm::RS256]);
if !allowed_algorithms.contains(&header.alg) {
return Err(RunbeamError::JwtValidation(format!(
"Algorithm {:?} not in allowed list: {:?}",
header.alg, allowed_algorithms
)));
}
tracing::debug!("JWT header decoded: alg={:?}, kid={}", header.alg, kid);
let insecure_token_data = jsonwebtoken::dangerous::insecure_decode::<JwtClaims>(token)
.map_err(|e| RunbeamError::JwtValidation(format!("Failed to decode JWT: {}", e)))?;
let issuer = &insecure_token_data.claims.iss;
if issuer.is_empty() {
return Err(RunbeamError::JwtValidation(
"Missing or empty issuer (iss) claim".to_string(),
));
}
tracing::debug!("JWT issuer extracted: {}", issuer);
if let Some(trusted_issuers) = &options.trusted_issuers {
let issuer_base_url = insecure_token_data.claims.api_base_url();
let is_trusted = trusted_issuers.iter().any(|trusted| {
issuer == trusted || issuer_base_url == *trusted || issuer.starts_with(trusted)
});
if !is_trusted {
return Err(RunbeamError::JwtValidation(format!(
"Issuer '{}' is not in the trusted issuers list",
issuer
)));
}
tracing::debug!("Issuer validated against trusted list");
} else {
tracing::warn!(
"⚠️ SECURITY WARNING: No trusted_issuers configured! Accepting JWT from ANY issuer: '{}'. \
This is a security risk - an attacker can issue their own tokens from a malicious JWKS endpoint.",
issuer
);
}
let base_url = insecure_token_data.claims.api_base_url();
tracing::debug!("JWT issuer base URL: {}", base_url);
let jwks_url = options.jwks_uri.as_deref()
.unwrap_or(&base_url);
let cache_duration = Duration::from_secs(options.jwks_cache_duration_hours * 3600);
let decoding_key = match get_decoding_key(jwks_url, &kid, cache_duration).await {
Ok(key) => key,
Err(e) => {
tracing::warn!("Initial JWKS fetch/cache lookup failed: {}", e);
return Err(e);
}
};
let primary_algorithm = allowed_algorithms.first()
.copied()
.unwrap_or(Algorithm::RS256);
let mut validation = Validation::new(primary_algorithm);
validation.validate_exp = options.validate_expiry;
validation.validate_nbf = false;
if let Some(leeway) = options.leeway_seconds {
validation.leeway = leeway;
}
let validation_result = decode::<JwtClaims>(token, &decoding_key, &validation);
let claims = match validation_result {
Ok(token_data) => token_data.claims,
Err(e) => {
tracing::warn!("JWT validation failed, attempting cache refresh: {}", e);
if let Err(clear_err) = clear_jwks_cache(jwks_url) {
tracing::error!("Failed to clear JWKS cache: {}", clear_err);
}
let fresh_key = get_decoding_key(jwks_url, &kid, cache_duration)
.await
.map_err(|refresh_err| {
tracing::error!("Failed to refresh JWKS: {}", refresh_err);
RunbeamError::JwtValidation(format!(
"Token validation failed and refresh failed: {}. Original error: {}",
refresh_err, e
))
})?;
decode::<JwtClaims>(token, &fresh_key, &validation)
.map_err(|retry_err| {
tracing::error!("JWT validation failed after refresh: {}", retry_err);
RunbeamError::JwtValidation(format!("Token validation failed: {}", retry_err))
})?
.claims
}
};
tracing::debug!(
"JWT validation successful: iss={}, sub={}, aud={:?}",
claims.iss,
claims.sub,
claims.aud
);
if claims.iss.is_empty() {
return Err(RunbeamError::JwtValidation(
"Missing or empty issuer (iss) claim".to_string(),
));
}
if claims.sub.is_empty() {
return Err(RunbeamError::JwtValidation(
"Missing or empty subject (sub) claim".to_string(),
));
}
if let Some(required_claims) = &options.required_claims {
let claims_json = serde_json::to_value(&claims)
.map_err(|e| RunbeamError::JwtValidation(format!("Failed to serialize claims: {}", e)))?;
for required_claim in required_claims {
if claims_json.get(required_claim).is_none() {
return Err(RunbeamError::JwtValidation(format!(
"Required claim '{}' is missing from JWT",
required_claim
)));
}
}
tracing::debug!("All required claims present: {:?}", required_claims);
}
Ok(claims)
}
pub fn extract_bearer_token(auth_header: &str) -> Result<&str, RunbeamError> {
if !auth_header.starts_with("Bearer ") {
return Err(RunbeamError::JwtValidation(
"Authorization header must start with 'Bearer '".to_string(),
));
}
let token = auth_header.trim_start_matches("Bearer ").trim();
if token.is_empty() {
return Err(RunbeamError::JwtValidation(
"Missing token in Authorization header".to_string(),
));
}
Ok(token)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_bearer_token_valid() {
let header = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test";
let token = extract_bearer_token(header).unwrap();
assert_eq!(token, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test");
}
#[test]
fn test_extract_bearer_token_with_whitespace() {
let header = "Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test ";
let token = extract_bearer_token(header).unwrap();
assert_eq!(token, "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test");
}
#[test]
fn test_extract_bearer_token_missing_bearer() {
let header = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.test";
let result = extract_bearer_token(header);
assert!(result.is_err());
}
#[test]
fn test_extract_bearer_token_empty_token() {
let header = "Bearer ";
let result = extract_bearer_token(header);
assert!(result.is_err());
}
#[test]
fn test_jwt_claims_is_expired() {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs() as i64;
let expired_claims = JwtClaims {
iss: "http://example.com".to_string(),
sub: "user123".to_string(),
aud: Some("runbeam-cli".to_string()),
exp: now - 3600, iat: now - 7200,
user: None,
team: None,
};
assert!(expired_claims.is_expired());
let valid_claims = JwtClaims {
iss: "http://example.com".to_string(),
sub: "user123".to_string(),
aud: Some("runbeam-cli".to_string()),
exp: now + 3600, iat: now,
user: None,
team: None,
};
assert!(!valid_claims.is_expired());
}
}