cortexai-mcp 0.1.0

Model Context Protocol (MCP) support for Cortex: stdio, SSE, and server transports
Documentation
//! HTTP Server Transport for MCP
//!
//! Implements a plain HTTP transport compatible with Meridian's `toolFromMCP()`.
//!
//! # Endpoints
//!
//! - `POST /mcp`       — JSON-RPC requests (initialize, tools/list, tools/call)
//! - `POST /call-tool`  — Simplified `{ "name": "...", "arguments": {...} }` execution
//! - `GET  /tools`      — Plain JSON tool list
//! - `GET  /health`     — Health check

use axum::{
    extract::State,
    http::{header, Method},
    routing::{get, post},
    Json, Router,
};
use http::StatusCode;
use serde::Deserialize;
use std::net::SocketAddr;
use tracing::info;
use serde_json::json;
use std::sync::Arc;
use tower_http::cors::{Any, CorsLayer};

use crate::protocol::JsonRpcRequest;
use crate::server::McpServer;

/// Configuration for the HTTP server transport
#[derive(Debug, Clone)]
pub struct HttpServerConfig {
    /// Host to bind to
    pub host: String,
    /// Port to bind to
    pub port: u16,
    /// Enable CORS headers
    pub enable_cors: bool,
}

impl Default for HttpServerConfig {
    fn default() -> Self {
        Self {
            host: "0.0.0.0".to_string(),
            port: 3001,
            enable_cors: true,
        }
    }
}

impl HttpServerConfig {
    /// Create config for localhost on specified port
    pub fn localhost(port: u16) -> Self {
        Self {
            host: "127.0.0.1".to_string(),
            port,
            ..Default::default()
        }
    }

    /// Create config that binds to all interfaces
    pub fn public(port: u16) -> Self {
        Self {
            port,
            ..Default::default()
        }
    }
}

// =============================================================================
// HTTP Router for McpServer
// =============================================================================

impl McpServer {
    /// Build an Axum router with HTTP endpoints for Meridian integration.
    ///
    /// Endpoints:
    /// - `GET  /health`     — Health check
    /// - `GET  /tools`      — Plain JSON tool list
    /// - `POST /mcp`        — JSON-RPC endpoint
    /// - `POST /call-tool`  — Simplified tool call
    pub fn http_router(self: Arc<Self>, config: HttpServerConfig) -> Router {
        let mut router = Router::new()
            .route("/health", get(handle_health))
            .route("/tools", get(handle_tools))
            .route("/mcp", post(handle_mcp_jsonrpc))
            .route("/call-tool", post(handle_call_tool))
            .with_state(self);

        if config.enable_cors {
            let cors = CorsLayer::new()
                .allow_origin(Any)
                .allow_methods([Method::GET, Method::POST])
                .allow_headers([header::CONTENT_TYPE, header::ACCEPT]);
            router = router.layer(cors);
        }

        router
    }

    /// Run the server with HTTP transport (binds and serves).
    ///
    /// This is a convenience wrapper around [`http_router`] that binds to the
    /// configured address and serves until the process is interrupted.
    pub async fn run_http(self: Arc<Self>, config: HttpServerConfig) -> Result<(), crate::error::McpError> {
        let addr: SocketAddr = format!("{}:{}", config.host, config.port)
            .parse()
            .map_err(|e| crate::error::McpError::Transport(format!("Invalid address: {}", e)))?;

        info!("Starting MCP HTTP server on http://{}", addr);

        let router = self.http_router(config);

        let listener = tokio::net::TcpListener::bind(addr)
            .await
            .map_err(|e| crate::error::McpError::Transport(format!("Failed to bind: {}", e)))?;

        axum::serve(listener, router)
            .await
            .map_err(|e| crate::error::McpError::Transport(format!("Server error: {}", e)))?;

        Ok(())
    }
}

// =============================================================================
// HTTP Handlers
// =============================================================================

/// GET /tools — plain JSON array of tool definitions
async fn handle_tools(
    State(server): State<Arc<McpServer>>,
) -> Json<serde_json::Value> {
    let request = crate::protocol::JsonRpcRequest::new(1i64, "tools/list");
    let response = server.handle_request(request).await;

    match response.result {
        Some(result) => Json(result["tools"].clone()),
        None => Json(json!([])),
    }
}

/// POST /mcp — JSON-RPC endpoint (initialize, tools/list, tools/call, etc.)
async fn handle_mcp_jsonrpc(
    State(server): State<Arc<McpServer>>,
    Json(request): Json<JsonRpcRequest>,
) -> Json<serde_json::Value> {
    let response = server.handle_request(request).await;
    let response_json = serde_json::to_value(&response).unwrap_or_default();
    Json(response_json)
}

/// Request body for POST /call-tool
#[derive(Debug, Deserialize)]
struct CallToolBody {
    name: String,
    #[serde(default)]
    arguments: serde_json::Value,
}

/// POST /call-tool — simplified tool execution
async fn handle_call_tool(
    State(server): State<Arc<McpServer>>,
    Json(body): Json<CallToolBody>,
) -> (StatusCode, Json<serde_json::Value>) {
    let rpc_request = JsonRpcRequest::new(1i64, "tools/call").with_params(json!({
        "name": body.name,
        "arguments": body.arguments
    }));

    let rpc_response = server.handle_request(rpc_request).await;

    match rpc_response.result {
        Some(result) => (StatusCode::OK, Json(result)),
        None => {
            let error_msg = rpc_response
                .error
                .map(|e| e.message)
                .unwrap_or_else(|| "Unknown error".to_string());
            (
                StatusCode::BAD_REQUEST,
                Json(json!({"error": error_msg})),
            )
        }
    }
}

