Skip to main content

model_context_protocol/
stdio.rs

1//! Stdio Transport for MCP Servers
2//!
3//! Communicates with MCP servers via standard input/output using JSON-RPC.
4//! This is used for MCP servers that run as child processes.
5
6use async_trait::async_trait;
7use serde_json::Value;
8use std::io::{BufRead, BufReader, Write};
9use std::process::{Child, Command, Stdio};
10use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
11use std::sync::{Arc, Mutex};
12use std::time::Duration;
13use tokio::sync::oneshot;
14use tokio::time::timeout;
15
16use crate::protocol::*;
17use crate::transport::{InitializeParams, McpTransport, McpTransportError, TransportTypeId};
18
19/// Stdio-based MCP transport for communicating with child processes.
20pub struct StdioTransport {
21    process: Arc<Mutex<Child>>,
22    next_id: Arc<AtomicI64>,
23    alive: Arc<AtomicBool>,
24}
25
26impl StdioTransport {
27    /// Spawn a new MCP server process.
28    pub fn spawn(command: &str, args: &[String]) -> Result<Self, McpTransportError> {
29        Self::spawn_with_env(command, args, std::collections::HashMap::new())
30    }
31
32    /// Spawn a new MCP server process with environment variables.
33    pub fn spawn_with_env(
34        command: &str,
35        args: &[String],
36        env: std::collections::HashMap<String, String>,
37    ) -> Result<Self, McpTransportError> {
38        let mut cmd = Command::new(command);
39        cmd.args(args)
40            .stdin(Stdio::piped())
41            .stdout(Stdio::piped())
42            .stderr(Stdio::piped());
43
44        for (key, value) in env {
45            cmd.env(key, value);
46        }
47
48        let child = cmd.spawn().map_err(|e| {
49            McpTransportError::TransportError(format!(
50                "Failed to spawn process '{}': {}",
51                command, e
52            ))
53        })?;
54
55        // Verify process is running
56        let mut process = child;
57        if let Some(status) = process.try_wait().map_err(|e| {
58            McpTransportError::TransportError(format!("Process check failed: {}", e))
59        })? {
60            return Err(McpTransportError::TransportError(format!(
61                "Process exited immediately with status: {}",
62                status
63            )));
64        }
65
66        Ok(Self {
67            process: Arc::new(Mutex::new(process)),
68            next_id: Arc::new(AtomicI64::new(1)),
69            alive: Arc::new(AtomicBool::new(true)),
70        })
71    }
72
73    /// Send a JSON-RPC request and wait for response (blocking).
74    pub fn send_request_sync(
75        &self,
76        method: &str,
77        params: Option<Value>,
78    ) -> Result<Value, McpTransportError> {
79        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
80        let request = JsonRpcRequest::new(JsonRpcId::Number(id), method, params);
81
82        let mut process = self
83            .process
84            .lock()
85            .map_err(|e| McpTransportError::TransportError(format!("Lock error: {}", e)))?;
86
87        // Get stdin
88        let stdin = process
89            .stdin
90            .as_mut()
91            .ok_or_else(|| McpTransportError::TransportError("Failed to get stdin".to_string()))?;
92
93        // Serialize and send request
94        let request_json = serde_json::to_string(&request)?;
95
96        writeln!(stdin, "{}", request_json).map_err(|e| McpTransportError::IoError(e))?;
97
98        stdin.flush().map_err(|e| McpTransportError::IoError(e))?;
99
100        // Read response from stdout
101        let stdout = process
102            .stdout
103            .as_mut()
104            .ok_or_else(|| McpTransportError::TransportError("Failed to get stdout".to_string()))?;
105
106        let mut reader = BufReader::new(stdout);
107        let mut response_line = String::new();
108
109        reader
110            .read_line(&mut response_line)
111            .map_err(|e| McpTransportError::IoError(e))?;
112
113        if response_line.is_empty() {
114            self.alive.store(false, Ordering::SeqCst);
115            return Err(McpTransportError::ConnectionClosed);
116        }
117
118        // Parse response
119        let response: JsonRpcResponse = serde_json::from_str(&response_line)?;
120
121        // Extract result or error
122        match response.payload {
123            JsonRpcPayload::Success { result } => Ok(result),
124            JsonRpcPayload::Error { error } => Err(McpTransportError::ServerError(format!(
125                "MCP Error: {}",
126                error
127            ))),
128        }
129    }
130
131    /// Check if the process is still running.
132    pub fn is_alive(&self) -> bool {
133        if !self.alive.load(Ordering::SeqCst) {
134            return false;
135        }
136
137        if let Ok(mut process) = self.process.lock() {
138            let alive = process.try_wait().ok().flatten().is_none();
139            self.alive.store(alive, Ordering::SeqCst);
140            alive
141        } else {
142            false
143        }
144    }
145
146    /// Stop the process.
147    pub fn stop(&self) -> Result<(), McpTransportError> {
148        self.alive.store(false, Ordering::SeqCst);
149
150        let mut process = self
151            .process
152            .lock()
153            .map_err(|e| McpTransportError::TransportError(format!("Lock error: {}", e)))?;
154
155        process.kill().map_err(|e| McpTransportError::IoError(e))?;
156
157        process.wait().map_err(|e| McpTransportError::IoError(e))?;
158
159        Ok(())
160    }
161}
162
163impl Drop for StdioTransport {
164    fn drop(&mut self) {
165        let _ = self.stop();
166    }
167}
168
169/// Async-friendly stdio transport with timeout support.
170pub struct AsyncStdioTransport {
171    inner: StdioTransport,
172}
173
174impl AsyncStdioTransport {
175    /// Spawn a new MCP server process.
176    pub fn spawn(command: &str, args: &[String]) -> Result<Self, McpTransportError> {
177        Ok(Self {
178            inner: StdioTransport::spawn(command, args)?,
179        })
180    }
181
182    /// Spawn with environment variables.
183    pub fn spawn_with_env(
184        command: &str,
185        args: &[String],
186        env: std::collections::HashMap<String, String>,
187    ) -> Result<Self, McpTransportError> {
188        Ok(Self {
189            inner: StdioTransport::spawn_with_env(command, args, env)?,
190        })
191    }
192
193    /// Send a request with a timeout.
194    pub async fn send_request_with_timeout(
195        &self,
196        method: &str,
197        params: Option<Value>,
198        timeout_duration: Duration,
199    ) -> Result<Value, McpTransportError> {
200        let method = method.to_string();
201        let process = Arc::clone(&self.inner.process);
202        let next_id = Arc::clone(&self.inner.next_id);
203        let alive = Arc::clone(&self.inner.alive);
204
205        let (tx, rx) = oneshot::channel();
206
207        // Spawn blocking task
208        tokio::task::spawn_blocking(move || {
209            let id = next_id.fetch_add(1, Ordering::SeqCst);
210            let request = JsonRpcRequest::new(JsonRpcId::Number(id), method, params);
211
212            let result: Result<Value, McpTransportError> = (|| {
213                let mut process = process
214                    .lock()
215                    .map_err(|e| McpTransportError::TransportError(format!("Lock error: {}", e)))?;
216
217                let stdin = process.stdin.as_mut().ok_or_else(|| {
218                    McpTransportError::TransportError("Failed to get stdin".to_string())
219                })?;
220
221                let request_json = serde_json::to_string(&request)?;
222
223                writeln!(stdin, "{}", request_json).map_err(|e| McpTransportError::IoError(e))?;
224
225                stdin.flush().map_err(|e| McpTransportError::IoError(e))?;
226
227                let stdout = process.stdout.as_mut().ok_or_else(|| {
228                    McpTransportError::TransportError("Failed to get stdout".to_string())
229                })?;
230
231                let mut reader = BufReader::new(stdout);
232                let mut response_line = String::new();
233
234                reader
235                    .read_line(&mut response_line)
236                    .map_err(|e| McpTransportError::IoError(e))?;
237
238                if response_line.is_empty() {
239                    alive.store(false, Ordering::SeqCst);
240                    return Err(McpTransportError::ConnectionClosed);
241                }
242
243                let response: JsonRpcResponse = serde_json::from_str(&response_line)?;
244
245                match response.payload {
246                    JsonRpcPayload::Success { result } => Ok(result),
247                    JsonRpcPayload::Error { error } => Err(McpTransportError::ServerError(
248                        format!("MCP Error: {}", error),
249                    )),
250                }
251            })();
252
253            let _ = tx.send(result);
254        });
255
256        // Wait with timeout
257        match timeout(timeout_duration, rx).await {
258            Ok(Ok(result)) => result,
259            Ok(Err(_)) => Err(McpTransportError::TransportError(
260                "Channel closed".to_string(),
261            )),
262            Err(_) => Err(McpTransportError::Timeout(format!(
263                "Request timed out after {:?}",
264                timeout_duration
265            ))),
266        }
267    }
268
269    /// Check if alive.
270    pub fn is_alive(&self) -> bool {
271        self.inner.is_alive()
272    }
273
274    /// Stop the transport.
275    pub fn stop(&self) -> Result<(), McpTransportError> {
276        self.inner.stop()
277    }
278}
279
280/// Adapter that wraps AsyncStdioTransport and implements McpTransport.
281pub struct StdioTransportAdapter {
282    inner: AsyncStdioTransport,
283    timeout: Duration,
284}
285
286impl StdioTransportAdapter {
287    /// Create and initialize a new stdio transport.
288    pub async fn connect(
289        command: &str,
290        args: &[String],
291        config: Option<Value>,
292        timeout: Duration,
293    ) -> Result<Self, McpTransportError> {
294        Self::connect_with_env(
295            command,
296            args,
297            std::collections::HashMap::new(),
298            config,
299            timeout,
300        )
301        .await
302    }
303
304    /// Create and initialize with environment variables.
305    pub async fn connect_with_env(
306        command: &str,
307        args: &[String],
308        env: std::collections::HashMap<String, String>,
309        config: Option<Value>,
310        timeout: Duration,
311    ) -> Result<Self, McpTransportError> {
312        let inner = AsyncStdioTransport::spawn_with_env(command, args, env)?;
313
314        let adapter = Self { inner, timeout };
315
316        // Send initialize request
317        let init_params = InitializeParams::new(config);
318        let _init_result = adapter
319            .inner
320            .send_request_with_timeout(
321                "initialize",
322                Some(serde_json::to_value(&init_params)?),
323                adapter.timeout,
324            )
325            .await?;
326
327        // Send initialized notification (no response expected, but we send it)
328        // Some servers expect this
329        let _ = adapter
330            .inner
331            .send_request_with_timeout(
332                "notifications/initialized",
333                Some(serde_json::json!({})),
334                adapter.timeout,
335            )
336            .await;
337
338        Ok(adapter)
339    }
340}
341
342#[async_trait]
343impl McpTransport for StdioTransportAdapter {
344    async fn list_tools(&self) -> Result<Vec<ToolDefinition>, McpTransportError> {
345        let result = self
346            .inner
347            .send_request_with_timeout("tools/list", Some(serde_json::json!({})), self.timeout)
348            .await?;
349
350        let list_result: ListToolsResult = serde_json::from_value(result)?;
351
352        Ok(list_result
353            .tools
354            .into_iter()
355            .map(ToolDefinition::from)
356            .collect())
357    }
358
359    async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError> {
360        let params = CallToolParams {
361            name: name.to_string(),
362            arguments: Some(args),
363        };
364
365        let result = self
366            .inner
367            .send_request_with_timeout(
368                "tools/call",
369                Some(serde_json::to_value(&params)?),
370                self.timeout,
371            )
372            .await?;
373
374        let call_result: CallToolResult = serde_json::from_value(result)?;
375
376        if call_result.is_error == Some(true) {
377            let error_text = call_result
378                .content
379                .first()
380                .and_then(|c| c.as_text())
381                .unwrap_or("Unknown error");
382            return Err(McpTransportError::ServerError(error_text.to_string()));
383        }
384
385        let text = call_result
386            .content
387            .iter()
388            .filter_map(|c| c.as_text())
389            .collect::<Vec<_>>()
390            .join("\n");
391
392        Ok(Value::String(text))
393    }
394
395    async fn shutdown(&self) -> Result<(), McpTransportError> {
396        self.inner.stop()
397    }
398
399    fn is_alive(&self) -> bool {
400        self.inner.is_alive()
401    }
402
403    fn transport_type(&self) -> TransportTypeId {
404        TransportTypeId::Stdio
405    }
406}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    // Note: These tests require actual processes to be available
413    // In practice, you'd use mock servers or skip these in CI
414
415    #[test]
416    fn test_transport_type() {
417        // We can't easily test spawn without a real server,
418        // but we can test the type system
419        assert_eq!(TransportTypeId::Stdio.to_string(), "stdio");
420    }
421}