mcp_client_fishcode2025/transport/
stdio.rs1use std::collections::HashMap;
2use std::sync::Arc;
3use tokio::process::{Child, ChildStderr, ChildStdin, ChildStdout, Command};
4
5use async_trait::async_trait;
6use mcp_core_fishcode2025::protocol::JsonRpcMessage;
7use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader};
8use tokio::sync::{mpsc, Mutex};
9
10use super::{send_message, Error, PendingRequests, Transport, TransportHandle, TransportMessage};
11
12pub struct StdioActor {
16 receiver: mpsc::Receiver<TransportMessage>,
17 pending_requests: Arc<PendingRequests>,
18 _process: Child, error_sender: mpsc::Sender<Error>,
20 stdin: ChildStdin,
21 stdout: ChildStdout,
22 stderr: ChildStderr,
23}
24
25impl StdioActor {
26 pub async fn run(mut self) {
27 use tokio::pin;
28
29 let incoming = Self::handle_incoming_messages(self.stdout, self.pending_requests.clone());
30 let outgoing = Self::handle_outgoing_messages(
31 self.receiver,
32 self.stdin,
33 self.pending_requests.clone(),
34 );
35
36 pin!(incoming);
38 pin!(outgoing);
39
40 tokio::select! {
42 result = &mut incoming => {
43 tracing::debug!("Stdin handler completed: {:?}", result);
44 }
45 result = &mut outgoing => {
46 tracing::debug!("Stdout handler completed: {:?}", result);
47 }
48 status = self._process.wait() => {
50 tracing::debug!("Process exited with status: {:?}", status);
51 }
52 }
53
54 let mut stderr_buffer = Vec::new();
56 if let Ok(bytes) = self.stderr.read_to_end(&mut stderr_buffer).await {
57 let err_msg = if bytes > 0 {
58 String::from_utf8_lossy(&stderr_buffer).to_string()
59 } else {
60 "Process ended unexpectedly".to_string()
61 };
62
63 tracing::info!("Process stderr: {}", err_msg);
64 let _ = self
65 .error_sender
66 .send(Error::StdioProcessError(err_msg))
67 .await;
68 }
69
70 self.pending_requests.clear().await;
72 }
73
74 async fn handle_incoming_messages(stdout: ChildStdout, pending_requests: Arc<PendingRequests>) {
75 let mut reader = BufReader::new(stdout);
76 let mut line = String::new();
77 loop {
78 match reader.read_line(&mut line).await {
79 Ok(0) => {
80 tracing::error!("Child process ended (EOF on stdout)");
81 break;
82 } Ok(_) => {
84 if let Ok(message) = serde_json::from_str::<JsonRpcMessage>(&line) {
85 tracing::debug!(
86 message = ?message,
87 "Received incoming message"
88 );
89
90 match &message {
91 JsonRpcMessage::Response(response) => {
92 if let Some(id) = &response.id {
93 pending_requests.respond(&id.to_string(), Ok(message)).await;
94 }
95 }
96 JsonRpcMessage::Error(error) => {
97 if let Some(id) = &error.id {
98 pending_requests.respond(&id.to_string(), Ok(message)).await;
99 }
100 }
101 _ => {} }
103 }
104 line.clear();
105 }
106 Err(e) => {
107 tracing::error!(error = ?e, "Error reading line");
108 break;
109 }
110 }
111 }
112 }
113
114 async fn handle_outgoing_messages(
115 mut receiver: mpsc::Receiver<TransportMessage>,
116 mut stdin: ChildStdin,
117 pending_requests: Arc<PendingRequests>,
118 ) {
119 while let Some(mut transport_msg) = receiver.recv().await {
120 let message_str = match serde_json::to_string(&transport_msg.message) {
121 Ok(s) => s,
122 Err(e) => {
123 if let Some(tx) = transport_msg.response_tx.take() {
124 let _ = tx.send(Err(Error::Serialization(e)));
125 }
126 continue;
127 }
128 };
129
130 tracing::debug!(message = ?transport_msg.message, "Sending outgoing message");
131
132 if let Some(response_tx) = transport_msg.response_tx.take() {
133 if let JsonRpcMessage::Request(request) = &transport_msg.message {
134 if let Some(id) = &request.id {
135 pending_requests.insert(id.to_string(), response_tx).await;
136 }
137 }
138 }
139
140 if let Err(e) = stdin
141 .write_all(format!("{}\n", message_str).as_bytes())
142 .await
143 {
144 tracing::error!(error = ?e, "Error writing message to child process");
145 pending_requests.clear().await;
146 break;
147 }
148
149 if let Err(e) = stdin.flush().await {
150 tracing::error!(error = ?e, "Error flushing message to child process");
151 pending_requests.clear().await;
152 break;
153 }
154 }
155 }
156}
157
158#[derive(Clone)]
159pub struct StdioTransportHandle {
160 sender: mpsc::Sender<TransportMessage>,
161 error_receiver: Arc<Mutex<mpsc::Receiver<Error>>>,
162}
163
164#[async_trait::async_trait]
165impl TransportHandle for StdioTransportHandle {
166 async fn send(&self, message: JsonRpcMessage) -> Result<JsonRpcMessage, Error> {
167 let result = send_message(&self.sender, message).await;
168 self.check_for_errors().await?;
170 result
171 }
172}
173
174impl StdioTransportHandle {
175 pub async fn check_for_errors(&self) -> Result<(), Error> {
177 match self.error_receiver.lock().await.try_recv() {
178 Ok(error) => {
179 tracing::debug!("Found error: {:?}", error);
180 Err(error)
181 }
182 Err(_) => Ok(()),
183 }
184 }
185}
186
187pub struct StdioTransport {
188 command: String,
189 args: Vec<String>,
190 env: HashMap<String, String>,
191}
192
193impl StdioTransport {
194 pub fn new<S: Into<String>>(
195 command: S,
196 args: Vec<String>,
197 env: HashMap<String, String>,
198 ) -> Self {
199 Self {
200 command: command.into(),
201 args,
202 env,
203 }
204 }
205
206 async fn spawn_process(&self) -> Result<(Child, ChildStdin, ChildStdout, ChildStderr), Error> {
207 let mut command = Command::new(&self.command);
208 command
209 .envs(&self.env)
210 .args(&self.args)
211 .stdin(std::process::Stdio::piped())
212 .stdout(std::process::Stdio::piped())
213 .stderr(std::process::Stdio::piped())
214 .kill_on_drop(true);
215
216 #[cfg(unix)]
218 command.process_group(0); #[cfg(windows)]
222 command.creation_flags(0x08000000); let mut process = command
225 .spawn()
226 .map_err(|e| Error::StdioProcessError(e.to_string()))?;
227
228 let stdin = process
229 .stdin
230 .take()
231 .ok_or_else(|| Error::StdioProcessError("Failed to get stdin".into()))?;
232
233 let stdout = process
234 .stdout
235 .take()
236 .ok_or_else(|| Error::StdioProcessError("Failed to get stdout".into()))?;
237
238 let stderr = process
239 .stderr
240 .take()
241 .ok_or_else(|| Error::StdioProcessError("Failed to get stderr".into()))?;
242
243 Ok((process, stdin, stdout, stderr))
244 }
245}
246
247#[async_trait]
248impl Transport for StdioTransport {
249 type Handle = StdioTransportHandle;
250
251 async fn start(&self) -> Result<Self::Handle, Error> {
252 let (process, stdin, stdout, stderr) = self.spawn_process().await?;
253 let (message_tx, message_rx) = mpsc::channel(32);
254 let (error_tx, error_rx) = mpsc::channel(1);
255
256 let actor = StdioActor {
257 receiver: message_rx,
258 pending_requests: Arc::new(PendingRequests::new()),
259 _process: process,
260 error_sender: error_tx,
261 stdin,
262 stdout,
263 stderr,
264 };
265
266 tokio::spawn(actor.run());
267
268 let handle = StdioTransportHandle {
269 sender: message_tx,
270 error_receiver: Arc::new(Mutex::new(error_rx)),
271 };
272 Ok(handle)
273 }
274
275 async fn close(&self) -> Result<(), Error> {
276 Ok(())
277 }
278}