use axum::{
extract::State,
http::{header, Method},
routing::{get, post},
Json, Router,
};
use http::StatusCode;
use serde::Deserialize;
use std::net::SocketAddr;
use tracing::info;
use serde_json::json;
use std::sync::Arc;
use tower_http::cors::{Any, CorsLayer};
use crate::protocol::JsonRpcRequest;
use crate::server::McpServer;
#[derive(Debug, Clone)]
pub struct HttpServerConfig {
pub host: String,
pub port: u16,
pub enable_cors: bool,
}
impl Default for HttpServerConfig {
fn default() -> Self {
Self {
host: "0.0.0.0".to_string(),
port: 3001,
enable_cors: true,
}
}
}
impl HttpServerConfig {
pub fn localhost(port: u16) -> Self {
Self {
host: "127.0.0.1".to_string(),
port,
..Default::default()
}
}
pub fn public(port: u16) -> Self {
Self {
port,
..Default::default()
}
}
}
impl McpServer {
pub fn http_router(self: Arc<Self>, config: HttpServerConfig) -> Router {
let mut router = Router::new()
.route("/health", get(handle_health))
.route("/tools", get(handle_tools))
.route("/mcp", post(handle_mcp_jsonrpc))
.route("/call-tool", post(handle_call_tool))
.with_state(self);
if config.enable_cors {
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods([Method::GET, Method::POST])
.allow_headers([header::CONTENT_TYPE, header::ACCEPT]);
router = router.layer(cors);
}
router
}
pub async fn run_http(self: Arc<Self>, config: HttpServerConfig) -> Result<(), crate::error::McpError> {
let addr: SocketAddr = format!("{}:{}", config.host, config.port)
.parse()
.map_err(|e| crate::error::McpError::Transport(format!("Invalid address: {}", e)))?;
info!("Starting MCP HTTP server on http://{}", addr);
let router = self.http_router(config);
let listener = tokio::net::TcpListener::bind(addr)
.await
.map_err(|e| crate::error::McpError::Transport(format!("Failed to bind: {}", e)))?;
axum::serve(listener, router)
.await
.map_err(|e| crate::error::McpError::Transport(format!("Server error: {}", e)))?;
Ok(())
}
}
async fn handle_tools(
State(server): State<Arc<McpServer>>,
) -> Json<serde_json::Value> {
let request = crate::protocol::JsonRpcRequest::new(1i64, "tools/list");
let response = server.handle_request(request).await;
match response.result {
Some(result) => Json(result["tools"].clone()),
None => Json(json!([])),
}
}
async fn handle_mcp_jsonrpc(
State(server): State<Arc<McpServer>>,
Json(request): Json<JsonRpcRequest>,
) -> Json<serde_json::Value> {
let response = server.handle_request(request).await;
let response_json = serde_json::to_value(&response).unwrap_or_default();
Json(response_json)
}
#[derive(Debug, Deserialize)]
struct CallToolBody {
name: String,
#[serde(default)]
arguments: serde_json::Value,
}
async fn handle_call_tool(
State(server): State<Arc<McpServer>>,
Json(body): Json<CallToolBody>,
) -> (StatusCode, Json<serde_json::Value>) {
let rpc_request = JsonRpcRequest::new(1i64, "tools/call").with_params(json!({
"name": body.name,
"arguments": body.arguments
}));
let rpc_response = server.handle_request(rpc_request).await;
match rpc_response.result {
Some(result) => (StatusCode::OK, Json(result)),
None => {
let error_msg = rpc_response
.error
.map(|e| e.message)
.unwrap_or_else(|| "Unknown error".to_string());
(
StatusCode::BAD_REQUEST,
Json(json!({"error": error_msg})),
)
}
}
}
async fn handle_health() -> Json<serde_json::Value> {
Json(json!({
"status": "ok",
"server": "cortexai",
"version": "0.1.0"
}))
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use crate::server::FnTool;
use axum::body::to_bytes;
use serde_json::{json, Value};
use tower::util::ServiceExt;
fn create_test_server() -> Arc<crate::server::McpServer> {
crate::server::McpServer::builder()
.name("test-http-server")
.version("1.0.0")
.add_tool(FnTool::new(
"echo",
"Echoes input",
json!({
"type": "object",
"properties": {
"message": {"type": "string"}
}
}),
|args| {
let msg = args["message"].as_str().unwrap_or("no message");
Ok(json!({"echoed": msg}))
},
))
.build()
}
async fn send_request(
app: axum::Router,
req: axum::http::Request<axum::body::Body>,
) -> axum::http::Response<axum::body::Body> {
app.oneshot(req).await.unwrap()
}
#[test]
fn test_http_server_config_defaults() {
let config = HttpServerConfig::default();
assert_eq!(config.host, "0.0.0.0");
assert_eq!(config.port, 3001);
assert!(config.enable_cors);
}
#[tokio::test]
async fn test_get_health() {
let server = create_test_server();
let app = server.http_router(HttpServerConfig::default());
let req = axum::http::Request::builder()
.method("GET")
.uri("/health")
.body(axum::body::Body::empty())
.unwrap();
let response = send_request(app, req).await;
assert_eq!(response.status(), http::StatusCode::OK);
let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
let value: Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(value["status"], "ok");
assert_eq!(value["server"], "cortexai");
assert_eq!(value["version"], "0.1.0");
}
#[tokio::test]
async fn test_get_tools() {
let server = create_test_server();
let app = server.http_router(HttpServerConfig::default());
let req = axum::http::Request::builder()
.method("GET")
.uri("/tools")
.body(axum::body::Body::empty())
.unwrap();
let response = send_request(app, req).await;
assert_eq!(response.status(), http::StatusCode::OK);
let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
let value: Value = serde_json::from_slice(&bytes).unwrap();
assert!(value.is_array());
let tools = value.as_array().unwrap();
assert_eq!(tools.len(), 1);
assert_eq!(tools[0]["name"], "echo");
}
#[tokio::test]
async fn test_post_mcp_tools_list() {
let server = create_test_server();
let app = server.http_router(HttpServerConfig::default());
let rpc_body = json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/list"
});
let req = axum::http::Request::builder()
.method("POST")
.uri("/mcp")
.header("content-type", "application/json")
.body(axum::body::Body::from(serde_json::to_vec(&rpc_body).unwrap()))
.unwrap();
let response = send_request(app, req).await;
assert_eq!(response.status(), http::StatusCode::OK);
let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
let value: Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(value["jsonrpc"], "2.0");
assert!(value["result"]["tools"].is_array());
}
#[tokio::test]
async fn test_post_mcp_tools_call() {
let server = create_test_server();
let app = server.http_router(HttpServerConfig::default());
let rpc_body = json!({
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": {
"name": "echo",
"arguments": { "message": "hello" }
}
});
let req = axum::http::Request::builder()
.method("POST")
.uri("/mcp")
.header("content-type", "application/json")
.body(axum::body::Body::from(serde_json::to_vec(&rpc_body).unwrap()))
.unwrap();
let response = send_request(app, req).await;
assert_eq!(response.status(), http::StatusCode::OK);
let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
let value: Value = serde_json::from_slice(&bytes).unwrap();
assert!(value["error"].is_null());
let text = value["result"]["content"][0]["text"].as_str().unwrap();
assert!(text.contains("hello"));
}
#[tokio::test]
async fn test_post_call_tool_simplified() {
let server = create_test_server();
let app = server.http_router(HttpServerConfig::default());
let body = json!({
"name": "echo",
"arguments": { "message": "hi there" }
});
let req = axum::http::Request::builder()
.method("POST")
.uri("/call-tool")
.header("content-type", "application/json")
.body(axum::body::Body::from(serde_json::to_vec(&body).unwrap()))
.unwrap();
let response = send_request(app, req).await;
assert_eq!(response.status(), http::StatusCode::OK);
let bytes = to_bytes(response.into_body(), usize::MAX).await.unwrap();
let value: Value = serde_json::from_slice(&bytes).unwrap();
assert!(!value["isError"].as_bool().unwrap_or(true));
assert!(value["content"][0]["text"].as_str().unwrap().contains("hi there"));
}
}