use crate::{Error, Sink as SinkTrait, Stream as StreamTrait};
use bytes::Bytes;
use commonware_utils::StableBuf;
use futures::channel::oneshot;
use std::{
collections::VecDeque,
sync::{Arc, Mutex},
};
pub struct Channel {
buffer: VecDeque<u8>,
waiter: Option<(usize, oneshot::Sender<Bytes>)>,
sink_alive: bool,
stream_alive: bool,
}
impl Channel {
pub fn init() -> (Sink, Stream) {
let channel = Arc::new(Mutex::new(Channel {
buffer: VecDeque::new(),
waiter: None,
sink_alive: true,
stream_alive: true,
}));
(
Sink {
channel: channel.clone(),
},
Stream { channel },
)
}
}
pub struct Sink {
channel: Arc<Mutex<Channel>>,
}
impl SinkTrait for Sink {
async fn send(&mut self, msg: impl Into<StableBuf> + Send) -> Result<(), Error> {
let msg = msg.into();
let (os_send, data) = {
let mut channel = self.channel.lock().unwrap();
if !channel.stream_alive {
return Err(Error::Closed);
}
channel.buffer.extend(msg.as_ref());
if channel
.waiter
.as_ref()
.is_some_and(|(requested, _)| *requested <= channel.buffer.len())
{
let (requested, os_send) = channel.waiter.take().unwrap();
let data: Vec<u8> = channel.buffer.drain(0..requested).collect();
(os_send, Bytes::from(data))
} else {
return Ok(());
}
};
os_send.send(data).map_err(|_| Error::SendFailed)?;
Ok(())
}
}
impl Drop for Sink {
fn drop(&mut self) {
let mut channel = self.channel.lock().unwrap();
channel.sink_alive = false;
channel.waiter.take();
}
}
pub struct Stream {
channel: Arc<Mutex<Channel>>,
}
impl StreamTrait for Stream {
async fn recv(&mut self, buf: impl Into<StableBuf> + Send) -> Result<StableBuf, Error> {
let mut buf = buf.into();
let os_recv = {
let mut channel = self.channel.lock().unwrap();
if channel.buffer.len() >= buf.len() {
let b: Vec<u8> = channel.buffer.drain(0..buf.len()).collect();
buf.put_slice(&b);
return Ok(buf);
}
if !channel.sink_alive {
return Err(Error::Closed);
}
assert!(channel.waiter.is_none());
let (os_send, os_recv) = oneshot::channel();
channel.waiter = Some((buf.len(), os_send));
os_recv
};
let data = os_recv.await.map_err(|_| Error::Closed)?;
assert_eq!(data.len(), buf.len());
buf.put_slice(&data);
Ok(buf)
}
}
impl Drop for Stream {
fn drop(&mut self) {
let mut channel = self.channel.lock().unwrap();
channel.stream_alive = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{deterministic, Clock, Runner, Spawner};
use commonware_macros::select;
use std::{thread::sleep, time::Duration};
#[test]
fn test_send_recv() {
let (mut sink, mut stream) = Channel::init();
let data = b"hello world".to_vec();
let executor = deterministic::Runner::default();
executor.start(|_| async move {
sink.send(data.clone()).await.unwrap();
let buf = stream.recv(vec![0; data.len()]).await.unwrap();
assert_eq!(buf.as_ref(), data);
});
}
#[test]
fn test_send_recv_partial_multiple() {
let (mut sink, mut stream) = Channel::init();
let data = b"hello".to_vec();
let data2 = b" world".to_vec();
let executor = deterministic::Runner::default();
executor.start(|_| async move {
sink.send(data).await.unwrap();
sink.send(data2).await.unwrap();
let buf = stream.recv(vec![0; 5]).await.unwrap();
assert_eq!(buf.as_ref(), b"hello");
let buf = stream.recv(buf).await.unwrap();
assert_eq!(buf.as_ref(), b" worl");
let buf = stream.recv(vec![0; 1]).await.unwrap();
assert_eq!(buf.as_ref(), b"d");
});
}
#[test]
fn test_send_recv_async() {
let (mut sink, mut stream) = Channel::init();
let data = b"hello world";
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let (buf, _) = futures::try_join!(stream.recv(vec![0; data.len()]), async {
sleep(Duration::from_millis(50));
sink.send(data.to_vec()).await
})
.unwrap();
assert_eq!(buf.as_ref(), data);
});
}
#[test]
fn test_recv_error_sink_dropped_while_waiting() {
let (sink, mut stream) = Channel::init();
let executor = deterministic::Runner::default();
executor.start(|context| async move {
futures::join!(
async {
let result = stream.recv(vec![0; 5]).await;
assert!(matches!(result, Err(Error::Closed)));
},
async {
context.sleep(Duration::from_millis(50)).await;
drop(sink);
}
);
});
}
#[test]
fn test_recv_error_sink_dropped_before_recv() {
let (sink, mut stream) = Channel::init();
drop(sink);
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let result = stream.recv(vec![0; 5]).await;
assert!(matches!(result, Err(Error::Closed)));
});
}
#[test]
fn test_send_error_stream_dropped() {
let (mut sink, mut stream) = Channel::init();
let executor = deterministic::Runner::default();
executor.start(|context| async move {
assert!(sink.send(b"7 bytes".to_vec()).await.is_ok());
let handle = context.clone().spawn(|_| async move {
let _ = stream.recv(vec![0; 5]).await;
let _ = stream.recv(vec![0; 5]).await;
});
context.sleep(Duration::from_millis(50)).await;
handle.abort();
assert!(matches!(handle.await, Err(Error::Closed)));
let result = sink.send(b"hello world".to_vec()).await;
assert!(matches!(result, Err(Error::Closed)));
});
}
#[test]
fn test_send_error_stream_dropped_before_send() {
let (mut sink, stream) = Channel::init();
drop(stream);
let executor = deterministic::Runner::default();
executor.start(|_| async move {
let result = sink.send(b"hello world".to_vec()).await;
assert!(matches!(result, Err(Error::Closed)));
});
}
#[test]
fn test_recv_timeout() {
let (_sink, mut stream) = Channel::init();
let executor = deterministic::Runner::default();
executor.start(|context| async move {
select! {
v = stream.recv(vec![0;5]) => {
panic!("unexpected value: {v:?}");
},
_ = context.sleep(Duration::from_millis(100)) => {
"timeout"
},
};
});
}
}