use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use jsonwebtoken::DecodingKey;
use serde::Deserialize;
use tokio::sync::RwLock;
use tracing::{debug, warn};
#[derive(Debug, Deserialize)]
pub struct JwksResponse {
pub keys: Vec<JsonWebKey>,
}
#[derive(Debug, Deserialize)]
pub struct JsonWebKey {
pub kid: Option<String>,
pub kty: String,
pub alg: Option<String>,
#[serde(rename = "use")]
pub key_use: Option<String>,
pub n: Option<String>,
pub e: Option<String>,
pub x5c: Option<Vec<String>>,
}
struct CachedJwks {
keys: HashMap<String, DecodingKey>,
fetched_at: Instant,
}
pub struct JwksClient {
url: String,
http_client: reqwest::Client,
cache: Arc<RwLock<Option<CachedJwks>>>,
cache_ttl: Duration,
}
impl std::fmt::Debug for JwksClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwksClient")
.field("url", &self.url)
.field("cache_ttl", &self.cache_ttl)
.finish_non_exhaustive()
}
}
impl JwksClient {
pub fn new(url: String, cache_ttl_secs: u64) -> Result<Self, JwksError> {
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.map_err(|e| JwksError::HttpClientError(e.to_string()))?;
Ok(Self {
url,
http_client,
cache: Arc::new(RwLock::new(None)),
cache_ttl: Duration::from_secs(cache_ttl_secs),
})
}
pub async fn get_key(&self, kid: &str) -> Result<DecodingKey, JwksError> {
{
let cache = self.cache.read().await;
if let Some(ref cached) = *cache
&& cached.fetched_at.elapsed() < self.cache_ttl
&& let Some(key) = cached.keys.get(kid)
{
debug!(kid = %kid, "Using cached JWKS key");
return Ok(key.clone());
}
}
debug!(kid = %kid, "JWKS cache miss, refreshing");
self.refresh().await?;
let cache = self.cache.read().await;
if let Some(ref cached) = *cache {
cached
.keys
.get(kid)
.cloned()
.ok_or_else(|| JwksError::KeyNotFound(kid.to_string()))
} else {
Err(JwksError::FetchFailed(
"Cache empty after refresh".to_string(),
))
}
}
pub async fn get_any_key(&self) -> Result<DecodingKey, JwksError> {
{
let cache = self.cache.read().await;
if let Some(ref cached) = *cache
&& cached.fetched_at.elapsed() < self.cache_ttl
&& let Some(key) = cached.keys.values().next()
{
debug!("Using first cached JWKS key (no kid specified)");
return Ok(key.clone());
}
}
debug!("JWKS cache miss for any key, refreshing");
self.refresh().await?;
let cache = self.cache.read().await;
if let Some(ref cached) = *cache {
cached
.keys
.values()
.next()
.cloned()
.ok_or(JwksError::NoKeysAvailable)
} else {
Err(JwksError::FetchFailed("No keys in JWKS".to_string()))
}
}
pub async fn refresh(&self) -> Result<(), JwksError> {
debug!(url = %self.url, "Fetching JWKS");
let response = self
.http_client
.get(&self.url)
.send()
.await
.map_err(|e| JwksError::FetchFailed(e.to_string()))?;
if !response.status().is_success() {
return Err(JwksError::FetchFailed(format!(
"HTTP {} from JWKS endpoint",
response.status()
)));
}
let jwks: JwksResponse = response
.json()
.await
.map_err(|e| JwksError::ParseFailed(e.to_string()))?;
let mut keys = HashMap::new();
for jwk in jwks.keys {
if let Some(ref key_use) = jwk.key_use
&& key_use != "sig"
{
continue;
}
let kid = jwk.kid.clone().unwrap_or_else(|| "default".to_string());
match self.parse_jwk(&jwk) {
Ok(Some(key)) => {
debug!(kid = %kid, kty = %jwk.kty, "Parsed JWKS key");
keys.insert(kid, key);
}
Ok(None) => {
debug!(kid = %kid, kty = %jwk.kty, "Skipping unsupported key type");
}
Err(e) => {
warn!(kid = %kid, error = %e, "Failed to parse JWKS key");
}
}
}
if keys.is_empty() {
return Err(JwksError::NoKeysAvailable);
}
debug!(count = keys.len(), "Cached JWKS keys");
let mut cache = self.cache.write().await;
*cache = Some(CachedJwks {
keys,
fetched_at: Instant::now(),
});
Ok(())
}
fn parse_jwk(&self, jwk: &JsonWebKey) -> Result<Option<DecodingKey>, JwksError> {
match jwk.kty.as_str() {
"RSA" => {
if let Some(ref x5c) = jwk.x5c
&& let Some(cert) = x5c.first()
{
let pem = format!(
"-----BEGIN CERTIFICATE-----\n{}\n-----END CERTIFICATE-----",
cert
);
return DecodingKey::from_rsa_pem(pem.as_bytes()).map(Some).map_err(
|e: jsonwebtoken::errors::Error| JwksError::KeyParseFailed(e.to_string()),
);
}
if let (Some(n), Some(e)) = (&jwk.n, &jwk.e) {
return DecodingKey::from_rsa_components(n, e).map(Some).map_err(
|e: jsonwebtoken::errors::Error| JwksError::KeyParseFailed(e.to_string()),
);
}
Ok(None)
}
_ => {
Ok(None)
}
}
}
pub fn url(&self) -> &str {
&self.url
}
}
#[derive(Debug, thiserror::Error)]
pub enum JwksError {
#[error("Failed to fetch JWKS: {0}")]
FetchFailed(String),
#[error("Failed to parse JWKS: {0}")]
ParseFailed(String),
#[error("Failed to parse key: {0}")]
KeyParseFailed(String),
#[error("Key not found: {0}")]
KeyNotFound(String),
#[error("No keys available in JWKS")]
NoKeysAvailable,
#[error("Failed to create HTTP client: {0}")]
HttpClientError(String),
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::indexing_slicing, clippy::panic)]
mod tests {
use super::*;
#[test]
fn test_parse_jwk_with_n_e() {
let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap();
let jwk = JsonWebKey {
kid: Some("test-key".to_string()),
kty: "RSA".to_string(),
alg: Some("RS256".to_string()),
key_use: Some("sig".to_string()),
n: Some("0vx7agoebGcQSuuPiLJXZptN9nndrQmbXEps2aiAFbWhM78LhWx4cbbfAAtVT86zwu1RK7aPFFxuhDR1L6tSoc_BJECPebWKRXjBZCiFV4n3oknjhMstn64tZ_2W-5JsGY4Hc5n9yBXArwl93lqt7_RN5w6Cf0h4QyQ5v-65YGjQR0_FDW2QvzqY368QQMicAtaSqzs8KJZgnYb9c7d0zgdAZHzu6qMQvRL5hajrn1n91CbOpbISD08qNLyrdkt-bFTWhAI4vMQFh6WeZu0fM4lFd2NcRwr3XPksINHaQ-G_xBniIqbw0Ls1jF44-csFCur-kEgU8awapJzKnqDKgw".to_string()),
e: Some("AQAB".to_string()),
x5c: None,
};
let result = client.parse_jwk(&jwk);
assert!(result.is_ok());
assert!(result.unwrap().is_some());
}
#[test]
fn test_parse_jwk_unsupported_type() {
let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap();
let jwk = JsonWebKey {
kid: Some("test-key".to_string()),
kty: "EC".to_string(), alg: Some("ES256".to_string()),
key_use: Some("sig".to_string()),
n: None,
e: None,
x5c: None,
};
let result = client.parse_jwk(&jwk);
assert!(result.is_ok());
assert!(result.unwrap().is_none()); }
#[test]
fn test_parse_jwk_missing_components() {
let client = JwksClient::new("http://example.com".to_string(), 3600).unwrap();
let jwk = JsonWebKey {
kid: Some("test-key".to_string()),
kty: "RSA".to_string(),
alg: Some("RS256".to_string()),
key_use: Some("sig".to_string()),
n: None, e: None, x5c: None,
};
let result = client.parse_jwk(&jwk);
assert!(result.is_ok());
assert!(result.unwrap().is_none()); }
}