1use super::types::*;
4use anyhow::{Context, Result};
5use std::sync::atomic::{AtomicU64, Ordering};
6use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
7use tokio::process::{Child, Command};
8use tokio::sync::Mutex;
9use tracing::{debug, info, warn};
10
11pub struct McpClient {
13 child: Mutex<Child>,
14 stdin: Mutex<tokio::process::ChildStdin>,
15 stdout: Mutex<BufReader<tokio::process::ChildStdout>>,
16 next_id: AtomicU64,
17 server_name: String,
18 capabilities: Mutex<Option<McpServerCapabilities>>,
19}
20
21impl McpClient {
22 pub async fn connect(config: &McpConfig) -> Result<Self> {
24 let mut cmd = Command::new(&config.command);
25 cmd.args(&config.args)
26 .stdin(std::process::Stdio::piped())
27 .stdout(std::process::Stdio::piped())
28 .stderr(std::process::Stdio::piped());
29
30 for (k, v) in &config.env {
31 cmd.env(k, v);
32 }
33
34 if let Some(cwd) = &config.cwd {
35 cmd.current_dir(cwd);
36 }
37
38 let mut child = cmd.spawn().context("Failed to spawn MCP server")?;
39
40 let stdin = child.stdin.take().context("No stdin on MCP server")?;
41 let stdout = child.stdout.take().context("No stdout on MCP server")?;
42
43 let client = Self {
44 child: Mutex::new(child),
45 stdin: Mutex::new(stdin),
46 stdout: Mutex::new(BufReader::new(stdout)),
47 next_id: AtomicU64::new(1),
48 server_name: config.command.clone(),
49 capabilities: Mutex::new(None),
50 };
51
52 client.initialize().await?;
54
55 Ok(client)
56 }
57
58 async fn initialize(&self) -> Result<()> {
60 let params = serde_json::json!({
61 "protocolVersion": "2024-11-05",
62 "capabilities": {},
63 "clientInfo": {
64 "name": "mur-commander",
65 "version": env!("CARGO_PKG_VERSION")
66 }
67 });
68
69 let response = self.request("initialize", Some(params)).await?;
70 let init_result: InitializeResult = serde_json::from_value(
71 response.context("Empty initialize response")?
72 ).context("Invalid initialize response")?;
73
74 info!(
75 "MCP server initialized: {} (protocol {})",
76 init_result.server_info.as_ref().map(|s| s.name.as_str()).unwrap_or("unknown"),
77 init_result.protocol_version
78 );
79
80 *self.capabilities.lock().await = Some(init_result.capabilities);
81
82 self.notify("notifications/initialized", None).await?;
84
85 Ok(())
86 }
87
88 pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
90 let response = self.request("tools/list", None).await?;
91 let result: ToolsListResult = serde_json::from_value(
92 response.context("Empty tools/list response")?
93 ).context("Invalid tools/list response")?;
94 Ok(result.tools)
95 }
96
97 pub async fn call_tool(&self, name: &str, arguments: serde_json::Value) -> Result<McpToolResult> {
99 let params = serde_json::json!({
100 "name": name,
101 "arguments": arguments
102 });
103
104 let response = self.request("tools/call", Some(params)).await?;
105 let result: McpToolResult = serde_json::from_value(
106 response.context("Empty tools/call response")?
107 ).context("Invalid tools/call response")?;
108
109 Ok(result)
110 }
111
112 pub async fn capabilities(&self) -> Option<McpServerCapabilities> {
114 self.capabilities.lock().await.clone()
115 }
116
117 async fn request(&self, method: &str, params: Option<serde_json::Value>) -> Result<Option<serde_json::Value>> {
119 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
120
121 let request = JsonRpcRequest {
122 jsonrpc: "2.0",
123 id,
124 method: method.to_string(),
125 params,
126 };
127
128 let mut line = serde_json::to_string(&request)?;
129 line.push('\n');
130
131 debug!("MCP -> {}: method={} id={}", self.server_name, method, id);
134
135 {
136 let mut stdin = self.stdin.lock().await;
137 stdin.write_all(line.as_bytes()).await?;
138 stdin.flush().await?;
139 }
140
141 let mut stdout = self.stdout.lock().await;
143 let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(60);
144 loop {
145 let mut buf = String::new();
146 let read = tokio::time::timeout_at(deadline, stdout.read_line(&mut buf)).await;
147 let n = match read {
148 Ok(result) => result?,
149 Err(_) => anyhow::bail!("MCP request '{}' timed out after 60s", method),
150 };
151 if n == 0 {
152 anyhow::bail!("MCP server closed stdout");
153 }
154
155 let buf = buf.trim();
156 if buf.is_empty() {
157 continue;
158 }
159
160 debug!("MCP ← {}: [response received]", self.server_name);
161
162 let response: JsonRpcResponse = match serde_json::from_str(buf) {
163 Ok(r) => r,
164 Err(_) => {
165 warn!("Skipping non-response line from MCP server");
167 continue;
168 }
169 };
170
171 if response.id == Some(id) {
173 if let Some(error) = response.error {
174 anyhow::bail!("MCP error ({}): {}", error.code, error.message);
175 }
176 return Ok(response.result);
177 }
178
179 }
181 }
182
183 async fn notify(&self, method: &str, params: Option<serde_json::Value>) -> Result<()> {
185 let notification = serde_json::json!({
186 "jsonrpc": "2.0",
187 "method": method,
188 "params": params.unwrap_or(serde_json::json!({}))
189 });
190
191 let mut line = serde_json::to_string(¬ification)?;
192 line.push('\n');
193
194 let mut stdin = self.stdin.lock().await;
195 stdin.write_all(line.as_bytes()).await?;
196 stdin.flush().await?;
197
198 Ok(())
199 }
200
201 pub async fn shutdown(&self) -> Result<()> {
203 let _ = self.request("shutdown", None).await;
205 let _ = self.notify("exit", None).await;
206
207 let mut child = self.child.lock().await;
209 tokio::select! {
210 _ = child.wait() => {},
211 _ = tokio::time::sleep(std::time::Duration::from_secs(5)) => {
212 let _ = child.kill().await;
213 }
214 }
215
216 Ok(())
217 }
218}
219
220impl Drop for McpClient {
221 fn drop(&mut self) {
222 if let Ok(mut child) = self.child.try_lock() {
224 let _ = child.start_kill();
225 }
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn test_json_rpc_request_serialization() {
235 let req = JsonRpcRequest {
236 jsonrpc: "2.0",
237 id: 1,
238 method: "tools/list".to_string(),
239 params: None,
240 };
241 let json = serde_json::to_string(&req).unwrap();
242 assert!(json.contains("\"jsonrpc\":\"2.0\""));
243 assert!(json.contains("\"method\":\"tools/list\""));
244 assert!(!json.contains("params")); }
246
247 #[test]
248 fn test_mcp_tool_deserialization() {
249 let json = r#"{
250 "name": "read_file",
251 "description": "Read a file",
252 "inputSchema": {
253 "type": "object",
254 "properties": {
255 "path": {"type": "string"}
256 }
257 }
258 }"#;
259 let tool: McpTool = serde_json::from_str(json).unwrap();
260 assert_eq!(tool.name, "read_file");
261 assert_eq!(tool.description.as_deref(), Some("Read a file"));
262 }
263
264 #[test]
265 fn test_mcp_content_text() {
266 let json = r#"{"type": "text", "text": "hello"}"#;
267 let content: McpContent = serde_json::from_str(json).unwrap();
268 match content {
269 McpContent::Text { text } => assert_eq!(text, "hello"),
270 _ => panic!("Expected text content"),
271 }
272 }
273
274 #[test]
275 fn test_mcp_tool_result_deserialization() {
276 let json = r#"{
277 "content": [{"type": "text", "text": "result"}],
278 "isError": false
279 }"#;
280 let result: McpToolResult = serde_json::from_str(json).unwrap();
281 assert_eq!(result.content.len(), 1);
282 assert!(!result.is_error);
283 }
284}