use std::convert::Infallible;
use std::net::SocketAddr;
use std::sync::Arc;
use axum::Json;
use axum::extract::State;
use axum::http::{HeaderMap, StatusCode};
use axum::response::IntoResponse;
use axum::response::sse::{Event, KeepAlive, Sse};
use axum::routing::post;
use serde_json::Value;
use tokio_stream::wrappers::UnboundedReceiverStream;
use crate::mcp::McpServer;
#[derive(Clone)]
pub struct HttpTransport {
server: Arc<McpServer>,
}
impl HttpTransport {
pub fn new(server: Arc<McpServer>) -> Self {
Self { server }
}
pub fn router(&self) -> axum::Router {
axum::Router::new()
.route("/mcp", post(rpc_handler))
.route("/mcp/sse", post(sse_handler))
.with_state(self.clone())
}
}
pub async fn serve(server: Arc<McpServer>, addr: SocketAddr) -> std::io::Result<()> {
let transport = HttpTransport::new(server);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, transport.router()).await?;
Ok(())
}
const ERR_METHOD_NOT_FOUND: i32 = -32601;
const ERR_INTERNAL: i32 = -32603;
const ERR_UNAUTHORIZED: i32 = -32001;
async fn dispatch_async(server: &McpServer, req: &Value) -> Option<Value> {
let id = req.get("id")?.clone();
let method = req.get("method").and_then(|v| v.as_str()).unwrap_or("");
let result: Result<Value, (i32, String)> = match method {
"initialize" => Ok(server.initialize_response()),
"tools/list" => Ok(serde_json::json!({ "tools": server.tools() })),
"resources/list" => Ok(serde_json::json!({ "resources": server.resources() })),
"resources/read" => {
let uri = req
.pointer("/params/uri")
.and_then(|v| v.as_str())
.unwrap_or("");
server
.read_resource(uri, serde_json::json!({}))
.map(|content| {
serde_json::json!({
"contents": [{ "uri": uri, "text": content.to_string() }]
})
})
.map_err(|e| (ERR_INTERNAL, e.to_string()))
}
"tools/call" => {
let name = req
.pointer("/params/name")
.and_then(|v| v.as_str())
.unwrap_or("");
let params = req
.pointer("/params/arguments")
.cloned()
.unwrap_or(serde_json::json!(null));
server
.call_tool_async(name, params)
.await
.map(|r| {
serde_json::json!({
"content": [{ "type": "text", "text": r.to_string() }]
})
})
.map_err(|e| (ERR_INTERNAL, e.to_string()))
}
_ => Err((ERR_METHOD_NOT_FOUND, format!("Method not found: {method}"))),
};
Some(match result {
Ok(value) => serde_json::json!({ "jsonrpc": "2.0", "id": id, "result": value }),
Err((code, message)) => serde_json::json!({
"jsonrpc": "2.0", "id": id,
"error": { "code": code, "message": message }
}),
})
}
fn authorized(server: &McpServer, headers: &HeaderMap) -> bool {
let auth = headers
.get("authorization")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
server.check_auth(auth)
}
fn unauthorized_response(id: Option<Value>) -> Json<Value> {
Json(serde_json::json!({
"jsonrpc": "2.0",
"id": id,
"error": { "code": ERR_UNAUTHORIZED, "message": "Unauthorized" }
}))
}
async fn rpc_handler(
State(state): State<HttpTransport>,
headers: HeaderMap,
Json(req): Json<Value>,
) -> impl IntoResponse {
let id = req.get("id").cloned();
if !authorized(&state.server, &headers) {
return (StatusCode::UNAUTHORIZED, unauthorized_response(id));
}
match dispatch_async(&state.server, &req).await {
Some(resp) => (StatusCode::OK, Json(resp)),
None => (StatusCode::NO_CONTENT, Json(serde_json::Value::Null)),
}
}
async fn sse_handler(
State(state): State<HttpTransport>,
headers: HeaderMap,
Json(req): Json<Value>,
) -> impl IntoResponse {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let server = state.server.clone();
tokio::spawn(async move {
let event = if !authorized(&server, &headers) {
Event::default().event("error").data(
serde_json::to_string(&unauthorized_response(req.get("id").cloned()).0)
.unwrap_or_default(),
)
} else if let Some(resp) = dispatch_async(&server, &req).await {
let data = serde_json::to_string(&resp).unwrap_or_default();
Event::default().event("message").data(data)
} else {
Event::default().event("noop")
};
let _ = tx.send(Ok::<_, Infallible>(event));
});
Sse::new(UnboundedReceiverStream::new(rx)).keep_alive(KeepAlive::default())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::mcp::schema::{ResourceDescription, ToolDescription};
fn server_with_echo() -> McpServer {
let mut server = McpServer::new("http-test", "1.0.0");
server.register_tool(ToolDescription {
name: "echo".into(),
description: "Echo".into(),
input_schema: serde_json::json!({"type": "object"}),
});
server.set_async_handler("echo", |params| async move { Ok(params) });
server
}
#[tokio::test]
async fn dispatch_initialize() {
let server = server_with_echo();
let req = serde_json::json!({"jsonrpc":"2.0","id":1,"method":"initialize","params":{}});
let resp = dispatch_async(&server, &req).await.unwrap();
assert_eq!(resp["result"]["serverInfo"]["name"], "http-test");
}
#[tokio::test]
async fn dispatch_tools_call_async() {
let server = server_with_echo();
let req = serde_json::json!({
"jsonrpc": "2.0", "id": 2, "method": "tools/call",
"params": { "name": "echo", "arguments": { "msg": "hello" } }
});
let resp = dispatch_async(&server, &req).await.unwrap();
let text = resp["result"]["content"][0]["text"].as_str().unwrap();
assert!(text.contains("hello"));
}
#[tokio::test]
async fn dispatch_unknown_method() {
let server = server_with_echo();
let req = serde_json::json!({"jsonrpc":"2.0","id":3,"method":"nope"});
let resp = dispatch_async(&server, &req).await.unwrap();
assert_eq!(resp["error"]["code"], ERR_METHOD_NOT_FOUND);
}
#[tokio::test]
async fn dispatch_resources_read() {
let mut server = McpServer::new("http-test", "1.0.0");
server.register_resource(ResourceDescription {
uri: "docs://x".into(),
name: "X".into(),
description: None,
mime_type: None,
});
server.set_resource_handler("docs://x", |_| Ok(serde_json::json!("# body")));
let req = serde_json::json!({
"jsonrpc": "2.0", "id": 4, "method": "resources/read",
"params": { "uri": "docs://x" }
});
let resp = dispatch_async(&server, &req).await.unwrap();
let text = resp["result"]["contents"][0]["text"].as_str().unwrap();
assert!(text.contains("body"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn http_round_trip_calls_tool() {
let server = Arc::new(server_with_echo());
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let transport = HttpTransport::new(server);
tokio::spawn(async move {
let _ = axum::serve(listener, transport.router()).await;
});
let body = serde_json::to_string(&serde_json::json!({
"jsonrpc": "2.0", "id": 9, "method": "tools/call",
"params": { "name": "echo", "arguments": { "v": 42 } }
}))
.unwrap();
let req = format!(
"POST /mcp HTTP/1.1\r\nHost: localhost\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
);
let mut stream = tokio::net::TcpStream::connect(addr).await.unwrap();
use tokio::io::{AsyncReadExt, AsyncWriteExt};
stream.write_all(req.as_bytes()).await.unwrap();
let mut buf = Vec::new();
stream.read_to_end(&mut buf).await.unwrap();
let response = String::from_utf8_lossy(&buf);
assert!(response.contains("200 OK"), "response: {response}");
assert!(response.contains("\"content\""), "response: {response}");
assert!(response.contains("\\\"v\\\":42"), "response: {response}");
}
}