use std::{
collections::HashMap,
path::PathBuf,
time::{Duration, Instant},
};
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::JwkSet};
use serde::Deserialize;
use tokio::{net::lookup_host, sync::RwLock};
use crate::auth::{AuthIdentity, AuthMethod};
fn evaluate_oauth_redirect(
attempt: &reqwest::redirect::Attempt<'_>,
allow_http: bool,
) -> Result<(), String> {
let prev_https = attempt
.previous()
.last()
.is_some_and(|prev| prev.scheme() == "https");
let target_url = attempt.url();
let dest_scheme = target_url.scheme();
if dest_scheme != "https" {
if prev_https {
return Err("redirect downgrades https -> http".to_owned());
}
if !allow_http || dest_scheme != "http" {
return Err("redirect to non-HTTP(S) URL refused".to_owned());
}
}
if let Some(reason) = crate::ssrf::redirect_target_reason(target_url) {
return Err(format!("redirect target forbidden: {reason}"));
}
if attempt.previous().len() >= 2 {
return Err("too many redirects (max 2)".to_owned());
}
Ok(())
}
#[cfg_attr(not(any(test, feature = "test-helpers")), allow(dead_code))]
async fn screen_oauth_target_with_test_override(
url: &str,
allow_http: bool,
#[cfg(any(test, feature = "test-helpers"))] test_allow_loopback_ssrf: bool,
) -> Result<(), crate::error::McpxError> {
let parsed = check_oauth_url("oauth target", url, allow_http)?;
#[cfg(any(test, feature = "test-helpers"))]
if test_allow_loopback_ssrf {
return Ok(());
}
if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
return Err(crate::error::McpxError::Config(format!(
"OAuth target forbidden ({reason}): {url}"
)));
}
let host = parsed.host_str().ok_or_else(|| {
crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
})?;
let port = parsed.port_or_known_default().ok_or_else(|| {
crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
})?;
let addrs = lookup_host((host, port)).await.map_err(|error| {
crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
})?;
let mut any_addr = false;
for addr in addrs {
any_addr = true;
if let Some(reason) = crate::ssrf::ip_block_reason(addr.ip()) {
return Err(crate::error::McpxError::Config(format!(
"OAuth target resolved to blocked IP ({reason}): {url}"
)));
}
}
if !any_addr {
return Err(crate::error::McpxError::Config(format!(
"OAuth target DNS resolution returned no addresses: {url}"
)));
}
Ok(())
}
async fn screen_oauth_target(url: &str, allow_http: bool) -> Result<(), crate::error::McpxError> {
#[cfg(any(test, feature = "test-helpers"))]
{
screen_oauth_target_with_test_override(url, allow_http, false).await
}
#[cfg(not(any(test, feature = "test-helpers")))]
{
let parsed = check_oauth_url("oauth target", url, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
return Err(crate::error::McpxError::Config(format!(
"OAuth target forbidden ({reason}): {url}"
)));
}
let host = parsed.host_str().ok_or_else(|| {
crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
})?;
let port = parsed.port_or_known_default().ok_or_else(|| {
crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
})?;
let addrs = lookup_host((host, port)).await.map_err(|error| {
crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
})?;
let mut any_addr = false;
for addr in addrs {
any_addr = true;
if let Some(reason) = crate::ssrf::ip_block_reason(addr.ip()) {
return Err(crate::error::McpxError::Config(format!(
"OAuth target resolved to blocked IP ({reason}): {url}"
)));
}
}
if !any_addr {
return Err(crate::error::McpxError::Config(format!(
"OAuth target DNS resolution returned no addresses: {url}"
)));
}
Ok(())
}
}
#[derive(Clone)]
pub struct OauthHttpClient {
inner: reqwest::Client,
allow_http: bool,
#[cfg(any(test, feature = "test-helpers"))]
test_allow_loopback_ssrf: bool,
}
impl OauthHttpClient {
pub fn with_config(config: &OAuthConfig) -> Result<Self, crate::error::McpxError> {
Self::build(Some(config))
}
#[deprecated(
since = "1.2.1",
note = "use OauthHttpClient::with_config(&OAuthConfig) so token/introspect/revoke/exchange traffic inherits ca_cert_path and the allow_http_oauth_urls toggle"
)]
pub fn new() -> Result<Self, crate::error::McpxError> {
Self::build(None)
}
fn build(config: Option<&OAuthConfig>) -> Result<Self, crate::error::McpxError> {
let allow_http = config.is_some_and(|c| c.allow_http_oauth_urls);
let mut builder = reqwest::Client::builder()
.connect_timeout(Duration::from_secs(10))
.timeout(Duration::from_secs(30))
.redirect(reqwest::redirect::Policy::custom(move |attempt| {
match evaluate_oauth_redirect(&attempt, allow_http) {
Ok(()) => attempt.follow(),
Err(reason) => {
tracing::warn!(
reason = %reason,
target = %attempt.url(),
"oauth redirect rejected"
);
attempt.error(reason)
}
}
}));
if let Some(cfg) = config
&& let Some(ref ca_path) = cfg.ca_cert_path
{
let pem = std::fs::read(ca_path).map_err(|e| {
crate::error::McpxError::Startup(format!(
"oauth http client: read ca_cert_path {}: {e}",
ca_path.display()
))
})?;
let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
crate::error::McpxError::Startup(format!(
"oauth http client: parse ca_cert_path {}: {e}",
ca_path.display()
))
})?;
builder = builder.add_root_certificate(cert);
}
let inner = builder.build().map_err(|e| {
crate::error::McpxError::Startup(format!("oauth http client init: {e}"))
})?;
Ok(Self {
inner,
allow_http,
#[cfg(any(test, feature = "test-helpers"))]
test_allow_loopback_ssrf: false,
})
}
async fn send_screened(
&self,
url: &str,
request: reqwest::RequestBuilder,
) -> Result<reqwest::Response, crate::error::McpxError> {
#[cfg(any(test, feature = "test-helpers"))]
if self.test_allow_loopback_ssrf {
screen_oauth_target_with_test_override(url, self.allow_http, true).await?;
} else {
screen_oauth_target(url, self.allow_http).await?;
}
#[cfg(not(any(test, feature = "test-helpers")))]
screen_oauth_target(url, self.allow_http).await?;
request.send().await.map_err(|error| {
crate::error::McpxError::Config(format!("oauth request {url}: {error}"))
})
}
#[cfg(any(test, feature = "test-helpers"))]
#[doc(hidden)]
#[must_use]
pub fn __test_allow_loopback_ssrf(mut self) -> Self {
self.test_allow_loopback_ssrf = true;
self
}
#[doc(hidden)]
pub async fn __test_get(&self, url: &str) -> reqwest::Result<reqwest::Response> {
self.inner.get(url).send().await
}
}
impl std::fmt::Debug for OauthHttpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OauthHttpClient").finish_non_exhaustive()
}
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct OAuthConfig {
pub issuer: String,
pub audience: String,
pub jwks_uri: String,
#[serde(default)]
pub scopes: Vec<ScopeMapping>,
pub role_claim: Option<String>,
#[serde(default)]
pub role_mappings: Vec<RoleMapping>,
#[serde(default = "default_jwks_cache_ttl")]
pub jwks_cache_ttl: String,
pub proxy: Option<OAuthProxyConfig>,
pub token_exchange: Option<TokenExchangeConfig>,
#[serde(default)]
pub ca_cert_path: Option<PathBuf>,
#[serde(default)]
pub allow_http_oauth_urls: bool,
#[serde(default = "default_max_jwks_keys")]
pub max_jwks_keys: usize,
#[serde(default)]
pub strict_audience_validation: bool,
#[serde(default = "default_jwks_max_bytes")]
pub jwks_max_response_bytes: u64,
}
fn default_jwks_cache_ttl() -> String {
"10m".into()
}
const fn default_max_jwks_keys() -> usize {
256
}
const fn default_jwks_max_bytes() -> u64 {
1024 * 1024
}
impl Default for OAuthConfig {
fn default() -> Self {
Self {
issuer: String::new(),
audience: String::new(),
jwks_uri: String::new(),
scopes: Vec::new(),
role_claim: None,
role_mappings: Vec::new(),
jwks_cache_ttl: default_jwks_cache_ttl(),
proxy: None,
token_exchange: None,
ca_cert_path: None,
allow_http_oauth_urls: false,
max_jwks_keys: default_max_jwks_keys(),
strict_audience_validation: false,
jwks_max_response_bytes: default_jwks_max_bytes(),
}
}
}
impl OAuthConfig {
pub fn builder(
issuer: impl Into<String>,
audience: impl Into<String>,
jwks_uri: impl Into<String>,
) -> OAuthConfigBuilder {
OAuthConfigBuilder {
inner: Self {
issuer: issuer.into(),
audience: audience.into(),
jwks_uri: jwks_uri.into(),
..Self::default()
},
}
}
pub fn validate(&self) -> Result<(), crate::error::McpxError> {
let allow_http = self.allow_http_oauth_urls;
let url = check_oauth_url("oauth.issuer", &self.issuer, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
return Err(crate::error::McpxError::Config(format!(
"oauth.issuer forbidden ({reason})"
)));
}
let url = check_oauth_url("oauth.jwks_uri", &self.jwks_uri, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
return Err(crate::error::McpxError::Config(format!(
"oauth.jwks_uri forbidden ({reason})"
)));
}
if let Some(proxy) = &self.proxy {
let url = check_oauth_url(
"oauth.proxy.authorize_url",
&proxy.authorize_url,
allow_http,
)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
return Err(crate::error::McpxError::Config(format!(
"oauth.proxy.authorize_url forbidden ({reason})"
)));
}
let url = check_oauth_url("oauth.proxy.token_url", &proxy.token_url, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
return Err(crate::error::McpxError::Config(format!(
"oauth.proxy.token_url forbidden ({reason})"
)));
}
if let Some(url) = &proxy.introspection_url {
let parsed = check_oauth_url("oauth.proxy.introspection_url", url, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
return Err(crate::error::McpxError::Config(format!(
"oauth.proxy.introspection_url forbidden ({reason})"
)));
}
}
if let Some(url) = &proxy.revocation_url {
let parsed = check_oauth_url("oauth.proxy.revocation_url", url, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
return Err(crate::error::McpxError::Config(format!(
"oauth.proxy.revocation_url forbidden ({reason})"
)));
}
}
}
if let Some(tx) = &self.token_exchange {
let url = check_oauth_url("oauth.token_exchange.token_url", &tx.token_url, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
return Err(crate::error::McpxError::Config(format!(
"oauth.token_exchange.token_url forbidden ({reason})"
)));
}
}
Ok(())
}
}
fn check_oauth_url(
field: &str,
raw: &str,
allow_http: bool,
) -> Result<url::Url, crate::error::McpxError> {
let parsed = url::Url::parse(raw).map_err(|e| {
crate::error::McpxError::Config(format!("{field}: invalid URL {raw:?}: {e}"))
})?;
if !parsed.username().is_empty() || parsed.password().is_some() {
return Err(crate::error::McpxError::Config(format!(
"{field} rejected: URL contains userinfo (credentials in URL are forbidden)"
)));
}
match parsed.scheme() {
"https" => Ok(parsed),
"http" if allow_http => Ok(parsed),
"http" => Err(crate::error::McpxError::Config(format!(
"{field}: must use https scheme (got http; set allow_http_oauth_urls=true \
to override - strongly discouraged in production)"
))),
other => Err(crate::error::McpxError::Config(format!(
"{field}: must use https scheme (got {other:?})"
))),
}
}
#[derive(Debug, Clone)]
#[must_use = "builders do nothing until `.build()` is called"]
pub struct OAuthConfigBuilder {
inner: OAuthConfig,
}
impl OAuthConfigBuilder {
pub fn scopes(mut self, scopes: Vec<ScopeMapping>) -> Self {
self.inner.scopes = scopes;
self
}
pub fn scope(mut self, scope: impl Into<String>, role: impl Into<String>) -> Self {
self.inner.scopes.push(ScopeMapping {
scope: scope.into(),
role: role.into(),
});
self
}
pub fn role_claim(mut self, claim: impl Into<String>) -> Self {
self.inner.role_claim = Some(claim.into());
self
}
pub fn role_mappings(mut self, mappings: Vec<RoleMapping>) -> Self {
self.inner.role_mappings = mappings;
self
}
pub fn role_mapping(mut self, claim_value: impl Into<String>, role: impl Into<String>) -> Self {
self.inner.role_mappings.push(RoleMapping {
claim_value: claim_value.into(),
role: role.into(),
});
self
}
pub fn jwks_cache_ttl(mut self, ttl: impl Into<String>) -> Self {
self.inner.jwks_cache_ttl = ttl.into();
self
}
pub fn proxy(mut self, proxy: OAuthProxyConfig) -> Self {
self.inner.proxy = Some(proxy);
self
}
pub fn token_exchange(mut self, token_exchange: TokenExchangeConfig) -> Self {
self.inner.token_exchange = Some(token_exchange);
self
}
pub fn ca_cert_path(mut self, path: impl Into<PathBuf>) -> Self {
self.inner.ca_cert_path = Some(path.into());
self
}
pub const fn allow_http_oauth_urls(mut self, allow: bool) -> Self {
self.inner.allow_http_oauth_urls = allow;
self
}
pub const fn strict_audience_validation(mut self, strict: bool) -> Self {
self.inner.strict_audience_validation = strict;
self
}
pub const fn jwks_max_response_bytes(mut self, bytes: u64) -> Self {
self.inner.jwks_max_response_bytes = bytes;
self
}
#[must_use]
pub fn build(self) -> OAuthConfig {
self.inner
}
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct ScopeMapping {
pub scope: String,
pub role: String,
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct RoleMapping {
pub claim_value: String,
pub role: String,
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct TokenExchangeConfig {
pub token_url: String,
pub client_id: String,
pub client_secret: Option<secrecy::SecretString>,
pub client_cert: Option<ClientCertConfig>,
pub audience: String,
}
impl TokenExchangeConfig {
#[must_use]
pub fn new(
token_url: String,
client_id: String,
client_secret: Option<secrecy::SecretString>,
client_cert: Option<ClientCertConfig>,
audience: String,
) -> Self {
Self {
token_url,
client_id,
client_secret,
client_cert,
audience,
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct ClientCertConfig {
pub cert_path: PathBuf,
pub key_path: PathBuf,
}
#[derive(Debug, Deserialize)]
#[non_exhaustive]
pub struct ExchangedToken {
pub access_token: String,
pub expires_in: Option<u64>,
pub issued_token_type: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Default)]
#[non_exhaustive]
pub struct OAuthProxyConfig {
pub authorize_url: String,
pub token_url: String,
pub client_id: String,
pub client_secret: Option<secrecy::SecretString>,
#[serde(default)]
pub introspection_url: Option<String>,
#[serde(default)]
pub revocation_url: Option<String>,
#[serde(default)]
pub expose_admin_endpoints: bool,
#[serde(default)]
pub require_auth_on_admin_endpoints: bool,
}
impl OAuthProxyConfig {
pub fn builder(
authorize_url: impl Into<String>,
token_url: impl Into<String>,
client_id: impl Into<String>,
) -> OAuthProxyConfigBuilder {
OAuthProxyConfigBuilder {
inner: Self {
authorize_url: authorize_url.into(),
token_url: token_url.into(),
client_id: client_id.into(),
..Self::default()
},
}
}
}
#[derive(Debug, Clone)]
#[must_use = "builders do nothing until `.build()` is called"]
pub struct OAuthProxyConfigBuilder {
inner: OAuthProxyConfig,
}
impl OAuthProxyConfigBuilder {
pub fn client_secret(mut self, secret: secrecy::SecretString) -> Self {
self.inner.client_secret = Some(secret);
self
}
pub fn introspection_url(mut self, url: impl Into<String>) -> Self {
self.inner.introspection_url = Some(url.into());
self
}
pub fn revocation_url(mut self, url: impl Into<String>) -> Self {
self.inner.revocation_url = Some(url.into());
self
}
pub const fn expose_admin_endpoints(mut self, expose: bool) -> Self {
self.inner.expose_admin_endpoints = expose;
self
}
pub const fn require_auth_on_admin_endpoints(mut self, require: bool) -> Self {
self.inner.require_auth_on_admin_endpoints = require;
self
}
#[must_use]
pub fn build(self) -> OAuthProxyConfig {
self.inner
}
}
type JwksKeyCache = (
HashMap<String, (Algorithm, DecodingKey)>,
Vec<(Algorithm, DecodingKey)>,
);
struct CachedKeys {
keys: HashMap<String, (Algorithm, DecodingKey)>,
unnamed_keys: Vec<(Algorithm, DecodingKey)>,
fetched_at: Instant,
ttl: Duration,
}
impl CachedKeys {
fn is_expired(&self) -> bool {
self.fetched_at.elapsed() >= self.ttl
}
}
#[allow(
missing_debug_implementations,
reason = "contains reqwest::Client and DecodingKey cache with no Debug impl"
)]
#[non_exhaustive]
pub struct JwksCache {
jwks_uri: String,
ttl: Duration,
max_jwks_keys: usize,
max_response_bytes: u64,
allow_http: bool,
inner: RwLock<Option<CachedKeys>>,
http: reqwest::Client,
validation_template: Validation,
expected_audience: String,
strict_audience_validation: bool,
scopes: Vec<ScopeMapping>,
role_claim: Option<String>,
role_mappings: Vec<RoleMapping>,
last_refresh_attempt: RwLock<Option<Instant>>,
refresh_lock: tokio::sync::Mutex<()>,
#[cfg(any(test, feature = "test-helpers"))]
test_allow_loopback_ssrf: bool,
}
const JWKS_REFRESH_COOLDOWN: Duration = Duration::from_secs(10);
const ACCEPTED_ALGS: &[Algorithm] = &[
Algorithm::RS256,
Algorithm::RS384,
Algorithm::RS512,
Algorithm::ES256,
Algorithm::ES384,
Algorithm::PS256,
Algorithm::PS384,
Algorithm::PS512,
Algorithm::EdDSA,
];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum JwtValidationFailure {
Expired,
Invalid,
}
impl JwksCache {
pub fn new(config: &OAuthConfig) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
rustls::crypto::ring::default_provider()
.install_default()
.ok();
jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER
.install_default()
.ok();
let ttl =
humantime::parse_duration(&config.jwks_cache_ttl).unwrap_or(Duration::from_mins(10));
let mut validation = Validation::new(Algorithm::RS256);
validation.validate_aud = false;
validation.set_issuer(&[&config.issuer]);
validation.set_required_spec_claims(&["exp", "iss"]);
validation.validate_exp = true;
validation.validate_nbf = true;
let allow_http = config.allow_http_oauth_urls;
let mut http_builder = reqwest::Client::builder()
.timeout(Duration::from_secs(10))
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::custom(move |attempt| {
match evaluate_oauth_redirect(&attempt, allow_http) {
Ok(()) => attempt.follow(),
Err(reason) => {
tracing::warn!(
reason = %reason,
target = %attempt.url(),
"oauth redirect rejected"
);
attempt.error(reason)
}
}
}));
if let Some(ref ca_path) = config.ca_cert_path {
let pem = std::fs::read(ca_path)?;
let cert = reqwest::tls::Certificate::from_pem(&pem)?;
http_builder = http_builder.add_root_certificate(cert);
}
let http = http_builder.build()?;
Ok(Self {
jwks_uri: config.jwks_uri.clone(),
ttl,
max_jwks_keys: config.max_jwks_keys,
max_response_bytes: config.jwks_max_response_bytes,
allow_http,
inner: RwLock::new(None),
http,
validation_template: validation,
expected_audience: config.audience.clone(),
strict_audience_validation: config.strict_audience_validation,
scopes: config.scopes.clone(),
role_claim: config.role_claim.clone(),
role_mappings: config.role_mappings.clone(),
last_refresh_attempt: RwLock::new(None),
refresh_lock: tokio::sync::Mutex::new(()),
#[cfg(any(test, feature = "test-helpers"))]
test_allow_loopback_ssrf: false,
})
}
#[cfg(any(test, feature = "test-helpers"))]
#[doc(hidden)]
#[must_use]
pub fn __test_allow_loopback_ssrf(mut self) -> Self {
self.test_allow_loopback_ssrf = true;
self
}
pub async fn validate_token(&self, token: &str) -> Option<AuthIdentity> {
self.validate_token_with_reason(token).await.ok()
}
pub async fn validate_token_with_reason(
&self,
token: &str,
) -> Result<AuthIdentity, JwtValidationFailure> {
let claims = self.decode_claims(token).await?;
self.check_audience(&claims)?;
let role = self.resolve_role(&claims)?;
let sub = claims.sub;
let name = claims
.extra
.get("preferred_username")
.and_then(|v| v.as_str())
.map(String::from)
.or_else(|| sub.clone())
.or(claims.azp)
.or(claims.client_id)
.unwrap_or_else(|| "oauth-client".into());
Ok(AuthIdentity {
name,
role,
method: AuthMethod::OAuthJwt,
raw_token: None,
sub,
})
}
async fn decode_claims(&self, token: &str) -> Result<Claims, JwtValidationFailure> {
let (key, alg) = self.select_jwks_key(token).await?;
let mut validation = self.validation_template.clone();
validation.algorithms = vec![alg];
let token_owned = token.to_owned();
let join =
tokio::task::spawn_blocking(move || decode::<Claims>(&token_owned, &key, &validation))
.await;
let decode_result = match join {
Ok(r) => r,
Err(join_err) => {
core::hint::cold_path();
tracing::error!(
error = %join_err,
"JWT decode task panicked or was cancelled"
);
return Err(JwtValidationFailure::Invalid);
}
};
decode_result.map(|td| td.claims).map_err(|e| {
core::hint::cold_path();
let failure = if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
JwtValidationFailure::Expired
} else {
JwtValidationFailure::Invalid
};
tracing::debug!(error = %e, ?alg, ?failure, "JWT decode failed");
failure
})
}
#[allow(clippy::cognitive_complexity)]
async fn select_jwks_key(
&self,
token: &str,
) -> Result<(DecodingKey, Algorithm), JwtValidationFailure> {
let Ok(header) = decode_header(token) else {
core::hint::cold_path();
tracing::debug!("JWT header decode failed");
return Err(JwtValidationFailure::Invalid);
};
let kid = header.kid.as_deref();
tracing::debug!(alg = ?header.alg, kid = kid.unwrap_or("-"), "JWT header decoded");
if !ACCEPTED_ALGS.contains(&header.alg) {
core::hint::cold_path();
tracing::debug!(alg = ?header.alg, "JWT algorithm not accepted");
return Err(JwtValidationFailure::Invalid);
}
let Some(key) = self.find_key(kid, header.alg).await else {
core::hint::cold_path();
tracing::debug!(kid = kid.unwrap_or("-"), alg = ?header.alg, "no matching JWKS key found");
return Err(JwtValidationFailure::Invalid);
};
Ok((key, header.alg))
}
fn check_audience(&self, claims: &Claims) -> Result<(), JwtValidationFailure> {
let aud_ok = claims.aud.contains(&self.expected_audience)
|| (!self.strict_audience_validation
&& claims
.azp
.as_deref()
.is_some_and(|azp| azp == self.expected_audience));
if aud_ok {
return Ok(());
}
core::hint::cold_path();
tracing::debug!(
aud = ?claims.aud.0,
azp = ?claims.azp,
expected = %self.expected_audience,
strict = self.strict_audience_validation,
"JWT rejected: audience mismatch"
);
Err(JwtValidationFailure::Invalid)
}
fn resolve_role(&self, claims: &Claims) -> Result<String, JwtValidationFailure> {
if let Some(ref claim_path) = self.role_claim {
let values = resolve_claim_path(&claims.extra, claim_path);
return self
.role_mappings
.iter()
.find(|m| values.contains(&m.claim_value.as_str()))
.map(|m| m.role.clone())
.ok_or(JwtValidationFailure::Invalid);
}
let token_scopes: Vec<&str> = claims
.scope
.as_deref()
.unwrap_or("")
.split_whitespace()
.collect();
self.scopes
.iter()
.find(|m| token_scopes.contains(&m.scope.as_str()))
.map(|m| m.role.clone())
.ok_or(JwtValidationFailure::Invalid)
}
async fn find_key(&self, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
{
let guard = self.inner.read().await;
if let Some(cached) = guard.as_ref()
&& !cached.is_expired()
&& let Some(key) = lookup_key(cached, kid, alg)
{
return Some(key);
}
}
self.refresh_with_cooldown().await;
let guard = self.inner.read().await;
guard
.as_ref()
.and_then(|cached| lookup_key(cached, kid, alg))
}
async fn refresh_with_cooldown(&self) {
let _guard = self.refresh_lock.lock().await;
{
let last = self.last_refresh_attempt.read().await;
if let Some(ts) = *last
&& ts.elapsed() < JWKS_REFRESH_COOLDOWN
{
tracing::debug!(
elapsed_ms = ts.elapsed().as_millis(),
cooldown_ms = JWKS_REFRESH_COOLDOWN.as_millis(),
"JWKS refresh skipped (cooldown active)"
);
return;
}
}
{
let mut last = self.last_refresh_attempt.write().await;
*last = Some(Instant::now());
}
let _ = self.refresh_inner().await;
}
async fn refresh_inner(&self) -> Result<(), String> {
let Some(jwks) = self.fetch_jwks().await else {
return Ok(());
};
let (keys, unnamed_keys) = match build_key_cache(&jwks, self.max_jwks_keys) {
Ok(cache) => cache,
Err(msg) => {
tracing::warn!(reason = %msg, "JWKS key cap exceeded; refusing to populate cache");
return Err(msg);
}
};
tracing::debug!(
named = keys.len(),
unnamed = unnamed_keys.len(),
"JWKS refreshed"
);
let mut guard = self.inner.write().await;
*guard = Some(CachedKeys {
keys,
unnamed_keys,
fetched_at: Instant::now(),
ttl: self.ttl,
});
Ok(())
}
#[allow(
clippy::cognitive_complexity,
reason = "screening, bounded streaming, and parse logging are intentionally kept in one fetch path"
)]
async fn fetch_jwks(&self) -> Option<JwkSet> {
#[cfg(any(test, feature = "test-helpers"))]
let screening = if self.test_allow_loopback_ssrf {
screen_oauth_target_with_test_override(&self.jwks_uri, self.allow_http, true).await
} else {
screen_oauth_target(&self.jwks_uri, self.allow_http).await
};
#[cfg(not(any(test, feature = "test-helpers")))]
let screening = screen_oauth_target(&self.jwks_uri, self.allow_http).await;
if let Err(error) = screening {
tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to screen JWKS target");
return None;
}
let mut resp = match self.http.get(&self.jwks_uri).send().await {
Ok(resp) => resp,
Err(e) => {
tracing::warn!(error = %e, uri = %self.jwks_uri, "failed to fetch JWKS");
return None;
}
};
let initial_capacity =
usize::try_from(self.max_response_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
let mut body = Vec::with_capacity(initial_capacity);
while let Some(chunk) = match resp.chunk().await {
Ok(chunk) => chunk,
Err(error) => {
tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to read JWKS response");
return None;
}
} {
let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
if body_len.saturating_add(chunk_len) > self.max_response_bytes {
tracing::warn!(
uri = %self.jwks_uri,
max_bytes = self.max_response_bytes,
"JWKS response exceeded configured size cap"
);
return None;
}
body.extend_from_slice(&chunk);
}
match serde_json::from_slice::<JwkSet>(&body) {
Ok(jwks) => Some(jwks),
Err(error) => {
tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to parse JWKS");
None
}
}
}
#[cfg(any(test, feature = "test-helpers"))]
#[doc(hidden)]
pub async fn __test_refresh_now(&self) -> Result<(), String> {
let jwks = self
.fetch_jwks()
.await
.ok_or_else(|| "failed to fetch or parse JWKS".to_owned())?;
let (keys, unnamed_keys) = build_key_cache(&jwks, self.max_jwks_keys)?;
let mut guard = self.inner.write().await;
*guard = Some(CachedKeys {
keys,
unnamed_keys,
fetched_at: Instant::now(),
ttl: self.ttl,
});
Ok(())
}
#[cfg(any(test, feature = "test-helpers"))]
#[doc(hidden)]
pub async fn __test_has_kid(&self, kid: &str) -> bool {
let guard = self.inner.read().await;
guard
.as_ref()
.is_some_and(|cache| cache.keys.contains_key(kid))
}
}
fn build_key_cache(jwks: &JwkSet, max_keys: usize) -> Result<JwksKeyCache, String> {
if jwks.keys.len() > max_keys {
return Err(format!(
"jwks_key_count_exceeds_cap: got {} keys, max is {}",
jwks.keys.len(),
max_keys
));
}
let mut keys = HashMap::new();
let mut unnamed_keys = Vec::new();
for jwk in &jwks.keys {
let Ok(decoding_key) = DecodingKey::from_jwk(jwk) else {
continue;
};
let Some(alg) = jwk_algorithm(jwk) else {
continue;
};
if let Some(ref kid) = jwk.common.key_id {
keys.insert(kid.clone(), (alg, decoding_key));
} else {
unnamed_keys.push((alg, decoding_key));
}
}
Ok((keys, unnamed_keys))
}
fn lookup_key(cached: &CachedKeys, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
if let Some(kid) = kid
&& let Some((cached_alg, key)) = cached.keys.get(kid)
&& *cached_alg == alg
{
return Some(key.clone());
}
cached
.unnamed_keys
.iter()
.find(|(a, _)| *a == alg)
.map(|(_, k)| k.clone())
}
#[allow(clippy::wildcard_enum_match_arm)]
fn jwk_algorithm(jwk: &jsonwebtoken::jwk::Jwk) -> Option<Algorithm> {
jwk.common.key_algorithm.and_then(|ka| match ka {
jsonwebtoken::jwk::KeyAlgorithm::RS256 => Some(Algorithm::RS256),
jsonwebtoken::jwk::KeyAlgorithm::RS384 => Some(Algorithm::RS384),
jsonwebtoken::jwk::KeyAlgorithm::RS512 => Some(Algorithm::RS512),
jsonwebtoken::jwk::KeyAlgorithm::ES256 => Some(Algorithm::ES256),
jsonwebtoken::jwk::KeyAlgorithm::ES384 => Some(Algorithm::ES384),
jsonwebtoken::jwk::KeyAlgorithm::PS256 => Some(Algorithm::PS256),
jsonwebtoken::jwk::KeyAlgorithm::PS384 => Some(Algorithm::PS384),
jsonwebtoken::jwk::KeyAlgorithm::PS512 => Some(Algorithm::PS512),
jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
_ => None,
})
}
fn resolve_claim_path<'a>(
extra: &'a HashMap<String, serde_json::Value>,
path: &str,
) -> Vec<&'a str> {
let mut segments = path.split('.');
let Some(first) = segments.next() else {
return Vec::new();
};
let mut current: Option<&serde_json::Value> = extra.get(first);
for segment in segments {
current = current.and_then(|v| v.get(segment));
}
match current {
Some(serde_json::Value::String(s)) => s.split_whitespace().collect(),
Some(serde_json::Value::Array(arr)) => arr.iter().filter_map(|v| v.as_str()).collect(),
_ => Vec::new(),
}
}
#[derive(Debug, Deserialize)]
struct Claims {
sub: Option<String>,
#[serde(default)]
aud: OneOrMany,
azp: Option<String>,
client_id: Option<String>,
scope: Option<String>,
#[serde(flatten)]
extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Default)]
struct OneOrMany(Vec<String>);
impl OneOrMany {
fn contains(&self, value: &str) -> bool {
self.0.iter().any(|v| v == value)
}
}
impl<'de> Deserialize<'de> for OneOrMany {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
use serde::de;
struct Visitor;
impl<'de> de::Visitor<'de> for Visitor {
type Value = OneOrMany;
fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("a string or array of strings")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<OneOrMany, E> {
Ok(OneOrMany(vec![v.to_owned()]))
}
fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<OneOrMany, A::Error> {
let mut v = Vec::new();
while let Some(s) = seq.next_element::<String>()? {
v.push(s);
}
Ok(OneOrMany(v))
}
}
deserializer.deserialize_any(Visitor)
}
}
#[must_use]
pub fn looks_like_jwt(token: &str) -> bool {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let mut parts = token.splitn(4, '.');
let Some(header_b64) = parts.next() else {
return false;
};
if parts.next().is_none() || parts.next().is_none() || parts.next().is_some() {
return false;
}
let Ok(header_bytes) = URL_SAFE_NO_PAD.decode(header_b64) else {
return false;
};
let Ok(header) = serde_json::from_slice::<serde_json::Value>(&header_bytes) else {
return false;
};
header.get("alg").is_some()
}
#[must_use]
pub fn protected_resource_metadata(
resource_url: &str,
server_url: &str,
config: &OAuthConfig,
) -> serde_json::Value {
let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
let auth_server = server_url;
serde_json::json!({
"resource": resource_url,
"authorization_servers": [auth_server],
"scopes_supported": scopes,
"bearer_methods_supported": ["header"]
})
}
#[must_use]
pub fn authorization_server_metadata(server_url: &str, config: &OAuthConfig) -> serde_json::Value {
let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
let mut meta = serde_json::json!({
"issuer": &config.issuer,
"authorization_endpoint": format!("{server_url}/authorize"),
"token_endpoint": format!("{server_url}/token"),
"registration_endpoint": format!("{server_url}/register"),
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256"],
"scopes_supported": scopes,
"token_endpoint_auth_methods_supported": ["none"],
});
if let Some(proxy) = &config.proxy
&& proxy.expose_admin_endpoints
&& let Some(obj) = meta.as_object_mut()
{
if proxy.introspection_url.is_some() {
obj.insert(
"introspection_endpoint".into(),
serde_json::Value::String(format!("{server_url}/introspect")),
);
}
if proxy.revocation_url.is_some() {
obj.insert(
"revocation_endpoint".into(),
serde_json::Value::String(format!("{server_url}/revoke")),
);
}
if proxy.require_auth_on_admin_endpoints {
obj.insert(
"introspection_endpoint_auth_methods_supported".into(),
serde_json::json!(["bearer"]),
);
obj.insert(
"revocation_endpoint_auth_methods_supported".into(),
serde_json::json!(["bearer"]),
);
}
}
meta
}
#[must_use]
pub fn handle_authorize(proxy: &OAuthProxyConfig, query: &str) -> axum::response::Response {
use axum::{
http::{StatusCode, header},
response::IntoResponse,
};
let upstream_query = replace_client_id(query, &proxy.client_id);
let redirect_url = format!("{}?{upstream_query}", proxy.authorize_url);
(StatusCode::FOUND, [(header::LOCATION, redirect_url)]).into_response()
}
pub async fn handle_token(
http: &OauthHttpClient,
proxy: &OAuthProxyConfig,
body: &str,
) -> axum::response::Response {
use axum::{
http::{StatusCode, header},
response::IntoResponse,
};
let mut upstream_body = replace_client_id(body, &proxy.client_id);
if let Some(ref secret) = proxy.client_secret {
use std::fmt::Write;
use secrecy::ExposeSecret;
let _ = write!(
upstream_body,
"&client_secret={}",
urlencoding::encode(secret.expose_secret())
);
}
let result = http
.send_screened(
&proxy.token_url,
http.inner
.post(&proxy.token_url)
.header("Content-Type", "application/x-www-form-urlencoded")
.body(upstream_body),
)
.await;
match result {
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let body_bytes = resp.bytes().await.unwrap_or_default();
(
status,
[(header::CONTENT_TYPE, "application/json")],
body_bytes,
)
.into_response()
}
Err(e) => {
tracing::error!(error = %e, "OAuth token proxy request failed");
(
StatusCode::BAD_GATEWAY,
[(header::CONTENT_TYPE, "application/json")],
"{\"error\":\"server_error\",\"error_description\":\"token endpoint unreachable\"}",
)
.into_response()
}
}
}
#[must_use]
pub fn handle_register(proxy: &OAuthProxyConfig, body: &serde_json::Value) -> serde_json::Value {
let mut resp = serde_json::json!({
"client_id": proxy.client_id,
"token_endpoint_auth_method": "none",
});
if let Some(uris) = body.get("redirect_uris")
&& let Some(obj) = resp.as_object_mut()
{
obj.insert("redirect_uris".into(), uris.clone());
}
if let Some(name) = body.get("client_name")
&& let Some(obj) = resp.as_object_mut()
{
obj.insert("client_name".into(), name.clone());
}
resp
}
pub async fn handle_introspect(
http: &OauthHttpClient,
proxy: &OAuthProxyConfig,
body: &str,
) -> axum::response::Response {
let Some(ref url) = proxy.introspection_url else {
return oauth_error_response(
axum::http::StatusCode::NOT_FOUND,
"not_supported",
"introspection endpoint is not configured",
);
};
proxy_oauth_admin_request(http, proxy, url, body).await
}
pub async fn handle_revoke(
http: &OauthHttpClient,
proxy: &OAuthProxyConfig,
body: &str,
) -> axum::response::Response {
let Some(ref url) = proxy.revocation_url else {
return oauth_error_response(
axum::http::StatusCode::NOT_FOUND,
"not_supported",
"revocation endpoint is not configured",
);
};
proxy_oauth_admin_request(http, proxy, url, body).await
}
async fn proxy_oauth_admin_request(
http: &OauthHttpClient,
proxy: &OAuthProxyConfig,
upstream_url: &str,
body: &str,
) -> axum::response::Response {
use axum::{
http::{StatusCode, header},
response::IntoResponse,
};
let mut upstream_body = replace_client_id(body, &proxy.client_id);
if let Some(ref secret) = proxy.client_secret {
use std::fmt::Write;
use secrecy::ExposeSecret;
let _ = write!(
upstream_body,
"&client_secret={}",
urlencoding::encode(secret.expose_secret())
);
}
let result = http
.send_screened(
upstream_url,
http.inner
.post(upstream_url)
.header("Content-Type", "application/x-www-form-urlencoded")
.body(upstream_body),
)
.await;
match result {
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let content_type = resp
.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("application/json")
.to_owned();
let body_bytes = resp.bytes().await.unwrap_or_default();
(status, [(header::CONTENT_TYPE, content_type)], body_bytes).into_response()
}
Err(e) => {
tracing::error!(error = %e, url = %upstream_url, "OAuth admin proxy request failed");
oauth_error_response(
StatusCode::BAD_GATEWAY,
"server_error",
"upstream endpoint unreachable",
)
}
}
}
fn oauth_error_response(
status: axum::http::StatusCode,
error: &str,
description: &str,
) -> axum::response::Response {
use axum::{http::header, response::IntoResponse};
let body = serde_json::json!({
"error": error,
"error_description": description,
});
(
status,
[(header::CONTENT_TYPE, "application/json")],
body.to_string(),
)
.into_response()
}
#[derive(Debug, Deserialize)]
struct OAuthErrorResponse {
error: String,
error_description: Option<String>,
}
fn sanitize_oauth_error_code(raw: &str) -> &'static str {
match raw {
"invalid_request" => "invalid_request",
"invalid_client" => "invalid_client",
"invalid_grant" => "invalid_grant",
"unauthorized_client" => "unauthorized_client",
"unsupported_grant_type" => "unsupported_grant_type",
"invalid_scope" => "invalid_scope",
"temporarily_unavailable" => "temporarily_unavailable",
"invalid_target" => "invalid_target",
_ => "server_error",
}
}
pub async fn exchange_token(
http: &OauthHttpClient,
config: &TokenExchangeConfig,
subject_token: &str,
) -> Result<ExchangedToken, crate::error::McpxError> {
use secrecy::ExposeSecret;
let mut req = http
.inner
.post(&config.token_url)
.header("Content-Type", "application/x-www-form-urlencoded")
.header("Accept", "application/json");
if let Some(ref secret) = config.client_secret {
use base64::Engine;
let credentials = base64::engine::general_purpose::STANDARD.encode(format!(
"{}:{}",
urlencoding::encode(&config.client_id),
urlencoding::encode(secret.expose_secret()),
));
req = req.header("Authorization", format!("Basic {credentials}"));
}
let form_body = build_exchange_form(config, subject_token);
let resp = http
.send_screened(&config.token_url, req.body(form_body))
.await
.map_err(|e| {
tracing::error!(error = %e, "token exchange request failed");
crate::error::McpxError::Auth("server_error".into())
})?;
let status = resp.status();
let body_bytes = resp.bytes().await.map_err(|e| {
tracing::error!(error = %e, "failed to read token exchange response");
crate::error::McpxError::Auth("server_error".into())
})?;
if !status.is_success() {
core::hint::cold_path();
let parsed = serde_json::from_slice::<OAuthErrorResponse>(&body_bytes).ok();
let short_code = parsed
.as_ref()
.map_or("server_error", |e| sanitize_oauth_error_code(&e.error));
if let Some(ref e) = parsed {
tracing::warn!(
status = %status,
upstream_error = %e.error,
upstream_error_description = e.error_description.as_deref().unwrap_or(""),
client_code = %short_code,
"token exchange rejected by authorization server",
);
} else {
tracing::warn!(
status = %status,
client_code = %short_code,
"token exchange rejected (unparseable upstream body)",
);
}
return Err(crate::error::McpxError::Auth(short_code.into()));
}
let exchanged = serde_json::from_slice::<ExchangedToken>(&body_bytes).map_err(|e| {
tracing::error!(error = %e, "failed to parse token exchange response");
crate::error::McpxError::Auth("server_error".into())
})?;
log_exchanged_token(&exchanged);
Ok(exchanged)
}
fn build_exchange_form(config: &TokenExchangeConfig, subject_token: &str) -> String {
let body = format!(
"grant_type={}&subject_token={}&subject_token_type={}&requested_token_type={}&audience={}",
urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
urlencoding::encode(subject_token),
urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
urlencoding::encode(&config.audience),
);
if config.client_secret.is_none() {
format!(
"{body}&client_id={}",
urlencoding::encode(&config.client_id)
)
} else {
body
}
}
fn log_exchanged_token(exchanged: &ExchangedToken) {
use base64::Engine;
if !looks_like_jwt(&exchanged.access_token) {
tracing::debug!(
token_len = exchanged.access_token.len(),
issued_token_type = ?exchanged.issued_token_type,
expires_in = exchanged.expires_in,
"exchanged token (opaque)",
);
return;
}
let Some(payload) = exchanged.access_token.split('.').nth(1) else {
return;
};
let Ok(decoded) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload) else {
return;
};
let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&decoded) else {
return;
};
tracing::debug!(
sub = ?claims.get("sub"),
aud = ?claims.get("aud"),
azp = ?claims.get("azp"),
iss = ?claims.get("iss"),
expires_in = exchanged.expires_in,
"exchanged token claims (JWT)",
);
}
fn replace_client_id(params: &str, upstream_client_id: &str) -> String {
let encoded_id = urlencoding::encode(upstream_client_id);
let mut parts: Vec<String> = params
.split('&')
.filter(|p| !p.starts_with("client_id="))
.map(String::from)
.collect();
parts.push(format!("client_id={encoded_id}"));
parts.join("&")
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use super::*;
#[test]
fn looks_like_jwt_valid() {
let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
let payload = URL_SAFE_NO_PAD.encode(b"{}");
let token = format!("{header}.{payload}.signature");
assert!(looks_like_jwt(&token));
}
#[test]
fn looks_like_jwt_rejects_opaque_token() {
assert!(!looks_like_jwt("dGhpcyBpcyBhbiBvcGFxdWUgdG9rZW4"));
}
#[test]
fn looks_like_jwt_rejects_two_segments() {
let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\"}");
let token = format!("{header}.payload");
assert!(!looks_like_jwt(&token));
}
#[test]
fn looks_like_jwt_rejects_four_segments() {
assert!(!looks_like_jwt("a.b.c.d"));
}
#[test]
fn looks_like_jwt_rejects_no_alg() {
let header = URL_SAFE_NO_PAD.encode(b"{\"typ\":\"JWT\"}");
let payload = URL_SAFE_NO_PAD.encode(b"{}");
let token = format!("{header}.{payload}.sig");
assert!(!looks_like_jwt(&token));
}
#[test]
fn protected_resource_metadata_shape() {
let config = OAuthConfig {
issuer: "https://auth.example.com".into(),
audience: "https://mcp.example.com/mcp".into(),
jwks_uri: "https://auth.example.com/.well-known/jwks.json".into(),
scopes: vec![
ScopeMapping {
scope: "mcp:read".into(),
role: "viewer".into(),
},
ScopeMapping {
scope: "mcp:admin".into(),
role: "ops".into(),
},
],
role_claim: None,
role_mappings: vec![],
jwks_cache_ttl: "10m".into(),
proxy: None,
token_exchange: None,
ca_cert_path: None,
allow_http_oauth_urls: false,
max_jwks_keys: default_max_jwks_keys(),
strict_audience_validation: false,
jwks_max_response_bytes: default_jwks_max_bytes(),
};
let meta = protected_resource_metadata(
"https://mcp.example.com/mcp",
"https://mcp.example.com",
&config,
);
assert_eq!(meta["resource"], "https://mcp.example.com/mcp");
assert_eq!(meta["authorization_servers"][0], "https://mcp.example.com");
assert_eq!(meta["scopes_supported"].as_array().unwrap().len(), 2);
assert_eq!(meta["bearer_methods_supported"][0], "header");
}
fn validation_https_config() -> OAuthConfig {
OAuthConfig::builder(
"https://auth.example.com",
"mcp",
"https://auth.example.com/.well-known/jwks.json",
)
.build()
}
#[test]
fn validate_accepts_all_https_urls() {
let cfg = validation_https_config();
cfg.validate().expect("all-HTTPS config must validate");
}
#[test]
fn validate_rejects_http_jwks_uri() {
let mut cfg = validation_https_config();
cfg.jwks_uri = "http://auth.example.com/.well-known/jwks.json".into();
let err = cfg.validate().expect_err("http jwks_uri must be rejected");
let msg = err.to_string();
assert!(
msg.contains("oauth.jwks_uri") && msg.contains("https"),
"error must reference offending field + scheme requirement; got {msg:?}"
);
}
#[test]
fn validate_rejects_http_proxy_authorize_url() {
let mut cfg = validation_https_config();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"http://idp.example.com/authorize", "https://idp.example.com/token",
"client",
)
.build(),
);
let err = cfg
.validate()
.expect_err("http authorize_url must be rejected");
assert!(
err.to_string().contains("oauth.proxy.authorize_url"),
"error must reference proxy.authorize_url; got {err}"
);
}
#[test]
fn validate_rejects_http_proxy_token_url() {
let mut cfg = validation_https_config();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"https://idp.example.com/authorize",
"http://idp.example.com/token", "client",
)
.build(),
);
let err = cfg.validate().expect_err("http token_url must be rejected");
assert!(
err.to_string().contains("oauth.proxy.token_url"),
"error must reference proxy.token_url; got {err}"
);
}
#[test]
fn validate_rejects_http_proxy_introspection_and_revocation_urls() {
let mut cfg = validation_https_config();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"https://idp.example.com/authorize",
"https://idp.example.com/token",
"client",
)
.introspection_url("http://idp.example.com/introspect")
.build(),
);
let err = cfg
.validate()
.expect_err("http introspection_url must be rejected");
assert!(err.to_string().contains("oauth.proxy.introspection_url"));
let mut cfg = validation_https_config();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"https://idp.example.com/authorize",
"https://idp.example.com/token",
"client",
)
.revocation_url("http://idp.example.com/revoke")
.build(),
);
let err = cfg
.validate()
.expect_err("http revocation_url must be rejected");
assert!(err.to_string().contains("oauth.proxy.revocation_url"));
}
#[test]
fn validate_rejects_http_token_exchange_url() {
let mut cfg = validation_https_config();
cfg.token_exchange = Some(TokenExchangeConfig::new(
"http://idp.example.com/token".into(), "client".into(),
None,
None,
"downstream".into(),
));
let err = cfg
.validate()
.expect_err("http token_exchange.token_url must be rejected");
assert!(
err.to_string().contains("oauth.token_exchange.token_url"),
"error must reference token_exchange.token_url; got {err}"
);
}
#[test]
fn validate_rejects_unparseable_url() {
let mut cfg = validation_https_config();
cfg.jwks_uri = "not a url".into();
let err = cfg
.validate()
.expect_err("unparseable URL must be rejected");
assert!(err.to_string().contains("invalid URL"));
}
#[test]
fn validate_rejects_non_http_scheme() {
let mut cfg = validation_https_config();
cfg.jwks_uri = "file:///etc/passwd".into();
let err = cfg.validate().expect_err("file:// scheme must be rejected");
let msg = err.to_string();
assert!(
msg.contains("must use https scheme") && msg.contains("file"),
"error must reject non-http(s) schemes; got {msg:?}"
);
}
#[test]
fn validate_accepts_http_with_escape_hatch() {
let mut cfg = OAuthConfig::builder(
"http://auth.local",
"mcp",
"http://auth.local/.well-known/jwks.json",
)
.allow_http_oauth_urls(true)
.build();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"http://idp.local/authorize",
"http://idp.local/token",
"client",
)
.introspection_url("http://idp.local/introspect")
.revocation_url("http://idp.local/revoke")
.build(),
);
cfg.token_exchange = Some(TokenExchangeConfig::new(
"http://idp.local/token".into(),
"client".into(),
None,
None,
"downstream".into(),
));
cfg.validate()
.expect("escape hatch must permit http on all URL fields");
}
#[test]
fn validate_with_escape_hatch_still_rejects_unparseable() {
let mut cfg = validation_https_config();
cfg.allow_http_oauth_urls = true;
cfg.jwks_uri = "::not-a-url::".into();
cfg.validate()
.expect_err("escape hatch must NOT bypass URL parsing");
}
#[tokio::test]
async fn jwks_cache_rejects_redirect_downgrade_to_http() {
rustls::crypto::ring::default_provider()
.install_default()
.ok();
let policy = reqwest::redirect::Policy::custom(|attempt| {
if attempt.url().scheme() != "https" {
attempt.error("redirect to non-HTTPS URL refused")
} else if attempt.previous().len() >= 2 {
attempt.error("too many redirects (max 2)")
} else {
attempt.follow()
}
});
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(3))
.redirect(policy)
.build()
.expect("test client builds");
let mock = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(
wiremock::ResponseTemplate::new(302)
.insert_header("location", "http://example.invalid/jwks.json"),
)
.mount(&mock)
.await;
let url = format!("{}/jwks.json", mock.uri());
let err = client
.get(&url)
.send()
.await
.expect_err("redirect policy must reject scheme downgrade");
let chain = format!("{err:#}");
assert!(
chain.contains("redirect to non-HTTPS URL refused")
|| chain.to_lowercase().contains("redirect"),
"error must surface redirect-policy rejection; got {chain:?}"
);
}
use rsa::{pkcs8::EncodePrivateKey, traits::PublicKeyParts};
fn generate_test_keypair(kid: &str) -> (String, serde_json::Value) {
let mut rng = rsa::rand_core::OsRng;
let private_key = rsa::RsaPrivateKey::new(&mut rng, 2048).expect("keypair generation");
let private_pem = private_key
.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
.expect("PKCS8 PEM export")
.to_string();
let public_key = private_key.to_public_key();
let n = URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be());
let e = URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be());
let jwks = serde_json::json!({
"keys": [{
"kty": "RSA",
"use": "sig",
"alg": "RS256",
"kid": kid,
"n": n,
"e": e
}]
});
(private_pem, jwks)
}
fn mint_token(
private_pem: &str,
kid: &str,
issuer: &str,
audience: &str,
subject: &str,
scope: &str,
) -> String {
let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
.expect("encoding key from PEM");
let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
header.kid = Some(kid.into());
let now = jsonwebtoken::get_current_timestamp();
let claims = serde_json::json!({
"iss": issuer,
"aud": audience,
"sub": subject,
"scope": scope,
"exp": now + 3600,
"iat": now,
});
jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
}
fn test_config(jwks_uri: &str) -> OAuthConfig {
OAuthConfig {
issuer: "https://auth.test.local".into(),
audience: "https://mcp.test.local/mcp".into(),
jwks_uri: jwks_uri.into(),
scopes: vec![
ScopeMapping {
scope: "mcp:read".into(),
role: "viewer".into(),
},
ScopeMapping {
scope: "mcp:admin".into(),
role: "ops".into(),
},
],
role_claim: None,
role_mappings: vec![],
jwks_cache_ttl: "5m".into(),
proxy: None,
token_exchange: None,
ca_cert_path: None,
allow_http_oauth_urls: true,
max_jwks_keys: default_max_jwks_keys(),
strict_audience_validation: false,
jwks_max_response_bytes: default_jwks_max_bytes(),
}
}
fn test_cache(config: &OAuthConfig) -> JwksCache {
JwksCache::new(config).unwrap().__test_allow_loopback_ssrf()
}
#[tokio::test]
async fn valid_jwt_returns_identity() {
let kid = "test-key-1";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"ci-bot",
"mcp:read mcp:other",
);
let identity = cache.validate_token(&token).await;
assert!(identity.is_some(), "valid JWT should authenticate");
let id = identity.unwrap();
assert_eq!(id.name, "ci-bot");
assert_eq!(id.role, "viewer"); assert_eq!(id.method, AuthMethod::OAuthJwt);
}
#[tokio::test]
async fn wrong_issuer_rejected() {
let kid = "test-key-2";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let token = mint_token(
&pem,
kid,
"https://wrong-issuer.example.com", "https://mcp.test.local/mcp",
"attacker",
"mcp:admin",
);
assert!(cache.validate_token(&token).await.is_none());
}
#[tokio::test]
async fn wrong_audience_rejected() {
let kid = "test-key-3";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://wrong-audience.example.com", "attacker",
"mcp:admin",
);
assert!(cache.validate_token(&token).await.is_none());
}
#[tokio::test]
async fn expired_jwt_rejected() {
let kid = "test-key-4";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let encoding_key =
jsonwebtoken::EncodingKey::from_rsa_pem(pem.as_bytes()).expect("encoding key");
let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
header.kid = Some(kid.into());
let now = jsonwebtoken::get_current_timestamp();
let claims = serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://mcp.test.local/mcp",
"sub": "expired-bot",
"scope": "mcp:read",
"exp": now - 120,
"iat": now - 3720,
});
let token = jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding");
assert!(cache.validate_token(&token).await.is_none());
}
#[tokio::test]
async fn no_matching_scope_rejected() {
let kid = "test-key-5";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"limited-bot",
"some:other:scope", );
assert!(cache.validate_token(&token).await.is_none());
}
#[tokio::test]
async fn wrong_signing_key_rejected() {
let kid = "test-key-6";
let (_pem, jwks) = generate_test_keypair(kid);
let (attacker_pem, _) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let token = mint_token(
&attacker_pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"attacker",
"mcp:admin",
);
assert!(cache.validate_token(&token).await.is_none());
}
#[tokio::test]
async fn admin_scope_maps_to_ops_role() {
let kid = "test-key-7";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"admin-bot",
"mcp:admin",
);
let id = cache
.validate_token(&token)
.await
.expect("should authenticate");
assert_eq!(id.role, "ops");
assert_eq!(id.name, "admin-bot");
}
#[tokio::test]
async fn jwks_server_down_returns_none() {
let config = test_config("http://127.0.0.1:1/jwks.json");
let cache = test_cache(&config);
let kid = "orphan-key";
let (pem, _) = generate_test_keypair(kid);
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"bot",
"mcp:read",
);
assert!(cache.validate_token(&token).await.is_none());
}
#[test]
fn resolve_claim_path_flat_string() {
let mut extra = HashMap::new();
extra.insert(
"scope".into(),
serde_json::Value::String("mcp:read mcp:admin".into()),
);
let values = resolve_claim_path(&extra, "scope");
assert_eq!(values, vec!["mcp:read", "mcp:admin"]);
}
#[test]
fn resolve_claim_path_flat_array() {
let mut extra = HashMap::new();
extra.insert(
"roles".into(),
serde_json::json!(["mcp-admin", "mcp-viewer"]),
);
let values = resolve_claim_path(&extra, "roles");
assert_eq!(values, vec!["mcp-admin", "mcp-viewer"]);
}
#[test]
fn resolve_claim_path_nested_keycloak() {
let mut extra = HashMap::new();
extra.insert(
"realm_access".into(),
serde_json::json!({"roles": ["uma_authorization", "mcp-admin"]}),
);
let values = resolve_claim_path(&extra, "realm_access.roles");
assert_eq!(values, vec!["uma_authorization", "mcp-admin"]);
}
#[test]
fn resolve_claim_path_missing_returns_empty() {
let extra = HashMap::new();
assert!(resolve_claim_path(&extra, "nonexistent.path").is_empty());
}
#[test]
fn resolve_claim_path_numeric_leaf_returns_empty() {
let mut extra = HashMap::new();
extra.insert("count".into(), serde_json::json!(42));
assert!(resolve_claim_path(&extra, "count").is_empty());
}
fn mint_token_with_claims(private_pem: &str, kid: &str, claims: &serde_json::Value) -> String {
let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
.expect("encoding key from PEM");
let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
header.kid = Some(kid.into());
jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
}
fn test_config_with_role_claim(
jwks_uri: &str,
role_claim: &str,
role_mappings: Vec<RoleMapping>,
) -> OAuthConfig {
OAuthConfig {
issuer: "https://auth.test.local".into(),
audience: "https://mcp.test.local/mcp".into(),
jwks_uri: jwks_uri.into(),
scopes: vec![],
role_claim: Some(role_claim.into()),
role_mappings,
jwks_cache_ttl: "5m".into(),
proxy: None,
token_exchange: None,
ca_cert_path: None,
allow_http_oauth_urls: true,
max_jwks_keys: default_max_jwks_keys(),
strict_audience_validation: false,
jwks_max_response_bytes: default_jwks_max_bytes(),
}
}
#[tokio::test]
async fn screen_oauth_target_rejects_literal_ip() {
let err = screen_oauth_target("https://127.0.0.1/jwks.json", false)
.await
.expect_err("literal IPs must be rejected");
let msg = err.to_string();
assert!(msg.contains("literal IPv4 addresses are forbidden"));
}
#[tokio::test]
async fn screen_oauth_target_rejects_private_dns_resolution() {
let err = screen_oauth_target("https://localhost/jwks.json", false)
.await
.expect_err("localhost resolution must be rejected");
let msg = err.to_string();
assert!(
msg.contains("blocked IP") && msg.contains("loopback"),
"got {msg:?}"
);
}
#[tokio::test]
async fn screen_oauth_target_rejects_literal_ip_even_with_allow_http() {
let err = screen_oauth_target("http://127.0.0.1/jwks.json", true)
.await
.expect_err("literal IPs must still be rejected when http is allowed");
let msg = err.to_string();
assert!(msg.contains("literal IPv4 addresses are forbidden"));
}
#[tokio::test]
async fn screen_oauth_target_rejects_private_dns_even_with_allow_http() {
let err = screen_oauth_target("http://localhost/jwks.json", true)
.await
.expect_err("private DNS resolution must still be rejected when http is allowed");
let msg = err.to_string();
assert!(
msg.contains("blocked IP") && msg.contains("loopback"),
"got {msg:?}"
);
}
#[tokio::test]
async fn screen_oauth_target_allows_public_hostname() {
screen_oauth_target("https://example.com/.well-known/jwks.json", false)
.await
.expect("public hostname should pass screening");
}
#[tokio::test]
async fn audience_falls_back_to_azp_by_default() {
let kid = "test-audience-azp-default";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let token = mint_token_with_claims(
&pem,
kid,
&serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://some-other-resource.example.com",
"azp": "https://mcp.test.local/mcp",
"sub": "compat-client",
"scope": "mcp:read",
"exp": now + 3600,
"iat": now,
}),
);
let identity = cache
.validate_token_with_reason(&token)
.await
.expect("azp fallback should remain enabled by default");
assert_eq!(identity.role, "viewer");
}
#[tokio::test]
async fn strict_audience_validation_rejects_azp_only_match() {
let kid = "test-audience-azp-strict";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let mut config = test_config(&jwks_uri);
config.strict_audience_validation = true;
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let token = mint_token_with_claims(
&pem,
kid,
&serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://some-other-resource.example.com",
"azp": "https://mcp.test.local/mcp",
"sub": "strict-client",
"scope": "mcp:read",
"exp": now + 3600,
"iat": now,
}),
);
let failure = cache
.validate_token_with_reason(&token)
.await
.expect_err("strict audience validation must ignore azp fallback");
assert_eq!(failure, JwtValidationFailure::Invalid);
}
#[derive(Clone, Default)]
struct CapturedLogs(Arc<std::sync::Mutex<Vec<u8>>>);
impl CapturedLogs {
fn contents(&self) -> String {
let bytes = self.0.lock().map(|guard| guard.clone()).unwrap_or_default();
String::from_utf8(bytes).unwrap_or_default()
}
}
struct CapturedLogsWriter(Arc<std::sync::Mutex<Vec<u8>>>);
impl std::io::Write for CapturedLogsWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if let Ok(mut guard) = self.0.lock() {
guard.extend_from_slice(buf);
}
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for CapturedLogs {
type Writer = CapturedLogsWriter;
fn make_writer(&'a self) -> Self::Writer {
CapturedLogsWriter(Arc::clone(&self.0))
}
}
#[tokio::test]
async fn jwks_response_size_cap_returns_none_and_logs_warning() {
let kid = "oversized-jwks";
let (_pem, jwks) = generate_test_keypair(kid);
let mut oversized_body = serde_json::to_string(&jwks).expect("jwks json");
oversized_body.push_str(&" ".repeat(4096));
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(
wiremock::ResponseTemplate::new(200)
.insert_header("content-type", "application/json")
.set_body_string(oversized_body),
)
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let mut config = test_config(&jwks_uri);
config.jwks_max_response_bytes = 256;
let cache = test_cache(&config);
let logs = CapturedLogs::default();
let subscriber = tracing_subscriber::fmt()
.with_writer(logs.clone())
.with_ansi(false)
.without_time()
.finish();
let _guard = tracing::subscriber::set_default(subscriber);
let result = cache.fetch_jwks().await;
assert!(result.is_none(), "oversized JWKS must be dropped");
assert!(
logs.contents()
.contains("JWKS response exceeded configured size cap"),
"expected cap-exceeded warning in logs"
);
}
#[tokio::test]
async fn role_claim_keycloak_nested_array() {
let kid = "test-role-1";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config_with_role_claim(
&jwks_uri,
"realm_access.roles",
vec![
RoleMapping {
claim_value: "mcp-admin".into(),
role: "ops".into(),
},
RoleMapping {
claim_value: "mcp-viewer".into(),
role: "viewer".into(),
},
],
);
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let token = mint_token_with_claims(
&pem,
kid,
&serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://mcp.test.local/mcp",
"sub": "keycloak-user",
"exp": now + 3600,
"iat": now,
"realm_access": { "roles": ["uma_authorization", "mcp-admin"] }
}),
);
let id = cache
.validate_token(&token)
.await
.expect("should authenticate");
assert_eq!(id.name, "keycloak-user");
assert_eq!(id.role, "ops");
}
#[tokio::test]
async fn role_claim_flat_roles_array() {
let kid = "test-role-2";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config_with_role_claim(
&jwks_uri,
"roles",
vec![
RoleMapping {
claim_value: "MCP.Admin".into(),
role: "ops".into(),
},
RoleMapping {
claim_value: "MCP.Reader".into(),
role: "viewer".into(),
},
],
);
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let token = mint_token_with_claims(
&pem,
kid,
&serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://mcp.test.local/mcp",
"sub": "azure-ad-user",
"exp": now + 3600,
"iat": now,
"roles": ["MCP.Reader", "OtherApp.Admin"]
}),
);
let id = cache
.validate_token(&token)
.await
.expect("should authenticate");
assert_eq!(id.name, "azure-ad-user");
assert_eq!(id.role, "viewer");
}
#[tokio::test]
async fn role_claim_no_matching_value_rejected() {
let kid = "test-role-3";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config_with_role_claim(
&jwks_uri,
"roles",
vec![RoleMapping {
claim_value: "mcp-admin".into(),
role: "ops".into(),
}],
);
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let token = mint_token_with_claims(
&pem,
kid,
&serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://mcp.test.local/mcp",
"sub": "limited-user",
"exp": now + 3600,
"iat": now,
"roles": ["some-other-role"]
}),
);
assert!(cache.validate_token(&token).await.is_none());
}
#[tokio::test]
async fn role_claim_space_separated_string() {
let kid = "test-role-4";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config_with_role_claim(
&jwks_uri,
"custom_scope",
vec![
RoleMapping {
claim_value: "write".into(),
role: "ops".into(),
},
RoleMapping {
claim_value: "read".into(),
role: "viewer".into(),
},
],
);
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let token = mint_token_with_claims(
&pem,
kid,
&serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://mcp.test.local/mcp",
"sub": "custom-client",
"exp": now + 3600,
"iat": now,
"custom_scope": "read audit"
}),
);
let id = cache
.validate_token(&token)
.await
.expect("should authenticate");
assert_eq!(id.name, "custom-client");
assert_eq!(id.role, "viewer");
}
#[tokio::test]
async fn scope_backward_compat_without_role_claim() {
let kid = "test-compat-1";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri); let cache = test_cache(&config);
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"legacy-bot",
"mcp:admin other:scope",
);
let id = cache
.validate_token(&token)
.await
.expect("should authenticate");
assert_eq!(id.name, "legacy-bot");
assert_eq!(id.role, "ops"); }
#[tokio::test]
async fn jwks_refresh_deduplication() {
let kid = "test-dedup";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.expect(1) .mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = Arc::new(test_cache(&config));
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"concurrent-bot",
"mcp:read",
);
let mut handles = Vec::new();
for _ in 0..5 {
let c = Arc::clone(&cache);
let t = token.clone();
handles.push(tokio::spawn(async move { c.validate_token(&t).await }));
}
for h in handles {
let result = h.await.unwrap();
assert!(result.is_some(), "all concurrent requests should succeed");
}
}
#[tokio::test]
async fn jwks_refresh_cooldown_blocks_rapid_requests() {
let kid = "test-cooldown";
let (_pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.expect(1) .mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let fake_token1 =
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTEifQ.e30.sig";
let _ = cache.validate_token(fake_token1).await;
let fake_token2 =
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTIifQ.e30.sig";
let _ = cache.validate_token(fake_token2).await;
let fake_token3 =
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTMifQ.e30.sig";
let _ = cache.validate_token(fake_token3).await;
}
fn proxy_cfg(token_url: &str) -> OAuthProxyConfig {
OAuthProxyConfig {
authorize_url: "https://example.invalid/auth".into(),
token_url: token_url.into(),
client_id: "mcp-client".into(),
client_secret: Some(secrecy::SecretString::from("shh".to_owned())),
introspection_url: None,
revocation_url: None,
expose_admin_endpoints: false,
require_auth_on_admin_endpoints: false,
}
}
fn test_http_client() -> OauthHttpClient {
rustls::crypto::ring::default_provider()
.install_default()
.ok();
let config = OAuthConfig::builder(
"https://auth.test.local",
"https://mcp.test.local/mcp",
"https://auth.test.local/.well-known/jwks.json",
)
.allow_http_oauth_urls(true)
.build();
OauthHttpClient::with_config(&config)
.expect("build test http client")
.__test_allow_loopback_ssrf()
}
#[tokio::test]
async fn introspect_proxies_and_injects_client_credentials() {
use wiremock::matchers::{body_string_contains, method, path};
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(method("POST"))
.and(path("/introspect"))
.and(body_string_contains("client_id=mcp-client"))
.and(body_string_contains("client_secret=shh"))
.and(body_string_contains("token=abc"))
.respond_with(
wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
"active": true,
"scope": "read"
})),
)
.expect(1)
.mount(&mock_server)
.await;
let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
proxy.introspection_url = Some(format!("{}/introspect", mock_server.uri()));
let http = test_http_client();
let resp = handle_introspect(&http, &proxy, "token=abc").await;
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn introspect_returns_404_when_not_configured() {
let proxy = proxy_cfg("https://example.invalid/token");
let http = test_http_client();
let resp = handle_introspect(&http, &proxy, "token=abc").await;
assert_eq!(resp.status(), 404);
}
#[tokio::test]
async fn revoke_proxies_and_returns_upstream_status() {
use wiremock::matchers::{method, path};
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(method("POST"))
.and(path("/revoke"))
.respond_with(wiremock::ResponseTemplate::new(200))
.expect(1)
.mount(&mock_server)
.await;
let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
proxy.revocation_url = Some(format!("{}/revoke", mock_server.uri()));
let http = test_http_client();
let resp = handle_revoke(&http, &proxy, "token=abc").await;
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn revoke_returns_404_when_not_configured() {
let proxy = proxy_cfg("https://example.invalid/token");
let http = test_http_client();
let resp = handle_revoke(&http, &proxy, "token=abc").await;
assert_eq!(resp.status(), 404);
}
#[test]
fn metadata_advertises_endpoints_only_when_configured() {
let mut cfg = test_config("https://auth.test.local/jwks.json");
let m = authorization_server_metadata("https://mcp.local", &cfg);
assert!(m.get("introspection_endpoint").is_none());
assert!(m.get("revocation_endpoint").is_none());
let mut proxy = proxy_cfg("https://upstream.local/token");
proxy.introspection_url = Some("https://upstream.local/introspect".into());
proxy.revocation_url = Some("https://upstream.local/revoke".into());
cfg.proxy = Some(proxy);
let m = authorization_server_metadata("https://mcp.local", &cfg);
assert!(
m.get("introspection_endpoint").is_none(),
"introspection must not be advertised when expose_admin_endpoints=false"
);
assert!(
m.get("revocation_endpoint").is_none(),
"revocation must not be advertised when expose_admin_endpoints=false"
);
if let Some(p) = cfg.proxy.as_mut() {
p.expose_admin_endpoints = true;
p.revocation_url = None;
}
let m = authorization_server_metadata("https://mcp.local", &cfg);
assert_eq!(
m["introspection_endpoint"],
serde_json::Value::String("https://mcp.local/introspect".into())
);
assert!(m.get("revocation_endpoint").is_none());
if let Some(p) = cfg.proxy.as_mut() {
p.revocation_url = Some("https://upstream.local/revoke".into());
}
let m = authorization_server_metadata("https://mcp.local", &cfg);
assert_eq!(
m["revocation_endpoint"],
serde_json::Value::String("https://mcp.local/revoke".into())
);
}
}