use std::collections::BTreeMap;
use std::io;
use std::process::Stdio;
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde_json::Value as JsonValue;
use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
use tokio::process::{Child, ChildStdin, ChildStdout, Command};
use crate::mcp_client::{McpClientBootstrap, McpClientTransport, McpStdioTransport};
use super::types::{
JsonRpcId, JsonRpcRequest, JsonRpcResponse, McpInitializeClientInfo, McpInitializeParams,
McpInitializeResult, McpListResourcesParams, McpListResourcesResult, McpListToolsParams,
McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpToolCallParams,
McpToolCallResult,
};
#[derive(Debug)]
pub struct McpStdioProcess {
child: Child,
stdin: ChildStdin,
stdout: BufReader<ChildStdout>,
}
impl McpStdioProcess {
pub fn spawn(transport: &McpStdioTransport) -> io::Result<Self> {
let mut command = Command::new(&transport.command);
command
.args(&transport.args)
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::inherit());
apply_env(&mut command, &transport.env);
let mut child = command.spawn()?;
let stdin = child
.stdin
.take()
.ok_or_else(|| io::Error::other("stdio MCP process missing stdin pipe"))?;
let stdout = child
.stdout
.take()
.ok_or_else(|| io::Error::other("stdio MCP process missing stdout pipe"))?;
Ok(Self {
child,
stdin,
stdout: BufReader::new(stdout),
})
}
pub async fn write_all(&mut self, bytes: &[u8]) -> io::Result<()> {
self.stdin.write_all(bytes).await
}
pub async fn flush(&mut self) -> io::Result<()> {
self.stdin.flush().await
}
pub async fn write_line(&mut self, line: &str) -> io::Result<()> {
self.write_all(line.as_bytes()).await?;
self.write_all(b"\n").await?;
self.flush().await
}
pub async fn read_line(&mut self) -> io::Result<String> {
let mut line = String::new();
let bytes_read = self.stdout.read_line(&mut line).await?;
if bytes_read == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"MCP stdio stream closed while reading line",
));
}
Ok(line)
}
pub async fn read_available(&mut self) -> io::Result<Vec<u8>> {
let mut buffer = vec![0_u8; 4096];
let read = self.stdout.read(&mut buffer).await?;
buffer.truncate(read);
Ok(buffer)
}
pub async fn write_frame(&mut self, payload: &[u8]) -> io::Result<()> {
let encoded = encode_frame(payload);
self.write_all(&encoded).await?;
self.flush().await
}
pub async fn read_frame(&mut self) -> io::Result<Vec<u8>> {
const MAX_FRAME_SIZE: usize = 50 * 1024 * 1024; let mut content_length = None;
loop {
let mut line = String::new();
let bytes_read = self.stdout.read_line(&mut line).await?;
if bytes_read == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"MCP stdio stream closed while reading headers",
));
}
if line == "\r\n" {
break;
}
if let Some(value) = line
.strip_prefix("Content-Length:")
.or_else(|| line.strip_prefix("content-length:"))
{
let parsed = value
.trim()
.parse::<usize>()
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
content_length = Some(parsed);
}
}
let content_length = content_length.ok_or_else(|| {
io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header")
})?;
if content_length > MAX_FRAME_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"MCP frame too large: {content_length} bytes exceeds {MAX_FRAME_SIZE} limit"
),
));
}
let mut payload = vec![0_u8; content_length];
self.stdout.read_exact(&mut payload).await?;
Ok(payload)
}
pub async fn write_jsonrpc_message<T: Serialize>(&mut self, message: &T) -> io::Result<()> {
let body = serde_json::to_vec(message)
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
self.write_frame(&body).await
}
pub async fn read_jsonrpc_message<T: DeserializeOwned>(&mut self) -> io::Result<T> {
let payload = self.read_frame().await?;
serde_json::from_slice(&payload)
.map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
}
pub async fn send_request<T: Serialize>(
&mut self,
request: &JsonRpcRequest<T>,
) -> io::Result<()> {
self.write_jsonrpc_message(request).await
}
pub async fn read_response<T: DeserializeOwned>(&mut self) -> io::Result<JsonRpcResponse<T>> {
self.read_jsonrpc_message().await
}
pub async fn request<TParams: Serialize, TResult: DeserializeOwned>(
&mut self,
id: JsonRpcId,
method: impl Into<String>,
params: Option<TParams>,
) -> io::Result<JsonRpcResponse<TResult>> {
let method = method.into();
let request = JsonRpcRequest::new(id.clone(), method.clone(), params);
self.send_request(&request).await?;
let response: JsonRpcResponse<TResult> = self.read_response().await?;
if response.id != id {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!(
"JSON-RPC response id mismatch for {method}: expected {:?}, got {:?}",
id, response.id
),
));
}
Ok(response)
}
pub async fn initialize(
&mut self,
id: JsonRpcId,
params: McpInitializeParams,
) -> io::Result<JsonRpcResponse<McpInitializeResult>> {
self.request(id, "initialize", Some(params)).await
}
pub async fn list_tools(
&mut self,
id: JsonRpcId,
params: Option<McpListToolsParams>,
) -> io::Result<JsonRpcResponse<McpListToolsResult>> {
self.request(id, "tools/list", params).await
}
pub async fn call_tool(
&mut self,
id: JsonRpcId,
params: McpToolCallParams,
) -> io::Result<JsonRpcResponse<McpToolCallResult>> {
self.request(id, "tools/call", Some(params)).await
}
pub async fn list_resources(
&mut self,
id: JsonRpcId,
params: Option<McpListResourcesParams>,
) -> io::Result<JsonRpcResponse<McpListResourcesResult>> {
self.request(id, "resources/list", params).await
}
pub async fn read_resource(
&mut self,
id: JsonRpcId,
params: McpReadResourceParams,
) -> io::Result<JsonRpcResponse<McpReadResourceResult>> {
self.request(id, "resources/read", Some(params)).await
}
pub async fn terminate(&mut self) -> io::Result<()> {
self.child.kill().await
}
pub async fn wait(&mut self) -> io::Result<std::process::ExitStatus> {
self.child.wait().await
}
pub(crate) async fn shutdown(&mut self) -> io::Result<()> {
if self.child.try_wait()?.is_none() {
self.child.kill().await?;
}
let _ = self.child.wait().await?;
Ok(())
}
}
pub fn spawn_mcp_stdio_process(bootstrap: &McpClientBootstrap) -> io::Result<McpStdioProcess> {
match &bootstrap.transport {
McpClientTransport::Stdio(transport) => McpStdioProcess::spawn(transport),
other => Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!(
"MCP bootstrap transport for {} is not stdio: {other:?}",
bootstrap.server_name
),
)),
}
}
fn apply_env(command: &mut Command, env: &BTreeMap<String, String>) {
for (key, value) in env {
command.env(key, value);
}
}
fn encode_frame(payload: &[u8]) -> Vec<u8> {
let header = format!("Content-Length: {}\r\n\r\n", payload.len());
let mut framed = header.into_bytes();
framed.extend_from_slice(payload);
framed
}
pub(crate) fn default_initialize_params() -> McpInitializeParams {
McpInitializeParams {
protocol_version: "2025-03-26".to_string(),
capabilities: JsonValue::Object(serde_json::Map::new()),
client_info: McpInitializeClientInfo {
name: "runtime".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
},
}
}