Skip to main content

mur_core/mcp/
client.rs

1//! MCP stdio client — connects to MCP servers via stdin/stdout.
2
3use super::types::*;
4use anyhow::{Context, Result};
5use std::sync::atomic::{AtomicU64, Ordering};
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::process::{Child, Command};
8use tokio::sync::Mutex;
9use tracing::{debug, info, warn};
10
11/// MCP client that communicates with a server process via stdio.
12pub struct McpClient {
13    child: Mutex<Child>,
14    stdin: Mutex<tokio::process::ChildStdin>,
15    stdout: Mutex<BufReader<tokio::process::ChildStdout>>,
16    next_id: AtomicU64,
17    server_name: String,
18    capabilities: Mutex<Option<McpServerCapabilities>>,
19}
20
21impl McpClient {
22    /// Spawn an MCP server process and perform the initialize handshake.
23    pub async fn connect(config: &McpConfig) -> Result<Self> {
24        let mut cmd = Command::new(&config.command);
25        cmd.args(&config.args)
26            .stdin(std::process::Stdio::piped())
27            .stdout(std::process::Stdio::piped())
28            .stderr(std::process::Stdio::piped());
29
30        for (k, v) in &config.env {
31            cmd.env(k, v);
32        }
33
34        if let Some(cwd) = &config.cwd {
35            cmd.current_dir(cwd);
36        }
37
38        let mut child = cmd.spawn().context("Failed to spawn MCP server")?;
39
40        let stdin = child.stdin.take().context("No stdin on MCP server")?;
41        let stdout = child.stdout.take().context("No stdout on MCP server")?;
42
43        let client = Self {
44            child: Mutex::new(child),
45            stdin: Mutex::new(stdin),
46            stdout: Mutex::new(BufReader::new(stdout)),
47            next_id: AtomicU64::new(1),
48            server_name: config.command.clone(),
49            capabilities: Mutex::new(None),
50        };
51
52        // Perform initialize handshake
53        client.initialize().await?;
54
55        Ok(client)
56    }
57
58    /// Send the initialize request and initialized notification.
59    async fn initialize(&self) -> Result<()> {
60        let params = serde_json::json!({
61            "protocolVersion": "2024-11-05",
62            "capabilities": {},
63            "clientInfo": {
64                "name": "mur-commander",
65                "version": env!("CARGO_PKG_VERSION")
66            }
67        });
68
69        let response = self.request("initialize", Some(params)).await?;
70        let init_result: InitializeResult = serde_json::from_value(
71            response.context("Empty initialize response")?
72        ).context("Invalid initialize response")?;
73
74        info!(
75            "MCP server initialized: {} (protocol {})",
76            init_result.server_info.as_ref().map(|s| s.name.as_str()).unwrap_or("unknown"),
77            init_result.protocol_version
78        );
79
80        *self.capabilities.lock().await = Some(init_result.capabilities);
81
82        // Send initialized notification (no id, no response expected)
83        self.notify("notifications/initialized", None).await?;
84
85        Ok(())
86    }
87
88    /// List available tools from the server.
89    pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
90        let response = self.request("tools/list", None).await?;
91        let result: ToolsListResult = serde_json::from_value(
92            response.context("Empty tools/list response")?
93        ).context("Invalid tools/list response")?;
94        Ok(result.tools)
95    }
96
97    /// Call a tool with the given arguments.
98    pub async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<McpToolResult> {
99        let params = serde_json::json!({
100            "name": name,
101            "arguments": arguments
102        });
103
104        let response = self.request("tools/call", Some(params)).await?;
105        let result: McpToolResult = serde_json::from_value(
106            response.context("Empty tools/call response")?
107        ).context("Invalid tools/call response")?;
108
109        Ok(result)
110    }
111
112    /// Get server capabilities.
113    pub async fn capabilities(&self) -> Option<McpServerCapabilities> {
114        self.capabilities.lock().await.clone()
115    }
116
117    /// Send a JSON-RPC request and wait for the response.
118    async fn request(&self, method: &str, params: Option<serde_json::Value>) -> Result<Option<serde_json::Value>> {
119        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
120
121        let request = JsonRpcRequest {
122            jsonrpc: "2.0",
123            id,
124            method: method.to_string(),
125            params,
126        };
127
128        let mut line = serde_json::to_string(&request)?;
129        line.push('\n');
130
131        // NOTE: Only log the method name, never the full request payload,
132        // because params may contain secrets (API keys, tokens, etc.).
133        debug!("MCP -> {}: method={} id={}", self.server_name, method, id);
134
135        {
136            let mut stdin = self.stdin.lock().await;
137            stdin.write_all(line.as_bytes()).await?;
138            stdin.flush().await?;
139        }
140
141        // Read response lines until we get one with our id (with timeout)
142        let mut stdout = self.stdout.lock().await;
143        let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(60);
144        loop {
145            let mut buf = String::new();
146            let read = tokio::time::timeout_at(deadline, stdout.read_line(&mut buf)).await;
147            let n = match read {
148                Ok(result) => result?,
149                Err(_) => anyhow::bail!("MCP request '{}' timed out after 60s", method),
150            };
151            if n == 0 {
152                anyhow::bail!("MCP server closed stdout");
153            }
154
155            let buf = buf.trim();
156            if buf.is_empty() {
157                continue;
158            }
159
160            debug!("MCP ← {}: [response received]", self.server_name);
161
162            let response: JsonRpcResponse = match serde_json::from_str(buf) {
163                Ok(r) => r,
164                Err(_) => {
165                    // Could be a notification, skip
166                    warn!("Skipping non-response line from MCP server");
167                    continue;
168                }
169            };
170
171            // Check if this is our response
172            if response.id == Some(id) {
173                if let Some(error) = response.error {
174                    anyhow::bail!("MCP error ({}): {}", error.code, error.message);
175                }
176                return Ok(response.result);
177            }
178
179            // Not our response, skip (could be notification or out-of-order)
180        }
181    }
182
183    /// Send a JSON-RPC notification (no id, no response).
184    async fn notify(&self, method: &str, params: Option<serde_json::Value>) -> Result<()> {
185        let notification = serde_json::json!({
186            "jsonrpc": "2.0",
187            "method": method,
188            "params": params.unwrap_or(serde_json::json!({}))
189        });
190
191        let mut line = serde_json::to_string(&notification)?;
192        line.push('\n');
193
194        let mut stdin = self.stdin.lock().await;
195        stdin.write_all(line.as_bytes()).await?;
196        stdin.flush().await?;
197
198        Ok(())
199    }
200
201    /// Shutdown the MCP server gracefully.
202    pub async fn shutdown(&self) -> Result<()> {
203        // Try to send shutdown request
204        let _ = self.request("shutdown", None).await;
205        let _ = self.notify("exit", None).await;
206
207        // Give it a moment, then kill
208        let mut child = self.child.lock().await;
209        tokio::select! {
210            _ = child.wait() => {},
211            _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => {
212                let _ = child.kill().await;
213            }
214        }
215
216        Ok(())
217    }
218}
219
220impl Drop for McpClient {
221    fn drop(&mut self) {
222        // Best-effort kill on drop
223        if let Ok(mut child) = self.child.try_lock() {
224            let _ = child.start_kill();
225        }
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_json_rpc_request_serialization() {
235        let req = JsonRpcRequest {
236            jsonrpc: "2.0",
237            id: 1,
238            method: "tools/list".to_string(),
239            params: None,
240        };
241        let json = serde_json::to_string(&req).unwrap();
242        assert!(json.contains("\"jsonrpc\":\"2.0\""));
243        assert!(json.contains("\"method\":\"tools/list\""));
244        assert!(!json.contains("params")); // skip_serializing_if
245    }
246
247    #[test]
248    fn test_mcp_tool_deserialization() {
249        let json = r#"{
250            "name": "read_file",
251            "description": "Read a file",
252            "inputSchema": {
253                "type": "object",
254                "properties": {
255                    "path": {"type": "string"}
256                }
257            }
258        }"#;
259        let tool: McpTool = serde_json::from_str(json).unwrap();
260        assert_eq!(tool.name, "read_file");
261        assert_eq!(tool.description.as_deref(), Some("Read a file"));
262    }
263
264    #[test]
265    fn test_mcp_content_text() {
266        let json = r#"{"type": "text", "text": "hello"}"#;
267        let content: McpContent = serde_json::from_str(json).unwrap();
268        match content {
269            McpContent::Text { text } => assert_eq!(text, "hello"),
270            _ => panic!("Expected text content"),
271        }
272    }
273
274    #[test]
275    fn test_mcp_tool_result_deserialization() {
276        let json = r#"{
277            "content": [{"type": "text", "text": "result"}],
278            "isError": false
279        }"#;
280        let result: McpToolResult = serde_json::from_str(json).unwrap();
281        assert_eq!(result.content.len(), 1);
282        assert!(!result.is_error);
283    }
284}