use std::{
collections::HashMap,
sync::RwLock,
time::{Duration, Instant},
};
use jsonwebtoken::DecodingKey;
use serde::Deserialize;
use tracing::debug;
#[derive(Debug, Deserialize)]
struct JwksDocument {
keys: Vec<JwkKey>,
}
#[derive(Debug, Deserialize)]
struct JwkKey {
kid: Option<String>,
kty: String,
n: Option<String>,
e: Option<String>,
x: Option<String>,
y: Option<String>,
}
pub struct JwksCache {
keys: RwLock<HashMap<String, DecodingKey>>,
jwks_uri: String,
last_fetched: RwLock<Option<Instant>>,
ttl: Duration,
client: reqwest::Client,
}
impl JwksCache {
pub fn new(jwks_uri: &str, ttl: Duration) -> Self {
Self {
keys: RwLock::new(HashMap::new()),
jwks_uri: jwks_uri.to_string(),
last_fetched: RwLock::new(None),
ttl,
client: reqwest::Client::new(),
}
}
pub async fn get_key(&self, kid: &str) -> Result<Option<DecodingKey>, String> {
if let Some(key) = self.get_key_from_cache(kid) {
if !self.is_stale() {
return Ok(Some(key));
}
}
self.fetch_keys().await?;
Ok(self.get_key_from_cache(kid))
}
pub fn get_key_from_cache(&self, kid: &str) -> Option<DecodingKey> {
self.keys.read().ok()?.get(kid).cloned()
}
pub async fn force_refresh(&self) -> Result<(), String> {
self.fetch_keys().await
}
fn is_stale(&self) -> bool {
self.last_fetched
.read()
.ok()
.and_then(|guard| *guard)
.is_none_or(|t| t.elapsed() > self.ttl)
}
async fn fetch_keys(&self) -> Result<(), String> {
debug!(uri = %self.jwks_uri, "Fetching JWKS keys");
let jwks: JwksDocument = self
.client
.get(&self.jwks_uri)
.send()
.await
.map_err(|e| format!("JWKS fetch failed: {e}"))?
.json()
.await
.map_err(|e| format!("JWKS parse failed: {e}"))?;
let mut cache = self.keys.write().map_err(|e| format!("JWKS lock poisoned: {e}"))?;
cache.clear();
for key in &jwks.keys {
if let Some(kid) = &key.kid {
if let Some(decoding_key) = Self::convert_jwk(key) {
cache.insert(kid.clone(), decoding_key);
}
}
}
if let Ok(mut last) = self.last_fetched.write() {
*last = Some(Instant::now());
}
debug!(key_count = cache.len(), "JWKS cache refreshed");
Ok(())
}
fn convert_jwk(jwk: &JwkKey) -> Option<DecodingKey> {
match jwk.kty.as_str() {
"RSA" => {
let n = jwk.n.as_ref()?;
let e = jwk.e.as_ref()?;
DecodingKey::from_rsa_components(n, e).ok()
},
"EC" => {
let x = jwk.x.as_ref()?;
let y = jwk.y.as_ref()?;
DecodingKey::from_ec_components(x, y).ok()
},
_ => None,
}
}
}
impl std::fmt::Debug for JwksCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let key_count = self.keys.read().map(|k| k.len()).unwrap_or(0);
f.debug_struct("JwksCache")
.field("jwks_uri", &self.jwks_uri)
.field("ttl", &self.ttl)
.field("cached_keys", &key_count)
.finish()
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use wiremock::{
Mock, MockServer, ResponseTemplate,
matchers::{method, path},
};
use super::*;
fn jwks_fixture() -> serde_json::Value {
serde_json::json!({
"keys": [
{
"kty": "RSA",
"kid": "test-key-1",
"use": "sig",
"n": "0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw",
"e": "AQAB"
}
]
})
}
#[tokio::test]
async fn test_jwks_cache_empty() {
let cache =
JwksCache::new("https://example.com/.well-known/jwks.json", Duration::from_secs(3600));
assert!(cache.get_key_from_cache("nonexistent_kid").is_none());
}
#[tokio::test]
async fn test_jwks_cache_fetch_and_retrieve() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/.well-known/jwks.json"))
.respond_with(ResponseTemplate::new(200).set_body_json(jwks_fixture()))
.mount(&mock_server)
.await;
let cache = JwksCache::new(
&format!("{}/.well-known/jwks.json", mock_server.uri()),
Duration::from_secs(3600),
);
let key = cache.get_key("test-key-1").await.unwrap();
assert!(key.is_some());
}
#[tokio::test]
async fn test_jwks_cache_missing_kid_returns_none() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/.well-known/jwks.json"))
.respond_with(ResponseTemplate::new(200).set_body_json(jwks_fixture()))
.mount(&mock_server)
.await;
let cache = JwksCache::new(
&format!("{}/.well-known/jwks.json", mock_server.uri()),
Duration::from_secs(3600),
);
let key = cache.get_key("nonexistent-kid").await.unwrap();
assert!(key.is_none());
}
#[tokio::test]
async fn test_jwks_cache_ttl_refresh() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/.well-known/jwks.json"))
.respond_with(ResponseTemplate::new(200).set_body_json(jwks_fixture()))
.expect(2) .mount(&mock_server)
.await;
let cache = JwksCache::new(
&format!("{}/.well-known/jwks.json", mock_server.uri()),
Duration::from_secs(0),
);
let _ = cache.get_key("test-key-1").await.unwrap();
let _ = cache.get_key("test-key-1").await.unwrap();
}
#[tokio::test]
async fn test_jwks_cache_force_refresh() {
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/.well-known/jwks.json"))
.respond_with(ResponseTemplate::new(200).set_body_json(jwks_fixture()))
.mount(&mock_server)
.await;
let cache = JwksCache::new(
&format!("{}/.well-known/jwks.json", mock_server.uri()),
Duration::from_secs(3600),
);
cache.force_refresh().await.unwrap();
assert!(cache.get_key_from_cache("test-key-1").is_some());
}
#[tokio::test]
async fn test_jwks_cache_network_error() {
let cache = JwksCache::new("http://127.0.0.1:1/nonexistent", Duration::from_secs(3600));
let result = cache.get_key("any-kid").await;
assert!(result.is_err());
}
}