Skip to main content

reasonkit_web/mcp/
transport.rs

1//! MCP Transport Abstraction
2//!
3//! This module provides the `McpTransport` trait for pluggable MCP transport
4//! backends, enabling flexible deployment options and performance optimization.
5//!
6//! # Supported Backends
7//!
8//! - **Stdio**: Standard I/O for CLI tool integration
9//! - **SSE**: Server-Sent Events for web clients
10//! - **WebSocket**: Bidirectional real-time communication
11//! - **HTTP**: REST-style request/response
12//!
13//! # Performance Optimization
14//!
15//! The transport layer supports pluggable backends for performance tuning:
16//! - Default `rmcp` SDK backend
17//! - High-performance `pmcp` backend (feature-gated)
18//!
19//! # Usage
20//!
21//! ```rust,ignore
22//! use reasonkit_web::mcp::transport::{McpTransport, StdioTransport};
23//!
24//! let transport = StdioTransport::new();
25//! let message = transport.receive().await?;
26//! transport.send(&response).await?;
27//! ```
28
29use async_trait::async_trait;
30use serde::{Deserialize, Serialize};
31use std::time::Duration;
32use thiserror::Error;
33
34/// MCP Transport errors
35#[derive(Error, Debug)]
36pub enum TransportError {
37    /// Connection establishment failed
38    #[error("Connection failed: {0}")]
39    ConnectionFailed(String),
40
41    /// Failed to send message
42    #[error("Send failed: {0}")]
43    SendFailed(String),
44
45    /// Failed to receive message
46    #[error("Receive failed: {0}")]
47    ReceiveFailed(String),
48
49    /// Operation timed out
50    #[error("Timeout after {0:?}")]
51    Timeout(Duration),
52
53    /// Transport connection closed
54    #[error("Transport closed")]
55    Closed,
56
57    /// Invalid message format error
58    #[error("Invalid message format: {0}")]
59    InvalidFormat(String),
60
61    /// IO operation error
62    #[error("IO error: {0}")]
63    Io(#[from] std::io::Error),
64
65    /// Serialization/deserialization error
66    #[error("Serialization error: {0}")]
67    Serialization(String),
68}
69
70/// Result type for transport operations
71pub type TransportResult<T> = Result<T, TransportError>;
72
73/// Transport configuration
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct TransportConfig {
76    /// Read timeout
77    pub read_timeout: Duration,
78    /// Write timeout
79    pub write_timeout: Duration,
80    /// Maximum message size in bytes
81    pub max_message_size: usize,
82    /// Enable compression
83    pub compression: bool,
84    /// Buffer size for reading
85    pub buffer_size: usize,
86}
87
88impl Default for TransportConfig {
89    fn default() -> Self {
90        Self {
91            read_timeout: Duration::from_secs(30),
92            write_timeout: Duration::from_secs(30),
93            max_message_size: 10 * 1024 * 1024, // 10 MB
94            compression: false,
95            buffer_size: 8192,
96        }
97    }
98}
99
100/// MCP message envelope
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct McpMessage {
103    /// JSON-RPC version
104    pub jsonrpc: String,
105    /// Message ID (for request/response correlation)
106    pub id: Option<serde_json::Value>,
107    /// Method name (for requests)
108    pub method: Option<String>,
109    /// Parameters (for requests)
110    pub params: Option<serde_json::Value>,
111    /// Result (for responses)
112    pub result: Option<serde_json::Value>,
113    /// Error (for error responses)
114    pub error: Option<McpError>,
115}
116
117impl McpMessage {
118    /// Create a new request message
119    pub fn request(
120        id: impl Into<serde_json::Value>,
121        method: &str,
122        params: serde_json::Value,
123    ) -> Self {
124        Self {
125            jsonrpc: "2.0".to_string(),
126            id: Some(id.into()),
127            method: Some(method.to_string()),
128            params: Some(params),
129            result: None,
130            error: None,
131        }
132    }
133
134    /// Create a new response message
135    pub fn response(id: impl Into<serde_json::Value>, result: serde_json::Value) -> Self {
136        Self {
137            jsonrpc: "2.0".to_string(),
138            id: Some(id.into()),
139            method: None,
140            params: None,
141            result: Some(result),
142            error: None,
143        }
144    }
145
146    /// Create a new error response
147    pub fn error_response(id: impl Into<serde_json::Value>, error: McpError) -> Self {
148        Self {
149            jsonrpc: "2.0".to_string(),
150            id: Some(id.into()),
151            method: None,
152            params: None,
153            result: None,
154            error: Some(error),
155        }
156    }
157
158    /// Create a notification (no ID)
159    pub fn notification(method: &str, params: serde_json::Value) -> Self {
160        Self {
161            jsonrpc: "2.0".to_string(),
162            id: None,
163            method: Some(method.to_string()),
164            params: Some(params),
165            result: None,
166            error: None,
167        }
168    }
169
170    /// Check if this is a request
171    pub fn is_request(&self) -> bool {
172        self.method.is_some() && self.id.is_some()
173    }
174
175    /// Check if this is a notification
176    pub fn is_notification(&self) -> bool {
177        self.method.is_some() && self.id.is_none()
178    }
179
180    /// Check if this is a response
181    pub fn is_response(&self) -> bool {
182        self.id.is_some() && (self.result.is_some() || self.error.is_some())
183    }
184}
185
186/// MCP error object
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct McpError {
189    /// Error code
190    pub code: i32,
191    /// Error message
192    pub message: String,
193    /// Additional error data
194    pub data: Option<serde_json::Value>,
195}
196
197impl McpError {
198    /// Parse error (-32700)
199    pub fn parse_error(message: impl Into<String>) -> Self {
200        Self {
201            code: -32700,
202            message: message.into(),
203            data: None,
204        }
205    }
206
207    /// Invalid request (-32600)
208    pub fn invalid_request(message: impl Into<String>) -> Self {
209        Self {
210            code: -32600,
211            message: message.into(),
212            data: None,
213        }
214    }
215
216    /// Method not found (-32601)
217    pub fn method_not_found(method: &str) -> Self {
218        Self {
219            code: -32601,
220            message: format!("Method not found: {}", method),
221            data: None,
222        }
223    }
224
225    /// Invalid params (-32602)
226    pub fn invalid_params(message: impl Into<String>) -> Self {
227        Self {
228            code: -32602,
229            message: message.into(),
230            data: None,
231        }
232    }
233
234    /// Internal error (-32603)
235    pub fn internal_error(message: impl Into<String>) -> Self {
236        Self {
237            code: -32603,
238            message: message.into(),
239            data: None,
240        }
241    }
242}
243
244/// Transport statistics for monitoring
245#[derive(Debug, Clone, Default, Serialize, Deserialize)]
246pub struct TransportStats {
247    /// Messages sent
248    pub messages_sent: u64,
249    /// Messages received
250    pub messages_received: u64,
251    /// Bytes sent
252    pub bytes_sent: u64,
253    /// Bytes received
254    pub bytes_received: u64,
255    /// Send errors
256    pub send_errors: u64,
257    /// Receive errors
258    pub receive_errors: u64,
259    /// Average send latency in microseconds
260    pub avg_send_latency_us: u64,
261    /// Average receive latency in microseconds
262    pub avg_receive_latency_us: u64,
263}
264
265/// MCP Transport trait for pluggable backends
266///
267/// Implement this trait to create custom MCP transport backends.
268#[async_trait]
269pub trait McpTransport: Send + Sync {
270    /// Get the transport type name
271    fn transport_type(&self) -> &'static str;
272
273    /// Connect the transport
274    async fn connect(&mut self) -> TransportResult<()>;
275
276    /// Disconnect the transport
277    async fn disconnect(&mut self) -> TransportResult<()>;
278
279    /// Check if connected
280    fn is_connected(&self) -> bool;
281
282    /// Send a message
283    async fn send(&mut self, message: &McpMessage) -> TransportResult<()>;
284
285    /// Receive a message
286    async fn receive(&mut self) -> TransportResult<McpMessage>;
287
288    /// Receive with timeout
289    async fn receive_timeout(&mut self, timeout: Duration) -> TransportResult<McpMessage>;
290
291    /// Get transport statistics
292    fn stats(&self) -> TransportStats;
293
294    /// Reset statistics
295    fn reset_stats(&mut self);
296}
297
298/// Transport type enumeration
299#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
300pub enum TransportType {
301    /// Standard I/O
302    Stdio,
303    /// Server-Sent Events
304    Sse,
305    /// WebSocket
306    WebSocket,
307    /// HTTP
308    Http,
309}
310
311impl std::fmt::Display for TransportType {
312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        match self {
314            Self::Stdio => write!(f, "stdio"),
315            Self::Sse => write!(f, "sse"),
316            Self::WebSocket => write!(f, "websocket"),
317            Self::Http => write!(f, "http"),
318        }
319    }
320}
321
322/// Standard I/O transport implementation
323pub struct StdioTransport {
324    #[allow(dead_code)]
325    config: TransportConfig,
326    connected: bool,
327    stats: TransportStats,
328}
329
330impl StdioTransport {
331    /// Create a new stdio transport
332    pub fn new() -> Self {
333        Self::with_config(TransportConfig::default())
334    }
335
336    /// Create with configuration
337    pub fn with_config(config: TransportConfig) -> Self {
338        Self {
339            config,
340            connected: false,
341            stats: TransportStats::default(),
342        }
343    }
344}
345
346impl Default for StdioTransport {
347    fn default() -> Self {
348        Self::new()
349    }
350}
351
352#[async_trait]
353impl McpTransport for StdioTransport {
354    fn transport_type(&self) -> &'static str {
355        "stdio"
356    }
357
358    async fn connect(&mut self) -> TransportResult<()> {
359        self.connected = true;
360        Ok(())
361    }
362
363    async fn disconnect(&mut self) -> TransportResult<()> {
364        self.connected = false;
365        Ok(())
366    }
367
368    fn is_connected(&self) -> bool {
369        self.connected
370    }
371
372    async fn send(&mut self, message: &McpMessage) -> TransportResult<()> {
373        let json = serde_json::to_string(message)
374            .map_err(|e| TransportError::Serialization(e.to_string()))?;
375
376        // Write to stdout with Content-Length header (LSP style)
377        let content = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
378
379        use std::io::Write;
380        let mut stdout = std::io::stdout().lock();
381        stdout
382            .write_all(content.as_bytes())
383            .map_err(TransportError::Io)?;
384        stdout.flush().map_err(TransportError::Io)?;
385
386        self.stats.messages_sent += 1;
387        self.stats.bytes_sent += content.len() as u64;
388
389        Ok(())
390    }
391
392    async fn receive(&mut self) -> TransportResult<McpMessage> {
393        use std::io::{BufRead, Read};
394
395        let stdin = std::io::stdin();
396        let mut reader = stdin.lock();
397
398        // Read Content-Length header
399        let mut header_line = String::new();
400        reader
401            .read_line(&mut header_line)
402            .map_err(TransportError::Io)?;
403
404        let content_length: usize = header_line
405            .trim()
406            .strip_prefix("Content-Length: ")
407            .ok_or_else(|| TransportError::InvalidFormat("Missing Content-Length header".into()))?
408            .parse()
409            .map_err(|_| TransportError::InvalidFormat("Invalid Content-Length".into()))?;
410
411        // Skip empty line
412        let mut empty = String::new();
413        reader.read_line(&mut empty).map_err(TransportError::Io)?;
414
415        // Read content
416        let mut content = vec![0u8; content_length];
417        reader
418            .read_exact(&mut content)
419            .map_err(TransportError::Io)?;
420
421        let message: McpMessage = serde_json::from_slice(&content)
422            .map_err(|e| TransportError::InvalidFormat(e.to_string()))?;
423
424        self.stats.messages_received += 1;
425        self.stats.bytes_received += content_length as u64;
426
427        Ok(message)
428    }
429
430    async fn receive_timeout(&mut self, _timeout: Duration) -> TransportResult<McpMessage> {
431        // For stdio, timeout handling would require threading
432        // This is a simplified implementation
433        self.receive().await
434    }
435
436    fn stats(&self) -> TransportStats {
437        self.stats.clone()
438    }
439
440    fn reset_stats(&mut self) {
441        self.stats = TransportStats::default();
442    }
443}
444
445/// Transport factory for creating transport instances
446pub struct TransportFactory;
447
448impl TransportFactory {
449    /// Create a transport of the specified type
450    pub fn create(transport_type: TransportType) -> Box<dyn McpTransport> {
451        match transport_type {
452            TransportType::Stdio => Box::new(StdioTransport::new()),
453            // Other transports would be implemented similarly
454            TransportType::Sse | TransportType::WebSocket | TransportType::Http => {
455                // Placeholder - would need full implementations
456                Box::new(StdioTransport::new())
457            }
458        }
459    }
460
461    /// Create a transport with configuration
462    pub fn create_with_config(
463        transport_type: TransportType,
464        config: TransportConfig,
465    ) -> Box<dyn McpTransport> {
466        match transport_type {
467            TransportType::Stdio => Box::new(StdioTransport::with_config(config)),
468            _ => Box::new(StdioTransport::with_config(config)),
469        }
470    }
471}
472
473#[cfg(test)]
474mod tests {
475    use super::*;
476
477    #[test]
478    fn test_message_request() {
479        let msg = McpMessage::request(1, "tools/list", serde_json::json!({}));
480        assert!(msg.is_request());
481        assert!(!msg.is_notification());
482        assert!(!msg.is_response());
483    }
484
485    #[test]
486    fn test_message_notification() {
487        let msg = McpMessage::notification("progress", serde_json::json!({ "percent": 50 }));
488        assert!(!msg.is_request());
489        assert!(msg.is_notification());
490        assert!(!msg.is_response());
491    }
492
493    #[test]
494    fn test_message_response() {
495        let msg = McpMessage::response(1, serde_json::json!({ "tools": [] }));
496        assert!(!msg.is_request());
497        assert!(!msg.is_notification());
498        assert!(msg.is_response());
499    }
500
501    #[test]
502    fn test_error_codes() {
503        let err = McpError::method_not_found("unknown");
504        assert_eq!(err.code, -32601);
505
506        let err = McpError::invalid_params("bad param");
507        assert_eq!(err.code, -32602);
508    }
509
510    #[test]
511    fn test_transport_config_default() {
512        let config = TransportConfig::default();
513        assert_eq!(config.read_timeout, Duration::from_secs(30));
514        assert!(!config.compression);
515    }
516}