1use anyhow::{Context, Result, bail};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::io::{BufRead, BufReader, BufWriter, Write};
6use std::process::{Child, Command, Stdio};
7use std::sync::{Arc, Mutex};
8
9use crate::tools::Tool;
10
11const PROTOCOL_VERSION: &str = "2024-11-05";
12
13#[derive(Serialize)]
14struct JsonRpcRequest {
15 jsonrpc: &'static str,
16 id: u64,
17 method: String,
18 #[serde(skip_serializing_if = "Option::is_none")]
19 params: Option<Value>,
20}
21
22#[derive(Serialize)]
23struct JsonRpcNotification {
24 jsonrpc: &'static str,
25 method: String,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 params: Option<Value>,
28}
29
30#[derive(Deserialize)]
31struct JsonRpcResponse {
32 #[allow(dead_code)]
33 jsonrpc: String,
34 id: Option<u64>,
35 result: Option<Value>,
36 error: Option<JsonRpcError>,
37}
38
39#[derive(Deserialize)]
40struct JsonRpcError {
41 code: i64,
42 message: String,
43 #[allow(dead_code)]
44 data: Option<Value>,
45}
46
47#[derive(Debug, Clone, Deserialize)]
48pub struct McpToolDef {
49 pub name: String,
50 pub description: Option<String>,
51 #[serde(rename = "inputSchema")]
52 pub input_schema: Value,
53}
54
55#[derive(Debug, Deserialize)]
56struct ToolsListResult {
57 tools: Vec<McpToolDef>,
58}
59
60#[derive(Debug, Deserialize)]
61struct ToolCallContent {
62 #[allow(dead_code)]
63 #[serde(rename = "type")]
64 content_type: String,
65 text: Option<String>,
66}
67
68#[derive(Debug, Deserialize)]
69struct ToolCallResult {
70 content: Vec<ToolCallContent>,
71 #[serde(rename = "isError", default)]
72 is_error: bool,
73}
74
75struct ClientInner {
76 stdin: BufWriter<std::process::ChildStdin>,
77 stdout: BufReader<std::process::ChildStdout>,
78 next_id: u64,
79}
80
81pub struct McpClient {
82 server_name: String,
83 inner: Mutex<ClientInner>,
84 _child: Mutex<Child>,
85}
86
87impl McpClient {
88 pub fn start(
89 server_name: &str,
90 command: &[String],
91 env: &HashMap<String, String>,
92 ) -> Result<Self> {
93 if command.is_empty() {
94 bail!("MCP server '{}' has empty command", server_name);
95 }
96
97 let mut cmd = Command::new(&command[0]);
98 if command.len() > 1 {
99 cmd.args(&command[1..]);
100 }
101 cmd.stdin(Stdio::piped())
102 .stdout(Stdio::piped())
103 .stderr(Stdio::null());
104
105 for (k, v) in env {
106 cmd.env(k, v);
107 }
108
109 let mut child = cmd
110 .spawn()
111 .with_context(|| format!("Failed to start MCP server '{}'", server_name))?;
112
113 let stdin = child.stdin.take().context("Failed to get stdin")?;
114 let stdout = child.stdout.take().context("Failed to get stdout")?;
115
116 Ok(McpClient {
117 server_name: server_name.to_string(),
118 inner: Mutex::new(ClientInner {
119 stdin: BufWriter::new(stdin),
120 stdout: BufReader::new(stdout),
121 next_id: 1,
122 }),
123 _child: Mutex::new(child),
124 })
125 }
126
127 fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value> {
128 let mut inner = self.inner.lock().map_err(|e| anyhow::anyhow!("{}", e))?;
129
130 let id = inner.next_id;
131 inner.next_id += 1;
132
133 let request = JsonRpcRequest {
134 jsonrpc: "2.0",
135 id,
136 method: method.to_string(),
137 params,
138 };
139
140 let msg = serde_json::to_string(&request)?;
141 writeln!(inner.stdin, "{}", msg)?;
142 inner.stdin.flush()?;
143
144 loop {
145 let mut line = String::new();
146 let bytes_read = inner.stdout.read_line(&mut line)?;
147 if bytes_read == 0 {
148 bail!(
149 "MCP server '{}' closed connection unexpectedly",
150 self.server_name
151 );
152 }
153 let line = line.trim();
154 if line.is_empty() {
155 continue;
156 }
157
158 let response: JsonRpcResponse = match serde_json::from_str(line) {
159 Ok(r) => r,
160 Err(_) => continue,
161 };
162
163 if response.id == Some(id) {
164 if let Some(error) = response.error {
165 bail!(
166 "MCP error from '{}': {} (code {})",
167 self.server_name,
168 error.message,
169 error.code
170 );
171 }
172 return response
173 .result
174 .ok_or_else(|| anyhow::anyhow!("Empty result from '{}'", self.server_name));
175 }
176 }
177 }
178
179 fn send_notification(&self, method: &str, params: Option<Value>) -> Result<()> {
180 let mut inner = self.inner.lock().map_err(|e| anyhow::anyhow!("{}", e))?;
181
182 let notification = JsonRpcNotification {
183 jsonrpc: "2.0",
184 method: method.to_string(),
185 params,
186 };
187
188 let msg = serde_json::to_string(¬ification)?;
189 writeln!(inner.stdin, "{}", msg)?;
190 inner.stdin.flush()?;
191 Ok(())
192 }
193
194 pub fn initialize(&self) -> Result<()> {
195 let params = serde_json::json!({
196 "protocolVersion": PROTOCOL_VERSION,
197 "capabilities": {},
198 "clientInfo": {
199 "name": "dot",
200 "version": "0.1.0"
201 }
202 });
203
204 let _result = self.send_request("initialize", Some(params))?;
205 self.send_notification("notifications/initialized", None)?;
206 tracing::info!("MCP server '{}' initialized", self.server_name);
207 Ok(())
208 }
209
210 pub fn list_tools(&self) -> Result<Vec<McpToolDef>> {
211 let result = self.send_request("tools/list", Some(serde_json::json!({})))?;
212 let tools_result: ToolsListResult = serde_json::from_value(result)?;
213 Ok(tools_result.tools)
214 }
215
216 pub fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
217 let params = serde_json::json!({
218 "name": name,
219 "arguments": arguments
220 });
221
222 let result = self.send_request("tools/call", Some(params))?;
223 let call_result: ToolCallResult = serde_json::from_value(result)?;
224
225 let text: Vec<String> = call_result
226 .content
227 .iter()
228 .filter_map(|c| c.text.clone())
229 .collect();
230 let output = text.join("\n");
231
232 if call_result.is_error {
233 bail!("{}", output);
234 }
235 Ok(output)
236 }
237
238 pub fn server_name(&self) -> &str {
239 &self.server_name
240 }
241}
242
243impl Drop for McpClient {
244 fn drop(&mut self) {
245 if let Ok(child) = self._child.get_mut() {
246 let _ = child.kill();
247 let _ = child.wait();
248 }
249 }
250}
251
252pub struct McpToolBridge {
254 tool_name: String,
255 prefixed_name: String,
256 description: String,
257 input_schema: Value,
258 client: Arc<McpClient>,
259}
260
261impl McpToolBridge {
262 pub fn new(client: Arc<McpClient>, server_name: &str, tool_def: &McpToolDef) -> Self {
263 McpToolBridge {
264 tool_name: tool_def.name.clone(),
265 prefixed_name: format!("{}_{}", server_name, tool_def.name),
266 description: tool_def
267 .description
268 .clone()
269 .unwrap_or_else(|| format!("[{}] {}", server_name, tool_def.name)),
270 input_schema: tool_def.input_schema.clone(),
271 client,
272 }
273 }
274}
275
276impl Tool for McpToolBridge {
277 fn name(&self) -> &str {
278 &self.prefixed_name
279 }
280
281 fn description(&self) -> &str {
282 &self.description
283 }
284
285 fn input_schema(&self) -> Value {
286 self.input_schema.clone()
287 }
288
289 fn execute(&self, input: Value) -> Result<String> {
290 tracing::debug!("MCP {}:{}", self.client.server_name(), self.tool_name);
291 self.client.call_tool(&self.tool_name, input)
292 }
293}
294
295pub struct McpManager {
297 clients: Vec<Arc<McpClient>>,
298}
299
300impl Default for McpManager {
301 fn default() -> Self {
302 Self::new()
303 }
304}
305
306impl McpManager {
307 pub fn new() -> Self {
308 McpManager {
309 clients: Vec::new(),
310 }
311 }
312
313 pub fn start_server(
314 &mut self,
315 name: &str,
316 command: &[String],
317 env: &HashMap<String, String>,
318 ) -> Result<()> {
319 let client = McpClient::start(name, command, env)?;
320 client.initialize()?;
321 self.clients.push(Arc::new(client));
322 Ok(())
323 }
324
325 pub fn discover_tools(&self) -> Vec<Box<dyn Tool>> {
326 let mut tools: Vec<Box<dyn Tool>> = Vec::new();
327
328 for client in &self.clients {
329 match client.list_tools() {
330 Ok(tool_defs) => {
331 tracing::info!("MCP '{}': {} tools", client.server_name(), tool_defs.len());
332 for td in &tool_defs {
333 tools.push(Box::new(McpToolBridge::new(
334 client.clone(),
335 client.server_name(),
336 td,
337 )));
338 }
339 }
340 Err(e) => {
341 tracing::warn!(
342 "Failed to list tools from '{}': {}",
343 client.server_name(),
344 e
345 );
346 }
347 }
348 }
349
350 tools
351 }
352
353 pub fn server_count(&self) -> usize {
354 self.clients.len()
355 }
356}