Skip to main content

car_engine/
mcp.rs

1//! MCP (Model Context Protocol) server integration.
2//!
3//! Discovers tools from MCP servers via stdin/stdout JSON-RPC and registers
4//! them into the canonical tool registry. MCP tools participate in the same
5//! capability/permission/policy flow as all other tools.
6
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9use std::collections::HashMap;
10use std::sync::atomic::{AtomicBool, Ordering};
11use std::sync::{Arc, Mutex as StdMutex};
12use std::time::Duration;
13use tokio::io::{AsyncBufRead, AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStderr, Command};
15use tokio::sync::{oneshot, Mutex};
16
17/// Configuration for an MCP server.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct McpServerConfig {
20    /// Display name for this server.
21    pub name: String,
22    /// Command to launch the server.
23    pub command: String,
24    /// Arguments for the command.
25    #[serde(default)]
26    pub args: Vec<String>,
27    /// Environment variables.
28    #[serde(default)]
29    pub env: HashMap<String, String>,
30    /// Working directory.
31    pub cwd: Option<String>,
32}
33
34/// Map of in-flight request id → the waiter to deliver its response to.
35type Pending = Arc<StdMutex<HashMap<u64, oneshot::Sender<McpResponse>>>>;
36
37/// A running MCP server connection.
38///
39/// A background **reader task** owns stdout and demultiplexes responses by id
40/// into per-request `oneshot` channels. `send_request` writes to stdin and then
41/// awaits its channel — never `read_line` directly. So a request timeout (or any
42/// cancellation of the caller's future) only drops a receiver; the reader keeps
43/// consuming the stream and discards the now-orphaned response. This makes the
44/// transport cancel-safe by construction: there is no read to interrupt
45/// mid-line, hence no desync and no poison/recovery dance.
46pub struct McpServer {
47    config: McpServerConfig,
48    child: Child,
49    stdin: tokio::io::BufWriter<tokio::process::ChildStdin>,
50    next_id: u64,
51    pending: Pending,
52    /// Background reader task handle (aborted on reconnect/drop).
53    reader: tokio::task::JoinHandle<()>,
54    /// Background stderr-drain task (aborted on reconnect/drop). Must exist:
55    /// stderr is piped, so an undrained chatty server fills the pipe buffer and
56    /// deadlocks its own stdout writes.
57    stderr_reader: tokio::task::JoinHandle<()>,
58    /// Cleared when the reader exits (EOF / read error) — i.e. the connection is
59    /// dead. The next `send_request` reconnects.
60    alive: Arc<AtomicBool>,
61    /// Backstop timeout for awaiting any single response. Callers usually impose
62    /// their own (per-action `timeout_ms` in the executor); this bounds requests
63    /// that don't, so a silent server can't hang a call forever.
64    request_timeout: Duration,
65}
66
67impl Drop for McpServer {
68    fn drop(&mut self) {
69        // tokio doesn't kill children or abort tasks on drop by default. Do it
70        // explicitly so a dropped server (e.g. `shutdown_all` draining the map)
71        // doesn't leak the reader/stderr tasks or the child process.
72        self.reader.abort();
73        self.stderr_reader.abort();
74        let _ = self.child.start_kill();
75    }
76}
77
78/// An MCP tool discovered from a server.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct McpToolInfo {
81    pub name: String,
82    pub description: Option<String>,
83    #[serde(rename = "inputSchema")]
84    pub input_schema: Option<Value>,
85}
86
87/// MCP JSON-RPC request.
88#[derive(Debug, Serialize)]
89struct McpRequest {
90    jsonrpc: &'static str,
91    method: String,
92    #[serde(skip_serializing_if = "Option::is_none")]
93    params: Option<Value>,
94    id: u64,
95}
96
97/// MCP JSON-RPC response.
98#[derive(Debug, Deserialize)]
99struct McpResponse {
100    result: Option<Value>,
101    error: Option<McpError>,
102    /// Echoed request id — used to route the response to its waiter.
103    id: Option<u64>,
104}
105
106#[derive(Debug, Deserialize)]
107struct McpError {
108    #[allow(dead_code)]
109    code: Option<i64>,
110    message: String,
111}
112
113/// Route one stdout line to its waiting request, by id. Unparseable lines and
114/// notifications (no id) are ignored; an id with no waiter (a late response from
115/// a request whose caller already gave up) is discarded — the stream stays
116/// synchronized either way. Pure over `pending`, so the demux is unit-testable
117/// without spawning a subprocess.
118fn route_line(line: &str, pending: &StdMutex<HashMap<u64, oneshot::Sender<McpResponse>>>) {
119    let resp: McpResponse = match serde_json::from_str(line) {
120        Ok(r) => r,
121        Err(_) => return, // notification / log noise / partial line — ignore
122    };
123    if let Some(id) = resp.id {
124        if let Some(tx) = pending.lock().unwrap().remove(&id) {
125            // Receiver may be gone (caller timed out / was cancelled) — fine.
126            let _ = tx.send(resp);
127        }
128        // Unknown id → orphaned/duplicate response; discard.
129    }
130}
131
132/// Background task: read newline-delimited responses and route each by id until
133/// the stream closes. On exit, mark the connection dead and drop all waiters
134/// (their `recv()` then errors out).
135async fn reader_loop<R: AsyncBufRead + Unpin>(
136    mut stdout: R,
137    pending: Pending,
138    alive: Arc<AtomicBool>,
139    server_name: String,
140) {
141    let mut line = String::new();
142    loop {
143        line.clear();
144        match stdout.read_line(&mut line).await {
145            Ok(0) | Err(_) => break, // EOF or read error → connection dead
146            Ok(_) => route_line(&line, &pending),
147        }
148    }
149    alive.store(false, Ordering::SeqCst);
150    pending.lock().unwrap().clear(); // dropping senders wakes waiters with an error
151    tracing::debug!(server = %server_name, "MCP reader exited; connection closed");
152}
153
154/// Drain the child's stderr to the log so a verbose server can't fill the pipe
155/// buffer and deadlock its own stdout writes.
156async fn stderr_drain_loop(stderr: ChildStderr, server_name: String) {
157    let mut lines = BufReader::new(stderr).lines();
158    while let Ok(Some(line)) = lines.next_line().await {
159        tracing::debug!(server = %server_name, "mcp stderr: {line}");
160    }
161}
162
163impl McpServer {
164    /// Start an MCP server and initialize the connection.
165    pub async fn start(config: McpServerConfig) -> Result<Self, String> {
166        let mut cmd = Command::new(&config.command);
167        cmd.args(&config.args)
168            .stdin(std::process::Stdio::piped())
169            .stdout(std::process::Stdio::piped())
170            .stderr(std::process::Stdio::piped());
171
172        if let Some(ref cwd) = config.cwd {
173            cmd.current_dir(cwd);
174        }
175        for (k, v) in &config.env {
176            cmd.env(k, v);
177        }
178
179        let mut child = cmd
180            .spawn()
181            .map_err(|e| format!("failed to start MCP server '{}': {}", config.name, e))?;
182
183        let stdin = child
184            .stdin
185            .take()
186            .ok_or_else(|| "MCP server has no stdin".to_string())?;
187        let stdout = child
188            .stdout
189            .take()
190            .ok_or_else(|| "MCP server has no stdout".to_string())?;
191        let stderr = child
192            .stderr
193            .take()
194            .ok_or_else(|| "MCP server has no stderr".to_string())?;
195
196        let pending: Pending = Arc::new(StdMutex::new(HashMap::new()));
197        let alive = Arc::new(AtomicBool::new(true));
198        let reader = tokio::spawn(reader_loop(
199            BufReader::new(stdout),
200            Arc::clone(&pending),
201            Arc::clone(&alive),
202            config.name.clone(),
203        ));
204        let stderr_reader = tokio::spawn(stderr_drain_loop(stderr, config.name.clone()));
205
206        let mut server = Self {
207            config,
208            child,
209            stdin: tokio::io::BufWriter::new(stdin),
210            next_id: 1,
211            pending,
212            reader,
213            stderr_reader,
214            alive,
215            request_timeout: Duration::from_secs(120),
216        };
217
218        // Send initialize
219        server
220            .send_request(
221                "initialize",
222                Some(serde_json::json!({
223                    "protocolVersion": "2024-11-05",
224                    "capabilities": {},
225                    "clientInfo": {
226                        "name": "car-runtime",
227                        "version": env!("CARGO_PKG_VERSION")
228                    }
229                })),
230            )
231            .await?;
232
233        // Send initialized notification (no id, per MCP spec)
234        let notification = serde_json::json!({
235            "jsonrpc": "2.0",
236            "method": "notifications/initialized"
237        });
238        let msg =
239            serde_json::to_string(&notification).map_err(|e| format!("serialize error: {e}"))?;
240        server.write_message(&msg).await?;
241
242        Ok(server)
243    }
244
245    /// Respawn the child and reader, replacing a dead session in place.
246    async fn reconnect(&mut self) -> Result<(), String> {
247        tracing::warn!(server = %self.config.name, "MCP connection closed; reconnecting (server-side state is lost)");
248        self.reader.abort(); // stop the old reader + stderr drain tasks
249        self.stderr_reader.abort();
250        let _ = self.child.kill().await;
251        // Box the recursive future to break the *type-size* recursion: start()
252        // handshakes via send_request. This cannot loop at runtime — start()
253        // builds a fresh, alive server, so its handshake never re-enters
254        // reconnect; a failing handshake propagates Err instead.
255        let fresh = Box::pin(Self::start(self.config.clone())).await?;
256        *self = fresh;
257        Ok(())
258    }
259
260    /// Write one newline-delimited JSON message to the server.
261    async fn write_message(&mut self, msg: &str) -> Result<(), String> {
262        self.stdin
263            .write_all(msg.as_bytes())
264            .await
265            .map_err(|e| format!("write to MCP server: {e}"))?;
266        self.stdin
267            .write_all(b"\n")
268            .await
269            .map_err(|e| format!("write newline: {e}"))?;
270        self.stdin.flush().await.map_err(|e| format!("flush: {e}"))?;
271        Ok(())
272    }
273
274    async fn send_request(&mut self, method: &str, params: Option<Value>) -> Result<Value, String> {
275        // Reconnect if the reader has exited (connection died).
276        if !self.alive.load(Ordering::SeqCst) {
277            self.reconnect().await.map_err(|e| {
278                format!(
279                    "MCP session '{}' is dead and reconnect failed: {e}",
280                    self.config.name
281                )
282            })?;
283        }
284
285        let id = self.next_id;
286        self.next_id += 1;
287
288        // Register our waiter BEFORE writing, so a fast response can't arrive
289        // before the reader knows where to route it.
290        let (tx, rx) = oneshot::channel();
291        self.pending.lock().unwrap().insert(id, tx);
292
293        let req = McpRequest {
294            jsonrpc: "2.0",
295            method: method.to_string(),
296            params,
297            id,
298        };
299        let msg = serde_json::to_string(&req).map_err(|e| format!("serialize error: {e}"))?;
300
301        if let Err(e) = self.write_message(&msg).await {
302            self.pending.lock().unwrap().remove(&id);
303            self.alive.store(false, Ordering::SeqCst); // broken pipe → dead
304            return Err(e);
305        }
306
307        // Await the channel — NOT a read. A timeout (or upstream cancellation
308        // dropping this future) just drops the receiver; the reader still
309        // consumes and discards the eventual response, so the stream never
310        // desyncs. We only clean up our pending entry on timeout.
311        let resp = match tokio::time::timeout(self.request_timeout, rx).await {
312            Ok(Ok(resp)) => resp,
313            Ok(Err(_)) => {
314                return Err(format!(
315                    "MCP server '{}' closed the connection",
316                    self.config.name
317                ))
318            }
319            Err(_) => {
320                self.pending.lock().unwrap().remove(&id);
321                return Err(format!("MCP request '{method}' timed out"));
322            }
323        };
324
325        if let Some(err) = resp.error {
326            return Err(format!("MCP error: {}", err.message));
327        }
328        resp.result
329            .ok_or_else(|| "MCP server returned no result".to_string())
330    }
331
332    /// Discover tools from this MCP server.
333    pub async fn list_tools(&mut self) -> Result<Vec<McpToolInfo>, String> {
334        let result = self.send_request("tools/list", None).await?;
335        let tools = result
336            .get("tools")
337            .and_then(|t| t.as_array())
338            .cloned()
339            .unwrap_or_default();
340
341        tools
342            .into_iter()
343            .map(|t| serde_json::from_value(t).map_err(|e| format!("invalid tool definition: {e}")))
344            .collect()
345    }
346
347    /// Call a tool on this MCP server.
348    pub async fn call_tool(&mut self, name: &str, arguments: Value) -> Result<Value, String> {
349        let result = self
350            .send_request(
351                "tools/call",
352                Some(serde_json::json!({
353                    "name": name,
354                    "arguments": arguments,
355                })),
356            )
357            .await?;
358
359        // Extract text content from MCP response format
360        if let Some(content) = result.get("content").and_then(|c| c.as_array()) {
361            let texts: Vec<&str> = content
362                .iter()
363                .filter_map(|block| {
364                    if block.get("type").and_then(|t| t.as_str()) == Some("text") {
365                        block.get("text").and_then(|t| t.as_str())
366                    } else {
367                        None
368                    }
369                })
370                .collect();
371            if !texts.is_empty() {
372                return Ok(Value::String(texts.join("\n")));
373            }
374        }
375
376        Ok(result)
377    }
378
379    /// Shut down the MCP server gracefully.
380    pub async fn shutdown(mut self) {
381        let _ = self.stdin.shutdown().await;
382        let _ = self.child.kill().await;
383        let _ = self.child.wait().await;
384    }
385
386    /// Get the server name.
387    pub fn name(&self) -> &str {
388        &self.config.name
389    }
390}
391
392/// MCP tool executor -- routes tool calls to the appropriate MCP server.
393pub struct McpToolExecutor {
394    servers: Arc<Mutex<HashMap<String, Arc<Mutex<McpServer>>>>>,
395    /// Maps tool_name -> server_name for routing.
396    tool_routes: Arc<Mutex<HashMap<String, String>>>,
397    /// Optional fallback for non-MCP tools.
398    fallback: Option<Arc<dyn super::ToolExecutor>>,
399}
400
401impl McpToolExecutor {
402    pub fn new() -> Self {
403        Self {
404            servers: Arc::new(Mutex::new(HashMap::new())),
405            tool_routes: Arc::new(Mutex::new(HashMap::new())),
406            fallback: None,
407        }
408    }
409
410    pub fn with_fallback(mut self, fallback: Arc<dyn super::ToolExecutor>) -> Self {
411        self.fallback = Some(fallback);
412        self
413    }
414
415    /// Add an MCP server and discover its tools.
416    /// Returns the list of discovered tool names (canonical form: `mcp_{server}_{tool}`).
417    pub async fn add_server(&self, mut server: McpServer) -> Result<Vec<String>, String> {
418        let server_name = server.config.name.clone();
419        let tools = server.list_tools().await?;
420
421        let tool_names: Vec<String> = tools
422            .iter()
423            .map(|t| format!("mcp_{}_{}", server_name, t.name))
424            .collect();
425
426        // Register tool routes
427        {
428            let mut routes = self.tool_routes.lock().await;
429            for (info, canonical_name) in tools.iter().zip(tool_names.iter()) {
430                routes.insert(canonical_name.clone(), server_name.clone());
431                // Also register the bare name for convenience
432                routes.insert(info.name.clone(), server_name.clone());
433            }
434        }
435
436        // Store server
437        self.servers
438            .lock()
439            .await
440            .insert(server_name, Arc::new(Mutex::new(server)));
441
442        Ok(tool_names)
443    }
444
445    /// Get tool schemas from all connected MCP servers.
446    pub async fn tool_schemas(&self) -> Vec<(String, car_ir::ToolSchema)> {
447        let mut schemas = Vec::new();
448        let servers = self.servers.lock().await;
449        for (server_name, server) in servers.iter() {
450            let mut srv = server.lock().await;
451            if let Ok(tools) = srv.list_tools().await {
452                for tool in tools {
453                    let canonical_name = format!("mcp_{}_{}", server_name, tool.name);
454                    schemas.push((
455                        server_name.clone(),
456                        car_ir::ToolSchema {
457                            name: canonical_name,
458                            description: tool.description.unwrap_or_default(),
459                            parameters: tool
460                                .input_schema
461                                .unwrap_or(serde_json::json!({"type": "object"})),
462                            returns: None,
463                            idempotent: false,
464                            cache_ttl_secs: None,
465                            rate_limit: None,
466                        },
467                    ));
468                }
469            }
470        }
471        schemas
472    }
473
474    /// Shut down all MCP servers.
475    pub async fn shutdown_all(&self) {
476        let mut servers = self.servers.lock().await;
477        // Dropping the Arc<Mutex<McpServer>> will drop the Child, killing the process.
478        servers.drain();
479    }
480}
481
482impl Default for McpToolExecutor {
483    fn default() -> Self {
484        Self::new()
485    }
486}
487
488#[async_trait::async_trait]
489impl super::ToolExecutor for McpToolExecutor {
490    async fn execute(&self, tool: &str, params: &Value) -> Result<Value, String> {
491        self.execute_with_action(tool, params, "", None).await
492    }
493
494    async fn execute_with_action(
495        &self,
496        tool: &str,
497        params: &Value,
498        action_id: &str,
499        timeout_ms: Option<u64>,
500    ) -> Result<Value, String> {
501        // Find which server handles this tool
502        let server_name = {
503            let routes = self.tool_routes.lock().await;
504            routes.get(tool).cloned()
505        };
506
507        if let Some(server_name) = server_name {
508            let servers = self.servers.lock().await;
509            if let Some(server) = servers.get(&server_name) {
510                let mut srv = server.lock().await;
511                // Strip the mcp_{server}_ prefix to get the bare tool name
512                let bare_name = tool
513                    .strip_prefix(&format!("mcp_{}_", server_name))
514                    .unwrap_or(tool);
515                return srv.call_tool(bare_name, params.clone()).await;
516            }
517        }
518
519        // Fallback
520        if let Some(ref fallback) = self.fallback {
521            return fallback
522                .execute_with_action(tool, params, action_id, timeout_ms)
523                .await;
524        }
525
526        Err(format!("unknown MCP tool: '{}'", tool))
527    }
528}
529
530#[cfg(test)]
531mod tests {
532    use super::*;
533
534    fn pending() -> StdMutex<HashMap<u64, oneshot::Sender<McpResponse>>> {
535        StdMutex::new(HashMap::new())
536    }
537
538    #[tokio::test]
539    async fn routes_response_to_matching_waiter() {
540        let p = pending();
541        let (tx, rx) = oneshot::channel();
542        p.lock().unwrap().insert(7, tx);
543        route_line(r#"{"jsonrpc":"2.0","id":7,"result":{"value":42}}"#, &p);
544        let resp = rx.await.expect("waiter delivered");
545        assert!(resp.result.is_some());
546        // Entry consumed.
547        assert!(p.lock().unwrap().is_empty());
548    }
549
550    #[tokio::test]
551    async fn unknown_id_is_discarded_without_disturbing_other_waiters() {
552        let p = pending();
553        let (tx, _rx) = oneshot::channel();
554        p.lock().unwrap().insert(1, tx);
555        // A late/orphaned response for an id nobody is waiting on.
556        route_line(r#"{"jsonrpc":"2.0","id":999,"result":{}}"#, &p);
557        // The id-1 waiter is untouched — stream stays synchronized.
558        assert!(p.lock().unwrap().contains_key(&1));
559    }
560
561    #[test]
562    fn notifications_and_garbage_are_ignored() {
563        let p = pending();
564        // No panic, no routing for an id-less notification or unparseable noise.
565        route_line(
566            r#"{"jsonrpc":"2.0","method":"notifications/progress","params":{}}"#,
567            &p,
568        );
569        route_line("not json at all", &p);
570        assert!(p.lock().unwrap().is_empty());
571    }
572
573    #[tokio::test]
574    async fn error_response_is_routed_for_send_request_to_surface() {
575        let p = pending();
576        let (tx, rx) = oneshot::channel();
577        p.lock().unwrap().insert(3, tx);
578        route_line(
579            r#"{"jsonrpc":"2.0","id":3,"error":{"code":-1,"message":"tool failed"}}"#,
580            &p,
581        );
582        let resp = rx.await.unwrap();
583        assert!(resp.error.is_some());
584        assert_eq!(resp.error.unwrap().message, "tool failed");
585    }
586
587    #[tokio::test]
588    async fn reader_loop_routes_then_marks_dead_and_clears_on_eof() {
589        let pending: Pending = Arc::new(StdMutex::new(HashMap::new()));
590        let alive = Arc::new(AtomicBool::new(true));
591        let (tx, rx) = oneshot::channel();
592        // A second waiter that never gets a response — must be swept on EOF.
593        let (tx2, rx2) = oneshot::channel();
594        pending.lock().unwrap().insert(1, tx);
595        pending.lock().unwrap().insert(2, tx2);
596
597        let input = b"{\"jsonrpc\":\"2.0\",\"id\":1,\"result\":{\"ok\":true}}\n";
598        reader_loop(
599            BufReader::new(&input[..]),
600            Arc::clone(&pending),
601            Arc::clone(&alive),
602            "t".into(),
603        )
604        .await;
605
606        assert!(rx.await.unwrap().result.is_some(), "id 1 routed");
607        assert!(!alive.load(Ordering::SeqCst), "EOF marks the session dead");
608        assert!(pending.lock().unwrap().is_empty(), "waiters swept on EOF");
609        // The unanswered waiter's receiver now errors (sender dropped).
610        assert!(rx2.await.is_err());
611    }
612
613    #[tokio::test]
614    async fn reader_loop_skips_noise_without_desync() {
615        // Garbage and an id-less notification precede the real reply — the reader
616        // must still deliver id 5 (no stream desync). This is the regression the
617        // reader-task design fixes vs. the old cancel-unsafe read_line.
618        let pending: Pending = Arc::new(StdMutex::new(HashMap::new()));
619        let alive = Arc::new(AtomicBool::new(true));
620        let (tx, rx) = oneshot::channel();
621        pending.lock().unwrap().insert(5, tx);
622
623        let input = b"garbage not json\n\
624            {\"jsonrpc\":\"2.0\",\"method\":\"notifications/progress\",\"params\":{}}\n\
625            {\"jsonrpc\":\"2.0\",\"id\":5,\"result\":{\"done\":true}}\n";
626        reader_loop(
627            BufReader::new(&input[..]),
628            Arc::clone(&pending),
629            alive,
630            "t".into(),
631        )
632        .await;
633
634        assert!(rx.await.unwrap().result.is_some(), "id 5 delivered past noise");
635    }
636}