#![allow(clippy::expect_used, reason = "tests")]
#![allow(clippy::unwrap_used, reason = "tests")]
#![allow(clippy::panic, reason = "tests")]
#![allow(clippy::print_stdout, reason = "tests")]
#![allow(clippy::print_stderr, reason = "tests")]
#![allow(clippy::indexing_slicing, reason = "tests")]
#![allow(dead_code, reason = "PEM fields kept for symmetry / future tests")]
#![cfg(all(feature = "oauth", feature = "test-helpers"))]
use std::{net::SocketAddr, path::PathBuf, sync::Arc, time::Duration};
use rcgen::{
BasicConstraints, CertificateParams, CertifiedIssuer, DnType, IsCa, KeyPair, KeyUsagePurpose,
};
use rmcp_server_kit::oauth::{OAuthConfig, OauthHttpClient};
use rustls::{
ServerConfig,
pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer},
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::TcpListener,
};
use tokio_rustls::TlsAcceptor;
use wiremock::{
Mock, MockServer, ResponseTemplate,
matchers::{method, path},
};
struct TestPki {
ca_pem: String,
leaf_cert_pem: String,
leaf_key_pem: String,
leaf_cert_der: Vec<u8>,
leaf_key_der: Vec<u8>,
}
fn build_test_pki() -> TestPki {
let mut ca_params = CertificateParams::new(Vec::<String>::new()).expect("ca params");
ca_params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
ca_params.key_usages = vec![
KeyUsagePurpose::KeyCertSign,
KeyUsagePurpose::CrlSign,
KeyUsagePurpose::DigitalSignature,
];
ca_params
.distinguished_name
.push(DnType::CommonName, "oauth-test-ca");
let ca_key = KeyPair::generate().expect("ca key");
let ca_issuer: CertifiedIssuer<'static, KeyPair> =
CertifiedIssuer::self_signed(ca_params, ca_key).expect("ca self-signed");
let mut leaf_params =
CertificateParams::new(vec!["localhost".to_owned()]).expect("leaf params");
leaf_params
.distinguished_name
.push(DnType::CommonName, "oauth-test-leaf");
let leaf_key = KeyPair::generate().expect("leaf key");
let leaf_cert = leaf_params
.signed_by(&leaf_key, &ca_issuer)
.expect("leaf signed");
let ca_pem = ca_issuer.as_ref().pem();
let leaf_cert_pem = leaf_cert.pem();
let leaf_key_pem = leaf_key.serialize_pem();
let leaf_cert_der = leaf_cert.der().to_vec();
let leaf_key_der = leaf_key.serialize_der();
TestPki {
ca_pem,
leaf_cert_pem,
leaf_key_pem,
leaf_cert_der,
leaf_key_der,
}
}
fn install_crypto_provider() {
let _ = rustls::crypto::ring::default_provider().install_default();
}
fn build_server_config(pki: &TestPki) -> Arc<ServerConfig> {
let cert = CertificateDer::from(pki.leaf_cert_der.clone());
let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(pki.leaf_key_der.clone()));
let config = ServerConfig::builder()
.with_no_client_auth()
.with_single_cert(vec![cert], key)
.expect("server config");
Arc::new(config)
}
async fn spawn_one_shot_tls(pki: &TestPki, response_bytes: Vec<u8>) -> String {
install_crypto_provider();
let server_config = build_server_config(pki);
let acceptor = TlsAcceptor::from(server_config);
let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0)))
.await
.expect("bind 127.0.0.1:0");
let port = listener.local_addr().expect("local_addr").port();
tokio::spawn(async move {
let accept_fut = listener.accept();
let (tcp, _peer) = match tokio::time::timeout(Duration::from_secs(5), accept_fut).await {
Ok(Ok(pair)) => pair,
Ok(Err(e)) => {
eprintln!("test tls accept error: {e}");
return;
}
Err(_) => {
eprintln!("test tls accept timed out");
return;
}
};
let mut tls_stream = match acceptor.accept(tcp).await {
Ok(s) => s,
Err(e) => {
eprintln!("test tls handshake error: {e}");
return;
}
};
let mut buf = vec![0u8; 16 * 1024];
let mut filled = 0usize;
while filled < buf.len() {
let n = match tokio::time::timeout(
Duration::from_secs(5),
tls_stream.read(&mut buf[filled..]),
)
.await
{
Ok(Ok(0)) => break,
Ok(Ok(n)) => n,
Ok(Err(e)) => {
eprintln!("test tls read error: {e}");
return;
}
Err(_) => {
eprintln!("test tls read timed out");
return;
}
};
filled += n;
if buf[..filled].windows(4).any(|w| w == b"\r\n\r\n") {
break;
}
}
if let Err(e) = tls_stream.write_all(&response_bytes).await {
eprintln!("test tls write error: {e}");
return;
}
let _ = tls_stream.shutdown().await;
});
format!("https://localhost:{port}/")
}
fn build_client_with_ca(pki: &TestPki, allow_http: bool) -> (OauthHttpClient, PathBuf) {
let dir = std::env::temp_dir();
let pid = std::process::id();
let nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_or(0, |d| d.as_nanos());
let ca_path = dir.join(format!("rmcp-oauth-ca-{pid}-{nanos}.pem"));
std::fs::write(&ca_path, pki.ca_pem.as_bytes()).expect("write ca pem");
let mut config = OAuthConfig::default();
config.ca_cert_path = Some(ca_path.clone());
config.allow_http_oauth_urls = allow_http;
let client = OauthHttpClient::with_config(&config)
.expect("client builds")
.__test_allow_loopback_ssrf();
(client, ca_path)
}
fn consume_pem(pki: &TestPki) {
let _ = (&pki.leaf_cert_pem, &pki.leaf_key_pem);
}
fn render_error_chain(err: &dyn std::error::Error) -> String {
let mut out = err.to_string();
let mut current = err.source();
while let Some(inner) = current {
out.push_str(" :: ");
out.push_str(&inner.to_string());
current = inner.source();
}
out.to_lowercase()
}
#[tokio::test]
async fn redirect_downgrade_https_to_http_is_rejected() {
let pki = build_test_pki();
consume_pem(&pki);
let response_bytes = b"HTTP/1.1 302 Found\r\n\
Location: http://attacker.invalid/exfil\r\n\
Content-Length: 0\r\n\
Connection: close\r\n\r\n"
.to_vec();
let url = spawn_one_shot_tls(&pki, response_bytes).await;
let (client, _ca_path) = build_client_with_ca(&pki, true);
let result = client.__test_get(&url).await;
let err = result.expect_err("downgrade must be rejected");
let rendered = render_error_chain(&err);
assert!(
rendered.contains("downgrade") || rendered.contains("https -> http"),
"expected downgrade error, got: {rendered}"
);
assert!(
err.is_redirect(),
"expected reqwest::Error::is_redirect()=true, got {err:?}"
);
}
#[tokio::test]
async fn redirect_to_non_http_scheme_is_rejected() {
install_crypto_provider();
let mock = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/redir"))
.respond_with(
ResponseTemplate::new(302).insert_header("location", "ftp://attacker.invalid/loot"),
)
.mount(&mock)
.await;
let mut config = OAuthConfig::default();
config.allow_http_oauth_urls = true;
let client = OauthHttpClient::with_config(&config)
.expect("client builds")
.__test_allow_loopback_ssrf();
let url = format!("{}/redir", mock.uri());
let result = client.__test_get(&url).await;
let err = result.expect_err("non-HTTP(S) redirect must be rejected");
let rendered = render_error_chain(&err);
assert!(
rendered.contains("non-http") || rendered.contains("refused") || rendered.contains("ftp"),
"expected non-HTTP(S) error, got: {rendered}"
);
assert!(err.is_redirect(), "expected redirect-error, got {err:?}");
}
#[tokio::test]
async fn redirect_chain_capped_at_two_hops() {
install_crypto_provider();
let mock = MockServer::start().await;
let base = mock.uri().replace("127.0.0.1", "localhost");
let to_b = format!("{base}/b");
let to_c = format!("{base}/c");
let to_d = format!("{base}/d");
Mock::given(method("GET"))
.and(path("/a"))
.respond_with(ResponseTemplate::new(302).insert_header("location", to_b.as_str()))
.mount(&mock)
.await;
Mock::given(method("GET"))
.and(path("/b"))
.respond_with(ResponseTemplate::new(302).insert_header("location", to_c.as_str()))
.mount(&mock)
.await;
Mock::given(method("GET"))
.and(path("/c"))
.respond_with(ResponseTemplate::new(302).insert_header("location", to_d.as_str()))
.mount(&mock)
.await;
Mock::given(method("GET"))
.and(path("/d"))
.respond_with(ResponseTemplate::new(200))
.mount(&mock)
.await;
let mut config = OAuthConfig::default();
config.allow_http_oauth_urls = true;
let client = OauthHttpClient::with_config(&config)
.expect("client builds")
.__test_allow_loopback_ssrf();
let url = format!("{base}/a");
let result = client.__test_get(&url).await;
let err = result.expect_err("3-hop redirect must be rejected");
let rendered = render_error_chain(&err);
assert!(
rendered.contains("too many redirects") || rendered.contains("max 2"),
"expected redirect-cap error, got: {rendered}"
);
assert!(err.is_redirect(), "expected redirect-error, got {err:?}");
}
#[tokio::test]
async fn ca_cert_path_is_applied_to_oauth_http_client() {
let pki = build_test_pki();
let response_bytes = b"HTTP/1.1 200 OK\r\n\
Content-Length: 0\r\n\
Connection: close\r\n\r\n"
.to_vec();
let url = spawn_one_shot_tls(&pki, response_bytes).await;
let (client, _ca_path) = build_client_with_ca(&pki, false);
let response = client
.__test_get(&url)
.await
.expect("request with custom CA must succeed");
assert_eq!(response.status().as_u16(), 200);
}
#[tokio::test]
async fn missing_ca_cert_path_makes_self_signed_request_fail() {
let pki = build_test_pki();
let response_bytes = b"HTTP/1.1 200 OK\r\n\
Content-Length: 0\r\n\
Connection: close\r\n\r\n"
.to_vec();
let url = spawn_one_shot_tls(&pki, response_bytes).await;
let config = OAuthConfig::default();
let client = OauthHttpClient::with_config(&config)
.expect("client builds")
.__test_allow_loopback_ssrf();
let result = client.__test_get(&url).await;
let err = result.expect_err("untrusted self-signed leaf must be rejected");
assert!(
err.is_connect() || err.is_request() || err.is_builder() || err.is_decode(),
"expected TLS-layer failure, got: {err:?}"
);
}
#[test]
fn nonexistent_ca_cert_path_returns_startup_error() {
let mut config = OAuthConfig::default();
config.ca_cert_path = Some(PathBuf::from(
"Z:/this/path/definitely/does/not/exist/ca.pem",
));
let err = OauthHttpClient::with_config(&config).expect_err("must fail to read CA");
let rendered = format!("{err:#}");
assert!(
rendered.contains("ca_cert_path") || rendered.contains("read"),
"expected ca_cert_path read error, got: {rendered}"
);
}
#[tokio::test]
async fn rejects_per_hop_redirect_to_private_ip_oauth_client() {
install_crypto_provider();
let mock = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/redir"))
.respond_with(
ResponseTemplate::new(302).insert_header("location", "https://10.0.0.1/internal"),
)
.mount(&mock)
.await;
let mut config = OAuthConfig::default();
config.allow_http_oauth_urls = true;
let client = OauthHttpClient::with_config(&config)
.expect("client builds")
.__test_allow_loopback_ssrf();
let url = format!("{}/redir", mock.uri());
let result = client.__test_get(&url).await;
let err = result.expect_err("redirect to private IP must be rejected");
let rendered = render_error_chain(&err);
assert!(
rendered.contains("redirect target forbidden")
|| rendered.contains("private")
|| rendered.contains("rfc1918"),
"expected redirect-target-forbidden error (per-hop SSRF guard), got: {rendered}"
);
assert!(
err.is_redirect(),
"expected reqwest::Error::is_redirect()=true, got {err:?}"
);
}
#[tokio::test]
async fn rejects_per_hop_redirect_to_loopback_oauth_client() {
install_crypto_provider();
let mock = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/redir"))
.respond_with(
ResponseTemplate::new(302).insert_header("location", "https://127.0.0.1/admin"),
)
.mount(&mock)
.await;
let mut config = OAuthConfig::default();
config.allow_http_oauth_urls = true;
let client = OauthHttpClient::with_config(&config)
.expect("client builds")
.__test_allow_loopback_ssrf();
let url = format!("{}/redir", mock.uri());
let result = client.__test_get(&url).await;
let err = result.expect_err("redirect to loopback must be rejected");
let rendered = render_error_chain(&err);
assert!(
rendered.contains("redirect target forbidden") || rendered.contains("loopback"),
"expected loopback redirect rejection, got: {rendered}"
);
assert!(err.is_redirect(), "expected redirect-error, got {err:?}");
}
#[tokio::test]
async fn rejects_redirect_with_userinfo_oauth_client() {
install_crypto_provider();
let mock = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/redir"))
.respond_with(
ResponseTemplate::new(302).insert_header("location", "https://evil@example.com/pwn"),
)
.mount(&mock)
.await;
let mut config = OAuthConfig::default();
config.allow_http_oauth_urls = true;
let client = OauthHttpClient::with_config(&config).expect("client builds");
let url = format!("{}/redir", mock.uri());
let result = client.__test_get(&url).await;
let err = result.expect_err("redirect with userinfo must be rejected");
let rendered = render_error_chain(&err);
assert!(
rendered.contains("redirect target forbidden")
|| rendered.contains("userinfo")
|| rendered.contains("credentials"),
"expected userinfo redirect rejection, got: {rendered}"
);
assert!(err.is_redirect(), "expected redirect-error, got {err:?}");
}
#[tokio::test]
async fn redirect_to_http_with_userinfo_rejected_when_http_allowed() {
install_crypto_provider();
let mock = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/redir"))
.respond_with(
ResponseTemplate::new(302)
.insert_header("location", "http://user:pass@example.com/pwn"),
)
.mount(&mock)
.await;
let mut config = OAuthConfig::default();
config.allow_http_oauth_urls = true;
let client = OauthHttpClient::with_config(&config).expect("client builds");
let url = format!("{}/redir", mock.uri());
let result = client.__test_get(&url).await;
let err = result.expect_err("http redirect with userinfo must be rejected");
let rendered = render_error_chain(&err);
assert!(
rendered.contains("redirect target forbidden")
|| rendered.contains("userinfo")
|| rendered.contains("credentials"),
"expected userinfo redirect rejection, got: {rendered}"
);
assert!(err.is_redirect(), "expected redirect-error, got {err:?}");
}