use rmcp::RoleClient;
use rmcp::RoleServer;
use rmcp::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage};
use rmcp::transport::Transport;
use std::fmt;
use std::future::Future;
use std::sync::Arc;
use tokio::sync::{Mutex, mpsc};
#[derive(Debug)]
pub enum InMemoryTransportError {
ChannelClosed,
}
impl fmt::Display for InMemoryTransportError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
InMemoryTransportError::ChannelClosed => write!(f, "Channel closed"),
}
}
}
impl std::error::Error for InMemoryTransportError {}
pub struct InMemoryTransport<R: ServiceRole> {
tx: Arc<Mutex<mpsc::Sender<TxJsonRpcMessage<R>>>>,
rx: Arc<Mutex<mpsc::Receiver<RxJsonRpcMessage<R>>>>,
}
impl<R: ServiceRole> InMemoryTransport<R> {
fn new(tx: mpsc::Sender<TxJsonRpcMessage<R>>, rx: mpsc::Receiver<RxJsonRpcMessage<R>>) -> Self {
Self { tx: Arc::new(Mutex::new(tx)), rx: Arc::new(Mutex::new(rx)) }
}
}
pub fn create_in_memory_transport() -> (InMemoryTransport<RoleClient>, InMemoryTransport<RoleServer>) {
let (client_tx, server_rx) = mpsc::channel(1000); let (server_tx, client_rx) = mpsc::channel(1000);
let client_transport = InMemoryTransport::new(client_tx, client_rx);
let server_transport = InMemoryTransport::new(server_tx, server_rx);
(client_transport, server_transport)
}
impl<R: ServiceRole> Transport<R> for InMemoryTransport<R> {
type Error = InMemoryTransportError;
fn send(&mut self, item: TxJsonRpcMessage<R>) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
let tx = self.tx.clone();
async move {
let tx = tx.lock().await;
tx.send(item).await.map_err(|_| InMemoryTransportError::ChannelClosed)?;
Ok(())
}
}
fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<R>>> + Send {
let rx = self.rx.clone();
async move {
let mut rx = rx.lock().await;
rx.recv().await
}
}
async fn close(&mut self) -> Result<(), Self::Error> {
Ok(())
}
}