use std::{
collections::HashMap,
path::PathBuf,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
time::{Duration, Instant},
};
use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode, decode_header, jwk::JwkSet};
use serde::Deserialize;
use tokio::{net::lookup_host, sync::RwLock};
use crate::auth::{AuthIdentity, AuthMethod};
fn evaluate_oauth_redirect(
attempt: &reqwest::redirect::Attempt<'_>,
allow_http: bool,
allowlist: &crate::ssrf::CompiledSsrfAllowlist,
) -> Result<(), String> {
let prev_https = attempt
.previous()
.last()
.is_some_and(|prev| prev.scheme() == "https");
let target_url = attempt.url();
let dest_scheme = target_url.scheme();
if dest_scheme != "https" {
if prev_https {
return Err("redirect downgrades https -> http".to_owned());
}
if !allow_http || dest_scheme != "http" {
return Err("redirect to non-HTTP(S) URL refused".to_owned());
}
}
if let Some(reason) = crate::ssrf::redirect_target_reason_with_allowlist(target_url, allowlist)
{
return Err(format!("redirect target forbidden: {reason}"));
}
if attempt.previous().len() >= 2 {
return Err("too many redirects (max 2)".to_owned());
}
Ok(())
}
#[cfg_attr(not(any(test, feature = "test-helpers")), allow(dead_code))]
async fn screen_oauth_target_with_test_override(
url: &str,
allow_http: bool,
allowlist: &crate::ssrf::CompiledSsrfAllowlist,
#[cfg(any(test, feature = "test-helpers"))] test_allow_loopback_ssrf: bool,
) -> Result<(), crate::error::McpxError> {
let parsed = check_oauth_url("oauth target", url, allow_http)?;
#[cfg(any(test, feature = "test-helpers"))]
if test_allow_loopback_ssrf {
return Ok(());
}
if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
return Err(crate::error::McpxError::Config(format!(
"OAuth target forbidden ({reason}): {url}"
)));
}
let host = parsed.host_str().ok_or_else(|| {
crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
})?;
let port = parsed.port_or_known_default().ok_or_else(|| {
crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
})?;
let addrs = lookup_host((host, port)).await.map_err(|error| {
crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
})?;
let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
let mut any_addr = false;
for addr in addrs {
any_addr = true;
let ip = addr.ip();
if let Some(reason) = crate::ssrf::ip_block_reason(ip) {
if reason == "cloud_metadata" {
return Err(crate::error::McpxError::Config(format!(
"OAuth target resolved to blocked IP ({reason}): {url}"
)));
}
if allowlist.is_empty() {
return Err(crate::error::McpxError::Config(format!(
"OAuth target resolved to blocked IP ({reason}): {url}"
)));
}
if host_allowed || allowlist.ip_allowed(ip) {
continue;
}
return Err(crate::error::McpxError::Config(format!(
"OAuth target blocked: hostname {host} resolved to {ip} ({reason}). \
To allow, add the hostname to oauth.ssrf_allowlist.hosts or the CIDR \
to oauth.ssrf_allowlist.cidrs (operators only -- see SECURITY.md). \
URL: {url}"
)));
}
}
if !any_addr {
return Err(crate::error::McpxError::Config(format!(
"OAuth target DNS resolution returned no addresses: {url}"
)));
}
Ok(())
}
async fn screen_oauth_target(
url: &str,
allow_http: bool,
allowlist: &crate::ssrf::CompiledSsrfAllowlist,
) -> Result<(), crate::error::McpxError> {
#[cfg(any(test, feature = "test-helpers"))]
{
screen_oauth_target_with_test_override(url, allow_http, allowlist, false).await
}
#[cfg(not(any(test, feature = "test-helpers")))]
{
let parsed = check_oauth_url("oauth target", url, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
return Err(crate::error::McpxError::Config(format!(
"OAuth target forbidden ({reason}): {url}"
)));
}
let host = parsed.host_str().ok_or_else(|| {
crate::error::McpxError::Config(format!("OAuth target URL has no host: {url}"))
})?;
let port = parsed.port_or_known_default().ok_or_else(|| {
crate::error::McpxError::Config(format!("OAuth target URL has no known port: {url}"))
})?;
let addrs = lookup_host((host, port)).await.map_err(|error| {
crate::error::McpxError::Config(format!("OAuth target DNS resolution {url}: {error}"))
})?;
let host_allowed = !allowlist.is_empty() && allowlist.host_allowed(host);
let mut any_addr = false;
for addr in addrs {
any_addr = true;
let ip = addr.ip();
if let Some(reason) = crate::ssrf::ip_block_reason(ip) {
if reason == "cloud_metadata" {
return Err(crate::error::McpxError::Config(format!(
"OAuth target resolved to blocked IP ({reason}): {url}"
)));
}
if allowlist.is_empty() {
return Err(crate::error::McpxError::Config(format!(
"OAuth target resolved to blocked IP ({reason}): {url}"
)));
}
if host_allowed || allowlist.ip_allowed(ip) {
continue;
}
return Err(crate::error::McpxError::Config(format!(
"OAuth target blocked: hostname {host} resolved to {ip} ({reason}). \
To allow, add the hostname to oauth.ssrf_allowlist.hosts or the CIDR \
to oauth.ssrf_allowlist.cidrs (operators only -- see SECURITY.md). \
URL: {url}"
)));
}
}
if !any_addr {
return Err(crate::error::McpxError::Config(format!(
"OAuth target DNS resolution returned no addresses: {url}"
)));
}
Ok(())
}
}
#[derive(Clone)]
pub struct OauthHttpClient {
inner: reqwest::Client,
allow_http: bool,
allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
#[cfg(feature = "oauth-mtls-client")]
mtls_clients: Arc<HashMap<MtlsClientKey, reqwest::Client>>,
#[cfg(any(test, feature = "test-helpers"))]
test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
}
#[cfg(feature = "oauth-mtls-client")]
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
struct MtlsClientKey {
cert_path: PathBuf,
key_path: PathBuf,
}
impl OauthHttpClient {
pub fn with_config(config: &OAuthConfig) -> Result<Self, crate::error::McpxError> {
Self::build(Some(config))
}
#[deprecated(
since = "1.2.1",
note = "use OauthHttpClient::with_config(&OAuthConfig) so token/introspect/revoke/exchange traffic inherits ca_cert_path and the allow_http_oauth_urls toggle"
)]
pub fn new() -> Result<Self, crate::error::McpxError> {
Self::build(None)
}
fn build(config: Option<&OAuthConfig>) -> Result<Self, crate::error::McpxError> {
let allow_http = config.is_some_and(|c| c.allow_http_oauth_urls);
let allowlist = match config.and_then(|c| c.ssrf_allowlist.as_ref()) {
Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
crate::error::McpxError::Startup(format!("oauth http client: {e}"))
})?),
None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
};
let redirect_allowlist = Arc::clone(&allowlist);
#[cfg(any(test, feature = "test-helpers"))]
let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
Arc::new(AtomicBool::new(false));
#[cfg(not(any(test, feature = "test-helpers")))]
let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
let resolver: Arc<dyn reqwest::dns::Resolve> =
Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
Arc::clone(&allowlist),
#[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
test_bypass.clone(),
));
let mut builder = reqwest::Client::builder()
.no_proxy()
.dns_resolver(Arc::clone(&resolver))
.connect_timeout(Duration::from_secs(10))
.timeout(Duration::from_secs(30))
.redirect(reqwest::redirect::Policy::custom(move |attempt| {
match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
Ok(()) => attempt.follow(),
Err(reason) => {
tracing::warn!(
reason = %reason,
target = %attempt.url(),
"oauth redirect rejected"
);
attempt.error(reason)
}
}
}));
if let Some(cfg) = config
&& let Some(ref ca_path) = cfg.ca_cert_path
{
let pem = std::fs::read(ca_path).map_err(|e| {
crate::error::McpxError::Startup(format!(
"oauth http client: read ca_cert_path {}: {e}",
ca_path.display()
))
})?;
let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
crate::error::McpxError::Startup(format!(
"oauth http client: parse ca_cert_path {}: {e}",
ca_path.display()
))
})?;
builder = builder.add_root_certificate(cert);
}
let inner = builder.build().map_err(|e| {
crate::error::McpxError::Startup(format!("oauth http client init: {e}"))
})?;
#[cfg(feature = "oauth-mtls-client")]
let mtls_clients = build_mtls_clients(config, &allowlist, &test_bypass)?;
Ok(Self {
inner,
allow_http,
allowlist,
#[cfg(feature = "oauth-mtls-client")]
mtls_clients,
#[cfg(any(test, feature = "test-helpers"))]
test_allow_loopback_ssrf: test_bypass,
})
}
async fn send_screened(
&self,
url: &str,
request: reqwest::RequestBuilder,
) -> Result<reqwest::Response, crate::error::McpxError> {
#[cfg(any(test, feature = "test-helpers"))]
if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
screen_oauth_target_with_test_override(url, self.allow_http, &self.allowlist, true)
.await?;
} else {
screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
}
#[cfg(not(any(test, feature = "test-helpers")))]
screen_oauth_target(url, self.allow_http, &self.allowlist).await?;
request.send().await.map_err(|error| {
crate::error::McpxError::Config(format!("oauth request {url}: {error}"))
})
}
#[cfg(any(test, feature = "test-helpers"))]
#[doc(hidden)]
#[must_use]
pub fn __test_allow_loopback_ssrf(self) -> Self {
self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
self
}
#[doc(hidden)]
pub async fn __test_get(&self, url: &str) -> reqwest::Result<reqwest::Response> {
self.inner.get(url).send().await
}
#[cfg(any(test, feature = "test-helpers"))]
#[doc(hidden)]
#[must_use]
pub fn __test_inner_client(&self) -> &reqwest::Client {
&self.inner
}
#[cfg(feature = "oauth-mtls-client")]
fn client_for(&self, cfg: &TokenExchangeConfig) -> &reqwest::Client {
if let Some(cc) = &cfg.client_cert {
let key = MtlsClientKey {
cert_path: cc.cert_path.clone(),
key_path: cc.key_path.clone(),
};
if let Some(client) = self.mtls_clients.get(&key) {
return client;
}
}
&self.inner
}
#[cfg(not(feature = "oauth-mtls-client"))]
fn client_for(&self, _cfg: &TokenExchangeConfig) -> &reqwest::Client {
&self.inner
}
}
impl std::fmt::Debug for OauthHttpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OauthHttpClient").finish_non_exhaustive()
}
}
#[derive(Debug, Clone, Default, Deserialize)]
#[non_exhaustive]
pub struct OAuthSsrfAllowlist {
#[serde(default)]
pub hosts: Vec<String>,
#[serde(default)]
pub cidrs: Vec<String>,
}
fn compile_oauth_ssrf_allowlist(
raw: &OAuthSsrfAllowlist,
) -> Result<crate::ssrf::CompiledSsrfAllowlist, String> {
let mut hosts: Vec<String> = Vec::with_capacity(raw.hosts.len());
for (idx, entry) in raw.hosts.iter().enumerate() {
let trimmed = entry.trim();
if trimmed.is_empty() {
return Err(format!("oauth.ssrf_allowlist.hosts[{idx}]: empty entry"));
}
if trimmed.contains([':', '/', '@', '?', '#']) {
return Err(format!(
"oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: must be a bare DNS hostname \
(no scheme, port, path, userinfo, query, or fragment)"
));
}
match url::Host::parse(trimmed) {
Ok(url::Host::Domain(_)) => {}
Ok(url::Host::Ipv4(_) | url::Host::Ipv6(_)) => {
return Err(format!(
"oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: literal IPs are forbidden \
here -- list them via oauth.ssrf_allowlist.cidrs instead"
));
}
Err(e) => {
return Err(format!(
"oauth.ssrf_allowlist.hosts[{idx}] = {trimmed:?}: invalid hostname: {e}"
));
}
}
hosts.push(trimmed.to_ascii_lowercase());
}
hosts.sort();
hosts.dedup();
let mut cidrs = Vec::with_capacity(raw.cidrs.len());
for (idx, entry) in raw.cidrs.iter().enumerate() {
let parsed = crate::ssrf::CidrEntry::parse(entry)
.map_err(|e| format!("oauth.ssrf_allowlist.cidrs[{idx}]: {e}"))?;
cidrs.push(parsed);
}
Ok(crate::ssrf::CompiledSsrfAllowlist::new(hosts, cidrs))
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct OAuthConfig {
pub issuer: String,
pub audience: String,
pub jwks_uri: String,
#[serde(default)]
pub scopes: Vec<ScopeMapping>,
pub role_claim: Option<String>,
#[serde(default)]
pub role_mappings: Vec<RoleMapping>,
#[serde(default = "default_jwks_cache_ttl")]
pub jwks_cache_ttl: String,
pub proxy: Option<OAuthProxyConfig>,
pub token_exchange: Option<TokenExchangeConfig>,
#[serde(default)]
pub ca_cert_path: Option<PathBuf>,
#[serde(default)]
pub allow_http_oauth_urls: bool,
#[serde(default)]
pub ssrf_allowlist: Option<OAuthSsrfAllowlist>,
#[serde(default = "default_max_jwks_keys")]
pub max_jwks_keys: usize,
#[serde(default)]
#[deprecated(
since = "1.7.0",
note = "use `audience_validation_mode` instead; this field is consulted only when `audience_validation_mode` is None"
)]
pub strict_audience_validation: bool,
#[serde(default)]
pub audience_validation_mode: Option<AudienceValidationMode>,
#[serde(default = "default_jwks_max_bytes")]
pub jwks_max_response_bytes: u64,
}
fn default_jwks_cache_ttl() -> String {
"10m".into()
}
const fn default_max_jwks_keys() -> usize {
256
}
const fn default_jwks_max_bytes() -> u64 {
1024 * 1024
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "snake_case")]
#[non_exhaustive]
pub enum AudienceValidationMode {
Permissive,
#[default]
Warn,
Strict,
}
impl Default for OAuthConfig {
fn default() -> Self {
Self {
issuer: String::new(),
audience: String::new(),
jwks_uri: String::new(),
scopes: Vec::new(),
role_claim: None,
role_mappings: Vec::new(),
jwks_cache_ttl: default_jwks_cache_ttl(),
proxy: None,
token_exchange: None,
ca_cert_path: None,
allow_http_oauth_urls: false,
max_jwks_keys: default_max_jwks_keys(),
#[allow(
deprecated,
reason = "default-construct deprecated field for backward compat"
)]
strict_audience_validation: false,
audience_validation_mode: None,
jwks_max_response_bytes: default_jwks_max_bytes(),
ssrf_allowlist: None,
}
}
}
impl OAuthConfig {
#[must_use]
pub fn effective_audience_validation_mode(&self) -> AudienceValidationMode {
if let Some(mode) = self.audience_validation_mode {
return mode;
}
#[allow(deprecated, reason = "intentional: legacy flag resolution path")]
if self.strict_audience_validation {
AudienceValidationMode::Strict
} else {
AudienceValidationMode::Warn
}
}
pub fn builder(
issuer: impl Into<String>,
audience: impl Into<String>,
jwks_uri: impl Into<String>,
) -> OAuthConfigBuilder {
OAuthConfigBuilder {
inner: Self {
issuer: issuer.into(),
audience: audience.into(),
jwks_uri: jwks_uri.into(),
..Self::default()
},
}
}
pub fn validate(&self) -> Result<(), crate::error::McpxError> {
let allow_http = self.allow_http_oauth_urls;
let url = check_oauth_url("oauth.issuer", &self.issuer, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
return Err(crate::error::McpxError::Config(format!(
"oauth.issuer forbidden ({reason})"
)));
}
let url = check_oauth_url("oauth.jwks_uri", &self.jwks_uri, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
return Err(crate::error::McpxError::Config(format!(
"oauth.jwks_uri forbidden ({reason})"
)));
}
if let Some(proxy) = &self.proxy {
let url = check_oauth_url(
"oauth.proxy.authorize_url",
&proxy.authorize_url,
allow_http,
)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
return Err(crate::error::McpxError::Config(format!(
"oauth.proxy.authorize_url forbidden ({reason})"
)));
}
let url = check_oauth_url("oauth.proxy.token_url", &proxy.token_url, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
return Err(crate::error::McpxError::Config(format!(
"oauth.proxy.token_url forbidden ({reason})"
)));
}
if let Some(url) = &proxy.introspection_url {
let parsed = check_oauth_url("oauth.proxy.introspection_url", url, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
return Err(crate::error::McpxError::Config(format!(
"oauth.proxy.introspection_url forbidden ({reason})"
)));
}
}
if let Some(url) = &proxy.revocation_url {
let parsed = check_oauth_url("oauth.proxy.revocation_url", url, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&parsed) {
return Err(crate::error::McpxError::Config(format!(
"oauth.proxy.revocation_url forbidden ({reason})"
)));
}
}
if proxy.expose_admin_endpoints
&& !proxy.require_auth_on_admin_endpoints
&& !proxy.allow_unauthenticated_admin_endpoints
{
return Err(crate::error::McpxError::Config(
"oauth.proxy: expose_admin_endpoints = true requires \
require_auth_on_admin_endpoints = true (recommended) \
or allow_unauthenticated_admin_endpoints = true \
(explicit opt-out, only safe behind an authenticated \
reverse proxy)"
.into(),
));
}
}
if let Some(tx) = &self.token_exchange {
let url = check_oauth_url("oauth.token_exchange.token_url", &tx.token_url, allow_http)?;
if let Some(reason) = crate::ssrf::check_url_literal_ip(&url) {
return Err(crate::error::McpxError::Config(format!(
"oauth.token_exchange.token_url forbidden ({reason})"
)));
}
validate_token_exchange_client_auth(tx)?;
}
if let Some(raw) = &self.ssrf_allowlist {
let compiled = compile_oauth_ssrf_allowlist(raw).map_err(|e| {
crate::error::McpxError::Config(format!("oauth.ssrf_allowlist: {e}"))
})?;
if !compiled.is_empty() {
tracing::warn!(
host_count = compiled.host_count(),
cidr_count = compiled.cidr_count(),
"oauth.ssrf_allowlist is configured: private/loopback OAuth/JWKS targets \
are now reachable. Cloud-metadata addresses remain blocked. \
See SECURITY.md \"Operator allowlist\"."
);
}
}
humantime::parse_duration(&self.jwks_cache_ttl).map_err(|e| {
crate::error::McpxError::Config(format!(
"oauth.jwks_cache_ttl {:?} is not a valid humantime duration (e.g. \"10m\", \"1h30m\"): {e}",
self.jwks_cache_ttl
))
})?;
Ok(())
}
}
fn validate_token_exchange_client_auth(
tx: &TokenExchangeConfig,
) -> Result<(), crate::error::McpxError> {
match (&tx.client_cert, tx.client_secret.is_some()) {
(Some(_), true) => Err(crate::error::McpxError::Config(
"oauth.token_exchange: client_cert and client_secret are mutually \
exclusive (RFC 8705 §2). Set exactly one."
.into(),
)),
(None, false) => Err(crate::error::McpxError::Config(
"oauth.token_exchange: token exchange requires client authentication. \
Set either client_secret (RFC 6749 §2.3.1) or client_cert (RFC 8705 §2)."
.into(),
)),
(Some(cc), false) => validate_client_cert_config(cc),
(None, true) => Ok(()),
}
}
fn validate_client_cert_config(cc: &ClientCertConfig) -> Result<(), crate::error::McpxError> {
#[cfg(not(feature = "oauth-mtls-client"))]
{
let _ = cc;
Err(crate::error::McpxError::Config(
"oauth.token_exchange.client_cert requires the `oauth-mtls-client` cargo feature; \
rebuild rmcp-server-kit with --features oauth-mtls-client (or have your \
application crate enable it via `rmcp-server-kit/oauth-mtls-client`), or remove \
the field"
.into(),
))
}
#[cfg(feature = "oauth-mtls-client")]
{
let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
tracing::warn!(error = %e, path = %cc.cert_path.display(), "client cert read failed");
crate::error::McpxError::Config(format!(
"oauth.token_exchange.client_cert.cert_path unreadable: {}",
cc.cert_path.display()
))
})?;
let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
tracing::warn!(error = %e, path = %cc.key_path.display(), "client cert key read failed");
crate::error::McpxError::Config(format!(
"oauth.token_exchange.client_cert.key_path unreadable: {}",
cc.key_path.display()
))
})?;
let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
combined.extend_from_slice(&cert_bytes);
if !cert_bytes.ends_with(b"\n") {
combined.push(b'\n');
}
combined.extend_from_slice(&key_bytes);
let _identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
tracing::warn!(
error = %e,
cert_path = %cc.cert_path.display(),
key_path = %cc.key_path.display(),
"client cert PEM parse failed"
);
crate::error::McpxError::Config(format!(
"oauth.token_exchange.client_cert: PEM parse failed (cert={}, key={})",
cc.cert_path.display(),
cc.key_path.display()
))
})?;
Ok(())
}
}
#[cfg(feature = "oauth-mtls-client")]
fn build_mtls_clients(
config: Option<&OAuthConfig>,
allowlist: &Arc<crate::ssrf::CompiledSsrfAllowlist>,
test_bypass: &crate::ssrf_resolver::TestLoopbackBypass,
) -> Result<Arc<HashMap<MtlsClientKey, reqwest::Client>>, crate::error::McpxError> {
let mut map: HashMap<MtlsClientKey, reqwest::Client> = HashMap::new();
let Some(cfg) = config else {
return Ok(Arc::new(map));
};
let Some(tx) = &cfg.token_exchange else {
return Ok(Arc::new(map));
};
let Some(cc) = &tx.client_cert else {
return Ok(Arc::new(map));
};
let cert_bytes = std::fs::read(&cc.cert_path).map_err(|e| {
crate::error::McpxError::Startup(format!(
"oauth http client mTLS: read cert_path {}: {e}",
cc.cert_path.display()
))
})?;
let key_bytes = std::fs::read(&cc.key_path).map_err(|e| {
crate::error::McpxError::Startup(format!(
"oauth http client mTLS: read key_path {}: {e}",
cc.key_path.display()
))
})?;
let mut combined = Vec::with_capacity(cert_bytes.len() + 1 + key_bytes.len());
combined.extend_from_slice(&cert_bytes);
if !cert_bytes.ends_with(b"\n") {
combined.push(b'\n');
}
combined.extend_from_slice(&key_bytes);
let identity = reqwest::Identity::from_pem(&combined).map_err(|e| {
crate::error::McpxError::Startup(format!(
"oauth http client mTLS: PEM parse (cert={}, key={}): {e}",
cc.cert_path.display(),
cc.key_path.display()
))
})?;
let resolver: Arc<dyn reqwest::dns::Resolve> =
Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
Arc::clone(allowlist),
#[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
test_bypass.clone(),
));
let mut builder = reqwest::Client::builder()
.no_proxy()
.dns_resolver(Arc::clone(&resolver))
.connect_timeout(Duration::from_secs(10))
.timeout(Duration::from_secs(30))
.redirect(reqwest::redirect::Policy::none())
.identity(identity);
if let Some(ref ca_path) = cfg.ca_cert_path {
let pem = std::fs::read(ca_path).map_err(|e| {
crate::error::McpxError::Startup(format!(
"oauth http client mTLS: read ca_cert_path {}: {e}",
ca_path.display()
))
})?;
let cert = reqwest::tls::Certificate::from_pem(&pem).map_err(|e| {
crate::error::McpxError::Startup(format!(
"oauth http client mTLS: parse ca_cert_path {}: {e}",
ca_path.display()
))
})?;
builder = builder.add_root_certificate(cert);
}
let client = builder.build().map_err(|e| {
crate::error::McpxError::Startup(format!("oauth http client mTLS init: {e}"))
})?;
map.insert(
MtlsClientKey {
cert_path: cc.cert_path.clone(),
key_path: cc.key_path.clone(),
},
client,
);
Ok(Arc::new(map))
}
fn check_oauth_url(
field: &str,
raw: &str,
allow_http: bool,
) -> Result<url::Url, crate::error::McpxError> {
let parsed = url::Url::parse(raw).map_err(|e| {
crate::error::McpxError::Config(format!("{field}: invalid URL {raw:?}: {e}"))
})?;
if !parsed.username().is_empty() || parsed.password().is_some() {
return Err(crate::error::McpxError::Config(format!(
"{field} rejected: URL contains userinfo (credentials in URL are forbidden)"
)));
}
match parsed.scheme() {
"https" => Ok(parsed),
"http" if allow_http => Ok(parsed),
"http" => Err(crate::error::McpxError::Config(format!(
"{field}: must use https scheme (got http; set allow_http_oauth_urls=true \
to override - strongly discouraged in production)"
))),
other => Err(crate::error::McpxError::Config(format!(
"{field}: must use https scheme (got {other:?})"
))),
}
}
#[derive(Debug, Clone)]
#[must_use = "builders do nothing until `.build()` is called"]
pub struct OAuthConfigBuilder {
inner: OAuthConfig,
}
impl OAuthConfigBuilder {
pub fn scopes(mut self, scopes: Vec<ScopeMapping>) -> Self {
self.inner.scopes = scopes;
self
}
pub fn scope(mut self, scope: impl Into<String>, role: impl Into<String>) -> Self {
self.inner.scopes.push(ScopeMapping {
scope: scope.into(),
role: role.into(),
});
self
}
pub fn role_claim(mut self, claim: impl Into<String>) -> Self {
self.inner.role_claim = Some(claim.into());
self
}
pub fn role_mappings(mut self, mappings: Vec<RoleMapping>) -> Self {
self.inner.role_mappings = mappings;
self
}
pub fn role_mapping(mut self, claim_value: impl Into<String>, role: impl Into<String>) -> Self {
self.inner.role_mappings.push(RoleMapping {
claim_value: claim_value.into(),
role: role.into(),
});
self
}
pub fn jwks_cache_ttl(mut self, ttl: impl Into<String>) -> Self {
self.inner.jwks_cache_ttl = ttl.into();
self
}
pub fn proxy(mut self, proxy: OAuthProxyConfig) -> Self {
self.inner.proxy = Some(proxy);
self
}
pub fn token_exchange(mut self, token_exchange: TokenExchangeConfig) -> Self {
self.inner.token_exchange = Some(token_exchange);
self
}
pub fn ca_cert_path(mut self, path: impl Into<PathBuf>) -> Self {
self.inner.ca_cert_path = Some(path.into());
self
}
pub const fn allow_http_oauth_urls(mut self, allow: bool) -> Self {
self.inner.allow_http_oauth_urls = allow;
self
}
#[deprecated(since = "1.7.0", note = "use `audience_validation_mode` instead")]
pub const fn strict_audience_validation(mut self, strict: bool) -> Self {
#[allow(
deprecated,
reason = "intentional: deprecated builder forwards to deprecated field"
)]
{
self.inner.strict_audience_validation = strict;
}
self.inner.audience_validation_mode = None;
self
}
pub const fn audience_validation_mode(mut self, mode: AudienceValidationMode) -> Self {
self.inner.audience_validation_mode = Some(mode);
self
}
pub const fn jwks_max_response_bytes(mut self, bytes: u64) -> Self {
self.inner.jwks_max_response_bytes = bytes;
self
}
pub fn ssrf_allowlist(mut self, allowlist: OAuthSsrfAllowlist) -> Self {
self.inner.ssrf_allowlist = Some(allowlist);
self
}
#[must_use]
pub fn build(self) -> OAuthConfig {
self.inner
}
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct ScopeMapping {
pub scope: String,
pub role: String,
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct RoleMapping {
pub claim_value: String,
pub role: String,
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct TokenExchangeConfig {
pub token_url: String,
pub client_id: String,
pub client_secret: Option<secrecy::SecretString>,
pub client_cert: Option<ClientCertConfig>,
pub audience: String,
}
impl TokenExchangeConfig {
#[must_use]
pub fn new(
token_url: String,
client_id: String,
client_secret: Option<secrecy::SecretString>,
client_cert: Option<ClientCertConfig>,
audience: String,
) -> Self {
Self {
token_url,
client_id,
client_secret,
client_cert,
audience,
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[non_exhaustive]
pub struct ClientCertConfig {
pub cert_path: PathBuf,
pub key_path: PathBuf,
}
impl ClientCertConfig {
#[must_use]
pub fn new(cert_path: PathBuf, key_path: PathBuf) -> Self {
Self {
cert_path,
key_path,
}
}
}
#[derive(Debug, Deserialize)]
#[non_exhaustive]
pub struct ExchangedToken {
pub access_token: String,
pub expires_in: Option<u64>,
pub issued_token_type: Option<String>,
}
#[derive(Debug, Clone, Deserialize, Default)]
#[non_exhaustive]
pub struct OAuthProxyConfig {
pub authorize_url: String,
pub token_url: String,
pub client_id: String,
pub client_secret: Option<secrecy::SecretString>,
#[serde(default)]
pub introspection_url: Option<String>,
#[serde(default)]
pub revocation_url: Option<String>,
#[serde(default)]
pub expose_admin_endpoints: bool,
#[serde(default)]
pub require_auth_on_admin_endpoints: bool,
#[serde(default)]
pub allow_unauthenticated_admin_endpoints: bool,
}
impl OAuthProxyConfig {
pub fn builder(
authorize_url: impl Into<String>,
token_url: impl Into<String>,
client_id: impl Into<String>,
) -> OAuthProxyConfigBuilder {
OAuthProxyConfigBuilder {
inner: Self {
authorize_url: authorize_url.into(),
token_url: token_url.into(),
client_id: client_id.into(),
..Self::default()
},
}
}
}
#[derive(Debug, Clone)]
#[must_use = "builders do nothing until `.build()` is called"]
pub struct OAuthProxyConfigBuilder {
inner: OAuthProxyConfig,
}
impl OAuthProxyConfigBuilder {
pub fn client_secret(mut self, secret: secrecy::SecretString) -> Self {
self.inner.client_secret = Some(secret);
self
}
pub fn introspection_url(mut self, url: impl Into<String>) -> Self {
self.inner.introspection_url = Some(url.into());
self
}
pub fn revocation_url(mut self, url: impl Into<String>) -> Self {
self.inner.revocation_url = Some(url.into());
self
}
pub const fn expose_admin_endpoints(mut self, expose: bool) -> Self {
self.inner.expose_admin_endpoints = expose;
self
}
pub const fn require_auth_on_admin_endpoints(mut self, require: bool) -> Self {
self.inner.require_auth_on_admin_endpoints = require;
self
}
pub const fn allow_unauthenticated_admin_endpoints(mut self, allow: bool) -> Self {
self.inner.allow_unauthenticated_admin_endpoints = allow;
self
}
#[must_use]
pub fn build(self) -> OAuthProxyConfig {
self.inner
}
}
type JwksKeyCache = (
HashMap<String, (Algorithm, DecodingKey)>,
Vec<(Algorithm, DecodingKey)>,
);
struct CachedKeys {
keys: HashMap<String, (Algorithm, DecodingKey)>,
unnamed_keys: Vec<(Algorithm, DecodingKey)>,
fetched_at: Instant,
ttl: Duration,
}
impl CachedKeys {
fn is_expired(&self) -> bool {
self.fetched_at.elapsed() >= self.ttl
}
}
#[allow(
missing_debug_implementations,
reason = "contains reqwest::Client and DecodingKey cache with no Debug impl"
)]
#[non_exhaustive]
pub struct JwksCache {
jwks_uri: String,
ttl: Duration,
max_jwks_keys: usize,
max_response_bytes: u64,
allow_http: bool,
inner: RwLock<Option<CachedKeys>>,
http: reqwest::Client,
validation_template: Validation,
expected_audience: String,
audience_mode: AudienceValidationMode,
azp_fallback_warned: AtomicBool,
scopes: Vec<ScopeMapping>,
role_claim: Option<String>,
role_mappings: Vec<RoleMapping>,
last_refresh_attempt: RwLock<Option<Instant>>,
refresh_lock: tokio::sync::Mutex<()>,
allowlist: Arc<crate::ssrf::CompiledSsrfAllowlist>,
#[cfg(any(test, feature = "test-helpers"))]
test_allow_loopback_ssrf: crate::ssrf_resolver::TestLoopbackBypass,
}
const JWKS_REFRESH_COOLDOWN: Duration = Duration::from_secs(10);
const ACCEPTED_ALGS: &[Algorithm] = &[
Algorithm::RS256,
Algorithm::RS384,
Algorithm::RS512,
Algorithm::ES256,
Algorithm::ES384,
Algorithm::PS256,
Algorithm::PS384,
Algorithm::PS512,
Algorithm::EdDSA,
];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum JwtValidationFailure {
Expired,
Invalid,
}
impl JwksCache {
pub fn new(config: &OAuthConfig) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
rustls::crypto::ring::default_provider()
.install_default()
.ok();
jsonwebtoken::crypto::rust_crypto::DEFAULT_PROVIDER
.install_default()
.ok();
let ttl = humantime::parse_duration(&config.jwks_cache_ttl)
.expect("jwks_cache_ttl validated by OAuthConfig::validate");
let mut validation = Validation::new(Algorithm::RS256);
validation.validate_aud = false;
validation.set_issuer(&[&config.issuer]);
validation.set_required_spec_claims(&["exp", "iss"]);
validation.validate_exp = true;
validation.validate_nbf = true;
let allow_http = config.allow_http_oauth_urls;
let allowlist = match config.ssrf_allowlist.as_ref() {
Some(raw) => Arc::new(compile_oauth_ssrf_allowlist(raw).map_err(|e| {
Box::<dyn std::error::Error + Send + Sync>::from(format!(
"oauth.ssrf_allowlist: {e}"
))
})?),
None => Arc::new(crate::ssrf::CompiledSsrfAllowlist::default()),
};
let redirect_allowlist = Arc::clone(&allowlist);
#[cfg(any(test, feature = "test-helpers"))]
let test_bypass: crate::ssrf_resolver::TestLoopbackBypass =
Arc::new(AtomicBool::new(false));
#[cfg(not(any(test, feature = "test-helpers")))]
let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = ();
let resolver: Arc<dyn reqwest::dns::Resolve> =
Arc::new(crate::ssrf_resolver::SsrfScreeningResolver::new(
Arc::clone(&allowlist),
#[allow(clippy::clone_on_ref_ptr, reason = "type alias varies per feature")]
test_bypass.clone(),
));
let mut http_builder = reqwest::Client::builder()
.no_proxy()
.dns_resolver(Arc::clone(&resolver))
.timeout(Duration::from_secs(10))
.connect_timeout(Duration::from_secs(3))
.redirect(reqwest::redirect::Policy::custom(move |attempt| {
match evaluate_oauth_redirect(&attempt, allow_http, &redirect_allowlist) {
Ok(()) => attempt.follow(),
Err(reason) => {
tracing::warn!(
reason = %reason,
target = %attempt.url(),
"oauth redirect rejected"
);
attempt.error(reason)
}
}
}));
if let Some(ref ca_path) = config.ca_cert_path {
let pem = std::fs::read(ca_path)?;
let cert = reqwest::tls::Certificate::from_pem(&pem)?;
http_builder = http_builder.add_root_certificate(cert);
}
let http = http_builder.build()?;
Ok(Self {
jwks_uri: config.jwks_uri.clone(),
ttl,
max_jwks_keys: config.max_jwks_keys,
max_response_bytes: config.jwks_max_response_bytes,
allow_http,
inner: RwLock::new(None),
http,
validation_template: validation,
expected_audience: config.audience.clone(),
audience_mode: config.effective_audience_validation_mode(),
azp_fallback_warned: AtomicBool::new(false),
scopes: config.scopes.clone(),
role_claim: config.role_claim.clone(),
role_mappings: config.role_mappings.clone(),
last_refresh_attempt: RwLock::new(None),
refresh_lock: tokio::sync::Mutex::new(()),
allowlist,
#[cfg(any(test, feature = "test-helpers"))]
test_allow_loopback_ssrf: test_bypass,
})
}
#[cfg(any(test, feature = "test-helpers"))]
#[doc(hidden)]
#[must_use]
pub fn __test_allow_loopback_ssrf(self) -> Self {
self.test_allow_loopback_ssrf.store(true, Ordering::Relaxed);
self
}
pub async fn validate_token(&self, token: &str) -> Option<AuthIdentity> {
self.validate_token_with_reason(token).await.ok()
}
pub async fn validate_token_with_reason(
&self,
token: &str,
) -> Result<AuthIdentity, JwtValidationFailure> {
let claims = self.decode_claims(token).await?;
self.check_audience(&claims)?;
let role = self.resolve_role(&claims)?;
let sub = claims.sub;
let name = claims
.extra
.get("preferred_username")
.and_then(|v| v.as_str())
.map(String::from)
.or_else(|| sub.clone())
.or(claims.azp)
.or(claims.client_id)
.unwrap_or_else(|| "oauth-client".into());
Ok(AuthIdentity {
name,
role,
method: AuthMethod::OAuthJwt,
raw_token: None,
sub,
})
}
async fn decode_claims(&self, token: &str) -> Result<Claims, JwtValidationFailure> {
let (key, alg) = self.select_jwks_key(token).await?;
let mut validation = self.validation_template.clone();
validation.algorithms = vec![alg];
let token_owned = token.to_owned();
let join =
tokio::task::spawn_blocking(move || decode::<Claims>(&token_owned, &key, &validation))
.await;
let decode_result = match join {
Ok(r) => r,
Err(join_err) => {
core::hint::cold_path();
tracing::error!(
error = %join_err,
"JWT decode task panicked or was cancelled"
);
return Err(JwtValidationFailure::Invalid);
}
};
decode_result.map(|td| td.claims).map_err(|e| {
core::hint::cold_path();
let failure = if matches!(e.kind(), jsonwebtoken::errors::ErrorKind::ExpiredSignature) {
JwtValidationFailure::Expired
} else {
JwtValidationFailure::Invalid
};
tracing::debug!(error = %e, ?alg, ?failure, "JWT decode failed");
failure
})
}
#[allow(clippy::cognitive_complexity)]
async fn select_jwks_key(
&self,
token: &str,
) -> Result<(DecodingKey, Algorithm), JwtValidationFailure> {
let Ok(header) = decode_header(token) else {
core::hint::cold_path();
tracing::debug!("JWT header decode failed");
return Err(JwtValidationFailure::Invalid);
};
let kid = header.kid.as_deref();
tracing::debug!(alg = ?header.alg, kid = kid.unwrap_or("-"), "JWT header decoded");
if !ACCEPTED_ALGS.contains(&header.alg) {
core::hint::cold_path();
tracing::debug!(alg = ?header.alg, "JWT algorithm not accepted");
return Err(JwtValidationFailure::Invalid);
}
let Some(key) = self.find_key(kid, header.alg).await else {
core::hint::cold_path();
tracing::debug!(kid = kid.unwrap_or("-"), alg = ?header.alg, "no matching JWKS key found");
return Err(JwtValidationFailure::Invalid);
};
Ok((key, header.alg))
}
fn check_audience(&self, claims: &Claims) -> Result<(), JwtValidationFailure> {
if claims.aud.contains(&self.expected_audience) {
return Ok(());
}
let azp_match = claims
.azp
.as_deref()
.is_some_and(|azp| azp == self.expected_audience);
if azp_match {
match self.audience_mode {
AudienceValidationMode::Permissive => return Ok(()),
AudienceValidationMode::Warn => {
if !self.azp_fallback_warned.swap(true, Ordering::Relaxed) {
tracing::warn!(
expected = %self.expected_audience,
azp = ?claims.azp,
"JWT accepted via deprecated `azp`-only audience fallback. \
Configure your IdP to populate `aud`, or set \
`audience_validation_mode = \"strict\"` once tokens carry `aud` correctly. \
To silence this warning without changing acceptance, \
set `audience_validation_mode = \"permissive\"`. \
This warning logs once per process."
);
}
return Ok(());
}
AudienceValidationMode::Strict => {}
}
}
core::hint::cold_path();
tracing::debug!(
aud = ?claims.aud.0,
azp = ?claims.azp,
expected = %self.expected_audience,
mode = ?self.audience_mode,
"JWT rejected: audience mismatch"
);
Err(JwtValidationFailure::Invalid)
}
fn resolve_role(&self, claims: &Claims) -> Result<String, JwtValidationFailure> {
if let Some(ref claim_path) = self.role_claim {
let owned_first_class: Vec<String> = first_class_claim_values(claims, claim_path);
let mut values: Vec<&str> = owned_first_class.iter().map(String::as_str).collect();
values.extend(resolve_claim_path(&claims.extra, claim_path));
return self
.role_mappings
.iter()
.find(|m| values.contains(&m.claim_value.as_str()))
.map(|m| m.role.clone())
.ok_or(JwtValidationFailure::Invalid);
}
let token_scopes: Vec<&str> = claims
.scope
.as_deref()
.unwrap_or("")
.split_whitespace()
.collect();
self.scopes
.iter()
.find(|m| token_scopes.contains(&m.scope.as_str()))
.map(|m| m.role.clone())
.ok_or(JwtValidationFailure::Invalid)
}
async fn find_key(&self, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
{
let guard = self.inner.read().await;
if let Some(cached) = guard.as_ref()
&& !cached.is_expired()
&& let Some(key) = lookup_key(cached, kid, alg)
{
return Some(key);
}
}
self.refresh_with_cooldown().await;
let guard = self.inner.read().await;
guard
.as_ref()
.and_then(|cached| lookup_key(cached, kid, alg))
}
async fn refresh_with_cooldown(&self) {
let _guard = self.refresh_lock.lock().await;
{
let last = self.last_refresh_attempt.read().await;
if let Some(ts) = *last
&& ts.elapsed() < JWKS_REFRESH_COOLDOWN
{
tracing::debug!(
elapsed_ms = ts.elapsed().as_millis(),
cooldown_ms = JWKS_REFRESH_COOLDOWN.as_millis(),
"JWKS refresh skipped (cooldown active)"
);
return;
}
}
{
let mut last = self.last_refresh_attempt.write().await;
*last = Some(Instant::now());
}
let _ = self.refresh_inner().await;
}
async fn refresh_inner(&self) -> Result<(), String> {
let Some(jwks) = self.fetch_jwks().await else {
return Ok(());
};
let (keys, unnamed_keys) = match build_key_cache(&jwks, self.max_jwks_keys) {
Ok(cache) => cache,
Err(msg) => {
tracing::warn!(reason = %msg, "JWKS key cap exceeded; refusing to populate cache");
return Err(msg);
}
};
tracing::debug!(
named = keys.len(),
unnamed = unnamed_keys.len(),
"JWKS refreshed"
);
let mut guard = self.inner.write().await;
*guard = Some(CachedKeys {
keys,
unnamed_keys,
fetched_at: Instant::now(),
ttl: self.ttl,
});
Ok(())
}
#[allow(
clippy::cognitive_complexity,
reason = "screening, bounded streaming, and parse logging are intentionally kept in one fetch path"
)]
async fn fetch_jwks(&self) -> Option<JwkSet> {
#[cfg(any(test, feature = "test-helpers"))]
let screening = if self.test_allow_loopback_ssrf.load(Ordering::Relaxed) {
screen_oauth_target_with_test_override(
&self.jwks_uri,
self.allow_http,
&self.allowlist,
true,
)
.await
} else {
screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await
};
#[cfg(not(any(test, feature = "test-helpers")))]
let screening = screen_oauth_target(&self.jwks_uri, self.allow_http, &self.allowlist).await;
if let Err(error) = screening {
tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to screen JWKS target");
return None;
}
let mut resp = match self.http.get(&self.jwks_uri).send().await {
Ok(resp) => resp,
Err(e) => {
tracing::warn!(error = %e, uri = %self.jwks_uri, "failed to fetch JWKS");
return None;
}
};
let initial_capacity =
usize::try_from(self.max_response_bytes.min(64 * 1024)).unwrap_or(64 * 1024);
let mut body = Vec::with_capacity(initial_capacity);
while let Some(chunk) = match resp.chunk().await {
Ok(chunk) => chunk,
Err(error) => {
tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to read JWKS response");
return None;
}
} {
let chunk_len = u64::try_from(chunk.len()).unwrap_or(u64::MAX);
let body_len = u64::try_from(body.len()).unwrap_or(u64::MAX);
if body_len.saturating_add(chunk_len) > self.max_response_bytes {
tracing::warn!(
uri = %self.jwks_uri,
max_bytes = self.max_response_bytes,
"JWKS response exceeded configured size cap"
);
return None;
}
body.extend_from_slice(&chunk);
}
match serde_json::from_slice::<JwkSet>(&body) {
Ok(jwks) => Some(jwks),
Err(error) => {
tracing::warn!(error = %error, uri = %self.jwks_uri, "failed to parse JWKS");
None
}
}
}
#[cfg(any(test, feature = "test-helpers"))]
#[doc(hidden)]
pub async fn __test_refresh_now(&self) -> Result<(), String> {
let jwks = self
.fetch_jwks()
.await
.ok_or_else(|| "failed to fetch or parse JWKS".to_owned())?;
let (keys, unnamed_keys) = build_key_cache(&jwks, self.max_jwks_keys)?;
let mut guard = self.inner.write().await;
*guard = Some(CachedKeys {
keys,
unnamed_keys,
fetched_at: Instant::now(),
ttl: self.ttl,
});
Ok(())
}
#[cfg(any(test, feature = "test-helpers"))]
#[doc(hidden)]
pub async fn __test_has_kid(&self, kid: &str) -> bool {
let guard = self.inner.read().await;
guard
.as_ref()
.is_some_and(|cache| cache.keys.contains_key(kid))
}
}
fn build_key_cache(jwks: &JwkSet, max_keys: usize) -> Result<JwksKeyCache, String> {
if jwks.keys.len() > max_keys {
return Err(format!(
"jwks_key_count_exceeds_cap: got {} keys, max is {}",
jwks.keys.len(),
max_keys
));
}
let mut keys = HashMap::new();
let mut unnamed_keys = Vec::new();
for jwk in &jwks.keys {
let Ok(decoding_key) = DecodingKey::from_jwk(jwk) else {
continue;
};
let Some(alg) = jwk_algorithm(jwk) else {
continue;
};
if let Some(ref kid) = jwk.common.key_id {
keys.insert(kid.clone(), (alg, decoding_key));
} else {
unnamed_keys.push((alg, decoding_key));
}
}
Ok((keys, unnamed_keys))
}
fn lookup_key(cached: &CachedKeys, kid: Option<&str>, alg: Algorithm) -> Option<DecodingKey> {
if let Some(kid) = kid
&& let Some((cached_alg, key)) = cached.keys.get(kid)
&& *cached_alg == alg
{
return Some(key.clone());
}
cached
.unnamed_keys
.iter()
.find(|(a, _)| *a == alg)
.map(|(_, k)| k.clone())
}
#[allow(clippy::wildcard_enum_match_arm)]
fn jwk_algorithm(jwk: &jsonwebtoken::jwk::Jwk) -> Option<Algorithm> {
jwk.common.key_algorithm.and_then(|ka| match ka {
jsonwebtoken::jwk::KeyAlgorithm::RS256 => Some(Algorithm::RS256),
jsonwebtoken::jwk::KeyAlgorithm::RS384 => Some(Algorithm::RS384),
jsonwebtoken::jwk::KeyAlgorithm::RS512 => Some(Algorithm::RS512),
jsonwebtoken::jwk::KeyAlgorithm::ES256 => Some(Algorithm::ES256),
jsonwebtoken::jwk::KeyAlgorithm::ES384 => Some(Algorithm::ES384),
jsonwebtoken::jwk::KeyAlgorithm::PS256 => Some(Algorithm::PS256),
jsonwebtoken::jwk::KeyAlgorithm::PS384 => Some(Algorithm::PS384),
jsonwebtoken::jwk::KeyAlgorithm::PS512 => Some(Algorithm::PS512),
jsonwebtoken::jwk::KeyAlgorithm::EdDSA => Some(Algorithm::EdDSA),
_ => None,
})
}
fn first_class_claim_values(claims: &Claims, path: &str) -> Vec<String> {
match path {
"sub" => claims.sub.iter().cloned().collect(),
"azp" => claims.azp.iter().cloned().collect(),
"client_id" => claims.client_id.iter().cloned().collect(),
"aud" => claims.aud.0.clone(),
"scope" => claims
.scope
.as_deref()
.unwrap_or("")
.split_whitespace()
.map(str::to_owned)
.collect(),
_ => Vec::new(),
}
}
fn resolve_claim_path<'a>(
extra: &'a HashMap<String, serde_json::Value>,
path: &str,
) -> Vec<&'a str> {
let mut segments = path.split('.');
let Some(first) = segments.next() else {
return Vec::new();
};
let mut current: Option<&serde_json::Value> = extra.get(first);
for segment in segments {
current = current.and_then(|v| v.get(segment));
}
match current {
Some(serde_json::Value::String(s)) => s.split_whitespace().collect(),
Some(serde_json::Value::Array(arr)) => arr.iter().filter_map(|v| v.as_str()).collect(),
_ => Vec::new(),
}
}
#[derive(Debug, Deserialize)]
struct Claims {
sub: Option<String>,
#[serde(default)]
aud: OneOrMany,
azp: Option<String>,
client_id: Option<String>,
scope: Option<String>,
#[serde(flatten)]
extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Default)]
struct OneOrMany(Vec<String>);
impl OneOrMany {
fn contains(&self, value: &str) -> bool {
self.0.iter().any(|v| v == value)
}
}
impl<'de> Deserialize<'de> for OneOrMany {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
use serde::de;
struct Visitor;
impl<'de> de::Visitor<'de> for Visitor {
type Value = OneOrMany;
fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("a string or array of strings")
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<OneOrMany, E> {
Ok(OneOrMany(vec![v.to_owned()]))
}
fn visit_seq<A: de::SeqAccess<'de>>(self, mut seq: A) -> Result<OneOrMany, A::Error> {
let mut v = Vec::new();
while let Some(s) = seq.next_element::<String>()? {
v.push(s);
}
Ok(OneOrMany(v))
}
}
deserializer.deserialize_any(Visitor)
}
}
#[must_use]
pub fn looks_like_jwt(token: &str) -> bool {
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
let mut parts = token.splitn(4, '.');
let Some(header_b64) = parts.next() else {
return false;
};
if parts.next().is_none() || parts.next().is_none() || parts.next().is_some() {
return false;
}
let Ok(header_bytes) = URL_SAFE_NO_PAD.decode(header_b64) else {
return false;
};
let Ok(header) = serde_json::from_slice::<serde_json::Value>(&header_bytes) else {
return false;
};
header.get("alg").is_some()
}
#[must_use]
pub fn protected_resource_metadata(
resource_url: &str,
server_url: &str,
config: &OAuthConfig,
) -> serde_json::Value {
let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
let auth_server = server_url;
serde_json::json!({
"resource": resource_url,
"authorization_servers": [auth_server],
"scopes_supported": scopes,
"bearer_methods_supported": ["header"]
})
}
#[must_use]
pub fn authorization_server_metadata(server_url: &str, config: &OAuthConfig) -> serde_json::Value {
let scopes: Vec<&str> = config.scopes.iter().map(|s| s.scope.as_str()).collect();
let mut meta = serde_json::json!({
"issuer": &config.issuer,
"authorization_endpoint": format!("{server_url}/authorize"),
"token_endpoint": format!("{server_url}/token"),
"registration_endpoint": format!("{server_url}/register"),
"response_types_supported": ["code"],
"grant_types_supported": ["authorization_code", "refresh_token"],
"code_challenge_methods_supported": ["S256"],
"scopes_supported": scopes,
"token_endpoint_auth_methods_supported": ["none"],
});
if let Some(proxy) = &config.proxy
&& proxy.expose_admin_endpoints
&& let Some(obj) = meta.as_object_mut()
{
if proxy.introspection_url.is_some() {
obj.insert(
"introspection_endpoint".into(),
serde_json::Value::String(format!("{server_url}/introspect")),
);
}
if proxy.revocation_url.is_some() {
obj.insert(
"revocation_endpoint".into(),
serde_json::Value::String(format!("{server_url}/revoke")),
);
}
if proxy.require_auth_on_admin_endpoints {
obj.insert(
"introspection_endpoint_auth_methods_supported".into(),
serde_json::json!(["bearer"]),
);
obj.insert(
"revocation_endpoint_auth_methods_supported".into(),
serde_json::json!(["bearer"]),
);
}
}
meta
}
#[must_use]
pub fn handle_authorize(proxy: &OAuthProxyConfig, query: &str) -> axum::response::Response {
use axum::{
http::{StatusCode, header},
response::IntoResponse,
};
let upstream_query = replace_client_id(query, &proxy.client_id);
let redirect_url = format!("{}?{upstream_query}", proxy.authorize_url);
(StatusCode::FOUND, [(header::LOCATION, redirect_url)]).into_response()
}
pub async fn handle_token(
http: &OauthHttpClient,
proxy: &OAuthProxyConfig,
body: &str,
) -> axum::response::Response {
use axum::{
http::{StatusCode, header},
response::IntoResponse,
};
let mut upstream_body = replace_client_id(body, &proxy.client_id);
if let Some(ref secret) = proxy.client_secret {
use std::fmt::Write;
use secrecy::ExposeSecret;
let _ = write!(
upstream_body,
"&client_secret={}",
urlencoding::encode(secret.expose_secret())
);
}
let result = http
.send_screened(
&proxy.token_url,
http.inner
.post(&proxy.token_url)
.header("Content-Type", "application/x-www-form-urlencoded")
.body(upstream_body),
)
.await;
match result {
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let body_bytes = resp.bytes().await.unwrap_or_default();
(
status,
[(header::CONTENT_TYPE, "application/json")],
body_bytes,
)
.into_response()
}
Err(e) => {
tracing::error!(error = %e, "OAuth token proxy request failed");
(
StatusCode::BAD_GATEWAY,
[(header::CONTENT_TYPE, "application/json")],
"{\"error\":\"server_error\",\"error_description\":\"token endpoint unreachable\"}",
)
.into_response()
}
}
}
#[must_use]
pub fn handle_register(proxy: &OAuthProxyConfig, body: &serde_json::Value) -> serde_json::Value {
let mut resp = serde_json::json!({
"client_id": proxy.client_id,
"token_endpoint_auth_method": "none",
});
if let Some(uris) = body.get("redirect_uris")
&& let Some(obj) = resp.as_object_mut()
{
obj.insert("redirect_uris".into(), uris.clone());
}
if let Some(name) = body.get("client_name")
&& let Some(obj) = resp.as_object_mut()
{
obj.insert("client_name".into(), name.clone());
}
resp
}
pub async fn handle_introspect(
http: &OauthHttpClient,
proxy: &OAuthProxyConfig,
body: &str,
) -> axum::response::Response {
let Some(ref url) = proxy.introspection_url else {
return oauth_error_response(
axum::http::StatusCode::NOT_FOUND,
"not_supported",
"introspection endpoint is not configured",
);
};
proxy_oauth_admin_request(http, proxy, url, body).await
}
pub async fn handle_revoke(
http: &OauthHttpClient,
proxy: &OAuthProxyConfig,
body: &str,
) -> axum::response::Response {
let Some(ref url) = proxy.revocation_url else {
return oauth_error_response(
axum::http::StatusCode::NOT_FOUND,
"not_supported",
"revocation endpoint is not configured",
);
};
proxy_oauth_admin_request(http, proxy, url, body).await
}
async fn proxy_oauth_admin_request(
http: &OauthHttpClient,
proxy: &OAuthProxyConfig,
upstream_url: &str,
body: &str,
) -> axum::response::Response {
use axum::{
http::{StatusCode, header},
response::IntoResponse,
};
let mut upstream_body = replace_client_id(body, &proxy.client_id);
if let Some(ref secret) = proxy.client_secret {
use std::fmt::Write;
use secrecy::ExposeSecret;
let _ = write!(
upstream_body,
"&client_secret={}",
urlencoding::encode(secret.expose_secret())
);
}
let result = http
.send_screened(
upstream_url,
http.inner
.post(upstream_url)
.header("Content-Type", "application/x-www-form-urlencoded")
.body(upstream_body),
)
.await;
match result {
Ok(resp) => {
let status =
StatusCode::from_u16(resp.status().as_u16()).unwrap_or(StatusCode::BAD_GATEWAY);
let content_type = resp
.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.unwrap_or("application/json")
.to_owned();
let body_bytes = resp.bytes().await.unwrap_or_default();
(status, [(header::CONTENT_TYPE, content_type)], body_bytes).into_response()
}
Err(e) => {
tracing::error!(error = %e, url = %upstream_url, "OAuth admin proxy request failed");
oauth_error_response(
StatusCode::BAD_GATEWAY,
"server_error",
"upstream endpoint unreachable",
)
}
}
}
fn oauth_error_response(
status: axum::http::StatusCode,
error: &str,
description: &str,
) -> axum::response::Response {
use axum::{http::header, response::IntoResponse};
let body = serde_json::json!({
"error": error,
"error_description": description,
});
(
status,
[(header::CONTENT_TYPE, "application/json")],
body.to_string(),
)
.into_response()
}
#[derive(Debug, Deserialize)]
struct OAuthErrorResponse {
error: String,
error_description: Option<String>,
}
fn sanitize_oauth_error_code(raw: &str) -> &'static str {
match raw {
"invalid_request" => "invalid_request",
"invalid_client" => "invalid_client",
"invalid_grant" => "invalid_grant",
"unauthorized_client" => "unauthorized_client",
"unsupported_grant_type" => "unsupported_grant_type",
"invalid_scope" => "invalid_scope",
"temporarily_unavailable" => "temporarily_unavailable",
"invalid_target" => "invalid_target",
_ => "server_error",
}
}
pub async fn exchange_token(
http: &OauthHttpClient,
config: &TokenExchangeConfig,
subject_token: &str,
) -> Result<ExchangedToken, crate::error::McpxError> {
use secrecy::ExposeSecret;
let client = http.client_for(config);
let mut req = client
.post(&config.token_url)
.header("Content-Type", "application/x-www-form-urlencoded")
.header("Accept", "application/json");
if config.client_cert.is_none()
&& let Some(ref secret) = config.client_secret
{
use base64::Engine;
let credentials = base64::engine::general_purpose::STANDARD.encode(format!(
"{}:{}",
urlencoding::encode(&config.client_id),
urlencoding::encode(secret.expose_secret()),
));
req = req.header("Authorization", format!("Basic {credentials}"));
}
let form_body = build_exchange_form(config, subject_token);
let resp = http
.send_screened(&config.token_url, req.body(form_body))
.await
.map_err(|e| {
tracing::error!(error = %e, "token exchange request failed");
crate::error::McpxError::Auth("server_error".into())
})?;
let status = resp.status();
let body_bytes = resp.bytes().await.map_err(|e| {
tracing::error!(error = %e, "failed to read token exchange response");
crate::error::McpxError::Auth("server_error".into())
})?;
if !status.is_success() {
core::hint::cold_path();
let parsed = serde_json::from_slice::<OAuthErrorResponse>(&body_bytes).ok();
let short_code = parsed
.as_ref()
.map_or("server_error", |e| sanitize_oauth_error_code(&e.error));
if let Some(ref e) = parsed {
tracing::warn!(
status = %status,
upstream_error = %e.error,
upstream_error_description = e.error_description.as_deref().unwrap_or(""),
client_code = %short_code,
"token exchange rejected by authorization server",
);
} else {
tracing::warn!(
status = %status,
client_code = %short_code,
"token exchange rejected (unparseable upstream body)",
);
}
return Err(crate::error::McpxError::Auth(short_code.into()));
}
let exchanged = serde_json::from_slice::<ExchangedToken>(&body_bytes).map_err(|e| {
tracing::error!(error = %e, "failed to parse token exchange response");
crate::error::McpxError::Auth("server_error".into())
})?;
log_exchanged_token(&exchanged);
Ok(exchanged)
}
fn build_exchange_form(config: &TokenExchangeConfig, subject_token: &str) -> String {
let body = format!(
"grant_type={}&subject_token={}&subject_token_type={}&requested_token_type={}&audience={}",
urlencoding::encode("urn:ietf:params:oauth:grant-type:token-exchange"),
urlencoding::encode(subject_token),
urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
urlencoding::encode("urn:ietf:params:oauth:token-type:access_token"),
urlencoding::encode(&config.audience),
);
if config.client_secret.is_none() {
format!(
"{body}&client_id={}",
urlencoding::encode(&config.client_id)
)
} else {
body
}
}
fn log_exchanged_token(exchanged: &ExchangedToken) {
use base64::Engine;
if !looks_like_jwt(&exchanged.access_token) {
tracing::debug!(
token_len = exchanged.access_token.len(),
issued_token_type = ?exchanged.issued_token_type,
expires_in = exchanged.expires_in,
"exchanged token (opaque)",
);
return;
}
let Some(payload) = exchanged.access_token.split('.').nth(1) else {
return;
};
let Ok(decoded) = base64::engine::general_purpose::URL_SAFE_NO_PAD.decode(payload) else {
return;
};
let Ok(claims) = serde_json::from_slice::<serde_json::Value>(&decoded) else {
return;
};
tracing::debug!(
sub = ?claims.get("sub"),
aud = ?claims.get("aud"),
azp = ?claims.get("azp"),
iss = ?claims.get("iss"),
expires_in = exchanged.expires_in,
"exchanged token claims (JWT)",
);
}
fn replace_client_id(params: &str, upstream_client_id: &str) -> String {
let encoded_id = urlencoding::encode(upstream_client_id);
let mut parts: Vec<String> = params
.split('&')
.filter(|p| !p.starts_with("client_id="))
.map(String::from)
.collect();
parts.push(format!("client_id={encoded_id}"));
parts.join("&")
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use super::*;
#[test]
fn looks_like_jwt_valid() {
let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\",\"typ\":\"JWT\"}");
let payload = URL_SAFE_NO_PAD.encode(b"{}");
let token = format!("{header}.{payload}.signature");
assert!(looks_like_jwt(&token));
}
#[test]
fn looks_like_jwt_rejects_opaque_token() {
assert!(!looks_like_jwt("dGhpcyBpcyBhbiBvcGFxdWUgdG9rZW4"));
}
#[test]
fn looks_like_jwt_rejects_two_segments() {
let header = URL_SAFE_NO_PAD.encode(b"{\"alg\":\"RS256\"}");
let token = format!("{header}.payload");
assert!(!looks_like_jwt(&token));
}
#[test]
fn looks_like_jwt_rejects_four_segments() {
assert!(!looks_like_jwt("a.b.c.d"));
}
#[test]
fn looks_like_jwt_rejects_no_alg() {
let header = URL_SAFE_NO_PAD.encode(b"{\"typ\":\"JWT\"}");
let payload = URL_SAFE_NO_PAD.encode(b"{}");
let token = format!("{header}.{payload}.sig");
assert!(!looks_like_jwt(&token));
}
#[test]
fn protected_resource_metadata_shape() {
let config = OAuthConfig {
issuer: "https://auth.example.com".into(),
audience: "https://mcp.example.com/mcp".into(),
jwks_uri: "https://auth.example.com/.well-known/jwks.json".into(),
scopes: vec![
ScopeMapping {
scope: "mcp:read".into(),
role: "viewer".into(),
},
ScopeMapping {
scope: "mcp:admin".into(),
role: "ops".into(),
},
],
role_claim: None,
role_mappings: vec![],
jwks_cache_ttl: "10m".into(),
proxy: None,
token_exchange: None,
ca_cert_path: None,
allow_http_oauth_urls: false,
max_jwks_keys: default_max_jwks_keys(),
#[allow(
deprecated,
reason = "test fixture: explicit value for the deprecated field"
)]
strict_audience_validation: false,
audience_validation_mode: None,
jwks_max_response_bytes: default_jwks_max_bytes(),
ssrf_allowlist: None,
};
let meta = protected_resource_metadata(
"https://mcp.example.com/mcp",
"https://mcp.example.com",
&config,
);
assert_eq!(meta["resource"], "https://mcp.example.com/mcp");
assert_eq!(meta["authorization_servers"][0], "https://mcp.example.com");
assert_eq!(meta["scopes_supported"].as_array().unwrap().len(), 2);
assert_eq!(meta["bearer_methods_supported"][0], "header");
}
fn validation_https_config() -> OAuthConfig {
OAuthConfig::builder(
"https://auth.example.com",
"mcp",
"https://auth.example.com/.well-known/jwks.json",
)
.build()
}
#[test]
fn validate_accepts_all_https_urls() {
let cfg = validation_https_config();
cfg.validate().expect("all-HTTPS config must validate");
}
#[test]
fn validate_rejects_unparseable_jwks_cache_ttl() {
let mut cfg = validation_https_config();
cfg.jwks_cache_ttl = "not-a-duration".into();
let err = cfg
.validate()
.expect_err("malformed jwks_cache_ttl must be rejected");
let msg = err.to_string();
assert!(
msg.contains("jwks_cache_ttl"),
"error must reference offending field; got {msg:?}"
);
}
#[test]
fn validate_rejects_http_jwks_uri() {
let mut cfg = validation_https_config();
cfg.jwks_uri = "http://auth.example.com/.well-known/jwks.json".into();
let err = cfg.validate().expect_err("http jwks_uri must be rejected");
let msg = err.to_string();
assert!(
msg.contains("oauth.jwks_uri") && msg.contains("https"),
"error must reference offending field + scheme requirement; got {msg:?}"
);
}
#[test]
fn validate_rejects_http_proxy_authorize_url() {
let mut cfg = validation_https_config();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"http://idp.example.com/authorize", "https://idp.example.com/token",
"client",
)
.build(),
);
let err = cfg
.validate()
.expect_err("http authorize_url must be rejected");
assert!(
err.to_string().contains("oauth.proxy.authorize_url"),
"error must reference proxy.authorize_url; got {err}"
);
}
#[test]
fn validate_rejects_http_proxy_token_url() {
let mut cfg = validation_https_config();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"https://idp.example.com/authorize",
"http://idp.example.com/token", "client",
)
.build(),
);
let err = cfg.validate().expect_err("http token_url must be rejected");
assert!(
err.to_string().contains("oauth.proxy.token_url"),
"error must reference proxy.token_url; got {err}"
);
}
#[test]
fn validate_rejects_http_proxy_introspection_and_revocation_urls() {
let mut cfg = validation_https_config();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"https://idp.example.com/authorize",
"https://idp.example.com/token",
"client",
)
.introspection_url("http://idp.example.com/introspect")
.build(),
);
let err = cfg
.validate()
.expect_err("http introspection_url must be rejected");
assert!(err.to_string().contains("oauth.proxy.introspection_url"));
let mut cfg = validation_https_config();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"https://idp.example.com/authorize",
"https://idp.example.com/token",
"client",
)
.revocation_url("http://idp.example.com/revoke")
.build(),
);
let err = cfg
.validate()
.expect_err("http revocation_url must be rejected");
assert!(err.to_string().contains("oauth.proxy.revocation_url"));
}
#[test]
fn validate_rejects_exposed_admin_endpoints_without_auth() {
let mut cfg = validation_https_config();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"https://idp.example.com/authorize",
"https://idp.example.com/token",
"client",
)
.introspection_url("https://idp.example.com/introspect")
.expose_admin_endpoints(true)
.build(),
);
let err = cfg
.validate()
.expect_err("expose_admin_endpoints without auth must fail");
let msg = err.to_string();
assert!(msg.contains("require_auth_on_admin_endpoints"), "{msg}");
assert!(
msg.contains("allow_unauthenticated_admin_endpoints"),
"{msg}"
);
}
#[test]
fn validate_accepts_exposed_admin_endpoints_with_auth() {
let mut cfg = validation_https_config();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"https://idp.example.com/authorize",
"https://idp.example.com/token",
"client",
)
.introspection_url("https://idp.example.com/introspect")
.expose_admin_endpoints(true)
.require_auth_on_admin_endpoints(true)
.build(),
);
cfg.validate()
.expect("authed admin endpoints must validate");
}
#[test]
fn validate_accepts_exposed_admin_endpoints_with_explicit_unauth_optout() {
let mut cfg = validation_https_config();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"https://idp.example.com/authorize",
"https://idp.example.com/token",
"client",
)
.introspection_url("https://idp.example.com/introspect")
.expose_admin_endpoints(true)
.allow_unauthenticated_admin_endpoints(true)
.build(),
);
cfg.validate()
.expect("explicit unauth opt-out must validate");
}
#[test]
fn validate_accepts_unexposed_admin_endpoints_without_auth() {
let mut cfg = validation_https_config();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"https://idp.example.com/authorize",
"https://idp.example.com/token",
"client",
)
.introspection_url("https://idp.example.com/introspect")
.build(),
);
cfg.validate()
.expect("unexposed admin endpoints must validate");
}
#[test]
fn validate_rejects_http_token_exchange_url() {
let mut cfg = validation_https_config();
cfg.token_exchange = Some(TokenExchangeConfig::new(
"http://idp.example.com/token".into(), "client".into(),
None,
None,
"downstream".into(),
));
let err = cfg
.validate()
.expect_err("http token_exchange.token_url must be rejected");
assert!(
err.to_string().contains("oauth.token_exchange.token_url"),
"error must reference token_exchange.token_url; got {err}"
);
}
#[test]
fn validate_rejects_unparseable_url() {
let mut cfg = validation_https_config();
cfg.jwks_uri = "not a url".into();
let err = cfg
.validate()
.expect_err("unparseable URL must be rejected");
assert!(err.to_string().contains("invalid URL"));
}
#[test]
fn validate_rejects_non_http_scheme() {
let mut cfg = validation_https_config();
cfg.jwks_uri = "file:///etc/passwd".into();
let err = cfg.validate().expect_err("file:// scheme must be rejected");
let msg = err.to_string();
assert!(
msg.contains("must use https scheme") && msg.contains("file"),
"error must reject non-http(s) schemes; got {msg:?}"
);
}
#[test]
fn validate_accepts_http_with_escape_hatch() {
let mut cfg = OAuthConfig::builder(
"http://auth.local",
"mcp",
"http://auth.local/.well-known/jwks.json",
)
.allow_http_oauth_urls(true)
.build();
cfg.proxy = Some(
OAuthProxyConfig::builder(
"http://idp.local/authorize",
"http://idp.local/token",
"client",
)
.introspection_url("http://idp.local/introspect")
.revocation_url("http://idp.local/revoke")
.build(),
);
cfg.token_exchange = Some(TokenExchangeConfig::new(
"http://idp.local/token".into(),
"client".into(),
Some(secrecy::SecretString::new("dev-secret".into())),
None,
"downstream".into(),
));
cfg.validate()
.expect("escape hatch must permit http on all URL fields");
}
#[test]
fn validate_with_escape_hatch_still_rejects_unparseable() {
let mut cfg = validation_https_config();
cfg.allow_http_oauth_urls = true;
cfg.jwks_uri = "::not-a-url::".into();
cfg.validate()
.expect_err("escape hatch must NOT bypass URL parsing");
}
#[tokio::test]
async fn jwks_cache_rejects_redirect_downgrade_to_http() {
rustls::crypto::ring::default_provider()
.install_default()
.ok();
let policy = reqwest::redirect::Policy::custom(|attempt| {
if attempt.url().scheme() != "https" {
attempt.error("redirect to non-HTTPS URL refused")
} else if attempt.previous().len() >= 2 {
attempt.error("too many redirects (max 2)")
} else {
attempt.follow()
}
});
let test_bypass: crate::ssrf_resolver::TestLoopbackBypass = Arc::new(AtomicBool::new(true));
let allowlist = Arc::new(crate::ssrf::CompiledSsrfAllowlist::default());
let resolver: Arc<dyn reqwest::dns::Resolve> = Arc::new(
crate::ssrf_resolver::SsrfScreeningResolver::new(Arc::clone(&allowlist), test_bypass),
);
let client = reqwest::Client::builder()
.no_proxy()
.dns_resolver(Arc::clone(&resolver))
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(3))
.redirect(policy)
.build()
.expect("test client builds");
let mock = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(
wiremock::ResponseTemplate::new(302)
.insert_header("location", "http://example.invalid/jwks.json"),
)
.mount(&mock)
.await;
let url = format!("{}/jwks.json", mock.uri());
let err = client
.get(&url)
.send()
.await
.expect_err("redirect policy must reject scheme downgrade");
let chain = format!("{err:#}");
assert!(
chain.contains("redirect to non-HTTPS URL refused")
|| chain.to_lowercase().contains("redirect"),
"error must surface redirect-policy rejection; got {chain:?}"
);
}
use rsa::{pkcs8::EncodePrivateKey, traits::PublicKeyParts};
fn generate_test_keypair(kid: &str) -> (String, serde_json::Value) {
let mut rng = rsa::rand_core::OsRng;
let private_key = rsa::RsaPrivateKey::new(&mut rng, 2048).expect("keypair generation");
let private_pem = private_key
.to_pkcs8_pem(rsa::pkcs8::LineEnding::LF)
.expect("PKCS8 PEM export")
.to_string();
let public_key = private_key.to_public_key();
let n = URL_SAFE_NO_PAD.encode(public_key.n().to_bytes_be());
let e = URL_SAFE_NO_PAD.encode(public_key.e().to_bytes_be());
let jwks = serde_json::json!({
"keys": [{
"kty": "RSA",
"use": "sig",
"alg": "RS256",
"kid": kid,
"n": n,
"e": e
}]
});
(private_pem, jwks)
}
fn mint_token(
private_pem: &str,
kid: &str,
issuer: &str,
audience: &str,
subject: &str,
scope: &str,
) -> String {
let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
.expect("encoding key from PEM");
let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
header.kid = Some(kid.into());
let now = jsonwebtoken::get_current_timestamp();
let claims = serde_json::json!({
"iss": issuer,
"aud": audience,
"sub": subject,
"scope": scope,
"exp": now + 3600,
"iat": now,
});
jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
}
fn test_config(jwks_uri: &str) -> OAuthConfig {
OAuthConfig {
issuer: "https://auth.test.local".into(),
audience: "https://mcp.test.local/mcp".into(),
jwks_uri: jwks_uri.into(),
scopes: vec![
ScopeMapping {
scope: "mcp:read".into(),
role: "viewer".into(),
},
ScopeMapping {
scope: "mcp:admin".into(),
role: "ops".into(),
},
],
role_claim: None,
role_mappings: vec![],
jwks_cache_ttl: "5m".into(),
proxy: None,
token_exchange: None,
ca_cert_path: None,
allow_http_oauth_urls: true,
max_jwks_keys: default_max_jwks_keys(),
#[allow(
deprecated,
reason = "test fixture: explicit value for the deprecated field"
)]
strict_audience_validation: false,
audience_validation_mode: None,
jwks_max_response_bytes: default_jwks_max_bytes(),
ssrf_allowlist: None,
}
}
fn test_cache(config: &OAuthConfig) -> JwksCache {
JwksCache::new(config).unwrap().__test_allow_loopback_ssrf()
}
#[tokio::test]
async fn valid_jwt_returns_identity() {
let kid = "test-key-1";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"ci-bot",
"mcp:read mcp:other",
);
let identity = cache.validate_token(&token).await;
assert!(identity.is_some(), "valid JWT should authenticate");
let id = identity.unwrap();
assert_eq!(id.name, "ci-bot");
assert_eq!(id.role, "viewer"); assert_eq!(id.method, AuthMethod::OAuthJwt);
}
#[tokio::test]
async fn wrong_issuer_rejected() {
let kid = "test-key-2";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let token = mint_token(
&pem,
kid,
"https://wrong-issuer.example.com", "https://mcp.test.local/mcp",
"attacker",
"mcp:admin",
);
assert!(cache.validate_token(&token).await.is_none());
}
#[tokio::test]
async fn wrong_audience_rejected() {
let kid = "test-key-3";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://wrong-audience.example.com", "attacker",
"mcp:admin",
);
assert!(cache.validate_token(&token).await.is_none());
}
#[tokio::test]
async fn expired_jwt_rejected() {
let kid = "test-key-4";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let encoding_key =
jsonwebtoken::EncodingKey::from_rsa_pem(pem.as_bytes()).expect("encoding key");
let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
header.kid = Some(kid.into());
let now = jsonwebtoken::get_current_timestamp();
let claims = serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://mcp.test.local/mcp",
"sub": "expired-bot",
"scope": "mcp:read",
"exp": now - 120,
"iat": now - 3720,
});
let token = jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding");
assert!(cache.validate_token(&token).await.is_none());
}
#[tokio::test]
async fn no_matching_scope_rejected() {
let kid = "test-key-5";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"limited-bot",
"some:other:scope", );
assert!(cache.validate_token(&token).await.is_none());
}
#[tokio::test]
async fn wrong_signing_key_rejected() {
let kid = "test-key-6";
let (_pem, jwks) = generate_test_keypair(kid);
let (attacker_pem, _) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let token = mint_token(
&attacker_pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"attacker",
"mcp:admin",
);
assert!(cache.validate_token(&token).await.is_none());
}
#[tokio::test]
async fn admin_scope_maps_to_ops_role() {
let kid = "test-key-7";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"admin-bot",
"mcp:admin",
);
let id = cache
.validate_token(&token)
.await
.expect("should authenticate");
assert_eq!(id.role, "ops");
assert_eq!(id.name, "admin-bot");
}
#[tokio::test]
async fn jwks_server_down_returns_none() {
let config = test_config("http://127.0.0.1:1/jwks.json");
let cache = test_cache(&config);
let kid = "orphan-key";
let (pem, _) = generate_test_keypair(kid);
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"bot",
"mcp:read",
);
assert!(cache.validate_token(&token).await.is_none());
}
#[test]
fn resolve_claim_path_flat_string() {
let mut extra = HashMap::new();
extra.insert(
"scope".into(),
serde_json::Value::String("mcp:read mcp:admin".into()),
);
let values = resolve_claim_path(&extra, "scope");
assert_eq!(values, vec!["mcp:read", "mcp:admin"]);
}
#[test]
fn resolve_claim_path_flat_array() {
let mut extra = HashMap::new();
extra.insert(
"roles".into(),
serde_json::json!(["mcp-admin", "mcp-viewer"]),
);
let values = resolve_claim_path(&extra, "roles");
assert_eq!(values, vec!["mcp-admin", "mcp-viewer"]);
}
#[test]
fn resolve_claim_path_nested_keycloak() {
let mut extra = HashMap::new();
extra.insert(
"realm_access".into(),
serde_json::json!({"roles": ["uma_authorization", "mcp-admin"]}),
);
let values = resolve_claim_path(&extra, "realm_access.roles");
assert_eq!(values, vec!["uma_authorization", "mcp-admin"]);
}
#[test]
fn resolve_claim_path_missing_returns_empty() {
let extra = HashMap::new();
assert!(resolve_claim_path(&extra, "nonexistent.path").is_empty());
}
#[test]
fn resolve_claim_path_numeric_leaf_returns_empty() {
let mut extra = HashMap::new();
extra.insert("count".into(), serde_json::json!(42));
assert!(resolve_claim_path(&extra, "count").is_empty());
}
fn make_claims(json: serde_json::Value) -> Claims {
serde_json::from_value(json).expect("test claims must deserialize")
}
#[test]
fn first_class_scope_claim_splits_on_whitespace() {
let claims = make_claims(serde_json::json!({
"iss": "https://issuer.example.com",
"exp": 9_999_999_999_u64,
"scope": "read write admin",
}));
let values = first_class_claim_values(&claims, "scope");
assert_eq!(values, vec!["read", "write", "admin"]);
}
#[test]
fn first_class_sub_claim_returns_single_value() {
let claims = make_claims(serde_json::json!({
"iss": "https://issuer.example.com",
"exp": 9_999_999_999_u64,
"sub": "service-account-orders",
}));
let values = first_class_claim_values(&claims, "sub");
assert_eq!(values, vec!["service-account-orders"]);
}
#[test]
fn first_class_aud_claim_returns_every_audience() {
let claims = make_claims(serde_json::json!({
"iss": "https://issuer.example.com",
"exp": 9_999_999_999_u64,
"aud": ["api-a", "api-b"],
}));
let values = first_class_claim_values(&claims, "aud");
assert_eq!(values, vec!["api-a", "api-b"]);
}
#[test]
fn first_class_unknown_path_returns_empty() {
let claims = make_claims(serde_json::json!({
"iss": "https://issuer.example.com",
"exp": 9_999_999_999_u64,
}));
assert!(first_class_claim_values(&claims, "realm_access.roles").is_empty());
}
fn mint_token_with_claims(private_pem: &str, kid: &str, claims: &serde_json::Value) -> String {
let encoding_key = jsonwebtoken::EncodingKey::from_rsa_pem(private_pem.as_bytes())
.expect("encoding key from PEM");
let mut header = jsonwebtoken::Header::new(Algorithm::RS256);
header.kid = Some(kid.into());
jsonwebtoken::encode(&header, &claims, &encoding_key).expect("JWT encoding")
}
fn test_config_with_role_claim(
jwks_uri: &str,
role_claim: &str,
role_mappings: Vec<RoleMapping>,
) -> OAuthConfig {
OAuthConfig {
issuer: "https://auth.test.local".into(),
audience: "https://mcp.test.local/mcp".into(),
jwks_uri: jwks_uri.into(),
scopes: vec![],
role_claim: Some(role_claim.into()),
role_mappings,
jwks_cache_ttl: "5m".into(),
proxy: None,
token_exchange: None,
ca_cert_path: None,
allow_http_oauth_urls: true,
max_jwks_keys: default_max_jwks_keys(),
#[allow(
deprecated,
reason = "test fixture: explicit value for the deprecated field"
)]
strict_audience_validation: false,
audience_validation_mode: None,
jwks_max_response_bytes: default_jwks_max_bytes(),
ssrf_allowlist: None,
}
}
#[tokio::test]
async fn screen_oauth_target_rejects_literal_ip() {
let err = screen_oauth_target(
"https://127.0.0.1/jwks.json",
false,
&crate::ssrf::CompiledSsrfAllowlist::default(),
)
.await
.expect_err("literal IPs must be rejected");
let msg = err.to_string();
assert!(msg.contains("literal IPv4 addresses are forbidden"));
}
#[tokio::test]
async fn screen_oauth_target_rejects_private_dns_resolution() {
let err = screen_oauth_target(
"https://localhost/jwks.json",
false,
&crate::ssrf::CompiledSsrfAllowlist::default(),
)
.await
.expect_err("localhost resolution must be rejected");
let msg = err.to_string();
assert!(
msg.contains("blocked IP") && msg.contains("loopback"),
"got {msg:?}"
);
}
#[tokio::test]
async fn screen_oauth_target_rejects_literal_ip_even_with_allow_http() {
let err = screen_oauth_target(
"http://127.0.0.1/jwks.json",
true,
&crate::ssrf::CompiledSsrfAllowlist::default(),
)
.await
.expect_err("literal IPs must still be rejected when http is allowed");
let msg = err.to_string();
assert!(msg.contains("literal IPv4 addresses are forbidden"));
}
#[tokio::test]
async fn screen_oauth_target_rejects_private_dns_even_with_allow_http() {
let err = screen_oauth_target(
"http://localhost/jwks.json",
true,
&crate::ssrf::CompiledSsrfAllowlist::default(),
)
.await
.expect_err("private DNS resolution must still be rejected when http is allowed");
let msg = err.to_string();
assert!(
msg.contains("blocked IP") && msg.contains("loopback"),
"got {msg:?}"
);
}
#[tokio::test]
async fn screen_oauth_target_allows_public_hostname() {
screen_oauth_target(
"https://example.com/.well-known/jwks.json",
false,
&crate::ssrf::CompiledSsrfAllowlist::default(),
)
.await
.expect("public hostname should pass screening");
}
fn make_allowlist(hosts: &[&str], cidrs: &[&str]) -> crate::ssrf::CompiledSsrfAllowlist {
let raw = OAuthSsrfAllowlist {
hosts: hosts.iter().map(|s| (*s).to_string()).collect(),
cidrs: cidrs.iter().map(|s| (*s).to_string()).collect(),
};
compile_oauth_ssrf_allowlist(&raw).expect("test allowlist compiles")
}
#[test]
fn compile_oauth_ssrf_allowlist_lowercases_and_dedupes_hosts() {
let raw = OAuthSsrfAllowlist {
hosts: vec!["RHBK.ops.example.com".into(), "rhbk.ops.example.com".into()],
cidrs: vec![],
};
let compiled = compile_oauth_ssrf_allowlist(&raw).expect("compiles");
assert_eq!(compiled.host_count(), 1);
assert!(compiled.host_allowed("rhbk.ops.example.com"));
assert!(compiled.host_allowed("RHBK.OPS.EXAMPLE.COM"));
}
#[test]
fn compile_oauth_ssrf_allowlist_rejects_literal_ip_in_hosts() {
let raw = OAuthSsrfAllowlist {
hosts: vec!["10.0.0.1".into()],
cidrs: vec![],
};
let err = compile_oauth_ssrf_allowlist(&raw).expect_err("literal IP in hosts");
assert!(err.contains("literal IPs are forbidden"), "got {err:?}");
}
#[test]
fn compile_oauth_ssrf_allowlist_rejects_host_with_port() {
let raw = OAuthSsrfAllowlist {
hosts: vec!["rhbk.ops.example.com:8443".into()],
cidrs: vec![],
};
let err = compile_oauth_ssrf_allowlist(&raw).expect_err("host:port");
assert!(err.contains("must be a bare DNS hostname"), "got {err:?}");
}
#[test]
fn compile_oauth_ssrf_allowlist_rejects_invalid_cidr() {
let raw = OAuthSsrfAllowlist {
hosts: vec![],
cidrs: vec!["not-a-cidr".into()],
};
let err = compile_oauth_ssrf_allowlist(&raw).expect_err("invalid CIDR");
assert!(err.contains("oauth.ssrf_allowlist.cidrs[0]"), "got {err:?}");
}
#[test]
fn validate_rejects_misconfigured_allowlist() {
let mut cfg = OAuthConfig::builder(
"https://auth.example.com/",
"mcp",
"https://auth.example.com/jwks.json",
)
.build();
cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
hosts: vec!["10.0.0.1".into()],
cidrs: vec![],
});
let err = cfg
.validate()
.expect_err("literal IP host must be rejected");
assert!(
err.to_string().contains("oauth.ssrf_allowlist"),
"got {err}"
);
}
#[tokio::test]
async fn screen_oauth_target_with_allowlist_emits_helpful_error() {
let allow = make_allowlist(&["other.example.com"], &["10.0.0.0/8"]);
let err = screen_oauth_target("https://localhost/jwks.json", false, &allow)
.await
.expect_err("loopback must still be blocked when not in allowlist");
let msg = err.to_string();
assert!(msg.contains("OAuth target blocked"), "got {msg:?}");
assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
assert!(msg.contains("SECURITY.md"), "got {msg:?}");
}
#[tokio::test]
async fn screen_oauth_target_empty_allowlist_uses_legacy_message() {
let err = screen_oauth_target(
"https://localhost/jwks.json",
false,
&crate::ssrf::CompiledSsrfAllowlist::default(),
)
.await
.expect_err("loopback rejection");
let msg = err.to_string();
assert!(msg.contains("blocked IP"), "got {msg:?}");
assert!(msg.contains("loopback"), "got {msg:?}");
assert!(!msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
}
#[tokio::test]
async fn screen_oauth_target_allows_loopback_when_host_allowlisted() {
let allow = make_allowlist(&["localhost"], &[]);
screen_oauth_target("https://localhost/jwks.json", false, &allow)
.await
.expect("allowlisted host must pass");
}
#[tokio::test]
async fn screen_oauth_target_allows_loopback_when_cidr_allowlisted() {
let allow = make_allowlist(&[], &["127.0.0.0/8", "::1/128"]);
screen_oauth_target("https://localhost/jwks.json", false, &allow)
.await
.expect("allowlisted CIDR must pass");
}
#[tokio::test]
async fn jwks_cache_rejects_misconfigured_allowlist_at_startup() {
let mut cfg = OAuthConfig::builder(
"https://auth.example.com/",
"mcp",
"https://auth.example.com/jwks.json",
)
.build();
cfg.ssrf_allowlist = Some(OAuthSsrfAllowlist {
hosts: vec![],
cidrs: vec!["bad-cidr".into()],
});
let Err(err) = JwksCache::new(&cfg) else {
panic!("invalid CIDR must fail JwksCache::new")
};
let msg = err.to_string();
assert!(msg.contains("oauth.ssrf_allowlist"), "got {msg:?}");
}
#[tokio::test]
async fn audience_falls_back_to_azp_by_default() {
let kid = "test-audience-azp-default";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let token = mint_token_with_claims(
&pem,
kid,
&serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://some-other-resource.example.com",
"azp": "https://mcp.test.local/mcp",
"sub": "compat-client",
"scope": "mcp:read",
"exp": now + 3600,
"iat": now,
}),
);
let identity = cache
.validate_token_with_reason(&token)
.await
.expect("azp fallback should remain enabled by default");
assert_eq!(identity.role, "viewer");
}
#[tokio::test]
async fn strict_audience_validation_rejects_azp_only_match() {
let kid = "test-audience-azp-strict";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let mut config = test_config(&jwks_uri);
#[allow(deprecated, reason = "covers the legacy bool resolution path")]
{
config.strict_audience_validation = true;
}
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let token = mint_token_with_claims(
&pem,
kid,
&serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://some-other-resource.example.com",
"azp": "https://mcp.test.local/mcp",
"sub": "strict-client",
"scope": "mcp:read",
"exp": now + 3600,
"iat": now,
}),
);
let failure = cache
.validate_token_with_reason(&token)
.await
.expect_err("strict audience validation must ignore azp fallback");
assert_eq!(failure, JwtValidationFailure::Invalid);
}
#[tokio::test]
async fn warn_mode_accepts_azp_only_match_and_warns_once() {
let kid = "test-audience-warn-mode";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let mut config = test_config(&jwks_uri);
config.audience_validation_mode = Some(AudienceValidationMode::Warn);
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let claims = serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://some-other-resource.example.com",
"azp": "https://mcp.test.local/mcp",
"sub": "warn-client",
"scope": "mcp:read",
"exp": now + 3600,
"iat": now,
});
let token = mint_token_with_claims(&pem, kid, &claims);
let identity = cache
.validate_token_with_reason(&token)
.await
.expect("warn mode must accept azp-only match");
assert_eq!(identity.role, "viewer");
assert!(
cache.azp_fallback_warned.load(Ordering::Relaxed),
"warn-once flag should be set after first azp-only match"
);
let token2 = mint_token_with_claims(&pem, kid, &claims);
cache
.validate_token_with_reason(&token2)
.await
.expect("warn mode must continue accepting subsequent matches");
assert!(
cache.azp_fallback_warned.load(Ordering::Relaxed),
"warn-once flag must remain set; the assertion guards against accidental clearing"
);
}
#[tokio::test]
async fn permissive_mode_accepts_azp_only_match_silently() {
let kid = "test-audience-permissive-mode";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let mut config = test_config(&jwks_uri);
config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let token = mint_token_with_claims(
&pem,
kid,
&serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://some-other-resource.example.com",
"azp": "https://mcp.test.local/mcp",
"sub": "permissive-client",
"scope": "mcp:read",
"exp": now + 3600,
"iat": now,
}),
);
cache
.validate_token_with_reason(&token)
.await
.expect("permissive mode must accept azp-only match");
assert!(
!cache.azp_fallback_warned.load(Ordering::Relaxed),
"permissive mode must not flip the warn-once flag"
);
}
#[test]
fn audience_validation_mode_overrides_legacy_bool() {
let mut config = OAuthConfig::default();
#[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
{
config.strict_audience_validation = false;
}
config.audience_validation_mode = Some(AudienceValidationMode::Strict);
assert_eq!(
config.effective_audience_validation_mode(),
AudienceValidationMode::Strict,
"explicit mode must override legacy false"
);
let mut config = OAuthConfig::default();
#[allow(deprecated, reason = "covers the precedence rule for the legacy bool")]
{
config.strict_audience_validation = true;
}
config.audience_validation_mode = Some(AudienceValidationMode::Permissive);
assert_eq!(
config.effective_audience_validation_mode(),
AudienceValidationMode::Permissive,
"explicit mode must override legacy true"
);
}
#[test]
fn audience_validation_mode_default_is_warn_when_unset() {
let config = OAuthConfig::default();
assert_eq!(
config.effective_audience_validation_mode(),
AudienceValidationMode::Warn,
"unset mode + unset bool must resolve to Warn (the new default)"
);
}
#[test]
fn audience_validation_legacy_bool_true_resolves_to_strict() {
let mut config = OAuthConfig::default();
#[allow(deprecated, reason = "covers the legacy bool resolution path")]
{
config.strict_audience_validation = true;
}
assert_eq!(
config.effective_audience_validation_mode(),
AudienceValidationMode::Strict,
"legacy bool=true must resolve to Strict for backward compat"
);
}
#[derive(Clone, Default)]
struct CapturedLogs(Arc<std::sync::Mutex<Vec<u8>>>);
impl CapturedLogs {
fn contents(&self) -> String {
let bytes = self.0.lock().map(|guard| guard.clone()).unwrap_or_default();
String::from_utf8(bytes).unwrap_or_default()
}
}
struct CapturedLogsWriter(Arc<std::sync::Mutex<Vec<u8>>>);
impl std::io::Write for CapturedLogsWriter {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if let Ok(mut guard) = self.0.lock() {
guard.extend_from_slice(buf);
}
Ok(buf.len())
}
fn flush(&mut self) -> std::io::Result<()> {
Ok(())
}
}
impl<'a> tracing_subscriber::fmt::MakeWriter<'a> for CapturedLogs {
type Writer = CapturedLogsWriter;
fn make_writer(&'a self) -> Self::Writer {
CapturedLogsWriter(Arc::clone(&self.0))
}
}
#[tokio::test]
async fn jwks_response_size_cap_returns_none_and_logs_warning() {
let kid = "oversized-jwks";
let (_pem, jwks) = generate_test_keypair(kid);
let mut oversized_body = serde_json::to_string(&jwks).expect("jwks json");
oversized_body.push_str(&" ".repeat(4096));
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(
wiremock::ResponseTemplate::new(200)
.insert_header("content-type", "application/json")
.set_body_string(oversized_body),
)
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let mut config = test_config(&jwks_uri);
config.jwks_max_response_bytes = 256;
let cache = test_cache(&config);
let logs = CapturedLogs::default();
let subscriber = tracing_subscriber::fmt()
.with_writer(logs.clone())
.with_ansi(false)
.without_time()
.finish();
let _guard = tracing::subscriber::set_default(subscriber);
let result = cache.fetch_jwks().await;
assert!(result.is_none(), "oversized JWKS must be dropped");
assert!(
logs.contents()
.contains("JWKS response exceeded configured size cap"),
"expected cap-exceeded warning in logs"
);
}
#[tokio::test]
async fn role_claim_keycloak_nested_array() {
let kid = "test-role-1";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config_with_role_claim(
&jwks_uri,
"realm_access.roles",
vec![
RoleMapping {
claim_value: "mcp-admin".into(),
role: "ops".into(),
},
RoleMapping {
claim_value: "mcp-viewer".into(),
role: "viewer".into(),
},
],
);
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let token = mint_token_with_claims(
&pem,
kid,
&serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://mcp.test.local/mcp",
"sub": "keycloak-user",
"exp": now + 3600,
"iat": now,
"realm_access": { "roles": ["uma_authorization", "mcp-admin"] }
}),
);
let id = cache
.validate_token(&token)
.await
.expect("should authenticate");
assert_eq!(id.name, "keycloak-user");
assert_eq!(id.role, "ops");
}
#[tokio::test]
async fn role_claim_flat_roles_array() {
let kid = "test-role-2";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config_with_role_claim(
&jwks_uri,
"roles",
vec![
RoleMapping {
claim_value: "MCP.Admin".into(),
role: "ops".into(),
},
RoleMapping {
claim_value: "MCP.Reader".into(),
role: "viewer".into(),
},
],
);
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let token = mint_token_with_claims(
&pem,
kid,
&serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://mcp.test.local/mcp",
"sub": "azure-ad-user",
"exp": now + 3600,
"iat": now,
"roles": ["MCP.Reader", "OtherApp.Admin"]
}),
);
let id = cache
.validate_token(&token)
.await
.expect("should authenticate");
assert_eq!(id.name, "azure-ad-user");
assert_eq!(id.role, "viewer");
}
#[tokio::test]
async fn role_claim_no_matching_value_rejected() {
let kid = "test-role-3";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config_with_role_claim(
&jwks_uri,
"roles",
vec![RoleMapping {
claim_value: "mcp-admin".into(),
role: "ops".into(),
}],
);
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let token = mint_token_with_claims(
&pem,
kid,
&serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://mcp.test.local/mcp",
"sub": "limited-user",
"exp": now + 3600,
"iat": now,
"roles": ["some-other-role"]
}),
);
assert!(cache.validate_token(&token).await.is_none());
}
#[tokio::test]
async fn role_claim_space_separated_string() {
let kid = "test-role-4";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config_with_role_claim(
&jwks_uri,
"custom_scope",
vec![
RoleMapping {
claim_value: "write".into(),
role: "ops".into(),
},
RoleMapping {
claim_value: "read".into(),
role: "viewer".into(),
},
],
);
let cache = test_cache(&config);
let now = jsonwebtoken::get_current_timestamp();
let token = mint_token_with_claims(
&pem,
kid,
&serde_json::json!({
"iss": "https://auth.test.local",
"aud": "https://mcp.test.local/mcp",
"sub": "custom-client",
"exp": now + 3600,
"iat": now,
"custom_scope": "read audit"
}),
);
let id = cache
.validate_token(&token)
.await
.expect("should authenticate");
assert_eq!(id.name, "custom-client");
assert_eq!(id.role, "viewer");
}
#[tokio::test]
async fn scope_backward_compat_without_role_claim() {
let kid = "test-compat-1";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri); let cache = test_cache(&config);
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"legacy-bot",
"mcp:admin other:scope",
);
let id = cache
.validate_token(&token)
.await
.expect("should authenticate");
assert_eq!(id.name, "legacy-bot");
assert_eq!(id.role, "ops"); }
#[tokio::test]
async fn jwks_refresh_deduplication() {
let kid = "test-dedup";
let (pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.expect(1) .mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = Arc::new(test_cache(&config));
let token = mint_token(
&pem,
kid,
"https://auth.test.local",
"https://mcp.test.local/mcp",
"concurrent-bot",
"mcp:read",
);
let mut handles = Vec::new();
for _ in 0..5 {
let c = Arc::clone(&cache);
let t = token.clone();
handles.push(tokio::spawn(async move { c.validate_token(&t).await }));
}
for h in handles {
let result = h.await.unwrap();
assert!(result.is_some(), "all concurrent requests should succeed");
}
}
#[tokio::test]
async fn jwks_refresh_cooldown_blocks_rapid_requests() {
let kid = "test-cooldown";
let (_pem, jwks) = generate_test_keypair(kid);
let mock_server = wiremock::MockServer::start().await;
let _mock = wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/jwks.json"))
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(&jwks))
.expect(1) .mount(&mock_server)
.await;
let jwks_uri = format!("{}/jwks.json", mock_server.uri());
let config = test_config(&jwks_uri);
let cache = test_cache(&config);
let fake_token1 =
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTEifQ.e30.sig";
let _ = cache.validate_token(fake_token1).await;
let fake_token2 =
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTIifQ.e30.sig";
let _ = cache.validate_token(fake_token2).await;
let fake_token3 =
"eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6InVua25vd24ta2lkLTMifQ.e30.sig";
let _ = cache.validate_token(fake_token3).await;
}
fn proxy_cfg(token_url: &str) -> OAuthProxyConfig {
OAuthProxyConfig {
authorize_url: "https://example.invalid/auth".into(),
token_url: token_url.into(),
client_id: "mcp-client".into(),
client_secret: Some(secrecy::SecretString::from("shh".to_owned())),
introspection_url: None,
revocation_url: None,
expose_admin_endpoints: false,
require_auth_on_admin_endpoints: false,
allow_unauthenticated_admin_endpoints: false,
}
}
fn test_http_client() -> OauthHttpClient {
rustls::crypto::ring::default_provider()
.install_default()
.ok();
let config = OAuthConfig::builder(
"https://auth.test.local",
"https://mcp.test.local/mcp",
"https://auth.test.local/.well-known/jwks.json",
)
.allow_http_oauth_urls(true)
.build();
OauthHttpClient::with_config(&config)
.expect("build test http client")
.__test_allow_loopback_ssrf()
}
#[tokio::test]
async fn introspect_proxies_and_injects_client_credentials() {
use wiremock::matchers::{body_string_contains, method, path};
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(method("POST"))
.and(path("/introspect"))
.and(body_string_contains("client_id=mcp-client"))
.and(body_string_contains("client_secret=shh"))
.and(body_string_contains("token=abc"))
.respond_with(
wiremock::ResponseTemplate::new(200).set_body_json(serde_json::json!({
"active": true,
"scope": "read"
})),
)
.expect(1)
.mount(&mock_server)
.await;
let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
proxy.introspection_url = Some(format!("{}/introspect", mock_server.uri()));
let http = test_http_client();
let resp = handle_introspect(&http, &proxy, "token=abc").await;
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn introspect_returns_404_when_not_configured() {
let proxy = proxy_cfg("https://example.invalid/token");
let http = test_http_client();
let resp = handle_introspect(&http, &proxy, "token=abc").await;
assert_eq!(resp.status(), 404);
}
#[tokio::test]
async fn revoke_proxies_and_returns_upstream_status() {
use wiremock::matchers::{method, path};
let mock_server = wiremock::MockServer::start().await;
wiremock::Mock::given(method("POST"))
.and(path("/revoke"))
.respond_with(wiremock::ResponseTemplate::new(200))
.expect(1)
.mount(&mock_server)
.await;
let mut proxy = proxy_cfg(&format!("{}/token", mock_server.uri()));
proxy.revocation_url = Some(format!("{}/revoke", mock_server.uri()));
let http = test_http_client();
let resp = handle_revoke(&http, &proxy, "token=abc").await;
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn revoke_returns_404_when_not_configured() {
let proxy = proxy_cfg("https://example.invalid/token");
let http = test_http_client();
let resp = handle_revoke(&http, &proxy, "token=abc").await;
assert_eq!(resp.status(), 404);
}
#[test]
fn metadata_advertises_endpoints_only_when_configured() {
let mut cfg = test_config("https://auth.test.local/jwks.json");
let m = authorization_server_metadata("https://mcp.local", &cfg);
assert!(m.get("introspection_endpoint").is_none());
assert!(m.get("revocation_endpoint").is_none());
let mut proxy = proxy_cfg("https://upstream.local/token");
proxy.introspection_url = Some("https://upstream.local/introspect".into());
proxy.revocation_url = Some("https://upstream.local/revoke".into());
cfg.proxy = Some(proxy);
let m = authorization_server_metadata("https://mcp.local", &cfg);
assert!(
m.get("introspection_endpoint").is_none(),
"introspection must not be advertised when expose_admin_endpoints=false"
);
assert!(
m.get("revocation_endpoint").is_none(),
"revocation must not be advertised when expose_admin_endpoints=false"
);
if let Some(p) = cfg.proxy.as_mut() {
p.expose_admin_endpoints = true;
p.revocation_url = None;
}
let m = authorization_server_metadata("https://mcp.local", &cfg);
assert_eq!(
m["introspection_endpoint"],
serde_json::Value::String("https://mcp.local/introspect".into())
);
assert!(m.get("revocation_endpoint").is_none());
if let Some(p) = cfg.proxy.as_mut() {
p.revocation_url = Some("https://upstream.local/revoke".into());
}
let m = authorization_server_metadata("https://mcp.local", &cfg);
assert_eq!(
m["revocation_endpoint"],
serde_json::Value::String("https://mcp.local/revoke".into())
);
}
fn https_cfg_with_tx(tx: TokenExchangeConfig) -> OAuthConfig {
let mut cfg = validation_https_config();
cfg.token_exchange = Some(tx);
cfg
}
fn tx_with(
client_secret: Option<&str>,
client_cert: Option<ClientCertConfig>,
) -> TokenExchangeConfig {
TokenExchangeConfig::new(
"https://idp.example.com/token".into(),
"client".into(),
client_secret.map(|s| secrecy::SecretString::new(s.into())),
client_cert,
"downstream".into(),
)
}
#[test]
fn validate_rejects_token_exchange_without_client_auth() {
let cfg = https_cfg_with_tx(tx_with(None, None));
let err = cfg
.validate()
.expect_err("token_exchange without client auth must be rejected");
let msg = err.to_string();
assert!(
msg.contains("requires client authentication"),
"error must explain missing client auth; got {msg:?}"
);
}
#[test]
fn validate_rejects_token_exchange_with_both_secret_and_cert() {
let cc = ClientCertConfig {
cert_path: PathBuf::from("/nonexistent/cert.pem"),
key_path: PathBuf::from("/nonexistent/key.pem"),
};
let cfg = https_cfg_with_tx(tx_with(Some("s"), Some(cc)));
let err = cfg
.validate()
.expect_err("client_secret + client_cert must be rejected");
let msg = err.to_string();
assert!(
msg.contains("mutually") && msg.contains("exclusive"),
"error must explain mutual exclusion; got {msg:?}"
);
}
#[cfg(not(feature = "oauth-mtls-client"))]
#[test]
fn validate_rejects_client_cert_without_feature() {
let cc = ClientCertConfig {
cert_path: PathBuf::from("/nonexistent/cert.pem"),
key_path: PathBuf::from("/nonexistent/key.pem"),
};
let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
let err = cfg
.validate()
.expect_err("client_cert without feature must be rejected");
assert!(
err.to_string().contains("oauth-mtls-client"),
"error must reference the cargo feature; got {err}"
);
}
#[cfg(feature = "oauth-mtls-client")]
#[test]
fn validate_rejects_missing_client_cert_files() {
let cc = ClientCertConfig {
cert_path: PathBuf::from("/nonexistent/cert.pem"),
key_path: PathBuf::from("/nonexistent/key.pem"),
};
let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
let err = cfg
.validate()
.expect_err("missing cert file must be rejected");
assert!(
err.to_string().contains("unreadable"),
"error must call out unreadable file; got {err}"
);
}
#[cfg(feature = "oauth-mtls-client")]
#[test]
fn validate_rejects_malformed_client_cert_pem() {
let dir = std::env::temp_dir();
let cert = dir.join(format!("rmcp-mtls-bad-cert-{}.pem", std::process::id()));
let key = dir.join(format!("rmcp-mtls-bad-key-{}.pem", std::process::id()));
std::fs::write(&cert, b"not a real PEM").expect("write tmp cert");
std::fs::write(&key, b"not a real PEM either").expect("write tmp key");
let cc = ClientCertConfig {
cert_path: cert.clone(),
key_path: key.clone(),
};
let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
let err = cfg.validate().expect_err("malformed PEM must be rejected");
let _ = std::fs::remove_file(&cert);
let _ = std::fs::remove_file(&key);
assert!(
err.to_string().contains("PEM parse failed"),
"error must call out PEM parse failure; got {err}"
);
}
#[cfg(feature = "oauth-mtls-client")]
fn write_self_signed_pem() -> (PathBuf, PathBuf) {
let cert = rcgen::generate_simple_self_signed(vec!["client.test".into()]).expect("rcgen");
let dir = std::env::temp_dir();
let pid = std::process::id();
let nonce: u64 = rand::random();
let cert_path = dir.join(format!("rmcp-mtls-cert-{pid}-{nonce}.pem"));
let key_path = dir.join(format!("rmcp-mtls-key-{pid}-{nonce}.pem"));
std::fs::write(&cert_path, cert.cert.pem()).expect("write cert");
std::fs::write(&key_path, cert.signing_key.serialize_pem()).expect("write key");
(cert_path, key_path)
}
#[cfg(feature = "oauth-mtls-client")]
fn install_test_crypto_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
#[cfg(feature = "oauth-mtls-client")]
#[test]
fn validate_accepts_well_formed_client_cert() {
install_test_crypto_provider();
let (cert_path, key_path) = write_self_signed_pem();
let cc = ClientCertConfig {
cert_path: cert_path.clone(),
key_path: key_path.clone(),
};
let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
let res = cfg.validate();
let _ = std::fs::remove_file(&cert_path);
let _ = std::fs::remove_file(&key_path);
res.expect("well-formed cert+key must validate");
}
#[cfg(feature = "oauth-mtls-client")]
#[test]
fn client_for_returns_cached_mtls_client() {
install_test_crypto_provider();
let (cert_path, key_path) = write_self_signed_pem();
let cc = ClientCertConfig {
cert_path: cert_path.clone(),
key_path: key_path.clone(),
};
let cfg = https_cfg_with_tx(tx_with(None, Some(cc)));
let http = OauthHttpClient::with_config(&cfg).expect("build mtls client");
let tx_ref = cfg.token_exchange.as_ref().expect("tx set");
let cert_client = http.client_for(tx_ref);
let inner_client = http.client_for(&tx_with(Some("s"), None));
let _ = std::fs::remove_file(&cert_path);
let _ = std::fs::remove_file(&key_path);
assert!(
!std::ptr::eq(cert_client, inner_client),
"client_for must return distinct clients for cert vs no-cert configs"
);
}
#[cfg(feature = "oauth-mtls-client")]
#[test]
fn client_for_falls_back_to_inner_when_cache_miss() {
install_test_crypto_provider();
let cfg = validation_https_config();
let http = OauthHttpClient::with_config(&cfg).expect("build client");
let unrelated_cc = ClientCertConfig {
cert_path: PathBuf::from("/cache/miss/cert.pem"),
key_path: PathBuf::from("/cache/miss/key.pem"),
};
let tx_unknown = tx_with(None, Some(unrelated_cc));
let fallback = http.client_for(&tx_unknown);
let inner = http.client_for(&tx_with(Some("s"), None));
assert!(
std::ptr::eq(fallback, inner),
"cache miss must fall back to inner client"
);
}
}