use jsonwebtoken::jwk::JwkSet;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use turbomcp_protocol::{Error as McpError, Result as McpResult};
use url::Url;
#[derive(Debug, Clone)]
struct CachedJwks {
jwks: JwkSet,
cached_at: SystemTime,
ttl: Duration,
}
impl CachedJwks {
fn is_valid(&self) -> bool {
match SystemTime::now().duration_since(self.cached_at) {
Ok(age) => age < self.ttl,
Err(_) => false, }
}
}
#[derive(Debug, Clone)]
pub struct JwksClient {
jwks_uri: String,
cache: Arc<RwLock<Option<CachedJwks>>>,
http_client: reqwest::Client,
cache_ttl: Duration,
min_refresh_interval: Duration,
last_refresh: Arc<RwLock<Option<SystemTime>>>,
ssrf_validator: Option<Arc<crate::ssrf::SsrfValidator>>,
}
impl JwksClient {
pub fn new(jwks_uri: String) -> Self {
Self {
jwks_uri,
cache: Arc::new(RwLock::new(None)),
http_client: reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.build()
.expect("Failed to create HTTP client"),
cache_ttl: Duration::from_secs(600), min_refresh_interval: Duration::from_secs(5), last_refresh: Arc::new(RwLock::new(None)),
ssrf_validator: None,
}
}
pub fn with_ssrf_validator(
jwks_uri: String,
ssrf_validator: Arc<crate::ssrf::SsrfValidator>,
) -> Self {
let mut client = Self::new(jwks_uri);
client.ssrf_validator = Some(ssrf_validator);
client
}
pub fn with_ttl(jwks_uri: String, cache_ttl: Duration) -> Self {
let mut client = Self::new(jwks_uri);
client.cache_ttl = cache_ttl;
client
}
pub async fn get_jwks(&self) -> McpResult<JwkSet> {
{
let cache = self.cache.read().await;
if let Some(cached) = cache.as_ref()
&& cached.is_valid()
{
debug!(jwks_uri = %self.jwks_uri, "Using cached JWKS");
return Ok(cached.jwks.clone());
}
}
self.fetch_and_cache().await
}
pub async fn refresh(&self) -> McpResult<JwkSet> {
{
let last_refresh = self.last_refresh.read().await;
if let Some(last) = *last_refresh
&& let Ok(since_last) = SystemTime::now().duration_since(last)
&& since_last < self.min_refresh_interval
{
warn!(
jwks_uri = %self.jwks_uri,
since_last_ms = since_last.as_millis(),
"JWKS refresh rate limited, using cache"
);
return self.get_jwks().await;
}
}
self.fetch_and_cache().await
}
async fn fetch_and_cache(&self) -> McpResult<JwkSet> {
info!(jwks_uri = %self.jwks_uri, "Fetching JWKS from endpoint");
if !Self::is_allowed_jwks_uri(&self.jwks_uri) {
return Err(McpError::invalid_params(
"JWKS endpoint must use HTTPS (HTTP only allowed for localhost)".to_string(),
));
}
if let Some(ref validator) = self.ssrf_validator {
validator.validate_url(&self.jwks_uri).map_err(|e| {
McpError::authentication(format!("SSRF validation failed for JWKS URI: {e}"))
})?;
}
let response = self
.http_client
.get(&self.jwks_uri)
.send()
.await
.map_err(|e| {
error!(jwks_uri = %self.jwks_uri, error = %e, "Failed to fetch JWKS");
McpError::internal(format!("JWKS fetch failed: {e}"))
})?;
if !response.status().is_success() {
error!(
jwks_uri = %self.jwks_uri,
status = %response.status(),
"JWKS endpoint returned error status"
);
return Err(McpError::internal(format!(
"JWKS endpoint returned status {}",
response.status()
)));
}
const MAX_JWKS_RESPONSE_SIZE: usize = 65_536; let bytes = response.bytes().await.map_err(|e| {
error!(jwks_uri = %self.jwks_uri, error = %e, "Failed to read JWKS response body");
McpError::internal(format!("Failed to read JWKS response: {e}"))
})?;
if bytes.len() > MAX_JWKS_RESPONSE_SIZE {
error!(
jwks_uri = %self.jwks_uri,
size = bytes.len(),
max = MAX_JWKS_RESPONSE_SIZE,
"JWKS response exceeds size limit"
);
return Err(McpError::internal(format!(
"JWKS response too large: {} bytes (max: {} bytes)",
bytes.len(),
MAX_JWKS_RESPONSE_SIZE
)));
}
let jwks: JwkSet = serde_json::from_slice(&bytes).map_err(|e| {
error!(jwks_uri = %self.jwks_uri, error = %e, "Failed to parse JWKS JSON");
McpError::internal(format!("Invalid JWKS format: {e}"))
})?;
info!(
jwks_uri = %self.jwks_uri,
key_count = jwks.keys.len(),
"Successfully fetched JWKS"
);
{
let mut cache = self.cache.write().await;
*cache = Some(CachedJwks {
jwks: jwks.clone(),
cached_at: SystemTime::now(),
ttl: self.cache_ttl,
});
}
{
let mut last_refresh = self.last_refresh.write().await;
*last_refresh = Some(SystemTime::now());
}
Ok(jwks)
}
pub fn jwks_uri(&self) -> &str {
&self.jwks_uri
}
fn is_allowed_jwks_uri(jwks_uri: &str) -> bool {
let Ok(parsed) = Url::parse(jwks_uri) else {
return false;
};
match parsed.scheme() {
"https" => true,
"http" => matches!(parsed.host_str(), Some("localhost" | "127.0.0.1" | "::1")),
_ => false,
}
}
pub async fn clear_cache(&self) {
let mut cache = self.cache.write().await;
*cache = None;
debug!(jwks_uri = %self.jwks_uri, "JWKS cache cleared");
}
}
#[derive(Debug, Default)]
pub struct JwksCache {
clients: Arc<RwLock<std::collections::HashMap<String, Arc<JwksClient>>>>,
ssrf_validator: Option<Arc<crate::ssrf::SsrfValidator>>,
}
impl JwksCache {
pub fn new() -> Self {
Self {
clients: Arc::new(RwLock::new(std::collections::HashMap::new())),
ssrf_validator: None,
}
}
pub fn with_ssrf_validator(ssrf_validator: Arc<crate::ssrf::SsrfValidator>) -> Self {
Self {
clients: Arc::new(RwLock::new(std::collections::HashMap::new())),
ssrf_validator: Some(ssrf_validator),
}
}
pub async fn get_client_for_issuer(&self, issuer: &str) -> Arc<JwksClient> {
let mut clients = self.clients.write().await;
if let Some(client) = clients.get(issuer) {
return Arc::clone(client);
}
let jwks_uri = Url::parse(issuer)
.and_then(|base| base.join(".well-known/jwks.json"))
.map(|u| u.to_string())
.unwrap_or_else(|_| format!("{issuer}/.well-known/jwks.json"));
let client = Arc::new(if let Some(ref validator) = self.ssrf_validator {
JwksClient::with_ssrf_validator(jwks_uri, Arc::clone(validator))
} else {
JwksClient::new(jwks_uri)
});
clients.insert(issuer.to_string(), Arc::clone(&client));
client
}
pub async fn get_jwks_for_issuer(&self, issuer: &str) -> McpResult<JwkSet> {
let client = self.get_client_for_issuer(issuer).await;
client.get_jwks().await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jwks_client_creation() {
let client = JwksClient::new("https://auth.example.com/jwks".to_string());
assert_eq!(client.jwks_uri(), "https://auth.example.com/jwks");
assert_eq!(client.cache_ttl, Duration::from_secs(600));
}
#[test]
fn test_jwks_client_with_custom_ttl() {
let client = JwksClient::with_ttl(
"https://auth.example.com/jwks".to_string(),
Duration::from_secs(300),
);
assert_eq!(client.cache_ttl, Duration::from_secs(300));
}
#[test]
fn test_cached_jwks_validity() {
let jwks = JwkSet { keys: vec![] };
let cached = CachedJwks {
jwks,
cached_at: SystemTime::now(),
ttl: Duration::from_secs(600),
};
assert!(cached.is_valid());
}
#[test]
fn test_cached_jwks_expired() {
let jwks = JwkSet { keys: vec![] };
let cached = CachedJwks {
jwks,
cached_at: SystemTime::now() - Duration::from_secs(700),
ttl: Duration::from_secs(600),
};
assert!(!cached.is_valid());
}
#[tokio::test]
async fn test_jwks_cache_creation() {
let cache = JwksCache::new();
let client1 = cache
.get_client_for_issuer("https://auth.example.com")
.await;
let client2 = cache
.get_client_for_issuer("https://auth.example.com")
.await;
assert!(Arc::ptr_eq(&client1, &client2));
}
#[tokio::test]
async fn test_jwks_cache_different_issuers() {
let cache = JwksCache::new();
let client1 = cache
.get_client_for_issuer("https://auth1.example.com")
.await;
let client2 = cache
.get_client_for_issuer("https://auth2.example.com")
.await;
assert!(!Arc::ptr_eq(&client1, &client2));
}
#[tokio::test]
async fn test_clear_cache() {
let client = JwksClient::new("https://auth.example.com/jwks".to_string());
client.clear_cache().await;
let cache = client.cache.read().await;
assert!(cache.is_none());
}
}