pulseengine_mcp_transport/
stdio.rs

1//! MCP-compliant Standard I/O transport implementation
2
3use crate::{
4    RequestHandler, Transport, TransportError,
5    batch::{JsonRpcMessage, create_error_response, process_batch},
6    validation::{extract_id_from_malformed, validate_message_string},
7};
8use async_trait::async_trait;
9use pulseengine_mcp_protocol::Response;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tracing::{debug, error, info, warn};
13
14/// Configuration for stdio transport
15#[derive(Debug, Clone)]
16pub struct StdioConfig {
17    /// Maximum message size in bytes (default: 10MB)
18    pub max_message_size: usize,
19    /// Enable message validation
20    pub validate_messages: bool,
21}
22
23impl Default for StdioConfig {
24    fn default() -> Self {
25        Self {
26            max_message_size: 10 * 1024 * 1024, // 10MB
27            validate_messages: true,
28        }
29    }
30}
31
32/// Standard I/O transport for MCP protocol
33///
34/// Implements the MCP stdio transport specification:
35/// - Messages are delimited by newlines
36/// - Messages MUST NOT contain embedded newlines
37/// - Messages must be valid UTF-8
38/// - Supports JSON-RPC batching
39/// - Proper error handling with ID preservation
40#[derive(Debug)]
41pub struct StdioTransport {
42    running: Arc<std::sync::atomic::AtomicBool>,
43    config: StdioConfig,
44}
45
46impl StdioTransport {
47    /// Create a new stdio transport with default configuration
48    pub fn new() -> Self {
49        Self {
50            running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
51            config: StdioConfig::default(),
52        }
53    }
54
55    /// Create a new stdio transport with custom configuration
56    pub fn with_config(config: StdioConfig) -> Self {
57        Self {
58            running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
59            config,
60        }
61    }
62
63    /// Get the configuration
64    pub fn config(&self) -> &StdioConfig {
65        &self.config
66    }
67
68    /// Check if the transport is running
69    pub fn is_running(&self) -> bool {
70        self.running.load(std::sync::atomic::Ordering::Relaxed)
71    }
72
73    /// Set running state (for testing purposes)
74    #[cfg(test)]
75    pub fn set_running(&self, running: bool) {
76        self.running
77            .store(running, std::sync::atomic::Ordering::Relaxed);
78    }
79
80    /// Process a single line from stdin
81    async fn process_line(
82        &self,
83        line: &str,
84        handler: &RequestHandler,
85        stdout: &mut tokio::io::Stdout,
86    ) -> Result<(), TransportError> {
87        // Validate message according to MCP spec
88        if self.config.validate_messages {
89            if let Err(e) = validate_message_string(line, Some(self.config.max_message_size)) {
90                warn!("Message validation failed: {}", e);
91
92                // Try to extract ID for error response
93                let request_id = extract_id_from_malformed(line);
94                let error_response = create_error_response(
95                    pulseengine_mcp_protocol::Error::invalid_request(format!(
96                        "Message validation failed: {e}"
97                    )),
98                    request_id,
99                );
100
101                self.send_response(stdout, &error_response).await?;
102                return Ok(());
103            }
104        }
105
106        debug!("Processing message: {}", line);
107
108        // Parse JSON-RPC message (single or batch)
109        let message = match JsonRpcMessage::parse(line) {
110            Ok(msg) => msg,
111            Err(e) => {
112                error!("Failed to parse JSON: {}", e);
113
114                // Try to extract ID for error response
115                let request_id = extract_id_from_malformed(line);
116                let error_response = create_error_response(
117                    pulseengine_mcp_protocol::Error::parse_error(format!("Invalid JSON: {e}")),
118                    request_id,
119                );
120
121                self.send_response(stdout, &error_response).await?;
122                return Ok(());
123            }
124        };
125
126        // Validate JSON-RPC structure
127        if let Err(e) = message.validate() {
128            warn!("JSON-RPC validation failed: {}", e);
129
130            // For invalid structure, we can't reliably extract ID, use None
131            let error_response = create_error_response(
132                pulseengine_mcp_protocol::Error::invalid_request(format!("Invalid JSON-RPC: {e}")),
133                None,
134            );
135
136            self.send_response(stdout, &error_response).await?;
137            return Ok(());
138        }
139
140        // Process the message (handles both single and batch)
141        match process_batch(message, handler).await {
142            Ok(Some(response_message)) => {
143                // Send response(s)
144                let response_json = response_message.to_string().map_err(|e| {
145                    TransportError::Protocol(format!("Failed to serialize response: {e}"))
146                })?;
147
148                self.send_line(stdout, &response_json).await?;
149            }
150            Ok(None) => {
151                // No response needed (notifications only)
152                debug!("No response needed for message");
153            }
154            Err(e) => {
155                error!("Failed to process message: {}", e);
156
157                // Send generic error response
158                let error_response = create_error_response(
159                    pulseengine_mcp_protocol::Error::internal_error(format!(
160                        "Processing failed: {e}"
161                    )),
162                    None,
163                );
164
165                self.send_response(stdout, &error_response).await?;
166            }
167        }
168
169        Ok(())
170    }
171
172    /// Send a response to stdout
173    async fn send_response(
174        &self,
175        stdout: &mut tokio::io::Stdout,
176        response: &Response,
177    ) -> Result<(), TransportError> {
178        let response_json = serde_json::to_string(response)
179            .map_err(|e| TransportError::Protocol(format!("Failed to serialize response: {e}")))?;
180
181        self.send_line(stdout, &response_json).await
182    }
183
184    /// Send a line to stdout with proper newline handling
185    async fn send_line(
186        &self,
187        stdout: &mut tokio::io::Stdout,
188        line: &str,
189    ) -> Result<(), TransportError> {
190        // Validate outgoing message
191        if self.config.validate_messages {
192            if let Err(e) = validate_message_string(line, Some(self.config.max_message_size)) {
193                return Err(TransportError::Protocol(format!(
194                    "Outgoing message validation failed: {e}"
195                )));
196            }
197        }
198
199        debug!("Sending response: {}", line);
200
201        // Write with newline
202        let line_with_newline = format!("{line}\n");
203
204        if let Err(e) = stdout.write_all(line_with_newline.as_bytes()).await {
205            return Err(TransportError::Connection(format!(
206                "Failed to write to stdout: {e}"
207            )));
208        }
209
210        if let Err(e) = stdout.flush().await {
211            return Err(TransportError::Connection(format!(
212                "Failed to flush stdout: {e}"
213            )));
214        }
215
216        Ok(())
217    }
218}
219
220impl Default for StdioTransport {
221    fn default() -> Self {
222        Self::new()
223    }
224}
225
226#[async_trait]
227impl Transport for StdioTransport {
228    async fn start(&mut self, handler: RequestHandler) -> Result<(), TransportError> {
229        info!("Starting MCP-compliant stdio transport");
230        info!("Max message size: {} bytes", self.config.max_message_size);
231        info!("Message validation: {}", self.config.validate_messages);
232
233        self.running
234            .store(true, std::sync::atomic::Ordering::Relaxed);
235
236        let stdin = tokio::io::stdin();
237        let mut stdout = tokio::io::stdout();
238        let mut reader = BufReader::new(stdin);
239        let mut line = String::new();
240
241        while self.running.load(std::sync::atomic::Ordering::Relaxed) {
242            line.clear();
243
244            match reader.read_line(&mut line).await {
245                Ok(0) => {
246                    debug!("EOF reached, stopping stdio transport");
247                    break;
248                }
249                Ok(_) => {
250                    // Remove trailing newline for processing
251                    let trimmed_line = line.trim_end_matches(['\n', '\r']);
252
253                    // Skip empty lines
254                    if trimmed_line.is_empty() {
255                        continue;
256                    }
257
258                    // Process the line
259                    if let Err(e) = self.process_line(trimmed_line, &handler, &mut stdout).await {
260                        error!("Failed to process line: {}", e);
261                        // Continue processing other messages
262                    }
263                }
264                Err(e) => {
265                    error!("Failed to read from stdin: {}", e);
266                    return Err(TransportError::Connection(format!("Stdin read error: {e}")));
267                }
268            }
269        }
270
271        info!("Stdio transport stopped");
272        Ok(())
273    }
274
275    async fn stop(&mut self) -> Result<(), TransportError> {
276        info!("Stopping stdio transport");
277        self.running
278            .store(false, std::sync::atomic::Ordering::Relaxed);
279        Ok(())
280    }
281
282    async fn health_check(&self) -> Result<(), TransportError> {
283        if self.running.load(std::sync::atomic::Ordering::Relaxed) {
284            Ok(())
285        } else {
286            Err(TransportError::Connection(
287                "Transport not running".to_string(),
288            ))
289        }
290    }
291}
292
293#[cfg(test)]
294mod tests {
295    use super::*;
296    use pulseengine_mcp_protocol::{Error as McpError, Request, Response};
297    use serde_json::json;
298    use std::io::Cursor;
299
300    // Mock handler for testing
301    fn mock_handler(
302        request: Request,
303    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>> {
304        Box::pin(async move {
305            if request.method == "error_method" {
306                Response {
307                    jsonrpc: "2.0".to_string(),
308                    id: request.id,
309                    result: None,
310                    error: Some(McpError::method_not_found("Method not found")),
311                }
312            } else {
313                Response {
314                    jsonrpc: "2.0".to_string(),
315                    id: request.id,
316                    result: Some(json!({"echo": request.method})),
317                    error: None,
318                }
319            }
320        })
321    }
322
323    #[tokio::test]
324    async fn test_stdio_config() {
325        let config = StdioConfig {
326            max_message_size: 1024,
327            validate_messages: true,
328        };
329
330        let transport = StdioTransport::with_config(config.clone());
331        assert_eq!(transport.config.max_message_size, 1024);
332        assert!(transport.config.validate_messages);
333    }
334
335    #[tokio::test]
336    async fn test_message_validation() {
337        let _transport = StdioTransport::new();
338        let _handler: RequestHandler = Box::new(mock_handler);
339
340        // Create a mock stdout
341        let mut stdout_buffer = Vec::<u8>::new();
342        let _stdout = Cursor::new(&mut stdout_buffer);
343
344        // Test invalid message with embedded newline
345        let invalid_line = "{\"jsonrpc\": \"2.0\", \"method\": \"test\n\", \"id\": 1}";
346
347        // This should create a mock stdout that we can write to
348        // For this test, we'll just verify the validation logic
349        assert!(validate_message_string(invalid_line, Some(1024)).is_err());
350    }
351
352    #[test]
353    fn test_extract_id_from_malformed() {
354        // Test valid JSON with ID
355        let text = r#"{"jsonrpc": "2.0", "method": "test", "id": 123}"#;
356        let id = extract_id_from_malformed(text);
357        assert_eq!(
358            id,
359            Some(pulseengine_mcp_protocol::NumberOrString::Number(123))
360        );
361
362        // Test string ID
363        let text = r#"{"jsonrpc": "2.0", "method": "test", "id": "abc"}"#;
364        let id = extract_id_from_malformed(text);
365        assert_eq!(
366            id,
367            Some(pulseengine_mcp_protocol::NumberOrString::String(
368                std::sync::Arc::from("abc")
369            ))
370        );
371
372        // Test malformed JSON
373        let text = r#"{"jsonrpc": "2.0", "method": "test", "id": 456"#; // Missing closing brace
374        let id = extract_id_from_malformed(text);
375        assert_eq!(
376            id,
377            Some(pulseengine_mcp_protocol::NumberOrString::Number(456))
378        );
379
380        // Test no ID
381        let text = r#"{"jsonrpc": "2.0", "method": "test"}"#;
382        let id = extract_id_from_malformed(text);
383        assert_eq!(id, None);
384    }
385
386    #[test]
387    fn test_default_config() {
388        let config = StdioConfig::default();
389        assert_eq!(config.max_message_size, 10 * 1024 * 1024);
390        assert!(config.validate_messages);
391    }
392
393    #[tokio::test]
394    async fn test_health_check() {
395        let transport = StdioTransport::new();
396
397        // Initially not running
398        assert!(transport.health_check().await.is_err());
399
400        // Set as running
401        transport
402            .running
403            .store(true, std::sync::atomic::Ordering::Relaxed);
404        assert!(transport.health_check().await.is_ok());
405    }
406
407    #[test]
408    fn test_transport_creation() {
409        let transport = StdioTransport::new();
410        assert!(!transport.is_running());
411        assert_eq!(transport.config().max_message_size, 10 * 1024 * 1024);
412        assert!(transport.config().validate_messages);
413    }
414
415    #[test]
416    fn test_transport_with_custom_config() {
417        let config = StdioConfig {
418            max_message_size: 2048,
419            validate_messages: false,
420        };
421        let transport = StdioTransport::with_config(config);
422
423        assert!(!transport.is_running());
424        assert_eq!(transport.config().max_message_size, 2048);
425        assert!(!transport.config().validate_messages);
426    }
427
428    #[test]
429    fn test_default_transport() {
430        let transport = StdioTransport::default();
431        assert!(!transport.is_running());
432        assert_eq!(transport.config().max_message_size, 10 * 1024 * 1024);
433        assert!(transport.config().validate_messages);
434    }
435
436    #[test]
437    fn test_running_state() {
438        let transport = StdioTransport::new();
439
440        // Initially not running
441        assert!(!transport.is_running());
442
443        // Set running
444        transport.set_running(true);
445        assert!(transport.is_running());
446
447        // Set not running
448        transport.set_running(false);
449        assert!(!transport.is_running());
450    }
451
452    #[tokio::test]
453    async fn test_stop_transport() {
454        let mut transport = StdioTransport::new();
455
456        // Set as running first
457        transport.set_running(true);
458        assert!(transport.is_running());
459
460        // Stop the transport
461        assert!(transport.stop().await.is_ok());
462        assert!(!transport.is_running());
463    }
464
465    #[test]
466    fn test_stdio_config_clone() {
467        let config1 = StdioConfig {
468            max_message_size: 1024,
469            validate_messages: true,
470        };
471
472        let config2 = config1.clone();
473        assert_eq!(config1.max_message_size, config2.max_message_size);
474        assert_eq!(config1.validate_messages, config2.validate_messages);
475    }
476
477    #[test]
478    fn test_config_debug() {
479        let config = StdioConfig::default();
480        let debug_str = format!("{config:?}");
481        assert!(debug_str.contains("StdioConfig"));
482        assert!(debug_str.contains("max_message_size"));
483        assert!(debug_str.contains("validate_messages"));
484    }
485
486    #[test]
487    fn test_transport_debug() {
488        let transport = StdioTransport::new();
489        let debug_str = format!("{transport:?}");
490        assert!(debug_str.contains("StdioTransport"));
491        assert!(debug_str.contains("running"));
492        assert!(debug_str.contains("config"));
493    }
494
495    #[tokio::test]
496    async fn test_message_size_validation() {
497        let config = StdioConfig {
498            max_message_size: 50, // Very small for testing
499            validate_messages: true,
500        };
501        let _transport = StdioTransport::with_config(config);
502
503        // Large message should fail validation
504        let large_message = "x".repeat(100);
505        assert!(validate_message_string(&large_message, Some(50)).is_err());
506
507        // Small message should pass
508        let small_message = "x".repeat(10);
509        assert!(validate_message_string(&small_message, Some(50)).is_ok());
510    }
511
512    #[test]
513    fn test_json_rpc_message_parsing() {
514        // Valid JSON-RPC request
515        let valid_msg = r#"{"jsonrpc": "2.0", "method": "test", "id": 1}"#;
516        let parsed = JsonRpcMessage::parse(valid_msg);
517        assert!(parsed.is_ok());
518
519        // Invalid JSON
520        let invalid_msg = r#"{"jsonrpc": "2.0", "method": "test""#; // Missing closing brace
521        let parsed = JsonRpcMessage::parse(invalid_msg);
522        assert!(parsed.is_err());
523    }
524
525    #[test]
526    fn test_message_validation_edge_cases() {
527        // Message with newline (should fail)
528        let newline_msg = "line1\nline2";
529        assert!(validate_message_string(newline_msg, Some(1024)).is_err());
530
531        // Message with carriage return (should fail)
532        let cr_msg = "line1\rline2";
533        assert!(validate_message_string(cr_msg, Some(1024)).is_err());
534
535        // Empty message (should pass)
536        let empty_msg = "";
537        assert!(validate_message_string(empty_msg, Some(1024)).is_ok());
538
539        // Normal message (should pass)
540        let normal_msg = "valid message";
541        assert!(validate_message_string(normal_msg, Some(1024)).is_ok());
542    }
543
544    #[test]
545    fn test_extract_id_edge_cases() {
546        // Null ID
547        let text = r#"{"jsonrpc": "2.0", "method": "test", "id": null}"#;
548        let id = extract_id_from_malformed(text);
549        assert_eq!(id, None);
550
551        // Boolean ID (not standard but should handle)
552        let text = r#"{"jsonrpc": "2.0", "method": "test", "id": true}"#;
553        let id = extract_id_from_malformed(text);
554        assert_eq!(id, None);
555
556        // Completely invalid JSON
557        let text = "not json at all";
558        let id = extract_id_from_malformed(text);
559        assert_eq!(id, None);
560
561        // Empty string
562        let text = "";
563        let id = extract_id_from_malformed(text);
564        assert_eq!(id, None);
565    }
566
567    #[tokio::test]
568    async fn test_response_serialization() {
569        let response = Response {
570            jsonrpc: "2.0".to_string(),
571            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(1)),
572            result: Some(json!({"status": "ok"})),
573            error: None,
574        };
575
576        let serialized = serde_json::to_string(&response);
577        assert!(serialized.is_ok());
578
579        let json_str = serialized.unwrap();
580        assert!(json_str.contains("jsonrpc"));
581        assert!(json_str.contains("2.0"));
582        assert!(json_str.contains("status"));
583    }
584
585    #[tokio::test]
586    async fn test_error_response_creation() {
587        let error = McpError::invalid_request("Test error");
588        let request_id = Some(pulseengine_mcp_protocol::NumberOrString::Number(42));
589
590        let response = create_error_response(error, request_id);
591
592        assert_eq!(response.jsonrpc, "2.0");
593        assert_eq!(
594            response.id,
595            Some(pulseengine_mcp_protocol::NumberOrString::Number(42))
596        );
597        assert!(response.error.is_some());
598        assert!(response.result.is_none());
599
600        let error_obj = response.error.unwrap();
601        assert!(error_obj.message.contains("Test error"));
602    }
603
604    #[test]
605    fn test_mock_handler_functionality() {
606        tokio::runtime::Runtime::new().unwrap().block_on(async {
607            let handler = mock_handler;
608
609            // Test normal method
610            let request = Request {
611                jsonrpc: "2.0".to_string(),
612                method: "test_method".to_string(),
613                params: json!({}),
614                id: Some(pulseengine_mcp_protocol::NumberOrString::Number(1)),
615            };
616
617            let response = handler(request).await;
618            assert_eq!(response.jsonrpc, "2.0");
619            assert_eq!(
620                response.id,
621                Some(pulseengine_mcp_protocol::NumberOrString::Number(1))
622            );
623            assert!(response.result.is_some());
624            assert!(response.error.is_none());
625
626            // Test error method
627            let error_request = Request {
628                jsonrpc: "2.0".to_string(),
629                method: "error_method".to_string(),
630                params: json!({}),
631                id: Some(pulseengine_mcp_protocol::NumberOrString::Number(2)),
632            };
633
634            let error_response = handler(error_request).await;
635            assert_eq!(error_response.jsonrpc, "2.0");
636            assert_eq!(
637                error_response.id,
638                Some(pulseengine_mcp_protocol::NumberOrString::Number(2))
639            );
640            assert!(error_response.result.is_none());
641            assert!(error_response.error.is_some());
642        });
643    }
644}