Skip to main content

aster/mcp/transport/
base.rs

1//! MCP Transport Base Types
2//!
3//! This module defines the Transport trait and related types for MCP communication.
4//! It supports multiple transport types: stdio, HTTP, SSE, and WebSocket.
5//!
6//! # Architecture
7//!
8//! The transport layer provides an abstraction over different communication mechanisms:
9//!
10//! - **Stdio**: Subprocess communication via stdin/stdout
11//! - **HTTP**: HTTP POST requests for request/response
12//! - **SSE**: Server-Sent Events for streaming
13//! - **WebSocket**: Full-duplex WebSocket connections
14//!
15//! Each transport implements the `Transport` trait which provides async send/receive
16//! capabilities with proper error handling.
17
18use async_trait::async_trait;
19use serde::{Deserialize, Serialize};
20use std::collections::HashMap;
21use std::sync::Arc;
22use std::time::Duration;
23use tokio::sync::mpsc;
24
25use crate::mcp::error::{McpError, McpResult};
26use crate::mcp::types::{ConnectionOptions, TransportType};
27
28/// JSON-RPC request ID type
29pub type RequestId = serde_json::Value;
30
31/// MCP JSON-RPC Request
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct McpRequest {
34    /// JSON-RPC version (always "2.0")
35    pub jsonrpc: String,
36    /// Request ID for matching responses
37    pub id: RequestId,
38    /// Method name
39    pub method: String,
40    /// Optional parameters
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub params: Option<serde_json::Value>,
43}
44
45impl McpRequest {
46    /// Create a new MCP request
47    pub fn new(id: impl Into<RequestId>, method: impl Into<String>) -> Self {
48        Self {
49            jsonrpc: "2.0".to_string(),
50            id: id.into(),
51            method: method.into(),
52            params: None,
53        }
54    }
55
56    /// Create a new MCP request with parameters
57    pub fn with_params(
58        id: impl Into<RequestId>,
59        method: impl Into<String>,
60        params: serde_json::Value,
61    ) -> Self {
62        Self {
63            jsonrpc: "2.0".to_string(),
64            id: id.into(),
65            method: method.into(),
66            params: Some(params),
67        }
68    }
69}
70
71/// MCP JSON-RPC Response
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct McpResponse {
74    /// JSON-RPC version (always "2.0")
75    pub jsonrpc: String,
76    /// Request ID matching the request
77    pub id: RequestId,
78    /// Result on success
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub result: Option<serde_json::Value>,
81    /// Error on failure
82    #[serde(skip_serializing_if = "Option::is_none")]
83    pub error: Option<McpErrorData>,
84}
85
86impl McpResponse {
87    /// Create a success response
88    pub fn success(id: RequestId, result: serde_json::Value) -> Self {
89        Self {
90            jsonrpc: "2.0".to_string(),
91            id,
92            result: Some(result),
93            error: None,
94        }
95    }
96
97    /// Create an error response
98    pub fn error(id: RequestId, error: McpErrorData) -> Self {
99        Self {
100            jsonrpc: "2.0".to_string(),
101            id,
102            result: None,
103            error: Some(error),
104        }
105    }
106
107    /// Check if the response is an error
108    pub fn is_error(&self) -> bool {
109        self.error.is_some()
110    }
111
112    /// Convert to Result
113    pub fn into_result(self) -> McpResult<serde_json::Value> {
114        if let Some(error) = self.error {
115            Err(McpError::server(error.code, error.message, error.data))
116        } else {
117            self.result
118                .ok_or_else(|| McpError::protocol("Response contains neither result nor error"))
119        }
120    }
121}
122
123/// MCP JSON-RPC Error Data
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct McpErrorData {
126    /// Error code
127    pub code: i32,
128    /// Error message
129    pub message: String,
130    /// Optional additional data
131    #[serde(skip_serializing_if = "Option::is_none")]
132    pub data: Option<serde_json::Value>,
133}
134
135impl McpErrorData {
136    /// Create a new error data
137    pub fn new(code: i32, message: impl Into<String>) -> Self {
138        Self {
139            code,
140            message: message.into(),
141            data: None,
142        }
143    }
144
145    /// Create a new error data with additional data
146    pub fn with_data(code: i32, message: impl Into<String>, data: serde_json::Value) -> Self {
147        Self {
148            code,
149            message: message.into(),
150            data: Some(data),
151        }
152    }
153}
154
155/// MCP JSON-RPC Notification (no response expected)
156#[derive(Debug, Clone, Serialize, Deserialize)]
157pub struct McpNotification {
158    /// JSON-RPC version (always "2.0")
159    pub jsonrpc: String,
160    /// Method name
161    pub method: String,
162    /// Optional parameters
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub params: Option<serde_json::Value>,
165}
166
167impl McpNotification {
168    /// Create a new notification
169    pub fn new(method: impl Into<String>) -> Self {
170        Self {
171            jsonrpc: "2.0".to_string(),
172            method: method.into(),
173            params: None,
174        }
175    }
176
177    /// Create a new notification with parameters
178    pub fn with_params(method: impl Into<String>, params: serde_json::Value) -> Self {
179        Self {
180            jsonrpc: "2.0".to_string(),
181            method: method.into(),
182            params: Some(params),
183        }
184    }
185}
186
187/// Message that can be sent/received over transport
188#[derive(Debug, Clone, Serialize, Deserialize)]
189#[serde(untagged)]
190pub enum McpMessage {
191    /// A request expecting a response
192    Request(McpRequest),
193    /// A response to a request
194    Response(McpResponse),
195    /// A notification (no response expected)
196    Notification(McpNotification),
197}
198
199impl McpMessage {
200    /// Get the request ID if this is a request or response
201    pub fn id(&self) -> Option<&RequestId> {
202        match self {
203            McpMessage::Request(req) => Some(&req.id),
204            McpMessage::Response(resp) => Some(&resp.id),
205            McpMessage::Notification(_) => None,
206        }
207    }
208
209    /// Get the method name if this is a request or notification
210    pub fn method(&self) -> Option<&str> {
211        match self {
212            McpMessage::Request(req) => Some(&req.method),
213            McpMessage::Response(_) => None,
214            McpMessage::Notification(notif) => Some(&notif.method),
215        }
216    }
217}
218
219/// Transport configuration for different transport types
220#[derive(Debug, Clone)]
221pub enum TransportConfig {
222    /// Stdio transport configuration
223    Stdio {
224        /// Command to execute
225        command: String,
226        /// Command arguments
227        args: Vec<String>,
228        /// Environment variables
229        env: HashMap<String, String>,
230        /// Working directory
231        cwd: Option<String>,
232    },
233    /// HTTP transport configuration
234    Http {
235        /// Server URL
236        url: String,
237        /// HTTP headers
238        headers: HashMap<String, String>,
239    },
240    /// SSE transport configuration
241    Sse {
242        /// Server URL
243        url: String,
244        /// HTTP headers
245        headers: HashMap<String, String>,
246    },
247    /// WebSocket transport configuration
248    WebSocket {
249        /// Server URL
250        url: String,
251        /// HTTP headers for upgrade request
252        headers: HashMap<String, String>,
253    },
254}
255
256impl TransportConfig {
257    /// Get the transport type
258    pub fn transport_type(&self) -> TransportType {
259        match self {
260            TransportConfig::Stdio { .. } => TransportType::Stdio,
261            TransportConfig::Http { .. } => TransportType::Http,
262            TransportConfig::Sse { .. } => TransportType::Sse,
263            TransportConfig::WebSocket { .. } => TransportType::WebSocket,
264        }
265    }
266}
267
268/// Transport state
269#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
270pub enum TransportState {
271    /// Transport is disconnected
272    #[default]
273    Disconnected,
274    /// Transport is connecting
275    Connecting,
276    /// Transport is connected and ready
277    Connected,
278    /// Transport is closing
279    Closing,
280    /// Transport encountered an error
281    Error,
282}
283
284/// Transport event for monitoring transport state changes
285#[derive(Debug, Clone)]
286pub enum TransportEvent {
287    /// Transport is connecting
288    Connecting,
289    /// Transport connected successfully
290    Connected,
291    /// Transport disconnected
292    Disconnected { reason: Option<String> },
293    /// Transport encountered an error
294    Error { error: String },
295    /// Message received from transport
296    MessageReceived(Box<McpMessage>),
297}
298
299/// Transport trait for MCP communication
300///
301/// This trait defines the interface for different transport implementations.
302/// All transports must be Send + Sync for use in async contexts.
303#[async_trait]
304pub trait Transport: Send + Sync {
305    /// Get the transport type
306    fn transport_type(&self) -> TransportType;
307
308    /// Get the current transport state
309    fn state(&self) -> TransportState;
310
311    /// Connect the transport
312    ///
313    /// This establishes the underlying connection (spawns process, opens socket, etc.)
314    async fn connect(&mut self) -> McpResult<()>;
315
316    /// Disconnect the transport
317    ///
318    /// This closes the underlying connection gracefully.
319    async fn disconnect(&mut self) -> McpResult<()>;
320
321    /// Send a message over the transport
322    ///
323    /// For request messages, use `send_request` instead to get the response.
324    async fn send(&mut self, message: McpMessage) -> McpResult<()>;
325
326    /// Send a request and wait for response
327    ///
328    /// This sends a request message and waits for the matching response.
329    async fn send_request(&mut self, request: McpRequest) -> McpResult<McpResponse>;
330
331    /// Send a request with timeout
332    ///
333    /// This sends a request and waits for response with a timeout.
334    async fn send_request_with_timeout(
335        &mut self,
336        request: McpRequest,
337        timeout: Duration,
338    ) -> McpResult<McpResponse>;
339
340    /// Subscribe to transport events
341    ///
342    /// Returns a receiver for transport events (state changes, incoming messages).
343    fn subscribe(&self) -> mpsc::Receiver<TransportEvent>;
344
345    /// Check if the transport is connected
346    fn is_connected(&self) -> bool {
347        self.state() == TransportState::Connected
348    }
349}
350
351/// Boxed transport type for dynamic dispatch
352pub type BoxedTransport = Box<dyn Transport>;
353
354/// Arc-wrapped transport for shared ownership
355pub type SharedTransport = Arc<tokio::sync::Mutex<BoxedTransport>>;
356
357/// Transport factory for creating transports from configuration
358pub struct TransportFactory;
359
360impl TransportFactory {
361    /// Create a transport from configuration
362    ///
363    /// This creates the appropriate transport implementation based on the config.
364    pub fn create(
365        config: TransportConfig,
366        options: ConnectionOptions,
367    ) -> McpResult<BoxedTransport> {
368        match config {
369            TransportConfig::Stdio {
370                command,
371                args,
372                env,
373                cwd,
374            } => {
375                use super::stdio::{StdioConfig, StdioTransport};
376                Ok(Box::new(StdioTransport::new(
377                    StdioConfig {
378                        command,
379                        args,
380                        env,
381                        cwd,
382                    },
383                    options,
384                )))
385            }
386            TransportConfig::Http { url, headers } => {
387                use super::http::{HttpConfig, HttpTransport};
388                Ok(Box::new(HttpTransport::new(
389                    HttpConfig { url, headers },
390                    options,
391                )))
392            }
393            TransportConfig::Sse { url, headers } => {
394                // SSE uses HTTP transport with streaming
395                use super::http::{HttpConfig, HttpTransport};
396                Ok(Box::new(HttpTransport::new(
397                    HttpConfig { url, headers },
398                    options,
399                )))
400            }
401            TransportConfig::WebSocket { url, headers } => {
402                use super::websocket::{WebSocketConfig, WebSocketTransport};
403                Ok(Box::new(WebSocketTransport::new(
404                    WebSocketConfig { url, headers },
405                    options,
406                )))
407            }
408        }
409    }
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_mcp_request_new() {
418        let req = McpRequest::new(serde_json::json!(1), "test/method");
419        assert_eq!(req.jsonrpc, "2.0");
420        assert_eq!(req.id, serde_json::json!(1));
421        assert_eq!(req.method, "test/method");
422        assert!(req.params.is_none());
423    }
424
425    #[test]
426    fn test_mcp_request_with_params() {
427        let params = serde_json::json!({"key": "value"});
428        let req =
429            McpRequest::with_params(serde_json::json!("req-1"), "test/method", params.clone());
430        assert_eq!(req.params, Some(params));
431    }
432
433    #[test]
434    fn test_mcp_response_success() {
435        let result = serde_json::json!({"status": "ok"});
436        let resp = McpResponse::success(serde_json::json!(1), result.clone());
437        assert!(!resp.is_error());
438        assert_eq!(resp.result, Some(result));
439    }
440
441    #[test]
442    fn test_mcp_response_error() {
443        let error = McpErrorData::new(-32600, "Invalid Request");
444        let resp = McpResponse::error(serde_json::json!(1), error);
445        assert!(resp.is_error());
446        assert!(resp.result.is_none());
447    }
448
449    #[test]
450    fn test_mcp_response_into_result() {
451        let result = serde_json::json!({"data": 42});
452        let resp = McpResponse::success(serde_json::json!(1), result.clone());
453        let res = resp.into_result();
454        assert!(res.is_ok());
455        assert_eq!(res.unwrap(), result);
456    }
457
458    #[test]
459    fn test_mcp_response_into_result_error() {
460        let error = McpErrorData::new(-32600, "Invalid Request");
461        let resp = McpResponse::error(serde_json::json!(1), error);
462        let res = resp.into_result();
463        assert!(res.is_err());
464    }
465
466    #[test]
467    fn test_mcp_notification() {
468        let notif = McpNotification::new("notifications/test");
469        assert_eq!(notif.jsonrpc, "2.0");
470        assert_eq!(notif.method, "notifications/test");
471        assert!(notif.params.is_none());
472    }
473
474    #[test]
475    fn test_mcp_notification_with_params() {
476        let params = serde_json::json!({"event": "update"});
477        let notif = McpNotification::with_params("notifications/test", params.clone());
478        assert_eq!(notif.params, Some(params));
479    }
480
481    #[test]
482    fn test_transport_config_type() {
483        let stdio = TransportConfig::Stdio {
484            command: "node".to_string(),
485            args: vec![],
486            env: HashMap::new(),
487            cwd: None,
488        };
489        assert_eq!(stdio.transport_type(), TransportType::Stdio);
490
491        let http = TransportConfig::Http {
492            url: "http://localhost:8080".to_string(),
493            headers: HashMap::new(),
494        };
495        assert_eq!(http.transport_type(), TransportType::Http);
496
497        let ws = TransportConfig::WebSocket {
498            url: "ws://localhost:8080".to_string(),
499            headers: HashMap::new(),
500        };
501        assert_eq!(ws.transport_type(), TransportType::WebSocket);
502    }
503
504    #[test]
505    fn test_transport_state_default() {
506        let state = TransportState::default();
507        assert_eq!(state, TransportState::Disconnected);
508    }
509
510    #[test]
511    fn test_mcp_message_id() {
512        let req = McpRequest::new(serde_json::json!(1), "test");
513        let msg = McpMessage::Request(req);
514        assert_eq!(msg.id(), Some(&serde_json::json!(1)));
515
516        let notif = McpNotification::new("test");
517        let msg = McpMessage::Notification(notif);
518        assert!(msg.id().is_none());
519    }
520
521    #[test]
522    fn test_mcp_message_method() {
523        let req = McpRequest::new(serde_json::json!(1), "test/method");
524        let msg = McpMessage::Request(req);
525        assert_eq!(msg.method(), Some("test/method"));
526
527        let resp = McpResponse::success(serde_json::json!(1), serde_json::json!({}));
528        let msg = McpMessage::Response(resp);
529        assert!(msg.method().is_none());
530    }
531
532    #[test]
533    fn test_mcp_error_data() {
534        let error = McpErrorData::new(-32600, "Invalid Request");
535        assert_eq!(error.code, -32600);
536        assert_eq!(error.message, "Invalid Request");
537        assert!(error.data.is_none());
538
539        let error_with_data = McpErrorData::with_data(
540            -32602,
541            "Invalid params",
542            serde_json::json!({"field": "name"}),
543        );
544        assert!(error_with_data.data.is_some());
545    }
546}