1use axum::{
13 extract::State,
14 http::{header, Method},
15 routing::{get, post},
16 Json, Router,
17};
18use http::StatusCode;
19use serde::Deserialize;
20use std::net::SocketAddr;
21use tracing::info;
22use serde_json::json;
23use std::sync::Arc;
24use tower_http::cors::{Any, CorsLayer};
25
26use crate::protocol::JsonRpcRequest;
27use crate::server::McpServer;
28
29#[derive(Debug, Clone)]
31pub struct HttpServerConfig {
32 pub host: String,
34 pub port: u16,
36 pub enable_cors: bool,
38}
39
40impl Default for HttpServerConfig {
41 fn default() -> Self {
42 Self {
43 host: "0.0.0.0".to_string(),
44 port: 3001,
45 enable_cors: true,
46 }
47 }
48}
49
50impl HttpServerConfig {
51 pub fn localhost(port: u16) -> Self {
53 Self {
54 host: "127.0.0.1".to_string(),
55 port,
56 ..Default::default()
57 }
58 }
59
60 pub fn public(port: u16) -> Self {
62 Self {
63 port,
64 ..Default::default()
65 }
66 }
67}
68
69impl McpServer {
74 pub fn http_router(self: Arc<Self>, config: HttpServerConfig) -> Router {
82 let mut router = Router::new()
83 .route("/health", get(handle_health))
84 .route("/tools", get(handle_tools))
85 .route("/mcp", post(handle_mcp_jsonrpc))
86 .route("/call-tool", post(handle_call_tool))
87 .with_state(self);
88
89 if config.enable_cors {
90 let cors = CorsLayer::new()
91 .allow_origin(Any)
92 .allow_methods([Method::GET, Method::POST])
93 .allow_headers([header::CONTENT_TYPE, header::ACCEPT]);
94 router = router.layer(cors);
95 }
96
97 router
98 }
99
100 pub async fn run_http(self: Arc<Self>, config: HttpServerConfig) -> Result<(), crate::error::McpError> {
105 let addr: SocketAddr = format!("{}:{}", config.host, config.port)
106 .parse()
107 .map_err(|e| crate::error::McpError::Transport(format!("Invalid address: {}", e)))?;
108
109 info!("Starting MCP HTTP server on http://{}", addr);
110
111 let router = self.http_router(config);
112
113 let listener = tokio::net::TcpListener::bind(addr)
114 .await
115 .map_err(|e| crate::error::McpError::Transport(format!("Failed to bind: {}", e)))?;
116
117 axum::serve(listener, router)
118 .await
119 .map_err(|e| crate::error::McpError::Transport(format!("Server error: {}", e)))?;
120
121 Ok(())
122 }
123}
124
125async fn handle_tools(
131 State(server): State<Arc<McpServer>>,
132) -> Json<serde_json::Value> {
133 let request = crate::protocol::JsonRpcRequest::new(1i64, "tools/list");
134 let response = server.handle_request(request).await;
135
136 match response.result {
137 Some(result) => Json(result["tools"].clone()),
138 None => Json(json!([])),
139 }
140}
141
142async fn handle_mcp_jsonrpc(
144 State(server): State<Arc<McpServer>>,
145 Json(request): Json<JsonRpcRequest>,
146) -> Json<serde_json::Value> {
147 let response = server.handle_request(request).await;
148 let response_json = serde_json::to_value(&response).unwrap_or_default();
149 Json(response_json)
150}
151
152#[derive(Debug, Deserialize)]
154struct CallToolBody {
155 name: String,
156 #[serde(default)]
157 arguments: serde_json::Value,
158}
159
160async fn handle_call_tool(
162 State(server): State<Arc<McpServer>>,
163 Json(body): Json<CallToolBody>,
164) -> (StatusCode, Json<serde_json::Value>) {
165 let rpc_request = JsonRpcRequest::new(1i64, "tools/call").with_params(json!({
166 "name": body.name,
167 "arguments": body.arguments
168 }));
169
170 let rpc_response = server.handle_request(rpc_request).await;
171
172 match rpc_response.result {
173 Some(result) => (StatusCode::OK, Json(result)),
174 None => {
175 let error_msg = rpc_response
176 .error
177 .map(|e| e.message)
178 .unwrap_or_else(|| "Unknown error".to_string());
179 (
180 StatusCode::BAD_REQUEST,
181 Json(json!({"error": error_msg})),
182 )
183 }
184 }
185}
186
187async fn handle_health() -> Json<serde_json::Value> {
189 Json(json!({
190 "status": "ok",
191 "server": "cortexai",
192 "version": "0.1.0"
193 }))
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 use std::sync::Arc;
201 use crate::server::FnTool;
202 use axum::body::to_bytes;
203 use serde_json::{json, Value};
204 use tower::util::ServiceExt;
205
206 fn create_test_server() -> Arc<crate::server::McpServer> {
207 crate::server::McpServer::builder()
208 .name("test-http-server")
209 .version("1.0.0")
210 .add_tool(FnTool::new(
211 "echo",
212 "Echoes input",
213 json!({
214 "type": "object",
215 "properties": {
216 "message": {"type": "string"}
217 }
218 }),
219 |args| {
220 let msg = args["message"].as_str().unwrap_or("no message");
221 Ok(json!({"echoed": msg}))
222 },
223 ))
224 .build()
225 }
226
227 async fn send_request(
228 app: axum::Router,
229 req: axum::http::Request<axum::body::Body>,
230 ) -> axum::http::Response<axum::body::Body> {
231 app.oneshot(req).await.unwrap()
232 }
233
234 #[test]
235 fn test_http_server_config_defaults() {
236 let config = HttpServerConfig::default();
237 assert_eq!(config.host, "0.0.0.0");
238 assert_eq!(config.port, 3001);
239 assert!(config.enable_cors);
240 }
241
242 #[tokio::test]
243 async fn test_get_health() {
244 let server = create_test_server();
245 let app = server.http_router(HttpServerConfig::default());
246
247 let req = axum::http::Request::builder()
248 .method("GET")
249 .uri("/health")
250 .body(axum::body::Body::empty())
251 .unwrap();
252
253 let response = send_request(app, req).await;
254 assert_eq!(response.status(), http::StatusCode::OK);
255
256 let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
257 let value: Value = serde_json::from_slice(&bytes).unwrap();
258 assert_eq!(value["status"], "ok");
259 assert_eq!(value["server"], "cortexai");
260 assert_eq!(value["version"], "0.1.0");
261 }
262
263 #[tokio::test]
264 async fn test_get_tools() {
265 let server = create_test_server();
266 let app = server.http_router(HttpServerConfig::default());
267
268 let req = axum::http::Request::builder()
269 .method("GET")
270 .uri("/tools")
271 .body(axum::body::Body::empty())
272 .unwrap();
273
274 let response = send_request(app, req).await;
275 assert_eq!(response.status(), http::StatusCode::OK);
276
277 let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
278 let value: Value = serde_json::from_slice(&bytes).unwrap();
279 assert!(value.is_array());
280 let tools = value.as_array().unwrap();
281 assert_eq!(tools.len(), 1);
282 assert_eq!(tools[0]["name"], "echo");
283 }
284
285 #[tokio::test]
286 async fn test_post_mcp_tools_list() {
287 let server = create_test_server();
288 let app = server.http_router(HttpServerConfig::default());
289
290 let rpc_body = json!({
291 "jsonrpc": "2.0",
292 "id": 1,
293 "method": "tools/list"
294 });
295
296 let req = axum::http::Request::builder()
297 .method("POST")
298 .uri("/mcp")
299 .header("content-type", "application/json")
300 .body(axum::body::Body::from(serde_json::to_vec(&rpc_body).unwrap()))
301 .unwrap();
302
303 let response = send_request(app, req).await;
304 assert_eq!(response.status(), http::StatusCode::OK);
305
306 let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
307 let value: Value = serde_json::from_slice(&bytes).unwrap();
308 assert_eq!(value["jsonrpc"], "2.0");
309 assert!(value["result"]["tools"].is_array());
310 }
311
312 #[tokio::test]
313 async fn test_post_mcp_tools_call() {
314 let server = create_test_server();
315 let app = server.http_router(HttpServerConfig::default());
316
317 let rpc_body = json!({
318 "jsonrpc": "2.0",
319 "id": 2,
320 "method": "tools/call",
321 "params": {
322 "name": "echo",
323 "arguments": { "message": "hello" }
324 }
325 });
326
327 let req = axum::http::Request::builder()
328 .method("POST")
329 .uri("/mcp")
330 .header("content-type", "application/json")
331 .body(axum::body::Body::from(serde_json::to_vec(&rpc_body).unwrap()))
332 .unwrap();
333
334 let response = send_request(app, req).await;
335 assert_eq!(response.status(), http::StatusCode::OK);
336
337 let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
338 let value: Value = serde_json::from_slice(&bytes).unwrap();
339 assert!(value["error"].is_null());
340 let text = value["result"]["content"][0]["text"].as_str().unwrap();
341 assert!(text.contains("hello"));
342 }
343
344 #[tokio::test]
345 async fn test_post_call_tool_simplified() {
346 let server = create_test_server();
347 let app = server.http_router(HttpServerConfig::default());
348
349 let body = json!({
350 "name": "echo",
351 "arguments": { "message": "hi there" }
352 });
353
354 let req = axum::http::Request::builder()
355 .method("POST")
356 .uri("/call-tool")
357 .header("content-type", "application/json")
358 .body(axum::body::Body::from(serde_json::to_vec(&body).unwrap()))
359 .unwrap();
360
361 let response = send_request(app, req).await;
362 assert_eq!(response.status(), http::StatusCode::OK);
363
364 let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
365 let value: Value = serde_json::from_slice(&bytes).unwrap();
366 assert!(!value["isError"].as_bool().unwrap_or(true));
367 assert!(value["content"][0]["text"].as_str().unwrap().contains("hi there"));
368 }
369}