async_mcp/transport/
stdio_transport.rs

1use super::{Message, Transport};
2use anyhow::Result;
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::io::{self, BufRead, Write};
6use std::process::Stdio;
7use std::sync::Arc;
8use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
9use tokio::process::Child;
10use tokio::sync::Mutex;
11use tracing::debug;
12
13/// Stdio transport for server with json serialization
14/// TODO: support for other binary serialzation formats
15#[derive(Default, Clone)]
16pub struct ServerStdioTransport;
17#[async_trait]
18impl Transport for ServerStdioTransport {
19    async fn receive(&self) -> Result<Option<Message>> {
20        let stdin = io::stdin();
21        let mut reader = stdin.lock();
22        let mut line = String::new();
23        reader.read_line(&mut line)?;
24        if line.is_empty() {
25            return Ok(None);
26        }
27
28        debug!("Received: {line}");
29        let message: Message = serde_json::from_str(&line)?;
30        Ok(Some(message))
31    }
32
33    async fn send(&self, message: &Message) -> Result<()> {
34        let stdout = io::stdout();
35        let mut writer = stdout.lock();
36        let serialized = serde_json::to_string(message)?;
37        debug!("Sending: {serialized}");
38        writer.write_all(serialized.as_bytes())?;
39        writer.write_all(b"\n")?;
40        writer.flush()?;
41        Ok(())
42    }
43
44    async fn open(&self) -> Result<()> {
45        Ok(())
46    }
47
48    async fn close(&self) -> Result<()> {
49        Ok(())
50    }
51}
52
53/// ClientStdioTransport launches a child process and communicates with it via stdio
54#[derive(Clone)]
55pub struct ClientStdioTransport {
56    stdin: Arc<Mutex<Option<BufWriter<tokio::process::ChildStdin>>>>,
57    stdout: Arc<Mutex<Option<BufReader<tokio::process::ChildStdout>>>>,
58    child: Arc<Mutex<Option<Child>>>,
59    program: String,
60    args: Vec<String>,
61    env: Option<HashMap<String, String>>,
62}
63
64impl ClientStdioTransport {
65    pub fn new(program: &str, args: &[&str], env: Option<HashMap<String, String>>) -> Result<Self> {
66        Ok(ClientStdioTransport {
67            stdin: Arc::new(Mutex::new(None)),
68            stdout: Arc::new(Mutex::new(None)),
69            child: Arc::new(Mutex::new(None)),
70            program: program.to_string(),
71            args: args.iter().map(|&s| s.to_string()).collect(),
72            env,
73        })
74    }
75}
76#[async_trait]
77impl Transport for ClientStdioTransport {
78    async fn receive(&self) -> Result<Option<Message>> {
79        debug!("ClientStdioTransport: Starting to receive message");
80        let mut stdout = self.stdout.lock().await;
81        let stdout = stdout
82            .as_mut()
83            .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?;
84
85        let mut line = String::new();
86        debug!("ClientStdioTransport: Reading line from process");
87        let bytes_read = stdout.read_line(&mut line).await?;
88        debug!("ClientStdioTransport: Read {} bytes", bytes_read);
89
90        if bytes_read == 0 {
91            debug!("ClientStdioTransport: Received EOF from process");
92            return Ok(None);
93        }
94
95        let row = if line.len() > 1000 {
96            let start = &line[..100];
97            let end = &line[line.len() - 100..];
98            format!("{}...{}", start, end)
99        } else {
100            line.clone()
101        };
102        
103        debug!("ClientStdioTransport: Received from process: {}", row);
104        let message: Message = serde_json::from_str(&line).map_err(|e| {
105            tracing::error!("Failed to parse message: {}", e);
106            e
107        })?;
108        debug!("ClientStdioTransport: Successfully parsed message");
109        Ok(Some(message))
110    }
111
112    async fn send(&self, message: &Message) -> Result<()> {
113        debug!("ClientStdioTransport: Starting to send message");
114        let mut stdin = self.stdin.lock().await;
115        let stdin = stdin
116            .as_mut()
117            .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?;
118
119        let serialized = serde_json::to_string(message)?;
120        debug!("ClientStdioTransport: Sending to process: {serialized}");
121        stdin.write_all(serialized.as_bytes()).await?;
122        stdin.write_all(b"\n").await?;
123        stdin.flush().await?;
124        debug!("ClientStdioTransport: Successfully sent and flushed message");
125        Ok(())
126    }
127
128    async fn open(&self) -> Result<()> {
129        debug!("ClientStdioTransport: Opening transport");
130        let mut command = tokio::process::Command::new(&self.program);
131
132        // Set up the command with args and stdio
133        command
134            .args(&self.args)
135            .stdin(Stdio::piped())
136            .stdout(Stdio::piped());
137
138        // Add environment variables
139        if let Some(env) = &self.env {
140            for (key, value) in env {
141                command.env(key, value);
142            }
143        }
144
145        let mut child = command.spawn()?;
146
147        debug!("ClientStdioTransport: Child process spawned");
148        let stdin = child
149            .stdin
150            .take()
151            .ok_or_else(|| anyhow::anyhow!("Child process stdin not available"))?;
152        let stdout = child
153            .stdout
154            .take()
155            .ok_or_else(|| anyhow::anyhow!("Child process stdout not available"))?;
156
157        *self.stdin.lock().await = Some(BufWriter::new(stdin));
158        *self.stdout.lock().await = Some(BufReader::new(stdout));
159        *self.child.lock().await = Some(child);
160
161        Ok(())
162    }
163
164    async fn close(&self) -> Result<()> {
165        const GRACEFUL_TIMEOUT_MS: u64 = 1000;
166        const SIGTERM_TIMEOUT_MS: u64 = 500;
167        debug!("Starting graceful shutdown");
168        {
169            let mut stdin_guard = self.stdin.lock().await;
170            if let Some(stdin) = stdin_guard.as_mut() {
171                debug!("Flushing stdin");
172                stdin.flush().await?;
173            }
174            *stdin_guard = None;
175        }
176
177        let mut child_guard = self.child.lock().await;
178        let Some(child) = child_guard.as_mut() else {
179            debug!("No child process to close");
180            return Ok(());
181        };
182
183        debug!("Attempting graceful shutdown");
184        match child.try_wait()? {
185            Some(status) => {
186                debug!("Process already exited with status: {}", status);
187                *child_guard = None;
188                return Ok(());
189            }
190            None => {
191                debug!("Waiting for process to exit gracefully");
192                tokio::time::sleep(tokio::time::Duration::from_millis(GRACEFUL_TIMEOUT_MS)).await;
193            }
194        }
195
196        if child.try_wait()?.is_none() {
197            debug!("Process still running, sending SIGTERM");
198            child.kill().await?;
199            tokio::time::sleep(tokio::time::Duration::from_millis(SIGTERM_TIMEOUT_MS)).await;
200        }
201
202        if child.try_wait()?.is_none() {
203            debug!("Process not responding to SIGTERM, forcing kill");
204            child.kill().await?;
205        }
206
207        match child.wait().await {
208            Ok(status) => debug!("Process exited with status: {}", status),
209            Err(e) => debug!("Error waiting for process exit: {}", e),
210        }
211
212        *child_guard = None;
213        debug!("Shutdown complete");
214        Ok(())
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use crate::transport::{JsonRpcMessage, JsonRpcRequest, JsonRpcVersion};
221
222    use super::*;
223    use std::time::Duration;
224    #[tokio::test]
225    #[cfg(unix)]
226    async fn test_stdio_transport() -> Result<()> {
227        // Create transport connected to cat command which will stay alive
228        let transport = ClientStdioTransport::new("cat", &[], None)?;
229
230        // Create a test message
231        let test_message = JsonRpcMessage::Request(JsonRpcRequest {
232            id: 1,
233            method: "test".to_string(),
234            params: Some(serde_json::json!({"hello": "world"})),
235            jsonrpc: JsonRpcVersion::default(),
236        });
237
238        // Open transport
239        transport.open().await?;
240
241        // Send message
242        transport.send(&test_message).await?;
243
244        // Receive echoed message
245        let response = transport.receive().await?;
246
247        // Verify the response matches
248        assert_eq!(Some(test_message), response);
249
250        // Clean up
251        transport.close().await?;
252
253        Ok(())
254    }
255
256    #[tokio::test]
257    #[cfg(unix)]
258    async fn test_graceful_shutdown() -> Result<()> {
259        // Create transport with a sleep command that runs for 5 seconds
260        let transport = ClientStdioTransport::new("sleep", &["5"], None)?;
261        transport.open().await?;
262
263        // Spawn a task that will read from the transport
264        let transport_clone = transport.clone();
265        let read_handle = tokio::spawn(async move {
266            let result = transport_clone.receive().await;
267            debug!("Receive returned: {:?}", result);
268            result
269        });
270
271        // Wait a bit to ensure the process is running
272        tokio::time::sleep(Duration::from_millis(100)).await;
273
274        // Initiate graceful shutdown
275        let start = std::time::Instant::now();
276        transport.close().await?;
277        let shutdown_duration = start.elapsed();
278
279        // Verify that:
280        // 1. The read operation was cancelled (returned None)
281        // 2. The shutdown completed in less than 5 seconds (didn't wait for sleep)
282        // 3. The process was properly terminated
283        let read_result = read_handle.await?;
284        assert!(read_result.is_ok());
285        assert_eq!(read_result.unwrap(), None);
286        assert!(shutdown_duration < Duration::from_secs(5));
287
288        // Verify process is no longer running
289        let child_guard = transport.child.lock().await;
290        assert!(child_guard.is_none());
291
292        Ok(())
293    }
294
295    #[tokio::test]
296    #[cfg(unix)]
297    async fn test_shutdown_with_pending_io() -> Result<()> {
298        // Use 'read' command which will wait for input without echoing
299        let transport = ClientStdioTransport::new("read", &[], None)?;
300        transport.open().await?;
301
302        // Start a receive operation that will be pending
303        let transport_clone = transport.clone();
304        let read_handle = tokio::spawn(async move { transport_clone.receive().await });
305
306        // Give some time for read operation to start
307        tokio::time::sleep(Duration::from_millis(100)).await;
308
309        // Send a message (will be pending since 'read' won't echo)
310        let test_message = JsonRpcMessage::Request(JsonRpcRequest {
311            id: 1,
312            method: "test".to_string(),
313            params: Some(serde_json::json!({"hello": "world"})),
314            jsonrpc: JsonRpcVersion::default(),
315        });
316        transport.send(&test_message).await?;
317
318        // Initiate shutdown
319        transport.close().await?;
320
321        // Verify the read operation was cancelled cleanly
322        let read_result = read_handle.await?;
323        assert!(read_result.is_ok());
324        assert_eq!(read_result.unwrap(), None);
325
326        Ok(())
327    }
328}