1use std::net::SocketAddr;
36use std::sync::Arc;
37
38use axum::{
39 extract::State,
40 http::{HeaderMap, StatusCode},
41 response::{
42 sse::{Event, KeepAlive, Sse},
43 IntoResponse, Json,
44 },
45 routing::{get, post},
46 Router,
47};
48use futures_util::stream::Stream;
49use serde_json::{json, Value};
50use std::collections::HashMap;
51use std::convert::Infallible;
52use std::time::Duration;
53use tokio::sync::{mpsc, Mutex};
54use tokio::task::JoinHandle;
55
56use car_mcp::error_codes::PARSE as E_PARSE;
57use car_mcp::{Request as McpRequest, Server as McpServer};
58
59const SESSION_HEADER: &str = "mcp-session-id";
62
63const SSE_KEEPALIVE_SECS: u64 = 30;
68
69pub struct McpSession {
74 tx: mpsc::Sender<String>,
77}
78
79pub type SessionMap = Mutex<HashMap<String, McpSession>>;
85
86#[derive(Clone)]
87struct McpState {
88 server: Arc<McpServer>,
89 sessions: Arc<SessionMap>,
90}
91
92pub async fn start_mcp(
103 server: Arc<McpServer>,
104 addr: SocketAddr,
105) -> Result<(SocketAddr, JoinHandle<()>, Arc<SessionMap>), String> {
106 let listener = tokio::net::TcpListener::bind(addr)
107 .await
108 .map_err(|e| format!("bind {addr}: {e}"))?;
109 let bound = listener
110 .local_addr()
111 .map_err(|e| format!("local_addr: {e}"))?;
112
113 let sessions: Arc<SessionMap> = Arc::new(Mutex::new(HashMap::new()));
114 let state = McpState {
115 server,
116 sessions: sessions.clone(),
117 };
118 let app: Router = Router::new()
119 .route("/mcp", post(handle_mcp_post).get(handle_mcp_get))
120 .route("/mcp/health", get(handle_health))
121 .with_state(state);
122
123 let task = tokio::spawn(async move {
124 if let Err(e) = axum::serve(listener, app).await {
125 tracing::warn!(error = %e, "mcp HTTP server exited");
126 }
127 });
128
129 Ok((bound, task, sessions))
130}
131
132async fn handle_health() -> impl IntoResponse {
133 Json(json!({
134 "status": "ok",
135 "protocol_version": car_mcp::PROTOCOL_VERSION,
136 "server_name": car_mcp::SERVER_NAME,
137 }))
138}
139
140async fn handle_mcp_get(
153 State(state): State<McpState>,
154 headers: HeaderMap,
155) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
156 let session_id = headers
157 .get(SESSION_HEADER)
158 .and_then(|v| v.to_str().ok())
159 .map(|s| s.to_string())
160 .unwrap_or_else(|| uuid_v4_simple());
161 let (tx, rx) = mpsc::channel::<String>(64);
162 {
163 let mut sessions = state.sessions.lock().await;
164 sessions.insert(session_id.clone(), McpSession { tx });
165 }
166 tracing::debug!(%session_id, "MCP SSE stream opened");
167
168 let init_event = serde_json::to_string(&json!({
172 "jsonrpc": "2.0",
173 "method": "notifications/initialized",
174 "params": { "session_id": session_id.clone() },
175 }))
176 .unwrap_or_else(|_| "{}".to_string());
177
178 let stream =
179 async_stream::stream_init_event(init_event, rx, state.sessions.clone(), session_id.clone());
180
181 Sse::new(stream).keep_alive(
182 KeepAlive::new()
183 .interval(Duration::from_secs(SSE_KEEPALIVE_SECS))
184 .text("ping"),
185 )
186}
187
188pub async fn push_to_session(sessions: &SessionMap, session_id: &str, payload: &Value) -> bool {
197 let json = match serde_json::to_string(payload) {
198 Ok(s) => s,
199 Err(_) => return false,
200 };
201 let guard = sessions.lock().await;
202 let Some(session) = guard.get(session_id) else {
203 return false;
204 };
205 session.tx.send(json).await.is_ok()
206}
207
208fn uuid_v4_simple() -> String {
212 uuid::Uuid::new_v4().to_string()
213}
214
215mod async_stream {
216 use super::*;
217 use std::pin::Pin;
218 use std::task::{Context, Poll};
219
220 pub fn stream_init_event(
224 init: String,
225 rx: mpsc::Receiver<String>,
226 sessions: Arc<SessionMap>,
227 session_id: String,
228 ) -> McpEventStream {
229 McpEventStream {
230 init: Some(init),
231 rx,
232 cleanup: Some(SessionCleanup {
233 sessions,
234 session_id,
235 }),
236 }
237 }
238
239 pub struct McpEventStream {
240 init: Option<String>,
241 rx: mpsc::Receiver<String>,
242 cleanup: Option<SessionCleanup>,
243 }
244
245 struct SessionCleanup {
250 sessions: Arc<SessionMap>,
251 session_id: String,
252 }
253
254 impl Drop for McpEventStream {
255 fn drop(&mut self) {
256 if let Some(cleanup) = self.cleanup.take() {
259 tokio::spawn(async move {
260 let mut guard = cleanup.sessions.lock().await;
261 guard.remove(&cleanup.session_id);
262 tracing::debug!(session_id = %cleanup.session_id, "MCP SSE stream closed");
263 });
264 }
265 }
266 }
267
268 impl Stream for McpEventStream {
269 type Item = Result<Event, Infallible>;
270
271 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
272 if let Some(init) = self.init.take() {
274 return Poll::Ready(Some(Ok(Event::default().data(init))));
275 }
276 match self.rx.poll_recv(cx) {
277 Poll::Ready(Some(payload)) => Poll::Ready(Some(Ok(Event::default().data(payload)))),
278 Poll::Ready(None) => Poll::Ready(None),
282 Poll::Pending => Poll::Pending,
283 }
284 }
285 }
286}
287
288async fn handle_mcp_post(State(state): State<McpState>, body: String) -> impl IntoResponse {
289 let req: McpRequest = match serde_json::from_str(&body) {
293 Ok(req) => req,
294 Err(e) => {
295 let resp = json!({
296 "jsonrpc": "2.0",
297 "id": Value::Null,
298 "error": {
299 "code": E_PARSE,
300 "message": format!("parse error: {e}"),
301 },
302 });
303 return (StatusCode::OK, Json(resp));
304 }
305 };
306
307 match state.server.handle(req).await {
311 Some(resp) => match serde_json::to_value(&resp) {
312 Ok(v) => (StatusCode::OK, Json(v)),
313 Err(e) => (
314 StatusCode::INTERNAL_SERVER_ERROR,
315 Json(json!({
316 "jsonrpc": "2.0",
317 "id": Value::Null,
318 "error": {
319 "code": -32603,
320 "message": format!("response serialization failed: {e}"),
321 },
322 })),
323 ),
324 },
325 None => (StatusCode::OK, Json(json!({}))),
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use std::time::Duration;
333
334 async fn boot_test_server() -> (SocketAddr, JoinHandle<()>) {
338 let server = Arc::new(McpServer::new());
339 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
340 let (bound, task, _sessions) = start_mcp(server, addr).await.expect("start_mcp");
341 (bound, task)
342 }
343
344 async fn boot_test_server_with_sessions() -> (SocketAddr, JoinHandle<()>, Arc<SessionMap>) {
347 let server = Arc::new(McpServer::new());
352 let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
353 .await
354 .expect("bind");
355 let bound = listener.local_addr().expect("local_addr");
356 let sessions: Arc<SessionMap> = Arc::new(Mutex::new(HashMap::new()));
357 let state = McpState {
358 server,
359 sessions: sessions.clone(),
360 };
361 let app = Router::new()
362 .route("/mcp", post(handle_mcp_post).get(handle_mcp_get))
363 .route("/mcp/health", get(handle_health))
364 .with_state(state);
365 let task = tokio::spawn(async move {
366 let _ = axum::serve(listener, app).await;
367 });
368 (bound, task, sessions)
369 }
370
371 async fn http_post(addr: SocketAddr, body: &str) -> (StatusCode, Value) {
372 let url = format!("http://{}/mcp", addr);
373 let client = reqwest::Client::new();
374 let resp = client
375 .post(&url)
376 .header("Content-Type", "application/json")
377 .body(body.to_string())
378 .send()
379 .await
380 .expect("post");
381 let status = resp.status();
382 let value: Value = resp.json().await.expect("json");
383 (status, value)
384 }
385
386 #[tokio::test]
387 async fn health_endpoint_returns_ok() {
388 let (addr, _task) = boot_test_server().await;
389 tokio::time::sleep(Duration::from_millis(50)).await;
393 let url = format!("http://{}/mcp/health", addr);
394 let resp = reqwest::get(&url).await.expect("get");
395 assert_eq!(resp.status(), StatusCode::OK);
396 let body: Value = resp.json().await.expect("json");
397 assert_eq!(body["status"], "ok");
398 assert_eq!(body["protocol_version"], car_mcp::PROTOCOL_VERSION);
399 }
400
401 #[tokio::test]
402 async fn initialize_round_trips_over_http() {
403 let (addr, _task) = boot_test_server().await;
404 tokio::time::sleep(Duration::from_millis(50)).await;
405 let req = r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{}}"#;
406 let (status, body) = http_post(addr, req).await;
407 assert_eq!(status, StatusCode::OK);
408 assert_eq!(body["jsonrpc"], "2.0");
409 assert_eq!(body["id"], 1);
410 assert_eq!(body["result"]["protocolVersion"], car_mcp::PROTOCOL_VERSION);
411 }
412
413 #[tokio::test]
414 async fn tools_list_round_trips_over_http() {
415 let (addr, _task) = boot_test_server().await;
416 tokio::time::sleep(Duration::from_millis(50)).await;
417 let req = r#"{"jsonrpc":"2.0","id":2,"method":"tools/list"}"#;
418 let (status, body) = http_post(addr, req).await;
419 assert_eq!(status, StatusCode::OK);
420 let tools = body["result"]["tools"].as_array().expect("tools array");
421 assert_eq!(tools.len(), 6);
422 }
423
424 #[tokio::test]
425 async fn malformed_json_returns_parse_error() {
426 let (addr, _task) = boot_test_server().await;
427 tokio::time::sleep(Duration::from_millis(50)).await;
428 let (status, body) = http_post(addr, "{not valid").await;
429 assert_eq!(status, StatusCode::OK);
430 assert_eq!(body["error"]["code"], -32700);
431 }
432
433 #[tokio::test]
434 async fn shared_memgine_lets_facts_persist_across_requests() {
435 let memgine = Arc::new(tokio::sync::Mutex::new(car_memgine::MemgineEngine::new(
438 None,
439 )));
440 let server = Arc::new(McpServer::with_memgine(memgine));
441 let addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
442 let (addr, _task, _sessions) = start_mcp(server, addr).await.expect("start");
443 tokio::time::sleep(Duration::from_millis(50)).await;
444
445 let add = r#"{"jsonrpc":"2.0","id":1,"method":"tools/call","params":{"name":"memory_add_fact","arguments":{"subject":"daemon","body":"shared engine works"}}}"#;
446 let (_, _) = http_post(addr, add).await;
447
448 let query = r#"{"jsonrpc":"2.0","id":2,"method":"tools/call","params":{"name":"memory_query","arguments":{"query":"daemon","k":5}}}"#;
449 let (_, body) = http_post(addr, query).await;
450 let text = body["result"]["content"][0]["text"].as_str().expect("text");
451 assert!(
452 text.contains("daemon"),
453 "expected query to find ingested fact: {text}"
454 );
455 }
456
457 #[tokio::test]
458 async fn sse_get_emits_init_event_and_registers_session() {
459 let (addr, _task, sessions) = boot_test_server_with_sessions().await;
460 tokio::time::sleep(Duration::from_millis(50)).await;
461 let url = format!("http://{}/mcp", addr);
462 let client = reqwest::Client::new();
463 let resp = client
464 .get(&url)
465 .header("mcp-session-id", "test-session-1")
466 .send()
467 .await
468 .expect("get");
469 assert_eq!(resp.status(), StatusCode::OK);
470 tokio::time::sleep(Duration::from_millis(50)).await;
472 {
473 let guard = sessions.lock().await;
474 assert!(guard.contains_key("test-session-1"));
475 }
476 let mut stream = resp.bytes_stream();
479 use futures_util::StreamExt;
480 let chunk = tokio::time::timeout(Duration::from_secs(2), stream.next())
481 .await
482 .expect("timeout")
483 .expect("chunk")
484 .expect("bytes");
485 let body = String::from_utf8_lossy(&chunk).to_string();
486 assert!(body.contains("notifications/initialized"));
487 assert!(body.contains("test-session-1"));
488 }
489
490 #[tokio::test]
491 async fn push_to_session_delivers_payload_to_connected_client() {
492 let (addr, _task, sessions) = boot_test_server_with_sessions().await;
493 tokio::time::sleep(Duration::from_millis(50)).await;
494 let url = format!("http://{}/mcp", addr);
495 let client = reqwest::Client::new();
496 let resp = client
497 .get(&url)
498 .header("mcp-session-id", "push-session")
499 .send()
500 .await
501 .expect("get");
502 let mut stream = resp.bytes_stream();
505 use futures_util::StreamExt;
506 let _init = tokio::time::timeout(Duration::from_secs(2), stream.next())
507 .await
508 .expect("timeout")
509 .expect("chunk")
510 .expect("bytes");
511
512 for _ in 0..20 {
514 let guard = sessions.lock().await;
515 if guard.contains_key("push-session") {
516 break;
517 }
518 drop(guard);
519 tokio::time::sleep(Duration::from_millis(20)).await;
520 }
521
522 let payload = json!({
524 "jsonrpc": "2.0",
525 "id": 99,
526 "method": "tools/call",
527 "params": { "name": "host_owned_tool", "arguments": {} }
528 });
529 let delivered = push_to_session(&sessions, "push-session", &payload).await;
530 assert!(delivered, "push must succeed for connected session");
531
532 let chunk = tokio::time::timeout(Duration::from_secs(2), stream.next())
534 .await
535 .expect("timeout")
536 .expect("chunk")
537 .expect("bytes");
538 let body = String::from_utf8_lossy(&chunk).to_string();
539 assert!(body.contains("host_owned_tool"));
540 assert!(body.contains("\"id\":99"));
541 }
542
543 #[tokio::test]
544 async fn push_to_session_returns_false_for_unknown_session() {
545 let sessions: Arc<SessionMap> = Arc::new(Mutex::new(HashMap::new()));
546 let delivered = push_to_session(&sessions, "nobody", &json!({"x":1})).await;
547 assert!(!delivered);
548 }
549}