mimobox-mcp 0.1.0-alpha

MimoBox MCP Server for AI agent sandbox integration
Documentation
//! HTTP 传输模块,提供 MCP Streamable HTTP 端点。

use std::{
    error::Error,
    io,
    sync::{Arc, Mutex},
};

use axum::{
    Router,
    body::Body,
    extract::{Request, State},
    http::{HeaderMap, HeaderName, HeaderValue, Method, StatusCode},
    middleware::Next,
    response::Response,
};
use mimobox_mcp::MimoboxServer;
use rmcp::transport::{
    StreamableHttpServerConfig, StreamableHttpService,
    streamable_http_server::session::local::LocalSessionManager,
};
use tokio::signal::unix::{SignalKind, signal};

type HttpResult<T> = Result<T, Box<dyn Error + Send + Sync>>;
type McpHttpService = StreamableHttpService<MimoboxServer, LocalSessionManager>;
type ServerRegistry = Arc<Mutex<Vec<MimoboxServer>>>;
type AllowedOrigins = Arc<Vec<String>>;

const MAX_CONCURRENT_SESSIONS: usize = 100;
const WILDCARD_ORIGIN: &str = "*";

/// 启动 MCP HTTP 服务器。
pub async fn run_http_server(
    bind_addr: &str,
    port: u16,
    allowed_origins: Option<String>,
) -> HttpResult<()> {
    tracing::warn!("HTTP 模式未启用认证,请勿在公网环境直接暴露。仅限本地开发和受信网络使用。");
    if !is_local_bind_addr(bind_addr) {
        tracing::warn!(
            bind_addr,
            "MCP HTTP 绑定地址不是本地回环地址,可能暴露到不受信网络"
        );
    }

    let allowed_origins = Arc::new(parse_allowed_origins(allowed_origins));
    let server_registry = Arc::new(Mutex::new(Vec::new()));
    let service = create_mcp_service(server_registry.clone(), bind_addr);
    let app =
        Router::new()
            .route_service("/mcp", service)
            .layer(axum::middleware::from_fn_with_state(
                allowed_origins,
                cors_middleware,
            ));

    let listener = tokio::net::TcpListener::bind((bind_addr, port)).await?;
    let local_addr = listener.local_addr()?;
    let mut sigterm = signal(SignalKind::terminate())?;
    let mut sigint = signal(SignalKind::interrupt())?;

    tracing::info!("MCP HTTP server listening on {local_addr}");
    tracing::info!("MCP endpoint: http://{local_addr}/mcp");

    axum::serve(listener, app)
        .with_graceful_shutdown(async move {
            tokio::select! {
                _ = sigterm.recv() => {
                    tracing::info!("Received SIGTERM, cleaning up sandboxes...");
                }
                _ = sigint.recv() => {
                    tracing::info!("Received SIGINT, cleaning up sandboxes...");
                }
            }
            cleanup_registered_servers(server_registry).await;
        })
        .await?;

    Ok(())
}

fn create_mcp_service(server_registry: ServerRegistry, bind_addr: &str) -> McpHttpService {
    let session_manager = Arc::new(LocalSessionManager::default());
    let config = StreamableHttpServerConfig::default()
        .with_stateful_mode(true)
        .with_allowed_hosts(allowed_hosts(bind_addr));

    StreamableHttpService::new(
        move || {
            let server = MimoboxServer::new();
            register_server(&server_registry, server.clone())?;
            Ok(server)
        },
        session_manager,
        config,
    )
}

fn register_server(server_registry: &ServerRegistry, server: MimoboxServer) -> io::Result<()> {
    let mut servers = server_registry
        .lock()
        .map_err(|_| io::Error::other("MCP HTTP server registry lock poisoned"))?;
    if servers.len() >= MAX_CONCURRENT_SESSIONS {
        tracing::warn!(
            max_sessions = MAX_CONCURRENT_SESSIONS,
            "MCP HTTP session registry 已达到上限,移除最早的 server handle"
        );
        drop(servers.remove(0));
    }
    servers.push(server);
    Ok(())
}

fn is_local_bind_addr(bind_addr: &str) -> bool {
    matches!(bind_addr, "127.0.0.1" | "::1")
}

fn allowed_hosts(bind_addr: &str) -> Vec<String> {
    let mut hosts = vec![
        "localhost".to_string(),
        "127.0.0.1".to_string(),
        "::1".to_string(),
    ];
    if !hosts.iter().any(|host| host == bind_addr) {
        hosts.push(bind_addr.to_string());
    }
    hosts
}

