1use std::collections::BTreeMap;
6use std::process::Stdio;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::Duration;
10
11use anyhow::{Context, Result, bail};
12use async_trait::async_trait;
13use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStdin, ChildStdout, Command};
15use tokio::sync::Mutex;
16
17use super::client::McpClient;
18use super::types::{CallToolResult, InitializeResult, ListToolsResult, ServerStatus};
19
20const DEFAULT_TIMEOUT_MS: u64 = 30_000;
22
23const MAX_SKIP_LINES: usize = 100;
26
27pub struct StdioClient {
29 server_name: String,
30 command: String,
31 args: Vec<String>,
32 env: BTreeMap<String, String>,
33 timeout_ms: u64,
34 status: Arc<Mutex<ServerStatus>>,
35 next_id: AtomicU64,
36 process: Arc<Mutex<Option<Child>>>,
37 stdin: Arc<Mutex<Option<ChildStdin>>>,
38 reader: Arc<Mutex<Option<BufReader<ChildStdout>>>>,
39 preread_line: Arc<Mutex<Option<String>>>,
41 request_lock: Arc<Mutex<()>>,
47}
48
49impl StdioClient {
50 pub fn new(
52 server_name: String,
53 command: String,
54 args: Vec<String>,
55 env: BTreeMap<String, String>,
56 timeout_ms: Option<u64>,
57 ) -> Self {
58 Self {
59 server_name,
60 command,
61 args,
62 env,
63 timeout_ms: timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS),
64 status: Arc::new(Mutex::new(ServerStatus::Disconnected)),
65 next_id: AtomicU64::new(1),
66 process: Arc::new(Mutex::new(None)),
67 stdin: Arc::new(Mutex::new(None)),
68 reader: Arc::new(Mutex::new(None)),
69 preread_line: Arc::new(Mutex::new(None)),
70 request_lock: Arc::new(Mutex::new(())),
71 }
72 }
73
74 async fn start(&self) -> Result<()> {
76 #[cfg(target_os = "windows")]
80 let (command, args) = windows_wrap_command(&self.command, &self.args);
81
82 #[cfg(not(target_os = "windows"))]
83 let (command, args) = (self.command.clone(), self.args.clone());
84
85 let mut cmd = Command::new(&command);
86 cmd.args(&args)
87 .stdin(Stdio::piped())
88 .stdout(Stdio::piped())
89 .stderr(Stdio::null());
90
91 for (key, value) in &self.env {
92 cmd.env(key, value);
93 }
94
95 crate::process_utils::suppress_console_window(&mut cmd);
96
97 let mut child = cmd.spawn().with_context(|| {
98 #[cfg(target_os = "windows")]
99 {
100 let msg = format!(
101 "Failed to spawn MCP server: {}. \
102 On Windows, commands like 'npx' are .cmd scripts and must \
103 be executed through 'cmd /C'. AtomCode wraps known commands \
104 automatically; if this is a custom .cmd/.bat, set command to \
105 'cmd' and add '/C' before the script name in args.",
106 self.command
107 );
108 msg
109 }
110 #[cfg(not(target_os = "windows"))]
111 {
112 format!("Failed to spawn MCP server: {}", self.command)
113 }
114 })?;
115
116 let stdin = child.stdin.take().context("Failed to get stdin")?;
117 let stdout = child.stdout.take().context("Failed to get stdout")?;
118 let reader = BufReader::new(stdout);
119
120 *self.process.lock().await = Some(child);
121 *self.stdin.lock().await = Some(stdin);
122 *self.reader.lock().await = Some(reader);
123
124 Ok(())
125 }
126
127 async fn send_request(
129 &self,
130 method: &str,
131 params: Option<serde_json::Value>,
132 ) -> Result<serde_json::Value> {
133 let _req_guard = self.request_lock.lock().await;
134 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
135
136 let mut request = serde_json::Map::new();
141 request.insert(
142 "jsonrpc".to_string(),
143 serde_json::Value::String("2.0".to_string()),
144 );
145 request.insert("id".to_string(), serde_json::Value::Number(id.into()));
146 request.insert(
147 "method".to_string(),
148 serde_json::Value::String(method.to_string()),
149 );
150 if let Some(p) = params {
151 request.insert("params".to_string(), p);
152 }
153 let request = serde_json::Value::Object(request);
154
155 let timeout = Duration::from_millis(self.timeout_ms);
156
157 {
159 let mut stdin = self.stdin.lock().await;
160 let stdin = stdin.as_mut().context("MCP server not connected (stdin)")?;
161
162 let mut body = serde_json::to_vec(&request)?;
163 body.push(b'\n');
164 stdin.write_all(&body).await?;
165 stdin.flush().await?;
166 }
167
168 let result = tokio::time::timeout(timeout, self.recv_jsonrpc_response())
170 .await
171 .with_context(|| {
172 format!(
173 "MCP request {} timed out after {}ms",
174 method, self.timeout_ms
175 )
176 })??;
177
178 if let Some(error) = result.error {
179 bail!("MCP error {} (code {}): {}", error.message, error.code, "");
180 }
181
182 result
183 .result
184 .ok_or_else(|| anyhow::anyhow!("MCP response missing result"))
185 }
186}
187
188#[async_trait]
189impl McpClient for StdioClient {
190 async fn initialize(&mut self) -> Result<InitializeResult> {
191 let mut status = self.status.lock().await;
192 *status = ServerStatus::Connecting;
193 drop(status);
194
195 self.start().await?;
196
197 self.drain_startup_messages().await?;
199
200 let params = serde_json::json!({
202 "protocolVersion": "2024-11-05",
203 "capabilities": {
204 "tools": {}
205 },
206 "clientInfo": {
207 "name": "atomcode",
208 "version": env!("CARGO_PKG_VERSION")
209 }
210 });
211
212 let result: InitializeResult =
213 serde_json::from_value(self.send_request("initialize", Some(params)).await?)
214 .context("Failed to parse initialize result")?;
215
216 {
218 let mut stdin = self.stdin.lock().await;
219 if let Some(stdin) = stdin.as_mut() {
220 let notification = serde_json::json!({
221 "jsonrpc": "2.0",
222 "method": "notifications/initialized"
223 });
224 let mut body = serde_json::to_vec(¬ification)?;
225 body.push(b'\n');
226 stdin.write_all(&body).await?;
227 stdin.flush().await?;
228 }
229 }
230
231 let mut status = self.status.lock().await;
232 *status = ServerStatus::Connected;
233
234 Ok(result)
235 }
236
237 async fn list_tools(&self) -> Result<ListToolsResult> {
238 let result = self.send_request("tools/list", None).await?;
239 serde_json::from_value(result).context("Failed to parse tools/list result")
240 }
241
242 async fn call_tool(
243 &self,
244 tool_name: &str,
245 arguments: serde_json::Value,
246 ) -> Result<CallToolResult> {
247 let params = serde_json::json!({
248 "name": tool_name,
249 "arguments": arguments
250 });
251
252 let result = self.send_request("tools/call", Some(params)).await?;
253 serde_json::from_value(result).context("Failed to parse tools/call result")
254 }
255
256 fn server_name(&self) -> &str {
257 &self.server_name
258 }
259
260 fn status(&self) -> ServerStatus {
261 self.status
262 .try_lock()
263 .map(|s| s.clone())
264 .unwrap_or(ServerStatus::Disconnected)
265 }
266}
267
268impl StdioClient {
269 async fn recv_jsonrpc_response(&self) -> Result<super::types::JsonRpcResponse> {
271 let mut reader = self.reader.lock().await;
272 let reader = reader
273 .as_mut()
274 .context("MCP server not connected (reader)")?;
275
276 let mut skipped_lines = 0;
277 loop {
278 let line = if let Some(s) = self.preread_line.lock().await.take() {
279 s
280 } else {
281 let mut buf = String::new();
282 loop {
283 buf.clear();
284 let n = reader.read_line(&mut buf).await?;
285 if n == 0 {
286 bail!("MCP server closed connection");
287 }
288 if !buf.trim().is_empty() {
289 break;
290 }
291 }
292 buf
293 };
294
295 let body = line.trim_end_matches(['\r', '\n']).trim_start();
296 if body.starts_with('{') || body.starts_with('[') {
297 return serde_json::from_str(body)
298 .context("Failed to parse NDJSON MCP message as JSON-RPC");
299 }
300 if strip_prefix_ci(body, "content-length:").is_some() {
301 return read_content_length_message(reader, line).await;
302 }
303
304 skipped_lines += 1;
309 if skipped_lines > MAX_SKIP_LINES {
310 bail!(
311 "MCP stdio: too many non-protocol lines (>{MAX_SKIP_LINES}), last line: {}",
312 body.chars().take(80).collect::<String>()
313 );
314 }
315 }
316 }
317
318 async fn drain_startup_messages(&self) -> Result<()> {
323 let _ = tokio::time::timeout(Duration::from_millis(500), async {
324 loop {
325 let mut line = String::new();
326 let mut reader = self.reader.lock().await;
327 let Some(r) = reader.as_mut() else {
328 return;
329 };
330 let read_res =
331 tokio::time::timeout(Duration::from_millis(80), r.read_line(&mut line)).await;
332 drop(reader);
333
334 match read_res {
335 Err(_) | Ok(Err(_)) | Ok(Ok(0)) => return,
336 Ok(Ok(_)) => {
337 let t = line.trim();
338 if t.is_empty() {
339 continue;
340 }
341 let js = t.trim_start();
342 if js.starts_with('{')
343 || js.starts_with('[')
344 || strip_prefix_ci(js, "content-length:").is_some()
345 {
346 *self.preread_line.lock().await = Some(line);
347 return;
348 }
349 }
350 }
351 }
352 })
353 .await;
354
355 Ok(())
356 }
357}
358
359fn strip_prefix_ci<'a>(s: &'a str, prefix_lower: &'static str) -> Option<&'a str> {
361 let b = s.as_bytes();
362 let p = prefix_lower.as_bytes();
363 if b.len() < p.len() {
364 return None;
365 }
366 if !b[..p.len()].eq_ignore_ascii_case(p) {
367 return None;
368 }
369 Some(&s[p.len()..])
370}
371
372async fn read_content_length_message(
373 reader: &mut BufReader<ChildStdout>,
374 mut line: String,
375) -> Result<super::types::JsonRpcResponse> {
376 let mut content_length: Option<usize> = None;
377 loop {
378 let t = line.trim_end_matches(['\r', '\n']).trim();
379 if t.is_empty() {
380 break;
381 }
382 if let Some(rest) = strip_prefix_ci(t, "content-length:") {
383 content_length = Some(rest.trim().parse().context("Invalid Content-Length")?);
384 }
385 line.clear();
386 let n = reader.read_line(&mut line).await?;
387 if n == 0 {
388 bail!("MCP server closed connection while reading headers");
389 }
390 }
391
392 let length = content_length.context("Missing Content-Length header")?;
393 let mut body = vec![0u8; length];
394 reader.read_exact(&mut body).await?;
395 serde_json::from_slice(&body).context("Failed to parse JSON-RPC response")
396}
397
398#[cfg_attr(not(target_os = "windows"), allow(dead_code))]
411fn wrap_cmd_script(command: &str, args: &[String], shell: &str) -> (String, Vec<String>) {
412 const CMD_SCRIPTS: &[&str] = &[
415 "npx",
416 "npm",
417 "npx.cmd",
418 "npm.cmd",
419 "yarn",
420 "yarn.cmd",
421 "pnpm",
422 "pnpm.cmd",
423 ];
424
425 let lower = command.to_ascii_lowercase();
426 let needs_wrap = CMD_SCRIPTS.iter().any(|&s| lower == s)
427 || lower.ends_with(".cmd")
428 || lower.ends_with(".bat");
429
430 if needs_wrap {
431 let mut wrapped_args = vec!["/C".to_string(), command.to_string()];
432 wrapped_args.extend(args.iter().cloned());
433 (shell.to_string(), wrapped_args)
434 } else {
435 (command.to_string(), args.to_vec())
436 }
437}
438
439#[cfg(target_os = "windows")]
441fn windows_wrap_command(command: &str, args: &[String]) -> (String, Vec<String>) {
442 wrap_cmd_script(command, args, "cmd.exe")
443}
444
445impl Drop for StdioClient {
446 fn drop(&mut self) {
447 if let Ok(mut process) = self.process.try_lock() {
449 if let Some(mut child) = process.take() {
450 let _ = child.start_kill();
451 }
452 }
453 }
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 #[test]
465 fn wrap_npx() {
466 let (cmd, args) = wrap_cmd_script("npx", &["-y".into(), "@pkg/server".into()], "cmd.exe");
467 assert_eq!(cmd, "cmd.exe");
468 assert_eq!(args, vec!["/C", "npx", "-y", "@pkg/server"]);
469 }
470
471 #[test]
472 fn wrap_npx_cmd_suffix() {
473 let (cmd, args) = wrap_cmd_script("npx.cmd", &["-y".into(), "@pkg/server".into()], "cmd.exe");
474 assert_eq!(cmd, "cmd.exe");
475 assert_eq!(args, vec!["/C", "npx.cmd", "-y", "@pkg/server"]);
476 }
477
478 #[test]
479 fn wrap_npm() {
480 let (cmd, args) = wrap_cmd_script("npm", &["install".into()], "cmd.exe");
481 assert_eq!(cmd, "cmd.exe");
482 assert_eq!(args, vec!["/C", "npm", "install"]);
483 }
484
485 #[test]
486 fn wrap_yarn() {
487 let (cmd, args) = wrap_cmd_script("yarn", &["add".into(), "lodash".into()], "cmd.exe");
488 assert_eq!(cmd, "cmd.exe");
489 assert_eq!(args, vec!["/C", "yarn", "add", "lodash"]);
490 }
491
492 #[test]
493 fn wrap_pnpm() {
494 let (cmd, args) = wrap_cmd_script("pnpm", &["install".into()], "cmd.exe");
495 assert_eq!(cmd, "cmd.exe");
496 assert_eq!(args, vec!["/C", "pnpm", "install"]);
497 }
498
499 #[test]
500 fn wrap_custom_bat() {
501 let (cmd, args) = wrap_cmd_script("my-script.bat", &["--flag".into()], "cmd.exe");
502 assert_eq!(cmd, "cmd.exe");
503 assert_eq!(args, vec!["/C", "my-script.bat", "--flag"]);
504 }
505
506 #[test]
507 fn wrap_custom_cmd_suffix() {
508 let (cmd, args) = wrap_cmd_script("build.cmd", &[], "cmd.exe");
509 assert_eq!(cmd, "cmd.exe");
510 assert_eq!(args, vec!["/C", "build.cmd"]);
511 }
512
513 #[test]
514 fn no_wrap_exe() {
515 let (cmd, args) = wrap_cmd_script("node", &["server.js".into()], "cmd.exe");
516 assert_eq!(cmd, "node");
517 assert_eq!(args, vec!["server.js"]);
518 }
519
520 #[test]
521 fn no_wrap_already_wrapped() {
522 let (cmd, args) =
524 wrap_cmd_script("cmd", &["/C".into(), "npx".into(), "-y".into()], "cmd.exe");
525 assert_eq!(cmd, "cmd");
526 assert_eq!(args, vec!["/C", "npx", "-y"]);
527 }
528
529 #[test]
530 fn wrap_case_insensitive() {
531 let (cmd, args) = wrap_cmd_script("NPX", &["-y".into(), "@pkg/server".into()], "cmd.exe");
532 assert_eq!(cmd, "cmd.exe");
533 assert_eq!(args, vec!["/C", "NPX", "-y", "@pkg/server"]);
534 }
535
536 #[test]
537 fn wrap_preserves_original_command_in_args() {
538 let (cmd, args) = wrap_cmd_script("Npx", &["-y".into()], "cmd.exe");
540 assert_eq!(cmd, "cmd.exe");
541 assert_eq!(args[1], "Npx"); }
543
544 #[test]
545 fn no_wrap_python() {
546 let (cmd, args) = wrap_cmd_script("python", &["-m".into(), "server".into()], "cmd.exe");
547 assert_eq!(cmd, "python");
548 assert_eq!(args, vec!["-m", "server"]);
549 }
550}