use crate::error::Result;
use crate::shared::{Transport as TransportTrait, TransportMessage};
#[cfg(test)]
use crate::types::{JSONRPCResponse, Request, RequestId};
use async_trait::async_trait;
use std::sync::Arc;
#[cfg(target_arch = "wasm32")]
use futures::lock::RwLock;
#[cfg(not(target_arch = "wasm32"))]
use tokio::sync::RwLock;
use super::core::ProtocolHandler;
#[cfg(not(target_arch = "wasm32"))]
#[async_trait]
pub trait TransportAdapter: Send + Sync {
async fn serve(&self, handler: Arc<dyn ProtocolHandler>) -> Result<()>;
fn transport_type(&self) -> &'static str;
}
#[cfg(target_arch = "wasm32")]
#[async_trait(?Send)]
pub trait TransportAdapter {
async fn serve(&self, handler: Arc<dyn ProtocolHandler>) -> Result<()>;
fn transport_type(&self) -> &'static str;
}
#[derive(Debug)]
pub struct GenericTransportAdapter<T: TransportTrait> {
transport: Arc<RwLock<T>>,
}
impl<T: TransportTrait> GenericTransportAdapter<T> {
pub fn new(transport: T) -> Self {
Self {
transport: Arc::new(RwLock::new(transport)),
}
}
async fn process_messages(
transport: Arc<RwLock<T>>,
handler: Arc<dyn ProtocolHandler>,
) -> Result<()> {
loop {
let message = {
let mut t = transport.write().await;
if !t.is_connected() {
break;
}
if let Ok(msg) = t.receive().await {
msg
} else {
if !t.is_connected() {
break;
}
return Err(crate::error::Error::internal("Transport receive failed"));
}
};
match message {
TransportMessage::Request { id, request } => {
let response = handler.handle_request(id, request, None).await;
let mut t = transport.write().await;
t.send(TransportMessage::Response(response)).await?;
},
TransportMessage::Notification(notification) => {
handler.handle_notification(notification).await?;
},
TransportMessage::Response(_) => {
tracing::warn!("Server received unexpected response message");
},
}
}
Ok(())
}
}
#[async_trait]
impl<T: TransportTrait + 'static> TransportAdapter for GenericTransportAdapter<T> {
async fn serve(&self, handler: Arc<dyn ProtocolHandler>) -> Result<()> {
let result = Self::process_messages(self.transport.clone(), handler).await;
{
let mut t = self.transport.write().await;
let _ = t.close().await;
}
result
}
fn transport_type(&self) -> &'static str {
"generic"
}
}
#[cfg(not(target_arch = "wasm32"))]
#[derive(Debug)]
pub struct StdioAdapter {
inner: GenericTransportAdapter<crate::shared::stdio::StdioTransport>,
}
#[cfg(not(target_arch = "wasm32"))]
impl StdioAdapter {
pub fn new() -> Self {
use crate::shared::stdio::StdioTransport;
Self {
inner: GenericTransportAdapter::new(StdioTransport::new()),
}
}
}
#[cfg(not(target_arch = "wasm32"))]
impl Default for StdioAdapter {
fn default() -> Self {
Self::new()
}
}
#[cfg(not(target_arch = "wasm32"))]
#[async_trait]
impl TransportAdapter for StdioAdapter {
async fn serve(&self, handler: Arc<dyn ProtocolHandler>) -> Result<()> {
self.inner.serve(handler).await
}
fn transport_type(&self) -> &'static str {
"stdio"
}
}
#[cfg(feature = "http")]
#[derive(Debug)]
pub struct HttpAdapter {
}
#[cfg(feature = "http")]
impl HttpAdapter {
pub fn new() -> Self {
Self {}
}
pub async fn handle_http_request(
&self,
handler: Arc<dyn ProtocolHandler>,
body: String,
) -> Result<String> {
let message: TransportMessage = serde_json::from_str(&body)?;
match message {
TransportMessage::Request { id, request } => {
let response = handler.handle_request(id, request, None).await;
Ok(serde_json::to_string(&TransportMessage::Response(
response,
))?)
},
TransportMessage::Notification(notification) => {
handler.handle_notification(notification).await?;
Ok("".to_string()) },
TransportMessage::Response(_) => Err(crate::error::Error::protocol(
crate::error::ErrorCode::INVALID_REQUEST,
"HTTP adapter only accepts requests and notifications",
)),
}
}
}
#[cfg(feature = "http")]
impl Default for HttpAdapter {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "http")]
#[async_trait]
impl TransportAdapter for HttpAdapter {
async fn serve(&self, _handler: Arc<dyn ProtocolHandler>) -> Result<()> {
Err(crate::error::Error::internal(
"HTTP adapter should be used with handle_http_request method",
))
}
fn transport_type(&self) -> &'static str {
"http"
}
}
#[cfg(feature = "websocket")]
#[derive(Debug)]
pub struct WebSocketAdapter<T: TransportTrait> {
inner: GenericTransportAdapter<T>,
}
#[cfg(feature = "websocket")]
impl<T: TransportTrait + 'static> WebSocketAdapter<T> {
pub fn new(transport: T) -> Self {
Self {
inner: GenericTransportAdapter::new(transport),
}
}
}
#[cfg(feature = "websocket")]
#[async_trait]
impl<T: TransportTrait + 'static> TransportAdapter for WebSocketAdapter<T> {
async fn serve(&self, handler: Arc<dyn ProtocolHandler>) -> Result<()> {
self.inner.serve(handler).await
}
fn transport_type(&self) -> &'static str {
"websocket"
}
}
#[cfg(test)]
#[derive(Debug)]
pub struct MockAdapter {
requests: Arc<RwLock<Vec<(RequestId, Request)>>>,
responses: Arc<RwLock<Vec<JSONRPCResponse>>>,
}
#[cfg(test)]
impl Default for MockAdapter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
impl MockAdapter {
pub fn new() -> Self {
Self {
requests: Arc::new(RwLock::new(Vec::new())),
responses: Arc::new(RwLock::new(Vec::new())),
}
}
pub async fn add_request(&self, id: RequestId, request: Request) {
self.requests.write().await.push((id, request));
}
pub async fn get_responses(&self) -> Vec<JSONRPCResponse> {
self.responses.read().await.clone()
}
}
#[cfg(test)]
#[async_trait]
impl TransportAdapter for MockAdapter {
async fn serve(&self, handler: Arc<dyn ProtocolHandler>) -> Result<()> {
let requests = self.requests.read().await.clone();
for (id, request) in requests {
let response = handler.handle_request(id, request, None).await;
self.responses.write().await.push(response);
}
Ok(())
}
fn transport_type(&self) -> &'static str {
"mock"
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::server::core::ServerCore;
use crate::types::{ClientRequest, Implementation, InitializeRequest, ServerCapabilities};
use std::collections::HashMap;
#[tokio::test]
async fn test_mock_adapter() {
use crate::runtime::RwLock;
use crate::shared::middleware::EnhancedMiddlewareChain;
let server = ServerCore::new(
Implementation::new("test-server", "1.0.0"),
ServerCapabilities::tools_only(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
HashMap::new(),
None,
None,
None,
None,
Arc::new(RwLock::new(EnhancedMiddlewareChain::new())),
Arc::new(RwLock::new(
crate::server::tool_middleware::ToolMiddlewareChain::new(),
)),
None, None, false,
crate::server::limits::PayloadLimits::default(),
);
let handler = Arc::new(server);
let adapter = MockAdapter::new();
let init_req = Request::Client(Box::new(ClientRequest::Initialize(InitializeRequest {
protocol_version: "2024-11-05".to_string(),
capabilities: crate::types::ClientCapabilities::default(),
client_info: Implementation::new("test-client", "1.0.0"),
})));
adapter.add_request(RequestId::from(1i64), init_req).await;
adapter.serve(handler).await.unwrap();
let responses = adapter.get_responses().await;
assert_eq!(responses.len(), 1);
assert_eq!(responses[0].id, RequestId::from(1i64));
}
}