Skip to main content

punch_runtime/
mcp.rs

1//! MCP (Model Context Protocol) client.
2//!
3//! Manages connections to MCP servers via stdio transport, providing
4//! tool discovery and invocation over JSON-RPC 2.0.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12use tokio::sync::{Mutex, oneshot};
13use tracing::{debug, info, warn};
14
15use punch_types::{PunchError, PunchResult, ToolCategory, ToolDefinition};
16
17/// A client connection to a single MCP server.
18pub struct McpClient {
19    /// Name of this MCP server (used for tool namespacing).
20    server_name: String,
21    /// The child process handle.
22    child: Mutex<Option<Child>>,
23    /// Sender for writing requests to the child's stdin.
24    stdin_tx: Mutex<Option<tokio::process::ChildStdin>>,
25    /// Pending requests awaiting responses, keyed by JSON-RPC id.
26    pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
27    /// Monotonic request ID counter.
28    next_id: AtomicU64,
29    /// Server capabilities discovered during initialization.
30    server_info: Mutex<Option<serde_json::Value>>,
31}
32
33impl McpClient {
34    /// Spawn an MCP server subprocess and prepare the client.
35    ///
36    /// Does NOT send the `initialize` request yet -- call [`initialize`] after.
37    pub async fn spawn(
38        server_name: String,
39        command: &str,
40        args: &[String],
41        env: &HashMap<String, String>,
42    ) -> PunchResult<Self> {
43        let mut cmd = Command::new(command);
44        cmd.args(args)
45            .envs(env)
46            .stdin(std::process::Stdio::piped())
47            .stdout(std::process::Stdio::piped())
48            .stderr(std::process::Stdio::piped());
49
50        let mut child = cmd.spawn().map_err(|e| PunchError::Mcp {
51            server: server_name.clone(),
52            message: format!("failed to spawn: {}", e),
53        })?;
54
55        let stdout = child.stdout.take().ok_or_else(|| PunchError::Mcp {
56            server: server_name.clone(),
57            message: "failed to capture stdout".into(),
58        })?;
59        let stdin = child.stdin.take().ok_or_else(|| PunchError::Mcp {
60            server: server_name.clone(),
61            message: "failed to capture stdin".into(),
62        })?;
63
64        let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
65            Arc::new(Mutex::new(HashMap::new()));
66
67        // Spawn a reader task to route responses to pending requests.
68        let pending_clone = Arc::clone(&pending);
69        let name_clone = server_name.clone();
70        tokio::spawn(async move {
71            let reader = BufReader::new(stdout);
72            let mut lines = reader.lines();
73
74            while let Ok(Some(line)) = lines.next_line().await {
75                let line = line.trim().to_string();
76                if line.is_empty() {
77                    continue;
78                }
79
80                match serde_json::from_str::<serde_json::Value>(&line) {
81                    Ok(msg) => {
82                        if let Some(id) = msg.get("id").and_then(|v| v.as_u64()) {
83                            let mut pending = pending_clone.lock().await;
84                            if let Some(tx) = pending.remove(&id) {
85                                let _ = tx.send(msg);
86                            }
87                        } else {
88                            // Notification from server (no id) -- log and discard.
89                            debug!(server = %name_clone, "mcp notification: {}", line);
90                        }
91                    }
92                    Err(e) => {
93                        warn!(server = %name_clone, "failed to parse MCP message: {}", e);
94                    }
95                }
96            }
97
98            debug!(server = %name_clone, "MCP stdout reader exited");
99        });
100
101        info!(server = %server_name, command = command, "MCP server spawned");
102
103        Ok(Self {
104            server_name,
105            child: Mutex::new(Some(child)),
106            stdin_tx: Mutex::new(Some(stdin)),
107            pending,
108            next_id: AtomicU64::new(1),
109            server_info: Mutex::new(None),
110        })
111    }
112
113    /// Send the JSON-RPC `initialize` handshake to the MCP server.
114    pub async fn initialize(&self) -> PunchResult<()> {
115        let params = serde_json::json!({
116            "protocolVersion": "2024-11-05",
117            "capabilities": {},
118            "clientInfo": {
119                "name": "punch-runtime",
120                "version": env!("CARGO_PKG_VERSION"),
121            }
122        });
123
124        let response = self.send_request("initialize", params).await?;
125
126        // Store the server info for later reference.
127        *self.server_info.lock().await = Some(response.clone());
128
129        // Send the `initialized` notification (no id, no response expected).
130        self.send_notification("notifications/initialized", serde_json::json!({}))
131            .await?;
132
133        info!(server = %self.server_name, "MCP server initialized");
134        Ok(())
135    }
136
137    /// Discover tools exposed by this MCP server.
138    ///
139    /// Tool names are namespaced as `mcp_{server_name}_{tool_name}`.
140    pub async fn list_tools(&self) -> PunchResult<Vec<ToolDefinition>> {
141        let response = self
142            .send_request("tools/list", serde_json::json!({}))
143            .await?;
144
145        let result = response.get("result").ok_or_else(|| PunchError::Mcp {
146            server: self.server_name.clone(),
147            message: "missing 'result' in tools/list response".into(),
148        })?;
149
150        let tools_array = result
151            .get("tools")
152            .and_then(|t| t.as_array())
153            .ok_or_else(|| PunchError::Mcp {
154                server: self.server_name.clone(),
155                message: "missing 'tools' array in response".into(),
156            })?;
157
158        let mut tools = Vec::new();
159        for tool in tools_array {
160            let raw_name = tool["name"].as_str().unwrap_or("unknown");
161            let namespaced = format!("mcp_{}_{}", self.server_name, raw_name);
162
163            tools.push(ToolDefinition {
164                name: namespaced,
165                description: tool["description"].as_str().unwrap_or("").to_string(),
166                input_schema: tool
167                    .get("inputSchema")
168                    .cloned()
169                    .unwrap_or(serde_json::json!({"type": "object"})),
170                category: ToolCategory::Agent,
171            });
172        }
173
174        debug!(
175            server = %self.server_name,
176            count = tools.len(),
177            "discovered MCP tools"
178        );
179
180        Ok(tools)
181    }
182
183    /// Call a tool on the MCP server.
184    ///
185    /// The `name` should be the raw tool name (without the mcp_ prefix).
186    pub async fn call_tool(
187        &self,
188        name: &str,
189        input: serde_json::Value,
190    ) -> PunchResult<serde_json::Value> {
191        let params = serde_json::json!({
192            "name": name,
193            "arguments": input,
194        });
195
196        let response = self.send_request("tools/call", params).await?;
197
198        let result = response.get("result").cloned().ok_or_else(|| {
199            // Check for error.
200            let error_msg = response["error"]["message"]
201                .as_str()
202                .unwrap_or("unknown error");
203            PunchError::Mcp {
204                server: self.server_name.clone(),
205                message: format!("tool call '{}' failed: {}", name, error_msg),
206            }
207        })?;
208
209        Ok(result)
210    }
211
212    /// Send a JSON-RPC 2.0 request and wait for the response.
213    async fn send_request(
214        &self,
215        method: &str,
216        params: serde_json::Value,
217    ) -> PunchResult<serde_json::Value> {
218        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
219
220        let request = serde_json::json!({
221            "jsonrpc": "2.0",
222            "id": id,
223            "method": method,
224            "params": params,
225        });
226
227        let (tx, rx) = oneshot::channel();
228        {
229            let mut pending = self.pending.lock().await;
230            pending.insert(id, tx);
231        }
232
233        self.write_message(&request).await?;
234
235        let response = tokio::time::timeout(std::time::Duration::from_secs(30), rx)
236            .await
237            .map_err(|_| PunchError::Mcp {
238                server: self.server_name.clone(),
239                message: format!("timeout waiting for response to '{}'", method),
240            })?
241            .map_err(|_| PunchError::Mcp {
242                server: self.server_name.clone(),
243                message: format!("response channel closed for '{}'", method),
244            })?;
245
246        // Check for JSON-RPC error.
247        if let Some(error) = response.get("error") {
248            let code = error["code"].as_i64().unwrap_or(-1);
249            let message = error["message"].as_str().unwrap_or("unknown");
250            return Err(PunchError::Mcp {
251                server: self.server_name.clone(),
252                message: format!("JSON-RPC error {}: {}", code, message),
253            });
254        }
255
256        Ok(response)
257    }
258
259    /// Send a JSON-RPC 2.0 notification (no id, no response expected).
260    async fn send_notification(&self, method: &str, params: serde_json::Value) -> PunchResult<()> {
261        let notification = serde_json::json!({
262            "jsonrpc": "2.0",
263            "method": method,
264            "params": params,
265        });
266
267        self.write_message(&notification).await
268    }
269
270    /// Write a JSON message to the child's stdin, followed by a newline.
271    async fn write_message(&self, msg: &serde_json::Value) -> PunchResult<()> {
272        let serialized = serde_json::to_string(msg).map_err(|e| PunchError::Mcp {
273            server: self.server_name.clone(),
274            message: format!("failed to serialize message: {}", e),
275        })?;
276
277        let mut stdin_guard = self.stdin_tx.lock().await;
278        let stdin = stdin_guard.as_mut().ok_or_else(|| PunchError::Mcp {
279            server: self.server_name.clone(),
280            message: "stdin not available (server may have exited)".into(),
281        })?;
282
283        stdin
284            .write_all(serialized.as_bytes())
285            .await
286            .map_err(|e| PunchError::Mcp {
287                server: self.server_name.clone(),
288                message: format!("failed to write to stdin: {}", e),
289            })?;
290        stdin.write_all(b"\n").await.map_err(|e| PunchError::Mcp {
291            server: self.server_name.clone(),
292            message: format!("failed to write newline: {}", e),
293        })?;
294        stdin.flush().await.map_err(|e| PunchError::Mcp {
295            server: self.server_name.clone(),
296            message: format!("failed to flush stdin: {}", e),
297        })?;
298
299        Ok(())
300    }
301
302    /// Shut down the MCP server process gracefully.
303    pub async fn shutdown(&self) -> PunchResult<()> {
304        // Drop stdin to signal EOF.
305        {
306            let mut stdin = self.stdin_tx.lock().await;
307            *stdin = None;
308        }
309
310        let mut child_guard = self.child.lock().await;
311        if let Some(ref mut child) = *child_guard {
312            match tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await {
313                Ok(Ok(status)) => {
314                    info!(
315                        server = %self.server_name,
316                        exit_code = ?status.code(),
317                        "MCP server exited"
318                    );
319                }
320                Ok(Err(e)) => {
321                    warn!(server = %self.server_name, "error waiting for MCP server: {}", e);
322                }
323                Err(_) => {
324                    warn!(server = %self.server_name, "MCP server did not exit in time, killing");
325                    let _ = child.kill().await;
326                }
327            }
328        }
329
330        Ok(())
331    }
332
333    /// Extract the raw tool name from a namespaced MCP tool name.
334    ///
335    /// E.g., `mcp_github_create_issue` with server_name `github` returns `create_issue`.
336    pub fn strip_namespace<'a>(&self, namespaced_name: &'a str) -> Option<&'a str> {
337        let prefix = format!("mcp_{}_", self.server_name);
338        namespaced_name.strip_prefix(&prefix)
339    }
340
341    /// The server name used for namespacing.
342    pub fn server_name(&self) -> &str {
343        &self.server_name
344    }
345}
346
347// ---------------------------------------------------------------------------
348// Tests
349// ---------------------------------------------------------------------------
350
351#[cfg(test)]
352mod tests {
353    use super::*;
354
355    #[test]
356    fn test_strip_namespace_basic() {
357        let client = McpClient {
358            server_name: "github".to_string(),
359            child: Mutex::new(None),
360            stdin_tx: Mutex::new(None),
361            pending: Arc::new(Mutex::new(HashMap::new())),
362            next_id: AtomicU64::new(1),
363            server_info: Mutex::new(None),
364        };
365
366        assert_eq!(
367            client.strip_namespace("mcp_github_create_issue"),
368            Some("create_issue")
369        );
370    }
371
372    #[test]
373    fn test_strip_namespace_no_match() {
374        let client = McpClient {
375            server_name: "github".to_string(),
376            child: Mutex::new(None),
377            stdin_tx: Mutex::new(None),
378            pending: Arc::new(Mutex::new(HashMap::new())),
379            next_id: AtomicU64::new(1),
380            server_info: Mutex::new(None),
381        };
382
383        assert_eq!(client.strip_namespace("mcp_slack_send"), None);
384    }
385
386    #[test]
387    fn test_strip_namespace_exact_prefix() {
388        let client = McpClient {
389            server_name: "fs".to_string(),
390            child: Mutex::new(None),
391            stdin_tx: Mutex::new(None),
392            pending: Arc::new(Mutex::new(HashMap::new())),
393            next_id: AtomicU64::new(1),
394            server_info: Mutex::new(None),
395        };
396
397        assert_eq!(
398            client.strip_namespace("mcp_fs_read_file"),
399            Some("read_file")
400        );
401        assert_eq!(client.strip_namespace("mcp_fs_"), Some(""));
402    }
403
404    #[test]
405    fn test_server_name() {
406        let client = McpClient {
407            server_name: "test-server".to_string(),
408            child: Mutex::new(None),
409            stdin_tx: Mutex::new(None),
410            pending: Arc::new(Mutex::new(HashMap::new())),
411            next_id: AtomicU64::new(1),
412            server_info: Mutex::new(None),
413        };
414
415        assert_eq!(client.server_name(), "test-server");
416    }
417
418    #[test]
419    fn test_next_id_atomic() {
420        let client = McpClient {
421            server_name: "test".to_string(),
422            child: Mutex::new(None),
423            stdin_tx: Mutex::new(None),
424            pending: Arc::new(Mutex::new(HashMap::new())),
425            next_id: AtomicU64::new(1),
426            server_info: Mutex::new(None),
427        };
428
429        let id1 = client.next_id.fetch_add(1, Ordering::Relaxed);
430        let id2 = client.next_id.fetch_add(1, Ordering::Relaxed);
431        assert_eq!(id1, 1);
432        assert_eq!(id2, 2);
433    }
434}