pulseengine_mcp_transport/
stdio.rs1use crate::{
4    batch::{create_error_response, process_batch, JsonRpcMessage},
5    validation::{extract_id_from_malformed, validate_message_string},
6    RequestHandler, Transport, TransportError,
7};
8use async_trait::async_trait;
9use pulseengine_mcp_protocol::Response;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tracing::{debug, error, info, warn};
13
14#[derive(Debug, Clone)]
16pub struct StdioConfig {
17    pub max_message_size: usize,
19    pub validate_messages: bool,
21}
22
23impl Default for StdioConfig {
24    fn default() -> Self {
25        Self {
26            max_message_size: 10 * 1024 * 1024, validate_messages: true,
28        }
29    }
30}
31
32pub struct StdioTransport {
41    running: Arc<std::sync::atomic::AtomicBool>,
42    config: StdioConfig,
43}
44
45impl StdioTransport {
46    pub fn new() -> Self {
48        Self {
49            running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
50            config: StdioConfig::default(),
51        }
52    }
53
54    pub fn with_config(config: StdioConfig) -> Self {
56        Self {
57            running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
58            config,
59        }
60    }
61
62    async fn process_line(
64        &self,
65        line: &str,
66        handler: &RequestHandler,
67        stdout: &mut tokio::io::Stdout,
68    ) -> Result<(), TransportError> {
69        if self.config.validate_messages {
71            if let Err(e) = validate_message_string(line, Some(self.config.max_message_size)) {
72                warn!("Message validation failed: {}", e);
73
74                let request_id = extract_id_from_malformed(line);
76                let error_response = create_error_response(
77                    pulseengine_mcp_protocol::Error::invalid_request(format!(
78                        "Message validation failed: {e}"
79                    )),
80                    request_id,
81                );
82
83                self.send_response(stdout, &error_response).await?;
84                return Ok(());
85            }
86        }
87
88        debug!("Processing message: {}", line);
89
90        let message = match JsonRpcMessage::parse(line) {
92            Ok(msg) => msg,
93            Err(e) => {
94                error!("Failed to parse JSON: {}", e);
95
96                let request_id = extract_id_from_malformed(line);
98                let error_response = create_error_response(
99                    pulseengine_mcp_protocol::Error::parse_error(format!("Invalid JSON: {e}")),
100                    request_id,
101                );
102
103                self.send_response(stdout, &error_response).await?;
104                return Ok(());
105            }
106        };
107
108        if let Err(e) = message.validate() {
110            warn!("JSON-RPC validation failed: {}", e);
111
112            let error_response = create_error_response(
114                pulseengine_mcp_protocol::Error::invalid_request(format!("Invalid JSON-RPC: {e}")),
115                serde_json::Value::Null,
116            );
117
118            self.send_response(stdout, &error_response).await?;
119            return Ok(());
120        }
121
122        match process_batch(message, handler).await {
124            Ok(Some(response_message)) => {
125                let response_json = response_message.to_string().map_err(|e| {
127                    TransportError::Protocol(format!("Failed to serialize response: {e}"))
128                })?;
129
130                self.send_line(stdout, &response_json).await?;
131            }
132            Ok(None) => {
133                debug!("No response needed for message");
135            }
136            Err(e) => {
137                error!("Failed to process message: {}", e);
138
139                let error_response = create_error_response(
141                    pulseengine_mcp_protocol::Error::internal_error(format!(
142                        "Processing failed: {e}"
143                    )),
144                    serde_json::Value::Null,
145                );
146
147                self.send_response(stdout, &error_response).await?;
148            }
149        }
150
151        Ok(())
152    }
153
154    async fn send_response(
156        &self,
157        stdout: &mut tokio::io::Stdout,
158        response: &Response,
159    ) -> Result<(), TransportError> {
160        let response_json = serde_json::to_string(response)
161            .map_err(|e| TransportError::Protocol(format!("Failed to serialize response: {e}")))?;
162
163        self.send_line(stdout, &response_json).await
164    }
165
166    async fn send_line(
168        &self,
169        stdout: &mut tokio::io::Stdout,
170        line: &str,
171    ) -> Result<(), TransportError> {
172        if self.config.validate_messages {
174            if let Err(e) = validate_message_string(line, Some(self.config.max_message_size)) {
175                return Err(TransportError::Protocol(format!(
176                    "Outgoing message validation failed: {e}"
177                )));
178            }
179        }
180
181        debug!("Sending response: {}", line);
182
183        let line_with_newline = format!("{line}\n");
185
186        if let Err(e) = stdout.write_all(line_with_newline.as_bytes()).await {
187            return Err(TransportError::Connection(format!(
188                "Failed to write to stdout: {e}"
189            )));
190        }
191
192        if let Err(e) = stdout.flush().await {
193            return Err(TransportError::Connection(format!(
194                "Failed to flush stdout: {e}"
195            )));
196        }
197
198        Ok(())
199    }
200}
201
202impl Default for StdioTransport {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208#[async_trait]
209impl Transport for StdioTransport {
210    async fn start(&mut self, handler: RequestHandler) -> Result<(), TransportError> {
211        info!("Starting MCP-compliant stdio transport");
212        info!("Max message size: {} bytes", self.config.max_message_size);
213        info!("Message validation: {}", self.config.validate_messages);
214
215        self.running
216            .store(true, std::sync::atomic::Ordering::Relaxed);
217
218        let stdin = tokio::io::stdin();
219        let mut stdout = tokio::io::stdout();
220        let mut reader = BufReader::new(stdin);
221        let mut line = String::new();
222
223        while self.running.load(std::sync::atomic::Ordering::Relaxed) {
224            line.clear();
225
226            match reader.read_line(&mut line).await {
227                Ok(0) => {
228                    debug!("EOF reached, stopping stdio transport");
229                    break;
230                }
231                Ok(_) => {
232                    let trimmed_line = line.trim_end_matches(['\n', '\r']);
234
235                    if trimmed_line.is_empty() {
237                        continue;
238                    }
239
240                    if let Err(e) = self.process_line(trimmed_line, &handler, &mut stdout).await {
242                        error!("Failed to process line: {}", e);
243                        }
245                }
246                Err(e) => {
247                    error!("Failed to read from stdin: {}", e);
248                    return Err(TransportError::Connection(format!("Stdin read error: {e}")));
249                }
250            }
251        }
252
253        info!("Stdio transport stopped");
254        Ok(())
255    }
256
257    async fn stop(&mut self) -> Result<(), TransportError> {
258        info!("Stopping stdio transport");
259        self.running
260            .store(false, std::sync::atomic::Ordering::Relaxed);
261        Ok(())
262    }
263
264    async fn health_check(&self) -> Result<(), TransportError> {
265        if self.running.load(std::sync::atomic::Ordering::Relaxed) {
266            Ok(())
267        } else {
268            Err(TransportError::Connection(
269                "Transport not running".to_string(),
270            ))
271        }
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use pulseengine_mcp_protocol::{Error as McpError, Request, Response};
279    use serde_json::json;
280    use std::io::Cursor;
281
282    fn mock_handler(
284        request: Request,
285    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>> {
286        Box::pin(async move {
287            if request.method == "error_method" {
288                Response {
289                    jsonrpc: "2.0".to_string(),
290                    id: request.id,
291                    result: None,
292                    error: Some(McpError::method_not_found("Method not found")),
293                }
294            } else {
295                Response {
296                    jsonrpc: "2.0".to_string(),
297                    id: request.id,
298                    result: Some(json!({"echo": request.method})),
299                    error: None,
300                }
301            }
302        })
303    }
304
305    #[tokio::test]
306    async fn test_stdio_config() {
307        let config = StdioConfig {
308            max_message_size: 1024,
309            validate_messages: true,
310        };
311
312        let transport = StdioTransport::with_config(config.clone());
313        assert_eq!(transport.config.max_message_size, 1024);
314        assert!(transport.config.validate_messages);
315    }
316
317    #[tokio::test]
318    async fn test_message_validation() {
319        let _transport = StdioTransport::new();
320        let _handler: RequestHandler = Box::new(mock_handler);
321
322        let mut stdout_buffer = Vec::<u8>::new();
324        let _stdout = Cursor::new(&mut stdout_buffer);
325
326        let invalid_line = "{\"jsonrpc\": \"2.0\", \"method\": \"test\n\", \"id\": 1}";
328
329        assert!(validate_message_string(invalid_line, Some(1024)).is_err());
332    }
333
334    #[test]
335    fn test_extract_id_from_malformed() {
336        let text = r#"{"jsonrpc": "2.0", "method": "test", "id": 123}"#;
338        let id = extract_id_from_malformed(text);
339        assert_eq!(id, json!(123));
340
341        let text = r#"{"jsonrpc": "2.0", "method": "test", "id": "abc"}"#;
343        let id = extract_id_from_malformed(text);
344        assert_eq!(id, json!("abc"));
345
346        let text = r#"{"jsonrpc": "2.0", "method": "test", "id": 456"#; let id = extract_id_from_malformed(text);
349        assert_eq!(id, json!(456));
350
351        let text = r#"{"jsonrpc": "2.0", "method": "test"}"#;
353        let id = extract_id_from_malformed(text);
354        assert_eq!(id, serde_json::Value::Null);
355    }
356
357    #[test]
358    fn test_default_config() {
359        let config = StdioConfig::default();
360        assert_eq!(config.max_message_size, 10 * 1024 * 1024);
361        assert!(config.validate_messages);
362    }
363
364    #[tokio::test]
365    async fn test_health_check() {
366        let transport = StdioTransport::new();
367
368        assert!(transport.health_check().await.is_err());
370
371        transport
373            .running
374            .store(true, std::sync::atomic::Ordering::Relaxed);
375        assert!(transport.health_check().await.is_ok());
376    }
377}