mcprotocol_rs/transport/stdio/
server.rs

1use crate::{protocol::Message, Result};
2use async_trait::async_trait;
3use tokio::{
4    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
5    sync::Mutex,
6};
7
8/// Stdio server configuration
9pub struct StdioServerConfig {
10    /// Buffer size
11    pub buffer_size: usize,
12}
13
14impl Default for StdioServerConfig {
15    fn default() -> Self {
16        Self { buffer_size: 4096 }
17    }
18}
19
20/// Stdio server implementation
21pub struct StdioServer {
22    config: StdioServerConfig,
23    stdin: Mutex<BufReader<tokio::io::Stdin>>,
24    stdout: Mutex<tokio::io::Stdout>,
25}
26
27impl StdioServer {
28    /// Create a new Stdio server
29    pub fn new(config: StdioServerConfig) -> Self {
30        let stdin = BufReader::new(tokio::io::stdin());
31        let stdout = tokio::io::stdout();
32
33        Self {
34            config,
35            stdin: Mutex::new(stdin),
36            stdout: Mutex::new(stdout),
37        }
38    }
39
40    /// Log a message (using stderr)
41    pub async fn log(&self, message: &str) -> Result<()> {
42        let mut stderr = tokio::io::stderr();
43        stderr.write_all(message.as_bytes()).await?;
44        stderr.write_all(b"\n").await?;
45        stderr.flush().await?;
46        Ok(())
47    }
48}
49
50#[async_trait]
51impl super::StdioTransport for StdioServer {
52    async fn initialize(&mut self) -> Result<()> {
53        self.log("MCP server initialized").await?;
54        Ok(())
55    }
56
57    async fn send(&self, message: Message) -> Result<()> {
58        let mut stdout = self.stdout.lock().await;
59        let json = serde_json::to_string(&message)?;
60
61        // Check if the message contains a newline
62        if json.contains('\n') {
63            self.log("Warning: Message contains embedded newlines")
64                .await?;
65            return Err(crate::Error::Transport(
66                "Message contains embedded newlines".into(),
67            ));
68        }
69
70        stdout.write_all(json.as_bytes()).await?;
71        stdout.write_all(b"\n").await?;
72        stdout.flush().await?;
73        Ok(())
74    }
75
76    async fn receive(&self) -> Result<Message> {
77        let mut stdin = self.stdin.lock().await;
78        let mut line = String::with_capacity(self.config.buffer_size);
79
80        if stdin.read_line(&mut line).await? == 0 {
81            self.log("Client connection closed").await?;
82            return Err(crate::Error::Transport("Client connection closed".into()));
83        }
84
85        match serde_json::from_str(&line) {
86            Ok(message) => Ok(message),
87            Err(e) => {
88                self.log(&format!("Error parsing message: {}", e)).await?;
89                Err(crate::Error::Transport(format!(
90                    "Invalid message format: {}",
91                    e
92                )))
93            }
94        }
95    }
96
97    async fn close(&mut self) -> Result<()> {
98        self.log("MCP server shutting down").await?;
99        Ok(())
100    }
101}
102
103/// Default Stdio server type
104pub type DefaultStdioServer = StdioServer;