Skip to main content

model_context_protocol/
transport.rs

1//! McpTransport - Abstract transport interface for MCP servers.
2//!
3//! This module defines the core transport trait that all MCP communication
4//! methods must implement, enabling uniform handling of stdio, HTTP, and
5//! other transport types.
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::fmt;
11
12use crate::protocol::ToolDefinition;
13
14/// Abstract transport interface for MCP server communication.
15///
16/// All MCP transports (stdio, HTTP, SSE) implement this trait to provide
17/// a uniform interface for tool discovery, execution, and shutdown.
18#[async_trait]
19pub trait McpTransport: Send + Sync {
20    /// Get the list of available tools from the server.
21    async fn list_tools(&self) -> Result<Vec<ToolDefinition>, McpTransportError>;
22
23    /// Execute a tool with the given arguments.
24    async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError>;
25
26    /// Perform a clean shutdown of the transport.
27    async fn shutdown(&self) -> Result<(), McpTransportError>;
28
29    /// Check if the transport is still connected/alive.
30    fn is_alive(&self) -> bool;
31
32    /// Get the transport type identifier.
33    fn transport_type(&self) -> TransportTypeId;
34}
35
36/// Transport type identifier.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38#[serde(rename_all = "lowercase")]
39pub enum TransportTypeId {
40    /// Standard I/O transport (recommended)
41    Stdio,
42    /// HTTP/REST transport
43    Http,
44    /// Server-Sent Events transport
45    Sse,
46}
47
48impl fmt::Display for TransportTypeId {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        match self {
51            TransportTypeId::Stdio => write!(f, "stdio"),
52            TransportTypeId::Http => write!(f, "http"),
53            TransportTypeId::Sse => write!(f, "sse"),
54        }
55    }
56}
57
58/// MCP transport errors.
59#[derive(Debug, thiserror::Error)]
60pub enum McpTransportError {
61    #[error("Unknown tool: {0}")]
62    UnknownTool(String),
63
64    #[error("Server not found: {0}")]
65    ServerNotFound(String),
66
67    #[error("Server error: {0}")]
68    ServerError(String),
69
70    #[error("Transport error: {0}")]
71    TransportError(String),
72
73    #[error("IO error: {0}")]
74    IoError(#[from] std::io::Error),
75
76    #[error("JSON error: {0}")]
77    JsonError(#[from] serde_json::Error),
78
79    #[error("Timeout: {0}")]
80    Timeout(String),
81
82    #[error("Protocol error: {0}")]
83    ProtocolError(String),
84
85    #[error("Not supported: {0}")]
86    NotSupported(String),
87
88    #[error("Connection closed")]
89    ConnectionClosed,
90}
91
92impl From<String> for McpTransportError {
93    fn from(s: String) -> Self {
94        McpTransportError::TransportError(s)
95    }
96}
97
98impl From<&str> for McpTransportError {
99    fn from(s: &str) -> Self {
100        McpTransportError::TransportError(s.to_string())
101    }
102}
103
104/// Configuration for connecting to an MCP server.
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct McpServerConnectionConfig {
107    /// Server name (identifier)
108    pub name: String,
109
110    /// Transport type
111    pub transport: TransportTypeId,
112
113    /// Command to run (for stdio)
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub command: Option<String>,
116
117    /// Command arguments (for stdio)
118    #[serde(default)]
119    pub args: Vec<String>,
120
121    /// URL endpoint (for HTTP/SSE)
122    #[serde(skip_serializing_if = "Option::is_none")]
123    pub url: Option<String>,
124
125    /// Server-specific configuration
126    #[serde(default)]
127    pub config: Value,
128
129    /// Connection timeout in seconds
130    #[serde(default = "default_timeout")]
131    pub timeout_secs: u64,
132
133    /// Environment variables to set for stdio transport
134    #[serde(default)]
135    pub env: std::collections::HashMap<String, String>,
136}
137
138fn default_timeout() -> u64 {
139    30
140}
141
142impl McpServerConnectionConfig {
143    /// Create a stdio server configuration.
144    pub fn stdio(name: impl Into<String>, command: impl Into<String>, args: Vec<String>) -> Self {
145        Self {
146            name: name.into(),
147            transport: TransportTypeId::Stdio,
148            command: Some(command.into()),
149            args,
150            url: None,
151            config: Value::Object(serde_json::Map::new()),
152            timeout_secs: default_timeout(),
153            env: std::collections::HashMap::new(),
154        }
155    }
156
157    /// Create an HTTP server configuration.
158    pub fn http(name: impl Into<String>, url: impl Into<String>) -> Self {
159        Self {
160            name: name.into(),
161            transport: TransportTypeId::Http,
162            command: None,
163            args: Vec::new(),
164            url: Some(url.into()),
165            config: Value::Object(serde_json::Map::new()),
166            timeout_secs: default_timeout(),
167            env: std::collections::HashMap::new(),
168        }
169    }
170
171    /// Set server-specific configuration.
172    pub fn with_config(mut self, config: Value) -> Self {
173        self.config = config;
174        self
175    }
176
177    /// Set connection timeout.
178    pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
179        self.timeout_secs = timeout_secs;
180        self
181    }
182
183    /// Add an environment variable (for stdio transport).
184    pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
185        self.env.insert(key.into(), value.into());
186        self
187    }
188}
189
190/// Initialize request for MCP protocol.
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct InitializeParams {
193    #[serde(rename = "protocolVersion")]
194    pub protocol_version: String,
195
196    pub capabilities: InitializeCapabilities,
197
198    #[serde(rename = "clientInfo")]
199    pub client_info: ClientInfo,
200
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub config: Option<Value>,
203}
204
205impl InitializeParams {
206    pub fn new(config: Option<Value>) -> Self {
207        Self {
208            protocol_version: "2024-11-05".to_string(),
209            capabilities: InitializeCapabilities::default(),
210            client_info: ClientInfo {
211                name: "mcp-rust".to_string(),
212                version: env!("CARGO_PKG_VERSION").to_string(),
213            },
214            config,
215        }
216    }
217}
218
219/// Initialize response from MCP server.
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct InitializeResult {
222    #[serde(rename = "protocolVersion")]
223    pub protocol_version: String,
224
225    pub capabilities: ServerCapabilities,
226
227    #[serde(rename = "serverInfo")]
228    pub server_info: ServerInfo,
229}
230
231/// Client capabilities for initialization.
232#[derive(Debug, Clone, Serialize, Deserialize, Default)]
233pub struct InitializeCapabilities {
234    #[serde(skip_serializing_if = "Option::is_none")]
235    pub tools: Option<ToolCapabilities>,
236}
237
238/// Server capabilities returned during initialization.
239#[derive(Debug, Clone, Serialize, Deserialize, Default)]
240pub struct ServerCapabilities {
241    #[serde(skip_serializing_if = "Option::is_none")]
242    pub tools: Option<ServerToolCapabilities>,
243
244    #[serde(skip_serializing_if = "Option::is_none")]
245    pub resources: Option<Value>,
246
247    #[serde(skip_serializing_if = "Option::is_none")]
248    pub prompts: Option<Value>,
249}
250
251/// Server tool capabilities.
252#[derive(Debug, Clone, Serialize, Deserialize, Default)]
253pub struct ServerToolCapabilities {
254    #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
255    pub list_changed: Option<bool>,
256}
257
258/// Tool-related capabilities.
259#[derive(Debug, Clone, Serialize, Deserialize, Default)]
260pub struct ToolCapabilities {
261    #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
262    pub list_changed: Option<bool>,
263}
264
265/// Client information for initialization.
266#[derive(Debug, Clone, Serialize, Deserialize)]
267pub struct ClientInfo {
268    pub name: String,
269    pub version: String,
270}
271
272/// Server information returned during initialization.
273#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct ServerInfo {
275    pub name: String,
276    pub version: String,
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_transport_type_display() {
285        assert_eq!(TransportTypeId::Stdio.to_string(), "stdio");
286        assert_eq!(TransportTypeId::Http.to_string(), "http");
287        assert_eq!(TransportTypeId::Sse.to_string(), "sse");
288    }
289
290    #[test]
291    fn test_connection_config_stdio() {
292        let config =
293            McpServerConnectionConfig::stdio("test", "node", vec!["server.js".to_string()])
294                .with_timeout(60);
295
296        assert_eq!(config.name, "test");
297        assert_eq!(config.transport, TransportTypeId::Stdio);
298        assert_eq!(config.command, Some("node".to_string()));
299        assert_eq!(config.timeout_secs, 60);
300    }
301
302    #[test]
303    fn test_connection_config_http() {
304        let config = McpServerConnectionConfig::http("api", "http://localhost:8080/mcp");
305
306        assert_eq!(config.name, "api");
307        assert_eq!(config.transport, TransportTypeId::Http);
308        assert_eq!(config.url, Some("http://localhost:8080/mcp".to_string()));
309    }
310
311    #[test]
312    fn test_initialize_params() {
313        let params = InitializeParams::new(None);
314        assert_eq!(params.protocol_version, "2024-11-05");
315        assert_eq!(params.client_info.name, "mcp-rust");
316    }
317}