use serde::{Deserialize, Serialize};
use std::cell::RefCell;
use std::rc::Rc;
use turbomcp_core::auth::{AuthError, JwtAlgorithm};
use wasm_bindgen::prelude::*;
use wasm_bindgen_futures::JsFuture;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Jwk {
pub kty: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub kid: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub alg: Option<String>,
#[serde(rename = "use", skip_serializing_if = "Option::is_none")]
pub use_: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub e: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub crv: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub x: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub y: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub k: Option<String>,
}
impl Jwk {
pub fn algorithm(&self) -> Option<JwtAlgorithm> {
self.alg.as_ref().and_then(|a| a.parse().ok())
}
pub fn is_rsa(&self) -> bool {
self.kty == "RSA" && self.n.is_some() && self.e.is_some()
}
pub fn is_ec(&self) -> bool {
self.kty == "EC" && self.crv.is_some() && self.x.is_some() && self.y.is_some()
}
pub fn is_symmetric(&self) -> bool {
self.kty == "oct" && self.k.is_some()
}
pub fn is_signing_key(&self) -> bool {
self.use_.as_ref().is_none_or(|u| u == "sig")
}
pub fn is_compatible_with_algorithm(&self, algorithm: JwtAlgorithm) -> bool {
match algorithm {
JwtAlgorithm::RS256 | JwtAlgorithm::RS384 | JwtAlgorithm::RS512 => self.is_rsa(),
JwtAlgorithm::ES256 => self.is_ec() && self.crv.as_deref() == Some("P-256"),
JwtAlgorithm::ES384 => self.is_ec() && self.crv.as_deref() == Some("P-384"),
JwtAlgorithm::HS256 | JwtAlgorithm::HS384 | JwtAlgorithm::HS512 => self.is_symmetric(),
}
}
pub fn validate_algorithm_compatibility(
&self,
algorithm: JwtAlgorithm,
) -> Result<(), AuthError> {
if !self.is_compatible_with_algorithm(algorithm) {
let key_type = if self.is_rsa() {
"RSA".to_string()
} else if self.is_ec() {
format!("EC ({})", self.crv.as_deref().unwrap_or("unknown curve"))
} else if self.is_symmetric() {
"symmetric (oct)".to_string()
} else {
"unknown".to_string()
};
return Err(AuthError::InvalidCredentialFormat(format!(
"Key type '{}' is not compatible with algorithm {}. \
This may indicate an algorithm confusion attack.",
key_type, algorithm
)));
}
Ok(())
}
pub fn validate_for_public_jwks(&self) -> Result<(), AuthError> {
if self.is_symmetric() {
return Err(AuthError::InvalidCredentialFormat(
"Symmetric keys (kty: oct) must not appear in public JWKS endpoints. \
This would expose the shared secret and allow anyone to forge JWTs. \
Use asymmetric keys (RSA, EC) for public JWKS."
.to_string(),
));
}
Ok(())
}
pub fn to_web_sys_jwk(&self) -> web_sys::JsonWebKey {
let jwk = web_sys::JsonWebKey::new(&self.kty);
if let Some(ref alg) = self.alg {
jwk.set_alg(alg);
}
if let Some(ref n) = self.n {
jwk.set_n(n);
}
if let Some(ref e) = self.e {
jwk.set_e(e);
}
if let Some(ref crv) = self.crv {
jwk.set_crv(crv);
}
if let Some(ref x) = self.x {
jwk.set_x(x);
}
if let Some(ref y) = self.y {
jwk.set_y(y);
}
if let Some(ref k) = self.k {
jwk.set_k(k);
}
let key_ops = js_sys::Array::new();
key_ops.push(&JsValue::from_str("verify"));
jwk.set_key_ops(&key_ops);
jwk
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct JwkSet {
pub keys: Vec<Jwk>,
}
impl JwkSet {
pub fn find_by_kid(&self, kid: &str) -> Option<&Jwk> {
self.keys.iter().find(|k| k.kid.as_deref() == Some(kid))
}
pub fn find_by_algorithm(&self, alg: JwtAlgorithm) -> Option<&Jwk> {
let alg_str = alg.as_str();
self.keys
.iter()
.find(|k| k.alg.as_deref() == Some(alg_str) && k.is_signing_key())
}
pub fn first_signing_key(&self) -> Option<&Jwk> {
self.keys.iter().find(|k| k.is_signing_key())
}
pub fn filter_for_public_jwks(self) -> Self {
let mut valid_keys = Vec::new();
for key in self.keys {
if let Err(e) = key.validate_for_public_jwks() {
#[cfg(target_arch = "wasm32")]
web_sys::console::warn_1(
&format!(
"⚠️ Security: Rejecting symmetric key from public JWKS (kid: {:?}): {}",
key.kid, e
)
.into(),
);
continue;
}
valid_keys.push(key);
}
Self { keys: valid_keys }
}
}
#[derive(Debug, Clone)]
struct CacheEntry {
jwks: JwkSet,
fetched_at: f64, }
enum FetchError {
Transient(AuthError),
Permanent(AuthError),
}
impl FetchError {
fn into_auth_error(self) -> AuthError {
match self {
Self::Transient(e) | Self::Permanent(e) => e,
}
}
fn is_transient(&self) -> bool {
matches!(self, Self::Transient(_))
}
}
const MAX_CACHE_AGE_MS: f64 = 6.0 * 3600.0 * 1000.0;
#[derive(Clone)]
pub struct JwksCache {
url: String,
ttl_ms: f64,
cache: Rc<RefCell<Option<CacheEntry>>>,
max_retries: u32,
retry_base_delay_ms: f64,
allow_insecure: bool,
}
impl JwksCache {
pub fn new(url: impl Into<String>) -> Self {
Self {
url: url.into(),
ttl_ms: 3600.0 * 1000.0, cache: Rc::new(RefCell::new(None)),
max_retries: 3,
retry_base_delay_ms: 100.0,
allow_insecure: false,
}
}
pub fn allow_insecure_http(mut self) -> Self {
self.allow_insecure = true;
self
}
fn validate_url(&self) -> Result<(), AuthError> {
let url_lower = self.url.to_lowercase();
let is_localhost = url_lower.contains("://localhost")
|| url_lower.contains("://127.0.0.1")
|| url_lower.contains("://[::1]");
if self.allow_insecure || is_localhost {
return Ok(());
}
if !url_lower.starts_with("https://") {
return Err(AuthError::KeyFetchError(
"JWKS URL must use HTTPS to prevent man-in-the-middle attacks. \
Use allow_insecure_http() only for local development."
.to_string(),
));
}
Ok(())
}
pub fn with_ttl_seconds(mut self, seconds: u64) -> Self {
self.ttl_ms = seconds as f64 * 1000.0;
self
}
pub fn with_retry_config(mut self, max_retries: u32, base_delay_ms: f64) -> Self {
self.max_retries = max_retries;
self.retry_base_delay_ms = base_delay_ms;
self
}
pub async fn get_jwks(&self) -> Result<JwkSet, AuthError> {
if let Some(ref entry) = *self.cache.borrow() {
let now = js_sys::Date::now();
let age_ms = now - entry.fetched_at;
let effective_ttl = self.ttl_ms.min(MAX_CACHE_AGE_MS);
if age_ms < effective_ttl {
return Ok(entry.jwks.clone());
}
}
let jwks = self.fetch_jwks().await?;
*self.cache.borrow_mut() = Some(CacheEntry {
jwks: jwks.clone(),
fetched_at: js_sys::Date::now(),
});
Ok(jwks)
}
pub async fn refresh(&self) -> Result<JwkSet, AuthError> {
let jwks = self.fetch_jwks().await?;
*self.cache.borrow_mut() = Some(CacheEntry {
jwks: jwks.clone(),
fetched_at: js_sys::Date::now(),
});
Ok(jwks)
}
pub async fn find_key(&self, kid: &str) -> Result<Jwk, AuthError> {
let jwks = self.get_jwks().await?;
jwks.find_by_kid(kid)
.cloned()
.ok_or_else(|| AuthError::KeyNotFound(kid.to_string()))
}
async fn fetch_jwks(&self) -> Result<JwkSet, AuthError> {
let mut last_error = None;
for attempt in 0..=self.max_retries {
match self.fetch_jwks_internal().await {
Ok(jwks) => return Ok(jwks),
Err(fetch_err) => {
if !fetch_err.is_transient() {
return Err(fetch_err.into_auth_error());
}
last_error = Some(fetch_err.into_auth_error());
if attempt < self.max_retries {
let delay_ms = self.retry_base_delay_ms * 2_f64.powi(attempt as i32);
Self::sleep_ms(delay_ms).await;
}
}
}
}
Err(last_error.unwrap_or_else(|| {
AuthError::KeyFetchError("Failed to fetch JWKS from authorization server".to_string())
}))
}
async fn fetch_jwks_internal(&self) -> Result<JwkSet, FetchError> {
self.validate_url().map_err(FetchError::Permanent)?;
let window = web_sys::window().ok_or_else(|| {
FetchError::Permanent(AuthError::Internal(
"No window object available".to_string(),
))
})?;
let request = web_sys::Request::new_with_str(&self.url).map_err(|_| {
FetchError::Permanent(AuthError::KeyFetchError(
"Failed to fetch JWKS from authorization server".to_string(),
))
})?;
let promise = window.fetch_with_request(&request);
let response = JsFuture::from(promise).await.map_err(|e| {
#[cfg(target_arch = "wasm32")]
web_sys::console::warn_1(&format!("JWKS network fetch failed: {:?}", e).into());
FetchError::Transient(AuthError::KeyFetchError(
"Failed to fetch JWKS from authorization server".to_string(),
))
})?;
let response: web_sys::Response = response.dyn_into().map_err(|_| {
FetchError::Permanent(AuthError::KeyFetchError(
"Failed to fetch JWKS from authorization server".to_string(),
))
})?;
if !response.ok() {
let status = response.status();
#[cfg(target_arch = "wasm32")]
web_sys::console::warn_1(
&format!("JWKS fetch failed with HTTP status {}", status).into(),
);
let error = AuthError::KeyFetchError(
"Failed to fetch JWKS from authorization server".to_string(),
);
return if status >= 500 {
Err(FetchError::Transient(error))
} else {
Err(FetchError::Permanent(error))
};
}
let json_promise = response.json().map_err(|_| {
FetchError::Permanent(AuthError::KeyFetchError(
"Failed to fetch JWKS from authorization server".to_string(),
))
})?;
let json_value = JsFuture::from(json_promise).await.map_err(|_| {
FetchError::Permanent(AuthError::KeyFetchError(
"Failed to fetch JWKS from authorization server".to_string(),
))
})?;
let jwks: JwkSet = serde_wasm_bindgen::from_value(json_value).map_err(|e| {
#[cfg(target_arch = "wasm32")]
web_sys::console::warn_1(&format!("JWKS format invalid: {:?}", e).into());
FetchError::Permanent(AuthError::KeyFetchError(
"Failed to fetch JWKS from authorization server".to_string(),
))
})?;
let jwks = jwks.filter_for_public_jwks();
Ok(jwks)
}
async fn sleep_ms(ms: f64) {
let promise = js_sys::Promise::new(&mut |resolve, _| {
let global = js_sys::global();
let set_timeout = js_sys::Reflect::get(&global, &"setTimeout".into())
.ok()
.and_then(|v| v.dyn_into::<js_sys::Function>().ok());
if let Some(timeout_fn) = set_timeout {
let _ = timeout_fn.call2(&global, &resolve, &(ms as i32).into());
} else {
let _ = resolve.call0(&JsValue::undefined());
}
});
let _ = JsFuture::from(promise).await;
}
pub fn clear(&self) {
*self.cache.borrow_mut() = None;
}
}
impl std::fmt::Debug for JwksCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("JwksCache")
.field("url", &self.url)
.field("ttl_ms", &self.ttl_ms)
.finish()
}
}
pub async fn fetch_jwks(url: &str) -> Result<JwkSet, AuthError> {
let cache = JwksCache::new(url);
cache.get_jwks().await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_jwk_is_rsa() {
let jwk = Jwk {
kty: "RSA".to_string(),
kid: Some("key1".to_string()),
alg: Some("RS256".to_string()),
use_: Some("sig".to_string()),
n: Some("modulus".to_string()),
e: Some("AQAB".to_string()),
crv: None,
x: None,
y: None,
k: None,
};
assert!(jwk.is_rsa());
assert!(!jwk.is_ec());
assert!(jwk.is_signing_key());
assert_eq!(jwk.algorithm(), Some(JwtAlgorithm::RS256));
}
#[test]
fn test_jwk_is_ec() {
let jwk = Jwk {
kty: "EC".to_string(),
kid: Some("key2".to_string()),
alg: Some("ES256".to_string()),
use_: Some("sig".to_string()),
n: None,
e: None,
crv: Some("P-256".to_string()),
x: Some("x-coord".to_string()),
y: Some("y-coord".to_string()),
k: None,
};
assert!(!jwk.is_rsa());
assert!(jwk.is_ec());
assert!(jwk.is_signing_key());
assert_eq!(jwk.algorithm(), Some(JwtAlgorithm::ES256));
}
#[test]
fn test_jwks_find_by_kid() {
let jwks = JwkSet {
keys: vec![
Jwk {
kty: "RSA".to_string(),
kid: Some("key1".to_string()),
alg: Some("RS256".to_string()),
use_: None,
n: Some("n".to_string()),
e: Some("e".to_string()),
crv: None,
x: None,
y: None,
k: None,
},
Jwk {
kty: "EC".to_string(),
kid: Some("key2".to_string()),
alg: Some("ES256".to_string()),
use_: None,
n: None,
e: None,
crv: Some("P-256".to_string()),
x: Some("x".to_string()),
y: Some("y".to_string()),
k: None,
},
],
};
assert!(jwks.find_by_kid("key1").is_some());
assert!(jwks.find_by_kid("key2").is_some());
assert!(jwks.find_by_kid("key3").is_none());
}
#[test]
fn test_jwk_is_symmetric() {
let jwk = Jwk {
kty: "oct".to_string(),
kid: Some("hmac-key".to_string()),
alg: Some("HS256".to_string()),
use_: Some("sig".to_string()),
n: None,
e: None,
crv: None,
x: None,
y: None,
k: Some("c2VjcmV0".to_string()), };
assert!(jwk.is_symmetric());
assert!(!jwk.is_rsa());
assert!(!jwk.is_ec());
}
#[test]
fn test_rsa_key_compatible_with_rs_algorithms() {
let rsa_jwk = Jwk {
kty: "RSA".to_string(),
kid: None,
alg: None,
use_: None,
n: Some("modulus".to_string()),
e: Some("AQAB".to_string()),
crv: None,
x: None,
y: None,
k: None,
};
assert!(rsa_jwk.is_compatible_with_algorithm(JwtAlgorithm::RS256));
assert!(rsa_jwk.is_compatible_with_algorithm(JwtAlgorithm::RS384));
assert!(rsa_jwk.is_compatible_with_algorithm(JwtAlgorithm::RS512));
assert!(!rsa_jwk.is_compatible_with_algorithm(JwtAlgorithm::ES256));
assert!(!rsa_jwk.is_compatible_with_algorithm(JwtAlgorithm::ES384));
assert!(!rsa_jwk.is_compatible_with_algorithm(JwtAlgorithm::HS256));
assert!(!rsa_jwk.is_compatible_with_algorithm(JwtAlgorithm::HS384));
assert!(!rsa_jwk.is_compatible_with_algorithm(JwtAlgorithm::HS512));
}
#[test]
fn test_ec_key_compatible_with_es_algorithms() {
let ec_p256_jwk = Jwk {
kty: "EC".to_string(),
kid: None,
alg: None,
use_: None,
n: None,
e: None,
crv: Some("P-256".to_string()),
x: Some("x-coord".to_string()),
y: Some("y-coord".to_string()),
k: None,
};
assert!(ec_p256_jwk.is_compatible_with_algorithm(JwtAlgorithm::ES256));
assert!(!ec_p256_jwk.is_compatible_with_algorithm(JwtAlgorithm::ES384));
assert!(!ec_p256_jwk.is_compatible_with_algorithm(JwtAlgorithm::RS256));
assert!(!ec_p256_jwk.is_compatible_with_algorithm(JwtAlgorithm::HS256));
let ec_p384_jwk = Jwk {
kty: "EC".to_string(),
kid: None,
alg: None,
use_: None,
n: None,
e: None,
crv: Some("P-384".to_string()),
x: Some("x-coord".to_string()),
y: Some("y-coord".to_string()),
k: None,
};
assert!(!ec_p384_jwk.is_compatible_with_algorithm(JwtAlgorithm::ES256));
assert!(ec_p384_jwk.is_compatible_with_algorithm(JwtAlgorithm::ES384));
}
#[test]
fn test_symmetric_key_compatible_with_hs_algorithms() {
let hmac_jwk = Jwk {
kty: "oct".to_string(),
kid: None,
alg: None,
use_: None,
n: None,
e: None,
crv: None,
x: None,
y: None,
k: Some("c2VjcmV0".to_string()),
};
assert!(hmac_jwk.is_compatible_with_algorithm(JwtAlgorithm::HS256));
assert!(hmac_jwk.is_compatible_with_algorithm(JwtAlgorithm::HS384));
assert!(hmac_jwk.is_compatible_with_algorithm(JwtAlgorithm::HS512));
assert!(!hmac_jwk.is_compatible_with_algorithm(JwtAlgorithm::RS256));
assert!(!hmac_jwk.is_compatible_with_algorithm(JwtAlgorithm::ES256));
}
#[test]
fn test_algorithm_confusion_attack_prevention() {
let rsa_public_key = Jwk {
kty: "RSA".to_string(),
kid: None,
alg: Some("RS256".to_string()), use_: Some("sig".to_string()),
n: Some("modulus".to_string()),
e: Some("AQAB".to_string()),
crv: None,
x: None,
y: None,
k: None,
};
let result = rsa_public_key.validate_algorithm_compatibility(JwtAlgorithm::HS256);
assert!(result.is_err(), "RSA key should not be usable with HS256");
let err = result.unwrap_err();
let err_msg = err.to_string();
assert!(
err_msg.contains("not compatible") && err_msg.contains("algorithm confusion"),
"Error should mention algorithm confusion attack: {}",
err_msg
);
}
#[test]
fn test_jwks_url_https_required() {
let https_cache = JwksCache::new("https://auth.example.com/.well-known/jwks.json");
assert!(https_cache.validate_url().is_ok());
}
#[test]
fn test_jwks_url_http_rejected() {
let http_cache = JwksCache::new("http://auth.example.com/.well-known/jwks.json");
let result = http_cache.validate_url();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.to_string().contains("HTTPS"),
"Error should mention HTTPS requirement: {}",
err
);
}
#[test]
fn test_jwks_url_localhost_allowed() {
let localhost_cache = JwksCache::new("http://localhost:8080/.well-known/jwks.json");
assert!(localhost_cache.validate_url().is_ok());
let localhost_127 = JwksCache::new("http://127.0.0.1:8080/.well-known/jwks.json");
assert!(localhost_127.validate_url().is_ok());
let localhost_ipv6 = JwksCache::new("http://[::1]:8080/.well-known/jwks.json");
assert!(localhost_ipv6.validate_url().is_ok());
}
#[test]
fn test_jwks_url_insecure_mode() {
let cache =
JwksCache::new("http://test-server/.well-known/jwks.json").allow_insecure_http();
assert!(cache.validate_url().is_ok());
}
#[test]
fn test_jwk_validate_for_public_jwks_rejects_symmetric() {
let hmac_key = Jwk {
kty: "oct".to_string(),
kid: Some("hmac-key-1".to_string()),
alg: Some("HS256".to_string()),
use_: Some("sig".to_string()),
n: None,
e: None,
crv: None,
x: None,
y: None,
k: Some("c2VjcmV0".to_string()),
};
let result = hmac_key.validate_for_public_jwks();
assert!(result.is_err(), "Symmetric key should be rejected");
let err = result.unwrap_err();
let err_msg = err.to_string();
assert!(
err_msg.contains("Symmetric keys") && err_msg.contains("must not appear"),
"Error should explain symmetric key prohibition: {}",
err_msg
);
}
#[test]
fn test_jwk_validate_for_public_jwks_accepts_asymmetric() {
let rsa_key = Jwk {
kty: "RSA".to_string(),
kid: Some("rsa-key-1".to_string()),
alg: Some("RS256".to_string()),
use_: Some("sig".to_string()),
n: Some("modulus".to_string()),
e: Some("AQAB".to_string()),
crv: None,
x: None,
y: None,
k: None,
};
assert!(rsa_key.validate_for_public_jwks().is_ok());
let ec_key = Jwk {
kty: "EC".to_string(),
kid: Some("ec-key-1".to_string()),
alg: Some("ES256".to_string()),
use_: Some("sig".to_string()),
n: None,
e: None,
crv: Some("P-256".to_string()),
x: Some("x-coord".to_string()),
y: Some("y-coord".to_string()),
k: None,
};
assert!(ec_key.validate_for_public_jwks().is_ok());
}
#[test]
fn test_jwks_filter_for_public_jwks() {
let jwks = JwkSet {
keys: vec![
Jwk {
kty: "RSA".to_string(),
kid: Some("rsa-1".to_string()),
alg: Some("RS256".to_string()),
use_: Some("sig".to_string()),
n: Some("n".to_string()),
e: Some("e".to_string()),
crv: None,
x: None,
y: None,
k: None,
},
Jwk {
kty: "oct".to_string(),
kid: Some("hmac-1".to_string()),
alg: Some("HS256".to_string()),
use_: Some("sig".to_string()),
n: None,
e: None,
crv: None,
x: None,
y: None,
k: Some("secret".to_string()),
},
Jwk {
kty: "EC".to_string(),
kid: Some("ec-1".to_string()),
alg: Some("ES256".to_string()),
use_: Some("sig".to_string()),
n: None,
e: None,
crv: Some("P-256".to_string()),
x: Some("x".to_string()),
y: Some("y".to_string()),
k: None,
},
],
};
let filtered = jwks.filter_for_public_jwks();
assert_eq!(filtered.keys.len(), 2);
assert!(filtered.find_by_kid("rsa-1").is_some());
assert!(filtered.find_by_kid("hmac-1").is_none()); assert!(filtered.find_by_kid("ec-1").is_some());
}
}