mcp_protocol_sdk/transport/
stdio.rs1use async_trait::async_trait;
7use serde_json::Value;
8use std::collections::HashMap;
9use std::process::Stdio;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
12use tokio::process::{Child, Command};
13use tokio::sync::{mpsc, Mutex};
14use tokio::time::{timeout, Duration};
15
16use crate::core::error::{McpError, McpResult};
17use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
18use crate::transport::traits::{ConnectionState, ServerTransport, Transport, TransportConfig};
19
20pub struct StdioClientTransport {
25 child: Option<Child>,
26 stdin_writer: Option<BufWriter<tokio::process::ChildStdin>>,
27 #[allow(dead_code)]
28 stdout_reader: Option<BufReader<tokio::process::ChildStdout>>,
29 notification_receiver: Option<mpsc::UnboundedReceiver<JsonRpcNotification>>,
30 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
31 config: TransportConfig,
32 state: ConnectionState,
33}
34
35impl StdioClientTransport {
36 pub async fn new<S: AsRef<str>>(command: S, args: Vec<S>) -> McpResult<Self> {
45 Self::with_config(command, args, TransportConfig::default()).await
46 }
47
48 pub async fn with_config<S: AsRef<str>>(
58 command: S,
59 args: Vec<S>,
60 config: TransportConfig,
61 ) -> McpResult<Self> {
62 let command_str = command.as_ref();
63 let args_str: Vec<&str> = args.iter().map(|s| s.as_ref()).collect();
64
65 tracing::debug!("Starting MCP server: {} {:?}", command_str, args_str);
66
67 let mut child = Command::new(command_str)
68 .args(&args_str)
69 .stdin(Stdio::piped())
70 .stdout(Stdio::piped())
71 .stderr(Stdio::piped())
72 .spawn()
73 .map_err(|e| McpError::transport(format!("Failed to start server process: {}", e)))?;
74
75 let stdin = child
76 .stdin
77 .take()
78 .ok_or_else(|| McpError::transport("Failed to get stdin handle"))?;
79 let stdout = child
80 .stdout
81 .take()
82 .ok_or_else(|| McpError::transport("Failed to get stdout handle"))?;
83
84 let stdin_writer = BufWriter::new(stdin);
85 let stdout_reader = BufReader::new(stdout);
86
87 let (notification_sender, notification_receiver) = mpsc::unbounded_channel();
88 let pending_requests = Arc::new(Mutex::new(HashMap::new()));
89
90 let reader_pending_requests = pending_requests.clone();
92 let reader = stdout_reader;
93 tokio::spawn(async move {
94 Self::message_processor(reader, notification_sender, reader_pending_requests).await;
95 });
96
97 Ok(Self {
98 child: Some(child),
99 stdin_writer: Some(stdin_writer),
100 stdout_reader: None, notification_receiver: Some(notification_receiver),
102 pending_requests,
103 config,
104 state: ConnectionState::Connected,
105 })
106 }
107
108 async fn message_processor(
109 mut reader: BufReader<tokio::process::ChildStdout>,
110 notification_sender: mpsc::UnboundedSender<JsonRpcNotification>,
111 pending_requests: Arc<Mutex<HashMap<Value, tokio::sync::oneshot::Sender<JsonRpcResponse>>>>,
112 ) {
113 let mut line = String::new();
114
115 loop {
116 line.clear();
117 match reader.read_line(&mut line).await {
118 Ok(0) => {
119 tracing::debug!("STDIO reader reached EOF");
120 break;
121 }
122 Ok(_) => {
123 let line = line.trim();
124 if line.is_empty() {
125 continue;
126 }
127
128 tracing::trace!("Received: {}", line);
129
130 if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(line) {
132 let mut pending = pending_requests.lock().await;
133 if let Some(sender) = pending.remove(&response.id) {
134 let _ = sender.send(response);
135 } else {
136 tracing::warn!(
137 "Received response for unknown request ID: {:?}",
138 response.id
139 );
140 }
141 }
142 else if let Ok(notification) =
144 serde_json::from_str::<JsonRpcNotification>(line)
145 {
146 if notification_sender.send(notification).is_err() {
147 tracing::debug!("Notification receiver dropped");
148 break;
149 }
150 } else {
151 tracing::warn!("Failed to parse message: {}", line);
152 }
153 }
154 Err(e) => {
155 tracing::error!("Error reading from stdout: {}", e);
156 break;
157 }
158 }
159 }
160 }
161}
162
163#[async_trait]
164impl Transport for StdioClientTransport {
165 async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
166 let writer = self
167 .stdin_writer
168 .as_mut()
169 .ok_or_else(|| McpError::transport("Transport not connected"))?;
170
171 let (sender, receiver) = tokio::sync::oneshot::channel();
172
173 {
175 let mut pending = self.pending_requests.lock().await;
176 pending.insert(request.id.clone(), sender);
177 }
178
179 let request_line = serde_json::to_string(&request).map_err(McpError::serialization)?;
181
182 tracing::trace!("Sending: {}", request_line);
183
184 writer
185 .write_all(request_line.as_bytes())
186 .await
187 .map_err(|e| McpError::transport(format!("Failed to write request: {}", e)))?;
188 writer
189 .write_all(b"\n")
190 .await
191 .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
192 writer
193 .flush()
194 .await
195 .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
196
197 let timeout_duration = Duration::from_millis(self.config.read_timeout_ms.unwrap_or(60_000));
199
200 let response = timeout(timeout_duration, receiver)
201 .await
202 .map_err(|_| McpError::timeout("Request timeout"))?
203 .map_err(|_| McpError::transport("Response channel closed"))?;
204
205 Ok(response)
206 }
207
208 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
209 let writer = self
210 .stdin_writer
211 .as_mut()
212 .ok_or_else(|| McpError::transport("Transport not connected"))?;
213
214 let notification_line =
215 serde_json::to_string(¬ification).map_err(McpError::serialization)?;
216
217 tracing::trace!("Sending notification: {}", notification_line);
218
219 writer
220 .write_all(notification_line.as_bytes())
221 .await
222 .map_err(|e| McpError::transport(format!("Failed to write notification: {}", e)))?;
223 writer
224 .write_all(b"\n")
225 .await
226 .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
227 writer
228 .flush()
229 .await
230 .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
231
232 Ok(())
233 }
234
235 async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>> {
236 if let Some(ref mut receiver) = self.notification_receiver {
237 match receiver.try_recv() {
238 Ok(notification) => Ok(Some(notification)),
239 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
240 Err(mpsc::error::TryRecvError::Disconnected) => {
241 Err(McpError::transport("Notification channel disconnected"))
242 }
243 }
244 } else {
245 Ok(None)
246 }
247 }
248
249 async fn close(&mut self) -> McpResult<()> {
250 tracing::debug!("Closing STDIO transport");
251
252 self.state = ConnectionState::Closing;
253
254 if let Some(mut writer) = self.stdin_writer.take() {
256 let _ = writer.shutdown().await;
257 }
258
259 if let Some(mut child) = self.child.take() {
261 match timeout(Duration::from_secs(5), child.wait()).await {
262 Ok(Ok(status)) => {
263 tracing::debug!("Server process exited with status: {}", status);
264 }
265 Ok(Err(e)) => {
266 tracing::warn!("Error waiting for server process: {}", e);
267 }
268 Err(_) => {
269 tracing::warn!("Timeout waiting for server process, killing it");
270 let _ = child.kill().await;
271 }
272 }
273 }
274
275 self.state = ConnectionState::Disconnected;
276 Ok(())
277 }
278
279 fn is_connected(&self) -> bool {
280 matches!(self.state, ConnectionState::Connected)
281 }
282
283 fn connection_info(&self) -> String {
284 format!("STDIO transport (state: {:?})", self.state)
285 }
286}
287
288pub struct StdioServerTransport {
293 stdin_reader: Option<BufReader<tokio::io::Stdin>>,
294 stdout_writer: Option<BufWriter<tokio::io::Stdout>>,
295 #[allow(dead_code)]
296 config: TransportConfig,
297 running: bool,
298 request_handler: Option<
299 Box<
300 dyn Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse> + Send + Sync,
301 >,
302 >,
303}
304
305impl StdioServerTransport {
306 pub fn new() -> Self {
311 Self::with_config(TransportConfig::default())
312 }
313
314 pub fn with_config(config: TransportConfig) -> Self {
322 let stdin_reader = BufReader::new(tokio::io::stdin());
323 let stdout_writer = BufWriter::new(tokio::io::stdout());
324
325 Self {
326 stdin_reader: Some(stdin_reader),
327 stdout_writer: Some(stdout_writer),
328 config,
329 running: false,
330 request_handler: None,
331 }
332 }
333
334 pub fn set_request_handler<F>(&mut self, handler: F)
339 where
340 F: Fn(JsonRpcRequest) -> tokio::sync::oneshot::Receiver<JsonRpcResponse>
341 + Send
342 + Sync
343 + 'static,
344 {
345 self.request_handler = Some(Box::new(handler));
346 }
347}
348
349#[async_trait]
350impl ServerTransport for StdioServerTransport {
351 async fn start(&mut self) -> McpResult<()> {
352 tracing::debug!("Starting STDIO server transport");
353
354 let mut reader = self
355 .stdin_reader
356 .take()
357 .ok_or_else(|| McpError::transport("STDIN reader already taken"))?;
358 let mut writer = self
359 .stdout_writer
360 .take()
361 .ok_or_else(|| McpError::transport("STDOUT writer already taken"))?;
362
363 self.running = true;
364
365 let mut line = String::new();
366 while self.running {
367 line.clear();
368
369 match reader.read_line(&mut line).await {
370 Ok(0) => {
371 tracing::debug!("STDIN closed, stopping server");
372 break;
373 }
374 Ok(_) => {
375 let line = line.trim();
376 if line.is_empty() {
377 continue;
378 }
379
380 tracing::trace!("Received: {}", line);
381
382 match serde_json::from_str::<JsonRpcRequest>(line) {
384 Ok(request) => {
385 let response = self.handle_request(request).await?;
386
387 let response_line = serde_json::to_string(&response)
388 .map_err(McpError::serialization)?;
389
390 tracing::trace!("Sending: {}", response_line);
391
392 writer
393 .write_all(response_line.as_bytes())
394 .await
395 .map_err(|e| {
396 McpError::transport(format!("Failed to write response: {}", e))
397 })?;
398 writer.write_all(b"\n").await.map_err(|e| {
399 McpError::transport(format!("Failed to write newline: {}", e))
400 })?;
401 writer.flush().await.map_err(|e| {
402 McpError::transport(format!("Failed to flush: {}", e))
403 })?;
404 }
405 Err(e) => {
406 tracing::warn!("Failed to parse request: {} - Error: {}", line, e);
407 }
410 }
411 }
412 Err(e) => {
413 tracing::error!("Error reading from stdin: {}", e);
414 return Err(McpError::io(e));
415 }
416 }
417 }
418
419 Ok(())
420 }
421
422 async fn handle_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse> {
423 Ok(JsonRpcResponse {
425 jsonrpc: "2.0".to_string(),
426 id: request.id,
427 result: None,
428 error: Some(crate::protocol::types::JsonRpcError {
429 code: crate::protocol::types::METHOD_NOT_FOUND,
430 message: format!("Method '{}' not found", request.method),
431 data: None,
432 }),
433 })
434 }
435
436 async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()> {
437 let writer = self
438 .stdout_writer
439 .as_mut()
440 .ok_or_else(|| McpError::transport("STDOUT writer not available"))?;
441
442 let notification_line =
443 serde_json::to_string(¬ification).map_err(McpError::serialization)?;
444
445 tracing::trace!("Sending notification: {}", notification_line);
446
447 writer
448 .write_all(notification_line.as_bytes())
449 .await
450 .map_err(|e| McpError::transport(format!("Failed to write notification: {}", e)))?;
451 writer
452 .write_all(b"\n")
453 .await
454 .map_err(|e| McpError::transport(format!("Failed to write newline: {}", e)))?;
455 writer
456 .flush()
457 .await
458 .map_err(|e| McpError::transport(format!("Failed to flush: {}", e)))?;
459
460 Ok(())
461 }
462
463 async fn stop(&mut self) -> McpResult<()> {
464 tracing::debug!("Stopping STDIO server transport");
465 self.running = false;
466 Ok(())
467 }
468
469 fn is_running(&self) -> bool {
470 self.running
471 }
472
473 fn server_info(&self) -> String {
474 format!("STDIO server transport (running: {})", self.running)
475 }
476}
477
478impl Default for StdioServerTransport {
479 fn default() -> Self {
480 Self::new()
481 }
482}
483
484impl Drop for StdioClientTransport {
485 fn drop(&mut self) {
486 if let Some(mut child) = self.child.take() {
487 let _ = child.start_kill();
489 }
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use serde_json::json;
497
498 #[test]
499 fn test_stdio_server_creation() {
500 let transport = StdioServerTransport::new();
501 assert!(!transport.is_running());
502 assert!(transport.stdin_reader.is_some());
503 assert!(transport.stdout_writer.is_some());
504 }
505
506 #[test]
507 fn test_stdio_server_with_config() {
508 let config = TransportConfig {
509 read_timeout_ms: Some(30_000),
510 ..Default::default()
511 };
512
513 let transport = StdioServerTransport::with_config(config);
514 assert_eq!(transport.config.read_timeout_ms, Some(30_000));
515 }
516
517 #[tokio::test]
518 async fn test_stdio_server_handle_request() {
519 let mut transport = StdioServerTransport::new();
520
521 let request = JsonRpcRequest {
522 jsonrpc: "2.0".to_string(),
523 id: json!(1),
524 method: "unknown_method".to_string(),
525 params: None,
526 };
527
528 let response = transport.handle_request(request).await.unwrap();
529 assert_eq!(response.jsonrpc, "2.0");
530 assert_eq!(response.id, json!(1));
531 assert!(response.error.is_some());
532 assert!(response.result.is_none());
533
534 let error = response.error.unwrap();
535 assert_eq!(error.code, crate::protocol::types::METHOD_NOT_FOUND);
536 }
537
538 }