Skip to main content

model_context_protocol/client/
stdio.rs

1//! Stdio Transport for connecting to MCP server processes.
2//!
3//! Communicates with MCP servers via standard input/output using JSON-RPC.
4//! This is used to connect to MCP servers that run as child processes.
5//!
6//! ## Architecture
7//!
8//! Uses true async I/O with `tokio::process` for non-blocking communication:
9//! - Separate async tasks for stdin writing and stdout reading
10//! - Pending request tracking with oneshot channels for responses
11//! - No mutex held across I/O operations, enabling concurrent requests
12
13use async_trait::async_trait;
14use dashmap::DashMap;
15use serde_json::Value;
16use std::collections::HashMap;
17use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
21use tokio::process::{Child, Command};
22use tokio::sync::{mpsc, oneshot};
23
24use crate::protocol::*;
25use crate::transport::{InitializeParams, McpTransport, McpTransportError, TransportTypeId};
26
27// =============================================================================
28// True Async Stdio Transport (using tokio::process)
29// =============================================================================
30
31/// Message sent to the writer task
32struct WriteRequest {
33    request_line: String,
34}
35
36/// True async stdio transport using tokio::process.
37///
38/// This implementation:
39/// - Uses `tokio::process::Command` for async process spawning
40/// - Separate tasks for reading stdout and writing to stdin
41/// - No mutex held during I/O, enabling true concurrent requests
42/// - Pending request map for matching responses to requests
43pub struct TokioStdioTransport {
44    /// Sender for write requests
45    write_tx: mpsc::Sender<WriteRequest>,
46    /// Pending requests: request ID -> response sender
47    pending: Arc<DashMap<i64, oneshot::Sender<Result<Value, McpTransportError>>>>,
48    /// Next request ID
49    next_id: AtomicI64,
50    /// Whether the transport is alive
51    alive: Arc<AtomicBool>,
52    /// Handle to the child process (for cleanup)
53    child: Arc<tokio::sync::Mutex<Child>>,
54}
55
56impl TokioStdioTransport {
57    /// Spawn a new MCP server process with true async I/O.
58    pub async fn spawn(command: &str, args: &[String]) -> Result<Self, McpTransportError> {
59        Self::spawn_with_env(command, args, HashMap::new()).await
60    }
61
62    /// Spawn with environment variables.
63    pub async fn spawn_with_env(
64        command: &str,
65        args: &[String],
66        env: HashMap<String, String>,
67    ) -> Result<Self, McpTransportError> {
68        let mut cmd = Command::new(command);
69        cmd.args(args)
70            .stdin(std::process::Stdio::piped())
71            .stdout(std::process::Stdio::piped())
72            .stderr(std::process::Stdio::piped())
73            .kill_on_drop(true);
74
75        for (key, value) in env {
76            cmd.env(key, value);
77        }
78
79        let mut child = cmd.spawn().map_err(|e| {
80            McpTransportError::TransportError(format!(
81                "Failed to spawn process '{}': {}",
82                command, e
83            ))
84        })?;
85
86        // Take ownership of stdin and stdout
87        let stdin = child
88            .stdin
89            .take()
90            .ok_or_else(|| McpTransportError::TransportError("Failed to get stdin".to_string()))?;
91        let stdout = child
92            .stdout
93            .take()
94            .ok_or_else(|| McpTransportError::TransportError("Failed to get stdout".to_string()))?;
95
96        let alive = Arc::new(AtomicBool::new(true));
97        let pending: Arc<DashMap<i64, oneshot::Sender<Result<Value, McpTransportError>>>> =
98            Arc::new(DashMap::new());
99
100        // Channel for write requests
101        let (write_tx, mut write_rx) = mpsc::channel::<WriteRequest>(256);
102
103        // Spawn writer task
104        let alive_writer = Arc::clone(&alive);
105        let mut stdin = stdin;
106        tokio::spawn(async move {
107            while let Some(req) = write_rx.recv().await {
108                if !alive_writer.load(Ordering::SeqCst) {
109                    break;
110                }
111                if let Err(e) = stdin.write_all(req.request_line.as_bytes()).await {
112                    eprintln!("Stdio write error: {}", e);
113                    alive_writer.store(false, Ordering::SeqCst);
114                    break;
115                }
116                if let Err(e) = stdin.flush().await {
117                    eprintln!("Stdio flush error: {}", e);
118                    alive_writer.store(false, Ordering::SeqCst);
119                    break;
120                }
121            }
122        });
123
124        // Spawn reader task
125        let pending_reader = Arc::clone(&pending);
126        let alive_reader = Arc::clone(&alive);
127        let mut reader = BufReader::new(stdout);
128        tokio::spawn(async move {
129            let mut line = String::new();
130            loop {
131                line.clear();
132                match reader.read_line(&mut line).await {
133                    Ok(0) => {
134                        // EOF - process closed
135                        alive_reader.store(false, Ordering::SeqCst);
136                        break;
137                    }
138                    Ok(_) => {
139                        // Parse response
140                        match serde_json::from_str::<JsonRpcResponse>(&line) {
141                            Ok(response) => {
142                                if let JsonRpcId::Number(id) = &response.id {
143                                    if let Some((_, tx)) = pending_reader.remove(id) {
144                                        let result = match response.payload {
145                                            JsonRpcPayload::Success { result } => Ok(result),
146                                            JsonRpcPayload::Error { error } => {
147                                                Err(McpTransportError::ServerError(format!(
148                                                    "MCP Error: {}",
149                                                    error
150                                                )))
151                                            }
152                                        };
153                                        let _ = tx.send(result);
154                                    }
155                                }
156                            }
157                            Err(e) => {
158                                eprintln!(
159                                    "Failed to parse response: {} - line: {}",
160                                    e,
161                                    line.trim()
162                                );
163                            }
164                        }
165                    }
166                    Err(e) => {
167                        eprintln!("Stdio read error: {}", e);
168                        alive_reader.store(false, Ordering::SeqCst);
169                        break;
170                    }
171                }
172            }
173
174            // Clean up pending requests on shutdown - receivers will get
175            // a channel closed error when the senders are dropped
176            pending_reader.clear();
177        });
178
179        Ok(Self {
180            write_tx,
181            pending,
182            next_id: AtomicI64::new(1),
183            alive,
184            child: Arc::new(tokio::sync::Mutex::new(child)),
185        })
186    }
187
188    /// Send a request and wait for response with timeout.
189    pub async fn send_request(
190        &self,
191        method: &str,
192        params: Option<Value>,
193        timeout_duration: Duration,
194    ) -> Result<Value, McpTransportError> {
195        if !self.alive.load(Ordering::SeqCst) {
196            return Err(McpTransportError::ConnectionClosed);
197        }
198
199        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
200        let request = JsonRpcRequest::new(JsonRpcId::Number(id), method.to_string(), params);
201        let request_json = serde_json::to_string(&request)?;
202        let request_line = format!("{}\n", request_json);
203
204        // Create response channel and register pending request
205        let (tx, rx) = oneshot::channel();
206        self.pending.insert(id, tx);
207
208        // Send write request
209        if self
210            .write_tx
211            .send(WriteRequest { request_line })
212            .await
213            .is_err()
214        {
215            self.pending.remove(&id);
216            return Err(McpTransportError::ConnectionClosed);
217        }
218
219        // Wait for response with timeout
220        match tokio::time::timeout(timeout_duration, rx).await {
221            Ok(Ok(result)) => result,
222            Ok(Err(_)) => {
223                self.pending.remove(&id);
224                Err(McpTransportError::ConnectionClosed)
225            }
226            Err(_) => {
227                self.pending.remove(&id);
228                Err(McpTransportError::Timeout(format!(
229                    "Request timed out after {:?}",
230                    timeout_duration
231                )))
232            }
233        }
234    }
235
236    /// Check if the transport is alive.
237    pub fn is_alive(&self) -> bool {
238        self.alive.load(Ordering::SeqCst)
239    }
240
241    /// Stop the transport and kill the process.
242    pub async fn stop(&self) -> Result<(), McpTransportError> {
243        self.alive.store(false, Ordering::SeqCst);
244
245        // Kill the child process
246        let mut child = self.child.lock().await;
247        if let Err(e) = child.kill().await {
248            // Process may have already exited
249            if e.kind() != std::io::ErrorKind::InvalidInput {
250                return Err(McpTransportError::TransportError(format!(
251                    "Failed to kill process: {}",
252                    e
253                )));
254            }
255        }
256
257        Ok(())
258    }
259}
260
261// =============================================================================
262// Async Stdio Transport (legacy wrapper, now uses TokioStdioTransport)
263// =============================================================================
264
265/// Async-friendly stdio transport with timeout support.
266///
267/// This is now a wrapper around `TokioStdioTransport` for backwards compatibility.
268pub struct AsyncStdioTransport {
269    inner: Arc<TokioStdioTransport>,
270}
271
272impl AsyncStdioTransport {
273    /// Spawn a new MCP server process.
274    pub async fn spawn(command: &str, args: &[String]) -> Result<Self, McpTransportError> {
275        Ok(Self {
276            inner: Arc::new(TokioStdioTransport::spawn(command, args).await?),
277        })
278    }
279
280    /// Spawn with environment variables.
281    pub async fn spawn_with_env(
282        command: &str,
283        args: &[String],
284        env: HashMap<String, String>,
285    ) -> Result<Self, McpTransportError> {
286        Ok(Self {
287            inner: Arc::new(TokioStdioTransport::spawn_with_env(command, args, env).await?),
288        })
289    }
290
291    /// Send a request with a timeout.
292    pub async fn send_request_with_timeout(
293        &self,
294        method: &str,
295        params: Option<Value>,
296        timeout_duration: Duration,
297    ) -> Result<Value, McpTransportError> {
298        self.inner
299            .send_request(method, params, timeout_duration)
300            .await
301    }
302
303    /// Check if alive.
304    pub fn is_alive(&self) -> bool {
305        self.inner.is_alive()
306    }
307
308    /// Stop the transport.
309    pub async fn stop(&self) -> Result<(), McpTransportError> {
310        self.inner.stop().await
311    }
312}
313
314/// Adapter that wraps AsyncStdioTransport and implements McpTransport.
315pub struct StdioTransportAdapter {
316    inner: AsyncStdioTransport,
317    timeout: Duration,
318}
319
320impl StdioTransportAdapter {
321    /// Create and initialize a new stdio transport.
322    pub async fn connect(
323        command: &str,
324        args: &[String],
325        config: Option<Value>,
326        timeout: Duration,
327    ) -> Result<Self, McpTransportError> {
328        Self::connect_with_env(command, args, HashMap::new(), config, timeout).await
329    }
330
331    /// Create and initialize with environment variables.
332    pub async fn connect_with_env(
333        command: &str,
334        args: &[String],
335        env: HashMap<String, String>,
336        config: Option<Value>,
337        timeout: Duration,
338    ) -> Result<Self, McpTransportError> {
339        let inner = AsyncStdioTransport::spawn_with_env(command, args, env).await?;
340
341        let adapter = Self { inner, timeout };
342
343        // Send initialize request
344        let init_params = InitializeParams::new(config);
345        let _init_result = adapter
346            .inner
347            .send_request_with_timeout(
348                "initialize",
349                Some(serde_json::to_value(&init_params)?),
350                adapter.timeout,
351            )
352            .await?;
353
354        // Send initialized notification
355        let _ = adapter
356            .inner
357            .send_request_with_timeout(
358                "notifications/initialized",
359                Some(serde_json::json!({})),
360                adapter.timeout,
361            )
362            .await;
363
364        Ok(adapter)
365    }
366}
367
368#[async_trait]
369impl McpTransport for StdioTransportAdapter {
370    async fn list_tools(&self) -> Result<Vec<McpToolDefinition>, McpTransportError> {
371        let result = self
372            .inner
373            .send_request_with_timeout("tools/list", Some(serde_json::json!({})), self.timeout)
374            .await?;
375
376        let list_result: ListToolsResult = serde_json::from_value(result)?;
377        Ok(list_result.tools)
378    }
379
380    async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError> {
381        let params = CallToolParams {
382            name: name.to_string(),
383            arguments: Some(args),
384            task: None,
385            meta: None,
386        };
387
388        let result = self
389            .inner
390            .send_request_with_timeout(
391                "tools/call",
392                Some(serde_json::to_value(&params)?),
393                self.timeout,
394            )
395            .await?;
396
397        let call_result: CallToolResult = serde_json::from_value(result)?;
398
399        if call_result.is_error == Some(true) {
400            let error_text = call_result
401                .content
402                .first()
403                .and_then(|c| c.as_text())
404                .unwrap_or("Unknown error");
405            return Err(McpTransportError::ServerError(error_text.to_string()));
406        }
407
408        let text = call_result
409            .content
410            .iter()
411            .filter_map(|c| c.as_text())
412            .collect::<Vec<_>>()
413            .join("\n");
414
415        Ok(Value::String(text))
416    }
417
418    async fn shutdown(&self) -> Result<(), McpTransportError> {
419        self.inner.stop().await
420    }
421
422    fn is_alive(&self) -> bool {
423        self.inner.is_alive()
424    }
425
426    fn transport_type(&self) -> TransportTypeId {
427        TransportTypeId::Stdio
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn test_transport_type() {
437        assert_eq!(TransportTypeId::Stdio.to_string(), "stdio");
438    }
439}