use anyhow::{anyhow, Context, Result};
use jsonwebtoken::{decode, decode_header, Algorithm, DecodingKey, Validation};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Claims {
pub sub: String, pub email: String, pub iss: String, pub aud: String, pub exp: usize, pub iat: usize, #[serde(default)]
pub name: Option<String>, #[serde(default)]
pub groups: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
struct JwksResponse {
keys: Vec<Jwk>,
}
#[derive(Debug, Deserialize, Clone)]
struct Jwk {
#[serde(rename = "use")]
key_use: Option<String>, kty: String,
kid: String,
#[allow(dead_code)]
alg: Option<String>, n: String,
e: String,
}
#[derive(Debug, Deserialize)]
struct OidcDiscovery {
jwks_uri: String,
}
#[derive(Clone)]
struct JwksCache {
keys: HashMap<String, DecodingKey>,
fetched_at: Instant,
ttl: Duration,
}
impl JwksCache {
fn new(keys: HashMap<String, DecodingKey>) -> Self {
Self {
keys,
fetched_at: Instant::now(),
ttl: Duration::from_secs(3600), }
}
fn is_expired(&self) -> bool {
self.fetched_at.elapsed() > self.ttl
}
}
pub struct JwtValidator {
jwks_cache: Arc<RwLock<HashMap<String, JwksCache>>>,
http_client: reqwest::Client,
}
impl JwtValidator {
pub fn new() -> Self {
Self {
jwks_cache: Arc::new(RwLock::new(HashMap::new())),
http_client: reqwest::Client::new(),
}
}
async fn discover_jwks_uri(&self, issuer_url: &str) -> Result<String> {
let discovery_url = format!("{}/.well-known/openid-configuration", issuer_url);
tracing::debug!("Discovering OIDC configuration from {}", discovery_url);
let response = self
.http_client
.get(&discovery_url)
.send()
.await
.context("Failed to fetch OIDC discovery document")?;
let discovery: OidcDiscovery = response
.json()
.await
.context("Failed to parse OIDC discovery document")?;
Ok(discovery.jwks_uri)
}
async fn fetch_jwks(&self, jwks_uri: &str) -> Result<HashMap<String, DecodingKey>> {
tracing::debug!("Fetching JWKS from {}", jwks_uri);
let response = self
.http_client
.get(jwks_uri)
.send()
.await
.context("Failed to fetch JWKS")?;
let response_text = response
.text()
.await
.context("Failed to read JWKS response body")?;
tracing::debug!("JWKS response: {}", response_text);
let jwks: JwksResponse = serde_json::from_str(&response_text)
.map_err(|e| anyhow!("Failed to parse JWKS response: {}", e))?;
let mut keys = HashMap::new();
for jwk in jwks.keys {
if jwk.kty == "RSA" && (jwk.key_use.is_none() || jwk.key_use.as_deref() == Some("sig"))
{
let decoding_key = DecodingKey::from_rsa_components(&jwk.n, &jwk.e)
.context("Failed to create decoding key from JWK")?;
keys.insert(jwk.kid.clone(), decoding_key);
tracing::debug!(
"Loaded JWK with kid: {}, use: {:?}, alg: {:?}",
jwk.kid,
jwk.key_use,
jwk.alg
);
}
}
tracing::info!("Loaded {} signing keys from JWKS", keys.len());
Ok(keys)
}
async fn get_jwks(&self, issuer_url: &str) -> Result<HashMap<String, DecodingKey>> {
{
let cache = self.jwks_cache.read().await;
if let Some(cached) = cache.get(issuer_url) {
if !cached.is_expired() {
tracing::debug!("Using cached JWKS for {}", issuer_url);
return Ok(cached.keys.clone());
} else {
tracing::debug!("JWKS cache expired for {}", issuer_url);
}
}
}
tracing::info!("Fetching fresh JWKS for {}", issuer_url);
let jwks_uri = self.discover_jwks_uri(issuer_url).await?;
let keys = self.fetch_jwks(&jwks_uri).await?;
{
let mut cache = self.jwks_cache.write().await;
cache.insert(issuer_url.to_string(), JwksCache::new(keys.clone()));
}
Ok(keys)
}
fn validate_custom_claims(
jwt_claims: &serde_json::Value,
expected_claims: &HashMap<String, String>,
) -> Result<()> {
let claims_obj = jwt_claims
.as_object()
.ok_or_else(|| anyhow!("JWT claims is not an object"))?;
for (key, expected_value) in expected_claims {
let actual_value = claims_obj
.get(key)
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow!("Claim '{}' not found or not a string", key))?;
if expected_value.contains('*') {
if !Self::matches_wildcard_pattern(expected_value, actual_value) {
return Err(anyhow!(
"Claim mismatch: '{}' pattern '{}' does not match '{}'",
key,
expected_value,
actual_value
));
}
} else {
if actual_value != expected_value {
return Err(anyhow!(
"Claim mismatch: '{}' expected '{}', got '{}'",
key,
expected_value,
actual_value
));
}
}
}
Ok(())
}
fn matches_wildcard_pattern(pattern: &str, text: &str) -> bool {
let parts: Vec<&str> = pattern.split('*').collect();
if parts.len() == 1 {
return pattern == text;
}
let mut pos = 0;
for (i, part) in parts.iter().enumerate() {
if i == 0 {
if !text.starts_with(part) {
return false;
}
pos = part.len();
} else if i == parts.len() - 1 {
if !text[pos..].ends_with(part) {
return false;
}
} else {
if let Some(found_pos) = text[pos..].find(part) {
pos += found_pos + part.len();
} else {
return false;
}
}
}
true
}
pub async fn validate(
&self,
token: &str,
issuer_url: &str,
expected_claims: &HashMap<String, String>,
) -> Result<serde_json::Value> {
let header = decode_header(token).context("Failed to decode JWT header")?;
let kid = header
.kid
.ok_or_else(|| anyhow!("JWT header missing kid"))?;
let keys = self.get_jwks(issuer_url).await?;
let key = keys
.get(&kid)
.ok_or_else(|| anyhow!("Key {} not found in JWKS for issuer {}", kid, issuer_url))?;
let mut validation = Validation::new(Algorithm::RS256);
validation.set_issuer(&[issuer_url]);
validation.validate_aud = false;
let token_data = decode::<serde_json::Value>(token, key, &validation)
.context("Failed to validate JWT token")?;
if let Some(exp) = token_data.claims.get("exp").and_then(|v| v.as_u64()) {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
if now > exp {
return Err(anyhow!("Token has expired"));
}
}
Self::validate_custom_claims(&token_data.claims, expected_claims)?;
Ok(token_data.claims)
}
}
impl Default for JwtValidator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jwt_validator_creation() {
let validator = JwtValidator::new();
assert!(validator.jwks_cache.try_read().is_ok());
}
#[test]
fn test_claims_deserialization() {
let json = r#"{
"sub": "user123",
"email": "test@example.com",
"iss": "https://issuer.example.com",
"aud": "my-client-id",
"exp": 1234567890,
"iat": 1234567800
}"#;
let claims: Claims = serde_json::from_str(json).unwrap();
assert_eq!(claims.sub, "user123");
assert_eq!(claims.email, "test@example.com");
assert_eq!(claims.iss, "https://issuer.example.com");
assert_eq!(claims.aud, "my-client-id");
}
#[test]
fn test_claims_deserialization_with_unknown_fields() {
let json = r#"{
"sub": "user123",
"email": "test@example.com",
"email_verified": true,
"iss": "https://issuer.example.com",
"aud": "my-client-id",
"exp": 1234567890,
"iat": 1234567800,
"unknown_field": "should be ignored"
}"#;
let claims: Claims = serde_json::from_str(json).unwrap();
assert_eq!(claims.sub, "user123");
assert_eq!(claims.email, "test@example.com");
assert_eq!(claims.iss, "https://issuer.example.com");
assert_eq!(claims.aud, "my-client-id");
}
#[test]
fn test_validate_custom_claims_success() {
let jwt_claims = serde_json::json!({
"aud": "my-audience",
"project_path": "myorg/myrepo",
"extra": "value"
});
let mut expected = HashMap::new();
expected.insert("aud".to_string(), "my-audience".to_string());
expected.insert("project_path".to_string(), "myorg/myrepo".to_string());
let result = JwtValidator::validate_custom_claims(&jwt_claims, &expected);
assert!(result.is_ok());
}
#[test]
fn test_validate_custom_claims_missing() {
let jwt_claims = serde_json::json!({
"aud": "my-audience"
});
let mut expected = HashMap::new();
expected.insert("aud".to_string(), "my-audience".to_string());
expected.insert("project_path".to_string(), "myorg/myrepo".to_string());
let result = JwtValidator::validate_custom_claims(&jwt_claims, &expected);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("project_path"));
}
#[test]
fn test_validate_custom_claims_mismatch() {
let jwt_claims = serde_json::json!({
"aud": "wrong-audience",
"project_path": "myorg/myrepo"
});
let mut expected = HashMap::new();
expected.insert("aud".to_string(), "my-audience".to_string());
let result = JwtValidator::validate_custom_claims(&jwt_claims, &expected);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("mismatch"));
}
#[test]
fn test_wildcard_pattern_prefix() {
assert!(JwtValidator::matches_wildcard_pattern("app*", "app"));
assert!(JwtValidator::matches_wildcard_pattern("app*", "app-mr/6"));
assert!(JwtValidator::matches_wildcard_pattern(
"app*",
"app-staging"
));
assert!(JwtValidator::matches_wildcard_pattern(
"app*",
"application"
));
assert!(!JwtValidator::matches_wildcard_pattern("app*", "myapp"));
assert!(!JwtValidator::matches_wildcard_pattern("app*", "webapp"));
}
#[test]
fn test_wildcard_pattern_suffix() {
assert!(JwtValidator::matches_wildcard_pattern("*-prod", "api-prod"));
assert!(JwtValidator::matches_wildcard_pattern("*-prod", "web-prod"));
assert!(JwtValidator::matches_wildcard_pattern(
"*-prod",
"my-service-prod"
));
assert!(!JwtValidator::matches_wildcard_pattern(
"*-prod",
"production"
));
assert!(!JwtValidator::matches_wildcard_pattern("*-prod", "prod"));
assert!(!JwtValidator::matches_wildcard_pattern(
"*-prod",
"api-prod-backup"
));
}
#[test]
fn test_wildcard_pattern_middle() {
assert!(JwtValidator::matches_wildcard_pattern(
"app-*-prod",
"app-staging-prod"
));
assert!(JwtValidator::matches_wildcard_pattern(
"app-*-prod",
"app-test-prod"
));
assert!(JwtValidator::matches_wildcard_pattern(
"app-*-prod",
"app-mr/6-prod"
));
assert!(!JwtValidator::matches_wildcard_pattern(
"app-*-prod",
"app-prod"
));
assert!(!JwtValidator::matches_wildcard_pattern(
"app-*-prod",
"app-staging"
));
assert!(!JwtValidator::matches_wildcard_pattern(
"app-*-prod",
"web-staging-prod"
));
}
#[test]
fn test_wildcard_pattern_multiple() {
assert!(JwtValidator::matches_wildcard_pattern(
"*-app-*",
"my-app-staging"
));
assert!(JwtValidator::matches_wildcard_pattern(
"*-app-*",
"test-app-mr/6"
));
assert!(JwtValidator::matches_wildcard_pattern(
"*-app-*",
"web-app-prod"
));
assert!(!JwtValidator::matches_wildcard_pattern(
"*-app-*",
"my-application"
));
assert!(!JwtValidator::matches_wildcard_pattern("*-app-*", "app"));
}
#[test]
fn test_wildcard_pattern_edge_cases() {
assert!(JwtValidator::matches_wildcard_pattern("*", "anything"));
assert!(JwtValidator::matches_wildcard_pattern("*", ""));
assert!(JwtValidator::matches_wildcard_pattern(
"app*",
"application"
));
assert!(JwtValidator::matches_wildcard_pattern("app*", "app"));
assert!(JwtValidator::matches_wildcard_pattern("*app", "myapp"));
assert!(JwtValidator::matches_wildcard_pattern("*app", "app"));
assert!(JwtValidator::matches_wildcard_pattern(
"app-*",
"app-staging"
));
assert!(JwtValidator::matches_wildcard_pattern("app-*", "app-")); assert!(!JwtValidator::matches_wildcard_pattern("app-*", "app"));
assert!(JwtValidator::matches_wildcard_pattern(
"app**prod",
"appprod"
));
assert!(JwtValidator::matches_wildcard_pattern(
"app**prod",
"app-staging-prod"
));
assert!(JwtValidator::matches_wildcard_pattern(
"app***prod",
"app-test-prod"
));
assert!(!JwtValidator::matches_wildcard_pattern("app*", ""));
assert!(!JwtValidator::matches_wildcard_pattern("*app", "ap"));
}
#[test]
fn test_validate_custom_claims_with_wildcard() {
let jwt_claims = serde_json::json!({
"aud": "my-audience",
"environment": "app-mr/6"
});
let mut expected = HashMap::new();
expected.insert("aud".to_string(), "my-audience".to_string());
expected.insert("environment".to_string(), "app*".to_string());
let result = JwtValidator::validate_custom_claims(&jwt_claims, &expected);
assert!(result.is_ok());
}
#[test]
fn test_validate_custom_claims_with_wildcard_no_match() {
let jwt_claims = serde_json::json!({
"aud": "my-audience",
"environment": "webapp-staging"
});
let mut expected = HashMap::new();
expected.insert("aud".to_string(), "my-audience".to_string());
expected.insert("environment".to_string(), "app*".to_string());
let result = JwtValidator::validate_custom_claims(&jwt_claims, &expected);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("pattern"));
}
#[test]
fn test_validate_custom_claims_mixed_exact_and_wildcard() {
let jwt_claims = serde_json::json!({
"aud": "my-audience",
"project_path": "myorg/myrepo",
"environment": "app-mr/12"
});
let mut expected = HashMap::new();
expected.insert("aud".to_string(), "my-audience".to_string());
expected.insert("project_path".to_string(), "myorg/myrepo".to_string());
expected.insert("environment".to_string(), "app*".to_string());
let result = JwtValidator::validate_custom_claims(&jwt_claims, &expected);
assert!(result.is_ok());
}
#[test]
fn test_validate_custom_claims_wildcard_backward_compat() {
let jwt_claims = serde_json::json!({
"aud": "my-audience",
"environment": "production"
});
let mut expected = HashMap::new();
expected.insert("aud".to_string(), "my-audience".to_string());
expected.insert("environment".to_string(), "production".to_string());
let result = JwtValidator::validate_custom_claims(&jwt_claims, &expected);
assert!(result.is_ok());
let mut expected_wrong = HashMap::new();
expected_wrong.insert("aud".to_string(), "my-audience".to_string());
expected_wrong.insert("environment".to_string(), "staging".to_string());
let result_wrong = JwtValidator::validate_custom_claims(&jwt_claims, &expected_wrong);
assert!(result_wrong.is_err());
}
}