use std::sync::Arc;
use tracing::{debug, warn};
use crate::config::auth::{JwtAuthConfig, JwtProviderConfig};
use crate::control::security::identity::{AuthMethod, AuthenticatedIdentity, Role};
use crate::control::security::jwt::{JwtClaims, JwtError};
use crate::control::security::util::base64_url_decode;
use crate::types::TenantId;
use super::cache::JwksCache;
use super::key::{VerificationKey, verify_signature};
pub struct JwksRegistry {
providers: Vec<JwtProviderConfig>,
cache: Arc<JwksCache>,
config: JwtAuthConfig,
policy: Arc<super::url::JwksPolicy>,
_refresh_handle: Option<tokio::task::JoinHandle<()>>,
}
struct DecodedToken<'a> {
parts: [&'a str; 3],
header: JwtHeader,
claims: JwtClaims,
}
impl JwksRegistry {
pub async fn init(config: JwtAuthConfig) -> Self {
let cache = Arc::new(JwksCache::new(config.jwks_cache_path.clone()));
let policy = Arc::new(config.jwks_policy().unwrap_or_default());
cache.load_from_disk();
for provider in &config.providers {
super::fetch::fetch_and_cache(&provider.name, &provider.jwks_url, &cache, &policy)
.await;
}
let refresh_handle = if !config.providers.is_empty() {
let pairs: Vec<(String, String)> = config
.providers
.iter()
.map(|p| (p.name.clone(), p.jwks_url.clone()))
.collect();
Some(super::fetch::spawn_refresh_task(
pairs,
cache.clone(),
config.jwks_refresh_secs,
policy.clone(),
))
} else {
None
};
Self {
providers: config.providers.clone(),
cache,
config,
policy,
_refresh_handle: refresh_handle,
}
}
pub async fn validate(&self, token: &str) -> Result<AuthenticatedIdentity, JwtError> {
let decoded = self.decode_unverified(token)?;
let provider = self.find_provider(&decoded.claims.iss)?;
let key = self.resolve_key(provider, &decoded).await?;
self.verify_signature_and_time(&decoded, &key, &provider.name)?;
if !provider.issuer.is_empty() && decoded.claims.iss != provider.issuer {
return Err(JwtError::InvalidIssuer);
}
if !provider.audience.is_empty() && decoded.claims.aud != provider.audience {
return Err(JwtError::InvalidAudience);
}
let claims = decoded.claims;
let kid = decoded.header.kid.as_deref().unwrap_or("");
let identity = build_identity(&claims);
debug!(
username = %identity.username,
tenant_id = claims.tenant_id,
provider = %provider.name,
kid = %kid,
"JWKS JWT validated"
);
Ok(identity)
}
pub async fn validate_with_provider(
&self,
provider_name: &str,
token: &str,
) -> Result<JwtClaims, JwtError> {
let decoded = self.decode_unverified(token)?;
let provider = self
.providers
.iter()
.find(|p| p.name == provider_name)
.ok_or(JwtError::InvalidIssuer)?;
let key = self.resolve_key(provider, &decoded).await?;
self.verify_signature_and_time(&decoded, &key, provider_name)?;
debug!(
provider = %provider_name,
kid = %decoded.header.kid.as_deref().unwrap_or(""),
sub = %decoded.claims.sub,
"JWKS JWT validated via validate_with_provider"
);
Ok(decoded.claims)
}
pub async fn validate_with_catalog_provider(
&self,
provider_name: &str,
jwks_uri: &str,
token: &str,
) -> Result<JwtClaims, JwtError> {
let decoded = self.decode_unverified(token)?;
let kid = decoded.header.kid.as_deref().unwrap_or("");
let key = match self.cache.get(provider_name, kid) {
Some(k) => k,
None => {
self.refetch_catalog_key(provider_name, jwks_uri, kid)
.await?
}
};
self.verify_signature_and_time(&decoded, &key, provider_name)?;
debug!(
provider = %provider_name,
kid = %kid,
sub = %decoded.claims.sub,
"JWKS JWT validated via catalog provider"
);
Ok(decoded.claims)
}
pub fn decode_claims(&self, token: &str) -> Result<JwtClaims, JwtError> {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return Err(JwtError::MalformedToken);
}
let payload_bytes = base64_url_decode(parts[1]).ok_or(JwtError::DecodingError)?;
sonic_rs::from_slice(&payload_bytes).map_err(|_| JwtError::InvalidClaims)
}
pub fn is_configured(&self) -> bool {
!self.providers.is_empty()
}
fn decode_unverified<'a>(&self, token: &'a str) -> Result<DecodedToken<'a>, JwtError> {
let raw: Vec<&str> = token.split('.').collect();
if raw.len() != 3 {
return Err(JwtError::MalformedToken);
}
let parts = [raw[0], raw[1], raw[2]];
let header = decode_jwt_header(parts[0])?;
if header.alg == "none" {
return Err(JwtError::UnsupportedAlgorithm);
}
if !self.config.allowed_algorithms.is_empty()
&& !self
.config
.allowed_algorithms
.iter()
.any(|a| a == &header.alg)
{
return Err(JwtError::UnsupportedAlgorithm);
}
let payload_bytes = base64_url_decode(parts[1]).ok_or(JwtError::DecodingError)?;
let claims: JwtClaims =
sonic_rs::from_slice(&payload_bytes).map_err(|_| JwtError::InvalidClaims)?;
Ok(DecodedToken {
parts,
header,
claims,
})
}
fn verify_signature_and_time(
&self,
decoded: &DecodedToken<'_>,
key: &VerificationKey,
provider_name: &str,
) -> Result<(), JwtError> {
let kid = decoded.header.kid.as_deref().unwrap_or("");
if key.algorithm != decoded.header.alg {
warn!(
expected = %key.algorithm,
actual = %decoded.header.alg,
kid = %kid,
provider = %provider_name,
"JWT algorithm mismatch — possible algorithm confusion attack"
);
return Err(JwtError::UnsupportedAlgorithm);
}
let signing_input = format!("{}.{}", decoded.parts[0], decoded.parts[1]);
let signature = base64_url_decode(decoded.parts[2]).ok_or(JwtError::DecodingError)?;
if !verify_signature(key, signing_input.as_bytes(), &signature) {
return Err(JwtError::InvalidSignature);
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if decoded.claims.exp > 0 && now > decoded.claims.exp + self.config.clock_skew_secs {
return Err(JwtError::Expired);
}
if decoded.claims.nbf > 0 && now + self.config.clock_skew_secs < decoded.claims.nbf {
return Err(JwtError::NotYetValid);
}
Ok(())
}
async fn resolve_key(
&self,
provider: &JwtProviderConfig,
decoded: &DecodedToken<'_>,
) -> Result<VerificationKey, JwtError> {
let kid = decoded.header.kid.as_deref().unwrap_or("");
match self.cache.get(&provider.name, kid) {
Some(k) => Ok(k),
None => self.refetch_for_unknown_kid(provider, kid).await,
}
}
fn find_provider(&self, issuer: &str) -> Result<&JwtProviderConfig, JwtError> {
if issuer.is_empty() {
return Err(JwtError::InvalidIssuer);
}
self.providers
.iter()
.find(|p| !p.issuer.is_empty() && p.issuer == issuer)
.ok_or(JwtError::InvalidIssuer)
}
async fn refetch_for_unknown_kid(
&self,
provider: &JwtProviderConfig,
kid: &str,
) -> Result<VerificationKey, JwtError> {
if !self
.cache
.can_refetch(&provider.name, self.config.jwks_min_refetch_secs)
{
warn!(
provider = %provider.name,
kid = %kid,
"unknown kid — re-fetch rate-limited"
);
return Err(JwtError::InvalidSignature);
}
self.cache.mark_refetch_attempted(&provider.name);
super::fetch::fetch_and_cache(
&provider.name,
&provider.jwks_url,
&self.cache,
&self.policy,
)
.await;
self.cache
.get(&provider.name, kid)
.ok_or(JwtError::InvalidSignature)
}
async fn refetch_catalog_key(
&self,
provider_name: &str,
jwks_uri: &str,
kid: &str,
) -> Result<VerificationKey, JwtError> {
if !self
.cache
.can_refetch(provider_name, self.config.jwks_min_refetch_secs)
{
warn!(
provider = %provider_name,
kid = %kid,
"unknown kid — re-fetch rate-limited (catalog provider)"
);
return Err(JwtError::InvalidSignature);
}
self.cache.mark_refetch_attempted(provider_name);
super::fetch::fetch_and_cache(provider_name, jwks_uri, &self.cache, &self.policy).await;
self.cache
.get(provider_name, kid)
.ok_or(JwtError::InvalidSignature)
}
}
fn build_identity(claims: &JwtClaims) -> AuthenticatedIdentity {
let roles: Vec<Role> = claims
.roles
.iter()
.map(|r| parse_role_infallible(r))
.collect();
let username = if claims.sub.is_empty() {
format!("jwt_user_{}", claims.user_id)
} else {
claims.sub.clone()
};
AuthenticatedIdentity {
user_id: claims.user_id,
username,
tenant_id: TenantId::new(claims.tenant_id),
auth_method: AuthMethod::OidcBearer,
roles,
is_superuser: claims.is_superuser,
default_database: None,
accessible_databases: AuthenticatedIdentity::default_database_set(claims.is_superuser),
}
}
fn parse_role_infallible(s: &str) -> Role {
match s.parse::<Role>() {
Ok(role) => role,
Err(never) => match never {},
}
}
#[derive(Debug, serde::Deserialize)]
struct JwtHeader {
alg: String,
#[serde(default)]
kid: Option<String>,
}
fn decode_jwt_header(encoded: &str) -> Result<JwtHeader, JwtError> {
let bytes = base64_url_decode(encoded).ok_or(JwtError::DecodingError)?;
sonic_rs::from_slice(&bytes).map_err(|_| JwtError::InvalidClaims)
}