modelcontextprotocol_client/transport/
stdio.rs1use anyhow::Result;
3use async_trait::async_trait;
4use mcp_protocol::messages::JsonRpcMessage;
5use std::process::Stdio;
6use tokio::process::{Child, Command};
7use std::sync::Arc;
8use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
9use tokio::sync::{mpsc, Mutex};
10
11pub struct StdioTransport {
13 child_process: Arc<Mutex<Option<Child>>>,
14 tx: mpsc::Sender<JsonRpcMessage>,
15 command: String,
16 args: Vec<String>,
17 stdin: Arc<Mutex<Option<tokio::process::ChildStdin>>>,
19}
20
21impl StdioTransport {
22 pub fn new(command: &str, args: Vec<String>) -> (Self, mpsc::Receiver<JsonRpcMessage>) {
24 let (tx, rx) = mpsc::channel(100);
25
26 let transport = Self {
27 child_process: Arc::new(Mutex::new(None)),
28 tx,
29 command: command.to_string(),
30 args,
31 stdin: Arc::new(Mutex::new(None)),
32 };
33
34 (transport, rx)
35 }
36}
37
38#[async_trait]
39impl super::Transport for StdioTransport {
40 async fn start(&self) -> Result<()> {
41 let mut child = Command::new(&self.command)
42 .args(&self.args)
43 .stdin(Stdio::piped())
44 .stdout(Stdio::piped())
45 .stderr(Stdio::inherit())
46 .spawn()?;
47
48 let stdout = child.stdout.take().expect("Failed to get stdout");
49 let stdin = child.stdin.take().expect("Failed to get stdin");
50
51 {
53 let mut guard = self.child_process.lock().await;
54 *guard = Some(child);
55 }
56
57 {
59 let mut stdin_guard = self.stdin.lock().await;
60 *stdin_guard = Some(stdin);
61 }
62
63 let tx = self.tx.clone();
64
65 tokio::spawn(async move {
67 let mut reader = BufReader::new(stdout);
68 let mut line = String::new();
69
70 while reader.read_line(&mut line).await.unwrap_or(0) > 0 {
71 match serde_json::from_str::<JsonRpcMessage>(&line) {
72 Ok(message) => {
73 if tx.send(message).await.is_err() {
74 break;
75 }
76 }
77 Err(err) => {
78 tracing::error!("Failed to parse JSON-RPC message: {}", err);
79 }
80 }
81
82 line.clear();
83 }
84 });
85
86 Ok(())
87 }
88
89 async fn send(&self, message: JsonRpcMessage) -> Result<()> {
90 let mut stdin_guard = self.stdin.lock().await;
92 let stdin = stdin_guard
93 .as_mut()
94 .ok_or_else(|| anyhow::anyhow!("Child process not started"))?;
95
96 let serialized = serde_json::to_string(&message)?;
97
98 stdin.write_all(serialized.as_bytes()).await?;
100 stdin.write_all(b"\n").await?;
101 stdin.flush().await?;
102
103 Ok(())
104 }
105
106 async fn close(&self) -> Result<()> {
107 {
109 let mut stdin_guard = self.stdin.lock().await;
110 *stdin_guard = None;
111 }
112
113 let mut guard = self.child_process.lock().await;
115
116 if let Some(mut child) = guard.take() {
117 let wait_future = child.wait();
119 match tokio::time::timeout(std::time::Duration::from_secs(1), wait_future).await {
120 Ok(Ok(_)) => return Ok(()),
121 _ => {
122 child.kill().await?;
124 child.wait().await?;
125 }
126 }
127 }
128
129 Ok(())
130 }
131
132 fn box_clone(&self) -> Box<dyn super::Transport> {
133 Box::new(self.clone())
134 }
135}
136
137impl Clone for StdioTransport {
138 fn clone(&self) -> Self {
139 Self {
140 child_process: self.child_process.clone(),
141 tx: self.tx.clone(),
142 command: self.command.clone(),
143 args: self.args.clone(),
144 stdin: self.stdin.clone(),
145 }
146 }
147}