commonware_runtime/
mocks.rs

1//! A mock implementation of a channel that implements the Sink and Stream traits.
2
3use crate::{Error, Sink as SinkTrait, Stream as StreamTrait};
4use bytes::Bytes;
5use commonware_utils::{StableBuf, StableBufMut};
6use futures::channel::oneshot;
7use std::{
8    collections::VecDeque,
9    sync::{Arc, Mutex},
10};
11
12/// A mock channel struct that is used internally by Sink and Stream.
13pub struct Channel {
14    /// Stores the bytes sent by the sink that are not yet read by the stream.
15    buffer: VecDeque<u8>,
16
17    /// If the stream is waiting to read bytes, the waiter stores the number of
18    /// bytes that the stream is waiting for, as well as the oneshot sender that
19    /// the sink uses to send the bytes to the stream directly.
20    waiter: Option<(usize, oneshot::Sender<Bytes>)>,
21}
22
23impl Channel {
24    /// Returns an async-safe Sink/Stream pair that share an underlying buffer of bytes.
25    pub fn init() -> (Sink, Stream) {
26        let channel = Arc::new(Mutex::new(Channel {
27            buffer: VecDeque::new(),
28            waiter: None,
29        }));
30        (
31            Sink {
32                channel: channel.clone(),
33            },
34            Stream { channel },
35        )
36    }
37}
38
39/// A mock sink that implements the Sink trait.
40pub struct Sink {
41    channel: Arc<Mutex<Channel>>,
42}
43
44impl SinkTrait for Sink {
45    async fn send<B: StableBuf>(&mut self, msg: B) -> Result<(), Error> {
46        let (os_send, data) = {
47            let mut channel = self.channel.lock().unwrap();
48            channel.buffer.extend(msg.as_ref());
49
50            // If there is a waiter and the buffer is large enough,
51            // return the waiter (while clearing the waiter field).
52            // Otherwise, return early.
53            if channel
54                .waiter
55                .as_ref()
56                .is_some_and(|(requested, _)| *requested <= channel.buffer.len())
57            {
58                let (requested, os_send) = channel.waiter.take().unwrap();
59                let data: Vec<u8> = channel.buffer.drain(0..requested).collect();
60                (os_send, Bytes::from(data))
61            } else {
62                return Ok(());
63            }
64        };
65
66        // Resolve the waiter.
67        os_send.send(data).map_err(|_| Error::SendFailed)?;
68        Ok(())
69    }
70}
71
72/// A mock stream that implements the Stream trait.
73pub struct Stream {
74    channel: Arc<Mutex<Channel>>,
75}
76
77impl StreamTrait for Stream {
78    async fn recv<B: StableBufMut>(&mut self, mut buf: B) -> Result<B, Error> {
79        let os_recv = {
80            let mut channel = self.channel.lock().unwrap();
81
82            // If the message is fully available in the buffer,
83            // drain the value into buf and return.
84            if channel.buffer.len() >= buf.len() {
85                let b: Vec<u8> = channel.buffer.drain(0..buf.len()).collect();
86                buf.put_slice(&b);
87                return Ok(buf);
88            }
89
90            // Otherwise, populate the waiter.
91            assert!(channel.waiter.is_none());
92            let (os_send, os_recv) = oneshot::channel();
93            channel.waiter = Some((buf.len(), os_send));
94            os_recv
95        };
96
97        // Wait for the waiter to be resolved.
98        let data = os_recv.await.map_err(|_| Error::RecvFailed)?;
99        assert_eq!(data.len(), buf.len());
100        buf.put_slice(&data);
101        Ok(buf)
102    }
103}
104
105#[cfg(test)]
106mod tests {
107    use super::*;
108    use crate::{deterministic, Clock, Runner};
109    use commonware_macros::select;
110    use futures::{executor::block_on, join};
111    use std::{thread::sleep, time::Duration};
112
113    #[test]
114    fn test_send_recv() {
115        let (mut sink, mut stream) = Channel::init();
116        let data = b"hello world".to_vec();
117
118        block_on(async {
119            sink.send(data.clone()).await.unwrap();
120            let buf = stream.recv(vec![0; data.len()]).await.unwrap();
121            assert_eq!(buf, data);
122        });
123    }
124
125    #[test]
126    fn test_send_recv_partial_multiple() {
127        let (mut sink, mut stream) = Channel::init();
128        let data = b"hello".to_vec();
129        let data2 = b" world".to_vec();
130
131        block_on(async {
132            sink.send(data).await.unwrap();
133            sink.send(data2).await.unwrap();
134            let buf = stream.recv(vec![0; 5]).await.unwrap();
135            assert_eq!(buf, b"hello");
136            let buf = stream.recv(buf).await.unwrap();
137            assert_eq!(buf, b" worl");
138            let buf = stream.recv(vec![0; 1]).await.unwrap();
139            assert_eq!(buf, b"d");
140        });
141    }
142
143    #[test]
144    fn test_send_recv_async() {
145        let (mut sink, mut stream) = Channel::init();
146
147        let data = b"hello world";
148        let buf = block_on(async {
149            futures::try_join!(stream.recv(vec![0; data.len()]), async {
150                sleep(Duration::from_millis(10_000));
151                sink.send(data.to_vec()).await
152            },)
153            .unwrap()
154            .0
155        });
156
157        assert_eq!(buf, data);
158    }
159
160    #[test]
161    fn test_recv_error() {
162        let (sink, mut stream) = Channel::init();
163        let executor = deterministic::Runner::default();
164
165        // If the oneshot sender is dropped before the oneshot receiver is resolved,
166        // the recv function should return an error.
167        executor.start(|_| async move {
168            let (v, _) = join!(stream.recv(vec![0; 5]), async {
169                // Take the waiter and drop it.
170                sink.channel.lock().unwrap().waiter.take();
171            },);
172            assert!(matches!(v, Err(Error::RecvFailed)));
173        });
174    }
175
176    #[test]
177    fn test_send_error() {
178        let (mut sink, mut stream) = Channel::init();
179        let executor = deterministic::Runner::default();
180
181        // If the waiter value has a min, but the oneshot receiver is dropped,
182        // the send function should return an error when attempting to send the data.
183        executor.start(|context| async move {
184            // Create a waiter using a recv call.
185            // But then drop the receiver.
186            select! {
187                v = stream.recv( vec![0;5]) => {
188                    panic!("unexpected value: {:?}", v);
189                },
190                _ = context.sleep(Duration::from_millis(100)) => {
191                    "timeout"
192                },
193            };
194            drop(stream);
195
196            // Try to send a message (longer than the requested amount), but the receiver is dropped.
197            let result = sink.send(b"hello world".to_vec()).await;
198            assert!(matches!(result, Err(Error::SendFailed)));
199        });
200    }
201
202    #[test]
203    fn test_recv_timeout() {
204        let (_sink, mut stream) = Channel::init();
205        let executor = deterministic::Runner::default();
206
207        // If there is no data to read, test that the recv function just blocks. A timeout should return first.
208        executor.start(|context| async move {
209            select! {
210                v = stream.recv(vec![0;5]) => {
211                    panic!("unexpected value: {:?}", v);
212                },
213                _ = context.sleep(Duration::from_millis(100)) => {
214                    "timeout"
215                },
216            };
217        });
218    }
219}