commonware_runtime/
mocks.rs

1//! A mock implementation of a channel that implements the Sink and Stream traits.
2
3use crate::{Error, Sink as SinkTrait, Stream as StreamTrait};
4use bytes::Bytes;
5use commonware_utils::StableBuf;
6use futures::channel::oneshot;
7use std::{
8    collections::VecDeque,
9    sync::{Arc, Mutex},
10};
11
12/// A mock channel struct that is used internally by Sink and Stream.
13pub struct Channel {
14    /// Stores the bytes sent by the sink that are not yet read by the stream.
15    buffer: VecDeque<u8>,
16
17    /// If the stream is waiting to read bytes, the waiter stores the number of
18    /// bytes that the stream is waiting for, as well as the oneshot sender that
19    /// the sink uses to send the bytes to the stream directly.
20    waiter: Option<(usize, oneshot::Sender<Bytes>)>,
21
22    /// Tracks whether the sink is still alive and able to send messages.
23    sink_alive: bool,
24
25    /// Tracks whether the stream is still alive and able to receive messages.
26    stream_alive: bool,
27}
28
29impl Channel {
30    /// Returns an async-safe Sink/Stream pair that share an underlying buffer of bytes.
31    pub fn init() -> (Sink, Stream) {
32        let channel = Arc::new(Mutex::new(Channel {
33            buffer: VecDeque::new(),
34            waiter: None,
35            sink_alive: true,
36            stream_alive: true,
37        }));
38        (
39            Sink {
40                channel: channel.clone(),
41            },
42            Stream { channel },
43        )
44    }
45}
46
47/// A mock sink that implements the Sink trait.
48pub struct Sink {
49    channel: Arc<Mutex<Channel>>,
50}
51
52impl SinkTrait for Sink {
53    async fn send(&mut self, msg: impl Into<StableBuf> + Send) -> Result<(), Error> {
54        let msg = msg.into();
55        let (os_send, data) = {
56            let mut channel = self.channel.lock().unwrap();
57
58            // If the receiver is dead, we cannot send any more messages.
59            if !channel.stream_alive {
60                return Err(Error::Closed);
61            }
62
63            // Add the data to the buffer.
64            channel.buffer.extend(msg.as_ref());
65
66            // If there is a waiter and the buffer is large enough,
67            // return the waiter (while clearing the waiter field).
68            // Otherwise, return early.
69            if channel
70                .waiter
71                .as_ref()
72                .is_some_and(|(requested, _)| *requested <= channel.buffer.len())
73            {
74                let (requested, os_send) = channel.waiter.take().unwrap();
75                let data: Vec<u8> = channel.buffer.drain(0..requested).collect();
76                (os_send, Bytes::from(data))
77            } else {
78                return Ok(());
79            }
80        };
81
82        // Resolve the waiter.
83        os_send.send(data).map_err(|_| Error::SendFailed)?;
84        Ok(())
85    }
86}
87
88impl Drop for Sink {
89    fn drop(&mut self) {
90        let mut channel = self.channel.lock().unwrap();
91        channel.sink_alive = false;
92
93        // If there is a waiter, resolve it by dropping the oneshot sender.
94        channel.waiter.take();
95    }
96}
97
98/// A mock stream that implements the Stream trait.
99pub struct Stream {
100    channel: Arc<Mutex<Channel>>,
101}
102
103impl StreamTrait for Stream {
104    async fn recv(&mut self, buf: impl Into<StableBuf> + Send) -> Result<StableBuf, Error> {
105        let mut buf = buf.into();
106        let os_recv = {
107            let mut channel = self.channel.lock().unwrap();
108
109            // If the message is fully available in the buffer,
110            // drain the value into buf and return.
111            if channel.buffer.len() >= buf.len() {
112                let b: Vec<u8> = channel.buffer.drain(0..buf.len()).collect();
113                buf.put_slice(&b);
114                return Ok(buf);
115            }
116
117            // At this point, there is not enough data in the buffer.
118            // If the stream is dead, we cannot receive any more messages.
119            if !channel.sink_alive {
120                return Err(Error::Closed);
121            }
122
123            // Otherwise, populate the waiter.
124            assert!(channel.waiter.is_none());
125            let (os_send, os_recv) = oneshot::channel();
126            channel.waiter = Some((buf.len(), os_send));
127            os_recv
128        };
129
130        // Wait for the waiter to be resolved.
131        // If the oneshot sender was dropped, it means the sink is closed.
132        let data = os_recv.await.map_err(|_| Error::Closed)?;
133        assert_eq!(data.len(), buf.len());
134        buf.put_slice(&data);
135        Ok(buf)
136    }
137}
138
139impl Drop for Stream {
140    fn drop(&mut self) {
141        let mut channel = self.channel.lock().unwrap();
142        channel.stream_alive = false;
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use crate::{deterministic, Clock, Runner, Spawner};
150    use commonware_macros::select;
151    use std::{thread::sleep, time::Duration};
152
153    #[test]
154    fn test_send_recv() {
155        let (mut sink, mut stream) = Channel::init();
156        let data = b"hello world".to_vec();
157
158        let executor = deterministic::Runner::default();
159        executor.start(|_| async move {
160            sink.send(data.clone()).await.unwrap();
161            let buf = stream.recv(vec![0; data.len()]).await.unwrap();
162            assert_eq!(buf.as_ref(), data);
163        });
164    }
165
166    #[test]
167    fn test_send_recv_partial_multiple() {
168        let (mut sink, mut stream) = Channel::init();
169        let data = b"hello".to_vec();
170        let data2 = b" world".to_vec();
171
172        let executor = deterministic::Runner::default();
173        executor.start(|_| async move {
174            sink.send(data).await.unwrap();
175            sink.send(data2).await.unwrap();
176            let buf = stream.recv(vec![0; 5]).await.unwrap();
177            assert_eq!(buf.as_ref(), b"hello");
178            let buf = stream.recv(buf).await.unwrap();
179            assert_eq!(buf.as_ref(), b" worl");
180            let buf = stream.recv(vec![0; 1]).await.unwrap();
181            assert_eq!(buf.as_ref(), b"d");
182        });
183    }
184
185    #[test]
186    fn test_send_recv_async() {
187        let (mut sink, mut stream) = Channel::init();
188        let data = b"hello world";
189
190        let executor = deterministic::Runner::default();
191        executor.start(|_| async move {
192            let (buf, _) = futures::try_join!(stream.recv(vec![0; data.len()]), async {
193                sleep(Duration::from_millis(50));
194                sink.send(data.to_vec()).await
195            })
196            .unwrap();
197            assert_eq!(buf.as_ref(), data);
198        });
199    }
200
201    #[test]
202    fn test_recv_error_sink_dropped_while_waiting() {
203        let (sink, mut stream) = Channel::init();
204
205        let executor = deterministic::Runner::default();
206        executor.start(|context| async move {
207            futures::join!(
208                async {
209                    let result = stream.recv(vec![0; 5]).await;
210                    assert!(matches!(result, Err(Error::Closed)));
211                },
212                async {
213                    // Wait for the stream to start waiting
214                    context.sleep(Duration::from_millis(50)).await;
215                    drop(sink);
216                }
217            );
218        });
219    }
220
221    #[test]
222    fn test_recv_error_sink_dropped_before_recv() {
223        let (sink, mut stream) = Channel::init();
224        drop(sink); // Drop sink immediately
225
226        let executor = deterministic::Runner::default();
227        executor.start(|_| async move {
228            let result = stream.recv(vec![0; 5]).await;
229            assert!(matches!(result, Err(Error::Closed)));
230        });
231    }
232
233    #[test]
234    fn test_send_error_stream_dropped() {
235        let (mut sink, mut stream) = Channel::init();
236
237        let executor = deterministic::Runner::default();
238        executor.start(|context| async move {
239            // Send some bytes
240            assert!(sink.send(b"7 bytes".to_vec()).await.is_ok());
241
242            // Spawn a task to initiate recv's where the first one will succeed and then will drop.
243            let handle = context.clone().spawn(|_| async move {
244                let _ = stream.recv(vec![0; 5]).await;
245                let _ = stream.recv(vec![0; 5]).await;
246            });
247
248            // Give the async task a moment to start
249            context.sleep(Duration::from_millis(50)).await;
250
251            // Drop the stream by aborting the handle
252            handle.abort();
253            assert!(matches!(handle.await, Err(Error::Closed)));
254
255            // Try to send a message. The stream is dropped, so this should fail.
256            let result = sink.send(b"hello world".to_vec()).await;
257            assert!(matches!(result, Err(Error::Closed)));
258        });
259    }
260
261    #[test]
262    fn test_send_error_stream_dropped_before_send() {
263        let (mut sink, stream) = Channel::init();
264        drop(stream); // Drop stream immediately
265
266        let executor = deterministic::Runner::default();
267        executor.start(|_| async move {
268            let result = sink.send(b"hello world".to_vec()).await;
269            assert!(matches!(result, Err(Error::Closed)));
270        });
271    }
272
273    #[test]
274    fn test_recv_timeout() {
275        let (_sink, mut stream) = Channel::init();
276
277        // If there is no data to read, test that the recv function just blocks.
278        // The timeout should return first.
279        let executor = deterministic::Runner::default();
280        executor.start(|context| async move {
281            select! {
282                v = stream.recv(vec![0;5]) => {
283                    panic!("unexpected value: {v:?}");
284                },
285                _ = context.sleep(Duration::from_millis(100)) => {
286                    "timeout"
287                },
288            };
289        });
290    }
291}