/// GET /health — health check
async fn handle_health() -> Json<serde_json::Value> {
    Json(json!({
        "status": "ok",
        "server": "cortexai",
        "version": "0.1.0"
    }))
}

#[cfg(test)]
mod tests {
    use super::*;

    use std::sync::Arc;
    use crate::server::FnTool;
    use axum::body::to_bytes;
    use serde_json::{json, Value};
    use tower::util::ServiceExt;

    fn create_test_server() -> Arc<crate::server::McpServer> {
        crate::server::McpServer::builder()
            .name("test-http-server")
            .version("1.0.0")
            .add_tool(FnTool::new(
                "echo",
                "Echoes input",
                json!({
                    "type": "object",
                    "properties": {
                        "message": {"type": "string"}
                    }
                }),
                |args| {
                    let msg = args["message"].as_str().unwrap_or("no message");
                    Ok(json!({"echoed": msg}))
                },
            ))
            .build()
    }

    async fn send_request(
        app: axum::Router,
        req: axum::http::Request<axum::body::Body>,
    ) -> axum::http::Response<axum::body::Body> {
        app.oneshot(req).await.unwrap()
    }

    #[test]
    fn test_http_server_config_defaults() {
        let config = HttpServerConfig::default();
        assert_eq!(config.host, "0.0.0.0");
        assert_eq!(config.port, 3001);
        assert!(config.enable_cors);
    }

    #[tokio::test]
    async fn test_get_health() {
        let server = create_test_server();
        let app = server.http_router(HttpServerConfig::default());

        let req = axum::http::Request::builder()
            .method("GET")
            .uri("/health")
            .body(axum::body::Body::empty())
            .unwrap();

        let response = send_request(app, req).await;
        assert_eq!(response.status(), http::StatusCode::OK);

        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
        let value: Value = serde_json::from_slice(&bytes).unwrap();
        assert_eq!(value["status"], "ok");
        assert_eq!(value["server"], "cortexai");
        assert_eq!(value["version"], "0.1.0");
    }

    #[tokio::test]
    async fn test_get_tools() {
        let server = create_test_server();
        let app = server.http_router(HttpServerConfig::default());

        let req = axum::http::Request::builder()
            .method("GET")
            .uri("/tools")
            .body(axum::body::Body::empty())
            .unwrap();

        let response = send_request(app, req).await;
        assert_eq!(response.status(), http::StatusCode::OK);

        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
        let value: Value = serde_json::from_slice(&bytes).unwrap();
        assert!(value.is_array());
        let tools = value.as_array().unwrap();
        assert_eq!(tools.len(), 1);
        assert_eq!(tools[0]["name"], "echo");
    }

    #[tokio::test]
    async fn test_post_mcp_tools_list() {
        let server = create_test_server();
        let app = server.http_router(HttpServerConfig::default());

        let rpc_body = json!({
            "jsonrpc": "2.0",
            "id": 1,
            "method": "tools/list"
        });

        let req = axum::http::Request::builder()
            .method("POST")
            .uri("/mcp")
            .header("content-type", "application/json")
            .body(axum::body::Body::from(serde_json::to_vec(&rpc_body).unwrap()))
            .unwrap();

        let response = send_request(app, req).await;
        assert_eq!(response.status(), http::StatusCode::OK);

        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
        let value: Value = serde_json::from_slice(&bytes).unwrap();
        assert_eq!(value["jsonrpc"], "2.0");
        assert!(value["result"]["tools"].is_array());
    }

    #[tokio::test]
    async fn test_post_mcp_tools_call() {
        let server = create_test_server();
        let app = server.http_router(HttpServerConfig::default());

        let rpc_body = json!({
            "jsonrpc": "2.0",
            "id": 2,
            "method": "tools/call",
            "params": {
                "name": "echo",
                "arguments": { "message": "hello" }
            }
        });

        let req = axum::http::Request::builder()
            .method("POST")
            .uri("/mcp")
            .header("content-type", "application/json")
            .body(axum::body::Body::from(serde_json::to_vec(&rpc_body).unwrap()))
            .unwrap();

        let response = send_request(app, req).await;
        assert_eq!(response.status(), http::StatusCode::OK);

        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
        let value: Value = serde_json::from_slice(&bytes).unwrap();
        assert!(value["error"].is_null());
        let text = value["result"]["content"][0]["text"].as_str().unwrap();
        assert!(text.contains("hello"));
    }

    #[tokio::test]
    async fn test_post_call_tool_simplified() {
        let server = create_test_server();
        let app = server.http_router(HttpServerConfig::default());

        let body = json!({
            "name": "echo",
            "arguments": { "message": "hi there" }
        });

        let req = axum::http::Request::builder()
            .method("POST")
            .uri("/call-tool")
            .header("content-type", "application/json")
            .body(axum::body::Body::from(serde_json::to_vec(&body).unwrap()))
            .unwrap();

        let response = send_request(app, req).await;
        assert_eq!(response.status(), http::StatusCode::OK);

        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
        let value: Value = serde_json::from_slice(&bytes).unwrap();
        assert!(!value["isError"].as_bool().unwrap_or(true));
        assert!(value["content"][0]["text"].as_str().unwrap().contains("hi there"));
    }
}