use {
async_channel::{
Receiver,
Sender,
},
async_lock::Mutex,
futures_lite::{
io::{
AsyncRead,
AsyncWrite,
},
stream::Stream,
},
std::{
collections::HashMap,
io,
pin::Pin,
sync::Arc,
task::{
Context,
Poll,
},
},
};
const CHANNEL_BUFFER_SIZE: usize = 64;
#[derive(Clone)]
pub struct MemoryTransport {
inner: Arc<Mutex<TransportInner>>,
}
struct TransportInner {
listeners: HashMap<String, Sender<ConnectionRequest>>,
}
struct ConnectionRequest {
client_tx: Sender<Vec<u8>>,
client_rx: Receiver<Vec<u8>>,
response: Sender<ConnectionResponse>,
}
struct ConnectionResponse {
server_tx: Sender<Vec<u8>>,
server_rx: Receiver<Vec<u8>>,
}
impl MemoryTransport {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(TransportInner {
listeners: HashMap::new(),
})),
}
}
pub async fn bind(&self, endpoint: &str) -> io::Result<MemoryListener> {
let mut inner = self.inner.lock().await;
if inner.listeners.contains_key(endpoint) {
return Err(io::Error::new(
io::ErrorKind::AddrInUse,
format!("endpoint '{}' already bound", endpoint),
));
}
let (tx, rx) = async_channel::bounded(CHANNEL_BUFFER_SIZE);
inner.listeners.insert(endpoint.to_string(), tx);
Ok(MemoryListener {
endpoint: endpoint.to_string(),
transport: self.clone(),
incoming: rx,
})
}
pub async fn connect(&self, endpoint: &str) -> io::Result<MemoryStream> {
let listener_tx = {
let inner = self.inner.lock().await;
inner.listeners.get(endpoint).cloned()
};
let listener_tx = listener_tx.ok_or_else(|| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
format!("no listener at endpoint '{}'", endpoint),
)
})?;
let (client_tx, client_rx) =
async_channel::bounded::<Vec<u8>>(CHANNEL_BUFFER_SIZE);
let (response_tx, response_rx) =
async_channel::bounded::<ConnectionResponse>(1);
let request = ConnectionRequest {
client_tx,
client_rx,
response: response_tx,
};
listener_tx.send(request).await.map_err(|_| {
io::Error::new(io::ErrorKind::ConnectionRefused, "listener closed")
})?;
let response = response_rx.recv().await.map_err(|_| {
io::Error::new(
io::ErrorKind::ConnectionRefused,
"listener did not respond",
)
})?;
Ok(MemoryStream::new(response.server_tx, response.server_rx))
}
async fn unbind(&self, endpoint: &str) {
let mut inner = self.inner.lock().await;
inner.listeners.remove(endpoint);
}
}
impl Default for MemoryTransport {
fn default() -> Self {
Self::new()
}
}
pub struct MemoryListener {
endpoint: String,
transport: MemoryTransport,
incoming: Receiver<ConnectionRequest>,
}
impl MemoryListener {
pub async fn accept(&mut self) -> io::Result<MemoryStream> {
let request = self.incoming.recv().await.map_err(|_| {
io::Error::new(io::ErrorKind::BrokenPipe, "listener channel closed")
})?;
let (server_tx, server_rx) =
async_channel::bounded::<Vec<u8>>(CHANNEL_BUFFER_SIZE);
let response = ConnectionResponse {
server_tx: request.client_tx,
server_rx,
};
request.response.send(response).await.map_err(|_| {
io::Error::new(io::ErrorKind::BrokenPipe, "client disconnected")
})?;
Ok(MemoryStream::new(server_tx, request.client_rx))
}
pub fn endpoint(&self) -> &str {
&self.endpoint
}
}
impl Drop for MemoryListener {
fn drop(&mut self) {
let transport = self.transport.clone();
let endpoint = std::mem::take(&mut self.endpoint);
smol::spawn(async move {
transport.unbind(&endpoint).await;
})
.detach();
}
}
pub struct MemoryStream {
tx: Sender<Vec<u8>>,
rx: Pin<Box<Receiver<Vec<u8>>>>,
read_buf: Vec<u8>,
}
impl MemoryStream {
fn new(tx: Sender<Vec<u8>>, rx: Receiver<Vec<u8>>) -> Self {
Self {
tx,
rx: Box::pin(rx),
read_buf: Vec::new(),
}
}
pub fn into_split(self) -> (MemoryReadHalf, MemoryWriteHalf) {
(
MemoryReadHalf {
rx: self.rx,
read_buf: self.read_buf,
},
MemoryWriteHalf { tx: self.tx },
)
}
}
impl AsyncRead for MemoryStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if !self.read_buf.is_empty() {
let to_copy = std::cmp::min(buf.len(), self.read_buf.len());
buf[..to_copy].copy_from_slice(&self.read_buf[..to_copy]);
self.read_buf.drain(..to_copy);
return Poll::Ready(Ok(to_copy));
}
match self.rx.as_mut().poll_next(cx) {
| Poll::Ready(Some(data)) => {
let to_copy = std::cmp::min(buf.len(), data.len());
buf[..to_copy].copy_from_slice(&data[..to_copy]);
if to_copy < data.len() {
self.read_buf.extend_from_slice(&data[to_copy..]);
}
Poll::Ready(Ok(to_copy))
},
| Poll::Ready(None) => Poll::Ready(Ok(0)), | Poll::Pending => Poll::Pending,
}
}
}
impl AsyncWrite for MemoryStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let data = buf.to_vec();
let len = data.len();
match self.tx.try_send(data) {
| Ok(()) => Poll::Ready(Ok(len)),
| Err(async_channel::TrySendError::Full(_)) => {
cx.waker().wake_by_ref();
Poll::Pending
},
| Err(async_channel::TrySendError::Closed(_)) => {
Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"channel closed",
)))
},
}
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.tx.close();
Poll::Ready(Ok(()))
}
}
pub struct MemoryReadHalf {
rx: Pin<Box<Receiver<Vec<u8>>>>,
read_buf: Vec<u8>,
}
impl AsyncRead for MemoryReadHalf {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
if !self.read_buf.is_empty() {
let to_copy = std::cmp::min(buf.len(), self.read_buf.len());
buf[..to_copy].copy_from_slice(&self.read_buf[..to_copy]);
self.read_buf.drain(..to_copy);
return Poll::Ready(Ok(to_copy));
}
match self.rx.as_mut().poll_next(cx) {
| Poll::Ready(Some(data)) => {
let to_copy = std::cmp::min(buf.len(), data.len());
buf[..to_copy].copy_from_slice(&data[..to_copy]);
if to_copy < data.len() {
self.read_buf.extend_from_slice(&data[to_copy..]);
}
Poll::Ready(Ok(to_copy))
},
| Poll::Ready(None) => Poll::Ready(Ok(0)),
| Poll::Pending => Poll::Pending,
}
}
}
pub struct MemoryWriteHalf {
tx: Sender<Vec<u8>>,
}
impl AsyncWrite for MemoryWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let data = buf.to_vec();
let len = data.len();
match self.tx.try_send(data) {
| Ok(()) => Poll::Ready(Ok(len)),
| Err(async_channel::TrySendError::Full(_)) => {
cx.waker().wake_by_ref();
Poll::Pending
},
| Err(async_channel::TrySendError::Closed(_)) => {
Poll::Ready(Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"channel closed",
)))
},
}
}
fn poll_flush(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
self.tx.close();
Poll::Ready(Ok(()))
}
}
impl std::fmt::Debug for MemoryStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryStream")
.field("read_buf_len", &self.read_buf.len())
.finish_non_exhaustive()
}
}
impl std::fmt::Debug for MemoryListener {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryListener")
.field("endpoint", &self.endpoint)
.finish_non_exhaustive()
}
}
impl std::fmt::Debug for MemoryTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryTransport").finish_non_exhaustive()
}
}
impl std::fmt::Debug for MemoryReadHalf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryReadHalf")
.field("read_buf_len", &self.read_buf.len())
.finish_non_exhaustive()
}
}
impl std::fmt::Debug for MemoryWriteHalf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryWriteHalf").finish_non_exhaustive()
}
}
#[cfg(test)]
mod tests {
use {
super::*,
futures_lite::io::{
AsyncReadExt,
AsyncWriteExt,
},
};
#[test]
fn test_basic_communication() {
smol::block_on(async {
let transport = MemoryTransport::new();
let mut listener = transport.bind("test").await.unwrap();
let transport_clone = transport.clone();
let client_handle = smol::spawn(async move {
let mut stream = transport_clone.connect("test").await.unwrap();
stream.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 5];
stream.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"world");
});
let mut server = listener.accept().await.unwrap();
let mut buf = [0u8; 5];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"hello");
server.write_all(b"world").await.unwrap();
client_handle.await;
});
}
#[test]
fn test_connection_refused() {
smol::block_on(async {
let transport = MemoryTransport::new();
let result = transport.connect("nonexistent").await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::ConnectionRefused);
});
}
#[test]
fn test_address_in_use() {
smol::block_on(async {
let transport = MemoryTransport::new();
let _listener1 = transport.bind("test").await.unwrap();
let result = transport.bind("test").await;
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::AddrInUse);
});
}
#[test]
fn test_multiple_endpoints() {
smol::block_on(async {
let transport = MemoryTransport::new();
let mut listener1 = transport.bind("endpoint1").await.unwrap();
let mut listener2 = transport.bind("endpoint2").await.unwrap();
let t = transport.clone();
smol::spawn(async move {
let mut s1 = t.connect("endpoint1").await.unwrap();
s1.write_all(b"to-1").await.unwrap();
let mut s2 = t.connect("endpoint2").await.unwrap();
s2.write_all(b"to-2").await.unwrap();
})
.detach();
let mut server1 = listener1.accept().await.unwrap();
let mut server2 = listener2.accept().await.unwrap();
let mut buf1 = [0u8; 4];
let mut buf2 = [0u8; 4];
server1.read_exact(&mut buf1).await.unwrap();
server2.read_exact(&mut buf2).await.unwrap();
assert_eq!(&buf1, b"to-1");
assert_eq!(&buf2, b"to-2");
});
}
#[test]
fn test_split_stream() {
smol::block_on(async {
let transport = MemoryTransport::new();
let mut listener = transport.bind("test").await.unwrap();
let t = transport.clone();
let client_handle = smol::spawn(async move {
let stream = t.connect("test").await.unwrap();
let (mut reader, mut writer) = stream.into_split();
writer.write_all(b"ping").await.unwrap();
let mut buf = [0u8; 4];
reader.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"pong");
});
let server = listener.accept().await.unwrap();
let (mut reader, mut writer) = server.into_split();
let mut buf = [0u8; 4];
reader.read_exact(&mut buf).await.unwrap();
assert_eq!(&buf, b"ping");
writer.write_all(b"pong").await.unwrap();
client_handle.await;
});
}
#[test]
fn test_large_transfer() {
smol::block_on(async {
let transport = MemoryTransport::new();
let mut listener = transport.bind("test").await.unwrap();
let data: Vec<u8> = (0..1_000_000).map(|i| (i % 256) as u8).collect();
let expected = data.clone();
let t = transport.clone();
let client_handle = smol::spawn(async move {
let mut stream = t.connect("test").await.unwrap();
stream.write_all(&data).await.unwrap();
});
let mut server = listener.accept().await.unwrap();
let mut received = vec![0u8; 1_000_000];
server.read_exact(&mut received).await.unwrap();
assert_eq!(received, expected);
client_handle.await;
});
}
#[test]
fn test_bidirectional() {
smol::block_on(async {
let transport = MemoryTransport::new();
let mut listener = transport.bind("test").await.unwrap();
let t = transport.clone();
let client_handle = smol::spawn(async move {
let mut stream = t.connect("test").await.unwrap();
for i in 0u8..10 {
stream.write_all(&[i]).await.unwrap();
let mut buf = [0u8; 1];
stream.read_exact(&mut buf).await.unwrap();
assert_eq!(buf[0], i + 100);
}
});
let mut server = listener.accept().await.unwrap();
for i in 0u8..10 {
let mut buf = [0u8; 1];
server.read_exact(&mut buf).await.unwrap();
assert_eq!(buf[0], i);
server.write_all(&[i + 100]).await.unwrap();
}
client_handle.await;
});
}
}