1use anyhow::{anyhow, Result};
2use serde::{Deserialize, Serialize};
3use serde_json::Value as JsonValue;
4use std::collections::VecDeque;
5use std::path::{Path, PathBuf};
6use std::process::Stdio;
7use std::sync::{Arc, Mutex};
8use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
9use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
10use tokio::task::JoinHandle;
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum McpFraming {
14 NewlineDelimited,
15 ContentLength,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
19#[serde(untagged)]
20pub enum JsonRpcId {
21 Number(u64),
22 String(String),
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct JsonRpcRequest<T = JsonValue> {
27 pub jsonrpc: String,
28 pub id: JsonRpcId,
29 pub method: String,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 pub params: Option<T>,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct JsonRpcNotification<T = JsonValue> {
36 pub jsonrpc: String,
37 pub method: String,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 pub params: Option<T>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct JsonRpcResponse<T = JsonValue> {
44 pub jsonrpc: String,
45 pub id: JsonRpcId,
46 pub result: Option<T>,
47 pub error: Option<JsonRpcError>,
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct JsonRpcError {
52 pub code: i64,
53 pub message: String,
54 pub data: Option<JsonValue>,
55}
56
57#[derive(Debug, Clone, Serialize, Deserialize)]
58#[serde(rename_all = "camelCase")]
59pub struct McpTool {
60 pub name: String,
61 pub description: Option<String>,
62 pub input_schema: JsonValue,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
66#[serde(rename_all = "camelCase")]
67pub struct McpListToolsResult {
68 pub tools: Vec<McpTool>,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
72#[serde(rename_all = "camelCase")]
73pub struct McpCallToolResult {
74 pub content: Vec<McpContent>,
75 pub is_error: Option<bool>,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79#[serde(tag = "type")]
80pub enum McpContent {
81 #[serde(rename = "text")]
82 Text { text: String },
83 #[serde(rename = "image")]
84 Image { data: String, mime_type: String },
85}
86
87pub struct McpProcess {
88 _child: Child,
89 stdin: ChildStdin,
90 stdout: BufReader<ChildStdout>,
91 framing: McpFraming,
92 stderr_lines: Arc<Mutex<VecDeque<String>>>,
93 _stderr_task: JoinHandle<()>,
94}
95
96impl McpProcess {
97 pub fn spawn(
98 command: &str,
99 args: &[String],
100 env: &std::collections::HashMap<String, String>,
101 ) -> Result<Self> {
102 Self::spawn_with_framing(command, args, env, McpFraming::NewlineDelimited)
103 }
104
105 pub fn spawn_with_framing(
106 command: &str,
107 args: &[String],
108 env: &std::collections::HashMap<String, String>,
109 framing: McpFraming,
110 ) -> Result<Self> {
111 let resolved_command =
112 resolve_command_path(command).unwrap_or_else(|| PathBuf::from(command));
113 let mut cmd = if is_cmd_wrapper(&resolved_command) {
114 let mut wrapper = Command::new("cmd");
115 wrapper.arg("/C").arg(&resolved_command);
116 wrapper
117 } else {
118 Command::new(&resolved_command)
119 };
120 cmd.args(args)
121 .stdin(Stdio::piped())
122 .stdout(Stdio::piped())
123 .stderr(Stdio::piped());
124
125 for (k, v) in env {
126 cmd.env(k, v);
127 }
128
129 let mut child = cmd.spawn()?;
130 let stdin = child
131 .stdin
132 .take()
133 .ok_or_else(|| anyhow!("Failed to capture stdin"))?;
134 let stdout = child
135 .stdout
136 .take()
137 .ok_or_else(|| anyhow!("Failed to capture stdout"))?;
138 let stderr = child
139 .stderr
140 .take()
141 .ok_or_else(|| anyhow!("Failed to capture stderr"))?;
142 let stderr_lines = Arc::new(Mutex::new(VecDeque::with_capacity(16)));
143 let stderr_task = spawn_stderr_drain(stderr, Arc::clone(&stderr_lines));
144
145 Ok(Self {
146 _child: child,
147 stdin,
148 stdout: BufReader::new(stdout),
149 framing,
150 stderr_lines,
151 _stderr_task: stderr_task,
152 })
153 }
154
155 pub async fn request<P: Serialize, R: for<'de> Deserialize<'de>>(
156 &mut self,
157 id: u64,
158 method: &str,
159 params: Option<P>,
160 ) -> Result<R> {
161 let req = JsonRpcRequest {
162 jsonrpc: "2.0".to_string(),
163 id: JsonRpcId::Number(id),
164 method: method.to_string(),
165 params: params.map(serde_json::to_value).transpose()?,
166 };
167
168 self.write_message(&req).await?;
169
170 loop {
171 let payload = self.read_message_payload().await?;
172 let value: JsonValue = serde_json::from_slice(&payload).map_err(|e| {
173 anyhow!(
174 "Failed to parse MCP response: {}. Raw: {}",
175 e,
176 String::from_utf8_lossy(&payload)
177 )
178 })?;
179
180 if value.get("id").is_none() {
182 continue;
183 }
184
185 let resp: JsonRpcResponse<R> = serde_json::from_value(value)
186 .map_err(|e| anyhow!("Failed to decode MCP response: {}", e))?;
187
188 if let Some(error) = resp.error {
189 return Err(anyhow!("MCP Error ({}): {}", error.code, error.message));
190 }
191
192 return resp
193 .result
194 .ok_or_else(|| anyhow!("Missing result in MCP response"));
195 }
196 }
197
198 pub async fn notify<P: Serialize>(&mut self, method: &str, params: Option<P>) -> Result<()> {
199 let notification = JsonRpcNotification {
200 jsonrpc: "2.0".to_string(),
201 method: method.to_string(),
202 params: params.map(serde_json::to_value).transpose()?,
203 };
204
205 self.write_message(¬ification).await
206 }
207
208 pub async fn initialize(&mut self, id: u64) -> Result<()> {
209 let params = serde_json::json!({
210 "protocolVersion": "2024-11-05",
211 "capabilities": {},
212 "clientInfo": { "name": "hematite", "version": env!("CARGO_PKG_VERSION") }
213 });
214 let _: JsonValue = self.request(id, "initialize", Some(params)).await?;
215 self.notify("notifications/initialized", Some(serde_json::json!({})))
216 .await?;
217 Ok(())
218 }
219
220 pub async fn list_tools(&mut self, id: u64) -> Result<Vec<McpTool>> {
221 let res: McpListToolsResult = self.request(id, "tools/list", None::<()>).await?;
222 Ok(res.tools)
223 }
224
225 pub async fn call_tool(
226 &mut self,
227 id: u64,
228 name: &str,
229 arguments: JsonValue,
230 ) -> Result<McpCallToolResult> {
231 let params = serde_json::json!({
232 "name": name,
233 "arguments": arguments
234 });
235 self.request(id, "tools/call", Some(params)).await
236 }
237
238 pub async fn shutdown(mut self) {
239 let _ = self._child.kill().await;
240 self._stderr_task.abort();
241 }
242
243 pub fn stderr_summary(&self) -> Option<String> {
244 let lines = self.stderr_lines.lock().ok()?;
245 if lines.is_empty() {
246 None
247 } else {
248 Some(lines.iter().cloned().collect::<Vec<_>>().join(" | "))
249 }
250 }
251
252 async fn write_message<T: Serialize>(&mut self, message: &T) -> Result<()> {
253 let payload = serde_json::to_vec(message)?;
254 match self.framing {
255 McpFraming::NewlineDelimited => {
256 self.stdin.write_all(&payload).await?;
257 self.stdin.write_all(b"\n").await?;
258 }
259 McpFraming::ContentLength => {
260 let header = format!("Content-Length: {}\r\n\r\n", payload.len());
261 self.stdin.write_all(header.as_bytes()).await?;
262 self.stdin.write_all(&payload).await?;
263 }
264 }
265 self.stdin.flush().await?;
266 Ok(())
267 }
268
269 async fn read_message_payload(&mut self) -> Result<Vec<u8>> {
270 match self.framing {
271 McpFraming::NewlineDelimited => {
272 let mut line = String::new();
273 self.stdout.read_line(&mut line).await?;
274 if line.is_empty() {
275 return Err(anyhow!("MCP server closed connection unexpectedly"));
276 }
277 Ok(line.into_bytes())
278 }
279 McpFraming::ContentLength => {
280 let mut first_line = String::new();
281 self.stdout.read_line(&mut first_line).await?;
282 if first_line.is_empty() {
283 return Err(anyhow!("MCP server closed connection unexpectedly"));
284 }
285
286 if !first_line.starts_with("Content-Length:") {
287 return Ok(first_line.into_bytes());
288 }
289
290 let content_length = first_line["Content-Length:".len()..]
291 .trim()
292 .parse::<usize>()
293 .map_err(|e| anyhow!("Invalid MCP Content-Length header: {}", e))?;
294
295 loop {
296 let mut header_line = String::new();
297 self.stdout.read_line(&mut header_line).await?;
298 if header_line.is_empty() {
299 return Err(anyhow!(
300 "MCP server closed connection while reading headers"
301 ));
302 }
303 if header_line == "\r\n" || header_line == "\n" {
304 break;
305 }
306 }
307
308 let mut payload = vec![0_u8; content_length];
309 self.stdout.read_exact(&mut payload).await?;
310 Ok(payload)
311 }
312 }
313 }
314}
315
316fn spawn_stderr_drain(
317 stderr: ChildStderr,
318 stderr_lines: Arc<Mutex<VecDeque<String>>>,
319) -> JoinHandle<()> {
320 tokio::spawn(async move {
321 let mut reader = BufReader::new(stderr);
322
323 loop {
324 let mut line = String::new();
325 match reader.read_line(&mut line).await {
326 Ok(0) | Err(_) => break,
327 Ok(_) => {
328 let trimmed = line.trim();
329 if trimmed.is_empty() {
330 continue;
331 }
332
333 if let Ok(mut lines) = stderr_lines.lock() {
334 lines.push_back(trimmed.to_string());
335 while lines.len() > 20 {
336 lines.pop_front();
337 }
338 }
339 }
340 }
341 }
342 })
343}
344
345#[cfg(windows)]
346fn resolve_command_path(command: &str) -> Option<PathBuf> {
347 let candidate = PathBuf::from(command);
348 let has_extension = Path::new(command).extension().is_some();
349 if candidate.is_absolute() || command.contains('\\') || command.contains('/') {
350 if !has_extension {
351 for ext in [".exe", ".cmd", ".bat", ".com"] {
352 let with_ext = PathBuf::from(format!("{command}{ext}"));
353 if with_ext.exists() {
354 return Some(with_ext);
355 }
356 }
357 }
358 if candidate.exists() {
359 return Some(candidate);
360 }
361 return None;
362 }
363
364 let path_var = std::env::var_os("PATH")?;
365 for dir in std::env::split_paths(&path_var) {
366 if !has_extension {
367 for ext in [".exe", ".cmd", ".bat", ".com"] {
368 let with_ext = dir.join(format!("{command}{ext}"));
369 if with_ext.exists() {
370 return Some(with_ext);
371 }
372 }
373 }
374 let direct = dir.join(command);
375 if direct.exists() {
376 return Some(direct);
377 }
378 }
379
380 None
381}
382
383#[cfg(not(windows))]
384fn resolve_command_path(command: &str) -> Option<PathBuf> {
385 Some(PathBuf::from(command))
386}
387
388#[cfg(windows)]
389fn is_cmd_wrapper(path: &Path) -> bool {
390 matches!(
391 path.extension().and_then(|ext| ext.to_str()).map(|ext| ext.to_ascii_lowercase()),
392 Some(ext) if ext == "cmd" || ext == "bat"
393 )
394}
395
396#[cfg(not(windows))]
397fn is_cmd_wrapper(_path: &Path) -> bool {
398 false
399}