mod common;
use std::collections::HashMap;
use common::{TestRig, UpstreamSpec};
use rmcp::ServiceExt;
use rmcp::transport::StreamableHttpClientTransport;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use test_upstream::{TestTool, TestToolBehavior};
fn echo(name: &str) -> TestTool {
TestTool { name: name.into(), description: None, behavior: TestToolBehavior::Echo }
}
#[tokio::test]
async fn missing_authorization_fails_initialize() {
let rig = TestRig::new(vec![
UpstreamSpec::new("private")
.with_tools(vec![echo("hidden")])
.with_require_auth("Bearer secret"),
])
.await;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
let mut headers = HeaderMap::new();
headers.insert(
HeaderName::from_static("x-mcp-servers"),
HeaderValue::from_str(&rig.x_mcp_servers()).unwrap(),
);
let transport = StreamableHttpClientTransport::from_config(
StreamableHttpClientTransportConfig::with_uri(rig.proxy.url.clone())
.custom_headers(headers.into_iter().filter_map(|(n, v)| Some((n?, v))).collect()),
);
let result = client_info_for_test().serve(transport).await;
let err = result.err().expect("initialize should fail with -32603");
let msg = format!("{err:?}");
assert!(
msg.contains("-32603") && msg.contains("upstream connect failed"),
"unexpected error: {msg}",
);
}
fn client_info_for_test() -> rmcp::model::ClientInfo {
let value = serde_json::json!({
"protocolVersion": "2025-06-18",
"capabilities": {},
"clientInfo": { "name": "auth-test", "version": "0.1.0" },
});
serde_json::from_value(value).expect("ClientInfo deserialize")
}
#[tokio::test]
async fn correct_authorization_lets_the_upstream_in() {
let rig = TestRig::new(vec![
UpstreamSpec::new("private")
.with_tools(vec![echo("hidden")])
.with_require_auth("Bearer secret"),
])
.await;
let per_url_headers = serde_json::json!({
rig.upstreams[0].url.clone(): { "Authorization": "Bearer secret" },
});
let mut headers = HashMap::new();
headers.insert("X-MCP-Servers", rig.x_mcp_servers());
headers.insert("X-MCP-Headers", per_url_headers.to_string());
let client = rig.connect_client(headers).await;
let tools = client.peer().list_all_tools().await.expect("list_all_tools");
let names: Vec<String> = tools.into_iter().map(|t| t.name.into()).collect();
assert_eq!(names, vec!["private_hidden".to_string()]);
client.cancel().await.ok();
}