mcprotocol_rs/transport/stdio/
client.rs

1use crate::{protocol::Message, Result};
2use async_trait::async_trait;
3use std::{path::PathBuf, process::Stdio};
4use tokio::{
5    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
6    process::{Child, Command},
7    sync::Mutex,
8};
9
10/// Stdio client configuration
11pub struct StdioClientConfig {
12    /// Server executable path
13    pub server_path: PathBuf,
14    /// Server arguments
15    pub server_args: Vec<String>,
16    /// Buffer size
17    pub buffer_size: usize,
18    /// Whether to capture server logs
19    pub capture_logs: bool,
20}
21
22impl Default for StdioClientConfig {
23    fn default() -> Self {
24        Self {
25            server_path: PathBuf::from("mcp-server"),
26            server_args: vec![],
27            buffer_size: 4096,
28            capture_logs: true,
29        }
30    }
31}
32
33/// Stdio client implementation
34pub struct StdioClient {
35    config: StdioClientConfig,
36    child: Mutex<Option<Child>>,
37    stdin: Mutex<Option<tokio::process::ChildStdin>>,
38    stdout: Mutex<Option<BufReader<tokio::process::ChildStdout>>>,
39    stderr: Mutex<Option<BufReader<tokio::process::ChildStderr>>>,
40}
41
42impl StdioClient {
43    /// Create a new Stdio client
44    pub fn new(config: StdioClientConfig) -> Self {
45        Self {
46            config,
47            child: Mutex::new(None),
48            stdin: Mutex::new(None),
49            stdout: Mutex::new(None),
50            stderr: Mutex::new(None),
51        }
52    }
53
54    /// Start log capture
55    async fn start_log_capture(&self, mut stderr: tokio::process::ChildStderr) {
56        tokio::spawn(async move {
57            let mut reader = BufReader::new(stderr);
58            let mut line = String::new();
59            while let Ok(n) = reader.read_line(&mut line).await {
60                if n == 0 {
61                    break;
62                }
63                // Here you can handle logs as needed, such as forwarding to a specific logging system
64                eprintln!("[MCP Server] {}", line.trim());
65                line.clear();
66            }
67        });
68    }
69}
70
71#[async_trait]
72impl super::StdioTransport for StdioClient {
73    async fn initialize(&mut self) -> Result<()> {
74        let mut child = Command::new(&self.config.server_path)
75            .args(&self.config.server_args)
76            .stdin(Stdio::piped())
77            .stdout(Stdio::piped())
78            .stderr(if self.config.capture_logs {
79                Stdio::piped()
80            } else {
81                Stdio::inherit()
82            })
83            .spawn()
84            .map_err(|e| crate::Error::Transport(format!("Failed to start server: {}", e)))?;
85
86        let stdin = child
87            .stdin
88            .take()
89            .ok_or_else(|| crate::Error::Transport("Failed to get server stdin handle".into()))?;
90        let stdout = child
91            .stdout
92            .take()
93            .ok_or_else(|| crate::Error::Transport("Failed to get server stdout handle".into()))?;
94
95        if self.config.capture_logs {
96            if let Some(stderr) = child.stderr.take() {
97                self.start_log_capture(stderr).await;
98            }
99        }
100
101        *self.stdin.lock().await = Some(stdin);
102        *self.stdout.lock().await = Some(BufReader::new(stdout));
103        *self.child.lock().await = Some(child);
104
105        Ok(())
106    }
107
108    async fn send(&self, message: Message) -> Result<()> {
109        let mut stdin = self.stdin.lock().await;
110        let stdin = stdin
111            .as_mut()
112            .ok_or_else(|| crate::Error::Transport("Server process not initialized".into()))?;
113
114        let json = serde_json::to_string(&message)?;
115        if json.contains('\n') {
116            return Err(crate::Error::Transport(
117                "Message contains embedded newlines".into(),
118            ));
119        }
120
121        stdin.write_all(json.as_bytes()).await?;
122        stdin.write_all(b"\n").await?;
123        stdin.flush().await?;
124        Ok(())
125    }
126
127    async fn receive(&self) -> Result<Message> {
128        let mut stdout = self.stdout.lock().await;
129        let stdout = stdout
130            .as_mut()
131            .ok_or_else(|| crate::Error::Transport("Server process not initialized".into()))?;
132
133        let mut line = String::with_capacity(self.config.buffer_size);
134        stdout.read_line(&mut line).await?;
135
136        if line.is_empty() {
137            return Err(crate::Error::Transport("Server process terminated".into()));
138        }
139
140        let message = serde_json::from_str(&line)?;
141        Ok(message)
142    }
143
144    async fn close(&mut self) -> Result<()> {
145        let mut child = self.child.lock().await;
146        if let Some(mut child) = child.take() {
147            // First close stdin to let the server know there will be no more input
148            drop(self.stdin.lock().await.take());
149
150            // Wait for the server process to end
151            match child.wait().await {
152                Ok(status) => {
153                    if !status.success() {
154                        return Err(crate::Error::Transport(format!(
155                            "Server process exited with status: {}",
156                            status
157                        )));
158                    }
159                }
160                Err(e) => {
161                    return Err(crate::Error::Transport(format!(
162                        "Failed to wait for server process: {}",
163                        e
164                    )));
165                }
166            }
167        }
168
169        *self.stdout.lock().await = None;
170        *self.stderr.lock().await = None;
171        Ok(())
172    }
173}
174
175/// Default Stdio client type
176pub type DefaultStdioClient = StdioClient;