use std::collections::HashMap;
use std::path::Path;
use std::sync::RwLock;
use std::time::Instant;
use super::key::VerificationKey;
struct ProviderKeys {
keys: HashMap<String, VerificationKey>,
fetched_at: Instant,
last_refetch_attempt: Instant,
}
pub struct JwksCache {
providers: RwLock<HashMap<String, ProviderKeys>>,
disk_path: Option<String>,
}
impl JwksCache {
pub fn new(disk_path: Option<String>) -> Self {
Self {
providers: RwLock::new(HashMap::new()),
disk_path,
}
}
pub fn get(&self, provider: &str, kid: &str) -> Option<VerificationKey> {
let providers = self.providers.read().unwrap_or_else(|p| p.into_inner());
providers
.get(provider)
.and_then(|pk| pk.keys.get(kid).cloned())
}
pub fn has_provider(&self, provider: &str) -> bool {
let providers = self.providers.read().unwrap_or_else(|p| p.into_inner());
providers.contains_key(provider)
}
pub fn update_provider(&self, provider: &str, keys: Vec<VerificationKey>) {
let now = Instant::now();
let key_map: HashMap<String, VerificationKey> =
keys.into_iter().map(|k| (k.kid.clone(), k)).collect();
let mut providers = self.providers.write().unwrap_or_else(|p| p.into_inner());
providers.insert(
provider.to_string(),
ProviderKeys {
keys: key_map,
fetched_at: now,
last_refetch_attempt: now,
},
);
if let Some(ref path) = self.disk_path {
let _ = self.persist_to_disk(path, &providers);
}
}
pub fn can_refetch(&self, provider: &str, min_interval_secs: u64) -> bool {
let providers = self.providers.read().unwrap_or_else(|p| p.into_inner());
match providers.get(provider) {
Some(pk) => pk.last_refetch_attempt.elapsed().as_secs() >= min_interval_secs,
None => true, }
}
pub fn mark_refetch_attempted(&self, provider: &str) {
let mut providers = self.providers.write().unwrap_or_else(|p| p.into_inner());
if let Some(pk) = providers.get_mut(provider) {
pk.last_refetch_attempt = Instant::now();
}
}
pub fn is_stale(&self, provider: &str, max_age_secs: u64) -> bool {
let providers = self.providers.read().unwrap_or_else(|p| p.into_inner());
match providers.get(provider) {
Some(pk) => pk.fetched_at.elapsed().as_secs() >= max_age_secs,
None => true,
}
}
pub fn load_from_disk(&self) {
let Some(ref path) = self.disk_path else {
return;
};
let disk_path = Path::new(path);
if !disk_path.exists() {
return;
}
let data = match std::fs::read_to_string(disk_path) {
Ok(d) => d,
Err(e) => {
tracing::warn!(path = %path, error = %e, "failed to read JWKS disk cache");
return;
}
};
let disk_cache: HashMap<String, Vec<DiskKeyEntry>> = match serde_json::from_str(&data) {
Ok(c) => c,
Err(e) => {
tracing::warn!(error = %e, "failed to parse JWKS disk cache");
return;
}
};
let now = Instant::now();
let mut providers = self.providers.write().unwrap_or_else(|p| p.into_inner());
for (provider_name, entries) in disk_cache {
let key_map: HashMap<String, VerificationKey> = entries
.into_iter()
.filter_map(|e| e.to_verification_key())
.map(|k| (k.kid.clone(), k))
.collect();
if !key_map.is_empty() {
tracing::info!(
provider = %provider_name,
keys = key_map.len(),
"loaded JWKS keys from disk cache"
);
providers.insert(
provider_name,
ProviderKeys {
keys: key_map,
fetched_at: now,
last_refetch_attempt: now,
},
);
}
}
}
fn persist_to_disk(
&self,
path: &str,
providers: &HashMap<String, ProviderKeys>,
) -> Result<(), std::io::Error> {
let disk_cache: HashMap<String, Vec<DiskKeyEntry>> = providers
.iter()
.map(|(name, pk)| {
let entries: Vec<DiskKeyEntry> = pk
.keys
.values()
.map(DiskKeyEntry::from_verification_key)
.collect();
(name.clone(), entries)
})
.collect();
let json = serde_json::to_string_pretty(&disk_cache).map_err(std::io::Error::other)?;
let disk_path = Path::new(path);
if let Some(parent) = disk_path.parent() {
std::fs::create_dir_all(parent)?;
}
let tmp = format!("{path}.tmp");
std::fs::write(&tmp, json)?;
std::fs::rename(&tmp, disk_path)?;
Ok(())
}
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
struct DiskKeyEntry {
kid: String,
algorithm: String,
key_type: String,
key_data: String,
}
impl DiskKeyEntry {
fn from_verification_key(key: &VerificationKey) -> Self {
use base64::Engine;
let (key_type, key_data) = match &key.key_type {
super::key::KeyType::Rsa(der) => (
"rsa".into(),
base64::engine::general_purpose::STANDARD.encode(der),
),
super::key::KeyType::EcP256(point) => (
"ec_p256".into(),
base64::engine::general_purpose::STANDARD.encode(point),
),
super::key::KeyType::EcP384(point) => (
"ec_p384".into(),
base64::engine::general_purpose::STANDARD.encode(point),
),
};
Self {
kid: key.kid.clone(),
algorithm: key.algorithm.clone(),
key_type,
key_data,
}
}
fn to_verification_key(&self) -> Option<VerificationKey> {
use base64::Engine;
let bytes = base64::engine::general_purpose::STANDARD
.decode(&self.key_data)
.ok()?;
let key_type = match self.key_type.as_str() {
"rsa" => super::key::KeyType::Rsa(bytes),
"ec_p256" => super::key::KeyType::EcP256(bytes),
"ec_p384" => super::key::KeyType::EcP384(bytes),
_ => return None,
};
Some(VerificationKey {
kid: self.kid.clone(),
algorithm: self.algorithm.clone(),
key_type,
})
}
}
#[cfg(test)]
mod tests {
use super::super::key::{KeyType, VerificationKey};
use super::*;
fn test_key(kid: &str) -> VerificationKey {
VerificationKey {
kid: kid.into(),
algorithm: "RS256".into(),
key_type: KeyType::Rsa(vec![0x30, 0x82, 0x01]),
}
}
#[test]
fn cache_get_and_update() {
let cache = JwksCache::new(None);
assert!(cache.get("provider1", "key1").is_none());
cache.update_provider("provider1", vec![test_key("key1"), test_key("key2")]);
assert!(cache.get("provider1", "key1").is_some());
assert!(cache.get("provider1", "key2").is_some());
assert!(cache.get("provider1", "key3").is_none());
assert!(cache.get("other", "key1").is_none());
}
#[test]
fn refetch_rate_limiting() {
let cache = JwksCache::new(None);
cache.update_provider("p1", vec![test_key("k1")]);
assert!(!cache.can_refetch("p1", 60));
assert!(cache.can_refetch("p1", 0));
assert!(cache.can_refetch("unknown", 60));
}
#[test]
fn staleness_check() {
let cache = JwksCache::new(None);
assert!(cache.is_stale("p1", 3600));
cache.update_provider("p1", vec![test_key("k1")]);
assert!(!cache.is_stale("p1", 3600)); assert!(cache.is_stale("p1", 0)); }
#[test]
fn disk_roundtrip() {
let key = test_key("disk-key-1");
let disk = DiskKeyEntry::from_verification_key(&key);
let restored = disk.to_verification_key().unwrap();
assert_eq!(restored.kid, "disk-key-1");
assert_eq!(restored.algorithm, "RS256");
}
}