use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum TransportError {
#[error("Connection failed: {0}")]
ConnectionFailed(String),
#[error("Send failed: {0}")]
SendFailed(String),
#[error("Receive failed: {0}")]
ReceiveFailed(String),
#[error("Timeout after {0:?}")]
Timeout(Duration),
#[error("Transport closed")]
Closed,
#[error("Invalid message format: {0}")]
InvalidFormat(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
Serialization(String),
}
pub type TransportResult<T> = Result<T, TransportError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransportConfig {
pub read_timeout: Duration,
pub write_timeout: Duration,
pub max_message_size: usize,
pub compression: bool,
pub buffer_size: usize,
}
impl Default for TransportConfig {
fn default() -> Self {
Self {
read_timeout: Duration::from_secs(30),
write_timeout: Duration::from_secs(30),
max_message_size: 10 * 1024 * 1024, compression: false,
buffer_size: 8192,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpMessage {
pub jsonrpc: String,
pub id: Option<serde_json::Value>,
pub method: Option<String>,
pub params: Option<serde_json::Value>,
pub result: Option<serde_json::Value>,
pub error: Option<McpError>,
}
impl McpMessage {
pub fn request(
id: impl Into<serde_json::Value>,
method: &str,
params: serde_json::Value,
) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id: Some(id.into()),
method: Some(method.to_string()),
params: Some(params),
result: None,
error: None,
}
}
pub fn response(id: impl Into<serde_json::Value>, result: serde_json::Value) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id: Some(id.into()),
method: None,
params: None,
result: Some(result),
error: None,
}
}
pub fn error_response(id: impl Into<serde_json::Value>, error: McpError) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id: Some(id.into()),
method: None,
params: None,
result: None,
error: Some(error),
}
}
pub fn notification(method: &str, params: serde_json::Value) -> Self {
Self {
jsonrpc: "2.0".to_string(),
id: None,
method: Some(method.to_string()),
params: Some(params),
result: None,
error: None,
}
}
pub fn is_request(&self) -> bool {
self.method.is_some() && self.id.is_some()
}
pub fn is_notification(&self) -> bool {
self.method.is_some() && self.id.is_none()
}
pub fn is_response(&self) -> bool {
self.id.is_some() && (self.result.is_some() || self.error.is_some())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct McpError {
pub code: i32,
pub message: String,
pub data: Option<serde_json::Value>,
}
impl McpError {
pub fn parse_error(message: impl Into<String>) -> Self {
Self {
code: -32700,
message: message.into(),
data: None,
}
}
pub fn invalid_request(message: impl Into<String>) -> Self {
Self {
code: -32600,
message: message.into(),
data: None,
}
}
pub fn method_not_found(method: &str) -> Self {
Self {
code: -32601,
message: format!("Method not found: {}", method),
data: None,
}
}
pub fn invalid_params(message: impl Into<String>) -> Self {
Self {
code: -32602,
message: message.into(),
data: None,
}
}
pub fn internal_error(message: impl Into<String>) -> Self {
Self {
code: -32603,
message: message.into(),
data: None,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct TransportStats {
pub messages_sent: u64,
pub messages_received: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub send_errors: u64,
pub receive_errors: u64,
pub avg_send_latency_us: u64,
pub avg_receive_latency_us: u64,
}
#[async_trait]
pub trait McpTransport: Send + Sync {
fn transport_type(&self) -> &'static str;
async fn connect(&mut self) -> TransportResult<()>;
async fn disconnect(&mut self) -> TransportResult<()>;
fn is_connected(&self) -> bool;
async fn send(&mut self, message: &McpMessage) -> TransportResult<()>;
async fn receive(&mut self) -> TransportResult<McpMessage>;
async fn receive_timeout(&mut self, timeout: Duration) -> TransportResult<McpMessage>;
fn stats(&self) -> TransportStats;
fn reset_stats(&mut self);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TransportType {
Stdio,
Sse,
WebSocket,
Http,
}
impl std::fmt::Display for TransportType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Stdio => write!(f, "stdio"),
Self::Sse => write!(f, "sse"),
Self::WebSocket => write!(f, "websocket"),
Self::Http => write!(f, "http"),
}
}
}
pub struct StdioTransport {
#[allow(dead_code)]
config: TransportConfig,
connected: bool,
stats: TransportStats,
}
impl StdioTransport {
pub fn new() -> Self {
Self::with_config(TransportConfig::default())
}
pub fn with_config(config: TransportConfig) -> Self {
Self {
config,
connected: false,
stats: TransportStats::default(),
}
}
}
impl Default for StdioTransport {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl McpTransport for StdioTransport {
fn transport_type(&self) -> &'static str {
"stdio"
}
async fn connect(&mut self) -> TransportResult<()> {
self.connected = true;
Ok(())
}
async fn disconnect(&mut self) -> TransportResult<()> {
self.connected = false;
Ok(())
}
fn is_connected(&self) -> bool {
self.connected
}
async fn send(&mut self, message: &McpMessage) -> TransportResult<()> {
let json = serde_json::to_string(message)
.map_err(|e| TransportError::Serialization(e.to_string()))?;
let content = format!("Content-Length: {}\r\n\r\n{}", json.len(), json);
use std::io::Write;
let mut stdout = std::io::stdout().lock();
stdout
.write_all(content.as_bytes())
.map_err(TransportError::Io)?;
stdout.flush().map_err(TransportError::Io)?;
self.stats.messages_sent += 1;
self.stats.bytes_sent += content.len() as u64;
Ok(())
}
async fn receive(&mut self) -> TransportResult<McpMessage> {
use std::io::{BufRead, Read};
let stdin = std::io::stdin();
let mut reader = stdin.lock();
let mut header_line = String::new();
reader
.read_line(&mut header_line)
.map_err(TransportError::Io)?;
let content_length: usize = header_line
.trim()
.strip_prefix("Content-Length: ")
.ok_or_else(|| TransportError::InvalidFormat("Missing Content-Length header".into()))?
.parse()
.map_err(|_| TransportError::InvalidFormat("Invalid Content-Length".into()))?;
let mut empty = String::new();
reader.read_line(&mut empty).map_err(TransportError::Io)?;
let mut content = vec![0u8; content_length];
reader
.read_exact(&mut content)
.map_err(TransportError::Io)?;
let message: McpMessage = serde_json::from_slice(&content)
.map_err(|e| TransportError::InvalidFormat(e.to_string()))?;
self.stats.messages_received += 1;
self.stats.bytes_received += content_length as u64;
Ok(message)
}
async fn receive_timeout(&mut self, _timeout: Duration) -> TransportResult<McpMessage> {
self.receive().await
}
fn stats(&self) -> TransportStats {
self.stats.clone()
}
fn reset_stats(&mut self) {
self.stats = TransportStats::default();
}
}
pub struct TransportFactory;
impl TransportFactory {
pub fn create(transport_type: TransportType) -> Box<dyn McpTransport> {
match transport_type {
TransportType::Stdio => Box::new(StdioTransport::new()),
TransportType::Sse | TransportType::WebSocket | TransportType::Http => {
Box::new(StdioTransport::new())
}
}
}
pub fn create_with_config(
transport_type: TransportType,
config: TransportConfig,
) -> Box<dyn McpTransport> {
match transport_type {
TransportType::Stdio => Box::new(StdioTransport::with_config(config)),
_ => Box::new(StdioTransport::with_config(config)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_request() {
let msg = McpMessage::request(1, "tools/list", serde_json::json!({}));
assert!(msg.is_request());
assert!(!msg.is_notification());
assert!(!msg.is_response());
}
#[test]
fn test_message_notification() {
let msg = McpMessage::notification("progress", serde_json::json!({ "percent": 50 }));
assert!(!msg.is_request());
assert!(msg.is_notification());
assert!(!msg.is_response());
}
#[test]
fn test_message_response() {
let msg = McpMessage::response(1, serde_json::json!({ "tools": [] }));
assert!(!msg.is_request());
assert!(!msg.is_notification());
assert!(msg.is_response());
}
#[test]
fn test_error_codes() {
let err = McpError::method_not_found("unknown");
assert_eq!(err.code, -32601);
let err = McpError::invalid_params("bad param");
assert_eq!(err.code, -32602);
}
#[test]
fn test_transport_config_default() {
let config = TransportConfig::default();
assert_eq!(config.read_timeout, Duration::from_secs(30));
assert!(!config.compression);
}
}