1use async_trait::async_trait;
30use serde::{Deserialize, Serialize};
31use std::time::Duration;
32use thiserror::Error;
33
34#[derive(Error, Debug)]
36pub enum TransportError {
37 #[error("Connection failed: {0}")]
39 ConnectionFailed(String),
40
41 #[error("Send failed: {0}")]
43 SendFailed(String),
44
45 #[error("Receive failed: {0}")]
47 ReceiveFailed(String),
48
49 #[error("Timeout after {0:?}")]
51 Timeout(Duration),
52
53 #[error("Transport closed")]
55 Closed,
56
57 #[error("Invalid message format: {0}")]
59 InvalidFormat(String),
60
61 #[error("IO error: {0}")]
63 Io(#[from] std::io::Error),
64
65 #[error("Serialization error: {0}")]
67 Serialization(String),
68}
69
70pub type TransportResult<T> = Result<T, TransportError>;
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct TransportConfig {
76 pub read_timeout: Duration,
78 pub write_timeout: Duration,
80 pub max_message_size: usize,
82 pub compression: bool,
84 pub buffer_size: usize,
86}
87
88impl Default for TransportConfig {
89 fn default() -> Self {
90 Self {
91 read_timeout: Duration::from_secs(30),
92 write_timeout: Duration::from_secs(30),
93 max_message_size: 10 * 1024 * 1024, compression: false,
95 buffer_size: 8192,
96 }
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct McpMessage {
103 pub jsonrpc: String,
105 pub id: Option<serde_json::Value>,
107 pub method: Option<String>,
109 pub params: Option<serde_json::Value>,
111 pub result: Option<serde_json::Value>,
113 pub error: Option<McpError>,
115}
116
117impl McpMessage {
118 pub fn request(
120 id: impl Into<serde_json::Value>,
121 method: &str,
122 params: serde_json::Value,
123 ) -> Self {
124 Self {
125 jsonrpc: "2.0".to_string(),
126 id: Some(id.into()),
127 method: Some(method.to_string()),
128 params: Some(params),
129 result: None,
130 error: None,
131 }
132 }
133
134 pub fn response(id: impl Into<serde_json::Value>, result: serde_json::Value) -> Self {
136 Self {
137 jsonrpc: "2.0".to_string(),
138 id: Some(id.into()),
139 method: None,
140 params: None,
141 result: Some(result),
142 error: None,
143 }
144 }
145
146 pub fn error_response(id: impl Into<serde_json::Value>, error: McpError) -> Self {
148 Self {
149 jsonrpc: "2.0".to_string(),
150 id: Some(id.into()),
151 method: None,
152 params: None,
153 result: None,
154 error: Some(error),
155 }
156 }
157
158 pub fn notification(method: &str, params: serde_json::Value) -> Self {
160 Self {
161 jsonrpc: "2.0".to_string(),
162 id: None,
163 method: Some(method.to_string()),
164 params: Some(params),
165 result: None,
166 error: None,
167 }
168 }
169
170 pub fn is_request(&self) -> bool {
172 self.method.is_some() && self.id.is_some()
173 }
174
175 pub fn is_notification(&self) -> bool {
177 self.method.is_some() && self.id.is_none()
178 }
179
180 pub fn is_response(&self) -> bool {
182 self.id.is_some() && (self.result.is_some() || self.error.is_some())
183 }
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct McpError {
189 pub code: i32,
191 pub message: String,
193 pub data: Option<serde_json::Value>,
195}
196
197impl McpError {
198 pub fn parse_error(message: impl Into<String>) -> Self {
200 Self {
201 code: -32700,
202 message: message.into(),
203 data: None,
204 }
205 }
206
207 pub fn invalid_request(message: impl Into<String>) -> Self {
209 Self {
210 code: -32600,
211 message: message.into(),
212 data: None,
213 }
214 }
215
216 pub fn method_not_found(method: &str) -> Self {
218 Self {
219 code: -32601,
220 message: format!("Method not found: {}", method),
221 data: None,
222 }
223 }
224
225 pub fn invalid_params(message: impl Into<String>) -> Self {
227 Self {
228 code: -32602,
229 message: message.into(),
230 data: None,
231 }
232 }
233
234 pub fn internal_error(message: impl Into<String>) -> Self {
236 Self {
237 code: -32603,
238 message: message.into(),
239 data: None,
240 }
241 }
242}
243
244#[derive(Debug, Clone, Default, Serialize, Deserialize)]
246pub struct TransportStats {
247 pub messages_sent: u64,
249 pub messages_received: u64,
251 pub bytes_sent: u64,
253 pub bytes_received: u64,
255 pub send_errors: u64,
257 pub receive_errors: u64,
259 pub avg_send_latency_us: u64,
261 pub avg_receive_latency_us: u64,
263}
264
265#[async_trait]
269pub trait McpTransport: Send + Sync {
270 fn transport_type(&self) -> &'static str;
272
273 async fn connect(&mut self) -> TransportResult<()>;
275
276 async fn disconnect(&mut self) -> TransportResult<()>;
278
279 fn is_connected(&self) -> bool;
281
282 async fn send(&mut self, message: &McpMessage) -> TransportResult<()>;
284
285 async fn receive(&mut self) -> TransportResult<McpMessage>;
287
288 async fn receive_timeout(&mut self, timeout: Duration) -> TransportResult<McpMessage>;
290
291 fn stats(&self) -> TransportStats;
293
294 fn reset_stats(&mut self);
296}
297
298#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
300pub enum TransportType {
301 Stdio,
303 Sse,
305 WebSocket,
307 Http,
309}
310
311impl std::fmt::Display for TransportType {
312 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313 match self {
314 Self::Stdio => write!(f, "stdio"),
315 Self::Sse => write!(f, "sse"),
316 Self::WebSocket => write!(f, "websocket"),
317 Self::Http => write!(f, "http"),
318 }
319 }
320}
321
322pub struct StdioTransport {
324 #[allow(dead_code)]
325 config: TransportConfig,
326 connected: bool,
327 stats: TransportStats,
328}
329
330impl StdioTransport {
331 pub fn new() -> Self {
333 Self::with_config(TransportConfig::default())
334 }
335
336 pub fn with_config(config: TransportConfig) -> Self {
338 Self {
339 config,
340 connected: false,
341 stats: TransportStats::default(),
342 }
343 }
344}
345
346impl Default for StdioTransport {
347 fn default() -> Self {
348 Self::new()
349 }
350}
351
352#[async_trait]
353impl McpTransport for StdioTransport {
354 fn transport_type(&self) -> &'static str {
355 "stdio"
356 }
357
358 async fn connect(&mut self) -> TransportResult<()> {
359 self.connected = true;
360 Ok(())
361 }
362
363 async fn disconnect(&mut self) -> TransportResult<()> {
364 self.connected = false;
365 Ok(())
366 }
367
368 fn is_connected(&self) -> bool {
369 self.connected
370 }
371
372 async fn send(&mut self, message: &McpMessage) -> TransportResult<()> {
373 let json = serde_json::to_string(message)
374 .map_err(|e| TransportError::Serialization(e.to_string()))?;
375
376 let content = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
378
379 use std::io::Write;
380 let mut stdout = std::io::stdout().lock();
381 stdout
382 .write_all(content.as_bytes())
383 .map_err(TransportError::Io)?;
384 stdout.flush().map_err(TransportError::Io)?;
385
386 self.stats.messages_sent += 1;
387 self.stats.bytes_sent += content.len() as u64;
388
389 Ok(())
390 }
391
392 async fn receive(&mut self) -> TransportResult<McpMessage> {
393 use std::io::{BufRead, Read};
394
395 let stdin = std::io::stdin();
396 let mut reader = stdin.lock();
397
398 let mut header_line = String::new();
400 reader
401 .read_line(&mut header_line)
402 .map_err(TransportError::Io)?;
403
404 let content_length: usize = header_line
405 .trim()
406 .strip_prefix("Content-Length: ")
407 .ok_or_else(|| TransportError::InvalidFormat("Missing Content-Length header".into()))?
408 .parse()
409 .map_err(|_| TransportError::InvalidFormat("Invalid Content-Length".into()))?;
410
411 let mut empty = String::new();
413 reader.read_line(&mut empty).map_err(TransportError::Io)?;
414
415 let mut content = vec![0u8; content_length];
417 reader
418 .read_exact(&mut content)
419 .map_err(TransportError::Io)?;
420
421 let message: McpMessage = serde_json::from_slice(&content)
422 .map_err(|e| TransportError::InvalidFormat(e.to_string()))?;
423
424 self.stats.messages_received += 1;
425 self.stats.bytes_received += content_length as u64;
426
427 Ok(message)
428 }
429
430 async fn receive_timeout(&mut self, _timeout: Duration) -> TransportResult<McpMessage> {
431 self.receive().await
434 }
435
436 fn stats(&self) -> TransportStats {
437 self.stats.clone()
438 }
439
440 fn reset_stats(&mut self) {
441 self.stats = TransportStats::default();
442 }
443}
444
445pub struct TransportFactory;
447
448impl TransportFactory {
449 pub fn create(transport_type: TransportType) -> Box<dyn McpTransport> {
451 match transport_type {
452 TransportType::Stdio => Box::new(StdioTransport::new()),
453 TransportType::Sse | TransportType::WebSocket | TransportType::Http => {
455 Box::new(StdioTransport::new())
457 }
458 }
459 }
460
461 pub fn create_with_config(
463 transport_type: TransportType,
464 config: TransportConfig,
465 ) -> Box<dyn McpTransport> {
466 match transport_type {
467 TransportType::Stdio => Box::new(StdioTransport::with_config(config)),
468 _ => Box::new(StdioTransport::with_config(config)),
469 }
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476
477 #[test]
478 fn test_message_request() {
479 let msg = McpMessage::request(1, "tools/list", serde_json::json!({}));
480 assert!(msg.is_request());
481 assert!(!msg.is_notification());
482 assert!(!msg.is_response());
483 }
484
485 #[test]
486 fn test_message_notification() {
487 let msg = McpMessage::notification("progress", serde_json::json!({ "percent": 50 }));
488 assert!(!msg.is_request());
489 assert!(msg.is_notification());
490 assert!(!msg.is_response());
491 }
492
493 #[test]
494 fn test_message_response() {
495 let msg = McpMessage::response(1, serde_json::json!({ "tools": [] }));
496 assert!(!msg.is_request());
497 assert!(!msg.is_notification());
498 assert!(msg.is_response());
499 }
500
501 #[test]
502 fn test_error_codes() {
503 let err = McpError::method_not_found("unknown");
504 assert_eq!(err.code, -32601);
505
506 let err = McpError::invalid_params("bad param");
507 assert_eq!(err.code, -32602);
508 }
509
510 #[test]
511 fn test_transport_config_default() {
512 let config = TransportConfig::default();
513 assert_eq!(config.read_timeout, Duration::from_secs(30));
514 assert!(!config.compression);
515 }
516}