modelcontextprotocol_client/transport/
stdio.rs

1// mcp-client/src/transport/stdio.rs
2use anyhow::Result;
3use async_trait::async_trait;
4use mcp_protocol::messages::JsonRpcMessage;
5use std::process::Stdio;
6use tokio::process::{Child, Command};
7use std::sync::Arc;
8use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
9use tokio::sync::{mpsc, Mutex};
10
11/// Transport implementation that uses stdio to communicate with a child process
12pub struct StdioTransport {
13    child_process: Arc<Mutex<Option<Child>>>,
14    tx: mpsc::Sender<JsonRpcMessage>,
15    command: String,
16    args: Vec<String>,
17    // Add a shared stdin channel for writing
18    stdin: Arc<Mutex<Option<tokio::process::ChildStdin>>>,
19}
20
21impl StdioTransport {
22    /// Create a new stdio transport with the given command and arguments
23    pub fn new(command: &str, args: Vec<String>) -> (Self, mpsc::Receiver<JsonRpcMessage>) {
24        let (tx, rx) = mpsc::channel(100);
25
26        let transport = Self {
27            child_process: Arc::new(Mutex::new(None)),
28            tx,
29            command: command.to_string(),
30            args,
31            stdin: Arc::new(Mutex::new(None)),
32        };
33
34        (transport, rx)
35    }
36}
37
38#[async_trait]
39impl super::Transport for StdioTransport {
40    async fn start(&self) -> Result<()> {
41        let mut child = Command::new(&self.command)
42            .args(&self.args)
43            .stdin(Stdio::piped())
44            .stdout(Stdio::piped())
45            .stderr(Stdio::inherit())
46            .spawn()?;
47
48        let stdout = child.stdout.take().expect("Failed to get stdout");
49        let stdin = child.stdin.take().expect("Failed to get stdin");
50
51        // Store child process
52        {
53            let mut guard = self.child_process.lock().await;
54            *guard = Some(child);
55        }
56
57        // Store stdin for writing messages
58        {
59            let mut stdin_guard = self.stdin.lock().await;
60            *stdin_guard = Some(stdin);
61        }
62
63        let tx = self.tx.clone();
64
65        // Spawn a task to read from stdout
66        tokio::spawn(async move {
67            let mut reader = BufReader::new(stdout);
68            let mut line = String::new();
69
70            while reader.read_line(&mut line).await.unwrap_or(0) > 0 {
71                match serde_json::from_str::<JsonRpcMessage>(&line) {
72                    Ok(message) => {
73                        if tx.send(message).await.is_err() {
74                            break;
75                        }
76                    }
77                    Err(err) => {
78                        tracing::error!("Failed to parse JSON-RPC message: {}", err);
79                    }
80                }
81
82                line.clear();
83            }
84        });
85
86        Ok(())
87    }
88
89    async fn send(&self, message: JsonRpcMessage) -> Result<()> {
90        // Get stdin from our stored mutex
91        let mut stdin_guard = self.stdin.lock().await;
92        let stdin = stdin_guard
93            .as_mut()
94            .ok_or_else(|| anyhow::anyhow!("Child process not started"))?;
95
96        let serialized = serde_json::to_string(&message)?;
97        
98        // Now we can directly use AsyncWriteExt methods on stdin
99        stdin.write_all(serialized.as_bytes()).await?;
100        stdin.write_all(b"\n").await?;
101        stdin.flush().await?;
102
103        Ok(())
104    }
105
106    async fn close(&self) -> Result<()> {
107        // First close stdin
108        {
109            let mut stdin_guard = self.stdin.lock().await;
110            *stdin_guard = None;
111        }
112        
113        // Then close the child process
114        let mut guard = self.child_process.lock().await;
115
116        if let Some(mut child) = guard.take() {
117            // Wait for a short time for the process to exit gracefully
118            let wait_future = child.wait();
119            match tokio::time::timeout(std::time::Duration::from_secs(1), wait_future).await {
120                Ok(Ok(_)) => return Ok(()),
121                _ => {
122                    // If it doesn't exit, kill it
123                    child.kill().await?;
124                    child.wait().await?;
125                }
126            }
127        }
128
129        Ok(())
130    }
131    
132    fn box_clone(&self) -> Box<dyn super::Transport> {
133        Box::new(self.clone())
134    }
135}
136
137impl Clone for StdioTransport {
138    fn clone(&self) -> Self {
139        Self {
140            child_process: self.child_process.clone(),
141            tx: self.tx.clone(),
142            command: self.command.clone(),
143            args: self.args.clone(),
144            stdin: self.stdin.clone(),
145        }
146    }
147}