use crate::{McpToolsError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum TransportType {
WebSocket,
Http,
Stdio,
Tcp,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransportConfig {
pub transport_type: TransportType,
pub config: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TransportMessage {
pub id: String,
pub message_type: String,
pub payload: serde_json::Value,
pub metadata: HashMap<String, serde_json::Value>,
}
#[async_trait::async_trait]
pub trait Transport: Send + Sync {
async fn start(&mut self) -> Result<()>;
async fn stop(&mut self) -> Result<()>;
async fn send(&self, message: TransportMessage) -> Result<()>;
async fn receive(&self) -> Result<TransportMessage>;
async fn is_connected(&self) -> bool;
}
pub struct WebSocketTransport {
config: TransportConfig,
connected: bool,
}
impl WebSocketTransport {
pub fn new(config: TransportConfig) -> Self {
Self {
config,
connected: false,
}
}
}
#[async_trait::async_trait]
impl Transport for WebSocketTransport {
async fn start(&mut self) -> Result<()> {
self.connected = true;
Ok(())
}
async fn stop(&mut self) -> Result<()> {
self.connected = false;
Ok(())
}
async fn send(&self, _message: TransportMessage) -> Result<()> {
if !self.connected {
return Err(McpToolsError::Client("Transport not connected".to_string()));
}
Ok(())
}
async fn receive(&self) -> Result<TransportMessage> {
if !self.connected {
return Err(McpToolsError::Client("Transport not connected".to_string()));
}
Err(McpToolsError::Client("Receive not implemented".to_string()))
}
async fn is_connected(&self) -> bool {
self.connected
}
}
pub struct HttpTransport {
config: TransportConfig,
client: reqwest::Client,
}
impl HttpTransport {
pub fn new(config: TransportConfig) -> Self {
Self {
config,
client: reqwest::Client::new(),
}
}
}
#[async_trait::async_trait]
impl Transport for HttpTransport {
async fn start(&mut self) -> Result<()> {
Ok(())
}
async fn stop(&mut self) -> Result<()> {
Ok(())
}
async fn send(&self, message: TransportMessage) -> Result<()> {
let url = self
.config
.config
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| {
McpToolsError::Config("Missing URL in HTTP transport config".to_string())
})?;
let _response = self.client.post(url).json(&message).send().await?;
Ok(())
}
async fn receive(&self) -> Result<TransportMessage> {
Err(McpToolsError::Client(
"HTTP transport doesn't support receive".to_string(),
))
}
async fn is_connected(&self) -> bool {
true
}
}
pub fn create_transport(config: TransportConfig) -> Result<Box<dyn Transport>> {
match config.transport_type {
TransportType::WebSocket => Ok(Box::new(WebSocketTransport::new(config))),
TransportType::Http => Ok(Box::new(HttpTransport::new(config))),
TransportType::Stdio | TransportType::Tcp => Err(McpToolsError::Client(format!(
"Transport type {:?} not yet implemented",
config.transport_type
))),
}
}
impl Default for TransportConfig {
fn default() -> Self {
let mut config = HashMap::new();
config.insert(
"url".to_string(),
serde_json::Value::String("ws://127.0.0.1:3000".to_string()),
);
Self {
transport_type: TransportType::WebSocket,
config,
}
}
}
impl TransportMessage {
pub fn new(
id: impl Into<String>,
message_type: impl Into<String>,
payload: serde_json::Value,
) -> Self {
Self {
id: id.into(),
message_type: message_type.into(),
payload,
metadata: HashMap::new(),
}
}
pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.metadata.insert(key.into(), value);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_transport_config_default() {
let config = TransportConfig::default();
assert_eq!(config.transport_type, TransportType::WebSocket);
assert!(config.config.contains_key("url"));
}
#[test]
fn test_transport_message_creation() {
let message =
TransportMessage::new("test-id", "test-type", serde_json::json!({"test": "data"}));
assert_eq!(message.id, "test-id");
assert_eq!(message.message_type, "test-type");
assert_eq!(message.payload["test"], "data");
}
#[tokio::test]
async fn test_websocket_transport() {
let config = TransportConfig::default();
let mut transport = WebSocketTransport::new(config);
assert!(!transport.is_connected().await);
transport.start().await.unwrap();
assert!(transport.is_connected().await);
transport.stop().await.unwrap();
assert!(!transport.is_connected().await);
}
}