Skip to main content

dot/
mcp.rs

1use anyhow::{Context, Result, bail};
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::collections::HashMap;
5use std::io::{BufRead, BufReader, BufWriter, Write};
6use std::process::{Child, Command, Stdio};
7use std::sync::{Arc, Mutex};
8
9use crate::tools::Tool;
10
11const PROTOCOL_VERSION: &str = "2024-11-05";
12
13#[derive(Serialize)]
14struct JsonRpcRequest {
15    jsonrpc: &'static str,
16    id: u64,
17    method: String,
18    #[serde(skip_serializing_if = "Option::is_none")]
19    params: Option<Value>,
20}
21
22#[derive(Serialize)]
23struct JsonRpcNotification {
24    jsonrpc: &'static str,
25    method: String,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    params: Option<Value>,
28}
29
30#[derive(Deserialize)]
31struct JsonRpcResponse {
32    #[allow(dead_code)]
33    jsonrpc: String,
34    id: Option<u64>,
35    result: Option<Value>,
36    error: Option<JsonRpcError>,
37}
38
39#[derive(Deserialize)]
40struct JsonRpcError {
41    code: i64,
42    message: String,
43    #[allow(dead_code)]
44    data: Option<Value>,
45}
46
47#[derive(Debug, Clone, Deserialize)]
48pub struct McpToolDef {
49    pub name: String,
50    pub description: Option<String>,
51    #[serde(rename = "inputSchema")]
52    pub input_schema: Value,
53}
54
55#[derive(Debug, Deserialize)]
56struct ToolsListResult {
57    tools: Vec<McpToolDef>,
58}
59
60#[derive(Debug, Deserialize)]
61struct ToolCallContent {
62    #[allow(dead_code)]
63    #[serde(rename = "type")]
64    content_type: String,
65    text: Option<String>,
66}
67
68#[derive(Debug, Deserialize)]
69struct ToolCallResult {
70    content: Vec<ToolCallContent>,
71    #[serde(rename = "isError", default)]
72    is_error: bool,
73}
74
75struct ClientInner {
76    stdin: BufWriter<std::process::ChildStdin>,
77    stdout: BufReader<std::process::ChildStdout>,
78    next_id: u64,
79}
80
81pub struct McpClient {
82    server_name: String,
83    inner: Mutex<ClientInner>,
84    _child: Mutex<Child>,
85}
86
87impl McpClient {
88    pub fn start(
89        server_name: &str,
90        command: &[String],
91        env: &HashMap<String, String>,
92    ) -> Result<Self> {
93        if command.is_empty() {
94            bail!("MCP server '{}' has empty command", server_name);
95        }
96
97        let mut cmd = Command::new(&command[0]);
98        if command.len() > 1 {
99            cmd.args(&command[1..]);
100        }
101        cmd.stdin(Stdio::piped())
102            .stdout(Stdio::piped())
103            .stderr(Stdio::null());
104
105        for (k, v) in env {
106            cmd.env(k, v);
107        }
108
109        let mut child = cmd
110            .spawn()
111            .with_context(|| format!("Failed to start MCP server '{}'", server_name))?;
112
113        let stdin = child.stdin.take().context("Failed to get stdin")?;
114        let stdout = child.stdout.take().context("Failed to get stdout")?;
115
116        Ok(McpClient {
117            server_name: server_name.to_string(),
118            inner: Mutex::new(ClientInner {
119                stdin: BufWriter::new(stdin),
120                stdout: BufReader::new(stdout),
121                next_id: 1,
122            }),
123            _child: Mutex::new(child),
124        })
125    }
126
127    fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value> {
128        let mut inner = self.inner.lock().map_err(|e| anyhow::anyhow!("{}", e))?;
129
130        let id = inner.next_id;
131        inner.next_id += 1;
132
133        let request = JsonRpcRequest {
134            jsonrpc: "2.0",
135            id,
136            method: method.to_string(),
137            params,
138        };
139
140        let msg = serde_json::to_string(&request)?;
141        writeln!(inner.stdin, "{}", msg)?;
142        inner.stdin.flush()?;
143
144        loop {
145            let mut line = String::new();
146            let bytes_read = inner.stdout.read_line(&mut line)?;
147            if bytes_read == 0 {
148                bail!(
149                    "MCP server '{}' closed connection unexpectedly",
150                    self.server_name
151                );
152            }
153            let line = line.trim();
154            if line.is_empty() {
155                continue;
156            }
157
158            let response: JsonRpcResponse = match serde_json::from_str(line) {
159                Ok(r) => r,
160                Err(_) => continue,
161            };
162
163            if response.id == Some(id) {
164                if let Some(error) = response.error {
165                    bail!(
166                        "MCP error from '{}': {} (code {})",
167                        self.server_name,
168                        error.message,
169                        error.code
170                    );
171                }
172                return response
173                    .result
174                    .ok_or_else(|| anyhow::anyhow!("Empty result from '{}'", self.server_name));
175            }
176        }
177    }
178
179    fn send_notification(&self, method: &str, params: Option<Value>) -> Result<()> {
180        let mut inner = self.inner.lock().map_err(|e| anyhow::anyhow!("{}", e))?;
181
182        let notification = JsonRpcNotification {
183            jsonrpc: "2.0",
184            method: method.to_string(),
185            params,
186        };
187
188        let msg = serde_json::to_string(&notification)?;
189        writeln!(inner.stdin, "{}", msg)?;
190        inner.stdin.flush()?;
191        Ok(())
192    }
193
194    pub fn initialize(&self) -> Result<()> {
195        let params = serde_json::json!({
196            "protocolVersion": PROTOCOL_VERSION,
197            "capabilities": {},
198            "clientInfo": {
199                "name": "dot",
200                "version": "0.1.0"
201            }
202        });
203
204        let _result = self.send_request("initialize", Some(params))?;
205        self.send_notification("notifications/initialized", None)?;
206        tracing::info!("MCP server '{}' initialized", self.server_name);
207        Ok(())
208    }
209
210    pub fn list_tools(&self) -> Result<Vec<McpToolDef>> {
211        let result = self.send_request("tools/list", Some(serde_json::json!({})))?;
212        let tools_result: ToolsListResult = serde_json::from_value(result)?;
213        Ok(tools_result.tools)
214    }
215
216    pub fn call_tool(&self, name: &str, arguments: Value) -> Result<String> {
217        let params = serde_json::json!({
218            "name": name,
219            "arguments": arguments
220        });
221
222        let result = self.send_request("tools/call", Some(params))?;
223        let call_result: ToolCallResult = serde_json::from_value(result)?;
224
225        let text: Vec<String> = call_result
226            .content
227            .iter()
228            .filter_map(|c| c.text.clone())
229            .collect();
230        let output = text.join("\n");
231
232        if call_result.is_error {
233            bail!("{}", output);
234        }
235        Ok(output)
236    }
237
238    pub fn server_name(&self) -> &str {
239        &self.server_name
240    }
241}
242
243impl Drop for McpClient {
244    fn drop(&mut self) {
245        if let Ok(child) = self._child.get_mut() {
246            let _ = child.kill();
247            let _ = child.wait();
248        }
249    }
250}
251
252/// Wraps an MCP server tool as a native `Tool` implementation.
253pub struct McpToolBridge {
254    tool_name: String,
255    prefixed_name: String,
256    description: String,
257    input_schema: Value,
258    client: Arc<McpClient>,
259}
260
261impl McpToolBridge {
262    pub fn new(client: Arc<McpClient>, server_name: &str, tool_def: &McpToolDef) -> Self {
263        McpToolBridge {
264            tool_name: tool_def.name.clone(),
265            prefixed_name: format!("{}_{}", server_name, tool_def.name),
266            description: tool_def
267                .description
268                .clone()
269                .unwrap_or_else(|| format!("[{}] {}", server_name, tool_def.name)),
270            input_schema: tool_def.input_schema.clone(),
271            client,
272        }
273    }
274}
275
276impl Tool for McpToolBridge {
277    fn name(&self) -> &str {
278        &self.prefixed_name
279    }
280
281    fn description(&self) -> &str {
282        &self.description
283    }
284
285    fn input_schema(&self) -> Value {
286        self.input_schema.clone()
287    }
288
289    fn execute(&self, input: Value) -> Result<String> {
290        tracing::debug!("MCP {}:{}", self.client.server_name(), self.tool_name);
291        self.client.call_tool(&self.tool_name, input)
292    }
293}
294
295/// Manages connections to all configured MCP servers.
296pub struct McpManager {
297    clients: Vec<Arc<McpClient>>,
298}
299
300impl Default for McpManager {
301    fn default() -> Self {
302        Self::new()
303    }
304}
305
306impl McpManager {
307    pub fn new() -> Self {
308        McpManager {
309            clients: Vec::new(),
310        }
311    }
312
313    pub fn start_server(
314        &mut self,
315        name: &str,
316        command: &[String],
317        env: &HashMap<String, String>,
318    ) -> Result<()> {
319        let client = McpClient::start(name, command, env)?;
320        client.initialize()?;
321        self.clients.push(Arc::new(client));
322        Ok(())
323    }
324
325    pub fn discover_tools(&self) -> Vec<Box<dyn Tool>> {
326        let mut tools: Vec<Box<dyn Tool>> = Vec::new();
327
328        for client in &self.clients {
329            match client.list_tools() {
330                Ok(tool_defs) => {
331                    tracing::info!("MCP '{}': {} tools", client.server_name(), tool_defs.len());
332                    for td in &tool_defs {
333                        tools.push(Box::new(McpToolBridge::new(
334                            client.clone(),
335                            client.server_name(),
336                            td,
337                        )));
338                    }
339                }
340                Err(e) => {
341                    tracing::warn!(
342                        "Failed to list tools from '{}': {}",
343                        client.server_name(),
344                        e
345                    );
346                }
347            }
348        }
349
350        tools
351    }
352
353    pub fn server_count(&self) -> usize {
354        self.clients.len()
355    }
356}