use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use dashmap::DashMap;
use jsonwebtoken::DecodingKey;
use serde::Deserialize;
use tokio::sync::{Mutex, RwLock};
use tracing::{debug, warn};
const NEGATIVE_CACHE_TTL: Duration = Duration::from_secs(30);
#[derive(Debug, Deserialize)]
pub struct JwksResponse {
pub keys: Vec<JsonWebKey>,
}
#[derive(Debug, Deserialize)]
pub struct JsonWebKey {
pub kid: Option<String>,
pub kty: String,
pub alg: Option<String>,
#[serde(rename = "use")]
pub key_use: Option<String>,
pub n: Option<String>,
pub e: Option<String>,
pub x5c: Option<Vec<String>>,
}
struct CachedJwks {
keys: HashMap<String, DecodingKey>,
fetched_at: Instant,
}
pub struct JwksClient {
url: String,
http_client: reqwest::Client,
cache: Arc<RwLock<Option<CachedJwks>>>,
cache_ttl: Duration,
refresh_lock: Arc<Mutex<()>>,
negative_cache: Arc<DashMap<String, Instant>>,
}
impl std::fmt::Debug for JwksClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwksClient")
.field("url", &self.url)
.field("cache_ttl", &self.cache_ttl)
.finish_non_exhaustive()
}
}
impl JwksClient {
pub fn new(url: String, cache_ttl_secs: u64) -> Result<Self, JwksError> {
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.map_err(|e| JwksError::HttpClientError(e.to_string()))?;
Ok(Self {
url,
http_client,
cache: Arc::new(RwLock::new(None)),
cache_ttl: Duration::from_secs(cache_ttl_secs),
refresh_lock: Arc::new(Mutex::new(())),
negative_cache: Arc::new(DashMap::new()),
})
}
pub async fn get_key(&self, kid: &str) -> Result<DecodingKey, JwksError> {
{
let cache = self.cache.read().await;
if let Some(ref cached) = *cache
&& cached.fetched_at.elapsed() < self.cache_ttl
&& let Some(key) = cached.keys.get(kid)
{
debug!(kid = %kid, "Using cached JWKS key");
return Ok(key.clone());
}
}
if let Some(entry) = self.negative_cache.get(kid) {
if entry.value().elapsed() < NEGATIVE_CACHE_TTL {
return Err(JwksError::KeyNotFound(kid.to_string()));
}
drop(entry);
self.negative_cache.remove(kid);
}
debug!(kid = %kid, "JWKS cache miss, refreshing");
self.refresh_if_needed().await?;
let cache = self.cache.read().await;
if let Some(ref cached) = *cache {
match cached.keys.get(kid).cloned() {
Some(key) => Ok(key),
None => {
drop(cache);
self.negative_cache.insert(kid.to_string(), Instant::now());
Err(JwksError::KeyNotFound(kid.to_string()))
}
}
} else {
Err(JwksError::FetchFailed(
"Cache empty after refresh".to_string(),
))
}
}
async fn refresh_if_needed(&self) -> Result<(), JwksError> {
let _guard = self.refresh_lock.lock().await;
{
let cache = self.cache.read().await;
if let Some(ref cached) = *cache
&& cached.fetched_at.elapsed() < self.cache_ttl
{
return Ok(());
}
}
self.refresh().await
}
pub async fn get_any_key(&self) -> Result<DecodingKey, JwksError> {
{
let cache = self.cache.read().await;
if let Some(ref cached) = *cache
&& cached.fetched_at.elapsed() < self.cache_ttl
&& let Some(key) = cached.keys.values().next()
{
debug!("Using first cached JWKS key (no kid specified)");
return Ok(key.clone());
}
}
debug!("JWKS cache miss for any key, refreshing");
self.refresh_if_needed().await?;
let cache = self.cache.read().await;
if let Some(ref cached) = *cache {
cached
.keys
.values()
.next()
.cloned()
.ok_or(JwksError::NoKeysAvailable)
} else {
Err(JwksError::FetchFailed("No keys in JWKS".to_string()))
}
}
pub async fn refresh(&self) -> Result<(), JwksError> {
debug!(url = %self.url, "Fetching JWKS");
let response = self
.http_client
.get(&self.url)
.send()
.await
.map_err(|e| JwksError::FetchFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(JwksError::FetchFailed(format!(
"HTTP {} from JWKS endpoint",
response.status()
)));
}
let jwks: JwksResponse = response
.json()
.await
.map_err(|e| JwksError::ParseFailed(e.to_string()))?;
let mut keys = HashMap::new();
for jwk in jwks.keys {
if let Some(ref key_use) = jwk.key_use
&& key_use != "sig"
{
continue;
}
let kid = jwk.kid.clone().unwrap_or_else(|| "default".to_string());
match self.parse_jwk(&jwk) {
Ok(Some(key)) => {
debug!(kid = %kid, kty = %jwk.kty, "Parsed JWKS key");
keys.insert(kid, key);
}
Ok(None) => {
debug!(kid = %kid, kty = %jwk.kty, "Skipping unsupported key type");
}
Err(e) => {
warn!(kid = %kid, error = %e, "Failed to parse JWKS key");
}
}
}
if keys.is_empty() {
return Err(JwksError::NoKeysAvailable);
}
debug!(count = keys.len(), "Cached JWKS keys");
for kid in keys.keys() {
self.negative_cache.remove(kid);
}
let mut cache = self.cache.write().await;
*cache = Some(CachedJwks {
keys,
fetched_at: Instant::now(),
});
Ok(())
}
fn parse_jwk(&self, jwk: &JsonWebKey) -> Result<Option<DecodingKey>, JwksError> {
match jwk.kty.as_str() {
"RSA" => {
if let Some(ref x5c) = jwk.x5c
&& let Some(cert) = x5c.first()
{
let pem = format!(
"-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----",
cert
);
return DecodingKey::from_rsa_pem(pem.as_bytes()).map(Some).map_err(
|e: jsonwebtoken::errors::Error| JwksError::KeyParseFailed(e.to_string()),
);
}
if let (Some(n), Some(e)) = (&jwk.n, &jwk.e) {
return DecodingKey::from_rsa_components(n, e).map(Some).map_err(
|e: jsonwebtoken::errors::Error| JwksError::KeyParseFailed(e.to_string()),
);
}
Ok(None)
}
_ => {
Ok(None)
}
}
}
pub fn url(&self) -> &str {
&self.url
}
}
#[derive(Debug, thiserror::Error)]
pub enum JwksError {
#[error("Failed to fetch JWKS: {0}")]
FetchFailed(String),
#[error("Failed to parse JWKS: {0}")]
ParseFailed(String),
#[error("Failed to parse key: {0}")]
KeyParseFailed(String),
#[error("Key not found: {0}")]
KeyNotFound(String),
#[error("No keys available in JWKS")]
NoKeysAvailable,
#[error("Failed to create HTTP client: {0}")]
HttpClientError(String),
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_parse_jwk_with_n_e() {
let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap();
let jwk = JsonWebKey {
kid: Some("test-key".to_string()),
kty: "RSA".to_string(),
alg: Some("RS256".to_string()),
key_use: Some("sig".to_string()),
n: Some("0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw".to_string()),
e: Some("AQAB".to_string()),
x5c: None,
};
let result = client.parse_jwk(&jwk);
assert!(result.is_ok());
assert!(result.unwrap().is_some());
}
#[test]
fn test_parse_jwk_unsupported_type() {
let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap();
let jwk = JsonWebKey {
kid: Some("test-key".to_string()),
kty: "EC".to_string(), alg: Some("ES256".to_string()),
key_use: Some("sig".to_string()),
n: None,
e: None,
x5c: None,
};
let result = client.parse_jwk(&jwk);
assert!(result.is_ok());
assert!(result.unwrap().is_none()); }
#[test]
fn test_parse_jwk_missing_components() {
let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap();
let jwk = JsonWebKey {
kid: Some("test-key".to_string()),
kty: "RSA".to_string(),
alg: Some("RS256".to_string()),
key_use: Some("sig".to_string()),
n: None, e: None, x5c: None,
};
let result = client.parse_jwk(&jwk);
assert!(result.is_ok());
assert!(result.unwrap().is_none()); }
#[test]
fn jwks_client_exposes_configured_url() {
let client =
JwksClient::new("https://issuer.example.com/.well-known/jwks".into(), 60).unwrap();
assert_eq!(client.url(), "https://issuer.example.com/.well-known/jwks");
}
#[test]
fn parse_jwk_returns_none_for_non_rsa_kty() {
let client = JwksClient::new("http://example.com".into(), 60).unwrap();
let jwk = JsonWebKey {
kid: Some("sym".into()),
kty: "oct".into(),
alg: None,
key_use: Some("sig".into()),
n: None,
e: None,
x5c: None,
};
assert!(client.parse_jwk(&jwk).unwrap().is_none());
}
#[test]
fn parse_jwk_returns_none_when_only_modulus_present() {
let client = JwksClient::new("http://example.com".into(), 60).unwrap();
let jwk = JsonWebKey {
kid: Some("partial".into()),
kty: "RSA".into(),
alg: Some("RS256".into()),
key_use: Some("sig".into()),
n: Some("AQAB".into()),
e: None,
x5c: None,
};
assert!(client.parse_jwk(&jwk).unwrap().is_none());
}
#[test]
fn parse_jwk_x5c_takes_precedence_over_n_e_and_fails_loudly_on_bad_cert() {
let client = JwksClient::new("http://example.com".into(), 60).unwrap();
let jwk = JsonWebKey {
kid: Some("bad-x5c".into()),
kty: "RSA".into(),
alg: Some("RS256".into()),
key_use: Some("sig".into()),
n: Some(
"0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw"
.into(),
),
e: Some("AQAB".into()),
x5c: Some(vec!["not-a-real-cert".into()]),
};
let err = client.parse_jwk(&jwk).unwrap_err();
assert!(matches!(err, JwksError::KeyParseFailed(_)), "got {err:?}");
}
#[test]
fn jwks_error_display_messages_are_descriptive() {
assert_eq!(
JwksError::FetchFailed("HTTP 500".into()).to_string(),
"Failed to fetch JWKS: HTTP 500"
);
assert_eq!(
JwksError::ParseFailed("eof".into()).to_string(),
"Failed to parse JWKS: eof"
);
assert_eq!(
JwksError::KeyParseFailed("bad PEM".into()).to_string(),
"Failed to parse key: bad PEM"
);
assert_eq!(
JwksError::KeyNotFound("abc".into()).to_string(),
"Key not found: abc"
);
assert_eq!(
JwksError::NoKeysAvailable.to_string(),
"No keys available in JWKS"
);
assert_eq!(
JwksError::HttpClientError("tls".into()).to_string(),
"Failed to create HTTP client: tls"
);
}
#[tokio::test]
async fn get_key_returns_cached_match_without_network() {
let client = JwksClient::new("http://example.invalid".into(), 3600).unwrap();
let key = DecodingKey::from_secret(b"placeholder");
let mut keys = HashMap::new();
keys.insert("kid-1".to_string(), key);
*client.cache.write().await = Some(CachedJwks {
keys,
fetched_at: Instant::now(),
});
let got = client.get_key("kid-1").await;
assert!(got.is_ok());
}
#[tokio::test]
async fn get_any_key_returns_first_cached_when_kid_absent() {
let client = JwksClient::new("http://example.invalid".into(), 3600).unwrap();
let mut keys = HashMap::new();
keys.insert("only".into(), DecodingKey::from_secret(b"placeholder"));
*client.cache.write().await = Some(CachedJwks {
keys,
fetched_at: Instant::now(),
});
let got = client.get_any_key().await;
assert!(got.is_ok());
}
}