use std::collections::BTreeMap;
use std::process::Stdio;
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::Duration;
use anyhow::{Context, Result, bail};
use async_trait::async_trait;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use tokio::sync::Mutex;
use super::client::McpClient;
use super::types::{CallToolResult, InitializeResult, ListToolsResult, ServerStatus};
const DEFAULT_TIMEOUT_MS: u64 = 30_000;
const MAX_SKIP_LINES: usize = 100;
pub struct StdioClient {
server_name: String,
command: String,
args: Vec<String>,
env: BTreeMap<String, String>,
timeout_ms: u64,
status: Arc<Mutex<ServerStatus>>,
next_id: AtomicU64,
process: Arc<Mutex<Option<Child>>>,
stdin: Arc<Mutex<Option<ChildStdin>>>,
reader: Arc<Mutex<Option<BufReader<ChildStdout>>>>,
preread_line: Arc<Mutex<Option<String>>>,
request_lock: Arc<Mutex<()>>,
}
impl StdioClient {
pub fn new(
server_name: String,
command: String,
args: Vec<String>,
env: BTreeMap<String, String>,
timeout_ms: Option<u64>,
) -> Self {
Self {
server_name,
command,
args,
env,
timeout_ms: timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS),
status: Arc::new(Mutex::new(ServerStatus::Disconnected)),
next_id: AtomicU64::new(1),
process: Arc::new(Mutex::new(None)),
stdin: Arc::new(Mutex::new(None)),
reader: Arc::new(Mutex::new(None)),
preread_line: Arc::new(Mutex::new(None)),
request_lock: Arc::new(Mutex::new(())),
}
}
async fn start(&self) -> Result<()> {
#[cfg(target_os = "windows")]
let (command, args) = windows_wrap_command(&self.command, &self.args);
#[cfg(not(target_os = "windows"))]
let (command, args) = (self.command.clone(), self.args.clone());
let mut cmd = Command::new(&command);
cmd.args(&args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::null());
for (key, value) in &self.env {
cmd.env(key, value);
}
crate::process_utils::suppress_console_window(&mut cmd);
let mut child = cmd.spawn().with_context(|| {
#[cfg(target_os = "windows")]
{
let msg = format!(
"Failed to spawn MCP server: {}. \
On Windows, commands like 'npx' are .cmd scripts and must \
be executed through 'cmd /C'. AtomCode wraps known commands \
automatically; if this is a custom .cmd/.bat, set command to \
'cmd' and add '/C' before the script name in args.",
self.command
);
msg
}
#[cfg(not(target_os = "windows"))]
{
format!("Failed to spawn MCP server: {}", self.command)
}
})?;
let stdin = child.stdin.take().context("Failed to get stdin")?;
let stdout = child.stdout.take().context("Failed to get stdout")?;
let reader = BufReader::new(stdout);
*self.process.lock().await = Some(child);
*self.stdin.lock().await = Some(stdin);
*self.reader.lock().await = Some(reader);
Ok(())
}
async fn send_request(
&self,
method: &str,
params: Option<serde_json::Value>,
) -> Result<serde_json::Value> {
let _req_guard = self.request_lock.lock().await;
let id = self.next_id.fetch_add(1, Ordering::SeqCst);
let mut request = serde_json::Map::new();
request.insert(
"jsonrpc".to_string(),
serde_json::Value::String("2.0".to_string()),
);
request.insert("id".to_string(), serde_json::Value::Number(id.into()));
request.insert(
"method".to_string(),
serde_json::Value::String(method.to_string()),
);
if let Some(p) = params {
request.insert("params".to_string(), p);
}
let request = serde_json::Value::Object(request);
let timeout = Duration::from_millis(self.timeout_ms);
{
let mut stdin = self.stdin.lock().await;
let stdin = stdin.as_mut().context("MCP server not connected (stdin)")?;
let mut body = serde_json::to_vec(&request)?;
body.push(b'\n');
stdin.write_all(&body).await?;
stdin.flush().await?;
}
let result = tokio::time::timeout(timeout, self.recv_jsonrpc_response())
.await
.with_context(|| {
format!(
"MCP request {} timed out after {}ms",
method, self.timeout_ms
)
})??;
if let Some(error) = result.error {
bail!("MCP error {} (code {}): {}", error.message, error.code, "");
}
result
.result
.ok_or_else(|| anyhow::anyhow!("MCP response missing result"))
}
}
#[async_trait]
impl McpClient for StdioClient {
async fn initialize(&mut self) -> Result<InitializeResult> {
let mut status = self.status.lock().await;
*status = ServerStatus::Connecting;
drop(status);
self.start().await?;
self.drain_startup_messages().await?;
let params = serde_json::json!({
"protocolVersion": "2024-11-05",
"capabilities": {
"tools": {}
},
"clientInfo": {
"name": "atomcode",
"version": env!("CARGO_PKG_VERSION")
}
});
let result: InitializeResult =
serde_json::from_value(self.send_request("initialize", Some(params)).await?)
.context("Failed to parse initialize result")?;
{
let mut stdin = self.stdin.lock().await;
if let Some(stdin) = stdin.as_mut() {
let notification = serde_json::json!({
"jsonrpc": "2.0",
"method": "notifications/initialized"
});
let mut body = serde_json::to_vec(¬ification)?;
body.push(b'\n');
stdin.write_all(&body).await?;
stdin.flush().await?;
}
}
let mut status = self.status.lock().await;
*status = ServerStatus::Connected;
Ok(result)
}
async fn list_tools(&self) -> Result<ListToolsResult> {
let result = self.send_request("tools/list", None).await?;
serde_json::from_value(result).context("Failed to parse tools/list result")
}
async fn call_tool(
&self,
tool_name: &str,
arguments: serde_json::Value,
) -> Result<CallToolResult> {
let params = serde_json::json!({
"name": tool_name,
"arguments": arguments
});
let result = self.send_request("tools/call", Some(params)).await?;
serde_json::from_value(result).context("Failed to parse tools/call result")
}
fn server_name(&self) -> &str {
&self.server_name
}
fn status(&self) -> ServerStatus {
self.status
.try_lock()
.map(|s| s.clone())
.unwrap_or(ServerStatus::Disconnected)
}
}
impl StdioClient {
async fn recv_jsonrpc_response(&self) -> Result<super::types::JsonRpcResponse> {
let mut reader = self.reader.lock().await;
let reader = reader
.as_mut()
.context("MCP server not connected (reader)")?;
let mut skipped_lines = 0;
loop {
let line = if let Some(s) = self.preread_line.lock().await.take() {
s
} else {
let mut buf = String::new();
loop {
buf.clear();
let n = reader.read_line(&mut buf).await?;
if n == 0 {
bail!("MCP server closed connection");
}
if !buf.trim().is_empty() {
break;
}
}
buf
};
let body = line.trim_end_matches(['\r', '\n']).trim_start();
if body.starts_with('{') || body.starts_with('[') {
return serde_json::from_str(body)
.context("Failed to parse NDJSON MCP message as JSON-RPC");
}
if strip_prefix_ci(body, "content-length:").is_some() {
return read_content_length_message(reader, line).await;
}
skipped_lines += 1;
if skipped_lines > MAX_SKIP_LINES {
bail!(
"MCP stdio: too many non-protocol lines (>{MAX_SKIP_LINES}), last line: {}",
body.chars().take(80).collect::<String>()
);
}
}
}
async fn drain_startup_messages(&self) -> Result<()> {
let _ = tokio::time::timeout(Duration::from_millis(500), async {
loop {
let mut line = String::new();
let mut reader = self.reader.lock().await;
let Some(r) = reader.as_mut() else {
return;
};
let read_res =
tokio::time::timeout(Duration::from_millis(80), r.read_line(&mut line)).await;
drop(reader);
match read_res {
Err(_) | Ok(Err(_)) | Ok(Ok(0)) => return,
Ok(Ok(_)) => {
let t = line.trim();
if t.is_empty() {
continue;
}
let js = t.trim_start();
if js.starts_with('{')
|| js.starts_with('[')
|| strip_prefix_ci(js, "content-length:").is_some()
{
*self.preread_line.lock().await = Some(line);
return;
}
}
}
}
})
.await;
Ok(())
}
}
fn strip_prefix_ci<'a>(s: &'a str, prefix_lower: &'static str) -> Option<&'a str> {
let b = s.as_bytes();
let p = prefix_lower.as_bytes();
if b.len() < p.len() {
return None;
}
if !b[..p.len()].eq_ignore_ascii_case(p) {
return None;
}
Some(&s[p.len()..])
}
async fn read_content_length_message(
reader: &mut BufReader<ChildStdout>,
mut line: String,
) -> Result<super::types::JsonRpcResponse> {
let mut content_length: Option<usize> = None;
loop {
let t = line.trim_end_matches(['\r', '\n']).trim();
if t.is_empty() {
break;
}
if let Some(rest) = strip_prefix_ci(t, "content-length:") {
content_length = Some(rest.trim().parse().context("Invalid Content-Length")?);
}
line.clear();
let n = reader.read_line(&mut line).await?;
if n == 0 {
bail!("MCP server closed connection while reading headers");
}
}
let length = content_length.context("Missing Content-Length header")?;
let mut body = vec![0u8; length];
reader.read_exact(&mut body).await?;
serde_json::from_slice(&body).context("Failed to parse JSON-RPC response")
}
#[cfg_attr(not(target_os = "windows"), allow(dead_code))]
fn wrap_cmd_script(command: &str, args: &[String], shell: &str) -> (String, Vec<String>) {
const CMD_SCRIPTS: &[&str] = &[
"npx",
"npm",
"npx.cmd",
"npm.cmd",
"yarn",
"yarn.cmd",
"pnpm",
"pnpm.cmd",
];
let lower = command.to_ascii_lowercase();
let needs_wrap = CMD_SCRIPTS.iter().any(|&s| lower == s)
|| lower.ends_with(".cmd")
|| lower.ends_with(".bat");
if needs_wrap {
let mut wrapped_args = vec!["/C".to_string(), command.to_string()];
wrapped_args.extend(args.iter().cloned());
(shell.to_string(), wrapped_args)
} else {
(command.to_string(), args.to_vec())
}
}
#[cfg(target_os = "windows")]
fn windows_wrap_command(command: &str, args: &[String]) -> (String, Vec<String>) {
wrap_cmd_script(command, args, "cmd.exe")
}
impl Drop for StdioClient {
fn drop(&mut self) {
if let Ok(mut process) = self.process.try_lock() {
if let Some(mut child) = process.take() {
let _ = child.start_kill();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wrap_npx() {
let (cmd, args) = wrap_cmd_script("npx", &["-y".into(), "@pkg/server".into()], "cmd.exe");
assert_eq!(cmd, "cmd.exe");
assert_eq!(args, vec!["/C", "npx", "-y", "@pkg/server"]);
}
#[test]
fn wrap_npx_cmd_suffix() {
let (cmd, args) = wrap_cmd_script("npx.cmd", &["-y".into(), "@pkg/server".into()], "cmd.exe");
assert_eq!(cmd, "cmd.exe");
assert_eq!(args, vec!["/C", "npx.cmd", "-y", "@pkg/server"]);
}
#[test]
fn wrap_npm() {
let (cmd, args) = wrap_cmd_script("npm", &["install".into()], "cmd.exe");
assert_eq!(cmd, "cmd.exe");
assert_eq!(args, vec!["/C", "npm", "install"]);
}
#[test]
fn wrap_yarn() {
let (cmd, args) = wrap_cmd_script("yarn", &["add".into(), "lodash".into()], "cmd.exe");
assert_eq!(cmd, "cmd.exe");
assert_eq!(args, vec!["/C", "yarn", "add", "lodash"]);
}
#[test]
fn wrap_pnpm() {
let (cmd, args) = wrap_cmd_script("pnpm", &["install".into()], "cmd.exe");
assert_eq!(cmd, "cmd.exe");
assert_eq!(args, vec!["/C", "pnpm", "install"]);
}
#[test]
fn wrap_custom_bat() {
let (cmd, args) = wrap_cmd_script("my-script.bat", &["--flag".into()], "cmd.exe");
assert_eq!(cmd, "cmd.exe");
assert_eq!(args, vec!["/C", "my-script.bat", "--flag"]);
}
#[test]
fn wrap_custom_cmd_suffix() {
let (cmd, args) = wrap_cmd_script("build.cmd", &[], "cmd.exe");
assert_eq!(cmd, "cmd.exe");
assert_eq!(args, vec!["/C", "build.cmd"]);
}
#[test]
fn no_wrap_exe() {
let (cmd, args) = wrap_cmd_script("node", &["server.js".into()], "cmd.exe");
assert_eq!(cmd, "node");
assert_eq!(args, vec!["server.js"]);
}
#[test]
fn no_wrap_already_wrapped() {
let (cmd, args) =
wrap_cmd_script("cmd", &["/C".into(), "npx".into(), "-y".into()], "cmd.exe");
assert_eq!(cmd, "cmd");
assert_eq!(args, vec!["/C", "npx", "-y"]);
}
#[test]
fn wrap_case_insensitive() {
let (cmd, args) = wrap_cmd_script("NPX", &["-y".into(), "@pkg/server".into()], "cmd.exe");
assert_eq!(cmd, "cmd.exe");
assert_eq!(args, vec!["/C", "NPX", "-y", "@pkg/server"]);
}
#[test]
fn wrap_preserves_original_command_in_args() {
let (cmd, args) = wrap_cmd_script("Npx", &["-y".into()], "cmd.exe");
assert_eq!(cmd, "cmd.exe");
assert_eq!(args[1], "Npx"); }
#[test]
fn no_wrap_python() {
let (cmd, args) = wrap_cmd_script("python", &["-m".into(), "server".into()], "cmd.exe");
assert_eq!(cmd, "python");
assert_eq!(args, vec!["-m", "server"]);
}
}