Skip to main content

model_context_protocol/server/
stdio.rs

1//! Stdio transport for MCP Server.
2//!
3//! This module provides `McpStdioServer` which wraps the core `McpServer`
4//! and handles stdin/stdout I/O.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use mcp::server::{McpServerConfig, stdio::McpStdioServer};
10//!
11//! let config = McpServerConfig::builder()
12//!     .name("my-server")
13//!     .version("1.0.0")
14//!     .with_tool(MyTool)
15//!     .build();
16//!
17//! McpStdioServer::run(config).await?;
18//! ```
19
20use std::sync::Arc;
21
22use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
23
24use super::{McpServer, McpServerConfig, ServerError, ServerStatus};
25use crate::protocol::JsonRpcMessage;
26
27/// MCP Server with stdio transport.
28///
29/// This server reads JSON-RPC messages from stdin and writes responses
30/// to stdout. It wraps the core `McpServer` and bridges stdio I/O to
31/// the internal channel-based communication.
32pub struct McpStdioServer {
33    server: Arc<McpServer>,
34}
35
36impl McpStdioServer {
37    /// Runs an MCP server with stdio transport.
38    ///
39    /// This is the main entry point for running a stdio-based MCP server.
40    /// The function blocks until the server stops (stdin closed or error).
41    ///
42    /// # Example
43    ///
44    /// ```ignore
45    /// let config = McpServerConfig::builder()
46    ///     .name("my-server")
47    ///     .version("1.0.0")
48    ///     .with_tool(MyTool)
49    ///     .build();
50    ///
51    /// McpStdioServer::run(config).await?;
52    /// ```
53    pub async fn run(config: McpServerConfig) -> Result<(), ServerError> {
54        let (server, mut channels) = McpServer::new(config);
55
56        let stdio_server = Self {
57            server: Arc::clone(&server),
58        };
59
60        // Spawn stdout writer task
61        let stdout_handle = tokio::spawn(async move {
62            let mut stdout = tokio::io::stdout();
63
64            while let Some(outbound) = channels.outbound_rx.recv().await {
65                let json = match outbound.to_json() {
66                    Ok(j) => j,
67                    Err(e) => {
68                        eprintln!("Failed to serialize outbound message: {}", e);
69                        continue;
70                    }
71                };
72
73                if let Err(e) = stdout.write_all(json.as_bytes()).await {
74                    eprintln!("Failed to write to stdout: {}", e);
75                    break;
76                }
77                if let Err(e) = stdout.write_all(b"\n").await {
78                    eprintln!("Failed to write newline to stdout: {}", e);
79                    break;
80                }
81                if let Err(e) = stdout.flush().await {
82                    eprintln!("Failed to flush stdout: {}", e);
83                    break;
84                }
85            }
86        });
87
88        // Run stdin reader in current task
89        let stdin = tokio::io::stdin();
90        let mut reader = BufReader::new(stdin);
91        let mut line = String::new();
92
93        loop {
94            line.clear();
95
96            match reader.read_line(&mut line).await {
97                Ok(0) => {
98                    // EOF - stdin closed
99                    break;
100                }
101                Ok(_) => {
102                    let trimmed = line.trim();
103                    if trimmed.is_empty() {
104                        continue;
105                    }
106
107                    // Parse the incoming message
108                    match JsonRpcMessage::parse(trimmed) {
109                        Ok(message) => {
110                            let inbound = message.into_client_inbound();
111                            if channels.inbound_tx.send(inbound).await.is_err() {
112                                // Server stopped
113                                break;
114                            }
115                        }
116                        Err(e) => {
117                            // Send parse error response through the channel
118                            // to ensure synchronization with other outbound messages
119                            let error_response = crate::protocol::JsonRpcResponse::error(
120                                crate::protocol::JsonRpcId::Null,
121                                -32700,
122                                format!("Parse error: {}", e),
123                                None,
124                            );
125                            let outbound =
126                                crate::protocol::ServerOutbound::Response(error_response);
127                            if channels.outbound_tx.send(outbound).await.is_err() {
128                                // Channel closed, server stopped
129                                break;
130                            }
131                        }
132                    }
133                }
134                Err(e) => {
135                    return Err(ServerError::Io(e));
136                }
137            }
138
139            // Check if server is still running
140            if stdio_server.server.status() != ServerStatus::Running {
141                break;
142            }
143        }
144
145        // Stop the server
146        server.stop();
147
148        // Wait for stdout writer to finish
149        let _ = stdout_handle.await;
150
151        Ok(())
152    }
153
154    /// Returns the underlying server reference.
155    pub fn server(&self) -> &Arc<McpServer> {
156        &self.server
157    }
158}
159
160#[cfg(test)]
161mod tests {
162    use crate::protocol::{JsonRpcId, ServerOutbound};
163    use tokio::sync::mpsc;
164
165    #[test]
166    fn test_stdio_server_module_exists() {
167        // Basic module existence test
168        // Full integration tests would require stdin/stdout mocking
169    }
170
171    /// Test that all outbound messages are synchronized through a single channel.
172    ///
173    /// This test verifies that:
174    /// 1. Parse errors are routed through the outbound channel (not directly to stdout)
175    /// 2. Multiple concurrent messages maintain their order when sent through the channel
176    /// 3. No message interleaving can occur because there's only one writer
177    #[tokio::test]
178    async fn test_outbound_message_synchronization() {
179        // Create a channel to simulate the outbound message flow
180        let (outbound_tx, mut outbound_rx) = mpsc::channel::<ServerOutbound>(256);
181
182        // Simulate sending multiple messages concurrently
183        let tx1 = outbound_tx.clone();
184        let tx2 = outbound_tx.clone();
185        let tx3 = outbound_tx.clone();
186
187        // Spawn tasks that send messages "simultaneously"
188        let handles = vec![
189            tokio::spawn(async move {
190                for i in 0..10 {
191                    let response = crate::protocol::JsonRpcResponse::success(
192                        JsonRpcId::Number(i),
193                        serde_json::json!({"msg": format!("response_{}", i)}),
194                    );
195                    tx1.send(ServerOutbound::Response(response)).await.unwrap();
196                }
197            }),
198            tokio::spawn(async move {
199                for i in 10..20 {
200                    let response = crate::protocol::JsonRpcResponse::error(
201                        JsonRpcId::Number(i),
202                        -32700,
203                        format!("Parse error {}", i),
204                        None,
205                    );
206                    tx2.send(ServerOutbound::Response(response)).await.unwrap();
207                }
208            }),
209            tokio::spawn(async move {
210                for i in 20..30 {
211                    let notification =
212                        crate::protocol::JsonRpcNotification::new(format!("notify_{}", i), None);
213                    tx3.send(ServerOutbound::Notification(notification))
214                        .await
215                        .unwrap();
216                }
217            }),
218        ];
219
220        // Wait for all senders to complete
221        for handle in handles {
222            handle.await.unwrap();
223        }
224
225        // Drop the original sender so the channel closes
226        drop(outbound_tx);
227
228        // Collect all messages - they should be complete (not interleaved)
229        let mut messages = Vec::new();
230        while let Some(msg) = outbound_rx.recv().await {
231            let json = msg.to_json().unwrap();
232            // Verify each message is valid JSON (not corrupted by interleaving)
233            let parsed: serde_json::Value = serde_json::from_str(&json)
234                .expect("Each message should be valid JSON - no interleaving");
235            messages.push(parsed);
236        }
237
238        // We should have received all 30 messages
239        assert_eq!(messages.len(), 30, "All messages should be received");
240
241        // Verify message integrity - each should be a complete, valid JSON-RPC message
242        for msg in &messages {
243            assert!(
244                msg.get("jsonrpc").is_some(),
245                "Each message should have jsonrpc field"
246            );
247        }
248    }
249
250    /// Test that the single-writer pattern prevents interleaving.
251    ///
252    /// By using a single channel receiver that writes to output, we guarantee
253    /// that messages are written atomically one at a time.
254    #[tokio::test]
255    async fn test_single_writer_pattern() {
256        use std::sync::atomic::{AtomicUsize, Ordering};
257        use std::sync::Arc;
258
259        let (outbound_tx, mut outbound_rx) = mpsc::channel::<ServerOutbound>(256);
260        let write_count = Arc::new(AtomicUsize::new(0));
261        let concurrent_writes = Arc::new(AtomicUsize::new(0));
262        let max_concurrent = Arc::new(AtomicUsize::new(0));
263
264        // Simulate the single writer task
265        let write_count_clone = Arc::clone(&write_count);
266        let concurrent_clone = Arc::clone(&concurrent_writes);
267        let max_clone = Arc::clone(&max_concurrent);
268
269        let writer_handle = tokio::spawn(async move {
270            while let Some(outbound) = outbound_rx.recv().await {
271                // Track concurrent writes
272                let current = concurrent_clone.fetch_add(1, Ordering::SeqCst) + 1;
273
274                // Update max concurrent if this is higher
275                let mut max = max_clone.load(Ordering::SeqCst);
276                while current > max {
277                    match max_clone.compare_exchange(
278                        max,
279                        current,
280                        Ordering::SeqCst,
281                        Ordering::SeqCst,
282                    ) {
283                        Ok(_) => break,
284                        Err(m) => max = m,
285                    }
286                }
287
288                // Simulate write operation
289                let _json = outbound.to_json().unwrap();
290
291                // Small delay to increase chance of detecting concurrency issues
292                tokio::task::yield_now().await;
293
294                write_count_clone.fetch_add(1, Ordering::SeqCst);
295                concurrent_clone.fetch_sub(1, Ordering::SeqCst);
296            }
297        });
298
299        // Send messages from multiple tasks
300        let mut send_handles = Vec::new();
301        for batch in 0..5 {
302            let tx = outbound_tx.clone();
303            send_handles.push(tokio::spawn(async move {
304                for i in 0..10 {
305                    let response = crate::protocol::JsonRpcResponse::success(
306                        JsonRpcId::Number(batch * 10 + i),
307                        serde_json::json!({}),
308                    );
309                    tx.send(ServerOutbound::Response(response)).await.unwrap();
310                }
311            }));
312        }
313
314        // Wait for all senders
315        for handle in send_handles {
316            handle.await.unwrap();
317        }
318        drop(outbound_tx);
319
320        // Wait for writer to finish
321        writer_handle.await.unwrap();
322
323        // Verify all messages were written
324        assert_eq!(write_count.load(Ordering::SeqCst), 50);
325
326        // The max concurrent writes should be 1 (single writer)
327        // Note: Due to the async nature, this might occasionally be 0 if
328        // the check happens between increment and actual write
329        assert!(
330            max_concurrent.load(Ordering::SeqCst) <= 1,
331            "Single writer should never have more than 1 concurrent write"
332        );
333    }
334}