Skip to main content

cortexai_mcp/
http_server.rs

1//! HTTP Server Transport for MCP
2//!
3//! Implements a plain HTTP transport compatible with Meridian's `toolFromMCP()`.
4//!
5//! # Endpoints
6//!
7//! - `POST /mcp`       — JSON-RPC requests (initialize, tools/list, tools/call)
8//! - `POST /call-tool`  — Simplified `{ "name": "...", "arguments": {...} }` execution
9//! - `GET  /tools`      — Plain JSON tool list
10//! - `GET  /health`     — Health check
11
12use axum::{
13    extract::State,
14    http::{header, Method},
15    routing::{get, post},
16    Json, Router,
17};
18use http::StatusCode;
19use serde::Deserialize;
20use std::net::SocketAddr;
21use tracing::info;
22use serde_json::json;
23use std::sync::Arc;
24use tower_http::cors::{Any, CorsLayer};
25
26use crate::protocol::JsonRpcRequest;
27use crate::server::McpServer;
28
29/// Configuration for the HTTP server transport
30#[derive(Debug, Clone)]
31pub struct HttpServerConfig {
32    /// Host to bind to
33    pub host: String,
34    /// Port to bind to
35    pub port: u16,
36    /// Enable CORS headers
37    pub enable_cors: bool,
38}
39
40impl Default for HttpServerConfig {
41    fn default() -> Self {
42        Self {
43            host: "0.0.0.0".to_string(),
44            port: 3001,
45            enable_cors: true,
46        }
47    }
48}
49
50impl HttpServerConfig {
51    /// Create config for localhost on specified port
52    pub fn localhost(port: u16) -> Self {
53        Self {
54            host: "127.0.0.1".to_string(),
55            port,
56            ..Default::default()
57        }
58    }
59
60    /// Create config that binds to all interfaces
61    pub fn public(port: u16) -> Self {
62        Self {
63            port,
64            ..Default::default()
65        }
66    }
67}
68
69// =============================================================================
70// HTTP Router for McpServer
71// =============================================================================
72
73impl McpServer {
74    /// Build an Axum router with HTTP endpoints for Meridian integration.
75    ///
76    /// Endpoints:
77    /// - `GET  /health`     — Health check
78    /// - `GET  /tools`      — Plain JSON tool list
79    /// - `POST /mcp`        — JSON-RPC endpoint
80    /// - `POST /call-tool`  — Simplified tool call
81    pub fn http_router(self: Arc<Self>, config: HttpServerConfig) -> Router {
82        let mut router = Router::new()
83            .route("/health", get(handle_health))
84            .route("/tools", get(handle_tools))
85            .route("/mcp", post(handle_mcp_jsonrpc))
86            .route("/call-tool", post(handle_call_tool))
87            .with_state(self);
88
89        if config.enable_cors {
90            let cors = CorsLayer::new()
91                .allow_origin(Any)
92                .allow_methods([Method::GET, Method::POST])
93                .allow_headers([header::CONTENT_TYPE, header::ACCEPT]);
94            router = router.layer(cors);
95        }
96
97        router
98    }
99
100    /// Run the server with HTTP transport (binds and serves).
101    ///
102    /// This is a convenience wrapper around [`http_router`] that binds to the
103    /// configured address and serves until the process is interrupted.
104    pub async fn run_http(self: Arc<Self>, config: HttpServerConfig) -> Result<(), crate::error::McpError> {
105        let addr: SocketAddr = format!("{}:{}", config.host, config.port)
106            .parse()
107            .map_err(|e| crate::error::McpError::Transport(format!("Invalid address: {}", e)))?;
108
109        info!("Starting MCP HTTP server on http://{}", addr);
110
111        let router = self.http_router(config);
112
113        let listener = tokio::net::TcpListener::bind(addr)
114            .await
115            .map_err(|e| crate::error::McpError::Transport(format!("Failed to bind: {}", e)))?;
116
117        axum::serve(listener, router)
118            .await
119            .map_err(|e| crate::error::McpError::Transport(format!("Server error: {}", e)))?;
120
121        Ok(())
122    }
123}
124
125// =============================================================================
126// HTTP Handlers
127// =============================================================================
128
129/// GET /tools — plain JSON array of tool definitions
130async fn handle_tools(
131    State(server): State<Arc<McpServer>>,
132) -> Json<serde_json::Value> {
133    let request = crate::protocol::JsonRpcRequest::new(1i64, "tools/list");
134    let response = server.handle_request(request).await;
135
136    match response.result {
137        Some(result) => Json(result["tools"].clone()),
138        None => Json(json!([])),
139    }
140}
141
142/// POST /mcp — JSON-RPC endpoint (initialize, tools/list, tools/call, etc.)
143async fn handle_mcp_jsonrpc(
144    State(server): State<Arc<McpServer>>,
145    Json(request): Json<JsonRpcRequest>,
146) -> Json<serde_json::Value> {
147    let response = server.handle_request(request).await;
148    let response_json = serde_json::to_value(&response).unwrap_or_default();
149    Json(response_json)
150}
151
152/// Request body for POST /call-tool
153#[derive(Debug, Deserialize)]
154struct CallToolBody {
155    name: String,
156    #[serde(default)]
157    arguments: serde_json::Value,
158}
159
160/// POST /call-tool — simplified tool execution
161async fn handle_call_tool(
162    State(server): State<Arc<McpServer>>,
163    Json(body): Json<CallToolBody>,
164) -> (StatusCode, Json<serde_json::Value>) {
165    let rpc_request = JsonRpcRequest::new(1i64, "tools/call").with_params(json!({
166        "name": body.name,
167        "arguments": body.arguments
168    }));
169
170    let rpc_response = server.handle_request(rpc_request).await;
171
172    match rpc_response.result {
173        Some(result) => (StatusCode::OK, Json(result)),
174        None => {
175            let error_msg = rpc_response
176                .error
177                .map(|e| e.message)
178                .unwrap_or_else(|| "Unknown error".to_string());
179            (
180                StatusCode::BAD_REQUEST,
181                Json(json!({"error": error_msg})),
182            )
183        }
184    }
185}
186
187/// GET /health — health check
188async fn handle_health() -> Json<serde_json::Value> {
189    Json(json!({
190        "status": "ok",
191        "server": "cortexai",
192        "version": "0.1.0"
193    }))
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    use std::sync::Arc;
201    use crate::server::FnTool;
202    use axum::body::to_bytes;
203    use serde_json::{json, Value};
204    use tower::util::ServiceExt;
205
206    fn create_test_server() -> Arc<crate::server::McpServer> {
207        crate::server::McpServer::builder()
208            .name("test-http-server")
209            .version("1.0.0")
210            .add_tool(FnTool::new(
211                "echo",
212                "Echoes input",
213                json!({
214                    "type": "object",
215                    "properties": {
216                        "message": {"type": "string"}
217                    }
218                }),
219                |args| {
220                    let msg = args["message"].as_str().unwrap_or("no message");
221                    Ok(json!({"echoed": msg}))
222                },
223            ))
224            .build()
225    }
226
227    async fn send_request(
228        app: axum::Router,
229        req: axum::http::Request<axum::body::Body>,
230    ) -> axum::http::Response<axum::body::Body> {
231        app.oneshot(req).await.unwrap()
232    }
233
234    #[test]
235    fn test_http_server_config_defaults() {
236        let config = HttpServerConfig::default();
237        assert_eq!(config.host, "0.0.0.0");
238        assert_eq!(config.port, 3001);
239        assert!(config.enable_cors);
240    }
241
242    #[tokio::test]
243    async fn test_get_health() {
244        let server = create_test_server();
245        let app = server.http_router(HttpServerConfig::default());
246
247        let req = axum::http::Request::builder()
248            .method("GET")
249            .uri("/health")
250            .body(axum::body::Body::empty())
251            .unwrap();
252
253        let response = send_request(app, req).await;
254        assert_eq!(response.status(), http::StatusCode::OK);
255
256        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
257        let value: Value = serde_json::from_slice(&bytes).unwrap();
258        assert_eq!(value["status"], "ok");
259        assert_eq!(value["server"], "cortexai");
260        assert_eq!(value["version"], "0.1.0");
261    }
262
263    #[tokio::test]
264    async fn test_get_tools() {
265        let server = create_test_server();
266        let app = server.http_router(HttpServerConfig::default());
267
268        let req = axum::http::Request::builder()
269            .method("GET")
270            .uri("/tools")
271            .body(axum::body::Body::empty())
272            .unwrap();
273
274        let response = send_request(app, req).await;
275        assert_eq!(response.status(), http::StatusCode::OK);
276
277        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
278        let value: Value = serde_json::from_slice(&bytes).unwrap();
279        assert!(value.is_array());
280        let tools = value.as_array().unwrap();
281        assert_eq!(tools.len(), 1);
282        assert_eq!(tools[0]["name"], "echo");
283    }
284
285    #[tokio::test]
286    async fn test_post_mcp_tools_list() {
287        let server = create_test_server();
288        let app = server.http_router(HttpServerConfig::default());
289
290        let rpc_body = json!({
291            "jsonrpc": "2.0",
292            "id": 1,
293            "method": "tools/list"
294        });
295
296        let req = axum::http::Request::builder()
297            .method("POST")
298            .uri("/mcp")
299            .header("content-type", "application/json")
300            .body(axum::body::Body::from(serde_json::to_vec(&rpc_body).unwrap()))
301            .unwrap();
302
303        let response = send_request(app, req).await;
304        assert_eq!(response.status(), http::StatusCode::OK);
305
306        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
307        let value: Value = serde_json::from_slice(&bytes).unwrap();
308        assert_eq!(value["jsonrpc"], "2.0");
309        assert!(value["result"]["tools"].is_array());
310    }
311
312    #[tokio::test]
313    async fn test_post_mcp_tools_call() {
314        let server = create_test_server();
315        let app = server.http_router(HttpServerConfig::default());
316
317        let rpc_body = json!({
318            "jsonrpc": "2.0",
319            "id": 2,
320            "method": "tools/call",
321            "params": {
322                "name": "echo",
323                "arguments": { "message": "hello" }
324            }
325        });
326
327        let req = axum::http::Request::builder()
328            .method("POST")
329            .uri("/mcp")
330            .header("content-type", "application/json")
331            .body(axum::body::Body::from(serde_json::to_vec(&rpc_body).unwrap()))
332            .unwrap();
333
334        let response = send_request(app, req).await;
335        assert_eq!(response.status(), http::StatusCode::OK);
336
337        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
338        let value: Value = serde_json::from_slice(&bytes).unwrap();
339        assert!(value["error"].is_null());
340        let text = value["result"]["content"][0]["text"].as_str().unwrap();
341        assert!(text.contains("hello"));
342    }
343
344    #[tokio::test]
345    async fn test_post_call_tool_simplified() {
346        let server = create_test_server();
347        let app = server.http_router(HttpServerConfig::default());
348
349        let body = json!({
350            "name": "echo",
351            "arguments": { "message": "hi there" }
352        });
353
354        let req = axum::http::Request::builder()
355            .method("POST")
356            .uri("/call-tool")
357            .header("content-type", "application/json")
358            .body(axum::body::Body::from(serde_json::to_vec(&body).unwrap()))
359            .unwrap();
360
361        let response = send_request(app, req).await;
362        assert_eq!(response.status(), http::StatusCode::OK);
363
364        let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
365        let value: Value = serde_json::from_slice(&bytes).unwrap();
366        assert!(!value["isError"].as_bool().unwrap_or(true));
367        assert!(value["content"][0]["text"].as_str().unwrap().contains("hi there"));
368    }
369}