Skip to main content

a3s_code_core/mcp/transport/
stdio.rs

1//! Stdio Transport for MCP
2//!
3//! Implements MCP transport over standard input/output for local process communication.
4
5use super::McpTransport;
6use crate::mcp::protocol::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, McpNotification};
7use anyhow::{anyhow, Context, Result};
8use async_trait::async_trait;
9use std::collections::HashMap;
10use std::process::Stdio;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, Command};
15use tokio::sync::{mpsc, oneshot, RwLock};
16
17/// Default request timeout for MCP tool calls
18const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 60;
19
20/// Stdio transport for MCP servers
21pub struct StdioTransport {
22    /// Child process
23    child: RwLock<Option<Child>>,
24    /// Stdin writer
25    stdin_tx: mpsc::Sender<String>,
26    /// Pending requests (id -> response sender)
27    pending: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
28    /// Notification receiver
29    notification_rx: RwLock<Option<mpsc::Receiver<McpNotification>>>,
30    /// Connected flag
31    connected: AtomicBool,
32    /// Per-request timeout in seconds
33    request_timeout_secs: u64,
34}
35
36impl StdioTransport {
37    /// Create a new stdio transport by spawning a process
38    pub async fn spawn(
39        command: &str,
40        args: &[String],
41        env: &HashMap<String, String>,
42    ) -> Result<Self> {
43        Self::spawn_with_timeout(command, args, env, DEFAULT_REQUEST_TIMEOUT_SECS).await
44    }
45
46    /// Create a new stdio transport with a custom request timeout
47    pub async fn spawn_with_timeout(
48        command: &str,
49        args: &[String],
50        env: &HashMap<String, String>,
51        request_timeout_secs: u64,
52    ) -> Result<Self> {
53        // Spawn the process
54        let mut cmd = Command::new(command);
55        cmd.args(args)
56            .stdin(Stdio::piped())
57            .stdout(Stdio::piped())
58            .stderr(Stdio::piped())
59            .kill_on_drop(true);
60
61        // Add environment variables
62        for (key, value) in env {
63            cmd.env(key, value);
64        }
65
66        let mut child = cmd
67            .spawn()
68            .with_context(|| format!("Failed to spawn MCP server: {} {:?}", command, args))?;
69
70        let stdin = child.stdin.take().ok_or_else(|| anyhow!("No stdin"))?;
71        let stdout = child.stdout.take().ok_or_else(|| anyhow!("No stdout"))?;
72
73        // Create channels
74        let (stdin_tx, mut stdin_rx) = mpsc::channel::<String>(100);
75        let (notification_tx, notification_rx) = mpsc::channel::<McpNotification>(100);
76        let pending: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>> =
77            Arc::new(RwLock::new(HashMap::new()));
78
79        // Spawn stdin writer task
80        let mut stdin_writer = stdin;
81        tokio::spawn(async move {
82            while let Some(msg) = stdin_rx.recv().await {
83                if let Err(e) = stdin_writer.write_all(msg.as_bytes()).await {
84                    tracing::error!("Failed to write to MCP stdin: {}", e);
85                    break;
86                }
87                if let Err(e) = stdin_writer.flush().await {
88                    tracing::error!("Failed to flush MCP stdin: {}", e);
89                    break;
90                }
91            }
92        });
93
94        // Spawn stdout reader task
95        let pending_clone = pending.clone();
96        tokio::spawn(async move {
97            let mut reader = BufReader::new(stdout);
98            let mut line = String::new();
99
100            loop {
101                line.clear();
102                match reader.read_line(&mut line).await {
103                    Ok(0) => {
104                        tracing::debug!("MCP stdout closed");
105                        break;
106                    }
107                    Ok(_) => {
108                        let trimmed = line.trim();
109                        if trimmed.is_empty() {
110                            continue;
111                        }
112
113                        // Try to parse as response
114                        if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(trimmed) {
115                            if let Some(id) = response.id {
116                                let mut pending = pending_clone.write().await;
117                                if let Some(tx) = pending.remove(&id) {
118                                    let _ = tx.send(response);
119                                }
120                            }
121                            continue;
122                        }
123
124                        // Try to parse as notification
125                        if let Ok(notification) =
126                            serde_json::from_str::<JsonRpcNotification>(trimmed)
127                        {
128                            let mcp_notif = McpNotification::from_json_rpc(&notification);
129                            let _ = notification_tx.send(mcp_notif).await;
130                            continue;
131                        }
132
133                        tracing::warn!("Unknown MCP message: {}", trimmed);
134                    }
135                    Err(e) => {
136                        tracing::error!("Failed to read MCP stdout: {}", e);
137                        break;
138                    }
139                }
140            }
141        });
142
143        Ok(Self {
144            child: RwLock::new(Some(child)),
145            stdin_tx,
146            pending,
147            notification_rx: RwLock::new(Some(notification_rx)),
148            connected: AtomicBool::new(true),
149            request_timeout_secs,
150        })
151    }
152}
153
154#[async_trait]
155impl McpTransport for StdioTransport {
156    async fn request(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
157        if !self.connected.load(Ordering::SeqCst) {
158            return Err(anyhow!("Transport not connected"));
159        }
160
161        // Create response channel
162        let (tx, rx) = oneshot::channel();
163        let request_id = request.id;
164
165        // Register pending request
166        {
167            let mut pending = self.pending.write().await;
168            pending.insert(request_id, tx);
169        }
170
171        // Serialize and send request
172        let msg = serde_json::to_string(&request)? + "\n";
173        self.stdin_tx
174            .send(msg)
175            .await
176            .map_err(|_| anyhow!("Failed to send request"))?;
177
178        // Wait for response with timeout
179        let response = match tokio::time::timeout(
180            std::time::Duration::from_secs(self.request_timeout_secs),
181            rx,
182        )
183        .await
184        {
185            Ok(Ok(resp)) => resp,
186            Ok(Err(_)) => {
187                // Channel closed — clean up pending entry
188                self.pending.write().await.remove(&request_id);
189                return Err(anyhow!("Response channel closed"));
190            }
191            Err(_) => {
192                // Timeout — clean up pending entry to prevent memory leak
193                self.pending.write().await.remove(&request_id);
194                return Err(anyhow!(
195                    "MCP request timed out after {}s",
196                    self.request_timeout_secs
197                ));
198            }
199        };
200
201        Ok(response)
202    }
203
204    async fn notify(&self, notification: JsonRpcNotification) -> Result<()> {
205        if !self.connected.load(Ordering::SeqCst) {
206            return Err(anyhow!("Transport not connected"));
207        }
208
209        let msg = serde_json::to_string(&notification)? + "\n";
210        self.stdin_tx
211            .send(msg)
212            .await
213            .map_err(|_| anyhow!("Failed to send notification"))?;
214
215        Ok(())
216    }
217
218    fn notifications(&self) -> mpsc::Receiver<McpNotification> {
219        // This is a bit awkward - we need to take ownership of the receiver
220        // In practice, this should only be called once
221        let mut rx_guard = self.notification_rx.blocking_write();
222        rx_guard.take().unwrap_or_else(|| {
223            let (_, rx) = mpsc::channel(1);
224            rx
225        })
226    }
227
228    async fn close(&self) -> Result<()> {
229        self.connected.store(false, Ordering::SeqCst);
230
231        // Kill the child process
232        let mut child_guard = self.child.write().await;
233        if let Some(mut child) = child_guard.take() {
234            let _ = child.kill().await;
235        }
236
237        Ok(())
238    }
239
240    fn is_connected(&self) -> bool {
241        self.connected.load(Ordering::SeqCst)
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[tokio::test]
250    async fn test_stdio_transport_spawn_invalid_command() {
251        let result = StdioTransport::spawn("nonexistent_command_12345", &[], &HashMap::new()).await;
252        assert!(result.is_err());
253    }
254
255    #[tokio::test]
256    async fn test_stdio_transport_spawn_echo() {
257        // Use a simple command that exists on most systems
258        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
259
260        if let Ok(transport) = result {
261            assert!(transport.is_connected());
262            transport.close().await.unwrap();
263            assert!(!transport.is_connected());
264        }
265        // If cat doesn't exist, that's fine - skip the test
266    }
267
268    #[tokio::test]
269    async fn test_stdio_transport_is_connected_initial() {
270        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
271        if let Ok(transport) = result {
272            assert!(transport.is_connected());
273            let _ = transport.close().await;
274        }
275    }
276
277    #[tokio::test]
278    async fn test_stdio_transport_close_disconnects() {
279        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
280        if let Ok(transport) = result {
281            assert!(transport.is_connected());
282            transport.close().await.unwrap();
283            assert!(!transport.is_connected());
284        }
285    }
286
287    #[tokio::test]
288    async fn test_stdio_transport_spawn_with_args() {
289        let args = vec!["--version".to_string()];
290        let result = StdioTransport::spawn("cat", &args, &HashMap::new()).await;
291        // May fail depending on system, but should not panic
292        let _ = result;
293    }
294
295    #[tokio::test]
296    async fn test_stdio_transport_spawn_with_env() {
297        let mut env = HashMap::new();
298        env.insert("TEST_VAR".to_string(), "test_value".to_string());
299        let result = StdioTransport::spawn("cat", &[], &env).await;
300        if let Ok(transport) = result {
301            let _ = transport.close().await;
302        }
303    }
304
305    #[tokio::test]
306    async fn test_stdio_transport_double_close() {
307        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
308        if let Ok(transport) = result {
309            transport.close().await.unwrap();
310            // Second close should not panic
311            let result = transport.close().await;
312            assert!(result.is_ok());
313        }
314    }
315
316    #[tokio::test]
317    async fn test_stdio_transport_request_after_close() {
318        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
319        if let Ok(transport) = result {
320            transport.close().await.unwrap();
321
322            let request = JsonRpcRequest::new(1, "test", None);
323            let result = transport.request(request).await;
324            assert!(result.is_err());
325            assert!(result.unwrap_err().to_string().contains("not connected"));
326        }
327    }
328
329    #[tokio::test]
330    async fn test_stdio_transport_notify_after_close() {
331        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
332        if let Ok(transport) = result {
333            transport.close().await.unwrap();
334
335            let notification = JsonRpcNotification::new("test", None);
336            let result = transport.notify(notification).await;
337            assert!(result.is_err());
338            assert!(result.unwrap_err().to_string().contains("not connected"));
339        }
340    }
341
342    #[test]
343    fn test_json_rpc_request_creation() {
344        let request =
345            JsonRpcRequest::new(1, "test_method", Some(serde_json::json!({"key": "value"})));
346        assert_eq!(request.id, 1);
347        assert_eq!(request.method, "test_method");
348        assert!(request.params.is_some());
349    }
350
351    #[test]
352    fn test_json_rpc_notification_creation() {
353        let notification = JsonRpcNotification::new("test_notification", None);
354        assert_eq!(notification.method, "test_notification");
355        assert!(notification.params.is_none());
356    }
357
358    #[tokio::test]
359    async fn test_stdio_transport_custom_timeout() {
360        // Spawn with a very short timeout (1 second)
361        let result = StdioTransport::spawn_with_timeout("cat", &[], &HashMap::new(), 1).await;
362        if let Ok(transport) = result {
363            assert_eq!(transport.request_timeout_secs, 1);
364            let _ = transport.close().await;
365        }
366    }
367
368    #[tokio::test]
369    async fn test_stdio_transport_default_timeout() {
370        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
371        if let Ok(transport) = result {
372            assert_eq!(transport.request_timeout_secs, DEFAULT_REQUEST_TIMEOUT_SECS);
373            let _ = transport.close().await;
374        }
375    }
376}