use std::sync::Arc;
use tracing::{debug, warn};
use crate::config::auth::{JwtAuthConfig, JwtProviderConfig};
use crate::control::security::identity::{AuthMethod, AuthenticatedIdentity, Role};
use crate::control::security::jwt::{JwtClaims, JwtError};
use crate::control::security::util::base64_url_decode;
use crate::types::TenantId;
use super::cache::JwksCache;
use super::key::verify_signature;
pub struct JwksRegistry {
providers: Vec<JwtProviderConfig>,
cache: Arc<JwksCache>,
config: JwtAuthConfig,
_refresh_handle: Option<tokio::task::JoinHandle<()>>,
}
impl JwksRegistry {
pub async fn init(config: JwtAuthConfig) -> Self {
let cache = Arc::new(JwksCache::new(config.jwks_cache_path.clone()));
cache.load_from_disk();
for provider in &config.providers {
super::fetch::fetch_and_cache(&provider.name, &provider.jwks_url, &cache).await;
}
let refresh_handle = if !config.providers.is_empty() {
let pairs: Vec<(String, String)> = config
.providers
.iter()
.map(|p| (p.name.clone(), p.jwks_url.clone()))
.collect();
Some(super::fetch::spawn_refresh_task(
pairs,
cache.clone(),
config.jwks_refresh_secs,
))
} else {
None
};
Self {
providers: config.providers.clone(),
cache,
config,
_refresh_handle: refresh_handle,
}
}
pub async fn validate(&self, token: &str) -> Result<AuthenticatedIdentity, JwtError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(JwtError::MalformedToken);
}
let header = decode_jwt_header(parts[0])?;
let kid = header.kid.as_deref().unwrap_or("");
let alg = &header.alg;
if alg == "none" {
return Err(JwtError::UnsupportedAlgorithm);
}
if !self.config.allowed_algorithms.is_empty()
&& !self.config.allowed_algorithms.iter().any(|a| a == alg)
{
return Err(JwtError::UnsupportedAlgorithm);
}
let payload_bytes = base64_url_decode(parts[1]).ok_or(JwtError::DecodingError)?;
let claims: JwtClaims =
serde_json::from_slice(&payload_bytes).map_err(|_| JwtError::InvalidClaims)?;
let provider = self.find_provider(&claims.iss)?;
let key = match self.cache.get(&provider.name, kid) {
Some(k) => k,
None => {
self.refetch_for_unknown_kid(provider, kid).await?
}
};
if key.algorithm != *alg {
warn!(
expected = %key.algorithm,
actual = %alg,
kid = %kid,
"JWT algorithm mismatch — possible algorithm confusion attack"
);
return Err(JwtError::UnsupportedAlgorithm);
}
let signing_input = format!("{}.{}", parts[0], parts[1]);
let signature = base64_url_decode(parts[2]).ok_or(JwtError::DecodingError)?;
if !verify_signature(&key, signing_input.as_bytes(), &signature) {
return Err(JwtError::InvalidSignature);
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if claims.exp > 0 && now > claims.exp + self.config.clock_skew_secs {
return Err(JwtError::Expired);
}
if claims.nbf > 0 && now + self.config.clock_skew_secs < claims.nbf {
return Err(JwtError::NotYetValid);
}
if !provider.issuer.is_empty() && claims.iss != provider.issuer {
return Err(JwtError::InvalidIssuer);
}
if !provider.audience.is_empty() && claims.aud != provider.audience {
return Err(JwtError::InvalidAudience);
}
let roles: Vec<Role> = claims
.roles
.iter()
.map(|r| r.parse::<Role>().unwrap_or(Role::Custom(r.clone())))
.collect();
let username = if claims.sub.is_empty() {
format!("jwt_user_{}", claims.user_id)
} else {
claims.sub.clone()
};
debug!(
username = %username,
tenant_id = claims.tenant_id,
provider = %provider.name,
kid = %kid,
"JWKS JWT validated"
);
Ok(AuthenticatedIdentity {
user_id: claims.user_id,
username,
tenant_id: TenantId::new(claims.tenant_id),
auth_method: AuthMethod::ApiKey,
roles,
is_superuser: claims.is_superuser,
})
}
pub fn decode_claims(&self, token: &str) -> Result<JwtClaims, JwtError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(JwtError::MalformedToken);
}
let payload_bytes = base64_url_decode(parts[1]).ok_or(JwtError::DecodingError)?;
serde_json::from_slice(&payload_bytes).map_err(|_| JwtError::InvalidClaims)
}
pub fn is_configured(&self) -> bool {
!self.providers.is_empty()
}
fn find_provider(&self, issuer: &str) -> Result<&JwtProviderConfig, JwtError> {
if let Some(p) = self
.providers
.iter()
.find(|p| !p.issuer.is_empty() && p.issuer == issuer)
{
return Ok(p);
}
if self.providers.len() == 1 {
return Ok(&self.providers[0]);
}
Err(JwtError::InvalidIssuer)
}
async fn refetch_for_unknown_kid(
&self,
provider: &JwtProviderConfig,
kid: &str,
) -> Result<super::key::VerificationKey, JwtError> {
if !self
.cache
.can_refetch(&provider.name, self.config.jwks_min_refetch_secs)
{
warn!(
provider = %provider.name,
kid = %kid,
"unknown kid — re-fetch rate-limited"
);
return Err(JwtError::InvalidSignature);
}
self.cache.mark_refetch_attempted(&provider.name);
super::fetch::fetch_and_cache(&provider.name, &provider.jwks_url, &self.cache).await;
self.cache
.get(&provider.name, kid)
.ok_or(JwtError::InvalidSignature)
}
}
#[derive(Debug, serde::Deserialize)]
struct JwtHeader {
alg: String,
#[serde(default)]
kid: Option<String>,
}
fn decode_jwt_header(encoded: &str) -> Result<JwtHeader, JwtError> {
let bytes = base64_url_decode(encoded).ok_or(JwtError::DecodingError)?;
serde_json::from_slice(&bytes).map_err(|_| JwtError::InvalidClaims)
}