model_context_protocol/server/stdio.rs
1//! Stdio transport for MCP Server.
2//!
3//! This module provides `McpStdioServer` which wraps the core `McpServer`
4//! and handles stdin/stdout I/O.
5//!
6//! # Example
7//!
8//! ```ignore
9//! use mcp::server::{McpServerConfig, stdio::McpStdioServer};
10//!
11//! let config = McpServerConfig::builder()
12//! .name("my-server")
13//! .version("1.0.0")
14//! .with_tool(MyTool)
15//! .build();
16//!
17//! McpStdioServer::run(config).await?;
18//! ```
19
20use std::sync::Arc;
21
22use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
23
24use super::{McpServer, McpServerConfig, ServerError, ServerStatus};
25use crate::protocol::JsonRpcMessage;
26
27/// MCP Server with stdio transport.
28///
29/// This server reads JSON-RPC messages from stdin and writes responses
30/// to stdout. It wraps the core `McpServer` and bridges stdio I/O to
31/// the internal channel-based communication.
32pub struct McpStdioServer {
33 server: Arc<McpServer>,
34}
35
36impl McpStdioServer {
37 /// Runs an MCP server with stdio transport.
38 ///
39 /// This is the main entry point for running a stdio-based MCP server.
40 /// The function blocks until the server stops (stdin closed or error).
41 ///
42 /// # Example
43 ///
44 /// ```ignore
45 /// let config = McpServerConfig::builder()
46 /// .name("my-server")
47 /// .version("1.0.0")
48 /// .with_tool(MyTool)
49 /// .build();
50 ///
51 /// McpStdioServer::run(config).await?;
52 /// ```
53 pub async fn run(config: McpServerConfig) -> Result<(), ServerError> {
54 let (server, mut channels) = McpServer::new(config);
55
56 let stdio_server = Self {
57 server: Arc::clone(&server),
58 };
59
60 // Spawn stdout writer task
61 let stdout_handle = tokio::spawn(async move {
62 let mut stdout = tokio::io::stdout();
63
64 while let Some(outbound) = channels.outbound_rx.recv().await {
65 let json = match outbound.to_json() {
66 Ok(j) => j,
67 Err(e) => {
68 eprintln!("Failed to serialize outbound message: {}", e);
69 continue;
70 }
71 };
72
73 if let Err(e) = stdout.write_all(json.as_bytes()).await {
74 eprintln!("Failed to write to stdout: {}", e);
75 break;
76 }
77 if let Err(e) = stdout.write_all(b"\n").await {
78 eprintln!("Failed to write newline to stdout: {}", e);
79 break;
80 }
81 if let Err(e) = stdout.flush().await {
82 eprintln!("Failed to flush stdout: {}", e);
83 break;
84 }
85 }
86 });
87
88 // Run stdin reader in current task
89 let stdin = tokio::io::stdin();
90 let mut reader = BufReader::new(stdin);
91 let mut line = String::new();
92
93 loop {
94 line.clear();
95
96 match reader.read_line(&mut line).await {
97 Ok(0) => {
98 // EOF - stdin closed
99 break;
100 }
101 Ok(_) => {
102 let trimmed = line.trim();
103 if trimmed.is_empty() {
104 continue;
105 }
106
107 // Parse the incoming message
108 match JsonRpcMessage::parse(trimmed) {
109 Ok(message) => {
110 let inbound = message.into_client_inbound();
111 if channels.inbound_tx.send(inbound).await.is_err() {
112 // Server stopped
113 break;
114 }
115 }
116 Err(e) => {
117 // Send parse error response through the channel
118 // to ensure synchronization with other outbound messages
119 let error_response = crate::protocol::JsonRpcResponse::error(
120 crate::protocol::JsonRpcId::Null,
121 -32700,
122 format!("Parse error: {}", e),
123 None,
124 );
125 let outbound =
126 crate::protocol::ServerOutbound::Response(error_response);
127 if channels.outbound_tx.send(outbound).await.is_err() {
128 // Channel closed, server stopped
129 break;
130 }
131 }
132 }
133 }
134 Err(e) => {
135 return Err(ServerError::Io(e));
136 }
137 }
138
139 // Check if server is still running
140 if stdio_server.server.status() != ServerStatus::Running {
141 break;
142 }
143 }
144
145 // Stop the server
146 server.stop();
147
148 // Wait for stdout writer to finish
149 let _ = stdout_handle.await;
150
151 Ok(())
152 }
153
154 /// Returns the underlying server reference.
155 pub fn server(&self) -> &Arc<McpServer> {
156 &self.server
157 }
158}
159
160#[cfg(test)]
161mod tests {
162 use crate::protocol::{JsonRpcId, ServerOutbound};
163 use tokio::sync::mpsc;
164
165 #[test]
166 fn test_stdio_server_module_exists() {
167 // Basic module existence test
168 // Full integration tests would require stdin/stdout mocking
169 }
170
171 /// Test that all outbound messages are synchronized through a single channel.
172 ///
173 /// This test verifies that:
174 /// 1. Parse errors are routed through the outbound channel (not directly to stdout)
175 /// 2. Multiple concurrent messages maintain their order when sent through the channel
176 /// 3. No message interleaving can occur because there's only one writer
177 #[tokio::test]
178 async fn test_outbound_message_synchronization() {
179 // Create a channel to simulate the outbound message flow
180 let (outbound_tx, mut outbound_rx) = mpsc::channel::<ServerOutbound>(256);
181
182 // Simulate sending multiple messages concurrently
183 let tx1 = outbound_tx.clone();
184 let tx2 = outbound_tx.clone();
185 let tx3 = outbound_tx.clone();
186
187 // Spawn tasks that send messages "simultaneously"
188 let handles = vec![
189 tokio::spawn(async move {
190 for i in 0..10 {
191 let response = crate::protocol::JsonRpcResponse::success(
192 JsonRpcId::Number(i),
193 serde_json::json!({"msg": format!("response_{}", i)}),
194 );
195 tx1.send(ServerOutbound::Response(response)).await.unwrap();
196 }
197 }),
198 tokio::spawn(async move {
199 for i in 10..20 {
200 let response = crate::protocol::JsonRpcResponse::error(
201 JsonRpcId::Number(i),
202 -32700,
203 format!("Parse error {}", i),
204 None,
205 );
206 tx2.send(ServerOutbound::Response(response)).await.unwrap();
207 }
208 }),
209 tokio::spawn(async move {
210 for i in 20..30 {
211 let notification =
212 crate::protocol::JsonRpcNotification::new(format!("notify_{}", i), None);
213 tx3.send(ServerOutbound::Notification(notification))
214 .await
215 .unwrap();
216 }
217 }),
218 ];
219
220 // Wait for all senders to complete
221 for handle in handles {
222 handle.await.unwrap();
223 }
224
225 // Drop the original sender so the channel closes
226 drop(outbound_tx);
227
228 // Collect all messages - they should be complete (not interleaved)
229 let mut messages = Vec::new();
230 while let Some(msg) = outbound_rx.recv().await {
231 let json = msg.to_json().unwrap();
232 // Verify each message is valid JSON (not corrupted by interleaving)
233 let parsed: serde_json::Value = serde_json::from_str(&json)
234 .expect("Each message should be valid JSON - no interleaving");
235 messages.push(parsed);
236 }
237
238 // We should have received all 30 messages
239 assert_eq!(messages.len(), 30, "All messages should be received");
240
241 // Verify message integrity - each should be a complete, valid JSON-RPC message
242 for msg in &messages {
243 assert!(
244 msg.get("jsonrpc").is_some(),
245 "Each message should have jsonrpc field"
246 );
247 }
248 }
249
250 /// Test that the single-writer pattern prevents interleaving.
251 ///
252 /// By using a single channel receiver that writes to output, we guarantee
253 /// that messages are written atomically one at a time.
254 #[tokio::test]
255 async fn test_single_writer_pattern() {
256 use std::sync::atomic::{AtomicUsize, Ordering};
257 use std::sync::Arc;
258
259 let (outbound_tx, mut outbound_rx) = mpsc::channel::<ServerOutbound>(256);
260 let write_count = Arc::new(AtomicUsize::new(0));
261 let concurrent_writes = Arc::new(AtomicUsize::new(0));
262 let max_concurrent = Arc::new(AtomicUsize::new(0));
263
264 // Simulate the single writer task
265 let write_count_clone = Arc::clone(&write_count);
266 let concurrent_clone = Arc::clone(&concurrent_writes);
267 let max_clone = Arc::clone(&max_concurrent);
268
269 let writer_handle = tokio::spawn(async move {
270 while let Some(outbound) = outbound_rx.recv().await {
271 // Track concurrent writes
272 let current = concurrent_clone.fetch_add(1, Ordering::SeqCst) + 1;
273
274 // Update max concurrent if this is higher
275 let mut max = max_clone.load(Ordering::SeqCst);
276 while current > max {
277 match max_clone.compare_exchange(
278 max,
279 current,
280 Ordering::SeqCst,
281 Ordering::SeqCst,
282 ) {
283 Ok(_) => break,
284 Err(m) => max = m,
285 }
286 }
287
288 // Simulate write operation
289 let _json = outbound.to_json().unwrap();
290
291 // Small delay to increase chance of detecting concurrency issues
292 tokio::task::yield_now().await;
293
294 write_count_clone.fetch_add(1, Ordering::SeqCst);
295 concurrent_clone.fetch_sub(1, Ordering::SeqCst);
296 }
297 });
298
299 // Send messages from multiple tasks
300 let mut send_handles = Vec::new();
301 for batch in 0..5 {
302 let tx = outbound_tx.clone();
303 send_handles.push(tokio::spawn(async move {
304 for i in 0..10 {
305 let response = crate::protocol::JsonRpcResponse::success(
306 JsonRpcId::Number(batch * 10 + i),
307 serde_json::json!({}),
308 );
309 tx.send(ServerOutbound::Response(response)).await.unwrap();
310 }
311 }));
312 }
313
314 // Wait for all senders
315 for handle in send_handles {
316 handle.await.unwrap();
317 }
318 drop(outbound_tx);
319
320 // Wait for writer to finish
321 writer_handle.await.unwrap();
322
323 // Verify all messages were written
324 assert_eq!(write_count.load(Ordering::SeqCst), 50);
325
326 // The max concurrent writes should be 1 (single writer)
327 // Note: Due to the async nature, this might occasionally be 0 if
328 // the check happens between increment and actual write
329 assert!(
330 max_concurrent.load(Ordering::SeqCst) <= 1,
331 "Single writer should never have more than 1 concurrent write"
332 );
333 }
334}