async_mcp/transport/
stdio_transport.rs1use super::{Message, Transport};
2use anyhow::Result;
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::io::{self, BufRead, Write};
6use std::process::Stdio;
7use std::sync::Arc;
8use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
9use tokio::process::Child;
10use tokio::sync::Mutex;
11use tracing::debug;
12
13#[derive(Default, Clone)]
16pub struct ServerStdioTransport;
17#[async_trait]
18impl Transport for ServerStdioTransport {
19 async fn receive(&self) -> Result<Option<Message>> {
20 let stdin = io::stdin();
21 let mut reader = stdin.lock();
22 let mut line = String::new();
23 reader.read_line(&mut line)?;
24 if line.is_empty() {
25 return Ok(None);
26 }
27
28 debug!("Received: {line}");
29 let message: Message = serde_json::from_str(&line)?;
30 Ok(Some(message))
31 }
32
33 async fn send(&self, message: &Message) -> Result<()> {
34 let stdout = io::stdout();
35 let mut writer = stdout.lock();
36 let serialized = serde_json::to_string(message)?;
37 debug!("Sending: {serialized}");
38 writer.write_all(serialized.as_bytes())?;
39 writer.write_all(b"\n")?;
40 writer.flush()?;
41 Ok(())
42 }
43
44 async fn open(&self) -> Result<()> {
45 Ok(())
46 }
47
48 async fn close(&self) -> Result<()> {
49 Ok(())
50 }
51}
52
53#[derive(Clone)]
55pub struct ClientStdioTransport {
56 stdin: Arc<Mutex<Option<BufWriter<tokio::process::ChildStdin>>>>,
57 stdout: Arc<Mutex<Option<BufReader<tokio::process::ChildStdout>>>>,
58 child: Arc<Mutex<Option<Child>>>,
59 program: String,
60 args: Vec<String>,
61 env: Option<HashMap<String, String>>,
62}
63
64impl ClientStdioTransport {
65 pub fn new(program: &str, args: &[&str], env: Option<HashMap<String, String>>) -> Result<Self> {
66 Ok(ClientStdioTransport {
67 stdin: Arc::new(Mutex::new(None)),
68 stdout: Arc::new(Mutex::new(None)),
69 child: Arc::new(Mutex::new(None)),
70 program: program.to_string(),
71 args: args.iter().map(|&s| s.to_string()).collect(),
72 env,
73 })
74 }
75}
76#[async_trait]
77impl Transport for ClientStdioTransport {
78 async fn receive(&self) -> Result<Option<Message>> {
79 debug!("ClientStdioTransport: Starting to receive message");
80 let mut stdout = self.stdout.lock().await;
81 let stdout = stdout
82 .as_mut()
83 .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?;
84
85 let mut line = String::new();
86 debug!("ClientStdioTransport: Reading line from process");
87 let bytes_read = stdout.read_line(&mut line).await?;
88 debug!("ClientStdioTransport: Read {} bytes", bytes_read);
89
90 if bytes_read == 0 {
91 debug!("ClientStdioTransport: Received EOF from process");
92 return Ok(None);
93 }
94
95 let row = if line.len() > 1000 {
96 let start = &line[..100];
97 let end = &line[line.len() - 100..];
98 format!("{}...{}", start, end)
99 } else {
100 line.clone()
101 };
102
103 debug!("ClientStdioTransport: Received from process: {}", row);
104 let message: Message = serde_json::from_str(&line).map_err(|e| {
105 tracing::error!("Failed to parse message: {}", e);
106 e
107 })?;
108 debug!("ClientStdioTransport: Successfully parsed message");
109 Ok(Some(message))
110 }
111
112 async fn send(&self, message: &Message) -> Result<()> {
113 debug!("ClientStdioTransport: Starting to send message");
114 let mut stdin = self.stdin.lock().await;
115 let stdin = stdin
116 .as_mut()
117 .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?;
118
119 let serialized = serde_json::to_string(message)?;
120 debug!("ClientStdioTransport: Sending to process: {serialized}");
121 stdin.write_all(serialized.as_bytes()).await?;
122 stdin.write_all(b"\n").await?;
123 stdin.flush().await?;
124 debug!("ClientStdioTransport: Successfully sent and flushed message");
125 Ok(())
126 }
127
128 async fn open(&self) -> Result<()> {
129 debug!("ClientStdioTransport: Opening transport");
130 let mut command = tokio::process::Command::new(&self.program);
131
132 command
134 .args(&self.args)
135 .stdin(Stdio::piped())
136 .stdout(Stdio::piped());
137
138 if let Some(env) = &self.env {
140 for (key, value) in env {
141 command.env(key, value);
142 }
143 }
144
145 let mut child = command.spawn()?;
146
147 debug!("ClientStdioTransport: Child process spawned");
148 let stdin = child
149 .stdin
150 .take()
151 .ok_or_else(|| anyhow::anyhow!("Child process stdin not available"))?;
152 let stdout = child
153 .stdout
154 .take()
155 .ok_or_else(|| anyhow::anyhow!("Child process stdout not available"))?;
156
157 *self.stdin.lock().await = Some(BufWriter::new(stdin));
158 *self.stdout.lock().await = Some(BufReader::new(stdout));
159 *self.child.lock().await = Some(child);
160
161 Ok(())
162 }
163
164 async fn close(&self) -> Result<()> {
165 const GRACEFUL_TIMEOUT_MS: u64 = 1000;
166 const SIGTERM_TIMEOUT_MS: u64 = 500;
167 debug!("Starting graceful shutdown");
168 {
169 let mut stdin_guard = self.stdin.lock().await;
170 if let Some(stdin) = stdin_guard.as_mut() {
171 debug!("Flushing stdin");
172 stdin.flush().await?;
173 }
174 *stdin_guard = None;
175 }
176
177 let mut child_guard = self.child.lock().await;
178 let Some(child) = child_guard.as_mut() else {
179 debug!("No child process to close");
180 return Ok(());
181 };
182
183 debug!("Attempting graceful shutdown");
184 match child.try_wait()? {
185 Some(status) => {
186 debug!("Process already exited with status: {}", status);
187 *child_guard = None;
188 return Ok(());
189 }
190 None => {
191 debug!("Waiting for process to exit gracefully");
192 tokio::time::sleep(tokio::time::Duration::from_millis(GRACEFUL_TIMEOUT_MS)).await;
193 }
194 }
195
196 if child.try_wait()?.is_none() {
197 debug!("Process still running, sending SIGTERM");
198 child.kill().await?;
199 tokio::time::sleep(tokio::time::Duration::from_millis(SIGTERM_TIMEOUT_MS)).await;
200 }
201
202 if child.try_wait()?.is_none() {
203 debug!("Process not responding to SIGTERM, forcing kill");
204 child.kill().await?;
205 }
206
207 match child.wait().await {
208 Ok(status) => debug!("Process exited with status: {}", status),
209 Err(e) => debug!("Error waiting for process exit: {}", e),
210 }
211
212 *child_guard = None;
213 debug!("Shutdown complete");
214 Ok(())
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use crate::transport::{JsonRpcMessage, JsonRpcRequest, JsonRpcVersion};
221
222 use super::*;
223 use std::time::Duration;
224 #[tokio::test]
225 #[cfg(unix)]
226 async fn test_stdio_transport() -> Result<()> {
227 let transport = ClientStdioTransport::new("cat", &[], None)?;
229
230 let test_message = JsonRpcMessage::Request(JsonRpcRequest {
232 id: 1,
233 method: "test".to_string(),
234 params: Some(serde_json::json!({"hello": "world"})),
235 jsonrpc: JsonRpcVersion::default(),
236 });
237
238 transport.open().await?;
240
241 transport.send(&test_message).await?;
243
244 let response = transport.receive().await?;
246
247 assert_eq!(Some(test_message), response);
249
250 transport.close().await?;
252
253 Ok(())
254 }
255
256 #[tokio::test]
257 #[cfg(unix)]
258 async fn test_graceful_shutdown() -> Result<()> {
259 let transport = ClientStdioTransport::new("sleep", &["5"], None)?;
261 transport.open().await?;
262
263 let transport_clone = transport.clone();
265 let read_handle = tokio::spawn(async move {
266 let result = transport_clone.receive().await;
267 debug!("Receive returned: {:?}", result);
268 result
269 });
270
271 tokio::time::sleep(Duration::from_millis(100)).await;
273
274 let start = std::time::Instant::now();
276 transport.close().await?;
277 let shutdown_duration = start.elapsed();
278
279 let read_result = read_handle.await?;
284 assert!(read_result.is_ok());
285 assert_eq!(read_result.unwrap(), None);
286 assert!(shutdown_duration < Duration::from_secs(5));
287
288 let child_guard = transport.child.lock().await;
290 assert!(child_guard.is_none());
291
292 Ok(())
293 }
294
295 #[tokio::test]
296 #[cfg(unix)]
297 async fn test_shutdown_with_pending_io() -> Result<()> {
298 let transport = ClientStdioTransport::new("read", &[], None)?;
300 transport.open().await?;
301
302 let transport_clone = transport.clone();
304 let read_handle = tokio::spawn(async move { transport_clone.receive().await });
305
306 tokio::time::sleep(Duration::from_millis(100)).await;
308
309 let test_message = JsonRpcMessage::Request(JsonRpcRequest {
311 id: 1,
312 method: "test".to_string(),
313 params: Some(serde_json::json!({"hello": "world"})),
314 jsonrpc: JsonRpcVersion::default(),
315 });
316 transport.send(&test_message).await?;
317
318 transport.close().await?;
320
321 let read_result = read_handle.await?;
323 assert!(read_result.is_ok());
324 assert_eq!(read_result.unwrap(), None);
325
326 Ok(())
327 }
328}