pulseengine_mcp_transport/
stdio.rs1use crate::{
4 RequestHandler, Transport, TransportError,
5 batch::{JsonRpcMessage, create_error_response, process_batch},
6 validation::{extract_id_from_malformed, validate_message_string},
7};
8use async_trait::async_trait;
9use pulseengine_mcp_protocol::Response;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
12use tracing::{debug, error, info, warn};
13
14#[derive(Debug, Clone)]
16pub struct StdioConfig {
17 pub max_message_size: usize,
19 pub validate_messages: bool,
21}
22
23impl Default for StdioConfig {
24 fn default() -> Self {
25 Self {
26 max_message_size: 10 * 1024 * 1024, validate_messages: true,
28 }
29 }
30}
31
32#[derive(Debug)]
41pub struct StdioTransport {
42 running: Arc<std::sync::atomic::AtomicBool>,
43 config: StdioConfig,
44}
45
46impl StdioTransport {
47 pub fn new() -> Self {
49 Self {
50 running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
51 config: StdioConfig::default(),
52 }
53 }
54
55 pub fn with_config(config: StdioConfig) -> Self {
57 Self {
58 running: Arc::new(std::sync::atomic::AtomicBool::new(false)),
59 config,
60 }
61 }
62
63 pub fn config(&self) -> &StdioConfig {
65 &self.config
66 }
67
68 pub fn is_running(&self) -> bool {
70 self.running.load(std::sync::atomic::Ordering::Relaxed)
71 }
72
73 #[cfg(test)]
75 pub fn set_running(&self, running: bool) {
76 self.running
77 .store(running, std::sync::atomic::Ordering::Relaxed);
78 }
79
80 async fn process_line(
82 &self,
83 line: &str,
84 handler: &RequestHandler,
85 stdout: &mut tokio::io::Stdout,
86 ) -> Result<(), TransportError> {
87 if self.config.validate_messages {
89 if let Err(e) = validate_message_string(line, Some(self.config.max_message_size)) {
90 warn!("Message validation failed: {}", e);
91
92 let request_id = extract_id_from_malformed(line);
94 let error_response = create_error_response(
95 pulseengine_mcp_protocol::Error::invalid_request(format!(
96 "Message validation failed: {e}"
97 )),
98 request_id,
99 );
100
101 self.send_response(stdout, &error_response).await?;
102 return Ok(());
103 }
104 }
105
106 debug!("Processing message: {}", line);
107
108 let message = match JsonRpcMessage::parse(line) {
110 Ok(msg) => msg,
111 Err(e) => {
112 error!("Failed to parse JSON: {}", e);
113
114 let request_id = extract_id_from_malformed(line);
116 let error_response = create_error_response(
117 pulseengine_mcp_protocol::Error::parse_error(format!("Invalid JSON: {e}")),
118 request_id,
119 );
120
121 self.send_response(stdout, &error_response).await?;
122 return Ok(());
123 }
124 };
125
126 if let Err(e) = message.validate() {
128 warn!("JSON-RPC validation failed: {}", e);
129
130 let error_response = create_error_response(
132 pulseengine_mcp_protocol::Error::invalid_request(format!("Invalid JSON-RPC: {e}")),
133 None,
134 );
135
136 self.send_response(stdout, &error_response).await?;
137 return Ok(());
138 }
139
140 match process_batch(message, handler).await {
142 Ok(Some(response_message)) => {
143 let response_json = response_message.to_string().map_err(|e| {
145 TransportError::Protocol(format!("Failed to serialize response: {e}"))
146 })?;
147
148 self.send_line(stdout, &response_json).await?;
149 }
150 Ok(None) => {
151 debug!("No response needed for message");
153 }
154 Err(e) => {
155 error!("Failed to process message: {}", e);
156
157 let error_response = create_error_response(
159 pulseengine_mcp_protocol::Error::internal_error(format!(
160 "Processing failed: {e}"
161 )),
162 None,
163 );
164
165 self.send_response(stdout, &error_response).await?;
166 }
167 }
168
169 Ok(())
170 }
171
172 async fn send_response(
174 &self,
175 stdout: &mut tokio::io::Stdout,
176 response: &Response,
177 ) -> Result<(), TransportError> {
178 let response_json = serde_json::to_string(response)
179 .map_err(|e| TransportError::Protocol(format!("Failed to serialize response: {e}")))?;
180
181 self.send_line(stdout, &response_json).await
182 }
183
184 async fn send_line(
186 &self,
187 stdout: &mut tokio::io::Stdout,
188 line: &str,
189 ) -> Result<(), TransportError> {
190 if self.config.validate_messages {
192 if let Err(e) = validate_message_string(line, Some(self.config.max_message_size)) {
193 return Err(TransportError::Protocol(format!(
194 "Outgoing message validation failed: {e}"
195 )));
196 }
197 }
198
199 debug!("Sending response: {}", line);
200
201 let line_with_newline = format!("{line}\n");
203
204 if let Err(e) = stdout.write_all(line_with_newline.as_bytes()).await {
205 return Err(TransportError::Connection(format!(
206 "Failed to write to stdout: {e}"
207 )));
208 }
209
210 if let Err(e) = stdout.flush().await {
211 return Err(TransportError::Connection(format!(
212 "Failed to flush stdout: {e}"
213 )));
214 }
215
216 Ok(())
217 }
218}
219
220impl Default for StdioTransport {
221 fn default() -> Self {
222 Self::new()
223 }
224}
225
226#[async_trait]
227impl Transport for StdioTransport {
228 async fn start(&mut self, handler: RequestHandler) -> Result<(), TransportError> {
229 info!("Starting MCP-compliant stdio transport");
230 info!("Max message size: {} bytes", self.config.max_message_size);
231 info!("Message validation: {}", self.config.validate_messages);
232
233 self.running
234 .store(true, std::sync::atomic::Ordering::Relaxed);
235
236 let stdin = tokio::io::stdin();
237 let mut stdout = tokio::io::stdout();
238 let mut reader = BufReader::new(stdin);
239 let mut line = String::new();
240
241 while self.running.load(std::sync::atomic::Ordering::Relaxed) {
242 line.clear();
243
244 match reader.read_line(&mut line).await {
245 Ok(0) => {
246 debug!("EOF reached, stopping stdio transport");
247 break;
248 }
249 Ok(_) => {
250 let trimmed_line = line.trim_end_matches(['\n', '\r']);
252
253 if trimmed_line.is_empty() {
255 continue;
256 }
257
258 if let Err(e) = self.process_line(trimmed_line, &handler, &mut stdout).await {
260 error!("Failed to process line: {}", e);
261 }
263 }
264 Err(e) => {
265 error!("Failed to read from stdin: {}", e);
266 return Err(TransportError::Connection(format!("Stdin read error: {e}")));
267 }
268 }
269 }
270
271 info!("Stdio transport stopped");
272 Ok(())
273 }
274
275 async fn stop(&mut self) -> Result<(), TransportError> {
276 info!("Stopping stdio transport");
277 self.running
278 .store(false, std::sync::atomic::Ordering::Relaxed);
279 Ok(())
280 }
281
282 async fn health_check(&self) -> Result<(), TransportError> {
283 if self.running.load(std::sync::atomic::Ordering::Relaxed) {
284 Ok(())
285 } else {
286 Err(TransportError::Connection(
287 "Transport not running".to_string(),
288 ))
289 }
290 }
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use pulseengine_mcp_protocol::{Error as McpError, Request, Response};
297 use serde_json::json;
298 use std::io::Cursor;
299
300 fn mock_handler(
302 request: Request,
303 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>> {
304 Box::pin(async move {
305 if request.method == "error_method" {
306 Response {
307 jsonrpc: "2.0".to_string(),
308 id: request.id,
309 result: None,
310 error: Some(McpError::method_not_found("Method not found")),
311 }
312 } else {
313 Response {
314 jsonrpc: "2.0".to_string(),
315 id: request.id,
316 result: Some(json!({"echo": request.method})),
317 error: None,
318 }
319 }
320 })
321 }
322
323 #[tokio::test]
324 async fn test_stdio_config() {
325 let config = StdioConfig {
326 max_message_size: 1024,
327 validate_messages: true,
328 };
329
330 let transport = StdioTransport::with_config(config.clone());
331 assert_eq!(transport.config.max_message_size, 1024);
332 assert!(transport.config.validate_messages);
333 }
334
335 #[tokio::test]
336 async fn test_message_validation() {
337 let _transport = StdioTransport::new();
338 let _handler: RequestHandler = Box::new(mock_handler);
339
340 let mut stdout_buffer = Vec::<u8>::new();
342 let _stdout = Cursor::new(&mut stdout_buffer);
343
344 let invalid_line = "{\"jsonrpc\": \"2.0\", \"method\": \"test\n\", \"id\": 1}";
346
347 assert!(validate_message_string(invalid_line, Some(1024)).is_err());
350 }
351
352 #[test]
353 fn test_extract_id_from_malformed() {
354 let text = r#"{"jsonrpc": "2.0", "method": "test", "id": 123}"#;
356 let id = extract_id_from_malformed(text);
357 assert_eq!(
358 id,
359 Some(pulseengine_mcp_protocol::NumberOrString::Number(123))
360 );
361
362 let text = r#"{"jsonrpc": "2.0", "method": "test", "id": "abc"}"#;
364 let id = extract_id_from_malformed(text);
365 assert_eq!(
366 id,
367 Some(pulseengine_mcp_protocol::NumberOrString::String(
368 std::sync::Arc::from("abc")
369 ))
370 );
371
372 let text = r#"{"jsonrpc": "2.0", "method": "test", "id": 456"#; let id = extract_id_from_malformed(text);
375 assert_eq!(
376 id,
377 Some(pulseengine_mcp_protocol::NumberOrString::Number(456))
378 );
379
380 let text = r#"{"jsonrpc": "2.0", "method": "test"}"#;
382 let id = extract_id_from_malformed(text);
383 assert_eq!(id, None);
384 }
385
386 #[test]
387 fn test_default_config() {
388 let config = StdioConfig::default();
389 assert_eq!(config.max_message_size, 10 * 1024 * 1024);
390 assert!(config.validate_messages);
391 }
392
393 #[tokio::test]
394 async fn test_health_check() {
395 let transport = StdioTransport::new();
396
397 assert!(transport.health_check().await.is_err());
399
400 transport
402 .running
403 .store(true, std::sync::atomic::Ordering::Relaxed);
404 assert!(transport.health_check().await.is_ok());
405 }
406
407 #[test]
408 fn test_transport_creation() {
409 let transport = StdioTransport::new();
410 assert!(!transport.is_running());
411 assert_eq!(transport.config().max_message_size, 10 * 1024 * 1024);
412 assert!(transport.config().validate_messages);
413 }
414
415 #[test]
416 fn test_transport_with_custom_config() {
417 let config = StdioConfig {
418 max_message_size: 2048,
419 validate_messages: false,
420 };
421 let transport = StdioTransport::with_config(config);
422
423 assert!(!transport.is_running());
424 assert_eq!(transport.config().max_message_size, 2048);
425 assert!(!transport.config().validate_messages);
426 }
427
428 #[test]
429 fn test_default_transport() {
430 let transport = StdioTransport::default();
431 assert!(!transport.is_running());
432 assert_eq!(transport.config().max_message_size, 10 * 1024 * 1024);
433 assert!(transport.config().validate_messages);
434 }
435
436 #[test]
437 fn test_running_state() {
438 let transport = StdioTransport::new();
439
440 assert!(!transport.is_running());
442
443 transport.set_running(true);
445 assert!(transport.is_running());
446
447 transport.set_running(false);
449 assert!(!transport.is_running());
450 }
451
452 #[tokio::test]
453 async fn test_stop_transport() {
454 let mut transport = StdioTransport::new();
455
456 transport.set_running(true);
458 assert!(transport.is_running());
459
460 assert!(transport.stop().await.is_ok());
462 assert!(!transport.is_running());
463 }
464
465 #[test]
466 fn test_stdio_config_clone() {
467 let config1 = StdioConfig {
468 max_message_size: 1024,
469 validate_messages: true,
470 };
471
472 let config2 = config1.clone();
473 assert_eq!(config1.max_message_size, config2.max_message_size);
474 assert_eq!(config1.validate_messages, config2.validate_messages);
475 }
476
477 #[test]
478 fn test_config_debug() {
479 let config = StdioConfig::default();
480 let debug_str = format!("{config:?}");
481 assert!(debug_str.contains("StdioConfig"));
482 assert!(debug_str.contains("max_message_size"));
483 assert!(debug_str.contains("validate_messages"));
484 }
485
486 #[test]
487 fn test_transport_debug() {
488 let transport = StdioTransport::new();
489 let debug_str = format!("{transport:?}");
490 assert!(debug_str.contains("StdioTransport"));
491 assert!(debug_str.contains("running"));
492 assert!(debug_str.contains("config"));
493 }
494
495 #[tokio::test]
496 async fn test_message_size_validation() {
497 let config = StdioConfig {
498 max_message_size: 50, validate_messages: true,
500 };
501 let _transport = StdioTransport::with_config(config);
502
503 let large_message = "x".repeat(100);
505 assert!(validate_message_string(&large_message, Some(50)).is_err());
506
507 let small_message = "x".repeat(10);
509 assert!(validate_message_string(&small_message, Some(50)).is_ok());
510 }
511
512 #[test]
513 fn test_json_rpc_message_parsing() {
514 let valid_msg = r#"{"jsonrpc": "2.0", "method": "test", "id": 1}"#;
516 let parsed = JsonRpcMessage::parse(valid_msg);
517 assert!(parsed.is_ok());
518
519 let invalid_msg = r#"{"jsonrpc": "2.0", "method": "test""#; let parsed = JsonRpcMessage::parse(invalid_msg);
522 assert!(parsed.is_err());
523 }
524
525 #[test]
526 fn test_message_validation_edge_cases() {
527 let newline_msg = "line1\nline2";
529 assert!(validate_message_string(newline_msg, Some(1024)).is_err());
530
531 let cr_msg = "line1\rline2";
533 assert!(validate_message_string(cr_msg, Some(1024)).is_err());
534
535 let empty_msg = "";
537 assert!(validate_message_string(empty_msg, Some(1024)).is_ok());
538
539 let normal_msg = "valid message";
541 assert!(validate_message_string(normal_msg, Some(1024)).is_ok());
542 }
543
544 #[test]
545 fn test_extract_id_edge_cases() {
546 let text = r#"{"jsonrpc": "2.0", "method": "test", "id": null}"#;
548 let id = extract_id_from_malformed(text);
549 assert_eq!(id, None);
550
551 let text = r#"{"jsonrpc": "2.0", "method": "test", "id": true}"#;
553 let id = extract_id_from_malformed(text);
554 assert_eq!(id, None);
555
556 let text = "not json at all";
558 let id = extract_id_from_malformed(text);
559 assert_eq!(id, None);
560
561 let text = "";
563 let id = extract_id_from_malformed(text);
564 assert_eq!(id, None);
565 }
566
567 #[tokio::test]
568 async fn test_response_serialization() {
569 let response = Response {
570 jsonrpc: "2.0".to_string(),
571 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(1)),
572 result: Some(json!({"status": "ok"})),
573 error: None,
574 };
575
576 let serialized = serde_json::to_string(&response);
577 assert!(serialized.is_ok());
578
579 let json_str = serialized.unwrap();
580 assert!(json_str.contains("jsonrpc"));
581 assert!(json_str.contains("2.0"));
582 assert!(json_str.contains("status"));
583 }
584
585 #[tokio::test]
586 async fn test_error_response_creation() {
587 let error = McpError::invalid_request("Test error");
588 let request_id = Some(pulseengine_mcp_protocol::NumberOrString::Number(42));
589
590 let response = create_error_response(error, request_id);
591
592 assert_eq!(response.jsonrpc, "2.0");
593 assert_eq!(
594 response.id,
595 Some(pulseengine_mcp_protocol::NumberOrString::Number(42))
596 );
597 assert!(response.error.is_some());
598 assert!(response.result.is_none());
599
600 let error_obj = response.error.unwrap();
601 assert!(error_obj.message.contains("Test error"));
602 }
603
604 #[test]
605 fn test_mock_handler_functionality() {
606 tokio::runtime::Runtime::new().unwrap().block_on(async {
607 let handler = mock_handler;
608
609 let request = Request {
611 jsonrpc: "2.0".to_string(),
612 method: "test_method".to_string(),
613 params: json!({}),
614 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(1)),
615 };
616
617 let response = handler(request).await;
618 assert_eq!(response.jsonrpc, "2.0");
619 assert_eq!(
620 response.id,
621 Some(pulseengine_mcp_protocol::NumberOrString::Number(1))
622 );
623 assert!(response.result.is_some());
624 assert!(response.error.is_none());
625
626 let error_request = Request {
628 jsonrpc: "2.0".to_string(),
629 method: "error_method".to_string(),
630 params: json!({}),
631 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(2)),
632 };
633
634 let error_response = handler(error_request).await;
635 assert_eq!(error_response.jsonrpc, "2.0");
636 assert_eq!(
637 error_response.id,
638 Some(pulseengine_mcp_protocol::NumberOrString::Number(2))
639 );
640 assert!(error_response.result.is_none());
641 assert!(error_response.error.is_some());
642 });
643 }
644}