Skip to main content

a3s_code_core/mcp/transport/
stdio.rs

1//! Stdio Transport for MCP
2//!
3//! Implements MCP transport over standard input/output for local process communication.
4
5use super::McpTransport;
6use crate::mcp::protocol::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, McpNotification};
7use anyhow::{anyhow, Context, Result};
8use async_trait::async_trait;
9use std::collections::HashMap;
10use std::process::Stdio;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, Command};
15use tokio::sync::{mpsc, oneshot, RwLock};
16
17/// Default request timeout for MCP tool calls
18const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 60;
19
20/// Stdio transport for MCP servers
21pub struct StdioTransport {
22    /// Child process
23    child: RwLock<Option<Child>>,
24    /// Stdin writer
25    stdin_tx: mpsc::Sender<String>,
26    /// Pending requests (id -> response sender)
27    pending: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
28    /// Notification receiver
29    notification_rx: RwLock<Option<mpsc::Receiver<McpNotification>>>,
30    /// Connected flag
31    connected: AtomicBool,
32    /// Per-request timeout in seconds
33    request_timeout_secs: u64,
34}
35
36impl StdioTransport {
37    /// Create a new stdio transport by spawning a process
38    pub async fn spawn(
39        command: &str,
40        args: &[String],
41        env: &HashMap<String, String>,
42    ) -> Result<Self> {
43        Self::spawn_with_timeout(command, args, env, DEFAULT_REQUEST_TIMEOUT_SECS).await
44    }
45
46    /// Create a new stdio transport with a custom request timeout
47    pub async fn spawn_with_timeout(
48        command: &str,
49        args: &[String],
50        env: &HashMap<String, String>,
51        request_timeout_secs: u64,
52    ) -> Result<Self> {
53        // Spawn the process
54        let mut cmd = Command::new(command);
55        cmd.args(args)
56            .stdin(Stdio::piped())
57            .stdout(Stdio::piped())
58            .stderr(Stdio::piped())
59            .kill_on_drop(true);
60
61        // Add environment variables
62        for (key, value) in env {
63            cmd.env(key, value);
64        }
65
66        let mut child = cmd
67            .spawn()
68            .with_context(|| format!("Failed to spawn MCP server: {} {:?}", command, args))?;
69
70        let stdin = child.stdin.take().ok_or_else(|| anyhow!("No stdin"))?;
71        let stdout = child.stdout.take().ok_or_else(|| anyhow!("No stdout"))?;
72
73        // Create channels
74        let (stdin_tx, mut stdin_rx) = mpsc::channel::<String>(100);
75        let (notification_tx, notification_rx) = mpsc::channel::<McpNotification>(100);
76        let pending: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>> =
77            Arc::new(RwLock::new(HashMap::new()));
78
79        // Spawn stdin writer task
80        let mut stdin_writer = stdin;
81        tokio::spawn(async move {
82            while let Some(msg) = stdin_rx.recv().await {
83                if let Err(e) = stdin_writer.write_all(msg.as_bytes()).await {
84                    tracing::error!("Failed to write to MCP stdin: {}", e);
85                    break;
86                }
87                if let Err(e) = stdin_writer.flush().await {
88                    tracing::error!("Failed to flush MCP stdin: {}", e);
89                    break;
90                }
91            }
92        });
93
94        // Spawn stdout reader task
95        let pending_clone = pending.clone();
96        tokio::spawn(async move {
97            let mut reader = BufReader::new(stdout);
98            let mut line = String::new();
99
100            loop {
101                line.clear();
102                match reader.read_line(&mut line).await {
103                    Ok(0) => {
104                        tracing::debug!("MCP stdout closed");
105                        break;
106                    }
107                    Ok(_) => {
108                        let trimmed = line.trim();
109                        if trimmed.is_empty() {
110                            continue;
111                        }
112
113                        // Try to parse as response
114                        if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(trimmed) {
115                            if let Some(id) = response.id {
116                                let mut pending = pending_clone.write().await;
117                                if let Some(tx) = pending.remove(&id) {
118                                    let _ = tx.send(response);
119                                }
120                            }
121                            continue;
122                        }
123
124                        // Try to parse as notification
125                        if let Ok(notification) =
126                            serde_json::from_str::<JsonRpcNotification>(trimmed)
127                        {
128                            let mcp_notif = McpNotification::from_json_rpc(&notification);
129                            let _ = notification_tx.send(mcp_notif).await;
130                            continue;
131                        }
132
133                        tracing::warn!("Unknown MCP message: {}", trimmed);
134                    }
135                    Err(e) => {
136                        tracing::error!("Failed to read MCP stdout: {}", e);
137                        break;
138                    }
139                }
140            }
141        });
142
143        Ok(Self {
144            child: RwLock::new(Some(child)),
145            stdin_tx,
146            pending,
147            notification_rx: RwLock::new(Some(notification_rx)),
148            connected: AtomicBool::new(true),
149            request_timeout_secs,
150        })
151    }
152}
153
154#[async_trait]
155impl McpTransport for StdioTransport {
156    async fn request(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
157        if !self.connected.load(Ordering::SeqCst) {
158            return Err(anyhow!("Transport not connected"));
159        }
160
161        // Create response channel
162        let (tx, rx) = oneshot::channel();
163
164        // Register pending request
165        {
166            let mut pending = self.pending.write().await;
167            pending.insert(request.id, tx);
168        }
169
170        // Serialize and send request
171        let msg = serde_json::to_string(&request)? + "\n";
172        self.stdin_tx
173            .send(msg)
174            .await
175            .map_err(|_| anyhow!("Failed to send request"))?;
176
177        // Wait for response with timeout
178        let response = tokio::time::timeout(
179            std::time::Duration::from_secs(self.request_timeout_secs),
180            rx,
181        )
182        .await
183        .map_err(|_| anyhow!("MCP request timed out after {}s", self.request_timeout_secs))?
184        .map_err(|_| anyhow!("Response channel closed"))?;
185
186        Ok(response)
187    }
188
189    async fn notify(&self, notification: JsonRpcNotification) -> Result<()> {
190        if !self.connected.load(Ordering::SeqCst) {
191            return Err(anyhow!("Transport not connected"));
192        }
193
194        let msg = serde_json::to_string(&notification)? + "\n";
195        self.stdin_tx
196            .send(msg)
197            .await
198            .map_err(|_| anyhow!("Failed to send notification"))?;
199
200        Ok(())
201    }
202
203    fn notifications(&self) -> mpsc::Receiver<McpNotification> {
204        // This is a bit awkward - we need to take ownership of the receiver
205        // In practice, this should only be called once
206        let mut rx_guard = self.notification_rx.blocking_write();
207        rx_guard.take().unwrap_or_else(|| {
208            let (_, rx) = mpsc::channel(1);
209            rx
210        })
211    }
212
213    async fn close(&self) -> Result<()> {
214        self.connected.store(false, Ordering::SeqCst);
215
216        // Kill the child process
217        let mut child_guard = self.child.write().await;
218        if let Some(mut child) = child_guard.take() {
219            let _ = child.kill().await;
220        }
221
222        Ok(())
223    }
224
225    fn is_connected(&self) -> bool {
226        self.connected.load(Ordering::SeqCst)
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    #[tokio::test]
235    async fn test_stdio_transport_spawn_invalid_command() {
236        let result = StdioTransport::spawn("nonexistent_command_12345", &[], &HashMap::new()).await;
237        assert!(result.is_err());
238    }
239
240    #[tokio::test]
241    async fn test_stdio_transport_spawn_echo() {
242        // Use a simple command that exists on most systems
243        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
244
245        if let Ok(transport) = result {
246            assert!(transport.is_connected());
247            transport.close().await.unwrap();
248            assert!(!transport.is_connected());
249        }
250        // If cat doesn't exist, that's fine - skip the test
251    }
252
253    #[tokio::test]
254    async fn test_stdio_transport_is_connected_initial() {
255        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
256        if let Ok(transport) = result {
257            assert!(transport.is_connected());
258            let _ = transport.close().await;
259        }
260    }
261
262    #[tokio::test]
263    async fn test_stdio_transport_close_disconnects() {
264        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
265        if let Ok(transport) = result {
266            assert!(transport.is_connected());
267            transport.close().await.unwrap();
268            assert!(!transport.is_connected());
269        }
270    }
271
272    #[tokio::test]
273    async fn test_stdio_transport_spawn_with_args() {
274        let args = vec!["--version".to_string()];
275        let result = StdioTransport::spawn("cat", &args, &HashMap::new()).await;
276        // May fail depending on system, but should not panic
277        let _ = result;
278    }
279
280    #[tokio::test]
281    async fn test_stdio_transport_spawn_with_env() {
282        let mut env = HashMap::new();
283        env.insert("TEST_VAR".to_string(), "test_value".to_string());
284        let result = StdioTransport::spawn("cat", &[], &env).await;
285        if let Ok(transport) = result {
286            let _ = transport.close().await;
287        }
288    }
289
290    #[tokio::test]
291    async fn test_stdio_transport_double_close() {
292        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
293        if let Ok(transport) = result {
294            transport.close().await.unwrap();
295            // Second close should not panic
296            let result = transport.close().await;
297            assert!(result.is_ok());
298        }
299    }
300
301    #[tokio::test]
302    async fn test_stdio_transport_request_after_close() {
303        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
304        if let Ok(transport) = result {
305            transport.close().await.unwrap();
306
307            let request = JsonRpcRequest::new(1, "test", None);
308            let result = transport.request(request).await;
309            assert!(result.is_err());
310            assert!(result.unwrap_err().to_string().contains("not connected"));
311        }
312    }
313
314    #[tokio::test]
315    async fn test_stdio_transport_notify_after_close() {
316        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
317        if let Ok(transport) = result {
318            transport.close().await.unwrap();
319
320            let notification = JsonRpcNotification::new("test", None);
321            let result = transport.notify(notification).await;
322            assert!(result.is_err());
323            assert!(result.unwrap_err().to_string().contains("not connected"));
324        }
325    }
326
327    #[test]
328    fn test_json_rpc_request_creation() {
329        let request =
330            JsonRpcRequest::new(1, "test_method", Some(serde_json::json!({"key": "value"})));
331        assert_eq!(request.id, 1);
332        assert_eq!(request.method, "test_method");
333        assert!(request.params.is_some());
334    }
335
336    #[test]
337    fn test_json_rpc_notification_creation() {
338        let notification = JsonRpcNotification::new("test_notification", None);
339        assert_eq!(notification.method, "test_notification");
340        assert!(notification.params.is_none());
341    }
342
343    #[tokio::test]
344    async fn test_stdio_transport_custom_timeout() {
345        // Spawn with a very short timeout (1 second)
346        let result = StdioTransport::spawn_with_timeout("cat", &[], &HashMap::new(), 1).await;
347        if let Ok(transport) = result {
348            assert_eq!(transport.request_timeout_secs, 1);
349            let _ = transport.close().await;
350        }
351    }
352
353    #[tokio::test]
354    async fn test_stdio_transport_default_timeout() {
355        let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
356        if let Ok(transport) = result {
357            assert_eq!(transport.request_timeout_secs, DEFAULT_REQUEST_TIMEOUT_SECS);
358            let _ = transport.close().await;
359        }
360    }
361}