use std::net::SocketAddr;
use std::sync::Arc;
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
response::{
sse::{Event, KeepAlive, Sse},
IntoResponse, Json,
},
routing::{get, post},
Router,
};
use futures_util::stream::Stream;
use serde_json::{json, Value};
use std::collections::HashMap;
use std::convert::Infallible;
use std::time::Duration;
use tokio::sync::{mpsc, Mutex};
use tokio::task::JoinHandle;
use car_mcp::error_codes::PARSE as E_PARSE;
use car_mcp::{Request as McpRequest, Server as McpServer};
const SESSION_HEADER: &str = "mcp-session-id";
const SSE_KEEPALIVE_SECS: u64 = 30;
pub struct McpSession {
tx: mpsc::Sender<String>,
}
pub type SessionMap = Mutex<HashMap<String, McpSession>>;
#[derive(Clone)]
struct McpState {
server: Arc<McpServer>,
sessions: Arc<SessionMap>,
}
pub async fn start_mcp(
server: Arc<McpServer>,
addr: SocketAddr,
) -> Result<(SocketAddr, JoinHandle<()>, Arc<SessionMap>), String> {
let listener = tokio::net::TcpListener::bind(addr)
.await
.map_err(|e| format!("bind {addr}: {e}"))?;
let bound = listener
.local_addr()
.map_err(|e| format!("local_addr: {e}"))?;
let sessions: Arc<SessionMap> = Arc::new(Mutex::new(HashMap::new()));
let state = McpState {
server,
sessions: sessions.clone(),
};
let app: Router = Router::new()
.route("/mcp", post(handle_mcp_post).get(handle_mcp_get))
.route("/mcp/health", get(handle_health))
.with_state(state);
let task = tokio::spawn(async move {
if let Err(e) = axum::serve(listener, app).await {
tracing::warn!(error = %e, "mcp HTTP server exited");
}
});
Ok((bound, task, sessions))
}
async fn handle_health() -> impl IntoResponse {
Json(json!({
"status": "ok",
"protocol_version": car_mcp::PROTOCOL_VERSION,
"server_name": car_mcp::SERVER_NAME,
}))
}
async fn handle_mcp_get(
State(state): State<McpState>,
headers: HeaderMap,
) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
let session_id = headers
.get(SESSION_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| uuid_v4_simple());
let (tx, rx) = mpsc::channel::<String>(64);
{
let mut sessions = state.sessions.lock().await;
sessions.insert(session_id.clone(), McpSession { tx });
}
tracing::debug!(%session_id, "MCP SSE stream opened");
let init_event = serde_json::to_string(&json!({
"jsonrpc": "2.0",
"method": "notifications/initialized",
"params": { "session_id": session_id.clone() },
}))
.unwrap_or_else(|_| "{}".to_string());
let stream =
async_stream::stream_init_event(init_event, rx, state.sessions.clone(), session_id.clone());
Sse::new(stream).keep_alive(
KeepAlive::new()
.interval(Duration::from_secs(SSE_KEEPALIVE_SECS))
.text("ping"),
)
}
pub async fn push_to_session(sessions: &SessionMap, session_id: &str, payload: &Value) -> bool {
let json = match serde_json::to_string(payload) {
Ok(s) => s,
Err(_) => return false,
};
let guard = sessions.lock().await;
let Some(session) = guard.get(session_id) else {
return false;
};
session.tx.send(json).await.is_ok()
}
fn uuid_v4_simple() -> String {
uuid::Uuid::new_v4().to_string()
}
mod async_stream {
use super::*;
use std::pin::Pin;
use std::task::{Context, Poll};
pub fn stream_init_event(
init: String,
rx: mpsc::Receiver<String>,
sessions: Arc<SessionMap>,
session_id: String,
) -> McpEventStream {
McpEventStream {
init: Some(init),
rx,
cleanup: Some(SessionCleanup {
sessions,
session_id,
}),
}
}
pub struct McpEventStream {
init: Option<String>,
rx: mpsc::Receiver<String>,
cleanup: Option<SessionCleanup>,
}
struct SessionCleanup {
sessions: Arc<SessionMap>,
session_id: String,
}
impl Drop for McpEventStream {
fn drop(&mut self) {
if let Some(cleanup) = self.cleanup.take() {
tokio::spawn(async move {
let mut guard = cleanup.sessions.lock().await;
guard.remove(&cleanup.session_id);
tracing::debug!(session_id = %cleanup.session_id, "MCP SSE stream closed");
});
}
}
}
impl Stream for McpEventStream {
type Item = Result<Event, Infallible>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if let Some(init) = self.init.take() {
return Poll::Ready(Some(Ok(Event::default().data(init))));
}
match self.rx.poll_recv(cx) {
Poll::Ready(Some(payload)) => Poll::Ready(Some(Ok(Event::default().data(payload)))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
}
async fn handle_mcp_post(State(state): State<McpState>, body: String) -> impl IntoResponse {
let req: McpRequest = match serde_json::from_str(&body) {
Ok(req) => req,
Err(e) => {
let resp = json!({
"jsonrpc": "2.0",
"id": Value::Null,
"error": {
"code": E_PARSE,
"message": format!("parse error: {e}"),
},
});
return (StatusCode::OK, Json(resp));
}
};
match state.server.handle(req).await {
Some(resp) => match serde_json::to_value(&resp) {
Ok(v) => (StatusCode::OK, Json(v)),
Err(e) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"jsonrpc": "2.0",
"id": Value::Null,
"error": {
"code": -32603,
"message": format!("response serialization failed: {e}"),
},
})),
),
},
None => (StatusCode::OK, Json(json!({}))),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
async fn boot_test_server() -> (SocketAddr, JoinHandle<()>) {
let server = Arc::new(McpServer::new());
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let (bound, task, _sessions) = start_mcp(server, addr).await.expect("start_mcp");
(bound, task)
}
async fn boot_test_server_with_sessions() -> (SocketAddr, JoinHandle<()>, Arc<SessionMap>) {
let server = Arc::new(McpServer::new());
let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
.await
.expect("bind");
let bound = listener.local_addr().expect("local_addr");
let sessions: Arc<SessionMap> = Arc::new(Mutex::new(HashMap::new()));
let state = McpState {
server,
sessions: sessions.clone(),
};
let app = Router::new()
.route("/mcp", post(handle_mcp_post).get(handle_mcp_get))
.route("/mcp/health", get(handle_health))
.with_state(state);
let task = tokio::spawn(async move {
let _ = axum::serve(listener, app).await;
});
(bound, task, sessions)
}
async fn http_post(addr: SocketAddr, body: &str) -> (StatusCode, Value) {
let url = format!("http://{}/mcp", addr);
let client = reqwest::Client::new();
let resp = client
.post(&url)
.header("Content-Type", "application/json")
.body(body.to_string())
.send()
.await
.expect("post");
let status = resp.status();
let value: Value = resp.json().await.expect("json");
(status, value)
}
#[tokio::test]
async fn health_endpoint_returns_ok() {
let (addr, _task) = boot_test_server().await;
tokio::time::sleep(Duration::from_millis(50)).await;
let url = format!("http://{}/mcp/health", addr);
let resp = reqwest::get(&url).await.expect("get");
assert_eq!(resp.status(), StatusCode::OK);
let body: Value = resp.json().await.expect("json");
assert_eq!(body["status"], "ok");
assert_eq!(body["protocol_version"], car_mcp::PROTOCOL_VERSION);
}
#[tokio::test]
async fn initialize_round_trips_over_http() {
let (addr, _task) = boot_test_server().await;
tokio::time::sleep(Duration::from_millis(50)).await;
let req = r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}"#;
let (status, body) = http_post(addr, req).await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body["jsonrpc"], "2.0");
assert_eq!(body["id"], 1);
assert_eq!(body["result"]["protocolVersion"], car_mcp::PROTOCOL_VERSION);
}
#[tokio::test]
async fn tools_list_round_trips_over_http() {
let (addr, _task) = boot_test_server().await;
tokio::time::sleep(Duration::from_millis(50)).await;
let req = r#"{"jsonrpc":"2.0","id":2,"method":"tools/list"}"#;
let (status, body) = http_post(addr, req).await;
assert_eq!(status, StatusCode::OK);
let tools = body["result"]["tools"].as_array().expect("tools array");
assert_eq!(tools.len(), 6);
}
#[tokio::test]
async fn malformed_json_returns_parse_error() {
let (addr, _task) = boot_test_server().await;
tokio::time::sleep(Duration::from_millis(50)).await;
let (status, body) = http_post(addr, "{not valid").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body["error"]["code"], -32700);
}
#[tokio::test]
async fn shared_memgine_lets_facts_persist_across_requests() {
let memgine = Arc::new(tokio::sync::Mutex::new(car_memgine::MemgineEngine::new(
None,
)));
let server = Arc::new(McpServer::with_memgine(memgine));
let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
let (addr, _task, _sessions) = start_mcp(server, addr).await.expect("start");
tokio::time::sleep(Duration::from_millis(50)).await;
let add = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"memory_add_fact","arguments":{"subject":"daemon","body":"shared engine works"}}}"#;
let (_, _) = http_post(addr, add).await;
let query = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"memory_query","arguments":{"query":"daemon","k":5}}}"#;
let (_, body) = http_post(addr, query).await;
let text = body["result"]["content"][0]["text"].as_str().expect("text");
assert!(
text.contains("daemon"),
"expected query to find ingested fact: {text}"
);
}
#[tokio::test]
async fn sse_get_emits_init_event_and_registers_session() {
let (addr, _task, sessions) = boot_test_server_with_sessions().await;
tokio::time::sleep(Duration::from_millis(50)).await;
let url = format!("http://{}/mcp", addr);
let client = reqwest::Client::new();
let resp = client
.get(&url)
.header("mcp-session-id", "test-session-1")
.send()
.await
.expect("get");
assert_eq!(resp.status(), StatusCode::OK);
tokio::time::sleep(Duration::from_millis(50)).await;
{
let guard = sessions.lock().await;
assert!(guard.contains_key("test-session-1"));
}
let mut stream = resp.bytes_stream();
use futures_util::StreamExt;
let chunk = tokio::time::timeout(Duration::from_secs(2), stream.next())
.await
.expect("timeout")
.expect("chunk")
.expect("bytes");
let body = String::from_utf8_lossy(&chunk).to_string();
assert!(body.contains("notifications/initialized"));
assert!(body.contains("test-session-1"));
}
#[tokio::test]
async fn push_to_session_delivers_payload_to_connected_client() {
let (addr, _task, sessions) = boot_test_server_with_sessions().await;
tokio::time::sleep(Duration::from_millis(50)).await;
let url = format!("http://{}/mcp", addr);
let client = reqwest::Client::new();
let resp = client
.get(&url)
.header("mcp-session-id", "push-session")
.send()
.await
.expect("get");
let mut stream = resp.bytes_stream();
use futures_util::StreamExt;
let _init = tokio::time::timeout(Duration::from_secs(2), stream.next())
.await
.expect("timeout")
.expect("chunk")
.expect("bytes");
for _ in 0..20 {
let guard = sessions.lock().await;
if guard.contains_key("push-session") {
break;
}
drop(guard);
tokio::time::sleep(Duration::from_millis(20)).await;
}
let payload = json!({
"jsonrpc": "2.0",
"id": 99,
"method": "tools/call",
"params": { "name": "host_owned_tool", "arguments": {} }
});
let delivered = push_to_session(&sessions, "push-session", &payload).await;
assert!(delivered, "push must succeed for connected session");
let chunk = tokio::time::timeout(Duration::from_secs(2), stream.next())
.await
.expect("timeout")
.expect("chunk")
.expect("bytes");
let body = String::from_utf8_lossy(&chunk).to_string();
assert!(body.contains("host_owned_tool"));
assert!(body.contains("\"id\":99"));
}
#[tokio::test]
async fn push_to_session_returns_false_for_unknown_session() {
let sessions: Arc<SessionMap> = Arc::new(Mutex::new(HashMap::new()));
let delivered = push_to_session(&sessions, "nobody", &json!({"x":1})).await;
assert!(!delivered);
}
}