use tokio::sync::mpsc;
use crate::traits::{AsyncFrameTransport, TransportError};
pub struct MemoryTransport {
tx: mpsc::Sender<Vec<u8>>,
rx: mpsc::Receiver<Vec<u8>>,
read_buf: Vec<u8>,
read_pos: usize,
}
impl MemoryTransport {
pub fn pair(buffer_size: usize) -> (Self, Self) {
let (tx_a, rx_b) = mpsc::channel(buffer_size);
let (tx_b, rx_a) = mpsc::channel(buffer_size);
let a = Self {
tx: tx_a,
rx: rx_a,
read_buf: Vec::new(),
read_pos: 0,
};
let b = Self {
tx: tx_b,
rx: rx_b,
read_buf: Vec::new(),
read_pos: 0,
};
(a, b)
}
}
#[async_trait::async_trait]
impl AsyncFrameTransport for MemoryTransport {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, TransportError> {
if self.read_pos < self.read_buf.len() {
let remaining = &self.read_buf[self.read_pos..];
let n = remaining.len().min(buf.len());
buf[..n].copy_from_slice(&remaining[..n]);
self.read_pos += n;
if self.read_pos >= self.read_buf.len() {
self.read_buf.clear();
self.read_pos = 0;
}
return Ok(n);
}
match self.rx.recv().await {
Some(data) => {
let n = data.len().min(buf.len());
buf[..n].copy_from_slice(&data[..n]);
if n < data.len() {
self.read_buf = data;
self.read_pos = n;
}
Ok(n)
}
None => Err(TransportError::ConnectionClosed),
}
}
async fn write_all(&mut self, data: &[u8]) -> Result<(), TransportError> {
self.tx
.send(data.to_vec())
.await
.map_err(|_| TransportError::ConnectionClosed)
}
async fn close(&mut self) -> Result<(), TransportError> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn bidirectional_communication() {
let (mut a, mut b) = MemoryTransport::pair(32);
a.write_all(b"hello from A").await.unwrap();
b.write_all(b"hello from B").await.unwrap();
let mut buf = [0u8; 64];
let n = b.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"hello from A");
let n = a.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"hello from B");
}
#[tokio::test]
async fn partial_reads() {
let (mut a, mut b) = MemoryTransport::pair(32);
a.write_all(b"hello world!").await.unwrap();
let mut small_buf = [0u8; 5];
let n = b.read(&mut small_buf).await.unwrap();
assert_eq!(&small_buf[..n], b"hello");
let n = b.read(&mut small_buf).await.unwrap();
assert_eq!(&small_buf[..n], b" worl");
let n = b.read(&mut small_buf).await.unwrap();
assert_eq!(&small_buf[..n], b"d!");
}
#[tokio::test]
async fn connection_closed_on_drop() {
let (a, mut b) = MemoryTransport::pair(32);
drop(a);
let mut buf = [0u8; 64];
let result = b.read(&mut buf).await;
assert!(matches!(result, Err(TransportError::ConnectionClosed)));
}
}