Skip to main content

codineer_runtime/mcp_stdio/
process.rs

1use std::collections::BTreeMap;
2use std::io;
3use std::process::Stdio;
4
5use serde::de::DeserializeOwned;
6use serde::Serialize;
7use serde_json::Value as JsonValue;
8use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
9use tokio::process::{Child, ChildStdin, ChildStdout, Command};
10
11use crate::mcp_client::{McpClientBootstrap, McpClientTransport, McpStdioTransport};
12
13use super::types::{
14    JsonRpcId, JsonRpcRequest, JsonRpcResponse, McpInitializeClientInfo, McpInitializeParams,
15    McpInitializeResult, McpListResourcesParams, McpListResourcesResult, McpListToolsParams,
16    McpListToolsResult, McpReadResourceParams, McpReadResourceResult, McpToolCallParams,
17    McpToolCallResult,
18};
19
20#[derive(Debug)]
21pub struct McpStdioProcess {
22    child: Child,
23    stdin: ChildStdin,
24    stdout: BufReader<ChildStdout>,
25}
26
27impl McpStdioProcess {
28    pub fn spawn(transport: &McpStdioTransport) -> io::Result<Self> {
29        let mut command = Command::new(&transport.command);
30        command
31            .args(&transport.args)
32            .stdin(Stdio::piped())
33            .stdout(Stdio::piped())
34            .stderr(Stdio::inherit());
35        apply_env(&mut command, &transport.env);
36
37        let mut child = command.spawn()?;
38        let stdin = child
39            .stdin
40            .take()
41            .ok_or_else(|| io::Error::other("stdio MCP process missing stdin pipe"))?;
42        let stdout = child
43            .stdout
44            .take()
45            .ok_or_else(|| io::Error::other("stdio MCP process missing stdout pipe"))?;
46
47        Ok(Self {
48            child,
49            stdin,
50            stdout: BufReader::new(stdout),
51        })
52    }
53
54    pub async fn write_all(&mut self, bytes: &[u8]) -> io::Result<()> {
55        self.stdin.write_all(bytes).await
56    }
57
58    pub async fn flush(&mut self) -> io::Result<()> {
59        self.stdin.flush().await
60    }
61
62    pub async fn write_line(&mut self, line: &str) -> io::Result<()> {
63        self.write_all(line.as_bytes()).await?;
64        self.write_all(b"\n").await?;
65        self.flush().await
66    }
67
68    pub async fn read_line(&mut self) -> io::Result<String> {
69        let mut line = String::new();
70        let bytes_read = self.stdout.read_line(&mut line).await?;
71        if bytes_read == 0 {
72            return Err(io::Error::new(
73                io::ErrorKind::UnexpectedEof,
74                "MCP stdio stream closed while reading line",
75            ));
76        }
77        Ok(line)
78    }
79
80    pub async fn read_available(&mut self) -> io::Result<Vec<u8>> {
81        let mut buffer = vec![0_u8; 4096];
82        let read = self.stdout.read(&mut buffer).await?;
83        buffer.truncate(read);
84        Ok(buffer)
85    }
86
87    pub async fn write_frame(&mut self, payload: &[u8]) -> io::Result<()> {
88        let encoded = encode_frame(payload);
89        self.write_all(&encoded).await?;
90        self.flush().await
91    }
92
93    pub async fn read_frame(&mut self) -> io::Result<Vec<u8>> {
94        const MAX_FRAME_SIZE: usize = 50 * 1024 * 1024; // 50 MiB
95        let mut content_length = None;
96        loop {
97            let mut line = String::new();
98            let bytes_read = self.stdout.read_line(&mut line).await?;
99            if bytes_read == 0 {
100                return Err(io::Error::new(
101                    io::ErrorKind::UnexpectedEof,
102                    "MCP stdio stream closed while reading headers",
103                ));
104            }
105            if line == "\r\n" {
106                break;
107            }
108            if let Some(value) = line
109                .strip_prefix("Content-Length:")
110                .or_else(|| line.strip_prefix("content-length:"))
111            {
112                let parsed = value
113                    .trim()
114                    .parse::<usize>()
115                    .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
116                content_length = Some(parsed);
117            }
118        }
119
120        let content_length = content_length.ok_or_else(|| {
121            io::Error::new(io::ErrorKind::InvalidData, "missing Content-Length header")
122        })?;
123        if content_length > MAX_FRAME_SIZE {
124            return Err(io::Error::new(
125                io::ErrorKind::InvalidData,
126                format!(
127                    "MCP frame too large: {content_length} bytes exceeds {MAX_FRAME_SIZE} limit"
128                ),
129            ));
130        }
131        let mut payload = vec![0_u8; content_length];
132        self.stdout.read_exact(&mut payload).await?;
133        Ok(payload)
134    }
135
136    pub async fn write_jsonrpc_message<T: Serialize>(&mut self, message: &T) -> io::Result<()> {
137        let body = serde_json::to_vec(message)
138            .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))?;
139        self.write_frame(&body).await
140    }
141
142    pub async fn read_jsonrpc_message<T: DeserializeOwned>(&mut self) -> io::Result<T> {
143        let payload = self.read_frame().await?;
144        serde_json::from_slice(&payload)
145            .map_err(|error| io::Error::new(io::ErrorKind::InvalidData, error))
146    }
147
148    pub async fn send_request<T: Serialize>(
149        &mut self,
150        request: &JsonRpcRequest<T>,
151    ) -> io::Result<()> {
152        self.write_jsonrpc_message(request).await
153    }
154
155    pub async fn read_response<T: DeserializeOwned>(&mut self) -> io::Result<JsonRpcResponse<T>> {
156        self.read_jsonrpc_message().await
157    }
158
159    pub async fn request<TParams: Serialize, TResult: DeserializeOwned>(
160        &mut self,
161        id: JsonRpcId,
162        method: impl Into<String>,
163        params: Option<TParams>,
164    ) -> io::Result<JsonRpcResponse<TResult>> {
165        let method = method.into();
166        let request = JsonRpcRequest::new(id.clone(), method.clone(), params);
167        self.send_request(&request).await?;
168        let response: JsonRpcResponse<TResult> = self.read_response().await?;
169        if response.id != id {
170            return Err(io::Error::new(
171                io::ErrorKind::InvalidData,
172                format!(
173                    "JSON-RPC response id mismatch for {method}: expected {:?}, got {:?}",
174                    id, response.id
175                ),
176            ));
177        }
178        Ok(response)
179    }
180
181    pub async fn initialize(
182        &mut self,
183        id: JsonRpcId,
184        params: McpInitializeParams,
185    ) -> io::Result<JsonRpcResponse<McpInitializeResult>> {
186        self.request(id, "initialize", Some(params)).await
187    }
188
189    pub async fn list_tools(
190        &mut self,
191        id: JsonRpcId,
192        params: Option<McpListToolsParams>,
193    ) -> io::Result<JsonRpcResponse<McpListToolsResult>> {
194        self.request(id, "tools/list", params).await
195    }
196
197    pub async fn call_tool(
198        &mut self,
199        id: JsonRpcId,
200        params: McpToolCallParams,
201    ) -> io::Result<JsonRpcResponse<McpToolCallResult>> {
202        self.request(id, "tools/call", Some(params)).await
203    }
204
205    pub async fn list_resources(
206        &mut self,
207        id: JsonRpcId,
208        params: Option<McpListResourcesParams>,
209    ) -> io::Result<JsonRpcResponse<McpListResourcesResult>> {
210        self.request(id, "resources/list", params).await
211    }
212
213    pub async fn read_resource(
214        &mut self,
215        id: JsonRpcId,
216        params: McpReadResourceParams,
217    ) -> io::Result<JsonRpcResponse<McpReadResourceResult>> {
218        self.request(id, "resources/read", Some(params)).await
219    }
220
221    pub async fn terminate(&mut self) -> io::Result<()> {
222        self.child.kill().await
223    }
224
225    pub async fn wait(&mut self) -> io::Result<std::process::ExitStatus> {
226        self.child.wait().await
227    }
228
229    pub(crate) async fn shutdown(&mut self) -> io::Result<()> {
230        if self.child.try_wait()?.is_none() {
231            self.child.kill().await?;
232        }
233        let _ = self.child.wait().await?;
234        Ok(())
235    }
236}
237
238pub fn spawn_mcp_stdio_process(bootstrap: &McpClientBootstrap) -> io::Result<McpStdioProcess> {
239    match &bootstrap.transport {
240        McpClientTransport::Stdio(transport) => McpStdioProcess::spawn(transport),
241        other => Err(io::Error::new(
242            io::ErrorKind::InvalidInput,
243            format!(
244                "MCP bootstrap transport for {} is not stdio: {other:?}",
245                bootstrap.server_name
246            ),
247        )),
248    }
249}
250
251fn apply_env(command: &mut Command, env: &BTreeMap<String, String>) {
252    for (key, value) in env {
253        command.env(key, value);
254    }
255}
256
257fn encode_frame(payload: &[u8]) -> Vec<u8> {
258    let header = format!("Content-Length: {}\r\n\r\n", payload.len());
259    let mut framed = header.into_bytes();
260    framed.extend_from_slice(payload);
261    framed
262}
263
264pub(crate) fn default_initialize_params() -> McpInitializeParams {
265    McpInitializeParams {
266        protocol_version: "2025-03-26".to_string(),
267        capabilities: JsonValue::Object(serde_json::Map::new()),
268        client_info: McpInitializeClientInfo {
269            name: "runtime".to_string(),
270            version: env!("CARGO_PKG_VERSION").to_string(),
271        },
272    }
273}