Skip to main content

atomcode_core/mcp/
transport_stdio.rs

1//! stdio transport for MCP servers.
2//!
3//! Communicates with MCP servers via subprocess stdin/stdout using JSON-RPC.
4
5use std::collections::BTreeMap;
6use std::process::Stdio;
7use std::sync::Arc;
8use std::sync::atomic::{AtomicU64, Ordering};
9use std::time::Duration;
10
11use anyhow::{Context, Result, bail};
12use async_trait::async_trait;
13use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, ChildStdin, ChildStdout, Command};
15use tokio::sync::Mutex;
16
17use super::client::McpClient;
18use super::types::{CallToolResult, InitializeResult, ListToolsResult, ServerStatus};
19
20/// Default timeout for MCP operations (30 seconds).
21const DEFAULT_TIMEOUT_MS: u64 = 30_000;
22
23/// Maximum non-protocol lines to skip before giving up.
24/// Protects against servers that spam stdout with logs.
25const MAX_SKIP_LINES: usize = 100;
26
27/// stdio-based MCP client.
28pub struct StdioClient {
29    server_name: String,
30    command: String,
31    args: Vec<String>,
32    env: BTreeMap<String, String>,
33    timeout_ms: u64,
34    status: Arc<Mutex<ServerStatus>>,
35    next_id: AtomicU64,
36    process: Arc<Mutex<Option<Child>>>,
37    stdin: Arc<Mutex<Option<ChildStdin>>>,
38    reader: Arc<Mutex<Option<BufReader<ChildStdout>>>>,
39    /// First response line peeked during startup drain (NDJSON or `Content-Length:`), not yet consumed.
40    preread_line: Arc<Mutex<Option<String>>>,
41    /// Serialize request/response round-trips.
42    ///
43    /// MCP over stdio is a single ordered byte stream. Allowing concurrent
44    /// in-flight requests can lead to response mix-ups or one caller
45    /// consuming the other's response, causing timeouts.
46    request_lock: Arc<Mutex<()>>,
47}
48
49impl StdioClient {
50    /// Create a new stdio client.
51    pub fn new(
52        server_name: String,
53        command: String,
54        args: Vec<String>,
55        env: BTreeMap<String, String>,
56        timeout_ms: Option<u64>,
57    ) -> Self {
58        Self {
59            server_name,
60            command,
61            args,
62            env,
63            timeout_ms: timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS),
64            status: Arc::new(Mutex::new(ServerStatus::Disconnected)),
65            next_id: AtomicU64::new(1),
66            process: Arc::new(Mutex::new(None)),
67            stdin: Arc::new(Mutex::new(None)),
68            reader: Arc::new(Mutex::new(None)),
69            preread_line: Arc::new(Mutex::new(None)),
70            request_lock: Arc::new(Mutex::new(())),
71        }
72    }
73
74    /// Start the subprocess and set up communication.
75    async fn start(&self) -> Result<()> {
76        // On Windows, commands like `npx`, `npm` are .cmd/.bat scripts
77        // that cannot be spawned directly via Command::new(). Wrap them
78        // through `cmd.exe /C` so the OS can locate and execute them.
79        #[cfg(target_os = "windows")]
80        let (command, args) = windows_wrap_command(&self.command, &self.args);
81
82        #[cfg(not(target_os = "windows"))]
83        let (command, args) = (self.command.clone(), self.args.clone());
84
85        let mut cmd = Command::new(&command);
86        cmd.args(&args)
87            .stdin(Stdio::piped())
88            .stdout(Stdio::piped())
89            .stderr(Stdio::null());
90
91        for (key, value) in &self.env {
92            cmd.env(key, value);
93        }
94
95        crate::process_utils::suppress_console_window(&mut cmd);
96
97        let mut child = cmd.spawn().with_context(|| {
98            #[cfg(target_os = "windows")]
99            {
100                let msg = format!(
101                    "Failed to spawn MCP server: {}. \
102                     On Windows, commands like 'npx' are .cmd scripts and must \
103                     be executed through 'cmd /C'. AtomCode wraps known commands \
104                     automatically; if this is a custom .cmd/.bat, set command to \
105                     'cmd' and add '/C' before the script name in args.",
106                    self.command
107                );
108                msg
109            }
110            #[cfg(not(target_os = "windows"))]
111            {
112                format!("Failed to spawn MCP server: {}", self.command)
113            }
114        })?;
115
116        let stdin = child.stdin.take().context("Failed to get stdin")?;
117        let stdout = child.stdout.take().context("Failed to get stdout")?;
118        let reader = BufReader::new(stdout);
119
120        *self.process.lock().await = Some(child);
121        *self.stdin.lock().await = Some(stdin);
122        *self.reader.lock().await = Some(reader);
123
124        Ok(())
125    }
126
127    /// Send a request and wait for response.
128    async fn send_request(
129        &self,
130        method: &str,
131        params: Option<serde_json::Value>,
132    ) -> Result<serde_json::Value> {
133        let _req_guard = self.request_lock.lock().await;
134        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
135
136        // IMPORTANT: omit `params` when it's None.
137        //
138        // The official JS MCP SDK's stdio transport can hang when it receives
139        // `"params": null` for methods that expect params to be absent.
140        let mut request = serde_json::Map::new();
141        request.insert(
142            "jsonrpc".to_string(),
143            serde_json::Value::String("2.0".to_string()),
144        );
145        request.insert("id".to_string(), serde_json::Value::Number(id.into()));
146        request.insert(
147            "method".to_string(),
148            serde_json::Value::String(method.to_string()),
149        );
150        if let Some(p) = params {
151            request.insert("params".to_string(), p);
152        }
153        let request = serde_json::Value::Object(request);
154
155        let timeout = Duration::from_millis(self.timeout_ms);
156
157        // Write request (NDJSON).
158        {
159            let mut stdin = self.stdin.lock().await;
160            let stdin = stdin.as_mut().context("MCP server not connected (stdin)")?;
161
162            let mut body = serde_json::to_vec(&request)?;
163            body.push(b'\n');
164            stdin.write_all(&body).await?;
165            stdin.flush().await?;
166        }
167
168        // Read response with timeout
169        let result = tokio::time::timeout(timeout, self.recv_jsonrpc_response())
170            .await
171            .with_context(|| {
172                format!(
173                    "MCP request {} timed out after {}ms",
174                    method, self.timeout_ms
175                )
176            })??;
177
178        if let Some(error) = result.error {
179            bail!("MCP error {} (code {}): {}", error.message, error.code, "");
180        }
181
182        result
183            .result
184            .ok_or_else(|| anyhow::anyhow!("MCP response missing result"))
185    }
186}
187
188#[async_trait]
189impl McpClient for StdioClient {
190    async fn initialize(&mut self) -> Result<InitializeResult> {
191        let mut status = self.status.lock().await;
192        *status = ServerStatus::Connecting;
193        drop(status);
194
195        self.start().await?;
196
197        // Drain any startup messages before JSON-RPC begins
198        self.drain_startup_messages().await?;
199
200        // Send initialize request
201        let params = serde_json::json!({
202            "protocolVersion": "2024-11-05",
203            "capabilities": {
204                "tools": {}
205            },
206            "clientInfo": {
207                "name": "atomcode",
208                "version": env!("CARGO_PKG_VERSION")
209            }
210        });
211
212        let result: InitializeResult =
213            serde_json::from_value(self.send_request("initialize", Some(params)).await?)
214                .context("Failed to parse initialize result")?;
215
216        // Send initialized notification
217        {
218            let mut stdin = self.stdin.lock().await;
219            if let Some(stdin) = stdin.as_mut() {
220                let notification = serde_json::json!({
221                    "jsonrpc": "2.0",
222                    "method": "notifications/initialized"
223                });
224                let mut body = serde_json::to_vec(&notification)?;
225                body.push(b'\n');
226                stdin.write_all(&body).await?;
227                stdin.flush().await?;
228            }
229        }
230
231        let mut status = self.status.lock().await;
232        *status = ServerStatus::Connected;
233
234        Ok(result)
235    }
236
237    async fn list_tools(&self) -> Result<ListToolsResult> {
238        let result = self.send_request("tools/list", None).await?;
239        serde_json::from_value(result).context("Failed to parse tools/list result")
240    }
241
242    async fn call_tool(
243        &self,
244        tool_name: &str,
245        arguments: serde_json::Value,
246    ) -> Result<CallToolResult> {
247        let params = serde_json::json!({
248            "name": tool_name,
249            "arguments": arguments
250        });
251
252        let result = self.send_request("tools/call", Some(params)).await?;
253        serde_json::from_value(result).context("Failed to parse tools/call result")
254    }
255
256    fn server_name(&self) -> &str {
257        &self.server_name
258    }
259
260    fn status(&self) -> ServerStatus {
261        self.status
262            .try_lock()
263            .map(|s| s.clone())
264            .unwrap_or(ServerStatus::Disconnected)
265    }
266}
267
268impl StdioClient {
269    /// Read one JSON-RPC response (NDJSON per MCP stdio spec, or legacy `Content-Length` framing).
270    async fn recv_jsonrpc_response(&self) -> Result<super::types::JsonRpcResponse> {
271        let mut reader = self.reader.lock().await;
272        let reader = reader
273            .as_mut()
274            .context("MCP server not connected (reader)")?;
275
276        let mut skipped_lines = 0;
277        loop {
278            let line = if let Some(s) = self.preread_line.lock().await.take() {
279                s
280            } else {
281                let mut buf = String::new();
282                loop {
283                    buf.clear();
284                    let n = reader.read_line(&mut buf).await?;
285                    if n == 0 {
286                        bail!("MCP server closed connection");
287                    }
288                    if !buf.trim().is_empty() {
289                        break;
290                    }
291                }
292                buf
293            };
294
295            let body = line.trim_end_matches(['\r', '\n']).trim_start();
296            if body.starts_with('{') || body.starts_with('[') {
297                return serde_json::from_str(body)
298                    .context("Failed to parse NDJSON MCP message as JSON-RPC");
299            }
300            if strip_prefix_ci(body, "content-length:").is_some() {
301                return read_content_length_message(reader, line).await;
302            }
303
304            // Some third-party MCP servers incorrectly print status logs to stdout
305            // after initialization. MCP requires stdout to contain only protocol
306            // messages, but skipping plain-text lines keeps otherwise usable tools
307            // available while still failing on malformed JSON-RPC frames above.
308            skipped_lines += 1;
309            if skipped_lines > MAX_SKIP_LINES {
310                bail!(
311                    "MCP stdio: too many non-protocol lines (>{MAX_SKIP_LINES}), last line: {}",
312                    body.chars().take(80).collect::<String>()
313                );
314            }
315        }
316    }
317
318    /// Drain non-protocol lines the server may print to stdout before the first MCP message.
319    ///
320    /// Lines that look like NDJSON or `Content-Length` are **not** consumed; they are moved to
321    /// [`Self::preread_line`] for [`Self::recv_jsonrpc_response`].
322    async fn drain_startup_messages(&self) -> Result<()> {
323        let _ = tokio::time::timeout(Duration::from_millis(500), async {
324            loop {
325                let mut line = String::new();
326                let mut reader = self.reader.lock().await;
327                let Some(r) = reader.as_mut() else {
328                    return;
329                };
330                let read_res =
331                    tokio::time::timeout(Duration::from_millis(80), r.read_line(&mut line)).await;
332                drop(reader);
333
334                match read_res {
335                    Err(_) | Ok(Err(_)) | Ok(Ok(0)) => return,
336                    Ok(Ok(_)) => {
337                        let t = line.trim();
338                        if t.is_empty() {
339                            continue;
340                        }
341                        let js = t.trim_start();
342                        if js.starts_with('{')
343                            || js.starts_with('[')
344                            || strip_prefix_ci(js, "content-length:").is_some()
345                        {
346                            *self.preread_line.lock().await = Some(line);
347                            return;
348                        }
349                    }
350                }
351            }
352        })
353        .await;
354
355        Ok(())
356    }
357}
358
359/// `prefix_lower` must be ASCII lower case.
360fn strip_prefix_ci<'a>(s: &'a str, prefix_lower: &'static str) -> Option<&'a str> {
361    let b = s.as_bytes();
362    let p = prefix_lower.as_bytes();
363    if b.len() < p.len() {
364        return None;
365    }
366    if !b[..p.len()].eq_ignore_ascii_case(p) {
367        return None;
368    }
369    Some(&s[p.len()..])
370}
371
372async fn read_content_length_message(
373    reader: &mut BufReader<ChildStdout>,
374    mut line: String,
375) -> Result<super::types::JsonRpcResponse> {
376    let mut content_length: Option<usize> = None;
377    loop {
378        let t = line.trim_end_matches(['\r', '\n']).trim();
379        if t.is_empty() {
380            break;
381        }
382        if let Some(rest) = strip_prefix_ci(t, "content-length:") {
383            content_length = Some(rest.trim().parse().context("Invalid Content-Length")?);
384        }
385        line.clear();
386        let n = reader.read_line(&mut line).await?;
387        if n == 0 {
388            bail!("MCP server closed connection while reading headers");
389        }
390    }
391
392    let length = content_length.context("Missing Content-Length header")?;
393    let mut body = vec![0u8; length];
394    reader.read_exact(&mut body).await?;
395    serde_json::from_slice(&body).context("Failed to parse JSON-RPC response")
396}
397
398/// On Windows, commands like `npx`, `npm`, `yarn`, `pnpm` are actually
399/// `.cmd`/`.bat` scripts that cannot be spawned directly via
400/// `Command::new()`. The OS `CreateProcess` API only launches `.exe`
401/// files directly. This function detects such commands and wraps them
402/// through `cmd.exe /C` so the OS can locate and execute the script.
403///
404/// If the user has already wrapped the command themselves (e.g.
405/// `command: "cmd"`, `args: ["/C", "npx", ...]`), this function is a
406/// no-op — `cmd` / `cmd.exe` are not in the wrap list.
407///
408/// The core logic is platform-independent (and testable on all platforms);
409/// the `shell` parameter is `"cmd.exe"` on Windows.
410#[cfg_attr(not(target_os = "windows"), allow(dead_code))]
411fn wrap_cmd_script(command: &str, args: &[String], shell: &str) -> (String, Vec<String>) {
412    /// Commands that are known to be `.cmd`/`.bat` scripts on Windows.
413    /// Checked case-insensitively.
414    const CMD_SCRIPTS: &[&str] = &[
415        "npx",
416        "npm",
417        "npx.cmd",
418        "npm.cmd",
419        "yarn",
420        "yarn.cmd",
421        "pnpm",
422        "pnpm.cmd",
423    ];
424
425    let lower = command.to_ascii_lowercase();
426    let needs_wrap = CMD_SCRIPTS.iter().any(|&s| lower == s)
427        || lower.ends_with(".cmd")
428        || lower.ends_with(".bat");
429
430    if needs_wrap {
431        let mut wrapped_args = vec!["/C".to_string(), command.to_string()];
432        wrapped_args.extend(args.iter().cloned());
433        (shell.to_string(), wrapped_args)
434    } else {
435        (command.to_string(), args.to_vec())
436    }
437}
438
439/// Windows-specific entry point that passes `"cmd.exe"` as the shell.
440#[cfg(target_os = "windows")]
441fn windows_wrap_command(command: &str, args: &[String]) -> (String, Vec<String>) {
442    wrap_cmd_script(command, args, "cmd.exe")
443}
444
445impl Drop for StdioClient {
446    fn drop(&mut self) {
447        // Try to kill the subprocess gracefully
448        if let Ok(mut process) = self.process.try_lock() {
449            if let Some(mut child) = process.take() {
450                let _ = child.start_kill();
451            }
452        }
453    }
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    // --- Platform-independent tests for wrap_cmd_script logic ---
461    // These run on ALL platforms (macOS, Linux, Windows) so we can
462    // verify the wrapping logic locally without a Windows machine.
463
464    #[test]
465    fn wrap_npx() {
466        let (cmd, args) = wrap_cmd_script("npx", &["-y".into(), "@pkg/server".into()], "cmd.exe");
467        assert_eq!(cmd, "cmd.exe");
468        assert_eq!(args, vec!["/C", "npx", "-y", "@pkg/server"]);
469    }
470
471    #[test]
472    fn wrap_npx_cmd_suffix() {
473        let (cmd, args) = wrap_cmd_script("npx.cmd", &["-y".into(), "@pkg/server".into()], "cmd.exe");
474        assert_eq!(cmd, "cmd.exe");
475        assert_eq!(args, vec!["/C", "npx.cmd", "-y", "@pkg/server"]);
476    }
477
478    #[test]
479    fn wrap_npm() {
480        let (cmd, args) = wrap_cmd_script("npm", &["install".into()], "cmd.exe");
481        assert_eq!(cmd, "cmd.exe");
482        assert_eq!(args, vec!["/C", "npm", "install"]);
483    }
484
485    #[test]
486    fn wrap_yarn() {
487        let (cmd, args) = wrap_cmd_script("yarn", &["add".into(), "lodash".into()], "cmd.exe");
488        assert_eq!(cmd, "cmd.exe");
489        assert_eq!(args, vec!["/C", "yarn", "add", "lodash"]);
490    }
491
492    #[test]
493    fn wrap_pnpm() {
494        let (cmd, args) = wrap_cmd_script("pnpm", &["install".into()], "cmd.exe");
495        assert_eq!(cmd, "cmd.exe");
496        assert_eq!(args, vec!["/C", "pnpm", "install"]);
497    }
498
499    #[test]
500    fn wrap_custom_bat() {
501        let (cmd, args) = wrap_cmd_script("my-script.bat", &["--flag".into()], "cmd.exe");
502        assert_eq!(cmd, "cmd.exe");
503        assert_eq!(args, vec!["/C", "my-script.bat", "--flag"]);
504    }
505
506    #[test]
507    fn wrap_custom_cmd_suffix() {
508        let (cmd, args) = wrap_cmd_script("build.cmd", &[], "cmd.exe");
509        assert_eq!(cmd, "cmd.exe");
510        assert_eq!(args, vec!["/C", "build.cmd"]);
511    }
512
513    #[test]
514    fn no_wrap_exe() {
515        let (cmd, args) = wrap_cmd_script("node", &["server.js".into()], "cmd.exe");
516        assert_eq!(cmd, "node");
517        assert_eq!(args, vec!["server.js"]);
518    }
519
520    #[test]
521    fn no_wrap_already_wrapped() {
522        // If user already set command to "cmd", don't double-wrap
523        let (cmd, args) =
524            wrap_cmd_script("cmd", &["/C".into(), "npx".into(), "-y".into()], "cmd.exe");
525        assert_eq!(cmd, "cmd");
526        assert_eq!(args, vec!["/C", "npx", "-y"]);
527    }
528
529    #[test]
530    fn wrap_case_insensitive() {
531        let (cmd, args) = wrap_cmd_script("NPX", &["-y".into(), "@pkg/server".into()], "cmd.exe");
532        assert_eq!(cmd, "cmd.exe");
533        assert_eq!(args, vec!["/C", "NPX", "-y", "@pkg/server"]);
534    }
535
536    #[test]
537    fn wrap_preserves_original_command_in_args() {
538        // The original command (with original casing) should appear in args
539        let (cmd, args) = wrap_cmd_script("Npx", &["-y".into()], "cmd.exe");
540        assert_eq!(cmd, "cmd.exe");
541        assert_eq!(args[1], "Npx"); // original casing preserved
542    }
543
544    #[test]
545    fn no_wrap_python() {
546        let (cmd, args) = wrap_cmd_script("python", &["-m".into(), "server".into()], "cmd.exe");
547        assert_eq!(cmd, "python");
548        assert_eq!(args, vec!["-m", "server"]);
549    }
550}