mcp_probe_core/transport/
mod.rs

1//! MCP transport layer abstraction and implementations.
2//!
3//! This module provides a unified interface for all MCP transport mechanisms:
4//! - **stdio**: Local process communication via stdin/stdout
5//! - **HTTP+SSE**: Remote servers using HTTP requests + Server-Sent Events
6//! - **HTTP Streaming**: Full-duplex HTTP streaming for bidirectional communication
7//!
8//! The transport layer is designed to be:
9//! - **Transport-agnostic**: Same interface for all transport types
10//! - **Async-first**: Built on tokio for high-performance async I/O
11//! - **Type-safe**: Leverages Rust's type system to prevent protocol violations
12//! - **Extensible**: Easy to add new transport mechanisms
13//! - **Robust**: Comprehensive error handling and recovery
14//!
15//! # Examples
16//!
17//! ```rust,no_run
18//! use mcp_probe_core::transport::{Transport, TransportFactory, TransportConfig};
19//! use mcp_probe_core::messages::JsonRpcRequest;
20//! use serde_json::json;
21//!
22//! #[tokio::main]
23//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
24//!     // Create a transport from configuration
25//!     let config = TransportConfig::stdio("python", &["server.py"]);
26//!         
27//!     let mut transport = TransportFactory::create(config).await?;
28//!     
29//!     // Connect transport
30//!     transport.connect().await?;
31//!     
32//!     // Send a request  
33//!     let request = JsonRpcRequest::new("1", "initialize", json!({}));
34//!     let response = transport.send_request(request, Some(std::time::Duration::from_secs(30))).await?;
35//!     println!("Received: {:?}", response);
36//!     
37//!     Ok(())
38//! }
39//! ```
40
41pub mod config;
42pub mod factory;
43
44#[cfg(feature = "stdio")]
45pub mod stdio;
46
47#[cfg(feature = "http-sse")]
48pub mod http_sse;
49
50#[cfg(feature = "http-stream")]
51pub mod http_stream;
52
53pub use config::*;
54pub use factory::*;
55
56use crate::error::{McpResult, TransportError};
57use crate::messages::{JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
58use async_trait::async_trait;
59use std::time::Duration;
60use tokio::sync::mpsc;
61
62/// Core transport trait for MCP communication.
63///
64/// This trait defines the interface that all MCP transports must implement.
65/// It provides async methods for sending/receiving messages and managing
66/// the transport connection lifecycle.
67///
68/// # Design Principles
69///
70/// - **Bidirectional**: Support both client-to-server and server-to-client messages
71/// - **Message-oriented**: Work with high-level MCP messages, not raw bytes
72/// - **Async**: All operations are async for maximum concurrency
73/// - **Reliable**: Handle connection failures and provide retry mechanisms
74/// - **Observable**: Provide hooks for monitoring and debugging
75#[async_trait]
76pub trait Transport: Send + Sync {
77    /// Connect to the MCP server.
78    ///
79    /// This establishes the underlying connection (process spawn, HTTP connection, etc.)
80    /// but does not perform MCP protocol initialization.
81    async fn connect(&mut self) -> McpResult<()>;
82
83    /// Disconnect from the MCP server.
84    ///
85    /// This cleanly closes the connection and releases any resources.
86    /// Should be called when the MCP session is complete.
87    async fn disconnect(&mut self) -> McpResult<()>;
88
89    /// Check if the transport is currently connected.
90    fn is_connected(&self) -> bool;
91
92    /// Send a JSON-RPC request and wait for the response.
93    ///
94    /// This is the primary method for client-initiated request/response interactions.
95    /// The method handles request correlation and timeout management.
96    ///
97    /// # Arguments
98    ///
99    /// * `request` - The JSON-RPC request to send
100    /// * `timeout` - Optional timeout for the request (uses default if None)
101    ///
102    /// # Returns
103    ///
104    /// The corresponding JSON-RPC response, or an error if the request fails.
105    async fn send_request(
106        &mut self,
107        request: JsonRpcRequest,
108        timeout: Option<Duration>,
109    ) -> McpResult<JsonRpcResponse>;
110
111    /// Send a JSON-RPC notification (fire-and-forget).
112    ///
113    /// Notifications don't expect responses and are used for events,
114    /// logging, and other one-way communications.
115    ///
116    /// # Arguments
117    ///
118    /// * `notification` - The JSON-RPC notification to send
119    async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()>;
120
121    /// Receive the next message from the server.
122    ///
123    /// This method blocks until a message is received or an error occurs.
124    /// It can return requests (from server to client), responses (to previous
125    /// client requests), or notifications.
126    ///
127    /// # Arguments
128    ///
129    /// * `timeout` - Optional timeout for receiving (blocks indefinitely if None)
130    async fn receive_message(&mut self, timeout: Option<Duration>) -> McpResult<JsonRpcMessage>;
131
132    /// Get transport-specific metadata and statistics.
133    ///
134    /// This can include connection info, performance metrics, error counts, etc.
135    /// The exact contents depend on the transport implementation.
136    fn get_info(&self) -> TransportInfo;
137
138    /// Get the transport configuration used for this instance.
139    fn get_config(&self) -> &TransportConfig;
140}
141
142/// Transport information and statistics.
143///
144/// This structure provides insight into the transport's current state,
145/// performance characteristics, and any relevant metadata.
146#[derive(Debug, Clone, serde::Serialize)]
147pub struct TransportInfo {
148    /// Type of transport (stdio, http-sse, http-stream)
149    pub transport_type: String,
150
151    /// Whether the transport is currently connected
152    pub connected: bool,
153
154    /// Connection establishment time (if connected)
155    pub connected_since: Option<std::time::SystemTime>,
156
157    /// Number of requests sent
158    pub requests_sent: u64,
159
160    /// Number of responses received
161    pub responses_received: u64,
162
163    /// Number of notifications sent
164    pub notifications_sent: u64,
165
166    /// Number of notifications received
167    pub notifications_received: u64,
168
169    /// Number of errors encountered
170    pub errors: u64,
171
172    /// Transport-specific metadata
173    pub metadata: std::collections::HashMap<String, serde_json::Value>,
174}
175
176impl TransportInfo {
177    /// Create a new transport info structure.
178    pub fn new(transport_type: impl Into<String>) -> Self {
179        Self {
180            transport_type: transport_type.into(),
181            connected: false,
182            connected_since: None,
183            requests_sent: 0,
184            responses_received: 0,
185            notifications_sent: 0,
186            notifications_received: 0,
187            errors: 0,
188            metadata: std::collections::HashMap::new(),
189        }
190    }
191
192    /// Mark the transport as connected.
193    pub fn mark_connected(&mut self) {
194        self.connected = true;
195        self.connected_since = Some(std::time::SystemTime::now());
196    }
197
198    /// Mark the transport as disconnected.
199    pub fn mark_disconnected(&mut self) {
200        self.connected = false;
201        self.connected_since = None;
202    }
203
204    /// Increment the request counter.
205    pub fn increment_requests_sent(&mut self) {
206        self.requests_sent += 1;
207    }
208
209    /// Increment the response counter.
210    pub fn increment_responses_received(&mut self) {
211        self.responses_received += 1;
212    }
213
214    /// Increment the notification sent counter.
215    pub fn increment_notifications_sent(&mut self) {
216        self.notifications_sent += 1;
217    }
218
219    /// Increment the notification received counter.
220    pub fn increment_notifications_received(&mut self) {
221        self.notifications_received += 1;
222    }
223
224    /// Increment the error counter.
225    pub fn increment_errors(&mut self) {
226        self.errors += 1;
227    }
228
229    /// Add transport-specific metadata.
230    pub fn add_metadata(&mut self, key: impl Into<String>, value: serde_json::Value) {
231        self.metadata.insert(key.into(), value);
232    }
233
234    /// Get the duration since connection was established.
235    pub fn connection_duration(&self) -> Option<Duration> {
236        self.connected_since.map(|since| {
237            std::time::SystemTime::now()
238                .duration_since(since)
239                .unwrap_or_default()
240        })
241    }
242}
243
244/// Message sender for internal transport communication.
245///
246/// This type is used internally by transport implementations to send
247/// messages between different async tasks (e.g., reader and writer tasks).
248pub type MessageSender = mpsc::UnboundedSender<JsonRpcMessage>;
249
250/// Message receiver for internal transport communication.
251///
252/// This type is used internally by transport implementations to receive
253/// messages from different async tasks.
254pub type MessageReceiver = mpsc::UnboundedReceiver<JsonRpcMessage>;
255
256/// Helper trait for transport implementations.
257///
258/// This trait provides common functionality that most transport implementations
259/// will need, such as message correlation, timeout handling, etc.
260pub trait TransportHelper {
261    /// Generate a unique request ID.
262    fn generate_request_id() -> String {
263        uuid::Uuid::new_v4().to_string()
264    }
265
266    /// Create a timeout future for the given duration.
267    fn timeout_future(duration: Duration) -> tokio::time::Sleep {
268        tokio::time::sleep(duration)
269    }
270
271    /// Validate that a JSON-RPC message is well-formed.
272    fn validate_message(message: &JsonRpcMessage) -> McpResult<()> {
273        match message {
274            JsonRpcMessage::Request(req) => {
275                if req.jsonrpc != "2.0" {
276                    return Err(TransportError::InvalidConfig {
277                        transport_type: "generic".to_string(),
278                        reason: format!("Invalid jsonrpc version: {}", req.jsonrpc),
279                    }
280                    .into());
281                }
282            }
283            JsonRpcMessage::Response(resp) => {
284                if resp.jsonrpc != "2.0" {
285                    return Err(TransportError::InvalidConfig {
286                        transport_type: "generic".to_string(),
287                        reason: format!("Invalid jsonrpc version: {}", resp.jsonrpc),
288                    }
289                    .into());
290                }
291            }
292            JsonRpcMessage::Notification(notif) => {
293                if notif.jsonrpc != "2.0" {
294                    return Err(TransportError::InvalidConfig {
295                        transport_type: "generic".to_string(),
296                        reason: format!("Invalid jsonrpc version: {}", notif.jsonrpc),
297                    }
298                    .into());
299                }
300            }
301        }
302        Ok(())
303    }
304}
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_transport_info_creation() {
312        let mut info = TransportInfo::new("test");
313        assert_eq!(info.transport_type, "test");
314        assert!(!info.connected);
315        assert_eq!(info.requests_sent, 0);
316
317        info.mark_connected();
318        assert!(info.connected);
319        assert!(info.connected_since.is_some());
320
321        info.increment_requests_sent();
322        assert_eq!(info.requests_sent, 1);
323    }
324
325    #[test]
326    fn test_transport_info_metadata() {
327        let mut info = TransportInfo::new("test");
328        info.add_metadata("version", serde_json::json!("1.0.0"));
329
330        assert_eq!(
331            info.metadata.get("version").unwrap(),
332            &serde_json::json!("1.0.0")
333        );
334    }
335
336    #[test]
337    fn test_connection_duration() {
338        let mut info = TransportInfo::new("test");
339        assert!(info.connection_duration().is_none());
340
341        info.mark_connected();
342        let duration = info.connection_duration();
343        assert!(duration.is_some());
344        assert!(duration.unwrap().as_millis() < 100); // Should be very small
345
346        info.mark_disconnected();
347        assert!(info.connection_duration().is_none());
348    }
349}