#![allow(clippy::expect_used, reason = "tests")]
#![allow(clippy::unwrap_used, reason = "tests")]
#![allow(clippy::panic, reason = "tests")]
#![cfg(all(feature = "oauth", feature = "test-helpers"))]
use std::time::Duration;
use rmcp_server_kit::oauth::{OAuthConfig, OauthHttpClient};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpListener,
};
fn install_crypto_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
fn build_client(allow_loopback: bool) -> OauthHttpClient {
install_crypto_provider();
let mut config = OAuthConfig::default();
config.allow_http_oauth_urls = true;
let client = OauthHttpClient::with_config(&config).expect("client builds");
if allow_loopback {
client.__test_allow_loopback_ssrf()
} else {
client
}
}
fn render_chain(err: &dyn std::error::Error) -> String {
let mut out = err.to_string();
let mut current = err.source();
while let Some(inner) = current {
out.push_str(" :: ");
out.push_str(&inner.to_string());
current = inner.source();
}
out
}
#[tokio::test]
async fn resolver_contract_always_err_loopback_blocked() {
let client = build_client(false);
let err = client
.__test_get("http://localhost/")
.await
.expect_err("loopback must be blocked without bypass");
let chain = render_chain(&err);
assert!(
chain.contains("ssrf:"),
"diagnostic must carry ssrf: prefix; got: {chain}"
);
assert!(
chain.contains("loopback") || chain.contains("blocked IP"),
"diagnostic must name the block reason; got: {chain}"
);
}
#[tokio::test]
async fn resolver_contract_empty_dns_failure_not_classified_as_ssrf() {
let client = build_client(false);
let err = client
.__test_get("http://nonexistent-host-for-mcpx-tests.invalid/")
.await
.expect_err("unresolvable host must surface as error");
let chain = render_chain(&err);
assert!(
!chain.contains("ssrf:"),
"DNS failure must not be tagged as ssrf policy denial; got: {chain}"
);
}
#[tokio::test]
async fn resolver_contract_verbatim_passthrough_with_bypass() {
let listener = TcpListener::bind("127.0.0.1:0").await.expect("bind");
let port = listener.local_addr().expect("local_addr").port();
let server = tokio::spawn(async move {
let (mut sock, _) = listener.accept().await.expect("accept");
let mut buf = [0u8; 1024];
let _ = tokio::time::timeout(Duration::from_secs(2), sock.read(&mut buf)).await;
let response = b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: close\r\n\r\nok";
let _ = sock.write_all(response).await;
let _ = sock.shutdown().await;
});
let client = build_client(true);
let url = format!("http://127.0.0.1:{port}/");
let response = client
.__test_get(&url)
.await
.expect("loopback must succeed with bypass enabled");
assert!(
response.status().is_success(),
"expected 2xx; got {}",
response.status()
);
let _ = response.bytes().await;
let _ = tokio::time::timeout(Duration::from_secs(2), server).await;
}
const DECOY_PROXY: &str = "http://192.0.2.1:1"; const TARGET_URL: &str = "http://localhost/";
fn run_with_env(var: &str, value: &str) -> String {
temp_env::with_var(var, Some(value), || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.expect("rt");
rt.block_on(async {
let client = build_client(false);
let err = client
.__test_get(TARGET_URL)
.await
.expect_err("loopback target must be rejected");
render_chain(&err)
})
})
}
#[test]
fn no_proxy_defeats_all_env_proxy_variants() {
for var in [
"HTTP_PROXY",
"HTTPS_PROXY",
"ALL_PROXY",
"http_proxy",
"https_proxy",
"all_proxy",
] {
let chain = run_with_env(var, DECOY_PROXY);
assert!(
chain.contains("ssrf:"),
"{var} must not bypass resolver; got: {chain}"
);
}
}