Skip to main content

hematite/agent/
mcp.rs

1use anyhow::{anyhow, Result};
2use serde::{Deserialize, Serialize};
3use serde_json::Value as JsonValue;
4use std::collections::VecDeque;
5use std::path::{Path, PathBuf};
6use std::process::Stdio;
7use std::sync::{Arc, Mutex};
8use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
9use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
10use tokio::task::JoinHandle;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum McpFraming {
14    NewlineDelimited,
15    ContentLength,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
19#[serde(untagged)]
20pub enum JsonRpcId {
21    Number(u64),
22    String(String),
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct JsonRpcRequest<T = JsonValue> {
27    pub jsonrpc: String,
28    pub id: JsonRpcId,
29    pub method: String,
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub params: Option<T>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct JsonRpcNotification<T = JsonValue> {
36    pub jsonrpc: String,
37    pub method: String,
38    #[serde(skip_serializing_if = "Option::is_none")]
39    pub params: Option<T>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct JsonRpcResponse<T = JsonValue> {
44    pub jsonrpc: String,
45    pub id: JsonRpcId,
46    pub result: Option<T>,
47    pub error: Option<JsonRpcError>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct JsonRpcError {
52    pub code: i64,
53    pub message: String,
54    pub data: Option<JsonValue>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58#[serde(rename_all = "camelCase")]
59pub struct McpTool {
60    pub name: String,
61    pub description: Option<String>,
62    pub input_schema: JsonValue,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(rename_all = "camelCase")]
67pub struct McpListToolsResult {
68    pub tools: Vec<McpTool>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72#[serde(rename_all = "camelCase")]
73pub struct McpCallToolResult {
74    pub content: Vec<McpContent>,
75    pub is_error: Option<bool>,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79#[serde(tag = "type")]
80pub enum McpContent {
81    #[serde(rename = "text")]
82    Text { text: String },
83    #[serde(rename = "image")]
84    Image { data: String, mime_type: String },
85}
86
87pub struct McpProcess {
88    _child: Child,
89    stdin: ChildStdin,
90    stdout: BufReader<ChildStdout>,
91    framing: McpFraming,
92    stderr_lines: Arc<Mutex<VecDeque<String>>>,
93    _stderr_task: JoinHandle<()>,
94}
95
96impl McpProcess {
97    pub fn spawn(
98        command: &str,
99        args: &[String],
100        env: &std::collections::HashMap<String, String>,
101    ) -> Result<Self> {
102        Self::spawn_with_framing(command, args, env, McpFraming::NewlineDelimited)
103    }
104
105    pub fn spawn_with_framing(
106        command: &str,
107        args: &[String],
108        env: &std::collections::HashMap<String, String>,
109        framing: McpFraming,
110    ) -> Result<Self> {
111        let resolved_command =
112            resolve_command_path(command).unwrap_or_else(|| PathBuf::from(command));
113        let mut cmd = if is_cmd_wrapper(&resolved_command) {
114            let mut wrapper = Command::new("cmd");
115            wrapper.arg("/C").arg(&resolved_command);
116            wrapper
117        } else {
118            Command::new(&resolved_command)
119        };
120        cmd.args(args)
121            .stdin(Stdio::piped())
122            .stdout(Stdio::piped())
123            .stderr(Stdio::piped());
124
125        for (k, v) in env {
126            cmd.env(k, v);
127        }
128
129        let mut child = cmd.spawn()?;
130        let stdin = child
131            .stdin
132            .take()
133            .ok_or_else(|| anyhow!("Failed to capture stdin"))?;
134        let stdout = child
135            .stdout
136            .take()
137            .ok_or_else(|| anyhow!("Failed to capture stdout"))?;
138        let stderr = child
139            .stderr
140            .take()
141            .ok_or_else(|| anyhow!("Failed to capture stderr"))?;
142        let stderr_lines = Arc::new(Mutex::new(VecDeque::with_capacity(16)));
143        let stderr_task = spawn_stderr_drain(stderr, Arc::clone(&stderr_lines));
144
145        Ok(Self {
146            _child: child,
147            stdin,
148            stdout: BufReader::new(stdout),
149            framing,
150            stderr_lines,
151            _stderr_task: stderr_task,
152        })
153    }
154
155    pub async fn request<P: Serialize, R: for<'de> Deserialize<'de>>(
156        &mut self,
157        id: u64,
158        method: &str,
159        params: Option<P>,
160    ) -> Result<R> {
161        let req = JsonRpcRequest {
162            jsonrpc: "2.0".to_string(),
163            id: JsonRpcId::Number(id),
164            method: method.to_string(),
165            params: params.map(serde_json::to_value).transpose()?,
166        };
167
168        self.write_message(&req).await?;
169
170        loop {
171            let payload = self.read_message_payload().await?;
172            let value: JsonValue = serde_json::from_slice(&payload).map_err(|e| {
173                anyhow!(
174                    "Failed to parse MCP response: {}. Raw: {}",
175                    e,
176                    String::from_utf8_lossy(&payload)
177                )
178            })?;
179
180            // Ignore notifications or server-initiated events while waiting for a response.
181            if value.get("id").is_none() {
182                continue;
183            }
184
185            let resp: JsonRpcResponse<R> = serde_json::from_value(value)
186                .map_err(|e| anyhow!("Failed to decode MCP response: {}", e))?;
187
188            if let Some(error) = resp.error {
189                return Err(anyhow!("MCP Error ({}): {}", error.code, error.message));
190            }
191
192            return resp
193                .result
194                .ok_or_else(|| anyhow!("Missing result in MCP response"));
195        }
196    }
197
198    pub async fn notify<P: Serialize>(&mut self, method: &str, params: Option<P>) -> Result<()> {
199        let notification = JsonRpcNotification {
200            jsonrpc: "2.0".to_string(),
201            method: method.to_string(),
202            params: params.map(serde_json::to_value).transpose()?,
203        };
204
205        self.write_message(&notification).await
206    }
207
208    pub async fn initialize(&mut self, id: u64) -> Result<()> {
209        let params = serde_json::json!({
210            "protocolVersion": "2024-11-05",
211            "capabilities": {},
212            "clientInfo": { "name": "hematite", "version": env!("CARGO_PKG_VERSION") }
213        });
214        let _: JsonValue = self.request(id, "initialize", Some(params)).await?;
215        self.notify("notifications/initialized", Some(serde_json::json!({})))
216            .await?;
217        Ok(())
218    }
219
220    pub async fn list_tools(&mut self, id: u64) -> Result<Vec<McpTool>> {
221        let res: McpListToolsResult = self.request(id, "tools/list", None::<()>).await?;
222        Ok(res.tools)
223    }
224
225    pub async fn call_tool(
226        &mut self,
227        id: u64,
228        name: &str,
229        arguments: JsonValue,
230    ) -> Result<McpCallToolResult> {
231        let params = serde_json::json!({
232            "name": name,
233            "arguments": arguments
234        });
235        self.request(id, "tools/call", Some(params)).await
236    }
237
238    pub async fn shutdown(mut self) {
239        let _ = self._child.kill().await;
240        self._stderr_task.abort();
241    }
242
243    pub fn stderr_summary(&self) -> Option<String> {
244        let lines = self.stderr_lines.lock().ok()?;
245        if lines.is_empty() {
246            None
247        } else {
248            Some({
249                let cap = lines.iter().map(|l| l.len()).sum::<usize>()
250                    + lines.len().saturating_sub(1) * 3;
251                let mut out = String::with_capacity(cap);
252                for (i, line) in lines.iter().enumerate() {
253                    if i > 0 {
254                        out.push_str(" | ");
255                    }
256                    out.push_str(line);
257                }
258                out
259            })
260        }
261    }
262
263    async fn write_message<T: Serialize>(&mut self, message: &T) -> Result<()> {
264        let payload = serde_json::to_vec(message)?;
265        match self.framing {
266            McpFraming::NewlineDelimited => {
267                self.stdin.write_all(&payload).await?;
268                self.stdin.write_all(b"\n").await?;
269            }
270            McpFraming::ContentLength => {
271                let header = format!("Content-Length: {}\r\n\r\n", payload.len());
272                self.stdin.write_all(header.as_bytes()).await?;
273                self.stdin.write_all(&payload).await?;
274            }
275        }
276        self.stdin.flush().await?;
277        Ok(())
278    }
279
280    async fn read_message_payload(&mut self) -> Result<Vec<u8>> {
281        match self.framing {
282            McpFraming::NewlineDelimited => {
283                let mut line = String::new();
284                self.stdout.read_line(&mut line).await?;
285                if line.is_empty() {
286                    return Err(anyhow!("MCP server closed connection unexpectedly"));
287                }
288                Ok(line.into_bytes())
289            }
290            McpFraming::ContentLength => {
291                let mut first_line = String::new();
292                self.stdout.read_line(&mut first_line).await?;
293                if first_line.is_empty() {
294                    return Err(anyhow!("MCP server closed connection unexpectedly"));
295                }
296
297                if !first_line.starts_with("Content-Length:") {
298                    return Ok(first_line.into_bytes());
299                }
300
301                let content_length = first_line["Content-Length:".len()..]
302                    .trim()
303                    .parse::<usize>()
304                    .map_err(|e| anyhow!("Invalid MCP Content-Length header: {}", e))?;
305
306                loop {
307                    let mut header_line = String::new();
308                    self.stdout.read_line(&mut header_line).await?;
309                    if header_line.is_empty() {
310                        return Err(anyhow!(
311                            "MCP server closed connection while reading headers"
312                        ));
313                    }
314                    if header_line == "\r\n" || header_line == "\n" {
315                        break;
316                    }
317                }
318
319                let mut payload = vec![0_u8; content_length];
320                self.stdout.read_exact(&mut payload).await?;
321                Ok(payload)
322            }
323        }
324    }
325}
326
327fn spawn_stderr_drain(
328    stderr: ChildStderr,
329    stderr_lines: Arc<Mutex<VecDeque<String>>>,
330) -> JoinHandle<()> {
331    tokio::spawn(async move {
332        let mut reader = BufReader::new(stderr);
333
334        loop {
335            let mut line = String::new();
336            match reader.read_line(&mut line).await {
337                Ok(0) | Err(_) => break,
338                Ok(_) => {
339                    let trimmed = line.trim();
340                    if trimmed.is_empty() {
341                        continue;
342                    }
343
344                    if let Ok(mut lines) = stderr_lines.lock() {
345                        lines.push_back(trimmed.to_string());
346                        while lines.len() > 20 {
347                            lines.pop_front();
348                        }
349                    }
350                }
351            }
352        }
353    })
354}
355
356#[cfg(windows)]
357fn resolve_command_path(command: &str) -> Option<PathBuf> {
358    let candidate = PathBuf::from(command);
359    let has_extension = Path::new(command).extension().is_some();
360    if candidate.is_absolute() || command.contains('\\') || command.contains('/') {
361        if !has_extension {
362            for ext in [".exe", ".cmd", ".bat", ".com"] {
363                let with_ext = PathBuf::from(format!("{command}{ext}"));
364                if with_ext.exists() {
365                    return Some(with_ext);
366                }
367            }
368        }
369        if candidate.exists() {
370            return Some(candidate);
371        }
372        return None;
373    }
374
375    let path_var = std::env::var_os("PATH")?;
376    for dir in std::env::split_paths(&path_var) {
377        if !has_extension {
378            for ext in [".exe", ".cmd", ".bat", ".com"] {
379                let with_ext = dir.join(format!("{command}{ext}"));
380                if with_ext.exists() {
381                    return Some(with_ext);
382                }
383            }
384        }
385        let direct = dir.join(command);
386        if direct.exists() {
387            return Some(direct);
388        }
389    }
390
391    None
392}
393
394#[cfg(not(windows))]
395fn resolve_command_path(command: &str) -> Option<PathBuf> {
396    Some(PathBuf::from(command))
397}
398
399#[cfg(windows)]
400fn is_cmd_wrapper(path: &Path) -> bool {
401    matches!(
402        path.extension().and_then(|ext| ext.to_str()).map(|ext| ext.to_ascii_lowercase()),
403        Some(ext) if ext == "cmd" || ext == "bat"
404    )
405}
406
407#[cfg(not(windows))]
408fn is_cmd_wrapper(_path: &Path) -> bool {
409    false
410}