use async_trait::async_trait;
use serde_json::Value;
use crate::transport::{http::HttpTransport, stdio::StdioTransport};
use crate::{protocol::Message, Result};
pub mod http;
pub mod stdio;
pub use http::{client::DefaultHttpClient as HttpClient, server::DefaultHttpServer as HttpServer};
pub use stdio::{
client::DefaultStdioClient as StdioClient, server::DefaultStdioServer as StdioServer,
};
#[derive(Debug, Clone)]
pub struct TransportConfig {
pub transport_type: TransportType,
pub parameters: Option<Value>,
}
#[derive(Debug, Clone)]
pub enum TransportType {
Stdio {
server_path: Option<String>,
server_args: Option<Vec<String>>,
},
Http {
base_url: String,
auth_token: Option<String>,
},
}
#[async_trait]
pub trait Transport: Send + Sync {
async fn initialize(&mut self) -> Result<()>;
async fn send(&self, message: Message) -> Result<()>;
async fn receive(&self) -> Result<Message>;
async fn close(&mut self) -> Result<()>;
}
pub struct ClientTransportFactory;
impl ClientTransportFactory {
pub fn create(&self, config: TransportConfig) -> Result<Box<dyn Transport>> {
match config.transport_type {
TransportType::Stdio {
server_path,
server_args,
} => {
use stdio::client::{StdioClient, StdioClientConfig};
let config = StdioClientConfig {
server_path: server_path
.map(std::path::PathBuf::from)
.unwrap_or_default(),
server_args: server_args.unwrap_or_default(),
..Default::default()
};
let client = StdioClient::new(config);
Ok(Box::new(StdioClientTransport(client)))
}
TransportType::Http {
base_url,
auth_token,
} => {
use http::client::{HttpClient, HttpClientConfig};
let config = HttpClientConfig {
base_url,
auth_token,
};
let client = HttpClient::new(config)?;
Ok(Box::new(HttpClientTransport(client)))
}
}
}
}
pub struct ServerTransportFactory;
impl ServerTransportFactory {
pub fn create(&self, config: TransportConfig) -> Result<Box<dyn Transport>> {
match config.transport_type {
TransportType::Stdio { .. } => {
use stdio::server::{StdioServer, StdioServerConfig};
let server = StdioServer::new(StdioServerConfig::default());
Ok(Box::new(StdioServerTransport(server)))
}
TransportType::Http {
base_url,
auth_token,
} => {
use http::server::{AxumHttpServer, HttpServerConfig};
let addr = base_url
.parse()
.map_err(|e| crate::Error::Transport(format!("Invalid address: {}", e)))?;
let config = HttpServerConfig { addr, auth_token };
let server = AxumHttpServer::new(config);
Ok(Box::new(HttpServerTransport(server)))
}
}
}
}
struct StdioClientTransport(stdio::client::StdioClient);
struct StdioServerTransport(stdio::server::StdioServer);
struct HttpClientTransport(http::client::HttpClient);
struct HttpServerTransport(http::server::AxumHttpServer);
macro_rules! impl_transport {
($wrapper:ident, $inner:ident) => {
#[async_trait]
impl Transport for $wrapper {
async fn initialize(&mut self) -> Result<()> {
self.0.initialize().await
}
async fn send(&self, message: Message) -> Result<()> {
self.0.send(message).await
}
async fn receive(&self) -> Result<Message> {
self.0.receive().await
}
async fn close(&mut self) -> Result<()> {
self.0.close().await
}
}
};
}
impl_transport!(StdioClientTransport, StdioClient);
impl_transport!(StdioServerTransport, StdioServer);
impl_transport!(HttpClientTransport, HttpClient);
impl_transport!(HttpServerTransport, AxumHttpServer);