mcp_client_fishcode2025/transport/
stdio.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
4
5use async_trait::async_trait;
6use mcp_core_fishcode2025::protocol::JsonRpcMessage;
7use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
8use tokio::sync::{mpsc, Mutex};
9
10use super::{send_message, Error, PendingRequests, Transport, TransportHandle, TransportMessage};
11
12/// A `StdioTransport` uses a child process's stdin/stdout as a communication channel.
13///
14/// It uses channels for message passing and handles responses asynchronously through a background task.
15pub struct StdioActor {
16    receiver: mpsc::Receiver<TransportMessage>,
17    pending_requests: Arc<PendingRequests>,
18    _process: Child, // we store the process to keep it alive
19    error_sender: mpsc::Sender<Error>,
20    stdin: ChildStdin,
21    stdout: ChildStdout,
22    stderr: ChildStderr,
23}
24
25impl StdioActor {
26    pub async fn run(mut self) {
27        use tokio::pin;
28
29        let incoming = Self::handle_incoming_messages(self.stdout, self.pending_requests.clone());
30        let outgoing = Self::handle_outgoing_messages(
31            self.receiver,
32            self.stdin,
33            self.pending_requests.clone(),
34        );
35
36        // take ownership of futures for tokio::select
37        pin!(incoming);
38        pin!(outgoing);
39
40        // Use select! to wait for either I/O completion or process exit
41        tokio::select! {
42            result = &mut incoming => {
43                tracing::debug!("Stdin handler completed: {:?}", result);
44            }
45            result = &mut outgoing => {
46                tracing::debug!("Stdout handler completed: {:?}", result);
47            }
48            // capture the status so we don't need to wait for a timeout
49            status = self._process.wait() => {
50                tracing::debug!("Process exited with status: {:?}", status);
51            }
52        }
53
54        // Then always try to read stderr before cleaning up
55        let mut stderr_buffer = Vec::new();
56        if let Ok(bytes) = self.stderr.read_to_end(&mut stderr_buffer).await {
57            let err_msg = if bytes > 0 {
58                String::from_utf8_lossy(&stderr_buffer).to_string()
59            } else {
60                "Process ended unexpectedly".to_string()
61            };
62
63            tracing::info!("Process stderr: {}", err_msg);
64            let _ = self
65                .error_sender
66                .send(Error::StdioProcessError(err_msg))
67                .await;
68        }
69
70        // Clean up regardless of which path we took
71        self.pending_requests.clear().await;
72    }
73
74    async fn handle_incoming_messages(stdout: ChildStdout, pending_requests: Arc<PendingRequests>) {
75        let mut reader = BufReader::new(stdout);
76        let mut line = String::new();
77        loop {
78            match reader.read_line(&mut line).await {
79                Ok(0) => {
80                    tracing::error!("Child process ended (EOF on stdout)");
81                    break;
82                } // EOF
83                Ok(_) => {
84                    if let Ok(message) = serde_json::from_str::<JsonRpcMessage>(&line) {
85                        tracing::debug!(
86                            message = ?message,
87                            "Received incoming message"
88                        );
89
90                        match &message {
91                            JsonRpcMessage::Response(response) => {
92                                if let Some(id) = &response.id {
93                                    pending_requests.respond(&id.to_string(), Ok(message)).await;
94                                }
95                            }
96                            JsonRpcMessage::Error(error) => {
97                                if let Some(id) = &error.id {
98                                    pending_requests.respond(&id.to_string(), Ok(message)).await;
99                                }
100                            }
101                            _ => {} // TODO: Handle other variants (Request, etc.)
102                        }
103                    }
104                    line.clear();
105                }
106                Err(e) => {
107                    tracing::error!(error = ?e, "Error reading line");
108                    break;
109                }
110            }
111        }
112    }
113
114    async fn handle_outgoing_messages(
115        mut receiver: mpsc::Receiver<TransportMessage>,
116        mut stdin: ChildStdin,
117        pending_requests: Arc<PendingRequests>,
118    ) {
119        while let Some(mut transport_msg) = receiver.recv().await {
120            let message_str = match serde_json::to_string(&transport_msg.message) {
121                Ok(s) => s,
122                Err(e) => {
123                    if let Some(tx) = transport_msg.response_tx.take() {
124                        let _ = tx.send(Err(Error::Serialization(e)));
125                    }
126                    continue;
127                }
128            };
129
130            tracing::debug!(message = ?transport_msg.message, "Sending outgoing message");
131
132            if let Some(response_tx) = transport_msg.response_tx.take() {
133                if let JsonRpcMessage::Request(request) = &transport_msg.message {
134                    if let Some(id) = &request.id {
135                        pending_requests.insert(id.to_string(), response_tx).await;
136                    }
137                }
138            }
139
140            if let Err(e) = stdin
141                .write_all(format!("{}\n", message_str).as_bytes())
142                .await
143            {
144                tracing::error!(error = ?e, "Error writing message to child process");
145                pending_requests.clear().await;
146                break;
147            }
148
149            if let Err(e) = stdin.flush().await {
150                tracing::error!(error = ?e, "Error flushing message to child process");
151                pending_requests.clear().await;
152                break;
153            }
154        }
155    }
156}
157
158#[derive(Clone)]
159pub struct StdioTransportHandle {
160    sender: mpsc::Sender<TransportMessage>,
161    error_receiver: Arc<Mutex<mpsc::Receiver<Error>>>,
162}
163
164#[async_trait::async_trait]
165impl TransportHandle for StdioTransportHandle {
166    async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error> {
167        let result = send_message(&self.sender, message).await;
168        // Check for any pending errors even if send is successful
169        self.check_for_errors().await?;
170        result
171    }
172}
173
174impl StdioTransportHandle {
175    /// Check if there are any process errors
176    pub async fn check_for_errors(&self) -> Result<(), Error> {
177        match self.error_receiver.lock().await.try_recv() {
178            Ok(error) => {
179                tracing::debug!("Found error: {:?}", error);
180                Err(error)
181            }
182            Err(_) => Ok(()),
183        }
184    }
185}
186
187pub struct StdioTransport {
188    command: String,
189    args: Vec<String>,
190    env: HashMap<String, String>,
191}
192
193impl StdioTransport {
194    pub fn new<S: Into<String>>(
195        command: S,
196        args: Vec<String>,
197        env: HashMap<String, String>,
198    ) -> Self {
199        Self {
200            command: command.into(),
201            args,
202            env,
203        }
204    }
205
206    async fn spawn_process(&self) -> Result<(Child, ChildStdin, ChildStdout, ChildStderr), Error> {
207        let mut command = Command::new(&self.command);
208        command
209            .envs(&self.env)
210            .args(&self.args)
211            .stdin(std::process::Stdio::piped())
212            .stdout(std::process::Stdio::piped())
213            .stderr(std::process::Stdio::piped())
214            .kill_on_drop(true);
215
216        // Set process group only on Unix systems
217        #[cfg(unix)]
218        command.process_group(0); // don't inherit signal handling from parent process
219
220        // Hide console window on Windows
221        #[cfg(windows)]
222        command.creation_flags(0x08000000); // CREATE_NO_WINDOW flag
223
224        let mut process = command
225            .spawn()
226            .map_err(|e| Error::StdioProcessError(e.to_string()))?;
227
228        let stdin = process
229            .stdin
230            .take()
231            .ok_or_else(|| Error::StdioProcessError("Failed to get stdin".into()))?;
232
233        let stdout = process
234            .stdout
235            .take()
236            .ok_or_else(|| Error::StdioProcessError("Failed to get stdout".into()))?;
237
238        let stderr = process
239            .stderr
240            .take()
241            .ok_or_else(|| Error::StdioProcessError("Failed to get stderr".into()))?;
242
243        Ok((process, stdin, stdout, stderr))
244    }
245}
246
247#[async_trait]
248impl Transport for StdioTransport {
249    type Handle = StdioTransportHandle;
250
251    async fn start(&self) -> Result<Self::Handle, Error> {
252        let (process, stdin, stdout, stderr) = self.spawn_process().await?;
253        let (message_tx, message_rx) = mpsc::channel(32);
254        let (error_tx, error_rx) = mpsc::channel(1);
255
256        let actor = StdioActor {
257            receiver: message_rx,
258            pending_requests: Arc::new(PendingRequests::new()),
259            _process: process,
260            error_sender: error_tx,
261            stdin,
262            stdout,
263            stderr,
264        };
265
266        tokio::spawn(actor.run());
267
268        let handle = StdioTransportHandle {
269            sender: message_tx,
270            error_receiver: Arc::new(Mutex::new(error_rx)),
271        };
272        Ok(handle)
273    }
274
275    async fn close(&self) -> Result<(), Error> {
276        Ok(())
277    }
278}