1use std::collections::HashMap;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9
10use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
11use tokio::process::{Child, Command};
12use tokio::sync::{Mutex, oneshot};
13use tracing::{debug, info, warn};
14
15use punch_types::{PunchError, PunchResult, ToolCategory, ToolDefinition};
16
17pub struct McpClient {
19 server_name: String,
21 child: Mutex<Option<Child>>,
23 stdin_tx: Mutex<Option<tokio::process::ChildStdin>>,
25 pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>>,
27 next_id: AtomicU64,
29 server_info: Mutex<Option<serde_json::Value>>,
31}
32
33impl McpClient {
34 pub async fn spawn(
38 server_name: String,
39 command: &str,
40 args: &[String],
41 env: &HashMap<String, String>,
42 ) -> PunchResult<Self> {
43 let mut cmd = Command::new(command);
44 cmd.args(args)
45 .envs(env)
46 .stdin(std::process::Stdio::piped())
47 .stdout(std::process::Stdio::piped())
48 .stderr(std::process::Stdio::piped());
49
50 let mut child = cmd.spawn().map_err(|e| PunchError::Mcp {
51 server: server_name.clone(),
52 message: format!("failed to spawn: {}", e),
53 })?;
54
55 let stdout = child.stdout.take().ok_or_else(|| PunchError::Mcp {
56 server: server_name.clone(),
57 message: "failed to capture stdout".into(),
58 })?;
59 let stdin = child.stdin.take().ok_or_else(|| PunchError::Mcp {
60 server: server_name.clone(),
61 message: "failed to capture stdin".into(),
62 })?;
63
64 let pending: Arc<Mutex<HashMap<u64, oneshot::Sender<serde_json::Value>>>> =
65 Arc::new(Mutex::new(HashMap::new()));
66
67 let pending_clone = Arc::clone(&pending);
69 let name_clone = server_name.clone();
70 tokio::spawn(async move {
71 let reader = BufReader::new(stdout);
72 let mut lines = reader.lines();
73
74 while let Ok(Some(line)) = lines.next_line().await {
75 let line = line.trim().to_string();
76 if line.is_empty() {
77 continue;
78 }
79
80 match serde_json::from_str::<serde_json::Value>(&line) {
81 Ok(msg) => {
82 if let Some(id) = msg.get("id").and_then(|v| v.as_u64()) {
83 let mut pending = pending_clone.lock().await;
84 if let Some(tx) = pending.remove(&id) {
85 let _ = tx.send(msg);
86 }
87 } else {
88 debug!(server = %name_clone, "mcp notification: {}", line);
90 }
91 }
92 Err(e) => {
93 warn!(server = %name_clone, "failed to parse MCP message: {}", e);
94 }
95 }
96 }
97
98 debug!(server = %name_clone, "MCP stdout reader exited");
99 });
100
101 info!(server = %server_name, command = command, "MCP server spawned");
102
103 Ok(Self {
104 server_name,
105 child: Mutex::new(Some(child)),
106 stdin_tx: Mutex::new(Some(stdin)),
107 pending,
108 next_id: AtomicU64::new(1),
109 server_info: Mutex::new(None),
110 })
111 }
112
113 pub async fn initialize(&self) -> PunchResult<()> {
115 let params = serde_json::json!({
116 "protocolVersion": "2024-11-05",
117 "capabilities": {},
118 "clientInfo": {
119 "name": "punch-runtime",
120 "version": env!("CARGO_PKG_VERSION"),
121 }
122 });
123
124 let response = self.send_request("initialize", params).await?;
125
126 *self.server_info.lock().await = Some(response.clone());
128
129 self.send_notification("notifications/initialized", serde_json::json!({}))
131 .await?;
132
133 info!(server = %self.server_name, "MCP server initialized");
134 Ok(())
135 }
136
137 pub async fn list_tools(&self) -> PunchResult<Vec<ToolDefinition>> {
141 let response = self
142 .send_request("tools/list", serde_json::json!({}))
143 .await?;
144
145 let result = response.get("result").ok_or_else(|| PunchError::Mcp {
146 server: self.server_name.clone(),
147 message: "missing 'result' in tools/list response".into(),
148 })?;
149
150 let tools_array = result
151 .get("tools")
152 .and_then(|t| t.as_array())
153 .ok_or_else(|| PunchError::Mcp {
154 server: self.server_name.clone(),
155 message: "missing 'tools' array in response".into(),
156 })?;
157
158 let mut tools = Vec::new();
159 for tool in tools_array {
160 let raw_name = tool["name"].as_str().unwrap_or("unknown");
161 let namespaced = format!("mcp_{}_{}", self.server_name, raw_name);
162
163 tools.push(ToolDefinition {
164 name: namespaced,
165 description: tool["description"].as_str().unwrap_or("").to_string(),
166 input_schema: tool
167 .get("inputSchema")
168 .cloned()
169 .unwrap_or(serde_json::json!({"type": "object"})),
170 category: ToolCategory::Agent,
171 });
172 }
173
174 debug!(
175 server = %self.server_name,
176 count = tools.len(),
177 "discovered MCP tools"
178 );
179
180 Ok(tools)
181 }
182
183 pub async fn call_tool(
187 &self,
188 name: &str,
189 input: serde_json::Value,
190 ) -> PunchResult<serde_json::Value> {
191 let params = serde_json::json!({
192 "name": name,
193 "arguments": input,
194 });
195
196 let response = self.send_request("tools/call", params).await?;
197
198 let result = response.get("result").cloned().ok_or_else(|| {
199 let error_msg = response["error"]["message"]
201 .as_str()
202 .unwrap_or("unknown error");
203 PunchError::Mcp {
204 server: self.server_name.clone(),
205 message: format!("tool call '{}' failed: {}", name, error_msg),
206 }
207 })?;
208
209 Ok(result)
210 }
211
212 async fn send_request(
214 &self,
215 method: &str,
216 params: serde_json::Value,
217 ) -> PunchResult<serde_json::Value> {
218 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
219
220 let request = serde_json::json!({
221 "jsonrpc": "2.0",
222 "id": id,
223 "method": method,
224 "params": params,
225 });
226
227 let (tx, rx) = oneshot::channel();
228 {
229 let mut pending = self.pending.lock().await;
230 pending.insert(id, tx);
231 }
232
233 self.write_message(&request).await?;
234
235 let response = tokio::time::timeout(std::time::Duration::from_secs(30), rx)
236 .await
237 .map_err(|_| PunchError::Mcp {
238 server: self.server_name.clone(),
239 message: format!("timeout waiting for response to '{}'", method),
240 })?
241 .map_err(|_| PunchError::Mcp {
242 server: self.server_name.clone(),
243 message: format!("response channel closed for '{}'", method),
244 })?;
245
246 if let Some(error) = response.get("error") {
248 let code = error["code"].as_i64().unwrap_or(-1);
249 let message = error["message"].as_str().unwrap_or("unknown");
250 return Err(PunchError::Mcp {
251 server: self.server_name.clone(),
252 message: format!("JSON-RPC error {}: {}", code, message),
253 });
254 }
255
256 Ok(response)
257 }
258
259 async fn send_notification(&self, method: &str, params: serde_json::Value) -> PunchResult<()> {
261 let notification = serde_json::json!({
262 "jsonrpc": "2.0",
263 "method": method,
264 "params": params,
265 });
266
267 self.write_message(¬ification).await
268 }
269
270 async fn write_message(&self, msg: &serde_json::Value) -> PunchResult<()> {
272 let serialized = serde_json::to_string(msg).map_err(|e| PunchError::Mcp {
273 server: self.server_name.clone(),
274 message: format!("failed to serialize message: {}", e),
275 })?;
276
277 let mut stdin_guard = self.stdin_tx.lock().await;
278 let stdin = stdin_guard.as_mut().ok_or_else(|| PunchError::Mcp {
279 server: self.server_name.clone(),
280 message: "stdin not available (server may have exited)".into(),
281 })?;
282
283 stdin
284 .write_all(serialized.as_bytes())
285 .await
286 .map_err(|e| PunchError::Mcp {
287 server: self.server_name.clone(),
288 message: format!("failed to write to stdin: {}", e),
289 })?;
290 stdin.write_all(b"\n").await.map_err(|e| PunchError::Mcp {
291 server: self.server_name.clone(),
292 message: format!("failed to write newline: {}", e),
293 })?;
294 stdin.flush().await.map_err(|e| PunchError::Mcp {
295 server: self.server_name.clone(),
296 message: format!("failed to flush stdin: {}", e),
297 })?;
298
299 Ok(())
300 }
301
302 pub async fn shutdown(&self) -> PunchResult<()> {
304 {
306 let mut stdin = self.stdin_tx.lock().await;
307 *stdin = None;
308 }
309
310 let mut child_guard = self.child.lock().await;
311 if let Some(ref mut child) = *child_guard {
312 match tokio::time::timeout(std::time::Duration::from_secs(5), child.wait()).await {
313 Ok(Ok(status)) => {
314 info!(
315 server = %self.server_name,
316 exit_code = ?status.code(),
317 "MCP server exited"
318 );
319 }
320 Ok(Err(e)) => {
321 warn!(server = %self.server_name, "error waiting for MCP server: {}", e);
322 }
323 Err(_) => {
324 warn!(server = %self.server_name, "MCP server did not exit in time, killing");
325 let _ = child.kill().await;
326 }
327 }
328 }
329
330 Ok(())
331 }
332
333 pub fn strip_namespace<'a>(&self, namespaced_name: &'a str) -> Option<&'a str> {
337 let prefix = format!("mcp_{}_", self.server_name);
338 namespaced_name.strip_prefix(&prefix)
339 }
340
341 pub fn server_name(&self) -> &str {
343 &self.server_name
344 }
345}
346
347#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_strip_namespace_basic() {
357 let client = McpClient {
358 server_name: "github".to_string(),
359 child: Mutex::new(None),
360 stdin_tx: Mutex::new(None),
361 pending: Arc::new(Mutex::new(HashMap::new())),
362 next_id: AtomicU64::new(1),
363 server_info: Mutex::new(None),
364 };
365
366 assert_eq!(
367 client.strip_namespace("mcp_github_create_issue"),
368 Some("create_issue")
369 );
370 }
371
372 #[test]
373 fn test_strip_namespace_no_match() {
374 let client = McpClient {
375 server_name: "github".to_string(),
376 child: Mutex::new(None),
377 stdin_tx: Mutex::new(None),
378 pending: Arc::new(Mutex::new(HashMap::new())),
379 next_id: AtomicU64::new(1),
380 server_info: Mutex::new(None),
381 };
382
383 assert_eq!(client.strip_namespace("mcp_slack_send"), None);
384 }
385
386 #[test]
387 fn test_strip_namespace_exact_prefix() {
388 let client = McpClient {
389 server_name: "fs".to_string(),
390 child: Mutex::new(None),
391 stdin_tx: Mutex::new(None),
392 pending: Arc::new(Mutex::new(HashMap::new())),
393 next_id: AtomicU64::new(1),
394 server_info: Mutex::new(None),
395 };
396
397 assert_eq!(
398 client.strip_namespace("mcp_fs_read_file"),
399 Some("read_file")
400 );
401 assert_eq!(client.strip_namespace("mcp_fs_"), Some(""));
402 }
403
404 #[test]
405 fn test_server_name() {
406 let client = McpClient {
407 server_name: "test-server".to_string(),
408 child: Mutex::new(None),
409 stdin_tx: Mutex::new(None),
410 pending: Arc::new(Mutex::new(HashMap::new())),
411 next_id: AtomicU64::new(1),
412 server_info: Mutex::new(None),
413 };
414
415 assert_eq!(client.server_name(), "test-server");
416 }
417
418 #[test]
419 fn test_next_id_atomic() {
420 let client = McpClient {
421 server_name: "test".to_string(),
422 child: Mutex::new(None),
423 stdin_tx: Mutex::new(None),
424 pending: Arc::new(Mutex::new(HashMap::new())),
425 next_id: AtomicU64::new(1),
426 server_info: Mutex::new(None),
427 };
428
429 let id1 = client.next_id.fetch_add(1, Ordering::Relaxed);
430 let id2 = client.next_id.fetch_add(1, Ordering::Relaxed);
431 assert_eq!(id1, 1);
432 assert_eq!(id2, 2);
433 }
434}