a3s_code_core/mcp/transport/
stdio.rs1use super::McpTransport;
6use crate::mcp::protocol::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, McpNotification};
7use anyhow::{anyhow, Context, Result};
8use async_trait::async_trait;
9use std::collections::HashMap;
10use std::process::Stdio;
11use std::sync::atomic::{AtomicBool, Ordering};
12use std::sync::Arc;
13use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
14use tokio::process::{Child, Command};
15use tokio::sync::{mpsc, oneshot, RwLock};
16
17const DEFAULT_REQUEST_TIMEOUT_SECS: u64 = 60;
19
20pub struct StdioTransport {
22 child: RwLock<Option<Child>>,
24 stdin_tx: mpsc::Sender<String>,
26 pending: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>>,
28 notification_rx: RwLock<Option<mpsc::Receiver<McpNotification>>>,
30 connected: AtomicBool,
32 request_timeout_secs: u64,
34}
35
36impl StdioTransport {
37 pub async fn spawn(
39 command: &str,
40 args: &[String],
41 env: &HashMap<String, String>,
42 ) -> Result<Self> {
43 Self::spawn_with_timeout(command, args, env, DEFAULT_REQUEST_TIMEOUT_SECS).await
44 }
45
46 pub async fn spawn_with_timeout(
48 command: &str,
49 args: &[String],
50 env: &HashMap<String, String>,
51 request_timeout_secs: u64,
52 ) -> Result<Self> {
53 let mut cmd = Command::new(command);
55 cmd.args(args)
56 .stdin(Stdio::piped())
57 .stdout(Stdio::piped())
58 .stderr(Stdio::piped())
59 .kill_on_drop(true);
60
61 for (key, value) in env {
63 cmd.env(key, value);
64 }
65
66 let mut child = cmd
67 .spawn()
68 .with_context(|| format!("Failed to spawn MCP server: {} {:?}", command, args))?;
69
70 let stdin = child.stdin.take().ok_or_else(|| anyhow!("No stdin"))?;
71 let stdout = child.stdout.take().ok_or_else(|| anyhow!("No stdout"))?;
72
73 let (stdin_tx, mut stdin_rx) = mpsc::channel::<String>(100);
75 let (notification_tx, notification_rx) = mpsc::channel::<McpNotification>(100);
76 let pending: Arc<RwLock<HashMap<u64, oneshot::Sender<JsonRpcResponse>>>> =
77 Arc::new(RwLock::new(HashMap::new()));
78
79 let mut stdin_writer = stdin;
81 tokio::spawn(async move {
82 while let Some(msg) = stdin_rx.recv().await {
83 if let Err(e) = stdin_writer.write_all(msg.as_bytes()).await {
84 tracing::error!("Failed to write to MCP stdin: {}", e);
85 break;
86 }
87 if let Err(e) = stdin_writer.flush().await {
88 tracing::error!("Failed to flush MCP stdin: {}", e);
89 break;
90 }
91 }
92 });
93
94 let pending_clone = pending.clone();
96 tokio::spawn(async move {
97 let mut reader = BufReader::new(stdout);
98 let mut line = String::new();
99
100 loop {
101 line.clear();
102 match reader.read_line(&mut line).await {
103 Ok(0) => {
104 tracing::debug!("MCP stdout closed");
105 break;
106 }
107 Ok(_) => {
108 let trimmed = line.trim();
109 if trimmed.is_empty() {
110 continue;
111 }
112
113 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(trimmed) {
115 if let Some(id) = response.id {
116 let mut pending = pending_clone.write().await;
117 if let Some(tx) = pending.remove(&id) {
118 let _ = tx.send(response);
119 }
120 }
121 continue;
122 }
123
124 if let Ok(notification) =
126 serde_json::from_str::<JsonRpcNotification>(trimmed)
127 {
128 let mcp_notif = McpNotification::from_json_rpc(¬ification);
129 let _ = notification_tx.send(mcp_notif).await;
130 continue;
131 }
132
133 tracing::warn!("Unknown MCP message: {}", trimmed);
134 }
135 Err(e) => {
136 tracing::error!("Failed to read MCP stdout: {}", e);
137 break;
138 }
139 }
140 }
141 });
142
143 Ok(Self {
144 child: RwLock::new(Some(child)),
145 stdin_tx,
146 pending,
147 notification_rx: RwLock::new(Some(notification_rx)),
148 connected: AtomicBool::new(true),
149 request_timeout_secs,
150 })
151 }
152}
153
154#[async_trait]
155impl McpTransport for StdioTransport {
156 async fn request(&self, request: JsonRpcRequest) -> Result<JsonRpcResponse> {
157 if !self.connected.load(Ordering::SeqCst) {
158 return Err(anyhow!("Transport not connected"));
159 }
160
161 let (tx, rx) = oneshot::channel();
163 let request_id = request.id;
164
165 {
167 let mut pending = self.pending.write().await;
168 pending.insert(request_id, tx);
169 }
170
171 let msg = serde_json::to_string(&request)? + "\n";
173 self.stdin_tx
174 .send(msg)
175 .await
176 .map_err(|_| anyhow!("Failed to send request"))?;
177
178 let response = match tokio::time::timeout(
180 std::time::Duration::from_secs(self.request_timeout_secs),
181 rx,
182 )
183 .await
184 {
185 Ok(Ok(resp)) => resp,
186 Ok(Err(_)) => {
187 self.pending.write().await.remove(&request_id);
189 return Err(anyhow!("Response channel closed"));
190 }
191 Err(_) => {
192 self.pending.write().await.remove(&request_id);
194 return Err(anyhow!(
195 "MCP request timed out after {}s",
196 self.request_timeout_secs
197 ));
198 }
199 };
200
201 Ok(response)
202 }
203
204 async fn notify(&self, notification: JsonRpcNotification) -> Result<()> {
205 if !self.connected.load(Ordering::SeqCst) {
206 return Err(anyhow!("Transport not connected"));
207 }
208
209 let msg = serde_json::to_string(¬ification)? + "\n";
210 self.stdin_tx
211 .send(msg)
212 .await
213 .map_err(|_| anyhow!("Failed to send notification"))?;
214
215 Ok(())
216 }
217
218 fn notifications(&self) -> mpsc::Receiver<McpNotification> {
219 let mut rx_guard = self.notification_rx.blocking_write();
222 rx_guard.take().unwrap_or_else(|| {
223 let (_, rx) = mpsc::channel(1);
224 rx
225 })
226 }
227
228 async fn close(&self) -> Result<()> {
229 self.connected.store(false, Ordering::SeqCst);
230
231 let mut child_guard = self.child.write().await;
233 if let Some(mut child) = child_guard.take() {
234 let _ = child.kill().await;
235 }
236
237 Ok(())
238 }
239
240 fn is_connected(&self) -> bool {
241 self.connected.load(Ordering::SeqCst)
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[tokio::test]
250 async fn test_stdio_transport_spawn_invalid_command() {
251 let result = StdioTransport::spawn("nonexistent_command_12345", &[], &HashMap::new()).await;
252 assert!(result.is_err());
253 }
254
255 #[tokio::test]
256 async fn test_stdio_transport_spawn_echo() {
257 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
259
260 if let Ok(transport) = result {
261 assert!(transport.is_connected());
262 transport.close().await.unwrap();
263 assert!(!transport.is_connected());
264 }
265 }
267
268 #[tokio::test]
269 async fn test_stdio_transport_is_connected_initial() {
270 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
271 if let Ok(transport) = result {
272 assert!(transport.is_connected());
273 let _ = transport.close().await;
274 }
275 }
276
277 #[tokio::test]
278 async fn test_stdio_transport_close_disconnects() {
279 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
280 if let Ok(transport) = result {
281 assert!(transport.is_connected());
282 transport.close().await.unwrap();
283 assert!(!transport.is_connected());
284 }
285 }
286
287 #[tokio::test]
288 async fn test_stdio_transport_spawn_with_args() {
289 let args = vec!["--version".to_string()];
290 let result = StdioTransport::spawn("cat", &args, &HashMap::new()).await;
291 let _ = result;
293 }
294
295 #[tokio::test]
296 async fn test_stdio_transport_spawn_with_env() {
297 let mut env = HashMap::new();
298 env.insert("TEST_VAR".to_string(), "test_value".to_string());
299 let result = StdioTransport::spawn("cat", &[], &env).await;
300 if let Ok(transport) = result {
301 let _ = transport.close().await;
302 }
303 }
304
305 #[tokio::test]
306 async fn test_stdio_transport_double_close() {
307 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
308 if let Ok(transport) = result {
309 transport.close().await.unwrap();
310 let result = transport.close().await;
312 assert!(result.is_ok());
313 }
314 }
315
316 #[tokio::test]
317 async fn test_stdio_transport_request_after_close() {
318 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
319 if let Ok(transport) = result {
320 transport.close().await.unwrap();
321
322 let request = JsonRpcRequest::new(1, "test", None);
323 let result = transport.request(request).await;
324 assert!(result.is_err());
325 assert!(result.unwrap_err().to_string().contains("not connected"));
326 }
327 }
328
329 #[tokio::test]
330 async fn test_stdio_transport_notify_after_close() {
331 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
332 if let Ok(transport) = result {
333 transport.close().await.unwrap();
334
335 let notification = JsonRpcNotification::new("test", None);
336 let result = transport.notify(notification).await;
337 assert!(result.is_err());
338 assert!(result.unwrap_err().to_string().contains("not connected"));
339 }
340 }
341
342 #[test]
343 fn test_json_rpc_request_creation() {
344 let request =
345 JsonRpcRequest::new(1, "test_method", Some(serde_json::json!({"key": "value"})));
346 assert_eq!(request.id, 1);
347 assert_eq!(request.method, "test_method");
348 assert!(request.params.is_some());
349 }
350
351 #[test]
352 fn test_json_rpc_notification_creation() {
353 let notification = JsonRpcNotification::new("test_notification", None);
354 assert_eq!(notification.method, "test_notification");
355 assert!(notification.params.is_none());
356 }
357
358 #[tokio::test]
359 async fn test_stdio_transport_custom_timeout() {
360 let result = StdioTransport::spawn_with_timeout("cat", &[], &HashMap::new(), 1).await;
362 if let Ok(transport) = result {
363 assert_eq!(transport.request_timeout_secs, 1);
364 let _ = transport.close().await;
365 }
366 }
367
368 #[tokio::test]
369 async fn test_stdio_transport_default_timeout() {
370 let result = StdioTransport::spawn("cat", &[], &HashMap::new()).await;
371 if let Ok(transport) = result {
372 assert_eq!(transport.request_timeout_secs, DEFAULT_REQUEST_TIMEOUT_SECS);
373 let _ = transport.close().await;
374 }
375 }
376}