mcp_core/transport/client/
stdio.rs1use crate::protocol::{Protocol, ProtocolBuilder, RequestOptions};
2use crate::transport::{
3 JsonRpcError, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, Message, RequestId,
4 Transport,
5};
6use crate::types::ErrorCode;
7use anyhow::Result;
8use async_trait::async_trait;
9use std::future::Future;
10use std::io::{BufRead, BufReader, BufWriter, Write};
11use std::pin::Pin;
12use std::process::Command;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15use tokio::time::timeout;
16use tracing::debug;
17
18#[derive(Clone)]
20pub struct ClientStdioTransport {
21 protocol: Protocol,
22 stdin: Arc<Mutex<Option<BufWriter<std::process::ChildStdin>>>>,
23 stdout: Arc<Mutex<Option<BufReader<std::process::ChildStdout>>>>,
24 child: Arc<Mutex<Option<std::process::Child>>>,
25 program: String,
26 args: Vec<String>,
27}
28
29impl ClientStdioTransport {
30 pub fn new(program: &str, args: &[&str]) -> Result<Self> {
31 Ok(ClientStdioTransport {
32 protocol: ProtocolBuilder::new().build(),
33 stdin: Arc::new(Mutex::new(None)),
34 stdout: Arc::new(Mutex::new(None)),
35 child: Arc::new(Mutex::new(None)),
36 program: program.to_string(),
37 args: args.iter().map(|&s| s.to_string()).collect(),
38 })
39 }
40}
41
42#[async_trait()]
43impl Transport for ClientStdioTransport {
44 async fn open(&self) -> Result<()> {
45 debug!("ClientStdioTransport: Opening transport");
46 let mut child = Command::new(&self.program)
47 .args(&self.args)
48 .stdin(std::process::Stdio::piped())
49 .stdout(std::process::Stdio::piped())
50 .spawn()?;
51
52 let stdin = child
53 .stdin
54 .take()
55 .ok_or_else(|| anyhow::anyhow!("Child process stdin not available"))?;
56 let stdout = child
57 .stdout
58 .take()
59 .ok_or_else(|| anyhow::anyhow!("Child process stdout not available"))?;
60
61 {
62 let mut stdin_lock = self.stdin.lock().await;
63 *stdin_lock = Some(BufWriter::new(stdin));
64 }
65 {
66 let mut stdout_lock = self.stdout.lock().await;
67 *stdout_lock = Some(BufReader::new(stdout));
68 }
69 {
70 let mut child_lock = self.child.lock().await;
71 *child_lock = Some(child);
72 }
73
74 let transport_clone = self.clone();
76 tokio::spawn(async move {
77 loop {
78 match transport_clone.poll_message().await {
79 Ok(Some(message)) => match message {
80 Message::Request(request) => {
81 let response = transport_clone.protocol.handle_request(request).await;
82 let _ = transport_clone
83 .send_response(response.id, response.result, response.error)
84 .await;
85 }
86 Message::Notification(notification) => {
87 let _ = transport_clone
88 .protocol
89 .handle_notification(notification)
90 .await;
91 }
92 Message::Response(response) => {
93 transport_clone.protocol.handle_response(response).await;
94 }
95 },
96 Ok(None) => break, Err(e) => {
98 debug!("ClientStdioTransport: Error polling message: {:?}", e);
99 break;
100 }
101 }
102 }
103 });
104 Ok(())
105 }
106
107 async fn close(&self) -> Result<()> {
108 let mut child_lock = self.child.lock().await;
109 if let Some(child) = child_lock.as_mut() {
110 let _ = child.kill();
111 }
112 *child_lock = None;
113
114 *self.stdin.lock().await = None;
116 *self.stdout.lock().await = None;
117
118 Ok(())
119 }
120
121 async fn poll_message(&self) -> Result<Option<Message>> {
122 debug!("ClientStdioTransport: Starting to receive message");
123
124 let mut stdout_guard = self.stdout.lock().await;
126 let mut stdout = stdout_guard
127 .take()
128 .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?;
129
130 drop(stdout_guard);
132
133 let (line_result, stdout) = tokio::task::spawn_blocking(move || {
135 let mut line = String::new();
136 let result = match stdout.read_line(&mut line) {
137 Ok(0) => Ok(None), Ok(_) => Ok(Some(line)),
139 Err(e) => Err(anyhow::anyhow!("Error reading line: {}", e)),
140 };
141 (result, stdout)
143 })
144 .await?;
145
146 let mut stdout_guard = self.stdout.lock().await;
148 *stdout_guard = Some(stdout);
149
150 match line_result? {
152 Some(line) => {
153 debug!(
154 "ClientStdioTransport: Received from process: {}",
155 line.trim()
156 );
157 let message: Message = serde_json::from_str(&line)?;
158 debug!("ClientStdioTransport: Successfully parsed message");
159 Ok(Some(message))
160 }
161 None => {
162 debug!("ClientStdioTransport: Received EOF from process");
163 Ok(None)
164 }
165 }
166 }
167
168 fn request(
169 &self,
170 method: &str,
171 params: Option<serde_json::Value>,
172 options: RequestOptions,
173 ) -> Pin<Box<dyn Future<Output = Result<JsonRpcResponse>> + Send + Sync>> {
174 let protocol = self.protocol.clone();
175 let stdin_arc = self.stdin.clone();
176 let method = method.to_owned();
177 Box::pin(async move {
178 let (id, rx) = protocol.create_request().await;
179 let request = JsonRpcRequest {
180 id,
181 method,
182 jsonrpc: Default::default(),
183 params,
184 };
185 let serialized = serde_json::to_string(&request)?;
186 debug!("ClientStdioTransport: Sending request: {}", serialized);
187
188 let mut stdin_guard = stdin_arc.lock().await;
190 let mut stdin = stdin_guard
191 .take()
192 .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?;
193
194 let stdin_result = tokio::task::spawn_blocking(move || {
196 stdin.write_all(serialized.as_bytes())?;
197 stdin.write_all(b"\n")?;
198 stdin.flush()?;
199 Ok::<_, anyhow::Error>(stdin)
200 })
201 .await??;
202
203 *stdin_guard = Some(stdin_result);
205
206 debug!("ClientStdioTransport: Request sent successfully");
207 let result = timeout(options.timeout, rx).await;
208 match result {
209 Ok(inner_result) => match inner_result {
210 Ok(response) => Ok(response),
211 Err(_) => {
212 protocol.cancel_response(id).await;
213 Ok(JsonRpcResponse {
214 id,
215 result: None,
216 error: Some(JsonRpcError {
217 code: ErrorCode::RequestTimeout as i32,
218 message: "Request cancelled".to_string(),
219 data: None,
220 }),
221 ..Default::default()
222 })
223 }
224 },
225 Err(_) => {
226 protocol.cancel_response(id).await;
227 Ok(JsonRpcResponse {
228 id,
229 result: None,
230 error: Some(JsonRpcError {
231 code: ErrorCode::RequestTimeout as i32,
232 message: "Request timed out".to_string(),
233 data: None,
234 }),
235 ..Default::default()
236 })
237 }
238 }
239 })
240 }
241
242 async fn send_response(
243 &self,
244 id: RequestId,
245 result: Option<serde_json::Value>,
246 error: Option<JsonRpcError>,
247 ) -> Result<()> {
248 let response = JsonRpcResponse {
249 id,
250 result,
251 error,
252 jsonrpc: Default::default(),
253 };
254 let serialized = serde_json::to_string(&response)?;
255 debug!("ClientStdioTransport: Sending response: {}", serialized);
256
257 let mut stdin_guard = self.stdin.lock().await;
259 let mut stdin = stdin_guard
260 .take()
261 .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?;
262
263 let stdin_result = tokio::task::spawn_blocking(move || {
265 stdin.write_all(serialized.as_bytes())?;
266 stdin.write_all(b"\n")?;
267 stdin.flush()?;
268 Ok::<_, anyhow::Error>(stdin)
269 })
270 .await??;
271
272 *stdin_guard = Some(stdin_result);
274
275 Ok(())
276 }
277
278 async fn send_notification(
279 &self,
280 method: &str,
281 params: Option<serde_json::Value>,
282 ) -> Result<()> {
283 let notification = JsonRpcNotification {
284 jsonrpc: Default::default(),
285 method: method.to_owned(),
286 params,
287 };
288 let serialized = serde_json::to_string(¬ification)?;
289 debug!("ClientStdioTransport: Sending notification: {}", serialized);
290
291 let mut stdin_guard = self.stdin.lock().await;
293 let mut stdin = stdin_guard
294 .take()
295 .ok_or_else(|| anyhow::anyhow!("Transport not opened"))?;
296
297 let stdin_result = tokio::task::spawn_blocking(move || {
299 stdin.write_all(serialized.as_bytes())?;
300 stdin.write_all(b"\n")?;
301 stdin.flush()?;
302 Ok::<_, anyhow::Error>(stdin)
303 })
304 .await??;
305
306 *stdin_guard = Some(stdin_result);
308
309 Ok(())
310 }
311}