Skip to main content

model_context_protocol/
transport.rs

1//! McpTransport - Abstract transport interface for MCP servers.
2//!
3//! This module defines the core transport trait that all MCP communication
4//! methods must implement, enabling uniform handling of stdio, HTTP, and
5//! other transport types.
6
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::fmt;
11
12use crate::protocol::McpToolDefinition;
13
14/// Abstract transport interface for MCP server communication.
15///
16/// All MCP transports (stdio, HTTP, SSE) implement this trait to provide
17/// a uniform interface for tool discovery, execution, and shutdown.
18#[async_trait]
19pub trait McpTransport: Send + Sync {
20    /// Get the list of available tools from the server.
21    async fn list_tools(&self) -> Result<Vec<McpToolDefinition>, McpTransportError>;
22
23    /// Execute a tool with the given arguments.
24    async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError>;
25
26    /// Perform a clean shutdown of the transport.
27    async fn shutdown(&self) -> Result<(), McpTransportError>;
28
29    /// Check if the transport is still connected/alive.
30    fn is_alive(&self) -> bool;
31
32    /// Get the transport type identifier.
33    fn transport_type(&self) -> TransportTypeId;
34}
35
36/// Transport type identifier.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38#[serde(rename_all = "lowercase")]
39pub enum TransportTypeId {
40    /// Standard I/O transport (recommended)
41    Stdio,
42    /// HTTP/REST transport
43    Http,
44}
45
46impl fmt::Display for TransportTypeId {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        match self {
49            TransportTypeId::Stdio => write!(f, "stdio"),
50            TransportTypeId::Http => write!(f, "http"),
51        }
52    }
53}
54
55/// MCP transport errors.
56#[derive(Debug, thiserror::Error)]
57pub enum McpTransportError {
58    #[error("Unknown tool: {0}")]
59    UnknownTool(String),
60
61    #[error("Server not found: {0}")]
62    ServerNotFound(String),
63
64    #[error("Server error: {0}")]
65    ServerError(String),
66
67    #[error("Transport error: {0}")]
68    TransportError(String),
69
70    #[error("IO error: {0}")]
71    IoError(#[from] std::io::Error),
72
73    #[error("JSON error: {0}")]
74    JsonError(#[from] serde_json::Error),
75
76    #[error("Timeout: {0}")]
77    Timeout(String),
78
79    #[error("Protocol error: {0}")]
80    ProtocolError(String),
81
82    #[error("Not supported: {0}")]
83    NotSupported(String),
84
85    #[error("Connection closed")]
86    ConnectionClosed,
87
88    #[error("Server '{0}' is restarting")]
89    ServerRestarting(String),
90}
91
92impl From<String> for McpTransportError {
93    fn from(s: String) -> Self {
94        McpTransportError::TransportError(s)
95    }
96}
97
98impl From<&str> for McpTransportError {
99    fn from(s: &str) -> Self {
100        McpTransportError::TransportError(s.to_string())
101    }
102}
103
104/// Restart policy for server connections.
105///
106/// Controls how the hub handles server disconnections and failures.
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct RestartPolicy {
109    /// Whether to automatically restart on failure
110    pub enabled: bool,
111    /// Maximum number of restart attempts (None = unlimited)
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub max_attempts: Option<u32>,
114    /// Delay between restart attempts in milliseconds
115    pub delay_ms: u64,
116    /// Exponential backoff multiplier (1.0 = no backoff)
117    pub backoff_multiplier: f64,
118    /// Maximum delay between restarts in milliseconds
119    pub max_delay_ms: u64,
120}
121
122impl Default for RestartPolicy {
123    fn default() -> Self {
124        Self {
125            enabled: false,
126            max_attempts: None,
127            delay_ms: 1000,
128            backoff_multiplier: 2.0,
129            max_delay_ms: 30000,
130        }
131    }
132}
133
134impl RestartPolicy {
135    /// No automatic restarts (default).
136    pub fn none() -> Self {
137        Self::default()
138    }
139
140    /// Always restart with default settings (unlimited retries).
141    pub fn always() -> Self {
142        Self {
143            enabled: true,
144            max_attempts: None,
145            delay_ms: 1000,
146            backoff_multiplier: 2.0,
147            max_delay_ms: 30000,
148        }
149    }
150
151    /// Restart up to N times before giving up.
152    pub fn max_retries(attempts: u32) -> Self {
153        Self {
154            enabled: true,
155            max_attempts: Some(attempts),
156            delay_ms: 1000,
157            backoff_multiplier: 2.0,
158            max_delay_ms: 30000,
159        }
160    }
161
162    /// Set initial delay between restarts (milliseconds).
163    pub fn with_delay_ms(mut self, ms: u64) -> Self {
164        self.delay_ms = ms;
165        self
166    }
167
168    /// Set backoff multiplier (e.g., 2.0 doubles delay each attempt).
169    pub fn with_backoff(mut self, multiplier: f64) -> Self {
170        self.backoff_multiplier = multiplier;
171        self
172    }
173
174    /// Set maximum delay cap (milliseconds).
175    pub fn with_max_delay_ms(mut self, ms: u64) -> Self {
176        self.max_delay_ms = ms;
177        self
178    }
179
180    /// Calculate delay for a given attempt number (0-indexed).
181    pub fn delay_for_attempt(&self, attempt: u32) -> u64 {
182        if self.backoff_multiplier > 1.0 {
183            let exp_delay = (self.delay_ms as f64) * self.backoff_multiplier.powi(attempt as i32);
184            (exp_delay as u64).min(self.max_delay_ms)
185        } else {
186            self.delay_ms
187        }
188    }
189}
190
191/// Configuration for connecting to an MCP server.
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct McpServerConnectionConfig {
194    /// Server name (identifier)
195    pub name: String,
196
197    /// Transport type
198    pub transport: TransportTypeId,
199
200    /// Command to run (for stdio)
201    #[serde(skip_serializing_if = "Option::is_none")]
202    pub command: Option<String>,
203
204    /// Command arguments (for stdio)
205    #[serde(default)]
206    pub args: Vec<String>,
207
208    /// URL endpoint (for HTTP/SSE)
209    #[serde(skip_serializing_if = "Option::is_none")]
210    pub url: Option<String>,
211
212    /// Server-specific configuration
213    #[serde(default)]
214    pub config: Value,
215
216    /// Connection timeout in seconds
217    #[serde(default = "default_timeout")]
218    pub timeout_secs: u64,
219
220    /// Environment variables to set for stdio transport
221    #[serde(default)]
222    pub env: std::collections::HashMap<String, String>,
223
224    /// Restart policy for handling server failures
225    #[serde(default)]
226    pub restart_policy: RestartPolicy,
227}
228
229fn default_timeout() -> u64 {
230    30
231}
232
233impl McpServerConnectionConfig {
234    /// Create a stdio server configuration.
235    pub fn stdio(name: impl Into<String>, command: impl Into<String>, args: Vec<String>) -> Self {
236        Self {
237            name: name.into(),
238            transport: TransportTypeId::Stdio,
239            command: Some(command.into()),
240            args,
241            url: None,
242            config: Value::Object(serde_json::Map::new()),
243            timeout_secs: default_timeout(),
244            env: std::collections::HashMap::new(),
245            restart_policy: RestartPolicy::none(),
246        }
247    }
248
249    /// Create an HTTP server configuration.
250    pub fn http(name: impl Into<String>, url: impl Into<String>) -> Self {
251        Self {
252            name: name.into(),
253            transport: TransportTypeId::Http,
254            command: None,
255            args: Vec::new(),
256            url: Some(url.into()),
257            config: Value::Object(serde_json::Map::new()),
258            timeout_secs: default_timeout(),
259            env: std::collections::HashMap::new(),
260            restart_policy: RestartPolicy::none(),
261        }
262    }
263
264    /// Set server-specific configuration.
265    pub fn with_config(mut self, config: Value) -> Self {
266        self.config = config;
267        self
268    }
269
270    /// Set connection timeout.
271    pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
272        self.timeout_secs = timeout_secs;
273        self
274    }
275
276    /// Add an environment variable (for stdio transport).
277    pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
278        self.env.insert(key.into(), value.into());
279        self
280    }
281
282    /// Set a custom restart policy.
283    pub fn with_restart(mut self, policy: RestartPolicy) -> Self {
284        self.restart_policy = policy;
285        self
286    }
287
288    /// Enable restart on failure with default settings (unlimited retries).
289    pub fn restart_on_failure(self) -> Self {
290        self.with_restart(RestartPolicy::always())
291    }
292
293    /// Enable restart with a maximum number of attempts.
294    pub fn restart_max_attempts(self, attempts: u32) -> Self {
295        self.with_restart(RestartPolicy::max_retries(attempts))
296    }
297}
298
299/// Initialize request for MCP protocol.
300#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct InitializeParams {
302    #[serde(rename = "protocolVersion")]
303    pub protocol_version: String,
304
305    pub capabilities: InitializeCapabilities,
306
307    #[serde(rename = "clientInfo")]
308    pub client_info: ClientInfo,
309
310    #[serde(skip_serializing_if = "Option::is_none")]
311    pub config: Option<Value>,
312}
313
314impl InitializeParams {
315    pub fn new(config: Option<Value>) -> Self {
316        Self {
317            protocol_version: crate::MCP_PROTOCOL_VERSION.to_string(),
318            capabilities: InitializeCapabilities::default(),
319            client_info: ClientInfo::new("mcp-rust", env!("CARGO_PKG_VERSION")),
320            config,
321        }
322    }
323}
324
325/// Initialize response from MCP server.
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct InitializeResult {
328    #[serde(rename = "protocolVersion")]
329    pub protocol_version: String,
330
331    pub capabilities: ServerCapabilities,
332
333    #[serde(rename = "serverInfo")]
334    pub server_info: ServerInfo,
335}
336
337/// Client capabilities for initialization (2025-11-25).
338#[derive(Debug, Clone, Serialize, Deserialize, Default)]
339pub struct InitializeCapabilities {
340    /// Experimental, non-standard capabilities.
341    #[serde(skip_serializing_if = "Option::is_none")]
342    pub experimental: Option<Value>,
343
344    /// Present if the client supports listing roots.
345    #[serde(skip_serializing_if = "Option::is_none")]
346    pub roots: Option<RootsCapabilities>,
347
348    /// Present if the client supports sampling from an LLM.
349    #[serde(skip_serializing_if = "Option::is_none")]
350    pub sampling: Option<SamplingCapabilities>,
351
352    /// Present if the client supports elicitation from the server.
353    #[serde(skip_serializing_if = "Option::is_none")]
354    pub elicitation: Option<ElicitationCapabilities>,
355
356    /// Present if the client supports task-augmented requests.
357    #[serde(skip_serializing_if = "Option::is_none")]
358    pub tasks: Option<TasksCapabilities>,
359
360    /// Tool-related capabilities.
361    #[serde(skip_serializing_if = "Option::is_none")]
362    pub tools: Option<ToolCapabilities>,
363}
364
365/// Roots capabilities.
366#[derive(Debug, Clone, Serialize, Deserialize, Default)]
367pub struct RootsCapabilities {
368    /// Whether the client supports notifications for changes to the roots list.
369    #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
370    pub list_changed: Option<bool>,
371}
372
373/// Sampling capabilities.
374#[derive(Debug, Clone, Serialize, Deserialize, Default)]
375pub struct SamplingCapabilities {
376    /// Whether the client supports context inclusion via includeContext parameter.
377    #[serde(skip_serializing_if = "Option::is_none")]
378    pub context: Option<Value>,
379
380    /// Whether the client supports tool use via tools and toolChoice parameters.
381    #[serde(skip_serializing_if = "Option::is_none")]
382    pub tools: Option<Value>,
383}
384
385/// Elicitation capabilities.
386#[derive(Debug, Clone, Serialize, Deserialize, Default)]
387pub struct ElicitationCapabilities {
388    /// Whether the client supports form-based elicitation.
389    #[serde(skip_serializing_if = "Option::is_none")]
390    pub form: Option<Value>,
391
392    /// Whether the client supports URL-based elicitation.
393    #[serde(skip_serializing_if = "Option::is_none")]
394    pub url: Option<Value>,
395}
396
397/// Tasks capabilities.
398#[derive(Debug, Clone, Serialize, Deserialize, Default)]
399pub struct TasksCapabilities {
400    /// Whether this party supports tasks/list.
401    #[serde(skip_serializing_if = "Option::is_none")]
402    pub list: Option<Value>,
403
404    /// Whether this party supports tasks/cancel.
405    #[serde(skip_serializing_if = "Option::is_none")]
406    pub cancel: Option<Value>,
407
408    /// Specifies which request types can be augmented with tasks.
409    #[serde(skip_serializing_if = "Option::is_none")]
410    pub requests: Option<Value>,
411}
412
413/// Server capabilities returned during initialization (2025-11-25).
414#[derive(Debug, Clone, Serialize, Deserialize, Default)]
415pub struct ServerCapabilities {
416    /// Experimental, non-standard capabilities.
417    #[serde(skip_serializing_if = "Option::is_none")]
418    pub experimental: Option<Value>,
419
420    /// Present if the server supports sending log messages to the client.
421    #[serde(skip_serializing_if = "Option::is_none")]
422    pub logging: Option<Value>,
423
424    /// Present if the server supports argument autocompletion suggestions.
425    #[serde(skip_serializing_if = "Option::is_none")]
426    pub completions: Option<Value>,
427
428    /// Present if the server offers any prompt templates.
429    #[serde(skip_serializing_if = "Option::is_none")]
430    pub prompts: Option<PromptsCapabilities>,
431
432    /// Present if the server offers any resources to read.
433    #[serde(skip_serializing_if = "Option::is_none")]
434    pub resources: Option<ResourcesCapabilities>,
435
436    /// Present if the server offers any tools to call.
437    #[serde(skip_serializing_if = "Option::is_none")]
438    pub tools: Option<ServerToolCapabilities>,
439
440    /// Present if the server supports task-augmented requests.
441    #[serde(skip_serializing_if = "Option::is_none")]
442    pub tasks: Option<TasksCapabilities>,
443}
444
445/// Prompts capabilities.
446#[derive(Debug, Clone, Serialize, Deserialize, Default)]
447pub struct PromptsCapabilities {
448    /// Whether this server supports notifications for changes to the prompt list.
449    #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
450    pub list_changed: Option<bool>,
451}
452
453/// Resources capabilities.
454#[derive(Debug, Clone, Serialize, Deserialize, Default)]
455pub struct ResourcesCapabilities {
456    /// Whether this server supports subscribing to resource updates.
457    #[serde(skip_serializing_if = "Option::is_none")]
458    pub subscribe: Option<bool>,
459
460    /// Whether this server supports notifications for changes to the resource list.
461    #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
462    pub list_changed: Option<bool>,
463}
464
465/// Server tool capabilities.
466#[derive(Debug, Clone, Serialize, Deserialize, Default)]
467pub struct ServerToolCapabilities {
468    #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
469    pub list_changed: Option<bool>,
470}
471
472/// Tool-related capabilities.
473#[derive(Debug, Clone, Serialize, Deserialize, Default)]
474pub struct ToolCapabilities {
475    #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
476    pub list_changed: Option<bool>,
477}
478
479/// Client information for initialization (2025-11-25).
480#[derive(Debug, Clone, Serialize, Deserialize)]
481pub struct ClientInfo {
482    /// Intended for programmatic or logical use.
483    pub name: String,
484
485    /// The version of the client implementation.
486    pub version: String,
487
488    /// Intended for UI and end-user contexts — optimized to be human-readable.
489    #[serde(skip_serializing_if = "Option::is_none")]
490    pub title: Option<String>,
491
492    /// An optional human-readable description.
493    #[serde(skip_serializing_if = "Option::is_none")]
494    pub description: Option<String>,
495
496    /// Optional set of sized icons that can be displayed.
497    #[serde(skip_serializing_if = "Option::is_none")]
498    pub icons: Option<Vec<crate::protocol::Icon>>,
499
500    /// An optional URL of the website for this implementation.
501    #[serde(rename = "websiteUrl", skip_serializing_if = "Option::is_none")]
502    pub website_url: Option<String>,
503}
504
505impl ClientInfo {
506    /// Create a new ClientInfo with just name and version.
507    pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
508        Self {
509            name: name.into(),
510            version: version.into(),
511            title: None,
512            description: None,
513            icons: None,
514            website_url: None,
515        }
516    }
517}
518
519/// Server information returned during initialization (2025-11-25).
520#[derive(Debug, Clone, Serialize, Deserialize)]
521pub struct ServerInfo {
522    /// Intended for programmatic or logical use.
523    pub name: String,
524
525    /// The version of the server implementation.
526    pub version: String,
527
528    /// Intended for UI and end-user contexts — optimized to be human-readable.
529    #[serde(skip_serializing_if = "Option::is_none")]
530    pub title: Option<String>,
531
532    /// An optional human-readable description.
533    #[serde(skip_serializing_if = "Option::is_none")]
534    pub description: Option<String>,
535
536    /// Optional set of sized icons that can be displayed.
537    #[serde(skip_serializing_if = "Option::is_none")]
538    pub icons: Option<Vec<crate::protocol::Icon>>,
539
540    /// An optional URL of the website for this implementation.
541    #[serde(rename = "websiteUrl", skip_serializing_if = "Option::is_none")]
542    pub website_url: Option<String>,
543}
544
545impl ServerInfo {
546    /// Create a new ServerInfo with just name and version.
547    pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
548        Self {
549            name: name.into(),
550            version: version.into(),
551            title: None,
552            description: None,
553            icons: None,
554            website_url: None,
555        }
556    }
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    #[test]
564    fn test_transport_type_display() {
565        assert_eq!(TransportTypeId::Stdio.to_string(), "stdio");
566        assert_eq!(TransportTypeId::Http.to_string(), "http");
567    }
568
569    #[test]
570    fn test_connection_config_stdio() {
571        let config =
572            McpServerConnectionConfig::stdio("test", "node", vec!["server.js".to_string()])
573                .with_timeout(60);
574
575        assert_eq!(config.name, "test");
576        assert_eq!(config.transport, TransportTypeId::Stdio);
577        assert_eq!(config.command, Some("node".to_string()));
578        assert_eq!(config.timeout_secs, 60);
579    }
580
581    #[test]
582    fn test_connection_config_http() {
583        let config = McpServerConnectionConfig::http("api", "http://localhost:8080/mcp");
584
585        assert_eq!(config.name, "api");
586        assert_eq!(config.transport, TransportTypeId::Http);
587        assert_eq!(config.url, Some("http://localhost:8080/mcp".to_string()));
588    }
589
590    #[test]
591    fn test_initialize_params() {
592        let params = InitializeParams::new(None);
593        assert_eq!(params.protocol_version, "2025-11-25");
594        assert_eq!(params.client_info.name, "mcp-rust");
595    }
596}