use std::{
error::Error,
io,
sync::{Arc, Mutex},
};
use axum::{
Router,
body::Body,
extract::{Request, State},
http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode},
middleware::Next,
response::Response,
};
use mimobox_mcp::MimoboxServer;
use rmcp::transport::{
StreamableHttpServerConfig, StreamableHttpService,
streamable_http_server::session::local::LocalSessionManager,
};
use tokio::signal::unix::{SignalKind, signal};
type HttpResult<T> = Result<T, Box<dyn Error + Send + Sync>>;
type McpHttpService = StreamableHttpService<MimoboxServer, LocalSessionManager>;
type ServerRegistry = Arc<Mutex<Vec<MimoboxServer>>>;
type AllowedOrigins = Arc<Vec<String>>;
const MAX_CONCURRENT_SESSIONS: usize = 100;
const WILDCARD_ORIGIN: &str = "*";
pub async fn run_http_server(
bind_addr: &str,
port: u16,
allowed_origins: Option<String>,
) -> HttpResult<()> {
tracing::warn!("HTTP 模式未启用认证,请勿在公网环境直接暴露。仅限本地开发和受信网络使用。");
if !is_local_bind_addr(bind_addr) {
tracing::warn!(
bind_addr,
"MCP HTTP 绑定地址不是本地回环地址,可能暴露到不受信网络"
);
}
let allowed_origins = Arc::new(parse_allowed_origins(allowed_origins));
let server_registry = Arc::new(Mutex::new(Vec::new()));
let service = create_mcp_service(server_registry.clone(), bind_addr);
let app =
Router::new()
.route_service("/mcp", service)
.layer(axum::middleware::from_fn_with_state(
allowed_origins,
cors_middleware,
));
let listener = tokio::net::TcpListener::bind((bind_addr, port)).await?;
let local_addr = listener.local_addr()?;
let mut sigterm = signal(SignalKind::terminate())?;
let mut sigint = signal(SignalKind::interrupt())?;
tracing::info!("MCP HTTP server listening on {local_addr}");
tracing::info!("MCP endpoint: http://{local_addr}/mcp");
axum::serve(listener, app)
.with_graceful_shutdown(async move {
tokio::select! {
_ = sigterm.recv() => {
tracing::info!("Received SIGTERM, cleaning up sandboxes...");
}
_ = sigint.recv() => {
tracing::info!("Received SIGINT, cleaning up sandboxes...");
}
}
cleanup_registered_servers(server_registry).await;
})
.await?;
Ok(())
}
fn create_mcp_service(server_registry: ServerRegistry, bind_addr: &str) -> McpHttpService {
let session_manager = Arc::new(LocalSessionManager::default());
let config = StreamableHttpServerConfig::default()
.with_stateful_mode(true)
.with_allowed_hosts(allowed_hosts(bind_addr));
StreamableHttpService::new(
move || {
let server = MimoboxServer::new();
register_server(&server_registry, server.clone())?;
Ok(server)
},
session_manager,
config,
)
}
fn register_server(server_registry: &ServerRegistry, server: MimoboxServer) -> io::Result<()> {
let mut servers = server_registry
.lock()
.map_err(|_| io::Error::other("MCP HTTP server registry lock poisoned"))?;
if servers.len() >= MAX_CONCURRENT_SESSIONS {
tracing::warn!(
max_sessions = MAX_CONCURRENT_SESSIONS,
"MCP HTTP session registry 已达到上限,移除最早的 server handle"
);
drop(servers.remove(0));
}
servers.push(server);
Ok(())
}
fn is_local_bind_addr(bind_addr: &str) -> bool {
matches!(bind_addr, "127.0.0.1" | "::1")
}
fn allowed_hosts(bind_addr: &str) -> Vec<String> {
let mut hosts = vec![
"localhost".to_string(),
"127.0.0.1".to_string(),
"::1".to_string(),
];
if !hosts.iter().any(|host| host == bind_addr) {
hosts.push(bind_addr.to_string());
}
hosts
}
fn parse_allowed_origins(allowed_origins: Option<String>) -> Vec<String> {
match allowed_origins {
Some(origins) if origins.contains(WILDCARD_ORIGIN) => {
tracing::warn!("CORS 配置为完全开放模式(*),请勿在生产环境使用");
vec![WILDCARD_ORIGIN.to_string()]
}
Some(origins) => {
let parsed = origins
.split(',')
.map(str::trim)
.filter(|origin| !origin.is_empty())
.map(ToOwned::to_owned)
.collect::<Vec<_>>();
warn_non_local_origins(&parsed);
parsed
}
None => vec![
"http://localhost".to_string(),
"http://127.0.0.1".to_string(),
],
}
}
fn warn_non_local_origins(origins: &[String]) {
for origin in origins {
if !is_local_origin(origin) {
tracing::warn!(origin, "CORS 允许非本地 origin,请确认仅用于受信客户端");
}
}
}
fn is_local_origin(origin: &str) -> bool {
let Some(host_part) = origin
.strip_prefix("http://")
.or_else(|| origin.strip_prefix("https://"))
else {
return false;
};
let host_with_port = match host_part.split('/').next() {
Some(host_with_port) => host_with_port,
None => host_part,
};
let host = if let Some(ipv6_part) = host_with_port.strip_prefix('[') {
match ipv6_part.split_once(']') {
Some((host, _)) => host,
None => host_with_port,
}
} else {
match host_with_port.split_once(':') {
Some((host, _)) => host,
None => host_with_port,
}
};
matches!(host, "localhost" | "127.0.0.1" | "::1")
}
async fn cleanup_registered_servers(server_registry: ServerRegistry) {
let servers = match server_registry.lock() {
Ok(mut servers) => std::mem::take(&mut *servers),
Err(_) => {
tracing::error!("MCP HTTP server registry lock poisoned, skip sandbox cleanup");
return;
}
};
for server in servers {
server.cleanup_all().await;
}
}
async fn cors_middleware(
State(allowed_origins): State<AllowedOrigins>,
req: Request,
next: Next,
) -> Response {
let allowed_origin = allowed_origin_header(req.headers(), allowed_origins.as_slice()).cloned();
if req.method() == Method::OPTIONS {
return cors_response(allowed_origin.as_ref());
}
let mut response = next.run(req).await;
apply_cors_headers(response.headers_mut(), allowed_origin.as_ref());
response
}
fn cors_response(allowed_origin: Option<&HeaderValue>) -> Response {
let mut response = Response::new(Body::empty());
*response.status_mut() = StatusCode::OK;
apply_cors_headers(response.headers_mut(), allowed_origin);
response
}
fn allowed_origin_header<'a>(
headers: &'a HeaderMap,
allowed_origins: &[String],
) -> Option<&'a HeaderValue> {
let origin = headers.get(HeaderName::from_static("origin"))?;
if is_origin_allowed(origin, allowed_origins) {
Some(origin)
} else {
None
}
}
fn is_origin_allowed(origin: &HeaderValue, allowed_origins: &[String]) -> bool {
if allowed_origins
.iter()
.any(|allowed_origin| allowed_origin == WILDCARD_ORIGIN)
{
return true;
}
let Ok(origin) = origin.to_str() else {
return false;
};
allowed_origins
.iter()
.any(|allowed_origin| allowed_origin == origin)
}
fn apply_cors_headers(headers: &mut HeaderMap, allowed_origin: Option<&HeaderValue>) {
if let Some(origin) = allowed_origin {
headers.insert(
HeaderName::from_static("access-control-allow-origin"),
origin.clone(),
);
}
headers.insert(
HeaderName::from_static("access-control-allow-methods"),
HeaderValue::from_static("GET, POST, DELETE, OPTIONS"),
);
headers.insert(
HeaderName::from_static("access-control-allow-headers"),
HeaderValue::from_static(
"Content-Type, Accept, Mcp-Session-Id, Mcp-Protocol-Version, Last-Event-ID",
),
);
headers.insert(
HeaderName::from_static("access-control-expose-headers"),
HeaderValue::from_static("Mcp-Session-Id, Mcp-Protocol-Version"),
);
}