model_context_protocol/client/
stdio.rs1use async_trait::async_trait;
14use dashmap::DashMap;
15use serde_json::Value;
16use std::collections::HashMap;
17use std::sync::atomic::{AtomicBool, AtomicI64, Ordering};
18use std::sync::Arc;
19use std::time::Duration;
20use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
21use tokio::process::{Child, Command};
22use tokio::sync::{mpsc, oneshot};
23
24use crate::protocol::*;
25use crate::transport::{InitializeParams, McpTransport, McpTransportError, TransportTypeId};
26
27struct WriteRequest {
33 request_line: String,
34}
35
36pub struct TokioStdioTransport {
44 write_tx: mpsc::Sender<WriteRequest>,
46 pending: Arc<DashMap<i64, oneshot::Sender<Result<Value, McpTransportError>>>>,
48 next_id: AtomicI64,
50 alive: Arc<AtomicBool>,
52 child: Arc<tokio::sync::Mutex<Child>>,
54}
55
56impl TokioStdioTransport {
57 pub async fn spawn(command: &str, args: &[String]) -> Result<Self, McpTransportError> {
59 Self::spawn_with_env(command, args, HashMap::new()).await
60 }
61
62 pub async fn spawn_with_env(
64 command: &str,
65 args: &[String],
66 env: HashMap<String, String>,
67 ) -> Result<Self, McpTransportError> {
68 let mut cmd = Command::new(command);
69 cmd.args(args)
70 .stdin(std::process::Stdio::piped())
71 .stdout(std::process::Stdio::piped())
72 .stderr(std::process::Stdio::piped())
73 .kill_on_drop(true);
74
75 for (key, value) in env {
76 cmd.env(key, value);
77 }
78
79 let mut child = cmd.spawn().map_err(|e| {
80 McpTransportError::TransportError(format!(
81 "Failed to spawn process '{}': {}",
82 command, e
83 ))
84 })?;
85
86 let stdin = child
88 .stdin
89 .take()
90 .ok_or_else(|| McpTransportError::TransportError("Failed to get stdin".to_string()))?;
91 let stdout = child
92 .stdout
93 .take()
94 .ok_or_else(|| McpTransportError::TransportError("Failed to get stdout".to_string()))?;
95
96 let alive = Arc::new(AtomicBool::new(true));
97 let pending: Arc<DashMap<i64, oneshot::Sender<Result<Value, McpTransportError>>>> =
98 Arc::new(DashMap::new());
99
100 let (write_tx, mut write_rx) = mpsc::channel::<WriteRequest>(256);
102
103 let alive_writer = Arc::clone(&alive);
105 let mut stdin = stdin;
106 tokio::spawn(async move {
107 while let Some(req) = write_rx.recv().await {
108 if !alive_writer.load(Ordering::SeqCst) {
109 break;
110 }
111 if let Err(e) = stdin.write_all(req.request_line.as_bytes()).await {
112 eprintln!("Stdio write error: {}", e);
113 alive_writer.store(false, Ordering::SeqCst);
114 break;
115 }
116 if let Err(e) = stdin.flush().await {
117 eprintln!("Stdio flush error: {}", e);
118 alive_writer.store(false, Ordering::SeqCst);
119 break;
120 }
121 }
122 });
123
124 let pending_reader = Arc::clone(&pending);
126 let alive_reader = Arc::clone(&alive);
127 let mut reader = BufReader::new(stdout);
128 tokio::spawn(async move {
129 let mut line = String::new();
130 loop {
131 line.clear();
132 match reader.read_line(&mut line).await {
133 Ok(0) => {
134 alive_reader.store(false, Ordering::SeqCst);
136 break;
137 }
138 Ok(_) => {
139 match serde_json::from_str::<JsonRpcResponse>(&line) {
141 Ok(response) => {
142 if let JsonRpcId::Number(id) = &response.id {
143 if let Some((_, tx)) = pending_reader.remove(id) {
144 let result = match response.payload {
145 JsonRpcPayload::Success { result } => Ok(result),
146 JsonRpcPayload::Error { error } => {
147 Err(McpTransportError::ServerError(format!(
148 "MCP Error: {}",
149 error
150 )))
151 }
152 };
153 let _ = tx.send(result);
154 }
155 }
156 }
157 Err(e) => {
158 eprintln!(
159 "Failed to parse response: {} - line: {}",
160 e,
161 line.trim()
162 );
163 }
164 }
165 }
166 Err(e) => {
167 eprintln!("Stdio read error: {}", e);
168 alive_reader.store(false, Ordering::SeqCst);
169 break;
170 }
171 }
172 }
173
174 pending_reader.clear();
177 });
178
179 Ok(Self {
180 write_tx,
181 pending,
182 next_id: AtomicI64::new(1),
183 alive,
184 child: Arc::new(tokio::sync::Mutex::new(child)),
185 })
186 }
187
188 pub async fn send_request(
190 &self,
191 method: &str,
192 params: Option<Value>,
193 timeout_duration: Duration,
194 ) -> Result<Value, McpTransportError> {
195 if !self.alive.load(Ordering::SeqCst) {
196 return Err(McpTransportError::ConnectionClosed);
197 }
198
199 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
200 let request = JsonRpcRequest::new(JsonRpcId::Number(id), method.to_string(), params);
201 let request_json = serde_json::to_string(&request)?;
202 let request_line = format!("{}\n", request_json);
203
204 let (tx, rx) = oneshot::channel();
206 self.pending.insert(id, tx);
207
208 if self
210 .write_tx
211 .send(WriteRequest { request_line })
212 .await
213 .is_err()
214 {
215 self.pending.remove(&id);
216 return Err(McpTransportError::ConnectionClosed);
217 }
218
219 match tokio::time::timeout(timeout_duration, rx).await {
221 Ok(Ok(result)) => result,
222 Ok(Err(_)) => {
223 self.pending.remove(&id);
224 Err(McpTransportError::ConnectionClosed)
225 }
226 Err(_) => {
227 self.pending.remove(&id);
228 Err(McpTransportError::Timeout(format!(
229 "Request timed out after {:?}",
230 timeout_duration
231 )))
232 }
233 }
234 }
235
236 pub fn is_alive(&self) -> bool {
238 self.alive.load(Ordering::SeqCst)
239 }
240
241 pub async fn stop(&self) -> Result<(), McpTransportError> {
243 self.alive.store(false, Ordering::SeqCst);
244
245 let mut child = self.child.lock().await;
247 if let Err(e) = child.kill().await {
248 if e.kind() != std::io::ErrorKind::InvalidInput {
250 return Err(McpTransportError::TransportError(format!(
251 "Failed to kill process: {}",
252 e
253 )));
254 }
255 }
256
257 Ok(())
258 }
259}
260
261pub struct AsyncStdioTransport {
269 inner: Arc<TokioStdioTransport>,
270}
271
272impl AsyncStdioTransport {
273 pub async fn spawn(command: &str, args: &[String]) -> Result<Self, McpTransportError> {
275 Ok(Self {
276 inner: Arc::new(TokioStdioTransport::spawn(command, args).await?),
277 })
278 }
279
280 pub async fn spawn_with_env(
282 command: &str,
283 args: &[String],
284 env: HashMap<String, String>,
285 ) -> Result<Self, McpTransportError> {
286 Ok(Self {
287 inner: Arc::new(TokioStdioTransport::spawn_with_env(command, args, env).await?),
288 })
289 }
290
291 pub async fn send_request_with_timeout(
293 &self,
294 method: &str,
295 params: Option<Value>,
296 timeout_duration: Duration,
297 ) -> Result<Value, McpTransportError> {
298 self.inner
299 .send_request(method, params, timeout_duration)
300 .await
301 }
302
303 pub fn is_alive(&self) -> bool {
305 self.inner.is_alive()
306 }
307
308 pub async fn stop(&self) -> Result<(), McpTransportError> {
310 self.inner.stop().await
311 }
312}
313
314pub struct StdioTransportAdapter {
316 inner: AsyncStdioTransport,
317 timeout: Duration,
318}
319
320impl StdioTransportAdapter {
321 pub async fn connect(
323 command: &str,
324 args: &[String],
325 config: Option<Value>,
326 timeout: Duration,
327 ) -> Result<Self, McpTransportError> {
328 Self::connect_with_env(command, args, HashMap::new(), config, timeout).await
329 }
330
331 pub async fn connect_with_env(
333 command: &str,
334 args: &[String],
335 env: HashMap<String, String>,
336 config: Option<Value>,
337 timeout: Duration,
338 ) -> Result<Self, McpTransportError> {
339 let inner = AsyncStdioTransport::spawn_with_env(command, args, env).await?;
340
341 let adapter = Self { inner, timeout };
342
343 let init_params = InitializeParams::new(config);
345 let _init_result = adapter
346 .inner
347 .send_request_with_timeout(
348 "initialize",
349 Some(serde_json::to_value(&init_params)?),
350 adapter.timeout,
351 )
352 .await?;
353
354 let _ = adapter
356 .inner
357 .send_request_with_timeout(
358 "notifications/initialized",
359 Some(serde_json::json!({})),
360 adapter.timeout,
361 )
362 .await;
363
364 Ok(adapter)
365 }
366}
367
368#[async_trait]
369impl McpTransport for StdioTransportAdapter {
370 async fn list_tools(&self) -> Result<Vec<McpToolDefinition>, McpTransportError> {
371 let result = self
372 .inner
373 .send_request_with_timeout("tools/list", Some(serde_json::json!({})), self.timeout)
374 .await?;
375
376 let list_result: ListToolsResult = serde_json::from_value(result)?;
377 Ok(list_result.tools)
378 }
379
380 async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError> {
381 let params = CallToolParams {
382 name: name.to_string(),
383 arguments: Some(args),
384 task: None,
385 meta: None,
386 };
387
388 let result = self
389 .inner
390 .send_request_with_timeout(
391 "tools/call",
392 Some(serde_json::to_value(¶ms)?),
393 self.timeout,
394 )
395 .await?;
396
397 let call_result: CallToolResult = serde_json::from_value(result)?;
398
399 if call_result.is_error == Some(true) {
400 let error_text = call_result
401 .content
402 .first()
403 .and_then(|c| c.as_text())
404 .unwrap_or("Unknown error");
405 return Err(McpTransportError::ServerError(error_text.to_string()));
406 }
407
408 let text = call_result
409 .content
410 .iter()
411 .filter_map(|c| c.as_text())
412 .collect::<Vec<_>>()
413 .join("\n");
414
415 Ok(Value::String(text))
416 }
417
418 async fn shutdown(&self) -> Result<(), McpTransportError> {
419 self.inner.stop().await
420 }
421
422 fn is_alive(&self) -> bool {
423 self.inner.is_alive()
424 }
425
426 fn transport_type(&self) -> TransportTypeId {
427 TransportTypeId::Stdio
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_transport_type() {
437 assert_eq!(TransportTypeId::Stdio.to_string(), "stdio");
438 }
439}