mcp_core/transport/client/
stdio.rs

1use crate::protocol::{Protocol, ProtocolBuilder, RequestOptions};
2use crate::transport::{
3    JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Message, RequestId,
4    Transport,
5};
6use crate::types::ErrorCode;
7use anyhow::Result;
8use async_trait::async_trait;
9use std::future::Future;
10use std::io::{BufRead, BufReader, BufWriter, Write};
11use std::pin::Pin;
12use std::process::Command;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15use tokio::time::timeout;
16use tracing::debug;
17
18/// ClientStdioTransport launches a child process and communicates with it via stdio.
19#[derive(Clone)]
20pub struct ClientStdioTransport {
21    protocol: Protocol,
22    stdin: Arc<Mutex<Option<BufWriter<std::process::ChildStdin>>>>,
23    stdout: Arc<Mutex<Option<BufReader<std::process::ChildStdout>>>>,
24    child: Arc<Mutex<Option<std::process::Child>>>,
25    program: String,
26    args: Vec<String>,
27}
28
29impl ClientStdioTransport {
30    pub fn new(program: &str, args: &[&str]) -> Result<Self> {
31        Ok(ClientStdioTransport {
32            protocol: ProtocolBuilder::new().build(),
33            stdin: Arc::new(Mutex::new(None)),
34            stdout: Arc::new(Mutex::new(None)),
35            child: Arc::new(Mutex::new(None)),
36            program: program.to_string(),
37            args: args.iter().map(|&s| s.to_string()).collect(),
38        })
39    }
40}
41
42#[async_trait()]
43impl Transport for ClientStdioTransport {
44    async fn open(&self) -> Result<()> {
45        debug!("ClientStdioTransport: Opening transport");
46        let mut child = Command::new(&self.program)
47            .args(&self.args)
48            .stdin(std::process::Stdio::piped())
49            .stdout(std::process::Stdio::piped())
50            .spawn()?;
51
52        let stdin = child
53            .stdin
54            .take()
55            .ok_or_else(|| anyhow::anyhow!("Child process stdin not available"))?;
56        let stdout = child
57            .stdout
58            .take()
59            .ok_or_else(|| anyhow::anyhow!("Child process stdout not available"))?;
60
61        {
62            let mut stdin_lock = self.stdin.lock().await;
63            *stdin_lock = Some(BufWriter::new(stdin));
64        }
65        {
66            let mut stdout_lock = self.stdout.lock().await;
67            *stdout_lock = Some(BufReader::new(stdout));
68        }
69        {
70            let mut child_lock = self.child.lock().await;
71            *child_lock = Some(child);
72        }
73
74        // Spawn a background task to continuously poll messages.
75        let transport_clone = self.clone();
76        tokio::spawn(async move {
77            loop {
78                match transport_clone.poll_message().await {
79                    Ok(Some(message)) => match message {
80                        Message::Request(request) => {
81                            let response = transport_clone.protocol.handle_request(request).await;
82                            let _ = transport_clone
83                                .send_response(response.id, response.result, response.error)
84                                .await;
85                        }
86                        Message::Notification(notification) => {
87                            let _ = transport_clone
88                                .protocol
89                                .handle_notification(notification)
90                                .await;
91                        }
92                        Message::Response(response) => {
93                            transport_clone.protocol.handle_response(response).await;
94                        }
95                    },
96                    Ok(None) => break, // EOF encountered.
97                    Err(e) => {
98                        debug!("ClientStdioTransport: Error polling message: {:?}", e);
99                        break;
100                    }
101                }
102            }
103        });
104        Ok(())
105    }
106
107    async fn close(&self) -> Result<()> {
108        let mut child_lock = self.child.lock().await;
109        if let Some(child) = child_lock.as_mut() {
110            let _ = child.kill();
111        }
112        *child_lock = None;
113
114        // Clear stdin and stdout
115        *self.stdin.lock().await = None;
116        *self.stdout.lock().await = None;
117
118        Ok(())
119    }
120
121    async fn poll_message(&self) -> Result<Option<Message>> {
122        debug!("ClientStdioTransport: Starting to receive message");
123
124        // Take ownership of stdout temporarily
125        let mut stdout_guard = self.stdout.lock().await;
126        let mut stdout = stdout_guard
127            .take()
128            .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?;
129
130        // Drop the lock before spawning the blocking task
131        drop(stdout_guard);
132
133        // Use a blocking operation in a spawn_blocking task
134        let (line_result, stdout) = tokio::task::spawn_blocking(move || {
135            let mut line = String::new();
136            let result = match stdout.read_line(&mut line) {
137                Ok(0) => Ok(None), // EOF
138                Ok(_) => Ok(Some(line)),
139                Err(e) => Err(anyhow::anyhow!("Error reading line: {}", e)),
140            };
141            // Return both the result and the stdout so we can put it back
142            (result, stdout)
143        })
144        .await?;
145
146        // Put stdout back
147        let mut stdout_guard = self.stdout.lock().await;
148        *stdout_guard = Some(stdout);
149
150        // Process the result
151        match line_result? {
152            Some(line) => {
153                debug!(
154                    "ClientStdioTransport: Received from process: {}",
155                    line.trim()
156                );
157                let message: Message = serde_json::from_str(&line)?;
158                debug!("ClientStdioTransport: Successfully parsed message");
159                Ok(Some(message))
160            }
161            None => {
162                debug!("ClientStdioTransport: Received EOF from process");
163                Ok(None)
164            }
165        }
166    }
167
168    fn request(
169        &self,
170        method: &str,
171        params: Option<serde_json::Value>,
172        options: RequestOptions,
173    ) -> Pin<Box<dyn Future<Output = Result<JsonRpcResponse>> + Send + Sync>> {
174        let protocol = self.protocol.clone();
175        let stdin_arc = self.stdin.clone();
176        let method = method.to_owned();
177        Box::pin(async move {
178            let (id, rx) = protocol.create_request().await;
179            let request = JsonRpcRequest {
180                id,
181                method,
182                jsonrpc: Default::default(),
183                params,
184            };
185            let serialized = serde_json::to_string(&request)?;
186            debug!("ClientStdioTransport: Sending request: {}", serialized);
187
188            // Get the stdin writer
189            let mut stdin_guard = stdin_arc.lock().await;
190            let mut stdin = stdin_guard
191                .take()
192                .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?;
193
194            // Use a blocking operation in a spawn_blocking task
195            let stdin_result = tokio::task::spawn_blocking(move || {
196                stdin.write_all(serialized.as_bytes())?;
197                stdin.write_all(b"\n")?;
198                stdin.flush()?;
199                Ok::<_, anyhow::Error>(stdin)
200            })
201            .await??;
202
203            // Put the writer back
204            *stdin_guard = Some(stdin_result);
205
206            debug!("ClientStdioTransport: Request sent successfully");
207            let result = timeout(options.timeout, rx).await;
208            match result {
209                Ok(inner_result) => match inner_result {
210                    Ok(response) => Ok(response),
211                    Err(_) => {
212                        protocol.cancel_response(id).await;
213                        Ok(JsonRpcResponse {
214                            id,
215                            result: None,
216                            error: Some(JsonRpcError {
217                                code: ErrorCode::RequestTimeout as i32,
218                                message: "Request cancelled".to_string(),
219                                data: None,
220                            }),
221                            ..Default::default()
222                        })
223                    }
224                },
225                Err(_) => {
226                    protocol.cancel_response(id).await;
227                    Ok(JsonRpcResponse {
228                        id,
229                        result: None,
230                        error: Some(JsonRpcError {
231                            code: ErrorCode::RequestTimeout as i32,
232                            message: "Request timed out".to_string(),
233                            data: None,
234                        }),
235                        ..Default::default()
236                    })
237                }
238            }
239        })
240    }
241
242    async fn send_response(
243        &self,
244        id: RequestId,
245        result: Option<serde_json::Value>,
246        error: Option<JsonRpcError>,
247    ) -> Result<()> {
248        let response = JsonRpcResponse {
249            id,
250            result,
251            error,
252            jsonrpc: Default::default(),
253        };
254        let serialized = serde_json::to_string(&response)?;
255        debug!("ClientStdioTransport: Sending response: {}", serialized);
256
257        // Get the stdin writer
258        let mut stdin_guard = self.stdin.lock().await;
259        let mut stdin = stdin_guard
260            .take()
261            .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?;
262
263        // Use a blocking operation in a spawn_blocking task
264        let stdin_result = tokio::task::spawn_blocking(move || {
265            stdin.write_all(serialized.as_bytes())?;
266            stdin.write_all(b"\n")?;
267            stdin.flush()?;
268            Ok::<_, anyhow::Error>(stdin)
269        })
270        .await??;
271
272        // Put the writer back
273        *stdin_guard = Some(stdin_result);
274
275        Ok(())
276    }
277
278    async fn send_notification(
279        &self,
280        method: &str,
281        params: Option<serde_json::Value>,
282    ) -> Result<()> {
283        let notification = JsonRpcNotification {
284            jsonrpc: Default::default(),
285            method: method.to_owned(),
286            params,
287        };
288        let serialized = serde_json::to_string(&notification)?;
289        debug!("ClientStdioTransport: Sending notification: {}", serialized);
290
291        // Get the stdin writer
292        let mut stdin_guard = self.stdin.lock().await;
293        let mut stdin = stdin_guard
294            .take()
295            .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?;
296
297        // Use a blocking operation in a spawn_blocking task
298        let stdin_result = tokio::task::spawn_blocking(move || {
299            stdin.write_all(serialized.as_bytes())?;
300            stdin.write_all(b"\n")?;
301            stdin.flush()?;
302            Ok::<_, anyhow::Error>(stdin)
303        })
304        .await??;
305
306        // Put the writer back
307        *stdin_guard = Some(stdin_result);
308
309        Ok(())
310    }
311}