use crate::{BufMut, Error, IoBufs, Sink as SinkTrait, Stream as StreamTrait};
use bytes::{Bytes, BytesMut};
use commonware_utils::{channel::oneshot, sync::Mutex};
use std::sync::Arc;
const DEFAULT_READ_BUFFER_SIZE: usize = 64 * 1024;
pub struct Channel {
buffer: BytesMut,
waiter: Option<(usize, oneshot::Sender<Bytes>)>,
read_buffer_size: usize,
sink_alive: bool,
stream_alive: bool,
}
impl Channel {
pub fn init() -> (Sink, Stream) {
Self::init_with_read_buffer_size(DEFAULT_READ_BUFFER_SIZE)
}
pub fn init_with_read_buffer_size(read_buffer_size: usize) -> (Sink, Stream) {
let channel = Arc::new(Mutex::new(Self {
buffer: BytesMut::new(),
waiter: None,
read_buffer_size,
sink_alive: true,
stream_alive: true,
}));
(
Sink {
channel: channel.clone(),
},
Stream {
channel,
buffer: BytesMut::new(),
},
)
}
}
pub struct Sink {
channel: Arc<Mutex<Channel>>,
}
impl SinkTrait for Sink {
async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
let (os_send, data) = {
let mut channel = self.channel.lock();
if !channel.stream_alive {
return Err(Error::Closed);
}
channel.buffer.put(bufs.into());
if channel
.waiter
.as_ref()
.is_some_and(|(requested, _)| *requested <= channel.buffer.len())
{
let (requested, os_send) = channel.waiter.take().unwrap();
let send_amount = channel
.buffer
.len()
.min(requested.max(channel.read_buffer_size));
let data = channel.buffer.split_to(send_amount).freeze();
(os_send, data)
} else {
return Ok(());
}
};
os_send.send(data).map_err(|_| Error::SendFailed)?;
Ok(())
}
}
impl Drop for Sink {
fn drop(&mut self) {
let mut channel = self.channel.lock();
channel.sink_alive = false;
channel.waiter.take();
}
}
pub struct Stream {
channel: Arc<Mutex<Channel>>,
buffer: BytesMut,
}
impl StreamTrait for Stream {
async fn recv(&mut self, len: usize) -> Result<IoBufs, Error> {
let os_recv = {
let mut channel = self.channel.lock();
if !channel.buffer.is_empty() {
let target = len.max(channel.read_buffer_size);
let pull_amount = channel
.buffer
.len()
.min(target.saturating_sub(self.buffer.len()));
if pull_amount > 0 {
let data = channel.buffer.split_to(pull_amount);
self.buffer.extend_from_slice(&data);
}
}
if self.buffer.len() >= len {
return Ok(IoBufs::from(self.buffer.split_to(len).freeze()));
}
if !channel.sink_alive {
return Err(Error::Closed);
}
let remaining = len - self.buffer.len();
assert!(channel.waiter.is_none());
let (os_send, os_recv) = oneshot::channel();
channel.waiter = Some((remaining, os_send));
os_recv
};
let data = os_recv.await.map_err(|_| Error::Closed)?;
self.buffer.extend_from_slice(&data);
assert!(self.buffer.len() >= len);
Ok(IoBufs::from(self.buffer.split_to(len).freeze()))
}
fn peek(&self, max_len: usize) -> &[u8] {
let len = max_len.min(self.buffer.len());
&self.buffer[..len]
}
}
impl Drop for Stream {
fn drop(&mut self) {
let mut channel = self.channel.lock();
channel.stream_alive = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{deterministic, Clock, Runner, Spawner};
use commonware_macros::select;
use std::{thread::sleep, time::Duration};
#[test]
fn test_send_recv() {
let (mut sink, mut stream) = Channel::init();
let data = b"hello world";
let executor = deterministic::Runner::default();
executor.start(|_| async move {
sink.send(data.as_slice()).await.unwrap();
let received = stream.recv(data.len()).await.unwrap();
assert_eq!(received.coalesce(), data);
});
}
#[test]
fn test_send_recv_partial_multiple() {
let (mut sink, mut stream) = Channel::init();
let data = b"hello";
let data2 = b" world";
let executor = deterministic::Runner::default();
executor.start(|_| async move {
sink.send(data.as_slice()).await.unwrap();
sink.send(data2.as_slice()).await.unwrap();
let received = stream.recv(5).await.unwrap();
assert_eq!(received.coalesce(), b"hello");
let received = stream.recv(5).await.unwrap();
assert_eq!(received.coalesce(), b" worl");
let received = stream.recv(1).await.unwrap();
assert_eq!(received.coalesce(), b"d");
});
}
#[test]
fn test_send_recv_async() {
let (mut sink, mut stream) = Channel::init();
let data = b"hello world";
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let (received, _) = futures::try_join!(stream.recv(data.len()), async {
sleep(Duration::from_millis(50));
sink.send(data.as_slice()).await
})
.unwrap();
assert_eq!(received.coalesce(), data);
});
}
#[test]
fn test_recv_error_sink_dropped_while_waiting() {
let (sink, mut stream) = Channel::init();
let executor = deterministic::Runner::default();
executor.start(|context| async move {
futures::join!(
async {
let result = stream.recv(5).await;
assert!(matches!(result, Err(Error::Closed)));
},
async {
context.sleep(Duration::from_millis(50)).await;
drop(sink);
}
);
});
}
#[test]
fn test_recv_error_sink_dropped_before_recv() {
let (sink, mut stream) = Channel::init();
drop(sink);
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let result = stream.recv(5).await;
assert!(matches!(result, Err(Error::Closed)));
});
}
#[test]
fn test_send_error_stream_dropped() {
let (mut sink, mut stream) = Channel::init();
let executor = deterministic::Runner::default();
executor.start(|context| async move {
assert!(sink.send(b"7 bytes".as_slice()).await.is_ok());
let handle = context.clone().spawn(|_| async move {
let _ = stream.recv(5).await;
let _ = stream.recv(5).await;
});
context.sleep(Duration::from_millis(50)).await;
handle.abort();
assert!(matches!(handle.await, Err(Error::Closed)));
let result = sink.send(b"hello world".as_slice()).await;
assert!(matches!(result, Err(Error::Closed)));
});
}
#[test]
fn test_send_error_stream_dropped_before_send() {
let (mut sink, stream) = Channel::init();
drop(stream);
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let result = sink.send(b"hello world".as_slice()).await;
assert!(matches!(result, Err(Error::Closed)));
});
}
#[test]
fn test_recv_timeout() {
let (_sink, mut stream) = Channel::init();
let executor = deterministic::Runner::default();
executor.start(|context| async move {
select! {
v = stream.recv(5) => {
panic!("unexpected value: {v:?}");
},
_ = context.sleep(Duration::from_millis(100)) => "timeout",
};
});
}
#[test]
fn test_peek_empty() {
let (_sink, stream) = Channel::init();
assert!(stream.peek(10).is_empty());
}
#[test]
fn test_peek_after_partial_recv() {
let (mut sink, mut stream) = Channel::init();
let executor = deterministic::Runner::default();
executor.start(|_| async move {
sink.send(b"hello world".as_slice()).await.unwrap();
let received = stream.recv(5).await.unwrap();
assert_eq!(received.coalesce(), b"hello");
assert_eq!(stream.peek(100), b" world");
assert_eq!(stream.peek(3), b" wo");
assert_eq!(stream.peek(100), b" world");
let received = stream.recv(6).await.unwrap();
assert_eq!(received.coalesce(), b" world");
assert!(stream.peek(100).is_empty());
});
}
#[test]
fn test_peek_after_recv_wakeup() {
let (mut sink, mut stream) = Channel::init_with_read_buffer_size(64);
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let (tx, rx) = oneshot::channel();
let recv_handle = context.clone().spawn(|_| async move {
let data = stream.recv(3).await.unwrap();
tx.send(stream).ok();
data
});
context.sleep(Duration::from_millis(10)).await;
sink.send(b"ABCDEFGHIJ".as_slice()).await.unwrap();
let received = recv_handle.await.unwrap();
assert_eq!(received.coalesce(), b"ABC");
let stream = rx.await.unwrap();
assert_eq!(stream.peek(100), b"DEFGHIJ");
});
}
#[test]
fn test_peek_multiple_sends() {
let (mut sink, mut stream) = Channel::init();
let executor = deterministic::Runner::default();
executor.start(|_| async move {
sink.send(b"aaa".as_slice()).await.unwrap();
sink.send(b"bbb".as_slice()).await.unwrap();
sink.send(b"ccc".as_slice()).await.unwrap();
let received = stream.recv(4).await.unwrap();
assert_eq!(received.coalesce(), b"aaab");
assert_eq!(stream.peek(100), b"bbccc");
});
}
#[test]
fn test_read_buffer_size_limit() {
let (mut sink, mut stream) = Channel::init_with_read_buffer_size(10);
let executor = deterministic::Runner::default();
executor.start(|_| async move {
sink.send(b"0123456789ABCDEF".as_slice()).await.unwrap();
let received = stream.recv(2).await.unwrap();
assert_eq!(received.coalesce(), b"01");
assert_eq!(stream.peek(100), b"23456789");
let received = stream.recv(8).await.unwrap();
assert_eq!(received.coalesce(), b"23456789");
let received = stream.recv(2).await.unwrap();
assert_eq!(received.coalesce(), b"AB");
assert_eq!(stream.peek(100), b"CDEF");
});
}
#[test]
fn test_recv_before_send() {
let (mut sink, mut stream) = Channel::init_with_read_buffer_size(10);
let executor = deterministic::Runner::default();
executor.start(|context| async move {
let recv_handle = context
.clone()
.spawn(|_| async move { stream.recv(3).await.unwrap() });
context.sleep(Duration::from_millis(10)).await;
sink.send(b"ABCDEFGHIJKLMNOP".as_slice()).await.unwrap();
let received = recv_handle.await.unwrap();
assert_eq!(received.coalesce(), b"ABC");
});
}
}