#![cfg(feature = "mcp")]
use std::collections::{HashMap, VecDeque};
use std::convert::Infallible;
use std::pin::Pin;
use std::sync::Arc;
use axum::body::{Body, Bytes};
use axum::http::{HeaderMap, HeaderValue, StatusCode, header};
use axum::response::{IntoResponse, Response};
use futures::{Stream, StreamExt as _};
use serde_json::{Value, json};
use tower::ServiceExt as _;
use crate::sse::{Event, Sse};
use crate::openapi::{ApiDoc, schema_entry_to_value};
const DEFAULT_PROTOCOL_VERSION: &str = "2025-06-18";
const SUPPORTED_PROTOCOL_VERSIONS: &[&str] = &["2025-06-18", "2025-03-26", "2024-11-05"];
const MCP_REQUEST_HEADERS: &[&str] = &["mcp-protocol-version", "mcp-session-id"];
const MAX_TOOL_RESPONSE_BYTES: usize = 10 * 1024 * 1024;
const FORWARDED_HEADERS: &[&str] = &[
"authorization",
"cookie",
"idempotency-key",
"host",
"forwarded",
"x-forwarded-for",
"x-forwarded-host",
"x-forwarded-proto",
"x-real-ip",
"accept-language",
];
pub(crate) type McpEndpointLayer = Box<
dyn FnOnce(axum::Router<crate::state::AppState>) -> axum::Router<crate::state::AppState> + Send,
>;
pub struct McpRuntime {
pub mount_path: String,
pub expose_all: bool,
pub(crate) endpoint_layer: Option<McpEndpointLayer>,
}
impl McpRuntime {
#[must_use]
pub fn new(mount_path: impl Into<String>) -> Self {
Self {
mount_path: mount_path.into(),
expose_all: false,
endpoint_layer: None,
}
}
}
#[derive(Clone, Debug)]
struct McpTool {
name: String,
description: Option<String>,
input_schema: Value,
annotations: Value,
method: String,
path_template: String,
path_params: Vec<String>,
has_body: bool,
has_query: bool,
streams: bool,
}
impl McpTool {
fn descriptor(&self) -> Value {
let mut obj = serde_json::Map::new();
obj.insert("name".into(), json!(self.name));
if let Some(desc) = &self.description {
obj.insert("description".into(), json!(desc));
}
obj.insert("inputSchema".into(), self.input_schema.clone());
obj.insert("annotations".into(), self.annotations.clone());
Value::Object(obj)
}
}
pub(crate) struct McpWiring {
pub cors: crate::config::CorsConfig,
pub trusted_hosts: crate::router::TrustedHostPolicy,
pub tenant_header: Option<String>,
pub csrf_header: String,
pub envelope_rate_limited: bool,
}
pub struct McpServer {
tools: Vec<McpTool>,
by_name: HashMap<String, usize>,
dispatch: axum::Router,
cors: crate::config::CorsConfig,
trusted_hosts: crate::router::TrustedHostPolicy,
tenant_header: Option<String>,
csrf_header: String,
envelope_rate_limited: bool,
server_name: String,
server_version: String,
}
impl McpServer {
fn origin_allowed(&self, origin: &str, host: Option<&str>, scheme: Option<&str>) -> bool {
if let Some(host) = host
&& is_same_origin(origin, host, scheme)
&& crate::router::extract_host_without_port(host)
.is_some_and(|h| self.trusted_hosts.allows_host(&h.to_ascii_lowercase()))
{
return true;
}
self.cors
.allowed_origins
.iter()
.any(|allowed| allowed == "*" || allowed == origin)
}
}
fn is_same_origin(origin: &str, host: &str, scheme: Option<&str>) -> bool {
let Some((origin_scheme, origin_authority)) = origin.split_once("://") else {
return false;
};
if scheme.is_some_and(|s| !s.eq_ignore_ascii_case(origin_scheme)) {
return false;
}
let host_scheme = scheme.unwrap_or(origin_scheme);
authority_matches(origin_authority, origin_scheme, host, host_scheme)
}
fn authority_matches(a: &str, a_scheme: &str, b: &str, b_scheme: &str) -> bool {
let (a_host, a_port) = split_host_port(a);
let (b_host, b_port) = split_host_port(b);
if !a_host.eq_ignore_ascii_case(b_host) {
return false;
}
a_port.or_else(|| default_port(a_scheme)) == b_port.or_else(|| default_port(b_scheme))
}
fn split_host_port(authority: &str) -> (&str, Option<&str>) {
if authority.starts_with('[') {
if let Some(close) = authority.find(']') {
let host = &authority[..=close];
let port = authority[close + 1..]
.strip_prefix(':')
.filter(|p| !p.is_empty());
return (host, port);
}
return (authority, None);
}
match authority.rsplit_once(':') {
Some((host, port)) if !port.is_empty() && port.bytes().all(|c| c.is_ascii_digit()) => {
(host, Some(port))
}
_ => (authority, None),
}
}
fn default_port(scheme: &str) -> Option<&'static str> {
match scheme.to_ascii_lowercase().as_str() {
"https" => Some("443"),
"http" => Some("80"),
_ => None,
}
}
fn should_expose(doc: &ApiDoc, expose_all: bool) -> bool {
if doc.hidden || doc.mcp_exclude {
return false;
}
if doc.mcp_stream {
if doc.mcp_tool {
return true;
}
return expose_all && is_read_only(doc.method);
}
if doc.response.is_none() {
return false;
}
if doc.mcp_tool {
return true;
}
if expose_all {
return is_read_only(doc.method);
}
false
}
fn is_read_only(method: &str) -> bool {
matches!(method.to_ascii_uppercase().as_str(), "GET" | "HEAD")
}
fn annotations_for(method: &str, title: &str) -> Value {
let upper = method.to_ascii_uppercase();
let read_only = is_read_only(&upper);
let mut obj = serde_json::Map::new();
obj.insert("title".into(), json!(title));
obj.insert("readOnlyHint".into(), json!(read_only));
if upper == "DELETE" {
obj.insert("destructiveHint".into(), json!(true));
}
Value::Object(obj)
}
fn build_input_schema(doc: &ApiDoc, components: &serde_json::Map<String, Value>) -> Value {
let mut properties = serde_json::Map::new();
let mut required: Vec<Value> = Vec::new();
let mut defs = serde_json::Map::new();
for param in doc.path_params {
let name = param.strip_prefix('*').unwrap_or(param);
properties.insert(name.to_owned(), json!({ "type": "string" }));
required.push(json!(name));
}
if let Some(query) = &doc.query_schema {
let schema = rewrite_refs(schema_entry_to_value(query), components, &mut defs);
properties.insert("query".to_owned(), schema);
}
if let Some(body) = &doc.request_body {
let schema = rewrite_refs(schema_entry_to_value(body), components, &mut defs);
properties.insert("body".to_owned(), schema);
required.push(json!("body"));
}
let mut schema = serde_json::Map::new();
schema.insert("type".into(), json!("object"));
schema.insert("properties".into(), Value::Object(properties));
if !required.is_empty() {
schema.insert("required".into(), Value::Array(required));
}
if !defs.is_empty() {
schema.insert("$defs".into(), Value::Object(defs));
}
Value::Object(schema)
}
fn rewrite_refs(
value: Value,
components: &serde_json::Map<String, Value>,
defs: &mut serde_json::Map<String, Value>,
) -> Value {
match value {
Value::Object(map) => {
if let Some(Value::String(reference)) = map.get("$ref")
&& let Some(name) = reference.strip_prefix("#/components/schemas/")
{
let name = name.to_owned();
let local = format!("#/$defs/{name}");
if !defs.contains_key(&name) {
defs.insert(name.clone(), Value::Null);
let resolved = components
.get(&name)
.cloned()
.unwrap_or_else(|| json!({ "type": "object", "title": name.clone() }));
let resolved = rewrite_refs(resolved, components, defs);
defs.insert(name, resolved);
}
return json!({ "$ref": local });
}
let rewritten: serde_json::Map<String, Value> = map
.into_iter()
.map(|(k, v)| (k, rewrite_refs(v, components, defs)))
.collect();
Value::Object(rewritten)
}
Value::Array(items) => Value::Array(
items
.into_iter()
.map(|v| rewrite_refs(v, components, defs))
.collect(),
),
other => other,
}
}
#[must_use]
pub fn derive_tools(
docs: &[ApiDoc],
expose_all: bool,
openapi: Option<&crate::openapi::OpenApiConfig>,
) -> Vec<McpToolInfo> {
let refs: Vec<&ApiDoc> = docs.iter().collect();
let config = openapi.cloned().unwrap_or_else(|| {
crate::openapi::OpenApiConfig::new("autumn-mcp", env!("CARGO_PKG_VERSION"))
});
let spec = crate::openapi::generate_spec(&config, &refs);
let components = spec
.components
.as_ref()
.map(|c| serde_json::to_value(&c.schemas).unwrap_or(Value::Null))
.and_then(|v| v.as_object().cloned())
.unwrap_or_default();
let mut tools = Vec::new();
let mut seen: std::collections::HashSet<&str> = std::collections::HashSet::new();
for doc in docs {
if (doc.mcp_tool || (expose_all && is_read_only(doc.method)))
&& doc.response.is_none()
&& !doc.mcp_stream
&& !doc.mcp_exclude
&& !doc.hidden
{
tracing::warn!(
operation_id = doc.operation_id,
method = doc.method,
path = doc.path,
"skipping MCP exposure: endpoint has no JSON response schema \
(HTML/Maud routes are not eligible as MCP tools)"
);
continue;
}
if !should_expose(doc, expose_all) {
continue;
}
if !seen.insert(doc.operation_id) {
tracing::warn!(
operation_id = doc.operation_id,
method = doc.method,
path = doc.path,
"duplicate MCP tool name; keeping the first registration and \
skipping this duplicate (set a distinct operation_id to expose both)"
);
continue;
}
let title = doc.summary.unwrap_or(doc.operation_id);
tools.push(McpToolInfo {
name: doc.operation_id.to_owned(),
description: doc.description.or(doc.summary).map(str::to_owned),
input_schema: build_input_schema(doc, &components),
annotations: annotations_for(doc.method, title),
method: doc.method.to_owned(),
path_template: doc.path.to_owned(),
path_params: doc.path_params.iter().map(|p| (*p).to_owned()).collect(),
has_body: doc.request_body.is_some(),
has_query: doc.query_schema.is_some(),
streams: doc.mcp_stream,
});
}
tools
}
#[derive(Clone, Debug)]
pub struct McpToolInfo {
name: String,
description: Option<String>,
input_schema: Value,
annotations: Value,
method: String,
path_template: String,
path_params: Vec<String>,
has_body: bool,
has_query: bool,
streams: bool,
}
impl McpServer {
#[must_use]
pub(crate) fn new(tools: Vec<McpToolInfo>, dispatch: axum::Router, wiring: McpWiring) -> Self {
let tools: Vec<McpTool> = tools
.into_iter()
.map(|t| McpTool {
name: t.name,
description: t.description,
input_schema: t.input_schema,
annotations: t.annotations,
method: t.method,
path_template: t.path_template,
path_params: t.path_params,
has_body: t.has_body,
has_query: t.has_query,
streams: t.streams,
})
.collect();
let by_name = tools
.iter()
.enumerate()
.map(|(i, t)| (t.name.clone(), i))
.collect();
Self {
tools,
by_name,
dispatch,
cors: wiring.cors,
trusted_hosts: wiring.trusted_hosts,
tenant_header: wiring.tenant_header,
csrf_header: wiring.csrf_header,
envelope_rate_limited: wiring.envelope_rate_limited,
server_name: "autumn-mcp".to_owned(),
server_version: env!("CARGO_PKG_VERSION").to_owned(),
}
}
}
pub(crate) fn build_mcp_router(
mount_path: &str,
tools: Vec<McpToolInfo>,
dispatch: axum::Router,
wiring: McpWiring,
endpoint_layer: Option<McpEndpointLayer>,
) -> axum::Router<crate::state::AppState> {
let server = Arc::new(McpServer::new(tools, dispatch, wiring));
tracing::debug!(
path = mount_path,
tools = server.tools.len(),
"Mounted MCP endpoint"
);
let mut rpc = axum::Router::<crate::state::AppState>::new()
.route(
mount_path,
axum::routing::get(serve_mcp_get).post(serve_mcp),
)
.layer(axum::extract::Extension(Arc::clone(&server)));
if let Some(layer_fn) = endpoint_layer {
rpc = layer_fn(rpc);
}
let guard_server = Arc::clone(&server);
rpc = rpc.layer(axum::middleware::from_fn(move |req, next| {
mcp_host_origin_guard(Arc::clone(&guard_server), req, next)
}));
let preflight = axum::Router::<crate::state::AppState>::new()
.route(mount_path, axum::routing::options(serve_mcp_options))
.layer(axum::extract::Extension(server));
rpc.merge(preflight)
}
pub(crate) fn apply_mcp_cors_layer(
router: axum::Router<crate::state::AppState>,
cors: &crate::config::CorsConfig,
) -> axum::Router<crate::state::AppState> {
let cors = cors.clone();
router.layer(axum::middleware::from_fn(
move |req: axum::extract::Request, next: axum::middleware::Next| {
let cors = cors.clone();
async move {
let origin = req
.headers()
.get(header::ORIGIN)
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
let mut response = next.run(req).await;
apply_cors_headers(&cors, origin.as_deref(), &mut response);
response
}
},
))
}
async fn serve_mcp_options(
axum::extract::Extension(server): axum::extract::Extension<Arc<McpServer>>,
headers: HeaderMap,
) -> Response {
use axum::http::HeaderValue;
let cors = &server.cors;
let mut out = HeaderMap::new();
out.insert(header::VARY, HeaderValue::from_static("origin"));
let origin = headers.get(header::ORIGIN).and_then(|o| o.to_str().ok());
let Some(allow_origin) = cors_allow_origin(cors, origin) else {
out.insert(
header::ALLOW,
HeaderValue::from_static("GET, POST, OPTIONS"),
);
return (StatusCode::NO_CONTENT, out).into_response();
};
if let Ok(v) = HeaderValue::from_str(&allow_origin) {
out.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, v);
}
if let Ok(v) = HeaderValue::from_str(&cors.allowed_methods.join(", ")) {
out.insert(header::ACCESS_CONTROL_ALLOW_METHODS, v);
}
let mut allow_headers = cors.allowed_headers.clone();
for extra in MCP_REQUEST_HEADERS {
if !allow_headers.iter().any(|h| h.eq_ignore_ascii_case(extra)) {
allow_headers.push((*extra).to_owned());
}
}
if let Ok(v) = HeaderValue::from_str(&allow_headers.join(", ")) {
out.insert(header::ACCESS_CONTROL_ALLOW_HEADERS, v);
}
if let Ok(v) = HeaderValue::from_str(&cors.max_age_secs.to_string()) {
out.insert(header::ACCESS_CONTROL_MAX_AGE, v);
}
if cors.allow_credentials {
out.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
}
(StatusCode::NO_CONTENT, out).into_response()
}
fn cors_allow_origin(cors: &crate::config::CorsConfig, origin: Option<&str>) -> Option<String> {
let origin = origin?;
let allow_any = cors.allowed_origins.iter().any(|a| a == "*");
if !(allow_any || cors.allowed_origins.iter().any(|a| a == origin)) {
return None;
}
Some(if allow_any && !cors.allow_credentials {
"*".to_owned()
} else {
origin.to_owned()
})
}
fn apply_cors_headers(
cors: &crate::config::CorsConfig,
origin: Option<&str>,
response: &mut Response,
) {
use axum::http::HeaderValue;
let headers = response.headers_mut();
headers.insert(header::VARY, HeaderValue::from_static("origin"));
if let Some(allow_origin) = cors_allow_origin(cors, origin)
&& let Ok(v) = HeaderValue::from_str(&allow_origin)
{
headers.insert(header::ACCESS_CONTROL_ALLOW_ORIGIN, v);
if cors.allow_credentials {
headers.insert(
header::ACCESS_CONTROL_ALLOW_CREDENTIALS,
HeaderValue::from_static("true"),
);
}
}
}
async fn serve_mcp_get() -> Response {
(
StatusCode::METHOD_NOT_ALLOWED,
[(header::ALLOW, "POST")],
"MCP server-initiated streaming is not supported (POST JSON-RPC only)",
)
.into_response()
}
struct ReplayContext<'a> {
headers: &'a HeaderMap,
identity: Option<&'a crate::security::ResolvedClientIdentity>,
peer: Option<std::net::SocketAddr>,
}
async fn mcp_host_origin_guard(
server: Arc<McpServer>,
req: axum::extract::Request,
next: axum::middleware::Next,
) -> Response {
let identity = req
.extensions()
.get::<crate::security::ResolvedClientIdentity>();
let host = identity
.and_then(|id| id.host.as_deref())
.or_else(|| req.uri().authority().map(http::uri::Authority::as_str))
.or_else(|| {
req.headers()
.get(header::HOST)
.and_then(|h| h.to_str().ok())
});
let host_trusted = host
.and_then(crate::router::extract_host_without_port)
.map(|h| h.trim_end_matches('.').to_ascii_lowercase())
.filter(|h| !h.is_empty())
.map_or_else(
|| server.trusted_hosts.allows_missing_host(),
|h| server.trusted_hosts.allows_host(&h),
);
if !host_trusted {
return (StatusCode::BAD_REQUEST, "Invalid Host header").into_response();
}
if let Some(origin) = req.headers().get(header::ORIGIN) {
let origin = origin.to_str().unwrap_or("");
let scheme = identity.and_then(|id| id.scheme.as_deref());
if !server.origin_allowed(origin, host, scheme) {
return (StatusCode::FORBIDDEN, "origin not allowed").into_response();
}
}
next.run(req).await
}
async fn serve_mcp(
axum::extract::Extension(server): axum::extract::Extension<Arc<McpServer>>,
identity: Option<axum::extract::Extension<crate::security::ResolvedClientIdentity>>,
connect_info: Option<
axum::extract::Extension<axum::extract::ConnectInfo<std::net::SocketAddr>>,
>,
headers: HeaderMap,
body: Bytes,
) -> Response {
let identity = identity.as_ref().map(|ext| &ext.0);
let origin = headers
.get(header::ORIGIN)
.and_then(|o| o.to_str().ok())
.map(str::to_owned);
let parsed: Value = match serde_json::from_slice(&body) {
Ok(v) => v,
Err(e) => {
return json_response(&parse_error(&e.to_string()));
}
};
if let Some(rejection) = reject_unsupported_protocol_version(&headers, &parsed) {
return rejection;
}
let ctx = ReplayContext {
headers: &headers,
identity,
peer: connect_info.map(|ext| (ext.0).0),
};
let mut response = match parsed {
Value::Array(batch) if batch.is_empty() => {
json_response(&error(Value::Null, -32600, "Invalid Request: empty batch"))
}
Value::Array(batch)
if batch
.iter()
.any(|msg| msg.get("method").and_then(Value::as_str) == Some("tools/call")) =>
{
json_response(&error(
Value::Null,
-32600,
"Invalid Request: batched tools/call is not supported; \
send each tools/call as a single JSON-RPC request",
))
}
Value::Array(batch) => {
let mut out = Vec::new();
for msg in batch {
if let Some(resp) = handle_message(&server, &msg) {
out.push(resp);
}
}
if out.is_empty() {
StatusCode::ACCEPTED.into_response()
} else {
json_response(&Value::Array(out))
}
}
msg @ Value::Object(_) => {
if let Some((id, params)) = single_tools_call(&msg) {
serve_tools_call(&server, &ctx, id, params).await
} else {
handle_message(&server, &msg).map_or_else(
|| StatusCode::ACCEPTED.into_response(),
|v| json_response(&v),
)
}
}
_ => json_response(&error(
Value::Null,
-32600,
"Invalid Request: expected a JSON object or array",
)),
};
apply_cors_headers(&server.cors, origin.as_deref(), &mut response);
response
}
fn reject_unsupported_protocol_version(headers: &HeaderMap, parsed: &Value) -> Option<Response> {
let is_initialize = parsed
.as_object()
.and_then(|o| o.get("method"))
.and_then(Value::as_str)
== Some("initialize");
if is_initialize {
return None;
}
let version = headers.get("mcp-protocol-version")?.to_str().unwrap_or("");
if SUPPORTED_PROTOCOL_VERSIONS.contains(&version) {
return None;
}
Some(
(
StatusCode::BAD_REQUEST,
format!("unsupported MCP-Protocol-Version: {version}"),
)
.into_response(),
)
}
fn handle_message(server: &McpServer, msg: &Value) -> Option<Value> {
let id = msg.get("id").cloned();
let id_ok = id
.as_ref()
.is_none_or(|v| v.is_string() || v.is_number() || v.is_null());
let is_valid = msg.is_object()
&& msg.get("jsonrpc").and_then(Value::as_str) == Some("2.0")
&& msg.get("method").and_then(Value::as_str).is_some()
&& id_ok;
if !is_valid {
let err_id = match &id {
Some(v) if v.is_string() || v.is_number() => v.clone(),
_ => Value::Null,
};
return Some(error(err_id, -32600, "Invalid Request"));
}
let id = id?;
let method = msg.get("method").and_then(Value::as_str).unwrap_or("");
let params = msg.get("params").cloned().unwrap_or(Value::Null);
let result = match method {
"initialize" => Ok(initialize_result(server, ¶ms)),
"ping" => Ok(json!({})),
"tools/list" => Ok(tools_list_result(server)),
"tools/call" => Err((
-32600,
"tools/call must be sent as a single JSON-RPC request".to_owned(),
)),
other => Err((-32601, format!("method not found: {other}"))),
};
Some(match result {
Ok(value) => success(id, value),
Err((code, message)) => error(id, code, &message),
})
}
fn initialize_result(server: &McpServer, params: &Value) -> Value {
let protocol = match params.get("protocolVersion").and_then(Value::as_str) {
Some(requested) if SUPPORTED_PROTOCOL_VERSIONS.contains(&requested) => requested,
_ => DEFAULT_PROTOCOL_VERSION,
};
json!({
"protocolVersion": protocol,
"capabilities": { "tools": { "listChanged": false } },
"serverInfo": {
"name": server.server_name,
"version": server.server_version,
},
})
}
fn tools_list_result(server: &McpServer) -> Value {
let tools: Vec<Value> = server.tools.iter().map(McpTool::descriptor).collect();
json!({ "tools": tools })
}
fn single_tools_call(msg: &Value) -> Option<(Value, Value)> {
let obj = msg.as_object()?;
if obj.get("jsonrpc").and_then(Value::as_str) != Some("2.0") {
return None;
}
if obj.get("method").and_then(Value::as_str) != Some("tools/call") {
return None;
}
let id = obj.get("id")?;
if !(id.is_string() || id.is_number() || id.is_null()) {
return None;
}
let params = obj.get("params").cloned().unwrap_or(Value::Null);
Some((id.clone(), params))
}
async fn serve_tools_call(
server: &McpServer,
ctx: &ReplayContext<'_>,
id: Value,
params: Value,
) -> Response {
let name = params.get("name").and_then(Value::as_str).unwrap_or("");
let arguments = match params.get("arguments") {
None => json!({}),
Some(value) if value.is_object() => value.clone(),
Some(_) => return json_response(&error(id, -32602, "`arguments` must be a JSON object")),
};
let Some(&idx) = server.by_name.get(name) else {
return json_response(&error(id, -32602, &format!("unknown tool: {name}")));
};
let tool = &server.tools[idx];
let mut request = match build_request(
tool,
ctx.headers,
&arguments,
&server.csrf_header,
server.tenant_header.as_deref(),
) {
Ok(req) => req,
Err(message) => return json_response(&error(id, -32602, &message)),
};
if let Some(identity) = ctx.identity {
request.extensions_mut().insert(identity.clone());
}
if let Some(peer) = ctx.peer {
request
.extensions_mut()
.insert(axum::extract::ConnectInfo(peer));
}
if server.envelope_rate_limited {
request
.extensions_mut()
.insert(crate::security::RateLimitExempt);
}
let response = match server.dispatch.clone().oneshot(request).await {
Ok(resp) => resp,
Err(e) => {
return json_response(&success(id, tool_error(&format!("dispatch failed: {e}"))));
}
};
let status = response.status();
let is_event_stream = response
.headers()
.get(header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.is_some_and(|c| {
c.trim_start()
.to_ascii_lowercase()
.starts_with("text/event-stream")
});
let client_accepts_sse = ctx
.headers
.get(header::ACCEPT)
.and_then(|v| v.to_str().ok())
.is_some_and(accept_includes_event_stream);
let mut cookies: Vec<HeaderValue> = Vec::new();
for value in response.headers().get_all(header::SET_COOKIE) {
cookies.push(value.clone());
}
if tool.streams && status.is_success() && is_event_stream && client_accepts_sse {
return stream_tool_result(id, ¶ms, response, cookies);
}
let Ok(bytes) = axum::body::to_bytes(response.into_body(), MAX_TOOL_RESPONSE_BYTES).await
else {
return json_response(&success(
id,
tool_error(&format!(
"handler response exceeded the {MAX_TOOL_RESPONSE_BYTES}-byte MCP tool-result limit"
)),
));
};
let text = if is_event_stream {
collapse_sse_body(&bytes)
} else {
String::from_utf8_lossy(&bytes).into_owned()
};
let value = if status.is_success() {
success(id, tool_ok(&text))
} else {
success(
id,
tool_error(&format!(
"handler returned HTTP {}: {text}",
status.as_u16()
)),
)
};
let mut resp = json_response(&value);
for cookie in cookies {
resp.headers_mut().append(header::SET_COOKIE, cookie);
}
resp
}
fn accept_includes_event_stream(accept: &str) -> bool {
accept.split(',').any(|part| {
let media = part
.split(';')
.next()
.unwrap_or("")
.trim()
.to_ascii_lowercase();
media == "text/event-stream" || media == "text/*"
})
}
fn progress_notification(token: &Value, progress: f64, message: &str) -> Value {
if let Ok(Value::Object(map)) = serde_json::from_str::<Value>(message)
&& map.get("progress").is_some_and(Value::is_number)
{
let mut params = serde_json::Map::new();
params.insert("progressToken".into(), token.clone());
for key in ["progress", "total", "message"] {
if let Some(v) = map.get(key) {
params.insert((*key).to_owned(), v.clone());
}
}
return json!({
"jsonrpc": "2.0",
"method": "notifications/progress",
"params": Value::Object(params),
});
}
json!({
"jsonrpc": "2.0",
"method": "notifications/progress",
"params": { "progressToken": token, "progress": progress, "message": message },
})
}
enum ProjectionPhase {
Streaming,
Final,
Done,
}
struct StreamProjection {
body: Pin<Box<dyn Stream<Item = Result<Bytes, axum::Error>> + Send>>,
parser: SseWireParser,
ready: VecDeque<Event>,
progress_token: Option<Value>,
id: Value,
progress: f64,
progress_parts: Vec<String>,
result_parts: Vec<String>,
phase: ProjectionPhase,
}
fn stream_tool_result(
id: Value,
params: &Value,
response: Response,
cookies: Vec<HeaderValue>,
) -> Response {
let progress_token = params
.get("_meta")
.and_then(|m| m.get("progressToken"))
.cloned();
let state = StreamProjection {
body: Box::pin(response.into_body().into_data_stream()),
parser: SseWireParser::new(),
ready: VecDeque::new(),
progress_token,
id,
progress: 0.0,
progress_parts: Vec::new(),
result_parts: Vec::new(),
phase: ProjectionPhase::Streaming,
};
let stream = futures::stream::unfold(state, project_next);
let mut resp = Sse::new(stream)
.keep_alive(crate::sse::keep_alive())
.into_response();
for cookie in cookies {
resp.headers_mut().append(header::SET_COOKIE, cookie);
}
resp
}
async fn project_next(
mut st: StreamProjection,
) -> Option<(Result<Event, Infallible>, StreamProjection)> {
loop {
if let Some(event) = st.ready.pop_front() {
return Some((Ok(event), st));
}
match st.phase {
ProjectionPhase::Done => return None,
ProjectionPhase::Final => {
let content = if st.result_parts.is_empty() {
st.progress_parts.join("\n")
} else {
st.result_parts.concat()
};
let value = success(st.id.clone(), tool_ok(&content));
st.phase = ProjectionPhase::Done;
return Some((Ok(Event::default().data(value.to_string())), st));
}
ProjectionPhase::Streaming => {
if let Some(Ok(bytes)) = st.body.next().await {
let events = st.parser.push(&bytes);
enqueue_projected(&mut st, events);
} else {
let trailing = st.parser.finish();
enqueue_projected(&mut st, trailing);
st.phase = ProjectionPhase::Final;
}
}
}
}
}
fn enqueue_projected(st: &mut StreamProjection, events: Vec<ParsedSseEvent>) {
for ev in events {
if ev.event.as_deref() == Some("result") {
st.result_parts.push(ev.data);
continue;
}
st.progress_parts.push(ev.data.clone());
if let Some(token) = &st.progress_token {
st.progress += 1.0;
let note = progress_notification(token, st.progress, &ev.data);
st.ready.push_back(Event::default().data(note.to_string()));
}
}
}
struct ParsedSseEvent {
event: Option<String>,
data: String,
}
struct SseWireParser {
buffer: String,
event_type: Option<String>,
data_lines: Vec<String>,
has_fields: bool,
}
impl SseWireParser {
const fn new() -> Self {
Self {
buffer: String::new(),
event_type: None,
data_lines: Vec::new(),
has_fields: false,
}
}
fn push(&mut self, bytes: &[u8]) -> Vec<ParsedSseEvent> {
self.buffer.push_str(&String::from_utf8_lossy(bytes));
let mut out = Vec::new();
while let Some(pos) = self.buffer.find('\n') {
let line: String = self.buffer.drain(..=pos).collect();
if let Some(event) = self.process_line(line.trim_end_matches(['\n', '\r'])) {
out.push(event);
}
}
out
}
fn finish(&mut self) -> Vec<ParsedSseEvent> {
let mut out = Vec::new();
if !self.buffer.is_empty() {
let line = std::mem::take(&mut self.buffer);
if let Some(event) = self.process_line(line.trim_end_matches(['\n', '\r'])) {
out.push(event);
}
}
if let Some(event) = self.dispatch() {
out.push(event);
}
out
}
fn process_line(&mut self, line: &str) -> Option<ParsedSseEvent> {
if line.is_empty() {
return self.dispatch();
}
if line.starts_with(':') {
return None;
}
let (field, value) = match line.split_once(':') {
Some((f, v)) => (f, v.strip_prefix(' ').unwrap_or(v)),
None => (line, ""),
};
match field {
"event" => {
self.event_type = Some(value.to_owned());
self.has_fields = true;
}
"data" => {
self.data_lines.push(value.to_owned());
self.has_fields = true;
}
_ => {}
}
None
}
fn dispatch(&mut self) -> Option<ParsedSseEvent> {
if !self.has_fields {
return None;
}
let event = ParsedSseEvent {
event: self.event_type.take(),
data: self.data_lines.join("\n"),
};
self.data_lines.clear();
self.has_fields = false;
Some(event)
}
}
fn collapse_sse_body(bytes: &[u8]) -> String {
let mut parser = SseWireParser::new();
let mut events = parser.push(bytes);
events.extend(parser.finish());
let (results, progress): (Vec<_>, Vec<_>) = events
.into_iter()
.partition(|e| e.event.as_deref() == Some("result"));
if results.is_empty() {
progress
.into_iter()
.map(|e| e.data)
.collect::<Vec<_>>()
.join("\n")
} else {
results.into_iter().map(|e| e.data).collect()
}
}
fn build_request(
tool: &McpTool,
headers: &HeaderMap,
arguments: &Value,
csrf_header: &str,
tenant_header: Option<&str>,
) -> Result<axum::http::Request<Body>, String> {
let mut path = tool.path_template.clone();
for param in &tool.path_params {
let is_catch_all = param.starts_with('*');
let arg_key = param.strip_prefix('*').unwrap_or(param);
let raw = arguments
.get(arg_key)
.ok_or_else(|| format!("missing required path parameter `{arg_key}`"))?;
let value = match raw {
Value::String(s) => s.clone(),
Value::Number(_) | Value::Bool(_) => raw.to_string(),
_ => return Err(format!("path parameter `{arg_key}` must be a string")),
};
let encoded = if is_catch_all {
value
.split('/')
.map(crate::paths::encode_path_segment)
.collect::<Vec<_>>()
.join("/")
} else {
crate::paths::encode_path_segment(&value)
};
path = replace_path_param(&path, param, &encoded);
}
if tool.has_query
&& let Some(query) = arguments.get("query")
{
let Value::Object(map) = query else {
return Err("`query` must be a JSON object".to_owned());
};
let mut pairs: Vec<(String, String)> = Vec::new();
for (key, value) in map {
match value {
Value::Array(items) => {
for item in items {
pairs.push((key.clone(), query_scalar(item)));
}
}
other => pairs.push((key.clone(), query_scalar(other))),
}
}
if !pairs.is_empty() {
let qs = serde_urlencoded::to_string(&pairs)
.map_err(|e| format!("invalid query arguments: {e}"))?;
path = format!("{path}?{qs}");
}
}
let mut builder = axum::http::Request::builder()
.method(tool.method.as_str())
.uri(&path);
for name in FORWARDED_HEADERS {
for value in headers.get_all(*name) {
builder = builder.header(*name, value);
}
}
if let Some(value) = headers.get(csrf_header) {
builder = builder.header(csrf_header, value);
}
if let Some(name) = tenant_header
&& let Some(value) = headers.get(name)
{
builder = builder.header(name, value);
}
let body = if tool.has_body {
let payload = arguments
.get("body")
.ok_or_else(|| "missing required `body` argument".to_owned())?;
builder = builder.header(header::CONTENT_TYPE, "application/json");
Body::from(serde_json::to_vec(payload).unwrap_or_default())
} else {
Body::empty()
};
builder
.body(body)
.map_err(|e| format!("invalid request: {e}"))
}
fn query_scalar(value: &Value) -> String {
match value {
Value::String(s) => s.clone(),
other => other.to_string(),
}
}
fn replace_path_param(path: &str, name: &str, value: &str) -> String {
let mut out = String::with_capacity(path.len());
let mut rest = path;
while let Some(start) = rest.find('{') {
out.push_str(&rest[..start]);
let after = &rest[start + 1..];
if let Some(end) = after.find('}') {
let inner = &after[..end];
let capture = inner.split(':').next().unwrap_or(inner).trim();
if capture == name {
out.push_str(value);
} else {
out.push('{');
out.push_str(inner);
out.push('}');
}
rest = &after[end + 1..];
} else {
out.push_str(&rest[start..]);
return out;
}
}
out.push_str(rest);
out
}
fn tool_ok(text: &str) -> Value {
json!({
"content": [ { "type": "text", "text": text } ],
"isError": false,
})
}
fn tool_error(text: &str) -> Value {
json!({
"content": [ { "type": "text", "text": text } ],
"isError": true,
})
}
fn success(id: Value, result: Value) -> Value {
let mut obj = serde_json::Map::new();
obj.insert("jsonrpc".into(), json!("2.0"));
obj.insert("id".into(), id);
obj.insert("result".into(), result);
Value::Object(obj)
}
fn error(id: Value, code: i64, message: &str) -> Value {
let mut obj = serde_json::Map::new();
obj.insert("jsonrpc".into(), json!("2.0"));
obj.insert("id".into(), id);
obj.insert("error".into(), json!({ "code": code, "message": message }));
Value::Object(obj)
}
fn parse_error(message: &str) -> Value {
json!({
"jsonrpc": "2.0",
"id": Value::Null,
"error": { "code": -32700, "message": format!("parse error: {message}") },
})
}
fn json_response(value: &Value) -> Response {
(
[(header::CONTENT_TYPE, "application/json")],
serde_json::to_string(value).unwrap_or_else(|_| "{}".to_owned()),
)
.into_response()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::openapi::{SchemaEntry, SchemaKind};
fn doc(method: &'static str, path: &'static str, op: &'static str) -> ApiDoc {
ApiDoc {
method,
path,
operation_id: op,
success_status: 200,
response: Some(SchemaEntry {
name: "Todo",
kind: SchemaKind::Ref,
}),
..Default::default()
}
}
#[test]
fn opt_in_required_without_hatch() {
let mut d = doc("GET", "/a", "a");
assert!(!should_expose(&d, false), "no opt-in => not exposed");
d.mcp_tool = true;
assert!(should_expose(&d, false));
}
#[test]
fn exclude_always_wins() {
let mut d = doc("GET", "/a", "a");
d.mcp_tool = true;
d.mcp_exclude = true;
assert!(!should_expose(&d, false));
assert!(!should_expose(&d, true));
}
#[test]
fn hatch_includes_reads_excludes_unopted_writes() {
let read = doc("GET", "/a", "a");
let write = doc("POST", "/a", "b");
assert!(should_expose(&read, true));
assert!(!should_expose(&write, true), "mutating verb needs opt-in");
}
#[test]
fn hatch_still_allows_opted_in_writes() {
let mut write = doc("POST", "/a", "b");
write.mcp_tool = true;
assert!(should_expose(&write, true));
}
#[test]
fn html_routes_are_ineligible() {
let mut d = doc("GET", "/page", "page");
d.response = None; d.mcp_tool = true;
assert!(!should_expose(&d, false));
}
#[test]
fn streaming_tool_is_eligible_without_response_schema() {
let mut d = doc("GET", "/api/search", "search");
d.response = None;
d.mcp_stream = true;
assert!(
!should_expose(&d, false),
"stream without opt-in stays hidden"
);
d.mcp_tool = true;
assert!(
should_expose(&d, false),
"opted-in streaming tool is exposed"
);
d.mcp_exclude = true;
assert!(!should_expose(&d, false));
}
#[test]
fn streaming_get_is_included_under_the_hatch() {
let mut d = doc("GET", "/api/search", "search");
d.response = None;
d.mcp_stream = true;
assert!(should_expose(&d, true));
let mut w = doc("POST", "/api/search", "search2");
w.response = None;
w.mcp_stream = true;
assert!(!should_expose(&w, true));
}
#[test]
fn accept_header_gates_sse() {
assert!(accept_includes_event_stream(
"application/json, text/event-stream"
));
assert!(accept_includes_event_stream("text/event-stream;q=1.0"));
assert!(accept_includes_event_stream("text/*"));
assert!(!accept_includes_event_stream("application/json"));
assert!(!accept_includes_event_stream("*/*"));
}
#[test]
fn sse_parser_splits_frames_and_joins_data() {
let mut p = SseWireParser::new();
let mut events = p.push(b"data: line1\ndata: line2\n\n");
events.extend(p.push(b": keep-alive\n\nevent: result\ndata: final\n\n"));
assert_eq!(events.len(), 2);
assert_eq!(events[0].event, None);
assert_eq!(events[0].data, "line1\nline2");
assert_eq!(events[1].event.as_deref(), Some("result"));
assert_eq!(events[1].data, "final");
}
#[test]
fn sse_parser_handles_chunk_boundaries_mid_frame() {
let mut p = SseWireParser::new();
assert!(p.push(b"data: hel").is_empty());
assert!(p.push(b"lo\n").is_empty());
let events = p.push(b"\n");
assert_eq!(events.len(), 1);
assert_eq!(events[0].data, "hello");
}
#[test]
fn progress_notification_plain_and_structured() {
let token = json!("tok");
let plain = progress_notification(&token, 2.0, "working");
assert_eq!(plain["method"], "notifications/progress");
assert_eq!(plain["params"]["progressToken"], "tok");
assert_eq!(plain["params"]["progress"], 2.0);
assert_eq!(plain["params"]["message"], "working");
let structured = progress_notification(
&token,
99.0,
r#"{"progress":50,"total":100,"message":"half"}"#,
);
assert_eq!(structured["params"]["progress"], 50);
assert_eq!(structured["params"]["total"], 100);
assert_eq!(structured["params"]["message"], "half");
}
#[test]
fn collapse_sse_body_prefers_result_frames() {
let joined = collapse_sse_body(b"data: a\n\ndata: b\n\n");
assert_eq!(joined, "a\nb");
let result = collapse_sse_body(b"data: progress\n\nevent: result\ndata: done\n\n");
assert_eq!(result, "done");
}
#[test]
fn annotations_track_method() {
assert_eq!(annotations_for("GET", "t")["readOnlyHint"], json!(true));
assert_eq!(annotations_for("POST", "t")["readOnlyHint"], json!(false));
assert_eq!(
annotations_for("DELETE", "t")["destructiveHint"],
json!(true)
);
assert!(
annotations_for("POST", "t")
.get("destructiveHint")
.is_none()
);
}
#[test]
fn input_schema_includes_path_param_and_body() {
let mut d = doc("POST", "/users/{id}", "create");
d.path_params = &["id"];
d.request_body = Some(SchemaEntry {
name: "NewUser",
kind: SchemaKind::Ref,
});
let schema = build_input_schema(&d, &serde_json::Map::new());
assert_eq!(schema["type"], "object");
assert!(schema["properties"]["id"].is_object());
assert!(schema["properties"]["body"].is_object());
let required = schema["required"].as_array().unwrap();
assert!(required.contains(&json!("id")));
assert!(required.contains(&json!("body")));
}
#[test]
fn replace_path_param_handles_regex_captures() {
assert_eq!(replace_path_param("/u/{id}", "id", "7"), "/u/7");
assert_eq!(replace_path_param("/u/{id:[0-9]+}", "id", "7"), "/u/7");
assert_eq!(
replace_path_param("/u/{id}/p/{pid}", "pid", "9"),
"/u/{id}/p/9"
);
}
fn tool(method: &str, path: &str, has_body: bool, has_query: bool) -> McpTool {
McpTool {
name: "t".to_owned(),
description: None,
input_schema: json!({}),
annotations: json!({}),
method: method.to_owned(),
path_template: path.to_owned(),
path_params: Vec::new(),
has_body,
has_query,
streams: false,
}
}
#[test]
fn build_request_rejects_missing_required_body() {
let t = tool("POST", "/api/todos", true, false);
let err =
build_request(&t, &HeaderMap::new(), &json!({}), "x-csrf-token", None).unwrap_err();
assert!(err.contains("body"), "got: {err}");
}
#[test]
fn build_request_explodes_array_query_into_repeated_keys() {
let t = tool("GET", "/api/search", false, true);
let req = build_request(
&t,
&HeaderMap::new(),
&json!({ "query": { "tags": ["a", "b"], "q": "x" } }),
"x-csrf-token",
None,
)
.expect("request builds");
let query = req.uri().query().unwrap_or_default();
assert!(query.contains("tags=a"), "got: {query}");
assert!(query.contains("tags=b"), "got: {query}");
assert!(query.contains("q=x"), "got: {query}");
assert!(
!query.contains("%5B"), "array must explode, not serialize as JSON: {query}"
);
}
#[test]
fn build_request_forwards_authorization_and_cookie() {
let t = tool("GET", "/secure", false, false);
let mut headers = HeaderMap::new();
headers.insert(header::AUTHORIZATION, "Bearer tok".parse().unwrap());
headers.insert(header::COOKIE, "autumn.sid=abc".parse().unwrap());
let req =
build_request(&t, &headers, &json!({}), "x-csrf-token", None).expect("request builds");
assert_eq!(
req.headers().get(header::AUTHORIZATION).unwrap(),
"Bearer tok"
);
assert_eq!(req.headers().get(header::COOKIE).unwrap(), "autumn.sid=abc");
}
#[test]
fn build_request_forwards_csrf_token() {
let t = tool("POST", "/api/todos", true, false);
let mut headers = HeaderMap::new();
headers.insert("x-csrf-token", "csrf123".parse().unwrap());
let req = build_request(
&t,
&headers,
&json!({ "body": { "x": 1 } }),
"x-csrf-token",
None,
)
.expect("request builds");
assert_eq!(req.headers().get("x-csrf-token").unwrap(), "csrf123");
}
#[test]
fn build_request_forwards_configured_csrf_header() {
let t = tool("POST", "/api/todos", true, false);
let mut headers = HeaderMap::new();
headers.insert("x-xsrf-token", "csrf123".parse().unwrap());
let req = build_request(
&t,
&headers,
&json!({ "body": { "x": 1 } }),
"x-xsrf-token",
None,
)
.expect("request builds");
assert_eq!(req.headers().get("x-xsrf-token").unwrap(), "csrf123");
}
#[test]
fn build_request_preserves_slashes_for_catch_all_param() {
let mut t = tool("GET", "/files/{*path}", false, false);
t.path_params = vec!["*path".to_owned()];
let req = build_request(
&t,
&HeaderMap::new(),
&json!({ "path": "a/b c/d.txt" }),
"x-csrf-token",
None,
)
.expect("request builds");
assert_eq!(req.uri().path(), "/files/a/b%20c/d.txt");
}
#[test]
fn build_request_forwards_configured_tenant_header() {
let t = tool("GET", "/api/todos", false, false);
let mut headers = HeaderMap::new();
headers.insert("x-tenant-id", "acme".parse().unwrap());
let req = build_request(
&t,
&headers,
&json!({}),
"x-csrf-token",
Some("x-tenant-id"),
)
.expect("request builds");
assert_eq!(req.headers().get("x-tenant-id").unwrap(), "acme");
let req =
build_request(&t, &headers, &json!({}), "x-csrf-token", None).expect("request builds");
assert!(req.headers().get("x-tenant-id").is_none());
}
#[test]
fn build_request_rejects_non_object_query() {
let t = tool("GET", "/api/search", false, true);
for bad in [
json!({ "query": null }),
json!({ "query": "all" }),
json!({ "query": [1, 2] }),
] {
let err = build_request(&t, &HeaderMap::new(), &bad, "x-csrf-token", None).unwrap_err();
assert!(err.contains("query"), "got: {err}");
}
assert!(build_request(&t, &HeaderMap::new(), &json!({}), "x-csrf-token", None).is_ok());
}
#[test]
fn build_request_forwards_identity_and_idempotency_headers() {
let t = tool("POST", "/api/todos", true, false);
let mut headers = HeaderMap::new();
headers.insert(header::HOST, "tenant1.example.com".parse().unwrap());
headers.insert("x-forwarded-for", "203.0.113.7".parse().unwrap());
headers.insert("x-forwarded-host", "tenant1.example.com".parse().unwrap());
headers.insert("x-real-ip", "203.0.113.7".parse().unwrap());
headers.insert("idempotency-key", "abc-123".parse().unwrap());
let req = build_request(
&t,
&headers,
&json!({ "body": { "x": 1 } }),
"x-csrf-token",
None,
)
.expect("request builds");
assert_eq!(
req.headers().get(header::HOST).unwrap(),
"tenant1.example.com"
);
assert_eq!(req.headers().get("x-forwarded-for").unwrap(), "203.0.113.7");
assert_eq!(req.headers().get("x-real-ip").unwrap(), "203.0.113.7");
assert_eq!(req.headers().get("idempotency-key").unwrap(), "abc-123");
}
#[test]
fn build_request_forwards_accept_language() {
let t = tool("GET", "/api/todos", false, false);
let mut headers = HeaderMap::new();
headers.insert("accept-language", "fr-CA,fr;q=0.9".parse().unwrap());
let req =
build_request(&t, &headers, &json!({}), "x-csrf-token", None).expect("request builds");
assert_eq!(
req.headers().get("accept-language").unwrap(),
"fr-CA,fr;q=0.9"
);
}
#[test]
fn build_request_preserves_repeated_cookie_headers() {
let t = tool("POST", "/api/todos", true, false);
let mut headers = HeaderMap::new();
headers.append("cookie", "session=abc".parse().unwrap());
headers.append("cookie", "csrf=dup1".parse().unwrap());
headers.append("cookie", "csrf=dup2".parse().unwrap());
let req = build_request(
&t,
&headers,
&json!({ "body": { "x": 1 } }),
"x-csrf-token",
None,
)
.expect("request builds");
let cookies: Vec<_> = req
.headers()
.get_all("cookie")
.iter()
.map(|v| v.to_str().unwrap().to_owned())
.collect();
assert_eq!(cookies, ["session=abc", "csrf=dup1", "csrf=dup2"]);
}
fn trusted(hosts: &[&str]) -> crate::router::TrustedHostPolicy {
let mut config = crate::config::AutumnConfig::default();
config.security.trusted_hosts.hosts = hosts.iter().map(|h| (*h).to_owned()).collect();
crate::router::TrustedHostPolicy::from_config(&config)
}
fn server(allowed_origins: Vec<String>) -> McpServer {
server_with_trusted(allowed_origins, &[])
}
fn server_with_trusted(allowed_origins: Vec<String>, hosts: &[&str]) -> McpServer {
let cors = crate::config::CorsConfig {
allowed_origins,
..crate::config::CorsConfig::default()
};
McpServer::new(
Vec::new(),
axum::Router::new(),
McpWiring {
cors,
trusted_hosts: trusted(hosts),
tenant_header: None,
csrf_header: "x-csrf-token".to_owned(),
envelope_rate_limited: false,
},
)
}
#[test]
fn origin_allowlist_enforced() {
let s = server(vec!["https://ok.example".to_owned()]);
assert!(s.origin_allowed("https://ok.example", None, None));
assert!(!s.origin_allowed("https://evil.example", None, None));
assert!(!server(Vec::new()).origin_allowed("https://any.example", None, None));
assert!(server(vec!["*".to_owned()]).origin_allowed("https://any.example", None, None));
}
#[test]
fn same_origin_allowed_without_cors_allowlist() {
let s = server_with_trusted(Vec::new(), &["app.example"]);
assert!(s.origin_allowed("https://app.example", Some("app.example"), None));
assert!(s.origin_allowed("https://app.example", Some("app.example"), Some("https")));
assert!(s.origin_allowed(
"http://localhost:8080",
Some("localhost:8080"),
Some("http")
));
assert!(!s.origin_allowed("https://evil.example", Some("app.example"), None));
assert!(!s.origin_allowed("http://app.example", Some("app.example"), Some("https")));
}
#[test]
fn same_origin_normalizes_default_ports() {
let s = server_with_trusted(Vec::new(), &["app.example"]);
assert!(s.origin_allowed(
"https://app.example",
Some("app.example:443"),
Some("https")
));
assert!(s.origin_allowed(
"https://app.example:443",
Some("app.example"),
Some("https")
));
assert!(s.origin_allowed("http://app.example", Some("app.example:80"), Some("http")));
assert!(!s.origin_allowed(
"https://app.example",
Some("app.example:8443"),
Some("https")
));
assert!(!s.origin_allowed("http://app.example:443", Some("app.example"), Some("http")));
}
#[test]
fn same_origin_rejected_for_untrusted_host() {
let s = server_with_trusted(Vec::new(), &["app.example"]);
assert!(!s.origin_allowed(
"http://attacker.example",
Some("attacker.example"),
Some("http")
));
let s = server_with_trusted(vec!["http://attacker.example".to_owned()], &["app.example"]);
assert!(s.origin_allowed(
"http://attacker.example",
Some("attacker.example"),
Some("http")
));
}
#[tokio::test]
async fn options_preflight_grants_only_allowlisted_origin() {
let s = Arc::new(server_with_trusted(
vec!["https://app.example".to_owned()],
&[],
));
let mut headers = HeaderMap::new();
headers.insert(header::ORIGIN, "https://app.example".parse().unwrap());
let resp = serve_mcp_options(axum::extract::Extension(s.clone()), headers).await;
assert_eq!(resp.status(), StatusCode::NO_CONTENT);
assert_eq!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.unwrap(),
"https://app.example"
);
assert!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_METHODS)
.is_some()
);
let allow_headers = resp
.headers()
.get(header::ACCESS_CONTROL_ALLOW_HEADERS)
.and_then(|v| v.to_str().ok())
.unwrap()
.to_ascii_lowercase();
assert!(
allow_headers.contains("mcp-protocol-version"),
"allow-headers missing MCP-Protocol-Version: {allow_headers}"
);
let mut headers = HeaderMap::new();
headers.insert(header::ORIGIN, "https://evil.example".parse().unwrap());
let resp = serve_mcp_options(axum::extract::Extension(s), headers).await;
assert_eq!(resp.status(), StatusCode::NO_CONTENT);
assert!(
resp.headers()
.get(header::ACCESS_CONTROL_ALLOW_ORIGIN)
.is_none()
);
}
#[test]
fn initialize_negotiates_supported_protocol_version() {
let s = server(Vec::new());
let echoed = initialize_result(&s, &json!({ "protocolVersion": "2024-11-05" }));
assert_eq!(echoed["protocolVersion"], "2024-11-05");
let fallback = initialize_result(&s, &json!({ "protocolVersion": "3999-01-01" }));
assert_eq!(fallback["protocolVersion"], DEFAULT_PROTOCOL_VERSION);
}
}