use crate::error::AuthError;
use serde::{Deserialize, Serialize};
use std::sync::{Arc, LazyLock};
use tokio::sync::{Mutex, RwLock};
static HTTP_CLIENT: LazyLock<reqwest::Client> = LazyLock::new(|| {
reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(5)) .pool_max_idle_per_host(10) .pool_idle_timeout(std::time::Duration::from_secs(30))
.build()
.expect("Failed to create HTTP client")
});
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Jwk {
pub kid: String,
pub kty: String,
pub alg: Option<String>,
#[serde(rename = "use")]
pub key_use: Option<String>,
pub key_ops: Option<Vec<String>>,
pub crv: Option<String>,
pub x: Option<String>,
pub y: Option<String>,
pub n: Option<String>,
pub e: Option<String>,
pub ext: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct JwksResponse {
pub keys: Vec<Jwk>,
}
const JWKS_CACHE_DURATION: u64 = 24 * 3600; const JWKS_CACHE_MAX_AGE: u64 = 7 * 24 * 3600;
#[derive(Debug, Clone)]
pub struct JwksCache {
cache: Arc<RwLock<Option<JwksResponse>>>,
expires_at: Arc<RwLock<Option<u64>>>,
cached_at: Arc<RwLock<Option<u64>>>,
jwks_url: String,
fetch_mutex: Arc<Mutex<()>>,
}
impl JwksCache {
pub fn new(jwks_url: &str) -> Self {
if !jwks_url.starts_with("https://") {
tracing::warn!("JWKS URL should use HTTPS: {}", jwks_url);
}
Self {
cache: Arc::new(RwLock::new(None)),
expires_at: Arc::new(RwLock::new(None)),
cached_at: Arc::new(RwLock::new(None)),
jwks_url: jwks_url.to_string(),
fetch_mutex: Arc::new(Mutex::new(())),
}
}
pub async fn get_jwks(&self) -> Result<JwksResponse, AuthError> {
self.get_jwks_with_fallback().await
}
async fn get_jwks_with_fallback(&self) -> Result<JwksResponse, AuthError> {
if let Some(cached) = self.get_cached_jwks().await {
tracing::debug!("Using valid cached JWKS data");
return Ok(cached);
}
match self.fetch_fresh_jwks().await {
Ok(jwks) => {
tracing::info!("Successfully refreshed JWKS cache");
Ok(jwks)
}
Err(e) => {
tracing::warn!("Failed to refresh JWKS, attempting fallback: {:?}", e);
self.get_stale_cache().await
}
}
}
async fn get_cached_jwks(&self) -> Option<JwksResponse> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let expires_at = *self.expires_at.read().await;
if let Some(expires) = expires_at {
if now < expires {
return self.cache.read().await.clone();
}
}
None
}
async fn get_stale_cache(&self) -> Result<JwksResponse, AuthError> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let cached_at = *self.cached_at.read().await;
if let Some(cache_time) = cached_at {
if now - cache_time <= JWKS_CACHE_MAX_AGE {
if let Some(cached) = self.cache.read().await.clone() {
tracing::warn!(
"Using stale JWKS cache as fallback (age: {} hours)",
(now - cache_time) / 3600
);
return Ok(cached);
}
}
}
let error_msg = "No valid JWKS cache available and network fetch failed";
tracing::error!("{}", error_msg);
Err(AuthError::JwksError(error_msg.to_string()))
}
async fn fetch_fresh_jwks(&self) -> Result<JwksResponse, AuthError> {
let _fetch_guard = self.fetch_mutex.lock().await;
if let Some(cached) = self.get_cached_jwks().await {
tracing::debug!("JWKS cache was updated while waiting for lock");
return Ok(cached);
}
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
tracing::info!("Fetching fresh JWKS from: {}", self.jwks_url);
let response = HTTP_CLIENT.get(&self.jwks_url).send().await.map_err(|e| {
let error_msg = format!("Failed to fetch JWKS: {e:?}");
tracing::error!("{}", error_msg);
AuthError::JwksError(error_msg)
})?;
if !response.status().is_success() {
let error_msg = format!("JWKS endpoint returned status: {}", response.status());
tracing::error!("{}", error_msg);
return Err(AuthError::JwksError(error_msg));
}
let jwks: JwksResponse = response.json().await.map_err(|e| {
let error_msg = format!("Failed to parse JWKS response: {e:?}");
tracing::error!("{}", error_msg);
AuthError::JwksError(error_msg)
})?;
if jwks.keys.is_empty() {
let error_msg = "JWKS response contains no keys";
tracing::error!("{}", error_msg);
return Err(AuthError::JwksError(error_msg.to_string()));
}
*self.cache.write().await = Some(jwks.clone());
*self.expires_at.write().await = Some(now + JWKS_CACHE_DURATION);
*self.cached_at.write().await = Some(now);
tracing::info!(
"JWKS cache updated, expires at: {} (cached at: {})",
now + JWKS_CACHE_DURATION,
now
);
Ok(jwks)
}
pub async fn find_key(&self, kid: &str) -> Result<Jwk, AuthError> {
let jwks = self.get_jwks().await?;
jwks.keys
.iter() .find(|key| key.kid == kid)
.cloned() .ok_or_else(|| {
tracing::warn!("Key with kid '{}' not found in JWKS", kid);
AuthError::NoMatchingKey
})
}
}