commonware_runtime/
mocks.rs

1//! A mock implementation of a channel that implements the Sink and Stream traits.
2
3use crate::{Error, Sink as SinkTrait, Stream as StreamTrait};
4use bytes::{Buf, BufMut, Bytes, BytesMut};
5use futures::channel::oneshot;
6use std::sync::{Arc, Mutex};
7
8/// A mock channel struct that is used internally by Sink and Stream.
9pub struct Channel {
10    /// Stores the bytes sent by the sink that are not yet read by the stream.
11    buffer: BytesMut,
12
13    /// If the stream is waiting to read bytes, the waiter stores the number of
14    /// bytes that the stream is waiting for, as well as the oneshot sender that
15    /// the sink uses to send the bytes to the stream directly.
16    waiter: Option<(usize, oneshot::Sender<Bytes>)>,
17
18    /// Tracks whether the sink is still alive and able to send messages.
19    sink_alive: bool,
20
21    /// Tracks whether the stream is still alive and able to receive messages.
22    stream_alive: bool,
23}
24
25impl Channel {
26    /// Returns an async-safe Sink/Stream pair that share an underlying buffer of bytes.
27    pub fn init() -> (Sink, Stream) {
28        let channel = Arc::new(Mutex::new(Self {
29            buffer: BytesMut::new(),
30            waiter: None,
31            sink_alive: true,
32            stream_alive: true,
33        }));
34        (
35            Sink {
36                channel: channel.clone(),
37            },
38            Stream { channel },
39        )
40    }
41}
42
43/// A mock sink that implements the Sink trait.
44pub struct Sink {
45    channel: Arc<Mutex<Channel>>,
46}
47
48impl SinkTrait for Sink {
49    async fn send(&mut self, buf: impl Buf + Send) -> Result<(), Error> {
50        let (os_send, data) = {
51            let mut channel = self.channel.lock().unwrap();
52
53            // If the receiver is dead, we cannot send any more messages.
54            if !channel.stream_alive {
55                return Err(Error::Closed);
56            }
57
58            channel.buffer.put(buf);
59
60            // If there is a waiter and the buffer is large enough,
61            // return the waiter (while clearing the waiter field).
62            // Otherwise, return early.
63            if channel
64                .waiter
65                .as_ref()
66                .is_some_and(|(requested, _)| *requested <= channel.buffer.len())
67            {
68                let (requested, os_send) = channel.waiter.take().unwrap();
69                let data = channel.buffer.copy_to_bytes(requested);
70                (os_send, data)
71            } else {
72                return Ok(());
73            }
74        };
75
76        // Resolve the waiter.
77        os_send.send(data).map_err(|_| Error::SendFailed)?;
78        Ok(())
79    }
80}
81
82impl Drop for Sink {
83    fn drop(&mut self) {
84        let mut channel = self.channel.lock().unwrap();
85        channel.sink_alive = false;
86
87        // If there is a waiter, resolve it by dropping the oneshot sender.
88        channel.waiter.take();
89    }
90}
91
92/// A mock stream that implements the Stream trait.
93pub struct Stream {
94    channel: Arc<Mutex<Channel>>,
95}
96
97impl StreamTrait for Stream {
98    async fn recv(&mut self, mut buf: impl BufMut + Send) -> Result<(), Error> {
99        let os_recv = {
100            let mut channel = self.channel.lock().unwrap();
101
102            // If the message is fully available in the buffer,
103            // drain the value into buf and return.
104            if channel.buffer.len() >= buf.remaining_mut() {
105                let b = channel.buffer.copy_to_bytes(buf.remaining_mut());
106                buf.put_slice(&b);
107                return Ok(());
108            }
109
110            // At this point, there is not enough data in the buffer.
111            // If the stream is dead, we cannot receive any more messages.
112            if !channel.sink_alive {
113                return Err(Error::Closed);
114            }
115
116            // Otherwise, populate the waiter.
117            assert!(channel.waiter.is_none());
118            let (os_send, os_recv) = oneshot::channel();
119            channel.waiter = Some((buf.remaining_mut(), os_send));
120            os_recv
121        };
122
123        // Wait for the waiter to be resolved.
124        // If the oneshot sender was dropped, it means the sink is closed.
125        let data = os_recv.await.map_err(|_| Error::Closed)?;
126        assert_eq!(data.len(), buf.remaining_mut());
127        buf.put_slice(&data);
128        Ok(())
129    }
130}
131
132impl Drop for Stream {
133    fn drop(&mut self) {
134        let mut channel = self.channel.lock().unwrap();
135        channel.stream_alive = false;
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::{deterministic, Clock, Runner, Spawner};
143    use commonware_macros::select;
144    use std::{thread::sleep, time::Duration};
145
146    #[test]
147    fn test_send_recv() {
148        let (mut sink, mut stream) = Channel::init();
149        let data = b"hello world";
150
151        let executor = deterministic::Runner::default();
152        executor.start(|_| async move {
153            sink.send(data.as_slice()).await.unwrap();
154            let mut buf = vec![0u8; data.len()];
155            stream.recv(&mut buf[..]).await.unwrap();
156            assert_eq!(&buf[..], data);
157        });
158    }
159
160    #[test]
161    fn test_send_recv_partial_multiple() {
162        let (mut sink, mut stream) = Channel::init();
163        let data = b"hello";
164        let data2 = b" world";
165
166        let executor = deterministic::Runner::default();
167        executor.start(|_| async move {
168            sink.send(data.as_slice()).await.unwrap();
169            sink.send(data2.as_slice()).await.unwrap();
170            let mut buf = [0u8; 5];
171            stream.recv(&mut buf[..]).await.unwrap();
172            assert_eq!(&buf[..], b"hello");
173            stream.recv(&mut buf[..]).await.unwrap();
174            assert_eq!(&buf[..], b" worl");
175            let mut buf = [0u8; 1];
176            stream.recv(&mut buf[..]).await.unwrap();
177            assert_eq!(&buf[..], b"d");
178        });
179    }
180
181    #[test]
182    fn test_send_recv_async() {
183        let (mut sink, mut stream) = Channel::init();
184        let data = b"hello world";
185
186        let executor = deterministic::Runner::default();
187        executor.start(|_| async move {
188            let mut buf = vec![0; data.len()];
189            let (_, _) = futures::try_join!(stream.recv(&mut buf[..]), async {
190                sleep(Duration::from_millis(50));
191                sink.send(data.as_slice()).await
192            })
193            .unwrap();
194            assert_eq!(&buf[..], data);
195        });
196    }
197
198    #[test]
199    fn test_recv_error_sink_dropped_while_waiting() {
200        let (sink, mut stream) = Channel::init();
201
202        let executor = deterministic::Runner::default();
203        executor.start(|context| async move {
204            futures::join!(
205                async {
206                    let mut buf = [0u8; 5];
207                    let result = stream.recv(&mut buf[..]).await;
208                    assert!(matches!(result, Err(Error::Closed)));
209                },
210                async {
211                    // Wait for the stream to start waiting
212                    context.sleep(Duration::from_millis(50)).await;
213                    drop(sink);
214                }
215            );
216        });
217    }
218
219    #[test]
220    fn test_recv_error_sink_dropped_before_recv() {
221        let (sink, mut stream) = Channel::init();
222        drop(sink); // Drop sink immediately
223
224        let executor = deterministic::Runner::default();
225        executor.start(|_| async move {
226            let mut buf = [0u8; 5];
227            let result = stream.recv(&mut buf[..]).await;
228            assert!(matches!(result, Err(Error::Closed)));
229        });
230    }
231
232    #[test]
233    fn test_send_error_stream_dropped() {
234        let (mut sink, mut stream) = Channel::init();
235
236        let executor = deterministic::Runner::default();
237        executor.start(|context| async move {
238            // Send some bytes
239            assert!(sink.send(b"7 bytes".as_slice()).await.is_ok());
240
241            // Spawn a task to initiate recv's where the first one will succeed and then will drop.
242            let handle = context.clone().spawn(|_| async move {
243                let mut buf = [0u8; 5];
244                let _ = stream.recv(&mut buf[..]).await;
245                let _ = stream.recv(&mut buf[..]).await;
246            });
247
248            // Give the async task a moment to start
249            context.sleep(Duration::from_millis(50)).await;
250
251            // Drop the stream by aborting the handle
252            handle.abort();
253            assert!(matches!(handle.await, Err(Error::Closed)));
254
255            // Try to send a message. The stream is dropped, so this should fail.
256            let result = sink.send(b"hello world".as_slice()).await;
257            assert!(matches!(result, Err(Error::Closed)));
258        });
259    }
260
261    #[test]
262    fn test_send_error_stream_dropped_before_send() {
263        let (mut sink, stream) = Channel::init();
264        drop(stream); // Drop stream immediately
265
266        let executor = deterministic::Runner::default();
267        executor.start(|_| async move {
268            let result = sink.send(b"hello world".as_slice()).await;
269            assert!(matches!(result, Err(Error::Closed)));
270        });
271    }
272
273    #[test]
274    fn test_recv_timeout() {
275        let (_sink, mut stream) = Channel::init();
276
277        // If there is no data to read, test that the recv function just blocks.
278        // The timeout should return first.
279        let executor = deterministic::Runner::default();
280        executor.start(|context| async move {
281            let mut buf = [0u8; 5];
282            select! {
283                v = stream.recv(&mut buf[..]) => {
284                    panic!("unexpected value: {v:?}");
285                },
286                _ = context.sleep(Duration::from_millis(100)) => {
287                    "timeout"
288                },
289            };
290        });
291    }
292}