use std::sync::Arc;
use async_trait::async_trait;
use bytes::Bytes;
use tokio::sync::{mpsc, Mutex};
use crate::error::{EnigmaProtocolError, Result};
#[async_trait]
pub trait Transport: Send + Sync {
async fn send(&self, data: Bytes) -> Result<()>;
async fn recv(&self) -> Result<Bytes>;
async fn close(&self) -> Result<()>;
}
pub struct InMemoryDuplexTransport {
sender: mpsc::Sender<Bytes>,
receiver: Mutex<mpsc::Receiver<Bytes>>,
}
impl InMemoryDuplexTransport {
pub fn new(sender: mpsc::Sender<Bytes>, receiver: mpsc::Receiver<Bytes>) -> Self {
Self {
sender,
receiver: Mutex::new(receiver),
}
}
}
#[async_trait]
impl Transport for InMemoryDuplexTransport {
async fn send(&self, data: Bytes) -> Result<()> {
self.sender
.send(data)
.await
.map_err(|_| EnigmaProtocolError::Transport)
}
async fn recv(&self) -> Result<Bytes> {
let mut guard = self.receiver.lock().await;
guard.recv().await.ok_or(EnigmaProtocolError::Transport)
}
async fn close(&self) -> Result<()> {
let mut guard = self.receiver.lock().await;
guard.close();
Ok(())
}
}
pub fn in_memory_duplex_pair(buffer: usize) -> (Arc<dyn Transport>, Arc<dyn Transport>) {
let (a_tx, a_rx) = mpsc::channel(buffer);
let (b_tx, b_rx) = mpsc::channel(buffer);
let first: Arc<dyn Transport> = Arc::new(InMemoryDuplexTransport::new(a_tx, b_rx));
let second: Arc<dyn Transport> = Arc::new(InMemoryDuplexTransport::new(b_tx, a_rx));
(first, second)
}