a3s_code_core/mcp/transport/
stdio.rs1use super::McpTransport;
6use crate::mcp::protocol::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, McpNotification};
7use anyhow::{anyhow, Context, Result};
8use async_trait::async_trait;
9use std::collections::HashMap;
10use std::process::Stdio;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, Command};
15use tokio::sync::{mpsc, oneshot, RwLock};
16
17const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 60;
19
20pub struct StdioTransport {
22 child: RwLock<Option<Child>>,
24 stdin_tx: mpsc::Sender<String>,
26 pending: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
28 notification_rx: RwLock<Option<mpsc::Receiver<McpNotification>>>,
30 connected: AtomicBool,
32 request_timeout_secs: u64,
34}
35
36impl StdioTransport {
37 pub async fn spawn(
39 command: &str,
40 args: &[String],
41 env: &HashMap<String, String>,
42 ) -> Result<Self> {
43 Self::spawn_with_timeout(command, args, env, DEFAULT_REQUEST_TIMEOUT_SECS).await
44 }
45
46 pub async fn spawn_with_timeout(
48 command: &str,
49 args: &[String],
50 env: &HashMap<String, String>,
51 request_timeout_secs: u64,
52 ) -> Result<Self> {
53 let mut cmd = Command::new(command);
55 cmd.args(args)
56 .stdin(Stdio::piped())
57 .stdout(Stdio::piped())
58 .stderr(Stdio::piped())
59 .kill_on_drop(true);
60
61 for (key, value) in env {
63 cmd.env(key, value);
64 }
65
66 let mut child = cmd
67 .spawn()
68 .with_context(|| format!("Failed to spawn MCP server: {} {:?}", command, args))?;
69
70 let stdin = child.stdin.take().ok_or_else(|| anyhow!("No stdin"))?;
71 let stdout = child.stdout.take().ok_or_else(|| anyhow!("No stdout"))?;
72
73 let (stdin_tx, mut stdin_rx) = mpsc::channel::<String>(100);
75 let (notification_tx, notification_rx) = mpsc::channel::<McpNotification>(100);
76 let pending: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>> =
77 Arc::new(RwLock::new(HashMap::new()));
78
79 let mut stdin_writer = stdin;
81 tokio::spawn(async move {
82 while let Some(msg) = stdin_rx.recv().await {
83 if let Err(e) = stdin_writer.write_all(msg.as_bytes()).await {
84 tracing::error!("Failed to write to MCP stdin: {}", e);
85 break;
86 }
87 if let Err(e) = stdin_writer.flush().await {
88 tracing::error!("Failed to flush MCP stdin: {}", e);
89 break;
90 }
91 }
92 });
93
94 let pending_clone = pending.clone();
96 tokio::spawn(async move {
97 let mut reader = BufReader::new(stdout);
98 let mut line = String::new();
99
100 loop {
101 line.clear();
102 match reader.read_line(&mut line).await {
103 Ok(0) => {
104 tracing::debug!("MCP stdout closed");
105 break;
106 }
107 Ok(_) => {
108 let trimmed = line.trim();
109 if trimmed.is_empty() {
110 continue;
111 }
112
113 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(trimmed) {
115 if let Some(id) = response.id {
116 let mut pending = pending_clone.write().await;
117 if let Some(tx) = pending.remove(&id) {
118 let _ = tx.send(response);
119 }
120 }
121 continue;
122 }
123
124 if let Ok(notification) =
126 serde_json::from_str::<JsonRpcNotification>(trimmed)
127 {
128 let mcp_notif = McpNotification::from_json_rpc(¬ification);
129 let _ = notification_tx.send(mcp_notif).await;
130 continue;
131 }
132
133 tracing::warn!("Unknown MCP message: {}", trimmed);
134 }
135 Err(e) => {
136 tracing::error!("Failed to read MCP stdout: {}", e);
137 break;
138 }
139 }
140 }
141 });
142
143 Ok(Self {
144 child: RwLock::new(Some(child)),
145 stdin_tx,
146 pending,
147 notification_rx: RwLock::new(Some(notification_rx)),
148 connected: AtomicBool::new(true),
149 request_timeout_secs,
150 })
151 }
152}
153
154#[async_trait]
155impl McpTransport for StdioTransport {
156 async fn request(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
157 if !self.connected.load(Ordering::SeqCst) {
158 return Err(anyhow!("Transport not connected"));
159 }
160
161 let (tx, rx) = oneshot::channel();
163
164 {
166 let mut pending = self.pending.write().await;
167 pending.insert(request.id, tx);
168 }
169
170 let msg = serde_json::to_string(&request)? + "\n";
172 self.stdin_tx
173 .send(msg)
174 .await
175 .map_err(|_| anyhow!("Failed to send request"))?;
176
177 let response = tokio::time::timeout(
179 std::time::Duration::from_secs(self.request_timeout_secs),
180 rx,
181 )
182 .await
183 .map_err(|_| anyhow!("MCP request timed out after {}s", self.request_timeout_secs))?
184 .map_err(|_| anyhow!("Response channel closed"))?;
185
186 Ok(response)
187 }
188
189 async fn notify(&self, notification: JsonRpcNotification) -> Result<()> {
190 if !self.connected.load(Ordering::SeqCst) {
191 return Err(anyhow!("Transport not connected"));
192 }
193
194 let msg = serde_json::to_string(¬ification)? + "\n";
195 self.stdin_tx
196 .send(msg)
197 .await
198 .map_err(|_| anyhow!("Failed to send notification"))?;
199
200 Ok(())
201 }
202
203 fn notifications(&self) -> mpsc::Receiver<McpNotification> {
204 let mut rx_guard = self.notification_rx.blocking_write();
207 rx_guard.take().unwrap_or_else(|| {
208 let (_, rx) = mpsc::channel(1);
209 rx
210 })
211 }
212
213 async fn close(&self) -> Result<()> {
214 self.connected.store(false, Ordering::SeqCst);
215
216 let mut child_guard = self.child.write().await;
218 if let Some(mut child) = child_guard.take() {
219 let _ = child.kill().await;
220 }
221
222 Ok(())
223 }
224
225 fn is_connected(&self) -> bool {
226 self.connected.load(Ordering::SeqCst)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[tokio::test]
235 async fn test_stdio_transport_spawn_invalid_command() {
236 let result = StdioTransport::spawn("nonexistent_command_12345", &[], &HashMap::new()).await;
237 assert!(result.is_err());
238 }
239
240 #[tokio::test]
241 async fn test_stdio_transport_spawn_echo() {
242 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
244
245 if let Ok(transport) = result {
246 assert!(transport.is_connected());
247 transport.close().await.unwrap();
248 assert!(!transport.is_connected());
249 }
250 }
252
253 #[tokio::test]
254 async fn test_stdio_transport_is_connected_initial() {
255 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
256 if let Ok(transport) = result {
257 assert!(transport.is_connected());
258 let _ = transport.close().await;
259 }
260 }
261
262 #[tokio::test]
263 async fn test_stdio_transport_close_disconnects() {
264 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
265 if let Ok(transport) = result {
266 assert!(transport.is_connected());
267 transport.close().await.unwrap();
268 assert!(!transport.is_connected());
269 }
270 }
271
272 #[tokio::test]
273 async fn test_stdio_transport_spawn_with_args() {
274 let args = vec!["--version".to_string()];
275 let result = StdioTransport::spawn("cat", &args, &HashMap::new()).await;
276 let _ = result;
278 }
279
280 #[tokio::test]
281 async fn test_stdio_transport_spawn_with_env() {
282 let mut env = HashMap::new();
283 env.insert("TEST_VAR".to_string(), "test_value".to_string());
284 let result = StdioTransport::spawn("cat", &[], &env).await;
285 if let Ok(transport) = result {
286 let _ = transport.close().await;
287 }
288 }
289
290 #[tokio::test]
291 async fn test_stdio_transport_double_close() {
292 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
293 if let Ok(transport) = result {
294 transport.close().await.unwrap();
295 let result = transport.close().await;
297 assert!(result.is_ok());
298 }
299 }
300
301 #[tokio::test]
302 async fn test_stdio_transport_request_after_close() {
303 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
304 if let Ok(transport) = result {
305 transport.close().await.unwrap();
306
307 let request = JsonRpcRequest::new(1, "test", None);
308 let result = transport.request(request).await;
309 assert!(result.is_err());
310 assert!(result.unwrap_err().to_string().contains("not connected"));
311 }
312 }
313
314 #[tokio::test]
315 async fn test_stdio_transport_notify_after_close() {
316 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
317 if let Ok(transport) = result {
318 transport.close().await.unwrap();
319
320 let notification = JsonRpcNotification::new("test", None);
321 let result = transport.notify(notification).await;
322 assert!(result.is_err());
323 assert!(result.unwrap_err().to_string().contains("not connected"));
324 }
325 }
326
327 #[test]
328 fn test_json_rpc_request_creation() {
329 let request =
330 JsonRpcRequest::new(1, "test_method", Some(serde_json::json!({"key": "value"})));
331 assert_eq!(request.id, 1);
332 assert_eq!(request.method, "test_method");
333 assert!(request.params.is_some());
334 }
335
336 #[test]
337 fn test_json_rpc_notification_creation() {
338 let notification = JsonRpcNotification::new("test_notification", None);
339 assert_eq!(notification.method, "test_notification");
340 assert!(notification.params.is_none());
341 }
342
343 #[tokio::test]
344 async fn test_stdio_transport_custom_timeout() {
345 let result = StdioTransport::spawn_with_timeout("cat", &[], &HashMap::new(), 1).await;
347 if let Ok(transport) = result {
348 assert_eq!(transport.request_timeout_secs, 1);
349 let _ = transport.close().await;
350 }
351 }
352
353 #[tokio::test]
354 async fn test_stdio_transport_default_timeout() {
355 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
356 if let Ok(transport) = result {
357 assert_eq!(transport.request_timeout_secs, DEFAULT_REQUEST_TIMEOUT_SECS);
358 let _ = transport.close().await;
359 }
360 }
361}