use super::{Message, Transport};
use super::error::{TransportError, TransportErrorCode};
use super::Result;
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::mpsc::{self, Receiver, Sender};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
use tracing::debug;
#[derive(Clone)]
pub struct ServerInMemoryTransport {
rx: Arc<Mutex<Option<Receiver<Message>>>>,
tx: Sender<Message>,
}
impl Default for ServerInMemoryTransport {
fn default() -> Self {
let (tx, rx) = mpsc::channel(100); Self {
rx: Arc::new(Mutex::new(Some(rx))),
tx,
}
}
}
#[async_trait]
impl Transport for ServerInMemoryTransport {
async fn receive(&self) -> Result<Option<Message>> {
let mut rx_guard = self.rx.lock().await;
let rx = rx_guard
.as_mut()
.ok_or_else(|| TransportError::new(TransportErrorCode::InvalidState, "Transport not opened"))?;
match rx.recv().await {
Some(message) => {
debug!("Server received: {:?}", message);
Ok(Some(message))
}
None => {
debug!("Client channel closed");
Ok(None)
}
}
}
async fn send(&self, message: &Message) -> Result<()> {
debug!("Server sending: {:?}", message);
self.tx
.send(message.clone())
.await
.map_err(|e| TransportError::new(TransportErrorCode::MessageSendFailed, format!("Failed to send message: {}", e)))?;
Ok(())
}
async fn open(&self) -> Result<()> {
Ok(())
}
async fn close(&self) -> Result<()> {
*self.rx.lock().await = None;
Ok(())
}
}
#[derive(Clone)]
pub struct ClientInMemoryTransport {
tx: Arc<Mutex<Option<Sender<Message>>>>,
rx: Arc<Mutex<Option<Receiver<Message>>>>,
server_handle: Arc<Mutex<Option<JoinHandle<()>>>>,
server_factory: Arc<dyn Fn(ServerInMemoryTransport) -> JoinHandle<()> + Send + Sync>,
}
impl ClientInMemoryTransport {
pub fn new<F>(server_factory: F) -> Self
where
F: Fn(ServerInMemoryTransport) -> JoinHandle<()> + Send + Sync + 'static,
{
Self {
tx: Arc::new(Mutex::new(None)),
rx: Arc::new(Mutex::new(None)),
server_handle: Arc::new(Mutex::new(None)),
server_factory: Arc::new(server_factory),
}
}
}
#[async_trait]
impl Transport for ClientInMemoryTransport {
async fn receive(&self) -> Result<Option<Message>> {
let mut rx_guard = self.rx.lock().await;
let rx = rx_guard
.as_mut()
.ok_or_else(|| TransportError::new(TransportErrorCode::InvalidState, "Transport not opened"))?;
match rx.recv().await {
Some(message) => {
debug!("Client received: {:?}", message);
Ok(Some(message))
}
None => {
debug!("Server channel closed");
Ok(None)
}
}
}
async fn send(&self, message: &Message) -> Result<()> {
let tx_guard = self.tx.lock().await;
let tx = tx_guard
.as_ref()
.ok_or_else(|| TransportError::new(TransportErrorCode::InvalidState, "Transport not opened"))?;
debug!("Client sending: {:?}", message);
tx.send(message.clone())
.await
.map_err(|e| TransportError::new(TransportErrorCode::MessageSendFailed, format!("Failed to send message: {}", e)))?;
Ok(())
}
async fn open(&self) -> Result<()> {
let (client_tx, server_rx) = mpsc::channel(100);
let (server_tx, client_rx) = mpsc::channel(100);
let server_transport = ServerInMemoryTransport {
rx: Arc::new(Mutex::new(Some(server_rx))),
tx: server_tx,
};
let server_handle = (self.server_factory)(server_transport);
*self.rx.lock().await = Some(client_rx);
*self.tx.lock().await = Some(client_tx);
*self.server_handle.lock().await = Some(server_handle);
Ok(())
}
async fn close(&self) -> Result<()> {
*self.tx.lock().await = None;
*self.rx.lock().await = None;
if let Some(handle) = self.server_handle.lock().await.take() {
handle.await.map_err(|e| TransportError::new(TransportErrorCode::InternalError, format!("Server task failed: {}", e)))?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::transport::{JsonRpcMessage, JsonRpcRequest, JsonRpcVersion};
use std::time::Duration;
async fn echo_server(transport: ServerInMemoryTransport) {
while let Ok(Some(message)) = transport.receive().await {
if transport.send(&message).await.is_err() {
break;
}
}
}
#[tokio::test]
async fn test_async_transport() -> Result<()> {
let transport = ClientInMemoryTransport::new(|t| tokio::spawn(echo_server(t)));
let test_message = JsonRpcMessage::Request(JsonRpcRequest {
id: 1,
method: "test".to_string(),
params: Some(serde_json::json!({"hello": "world"})),
jsonrpc: JsonRpcVersion::default(),
});
transport.open().await?;
transport.send(&test_message).await?;
let response = transport.receive().await?;
assert_eq!(Some(test_message), response);
transport.close().await?;
Ok(())
}
#[tokio::test]
async fn test_graceful_shutdown() -> Result<()> {
let transport = ClientInMemoryTransport::new(|t| {
tokio::spawn(async move {
tokio::time::sleep(Duration::from_secs(5)).await;
drop(t);
})
});
transport.open().await?;
let transport_clone = transport.clone();
let read_handle = tokio::spawn(async move {
let result = transport_clone.receive().await;
debug!("Receive returned: {:?}", result);
result
});
tokio::time::sleep(Duration::from_millis(100)).await;
let start = std::time::Instant::now();
transport.close().await?;
let shutdown_duration = start.elapsed();
assert!(shutdown_duration < Duration::from_secs(5));
let read_result = read_handle.await?;
assert!(read_result.is_ok());
assert_eq!(read_result.unwrap(), None);
Ok(())
}
#[tokio::test]
async fn test_multiple_messages() -> Result<()> {
let transport = ClientInMemoryTransport::new(|t| tokio::spawn(echo_server(t)));
transport.open().await?;
let messages: Vec<_> = (0..5)
.map(|i| {
JsonRpcMessage::Request(JsonRpcRequest {
id: i,
method: format!("test_{}", i),
params: Some(serde_json::json!({"index": i})),
jsonrpc: JsonRpcVersion::default(),
})
})
.collect();
for msg in &messages {
transport.send(msg).await?;
}
for expected in &messages {
let received = transport.receive().await?;
assert_eq!(Some(expected.clone()), received);
}
transport.close().await?;
Ok(())
}
}