1use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::fmt;
11
12use crate::protocol::McpToolDefinition;
13
14#[async_trait]
19pub trait McpTransport: Send + Sync {
20 async fn list_tools(&self) -> Result<Vec<McpToolDefinition>, McpTransportError>;
22
23 async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError>;
25
26 async fn shutdown(&self) -> Result<(), McpTransportError>;
28
29 fn is_alive(&self) -> bool;
31
32 fn transport_type(&self) -> TransportTypeId;
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38#[serde(rename_all = "lowercase")]
39pub enum TransportTypeId {
40 Stdio,
42 Http,
44}
45
46impl fmt::Display for TransportTypeId {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 match self {
49 TransportTypeId::Stdio => write!(f, "stdio"),
50 TransportTypeId::Http => write!(f, "http"),
51 }
52 }
53}
54
55#[derive(Debug, thiserror::Error)]
57pub enum McpTransportError {
58 #[error("Unknown tool: {0}")]
59 UnknownTool(String),
60
61 #[error("Server not found: {0}")]
62 ServerNotFound(String),
63
64 #[error("Server error: {0}")]
65 ServerError(String),
66
67 #[error("Transport error: {0}")]
68 TransportError(String),
69
70 #[error("IO error: {0}")]
71 IoError(#[from] std::io::Error),
72
73 #[error("JSON error: {0}")]
74 JsonError(#[from] serde_json::Error),
75
76 #[error("Timeout: {0}")]
77 Timeout(String),
78
79 #[error("Protocol error: {0}")]
80 ProtocolError(String),
81
82 #[error("Not supported: {0}")]
83 NotSupported(String),
84
85 #[error("Connection closed")]
86 ConnectionClosed,
87
88 #[error("Server '{0}' is restarting")]
89 ServerRestarting(String),
90}
91
92impl From<String> for McpTransportError {
93 fn from(s: String) -> Self {
94 McpTransportError::TransportError(s)
95 }
96}
97
98impl From<&str> for McpTransportError {
99 fn from(s: &str) -> Self {
100 McpTransportError::TransportError(s.to_string())
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct RestartPolicy {
109 pub enabled: bool,
111 #[serde(skip_serializing_if = "Option::is_none")]
113 pub max_attempts: Option<u32>,
114 pub delay_ms: u64,
116 pub backoff_multiplier: f64,
118 pub max_delay_ms: u64,
120}
121
122impl Default for RestartPolicy {
123 fn default() -> Self {
124 Self {
125 enabled: false,
126 max_attempts: None,
127 delay_ms: 1000,
128 backoff_multiplier: 2.0,
129 max_delay_ms: 30000,
130 }
131 }
132}
133
134impl RestartPolicy {
135 pub fn none() -> Self {
137 Self::default()
138 }
139
140 pub fn always() -> Self {
142 Self {
143 enabled: true,
144 max_attempts: None,
145 delay_ms: 1000,
146 backoff_multiplier: 2.0,
147 max_delay_ms: 30000,
148 }
149 }
150
151 pub fn max_retries(attempts: u32) -> Self {
153 Self {
154 enabled: true,
155 max_attempts: Some(attempts),
156 delay_ms: 1000,
157 backoff_multiplier: 2.0,
158 max_delay_ms: 30000,
159 }
160 }
161
162 pub fn with_delay_ms(mut self, ms: u64) -> Self {
164 self.delay_ms = ms;
165 self
166 }
167
168 pub fn with_backoff(mut self, multiplier: f64) -> Self {
170 self.backoff_multiplier = multiplier;
171 self
172 }
173
174 pub fn with_max_delay_ms(mut self, ms: u64) -> Self {
176 self.max_delay_ms = ms;
177 self
178 }
179
180 pub fn delay_for_attempt(&self, attempt: u32) -> u64 {
182 if self.backoff_multiplier > 1.0 {
183 let exp_delay = (self.delay_ms as f64) * self.backoff_multiplier.powi(attempt as i32);
184 (exp_delay as u64).min(self.max_delay_ms)
185 } else {
186 self.delay_ms
187 }
188 }
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct McpServerConnectionConfig {
194 pub name: String,
196
197 pub transport: TransportTypeId,
199
200 #[serde(skip_serializing_if = "Option::is_none")]
202 pub command: Option<String>,
203
204 #[serde(default)]
206 pub args: Vec<String>,
207
208 #[serde(skip_serializing_if = "Option::is_none")]
210 pub url: Option<String>,
211
212 #[serde(default)]
214 pub config: Value,
215
216 #[serde(default = "default_timeout")]
218 pub timeout_secs: u64,
219
220 #[serde(default)]
222 pub env: std::collections::HashMap<String, String>,
223
224 #[serde(default)]
226 pub restart_policy: RestartPolicy,
227}
228
229fn default_timeout() -> u64 {
230 30
231}
232
233impl McpServerConnectionConfig {
234 pub fn stdio(name: impl Into<String>, command: impl Into<String>, args: Vec<String>) -> Self {
236 Self {
237 name: name.into(),
238 transport: TransportTypeId::Stdio,
239 command: Some(command.into()),
240 args,
241 url: None,
242 config: Value::Object(serde_json::Map::new()),
243 timeout_secs: default_timeout(),
244 env: std::collections::HashMap::new(),
245 restart_policy: RestartPolicy::none(),
246 }
247 }
248
249 pub fn http(name: impl Into<String>, url: impl Into<String>) -> Self {
251 Self {
252 name: name.into(),
253 transport: TransportTypeId::Http,
254 command: None,
255 args: Vec::new(),
256 url: Some(url.into()),
257 config: Value::Object(serde_json::Map::new()),
258 timeout_secs: default_timeout(),
259 env: std::collections::HashMap::new(),
260 restart_policy: RestartPolicy::none(),
261 }
262 }
263
264 pub fn with_config(mut self, config: Value) -> Self {
266 self.config = config;
267 self
268 }
269
270 pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
272 self.timeout_secs = timeout_secs;
273 self
274 }
275
276 pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
278 self.env.insert(key.into(), value.into());
279 self
280 }
281
282 pub fn with_restart(mut self, policy: RestartPolicy) -> Self {
284 self.restart_policy = policy;
285 self
286 }
287
288 pub fn restart_on_failure(self) -> Self {
290 self.with_restart(RestartPolicy::always())
291 }
292
293 pub fn restart_max_attempts(self, attempts: u32) -> Self {
295 self.with_restart(RestartPolicy::max_retries(attempts))
296 }
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
301pub struct InitializeParams {
302 #[serde(rename = "protocolVersion")]
303 pub protocol_version: String,
304
305 pub capabilities: InitializeCapabilities,
306
307 #[serde(rename = "clientInfo")]
308 pub client_info: ClientInfo,
309
310 #[serde(skip_serializing_if = "Option::is_none")]
311 pub config: Option<Value>,
312}
313
314impl InitializeParams {
315 pub fn new(config: Option<Value>) -> Self {
316 Self {
317 protocol_version: crate::MCP_PROTOCOL_VERSION.to_string(),
318 capabilities: InitializeCapabilities::default(),
319 client_info: ClientInfo::new("mcp-rust", env!("CARGO_PKG_VERSION")),
320 config,
321 }
322 }
323}
324
325#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct InitializeResult {
328 #[serde(rename = "protocolVersion")]
329 pub protocol_version: String,
330
331 pub capabilities: ServerCapabilities,
332
333 #[serde(rename = "serverInfo")]
334 pub server_info: ServerInfo,
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize, Default)]
339pub struct InitializeCapabilities {
340 #[serde(skip_serializing_if = "Option::is_none")]
342 pub experimental: Option<Value>,
343
344 #[serde(skip_serializing_if = "Option::is_none")]
346 pub roots: Option<RootsCapabilities>,
347
348 #[serde(skip_serializing_if = "Option::is_none")]
350 pub sampling: Option<SamplingCapabilities>,
351
352 #[serde(skip_serializing_if = "Option::is_none")]
354 pub elicitation: Option<ElicitationCapabilities>,
355
356 #[serde(skip_serializing_if = "Option::is_none")]
358 pub tasks: Option<TasksCapabilities>,
359
360 #[serde(skip_serializing_if = "Option::is_none")]
362 pub tools: Option<ToolCapabilities>,
363}
364
365#[derive(Debug, Clone, Serialize, Deserialize, Default)]
367pub struct RootsCapabilities {
368 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
370 pub list_changed: Option<bool>,
371}
372
373#[derive(Debug, Clone, Serialize, Deserialize, Default)]
375pub struct SamplingCapabilities {
376 #[serde(skip_serializing_if = "Option::is_none")]
378 pub context: Option<Value>,
379
380 #[serde(skip_serializing_if = "Option::is_none")]
382 pub tools: Option<Value>,
383}
384
385#[derive(Debug, Clone, Serialize, Deserialize, Default)]
387pub struct ElicitationCapabilities {
388 #[serde(skip_serializing_if = "Option::is_none")]
390 pub form: Option<Value>,
391
392 #[serde(skip_serializing_if = "Option::is_none")]
394 pub url: Option<Value>,
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize, Default)]
399pub struct TasksCapabilities {
400 #[serde(skip_serializing_if = "Option::is_none")]
402 pub list: Option<Value>,
403
404 #[serde(skip_serializing_if = "Option::is_none")]
406 pub cancel: Option<Value>,
407
408 #[serde(skip_serializing_if = "Option::is_none")]
410 pub requests: Option<Value>,
411}
412
413#[derive(Debug, Clone, Serialize, Deserialize, Default)]
415pub struct ServerCapabilities {
416 #[serde(skip_serializing_if = "Option::is_none")]
418 pub experimental: Option<Value>,
419
420 #[serde(skip_serializing_if = "Option::is_none")]
422 pub logging: Option<Value>,
423
424 #[serde(skip_serializing_if = "Option::is_none")]
426 pub completions: Option<Value>,
427
428 #[serde(skip_serializing_if = "Option::is_none")]
430 pub prompts: Option<PromptsCapabilities>,
431
432 #[serde(skip_serializing_if = "Option::is_none")]
434 pub resources: Option<ResourcesCapabilities>,
435
436 #[serde(skip_serializing_if = "Option::is_none")]
438 pub tools: Option<ServerToolCapabilities>,
439
440 #[serde(skip_serializing_if = "Option::is_none")]
442 pub tasks: Option<TasksCapabilities>,
443}
444
445#[derive(Debug, Clone, Serialize, Deserialize, Default)]
447pub struct PromptsCapabilities {
448 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
450 pub list_changed: Option<bool>,
451}
452
453#[derive(Debug, Clone, Serialize, Deserialize, Default)]
455pub struct ResourcesCapabilities {
456 #[serde(skip_serializing_if = "Option::is_none")]
458 pub subscribe: Option<bool>,
459
460 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
462 pub list_changed: Option<bool>,
463}
464
465#[derive(Debug, Clone, Serialize, Deserialize, Default)]
467pub struct ServerToolCapabilities {
468 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
469 pub list_changed: Option<bool>,
470}
471
472#[derive(Debug, Clone, Serialize, Deserialize, Default)]
474pub struct ToolCapabilities {
475 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
476 pub list_changed: Option<bool>,
477}
478
479#[derive(Debug, Clone, Serialize, Deserialize)]
481pub struct ClientInfo {
482 pub name: String,
484
485 pub version: String,
487
488 #[serde(skip_serializing_if = "Option::is_none")]
490 pub title: Option<String>,
491
492 #[serde(skip_serializing_if = "Option::is_none")]
494 pub description: Option<String>,
495
496 #[serde(skip_serializing_if = "Option::is_none")]
498 pub icons: Option<Vec<crate::protocol::Icon>>,
499
500 #[serde(rename = "websiteUrl", skip_serializing_if = "Option::is_none")]
502 pub website_url: Option<String>,
503}
504
505impl ClientInfo {
506 pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
508 Self {
509 name: name.into(),
510 version: version.into(),
511 title: None,
512 description: None,
513 icons: None,
514 website_url: None,
515 }
516 }
517}
518
519#[derive(Debug, Clone, Serialize, Deserialize)]
521pub struct ServerInfo {
522 pub name: String,
524
525 pub version: String,
527
528 #[serde(skip_serializing_if = "Option::is_none")]
530 pub title: Option<String>,
531
532 #[serde(skip_serializing_if = "Option::is_none")]
534 pub description: Option<String>,
535
536 #[serde(skip_serializing_if = "Option::is_none")]
538 pub icons: Option<Vec<crate::protocol::Icon>>,
539
540 #[serde(rename = "websiteUrl", skip_serializing_if = "Option::is_none")]
542 pub website_url: Option<String>,
543}
544
545impl ServerInfo {
546 pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
548 Self {
549 name: name.into(),
550 version: version.into(),
551 title: None,
552 description: None,
553 icons: None,
554 website_url: None,
555 }
556 }
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562
563 #[test]
564 fn test_transport_type_display() {
565 assert_eq!(TransportTypeId::Stdio.to_string(), "stdio");
566 assert_eq!(TransportTypeId::Http.to_string(), "http");
567 }
568
569 #[test]
570 fn test_connection_config_stdio() {
571 let config =
572 McpServerConnectionConfig::stdio("test", "node", vec!["server.js".to_string()])
573 .with_timeout(60);
574
575 assert_eq!(config.name, "test");
576 assert_eq!(config.transport, TransportTypeId::Stdio);
577 assert_eq!(config.command, Some("node".to_string()));
578 assert_eq!(config.timeout_secs, 60);
579 }
580
581 #[test]
582 fn test_connection_config_http() {
583 let config = McpServerConnectionConfig::http("api", "http://localhost:8080/mcp");
584
585 assert_eq!(config.name, "api");
586 assert_eq!(config.transport, TransportTypeId::Http);
587 assert_eq!(config.url, Some("http://localhost:8080/mcp".to_string()));
588 }
589
590 #[test]
591 fn test_initialize_params() {
592 let params = InitializeParams::new(None);
593 assert_eq!(params.protocol_version, "2025-11-25");
594 assert_eq!(params.client_info.name, "mcp-rust");
595 }
596}