1use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use serde_json::Value;
10use std::fmt;
11
12use crate::protocol::ToolDefinition;
13
14#[async_trait]
19pub trait McpTransport: Send + Sync {
20 async fn list_tools(&self) -> Result<Vec<ToolDefinition>, McpTransportError>;
22
23 async fn call_tool(&self, name: &str, args: Value) -> Result<Value, McpTransportError>;
25
26 async fn shutdown(&self) -> Result<(), McpTransportError>;
28
29 fn is_alive(&self) -> bool;
31
32 fn transport_type(&self) -> TransportTypeId;
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38#[serde(rename_all = "lowercase")]
39pub enum TransportTypeId {
40 Stdio,
42 Http,
44 Sse,
46}
47
48impl fmt::Display for TransportTypeId {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 match self {
51 TransportTypeId::Stdio => write!(f, "stdio"),
52 TransportTypeId::Http => write!(f, "http"),
53 TransportTypeId::Sse => write!(f, "sse"),
54 }
55 }
56}
57
58#[derive(Debug, thiserror::Error)]
60pub enum McpTransportError {
61 #[error("Unknown tool: {0}")]
62 UnknownTool(String),
63
64 #[error("Server not found: {0}")]
65 ServerNotFound(String),
66
67 #[error("Server error: {0}")]
68 ServerError(String),
69
70 #[error("Transport error: {0}")]
71 TransportError(String),
72
73 #[error("IO error: {0}")]
74 IoError(#[from] std::io::Error),
75
76 #[error("JSON error: {0}")]
77 JsonError(#[from] serde_json::Error),
78
79 #[error("Timeout: {0}")]
80 Timeout(String),
81
82 #[error("Protocol error: {0}")]
83 ProtocolError(String),
84
85 #[error("Not supported: {0}")]
86 NotSupported(String),
87
88 #[error("Connection closed")]
89 ConnectionClosed,
90}
91
92impl From<String> for McpTransportError {
93 fn from(s: String) -> Self {
94 McpTransportError::TransportError(s)
95 }
96}
97
98impl From<&str> for McpTransportError {
99 fn from(s: &str) -> Self {
100 McpTransportError::TransportError(s.to_string())
101 }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct McpServerConnectionConfig {
107 pub name: String,
109
110 pub transport: TransportTypeId,
112
113 #[serde(skip_serializing_if = "Option::is_none")]
115 pub command: Option<String>,
116
117 #[serde(default)]
119 pub args: Vec<String>,
120
121 #[serde(skip_serializing_if = "Option::is_none")]
123 pub url: Option<String>,
124
125 #[serde(default)]
127 pub config: Value,
128
129 #[serde(default = "default_timeout")]
131 pub timeout_secs: u64,
132
133 #[serde(default)]
135 pub env: std::collections::HashMap<String, String>,
136}
137
138fn default_timeout() -> u64 {
139 30
140}
141
142impl McpServerConnectionConfig {
143 pub fn stdio(name: impl Into<String>, command: impl Into<String>, args: Vec<String>) -> Self {
145 Self {
146 name: name.into(),
147 transport: TransportTypeId::Stdio,
148 command: Some(command.into()),
149 args,
150 url: None,
151 config: Value::Object(serde_json::Map::new()),
152 timeout_secs: default_timeout(),
153 env: std::collections::HashMap::new(),
154 }
155 }
156
157 pub fn http(name: impl Into<String>, url: impl Into<String>) -> Self {
159 Self {
160 name: name.into(),
161 transport: TransportTypeId::Http,
162 command: None,
163 args: Vec::new(),
164 url: Some(url.into()),
165 config: Value::Object(serde_json::Map::new()),
166 timeout_secs: default_timeout(),
167 env: std::collections::HashMap::new(),
168 }
169 }
170
171 pub fn with_config(mut self, config: Value) -> Self {
173 self.config = config;
174 self
175 }
176
177 pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
179 self.timeout_secs = timeout_secs;
180 self
181 }
182
183 pub fn with_env(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
185 self.env.insert(key.into(), value.into());
186 self
187 }
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct InitializeParams {
193 #[serde(rename = "protocolVersion")]
194 pub protocol_version: String,
195
196 pub capabilities: InitializeCapabilities,
197
198 #[serde(rename = "clientInfo")]
199 pub client_info: ClientInfo,
200
201 #[serde(skip_serializing_if = "Option::is_none")]
202 pub config: Option<Value>,
203}
204
205impl InitializeParams {
206 pub fn new(config: Option<Value>) -> Self {
207 Self {
208 protocol_version: "2024-11-05".to_string(),
209 capabilities: InitializeCapabilities::default(),
210 client_info: ClientInfo {
211 name: "mcp-rust".to_string(),
212 version: env!("CARGO_PKG_VERSION").to_string(),
213 },
214 config,
215 }
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct InitializeResult {
222 #[serde(rename = "protocolVersion")]
223 pub protocol_version: String,
224
225 pub capabilities: ServerCapabilities,
226
227 #[serde(rename = "serverInfo")]
228 pub server_info: ServerInfo,
229}
230
231#[derive(Debug, Clone, Serialize, Deserialize, Default)]
233pub struct InitializeCapabilities {
234 #[serde(skip_serializing_if = "Option::is_none")]
235 pub tools: Option<ToolCapabilities>,
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize, Default)]
240pub struct ServerCapabilities {
241 #[serde(skip_serializing_if = "Option::is_none")]
242 pub tools: Option<ServerToolCapabilities>,
243
244 #[serde(skip_serializing_if = "Option::is_none")]
245 pub resources: Option<Value>,
246
247 #[serde(skip_serializing_if = "Option::is_none")]
248 pub prompts: Option<Value>,
249}
250
251#[derive(Debug, Clone, Serialize, Deserialize, Default)]
253pub struct ServerToolCapabilities {
254 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
255 pub list_changed: Option<bool>,
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize, Default)]
260pub struct ToolCapabilities {
261 #[serde(rename = "listChanged", skip_serializing_if = "Option::is_none")]
262 pub list_changed: Option<bool>,
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
267pub struct ClientInfo {
268 pub name: String,
269 pub version: String,
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct ServerInfo {
275 pub name: String,
276 pub version: String,
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_transport_type_display() {
285 assert_eq!(TransportTypeId::Stdio.to_string(), "stdio");
286 assert_eq!(TransportTypeId::Http.to_string(), "http");
287 assert_eq!(TransportTypeId::Sse.to_string(), "sse");
288 }
289
290 #[test]
291 fn test_connection_config_stdio() {
292 let config =
293 McpServerConnectionConfig::stdio("test", "node", vec!["server.js".to_string()])
294 .with_timeout(60);
295
296 assert_eq!(config.name, "test");
297 assert_eq!(config.transport, TransportTypeId::Stdio);
298 assert_eq!(config.command, Some("node".to_string()));
299 assert_eq!(config.timeout_secs, 60);
300 }
301
302 #[test]
303 fn test_connection_config_http() {
304 let config = McpServerConnectionConfig::http("api", "http://localhost:8080/mcp");
305
306 assert_eq!(config.name, "api");
307 assert_eq!(config.transport, TransportTypeId::Http);
308 assert_eq!(config.url, Some("http://localhost:8080/mcp".to_string()));
309 }
310
311 #[test]
312 fn test_initialize_params() {
313 let params = InitializeParams::new(None);
314 assert_eq!(params.protocol_version, "2024-11-05");
315 assert_eq!(params.client_info.name, "mcp-rust");
316 }
317}