Skip to main content

mcp_utils/
transport.rs

1use rmcp::RoleClient;
2use rmcp::RoleServer;
3use rmcp::service::{RxJsonRpcMessage, ServiceRole, TxJsonRpcMessage};
4use rmcp::transport::Transport;
5use std::fmt;
6use std::future::Future;
7use std::sync::Arc;
8use tokio::sync::{Mutex, mpsc};
9
10#[derive(Debug)]
11pub enum InMemoryTransportError {
12    ChannelClosed,
13}
14
15impl fmt::Display for InMemoryTransportError {
16    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
17        match self {
18            InMemoryTransportError::ChannelClosed => write!(f, "Channel closed"),
19        }
20    }
21}
22
23impl std::error::Error for InMemoryTransportError {}
24
25/// In-memory transport for connecting `McpServer` and `McpClient` in tests
26pub struct InMemoryTransport<R: ServiceRole> {
27    tx: Arc<Mutex<mpsc::Sender<TxJsonRpcMessage<R>>>>,
28    rx: Arc<Mutex<mpsc::Receiver<RxJsonRpcMessage<R>>>>,
29}
30
31impl<R: ServiceRole> InMemoryTransport<R> {
32    fn new(tx: mpsc::Sender<TxJsonRpcMessage<R>>, rx: mpsc::Receiver<RxJsonRpcMessage<R>>) -> Self {
33        Self { tx: Arc::new(Mutex::new(tx)), rx: Arc::new(Mutex::new(rx)) }
34    }
35}
36
37/// Create a pair of transports for client and server
38pub fn create_in_memory_transport() -> (InMemoryTransport<RoleClient>, InMemoryTransport<RoleServer>) {
39    // Client sends ClientRequest/ClientResult, receives ServerRequest/ServerResult
40    // Server sends ServerRequest/ServerResult, receives ClientRequest/ClientResult
41    let (client_tx, server_rx) = mpsc::channel(1000); // Client -> Server
42    let (server_tx, client_rx) = mpsc::channel(1000); // Server -> Client
43
44    let client_transport = InMemoryTransport::new(client_tx, client_rx);
45    let server_transport = InMemoryTransport::new(server_tx, server_rx);
46
47    (client_transport, server_transport)
48}
49
50impl<R: ServiceRole> Transport<R> for InMemoryTransport<R> {
51    type Error = InMemoryTransportError;
52
53    fn send(&mut self, item: TxJsonRpcMessage<R>) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
54        let tx = self.tx.clone();
55        async move {
56            let tx = tx.lock().await;
57            tx.send(item).await.map_err(|_| InMemoryTransportError::ChannelClosed)?;
58            Ok(())
59        }
60    }
61
62    fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<R>>> + Send {
63        let rx = self.rx.clone();
64        async move {
65            let mut rx = rx.lock().await;
66            rx.recv().await
67        }
68    }
69
70    async fn close(&mut self) -> Result<(), Self::Error> {
71        // Channels will be closed when dropped
72        Ok(())
73    }
74}