#[cfg(feature = "sso")]
use dashmap::DashMap;
#[cfg(feature = "sso")]
use jsonwebtoken::DecodingKey;
#[cfg(feature = "sso")]
use std::sync::atomic::{AtomicU64, Ordering};
#[cfg(feature = "sso")]
use std::time::Duration;
#[cfg(feature = "sso")]
use tokio::sync::Mutex;
use super::TokenError;
#[cfg(feature = "sso")]
pub struct JwksCache {
keys: DashMap<String, DecodingKey>,
jwks_uri: String,
last_refresh: AtomicU64,
refresh_interval: Duration,
max_keys: usize,
refresh_lock: Mutex<()>,
client: reqwest::Client,
}
#[cfg(feature = "sso")]
impl JwksCache {
pub fn new(jwks_uri: impl Into<String>) -> Self {
Self {
keys: DashMap::new(),
jwks_uri: jwks_uri.into(),
last_refresh: AtomicU64::new(0),
refresh_interval: Duration::from_secs(3600), max_keys: 100,
refresh_lock: Mutex::new(()),
client: reqwest::Client::new(),
}
}
pub fn with_refresh_interval(mut self, interval: Duration) -> Self {
self.refresh_interval = interval;
self
}
pub fn with_max_keys(mut self, max_keys: usize) -> Self {
self.max_keys = max_keys.max(1);
self
}
pub async fn get_key(&self, kid: &str) -> Result<DecodingKey, TokenError> {
if let Some(key) = self.keys.get(kid) {
return Ok(key.clone());
}
self.refresh().await?;
self.keys
.get(kid)
.map(|k| k.clone())
.ok_or_else(|| TokenError::KeyNotFound(kid.to_string()))
}
pub async fn refresh(&self) -> Result<(), TokenError> {
let now =
std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_secs();
let last = self.last_refresh.load(Ordering::Relaxed);
let interval_secs = self.refresh_interval.as_secs();
if last > 0 && now - last < interval_secs / 2 {
return Ok(());
}
let _refresh_guard = self.refresh_lock.lock().await;
let last = self.last_refresh.load(Ordering::Relaxed);
if last > 0 && now - last < interval_secs / 2 {
return Ok(());
}
tracing::debug!(jwks.uri = %self.jwks_uri, "fetching jwks");
let response = self
.client
.get(&self.jwks_uri)
.send()
.await?
.error_for_status()
.map_err(|e| TokenError::JwksFetchError(e.to_string()))?;
let jwks: Jwks =
response.json().await.map_err(|e| TokenError::JwksParseError(e.to_string()))?;
let total_keys = jwks.keys.len();
self.keys.clear();
for key in jwks.keys.into_iter().take(self.max_keys) {
if let Some(kid) = &key.kid {
if let Ok(decoding_key) = key.to_decoding_key() {
self.keys.insert(kid.clone(), decoding_key);
}
}
}
if total_keys > self.max_keys {
tracing::warn!(
jwks.uri = %self.jwks_uri,
jwks.keys_total = total_keys,
jwks.keys_cached = self.max_keys,
"jwks response exceeded cache key limit"
);
}
self.last_refresh.store(now, Ordering::Relaxed);
tracing::debug!(jwks.keys_cached = self.keys.len(), "jwks cache refreshed");
Ok(())
}
pub fn len(&self) -> usize {
self.keys.len()
}
pub fn is_empty(&self) -> bool {
self.keys.is_empty()
}
}
#[cfg(feature = "sso")]
#[derive(Debug, serde::Deserialize)]
struct Jwks {
keys: Vec<Jwk>,
}
#[cfg(feature = "sso")]
#[derive(Debug, serde::Deserialize)]
#[allow(dead_code)]
struct Jwk {
kty: String,
kid: Option<String>,
alg: Option<String>,
#[serde(rename = "use")]
use_: Option<String>,
n: Option<String>,
e: Option<String>,
x: Option<String>,
y: Option<String>,
crv: Option<String>,
}
#[cfg(feature = "sso")]
impl Jwk {
fn to_decoding_key(&self) -> Result<DecodingKey, TokenError> {
match self.kty.as_str() {
"RSA" => {
let n = self
.n
.as_ref()
.ok_or_else(|| TokenError::JwksParseError("Missing 'n' in RSA key".into()))?;
let e = self
.e
.as_ref()
.ok_or_else(|| TokenError::JwksParseError("Missing 'e' in RSA key".into()))?;
DecodingKey::from_rsa_components(n, e)
.map_err(|e| TokenError::JwksParseError(e.to_string()))
}
"EC" => {
let x = self
.x
.as_ref()
.ok_or_else(|| TokenError::JwksParseError("Missing 'x' in EC key".into()))?;
let y = self
.y
.as_ref()
.ok_or_else(|| TokenError::JwksParseError("Missing 'y' in EC key".into()))?;
DecodingKey::from_ec_components(x, y)
.map_err(|e| TokenError::JwksParseError(e.to_string()))
}
_ => Err(TokenError::UnsupportedAlgorithm(self.kty.clone())),
}
}
}
#[cfg(not(feature = "sso"))]
pub struct JwksCache;
#[cfg(not(feature = "sso"))]
impl JwksCache {
pub fn new(_jwks_uri: impl Into<String>) -> Self {
Self
}
}