fn parse_allowed_origins(allowed_origins: Option<String>) -> Vec<String> {
    match allowed_origins {
        Some(origins) if origins.contains(WILDCARD_ORIGIN) => {
            tracing::warn!("CORS 配置为完全开放模式(*),请勿在生产环境使用");
            vec![WILDCARD_ORIGIN.to_string()]
        }
        Some(origins) => {
            let parsed = origins
                .split(',')
                .map(str::trim)
                .filter(|origin| !origin.is_empty())
                .map(ToOwned::to_owned)
                .collect::<Vec<_>>();
            warn_non_local_origins(&parsed);
            parsed
        }
        None => vec![
            "http://localhost".to_string(),
            "http://127.0.0.1".to_string(),
        ],
    }
}

fn warn_non_local_origins(origins: &[String]) {
    for origin in origins {
        if !is_local_origin(origin) {
            tracing::warn!(origin, "CORS 允许非本地 origin,请确认仅用于受信客户端");
        }
    }
}

fn is_local_origin(origin: &str) -> bool {
    let Some(host_part) = origin
        .strip_prefix("http://")
        .or_else(|| origin.strip_prefix("https://"))
    else {
        return false;
    };

    let host_with_port = match host_part.split('/').next() {
        Some(host_with_port) => host_with_port,
        None => host_part,
    };
    let host = if let Some(ipv6_part) = host_with_port.strip_prefix('[') {
        match ipv6_part.split_once(']') {
            Some((host, _)) => host,
            None => host_with_port,
        }
    } else {
        match host_with_port.split_once(':') {
            Some((host, _)) => host,
            None => host_with_port,
        }
    };

    matches!(host, "localhost" | "127.0.0.1" | "::1")
}

async fn cleanup_registered_servers(server_registry: ServerRegistry) {
    let servers = match server_registry.lock() {
        Ok(mut servers) => std::mem::take(&mut *servers),
        Err(_) => {
            tracing::error!("MCP HTTP server registry lock poisoned, skip sandbox cleanup");
            return;
        }
    };

    for server in servers {
        server.cleanup_all().await;
    }
}

async fn cors_middleware(
    State(allowed_origins): State<AllowedOrigins>,
    req: Request,
    next: Next,
) -> Response {
    let allowed_origin = allowed_origin_header(req.headers(), allowed_origins.as_slice()).cloned();

    if req.method() == Method::OPTIONS {
        return cors_response(allowed_origin.as_ref());
    }

    let mut response = next.run(req).await;
    apply_cors_headers(response.headers_mut(), allowed_origin.as_ref());
    response
}

fn cors_response(allowed_origin: Option<&HeaderValue>) -> Response {
    let mut response = Response::new(Body::empty());
    *response.status_mut() = StatusCode::OK;
    apply_cors_headers(response.headers_mut(), allowed_origin);
    response
}

fn allowed_origin_header<'a>(
    headers: &'a HeaderMap,
    allowed_origins: &[String],
) -> Option<&'a HeaderValue> {
    let origin = headers.get(HeaderName::from_static("origin"))?;
    if is_origin_allowed(origin, allowed_origins) {
        Some(origin)
    } else {
        None
    }
}

fn is_origin_allowed(origin: &HeaderValue, allowed_origins: &[String]) -> bool {
    if allowed_origins
        .iter()
        .any(|allowed_origin| allowed_origin == WILDCARD_ORIGIN)
    {
        return true;
    }

    let Ok(origin) = origin.to_str() else {
        return false;
    };

    allowed_origins
        .iter()
        .any(|allowed_origin| allowed_origin == origin)
}

fn apply_cors_headers(headers: &mut HeaderMap, allowed_origin: Option<&HeaderValue>) {
    if let Some(origin) = allowed_origin {
        headers.insert(
            HeaderName::from_static("access-control-allow-origin"),
            origin.clone(),
        );
    }
    headers.insert(
        HeaderName::from_static("access-control-allow-methods"),
        HeaderValue::from_static("GET, POST, DELETE, OPTIONS"),
    );
    headers.insert(
        HeaderName::from_static("access-control-allow-headers"),
        HeaderValue::from_static(
            "Content-Type, Accept, Mcp-Session-Id, Mcp-Protocol-Version, Last-Event-ID",
        ),
    );
    headers.insert(
        HeaderName::from_static("access-control-expose-headers"),
        HeaderValue::from_static("Mcp-Session-Id, Mcp-Protocol-Version"),
    );
}