use crate::core::error::McpResult;
use crate::protocol::types::{JsonRpcNotification, JsonRpcRequest, JsonRpcResponse};
use async_trait::async_trait;
#[async_trait]
pub trait Transport: Send + Sync {
async fn send_request(&mut self, request: JsonRpcRequest) -> McpResult<JsonRpcResponse>;
async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()>;
async fn receive_notification(&mut self) -> McpResult<Option<JsonRpcNotification>>;
async fn close(&mut self) -> McpResult<()>;
fn is_connected(&self) -> bool {
true }
fn connection_info(&self) -> String {
"Unknown transport".to_string()
}
}
pub type ServerRequestHandler = std::sync::Arc<
dyn Fn(
JsonRpcRequest,
) -> std::pin::Pin<
Box<dyn std::future::Future<Output = McpResult<JsonRpcResponse>> + Send + 'static>,
> + Send
+ Sync,
>;
#[async_trait]
pub trait ServerTransport: Send + Sync {
async fn start(&mut self) -> McpResult<()>;
fn set_request_handler(&mut self, handler: ServerRequestHandler);
async fn send_notification(&mut self, notification: JsonRpcNotification) -> McpResult<()>;
async fn stop(&mut self) -> McpResult<()>;
fn is_running(&self) -> bool {
true }
fn server_info(&self) -> String {
"Unknown server transport".to_string()
}
}
#[derive(Debug, Clone)]
pub struct TransportConfig {
pub connect_timeout_ms: Option<u64>,
pub read_timeout_ms: Option<u64>,
pub write_timeout_ms: Option<u64>,
pub max_message_size: Option<usize>,
pub keep_alive_ms: Option<u64>,
pub compression: bool,
pub headers: std::collections::HashMap<String, String>,
}
impl Default for TransportConfig {
fn default() -> Self {
Self {
connect_timeout_ms: Some(30_000), read_timeout_ms: Some(60_000), write_timeout_ms: Some(30_000), max_message_size: Some(16 * 1024 * 1024), keep_alive_ms: Some(30_000), compression: false,
headers: std::collections::HashMap::new(),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum ConnectionState {
Disconnected,
Connecting,
Connected,
Reconnecting,
Closing,
Error(String),
}
#[derive(Debug, Clone, Default)]
pub struct TransportStats {
pub requests_sent: u64,
pub responses_received: u64,
pub notifications_sent: u64,
pub notifications_received: u64,
pub connection_errors: u64,
pub protocol_errors: u64,
pub bytes_sent: u64,
pub bytes_received: u64,
pub uptime_ms: u64,
}
pub trait TransportStats_: Send + Sync {
fn stats(&self) -> TransportStats;
fn reset_stats(&mut self);
}
#[async_trait]
pub trait ReconnectableTransport: Transport {
async fn reconnect(&mut self) -> McpResult<()>;
fn set_reconnect_config(&mut self, config: ReconnectConfig);
fn connection_state(&self) -> ConnectionState;
}
#[derive(Debug, Clone)]
pub struct ReconnectConfig {
pub enabled: bool,
pub max_attempts: Option<u32>,
pub initial_delay_ms: u64,
pub max_delay_ms: u64,
pub backoff_multiplier: f64,
pub jitter_factor: f64,
}
impl Default for ReconnectConfig {
fn default() -> Self {
Self {
enabled: true,
max_attempts: Some(5),
initial_delay_ms: 1000, max_delay_ms: 30_000, backoff_multiplier: 2.0,
jitter_factor: 0.1,
}
}
}
pub trait FilterableTransport: Send + Sync {
fn set_message_filter(&mut self, filter: Box<dyn Fn(&JsonRpcRequest) -> bool + Send + Sync>);
fn clear_message_filter(&mut self);
}
#[derive(Debug, Clone)]
pub enum TransportEvent {
Connected,
Disconnected,
MessageSent {
message_type: String,
size: usize,
},
MessageReceived {
message_type: String,
size: usize,
},
Error {
message: String,
},
}
pub trait EventEmittingTransport: Send + Sync {
fn add_event_listener(&mut self, listener: Box<dyn Fn(TransportEvent) + Send + Sync>);
fn clear_event_listeners(&mut self);
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transport_config_default() {
let config = TransportConfig::default();
assert_eq!(config.connect_timeout_ms, Some(30_000));
assert_eq!(config.read_timeout_ms, Some(60_000));
assert_eq!(config.max_message_size, Some(16 * 1024 * 1024));
assert!(!config.compression);
}
#[test]
fn test_reconnect_config_default() {
let config = ReconnectConfig::default();
assert!(config.enabled);
assert_eq!(config.max_attempts, Some(5));
assert_eq!(config.initial_delay_ms, 1000);
assert_eq!(config.max_delay_ms, 30_000);
assert_eq!(config.backoff_multiplier, 2.0);
assert_eq!(config.jitter_factor, 0.1);
}
#[test]
fn test_connection_state_equality() {
assert_eq!(ConnectionState::Connected, ConnectionState::Connected);
assert_eq!(ConnectionState::Disconnected, ConnectionState::Disconnected);
assert_ne!(ConnectionState::Connected, ConnectionState::Disconnected);
let error1 = ConnectionState::Error("test".to_string());
let error2 = ConnectionState::Error("test".to_string());
let error3 = ConnectionState::Error("other".to_string());
assert_eq!(error1, error2);
assert_ne!(error1, error3);
}
#[test]
fn test_transport_stats_default() {
let stats = TransportStats::default();
assert_eq!(stats.requests_sent, 0);
assert_eq!(stats.responses_received, 0);
assert_eq!(stats.bytes_sent, 0);
assert_eq!(stats.bytes_received, 0);
}
}