use super::traits::{AuthContext, ClaimMappings};
use crate::error::{Error, ErrorCode, Result};
#[cfg(feature = "jwt-auth")]
use std::collections::HashMap;
#[cfg(feature = "jwt-auth")]
use std::sync::Arc;
use std::time::Duration;
#[cfg(feature = "jwt-auth")]
use std::time::Instant;
#[cfg(feature = "jwt-auth")]
use tokio::sync::RwLock;
#[cfg(feature = "jwt-auth")]
struct CachedJwks {
keys: HashMap<String, jsonwebtoken::DecodingKey>,
fetched_at: Instant,
ttl: Duration,
}
#[cfg(feature = "jwt-auth")]
impl std::fmt::Debug for CachedJwks {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CachedJwks")
.field("keys_count", &self.keys.len())
.field("fetched_at", &self.fetched_at)
.field("ttl", &self.ttl)
.finish()
}
}
#[cfg(feature = "jwt-auth")]
impl CachedJwks {
fn is_expired(&self) -> bool {
self.fetched_at.elapsed() > self.ttl
}
}
#[derive(Debug)]
pub struct JwtValidator {
#[cfg(feature = "jwt-auth")]
jwks_cache: Arc<RwLock<HashMap<String, CachedJwks>>>,
#[cfg(not(target_arch = "wasm32"))]
http_client: reqwest::Client,
cache_ttl: Duration,
}
impl Default for JwtValidator {
fn default() -> Self {
Self::new()
}
}
impl Clone for JwtValidator {
fn clone(&self) -> Self {
Self {
#[cfg(feature = "jwt-auth")]
jwks_cache: Arc::clone(&self.jwks_cache),
#[cfg(not(target_arch = "wasm32"))]
http_client: self.http_client.clone(),
cache_ttl: self.cache_ttl,
}
}
}
impl JwtValidator {
#[cfg(not(target_arch = "wasm32"))]
pub fn new() -> Self {
Self::with_cache_ttl(Duration::from_secs(3600))
}
#[cfg(not(target_arch = "wasm32"))]
pub fn with_cache_ttl(cache_ttl: Duration) -> Self {
Self {
#[cfg(feature = "jwt-auth")]
jwks_cache: Arc::new(RwLock::new(HashMap::new())),
http_client: reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.expect("Failed to create HTTP client"),
cache_ttl,
}
}
#[cfg(not(target_arch = "wasm32"))]
pub fn with_http_client(http_client: reqwest::Client, cache_ttl: Duration) -> Self {
Self {
#[cfg(feature = "jwt-auth")]
jwks_cache: Arc::new(RwLock::new(HashMap::new())),
http_client,
cache_ttl,
}
}
#[cfg(all(not(target_arch = "wasm32"), feature = "jwt-auth"))]
pub async fn validate(&self, token: &str, config: &ValidationConfig) -> Result<AuthContext> {
use jsonwebtoken::{decode, decode_header, Algorithm, Validation};
let header = decode_header(token).map_err(|e| {
Error::protocol(
ErrorCode::AUTHENTICATION_REQUIRED,
format!("Invalid token header: {}", e),
)
})?;
let kid = header.kid.ok_or_else(|| {
Error::protocol(
ErrorCode::AUTHENTICATION_REQUIRED,
"Token missing key ID (kid)",
)
})?;
let key = self.get_key(&config.jwks_uri, &kid).await?;
let mut validation = Validation::new(Algorithm::RS256);
validation.set_issuer(&[&config.issuer]);
validation.set_audience(&[&config.audience]);
validation.leeway = config.leeway_seconds;
let token_data = decode::<serde_json::Value>(token, &key, &validation).map_err(|e| {
let msg = match e.kind() {
jsonwebtoken::errors::ErrorKind::ExpiredSignature => "Token expired",
jsonwebtoken::errors::ErrorKind::InvalidIssuer => "Invalid issuer",
jsonwebtoken::errors::ErrorKind::InvalidAudience => "Invalid audience",
jsonwebtoken::errors::ErrorKind::InvalidSignature => "Invalid signature",
jsonwebtoken::errors::ErrorKind::ImmatureSignature => "Token not yet valid",
_ => "Token validation failed",
};
Error::protocol(ErrorCode::AUTHENTICATION_REQUIRED, msg)
})?;
if let Some(ref required_use) = config.required_token_use {
if let Some(token_use) = token_data.claims.get("token_use").and_then(|v| v.as_str()) {
if token_use != required_use {
return Err(Error::protocol(
ErrorCode::AUTHENTICATION_REQUIRED,
format!(
"Invalid token_use: expected {}, got {}",
required_use, token_use
),
));
}
}
}
let normalized_claims = config.claim_mappings.normalize_claims(&token_data.claims);
let subject = normalized_claims
.get("sub")
.and_then(|v| v.as_str())
.unwrap_or_default()
.to_string();
if subject.is_empty() {
return Err(Error::protocol(
ErrorCode::AUTHENTICATION_REQUIRED,
"Token missing subject claim",
));
}
let scopes = parse_scopes(&token_data.claims);
let client_id = token_data
.claims
.get("azp")
.or_else(|| token_data.claims.get("client_id"))
.and_then(|v| v.as_str())
.map(String::from);
let expires_at = token_data.claims.get("exp").and_then(|v| v.as_u64());
Ok(AuthContext {
subject,
scopes,
claims: normalized_claims,
token: Some(token.to_string()),
client_id,
expires_at,
authenticated: true,
})
}
#[cfg(any(target_arch = "wasm32", not(feature = "jwt-auth")))]
pub async fn validate(&self, _token: &str, _config: &ValidationConfig) -> Result<AuthContext> {
Err(Error::protocol(
ErrorCode::METHOD_NOT_FOUND,
"JWT validation requires the 'jwt-auth' feature and non-WASM target",
))
}
#[cfg(all(not(target_arch = "wasm32"), feature = "jwt-auth"))]
async fn get_key(&self, jwks_uri: &str, kid: &str) -> Result<jsonwebtoken::DecodingKey> {
{
let cache = self.jwks_cache.read().await;
if let Some(cached) = cache.get(jwks_uri) {
if !cached.is_expired() {
if let Some(key) = cached.keys.get(kid) {
return Ok(key.clone());
}
}
}
}
self.refresh_jwks(jwks_uri).await?;
{
let cache = self.jwks_cache.read().await;
if let Some(cached) = cache.get(jwks_uri) {
if let Some(key) = cached.keys.get(kid) {
return Ok(key.clone());
}
}
}
Err(Error::protocol(
ErrorCode::AUTHENTICATION_REQUIRED,
format!("Unknown key ID: {}", kid),
))
}
#[cfg(all(not(target_arch = "wasm32"), feature = "jwt-auth"))]
async fn refresh_jwks(&self, jwks_uri: &str) -> Result<()> {
tracing::debug!(jwks_uri = %jwks_uri, "Fetching JWKS");
let response = self
.http_client
.get(jwks_uri)
.send()
.await
.map_err(|e| Error::internal(format!("Failed to fetch JWKS: {}", e)))?;
if !response.status().is_success() {
return Err(Error::internal(format!(
"JWKS endpoint returned status {}",
response.status()
)));
}
let jwks: JwksResponse = response
.json()
.await
.map_err(|e| Error::internal(format!("Failed to parse JWKS: {}", e)))?;
let mut keys = HashMap::new();
for key in jwks.keys {
if let (Some(kid), Some(n), Some(e)) = (&key.kid, &key.n, &key.e) {
match jsonwebtoken::DecodingKey::from_rsa_components(n, e) {
Ok(decoding_key) => {
keys.insert(kid.clone(), decoding_key);
},
Err(err) => {
tracing::warn!(kid = %kid, error = %err, "Failed to parse JWK");
},
}
}
}
if keys.is_empty() {
return Err(Error::internal("No valid keys found in JWKS"));
}
tracing::info!(jwks_uri = %jwks_uri, keys_count = keys.len(), "Cached JWKS keys");
let mut cache = self.jwks_cache.write().await;
let cached = CachedJwks {
keys,
fetched_at: Instant::now(),
ttl: self.cache_ttl,
};
cache.insert(jwks_uri.to_string(), cached);
Ok(())
}
#[cfg(feature = "jwt-auth")]
pub async fn clear_cache(&self) {
let mut cache = self.jwks_cache.write().await;
cache.clear();
}
#[cfg(feature = "jwt-auth")]
pub async fn clear_issuer_cache(&self, jwks_uri: &str) {
let mut cache = self.jwks_cache.write().await;
cache.remove(jwks_uri);
}
#[cfg(feature = "jwt-auth")]
pub async fn cache_size(&self) -> usize {
let cache = self.jwks_cache.read().await;
cache.len()
}
}
#[derive(Debug, Clone)]
pub struct ValidationConfig {
pub issuer: String,
pub jwks_uri: String,
pub audience: String,
pub leeway_seconds: u64,
pub required_token_use: Option<String>,
pub claim_mappings: ClaimMappings,
}
impl ValidationConfig {
pub fn new(
issuer: impl Into<String>,
jwks_uri: impl Into<String>,
audience: impl Into<String>,
) -> Self {
Self {
issuer: issuer.into(),
jwks_uri: jwks_uri.into(),
audience: audience.into(),
leeway_seconds: 60,
required_token_use: None,
claim_mappings: ClaimMappings::default(),
}
}
pub fn cognito(region: &str, user_pool_id: &str, client_id: &str) -> Self {
let issuer = format!(
"https://cognito-idp.{}.amazonaws.com/{}",
region, user_pool_id
);
let jwks_uri = format!("{}/.well-known/jwks.json", issuer);
Self {
issuer,
jwks_uri,
audience: client_id.to_string(),
leeway_seconds: 60,
required_token_use: Some("access".to_string()),
claim_mappings: ClaimMappings::cognito(),
}
}
pub fn google(client_id: &str) -> Self {
Self {
issuer: "https://accounts.google.com".to_string(),
jwks_uri: "https://www.googleapis.com/oauth2/v3/certs".to_string(),
audience: client_id.to_string(),
leeway_seconds: 60,
required_token_use: None,
claim_mappings: ClaimMappings::google(),
}
}
pub fn auth0(domain: &str, client_id: &str) -> Self {
let issuer = format!("https://{}/", domain);
let jwks_uri = format!("https://{}/.well-known/jwks.json", domain);
Self {
issuer,
jwks_uri,
audience: client_id.to_string(),
leeway_seconds: 60,
required_token_use: None,
claim_mappings: ClaimMappings::auth0(),
}
}
pub fn okta(domain: &str, client_id: &str) -> Self {
let issuer = format!("https://{}", domain);
let jwks_uri = format!("https://{}/oauth2/v1/keys", domain);
Self {
issuer,
jwks_uri,
audience: client_id.to_string(),
leeway_seconds: 60,
required_token_use: None,
claim_mappings: ClaimMappings::okta(),
}
}
pub fn entra(tenant_id: &str, client_id: &str) -> Self {
let issuer = format!("https://login.microsoftonline.com/{}/v2.0", tenant_id);
let jwks_uri = format!(
"https://login.microsoftonline.com/{}/discovery/v2.0/keys",
tenant_id
);
Self {
issuer,
jwks_uri,
audience: client_id.to_string(),
leeway_seconds: 60,
required_token_use: None,
claim_mappings: ClaimMappings::entra(),
}
}
pub fn with_leeway(mut self, seconds: u64) -> Self {
self.leeway_seconds = seconds;
self
}
pub fn with_required_token_use(mut self, token_use: impl Into<String>) -> Self {
self.required_token_use = Some(token_use.into());
self
}
pub fn with_claim_mappings(mut self, mappings: ClaimMappings) -> Self {
self.claim_mappings = mappings;
self
}
}
#[cfg(all(feature = "jwt-auth", not(target_arch = "wasm32")))]
fn parse_scopes(claims: &serde_json::Value) -> Vec<String> {
if let Some(scope) = claims.get("scope") {
if let Some(s) = scope.as_str() {
return s.split_whitespace().map(String::from).collect();
}
if let Some(arr) = scope.as_array() {
return arr
.iter()
.filter_map(|v| v.as_str())
.map(String::from)
.collect();
}
}
if let Some(scp) = claims.get("scp") {
if let Some(arr) = scp.as_array() {
return arr
.iter()
.filter_map(|v| v.as_str())
.map(String::from)
.collect();
}
if let Some(s) = scp.as_str() {
return s.split_whitespace().map(String::from).collect();
}
}
Vec::new()
}
#[cfg(feature = "jwt-auth")]
#[derive(Debug, serde::Deserialize)]
struct JwksResponse {
keys: Vec<JwkKey>,
}
#[cfg(feature = "jwt-auth")]
#[derive(Debug, serde::Deserialize)]
struct JwkKey {
kid: Option<String>,
#[allow(dead_code)]
kty: String,
n: Option<String>,
e: Option<String>,
#[allow(dead_code)]
alg: Option<String>,
#[serde(rename = "use")]
#[allow(dead_code)]
key_use: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validation_config_cognito() {
let config = ValidationConfig::cognito("us-east-1", "us-east-1_xxxxx", "client-123");
assert_eq!(
config.issuer,
"https://cognito-idp.us-east-1.amazonaws.com/us-east-1_xxxxx"
);
assert_eq!(
config.jwks_uri,
"https://cognito-idp.us-east-1.amazonaws.com/us-east-1_xxxxx/.well-known/jwks.json"
);
assert_eq!(config.audience, "client-123");
assert_eq!(config.required_token_use, Some("access".to_string()));
}
#[test]
fn test_validation_config_google() {
let config = ValidationConfig::google("client-123.apps.googleusercontent.com");
assert_eq!(config.issuer, "https://accounts.google.com");
assert_eq!(
config.jwks_uri,
"https://www.googleapis.com/oauth2/v3/certs"
);
assert!(config.required_token_use.is_none());
}
#[test]
fn test_validation_config_auth0() {
let config = ValidationConfig::auth0("tenant.auth0.com", "client-123");
assert_eq!(config.issuer, "https://tenant.auth0.com/");
assert_eq!(
config.jwks_uri,
"https://tenant.auth0.com/.well-known/jwks.json"
);
}
#[test]
fn test_validation_config_okta() {
let config = ValidationConfig::okta("dev-123.okta.com", "client-123");
assert_eq!(config.issuer, "https://dev-123.okta.com");
assert_eq!(config.jwks_uri, "https://dev-123.okta.com/oauth2/v1/keys");
}
#[test]
fn test_validation_config_entra() {
let config = ValidationConfig::entra("tenant-id-123", "client-123");
assert_eq!(
config.issuer,
"https://login.microsoftonline.com/tenant-id-123/v2.0"
);
assert_eq!(
config.jwks_uri,
"https://login.microsoftonline.com/tenant-id-123/discovery/v2.0/keys"
);
}
#[test]
fn test_validation_config_builder() {
let config = ValidationConfig::new(
"https://issuer.example.com",
"https://issuer.example.com/.well-known/jwks.json",
"my-audience",
)
.with_leeway(120)
.with_required_token_use("access");
assert_eq!(config.leeway_seconds, 120);
assert_eq!(config.required_token_use, Some("access".to_string()));
}
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_jwt_validator_creation() {
let validator = JwtValidator::new();
assert_eq!(validator.cache_ttl, Duration::from_secs(3600));
}
#[cfg(not(target_arch = "wasm32"))]
#[test]
fn test_jwt_validator_custom_ttl() {
let validator = JwtValidator::with_cache_ttl(Duration::from_secs(7200));
assert_eq!(validator.cache_ttl, Duration::from_secs(7200));
}
#[cfg(all(feature = "jwt-auth", not(target_arch = "wasm32")))]
#[test]
fn test_jwt_validator_clone() {
use std::sync::Arc;
let validator1 = JwtValidator::new();
let validator2 = validator1.clone();
assert!(Arc::ptr_eq(&validator1.jwks_cache, &validator2.jwks_cache));
}
#[cfg(all(feature = "jwt-auth", not(target_arch = "wasm32")))]
#[test]
fn test_parse_scopes_space_separated() {
let claims = serde_json::json!({
"scope": "read write admin"
});
let scopes = parse_scopes(&claims);
assert_eq!(scopes, vec!["read", "write", "admin"]);
}
#[cfg(all(feature = "jwt-auth", not(target_arch = "wasm32")))]
#[test]
fn test_parse_scopes_array() {
let claims = serde_json::json!({
"scope": ["read", "write", "admin"]
});
let scopes = parse_scopes(&claims);
assert_eq!(scopes, vec!["read", "write", "admin"]);
}
#[cfg(all(feature = "jwt-auth", not(target_arch = "wasm32")))]
#[test]
fn test_parse_scopes_scp_array() {
let claims = serde_json::json!({
"scp": ["User.Read", "User.Write"]
});
let scopes = parse_scopes(&claims);
assert_eq!(scopes, vec!["User.Read", "User.Write"]);
}
#[cfg(all(feature = "jwt-auth", not(target_arch = "wasm32")))]
#[test]
fn test_parse_scopes_scp_string() {
let claims = serde_json::json!({
"scp": "User.Read User.Write"
});
let scopes = parse_scopes(&claims);
assert_eq!(scopes, vec!["User.Read", "User.Write"]);
}
#[cfg(all(feature = "jwt-auth", not(target_arch = "wasm32")))]
#[test]
fn test_parse_scopes_empty() {
let claims = serde_json::json!({});
let scopes = parse_scopes(&claims);
assert!(scopes.is_empty());
}
#[cfg(all(feature = "jwt-auth", not(target_arch = "wasm32")))]
#[tokio::test]
async fn test_clear_cache() {
let validator = JwtValidator::new();
assert_eq!(validator.cache_size().await, 0);
validator.clear_cache().await;
assert_eq!(validator.cache_size().await, 0);
}
}