Skip to main content

commonware_runtime/
mocks.rs

1//! A mock implementation of a channel that implements the Sink and Stream traits.
2
3use crate::{BufMut, Error, IoBufs, Sink as SinkTrait, Stream as StreamTrait};
4use bytes::{Bytes, BytesMut};
5use commonware_utils::{channel::oneshot, sync::Mutex};
6use std::sync::Arc;
7
8/// Default read buffer size for the stream's local buffer (64 KB).
9const DEFAULT_READ_BUFFER_SIZE: usize = 64 * 1024;
10
11/// A mock channel struct that is used internally by Sink and Stream.
12pub struct Channel {
13    /// Stores the bytes sent by the sink that are not yet read by the stream.
14    buffer: BytesMut,
15
16    /// If the stream is waiting to read bytes, the waiter stores the number of
17    /// bytes that the stream is waiting for, as well as the oneshot sender that
18    /// the sink uses to send the bytes to the stream directly.
19    waiter: Option<(usize, oneshot::Sender<Bytes>)>,
20
21    /// Target size for the stream's local buffer, used to bound buffering.
22    read_buffer_size: usize,
23
24    /// Tracks whether the sink is still alive and able to send messages.
25    sink_alive: bool,
26
27    /// Tracks whether the stream is still alive and able to receive messages.
28    stream_alive: bool,
29}
30
31impl Channel {
32    /// Returns an async-safe Sink/Stream pair with default read buffer size.
33    pub fn init() -> (Sink, Stream) {
34        Self::init_with_read_buffer_size(DEFAULT_READ_BUFFER_SIZE)
35    }
36
37    /// Returns an async-safe Sink/Stream pair with the specified buffer capacity.
38    pub fn init_with_read_buffer_size(read_buffer_size: usize) -> (Sink, Stream) {
39        let channel = Arc::new(Mutex::new(Self {
40            buffer: BytesMut::new(),
41            waiter: None,
42            read_buffer_size,
43            sink_alive: true,
44            stream_alive: true,
45        }));
46        (
47            Sink {
48                channel: channel.clone(),
49            },
50            Stream {
51                channel,
52                buffer: BytesMut::new(),
53            },
54        )
55    }
56}
57
58/// A mock sink that implements the Sink trait.
59pub struct Sink {
60    channel: Arc<Mutex<Channel>>,
61}
62
63impl SinkTrait for Sink {
64    async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
65        let (os_send, data) = {
66            let mut channel = self.channel.lock();
67
68            // If the receiver is dead, we cannot send any more messages.
69            if !channel.stream_alive {
70                return Err(Error::Closed);
71            }
72
73            channel.buffer.put(bufs.into());
74
75            // If there is a waiter and the buffer is large enough,
76            // return the waiter (while clearing the waiter field).
77            // Otherwise, return early.
78            if channel
79                .waiter
80                .as_ref()
81                .is_some_and(|(requested, _)| *requested <= channel.buffer.len())
82            {
83                // Send up to read_buffer_size bytes (but at least requested amount)
84                let (requested, os_send) = channel.waiter.take().unwrap();
85                let send_amount = channel
86                    .buffer
87                    .len()
88                    .min(requested.max(channel.read_buffer_size));
89                let data = channel.buffer.split_to(send_amount).freeze();
90                (os_send, data)
91            } else {
92                return Ok(());
93            }
94        };
95
96        // Resolve the waiter.
97        os_send.send(data).map_err(|_| Error::SendFailed)?;
98        Ok(())
99    }
100}
101
102impl Drop for Sink {
103    fn drop(&mut self) {
104        let mut channel = self.channel.lock();
105        channel.sink_alive = false;
106
107        // If there is a waiter, resolve it by dropping the oneshot sender.
108        channel.waiter.take();
109    }
110}
111
112/// A mock stream that implements the Stream trait.
113pub struct Stream {
114    channel: Arc<Mutex<Channel>>,
115    /// Local buffer for data that has been received but not yet consumed.
116    buffer: BytesMut,
117}
118
119impl StreamTrait for Stream {
120    async fn recv(&mut self, len: usize) -> Result<IoBufs, Error> {
121        let os_recv = {
122            let mut channel = self.channel.lock();
123
124            // Pull data from channel buffer into local buffer.
125            if !channel.buffer.is_empty() {
126                let target = len.max(channel.read_buffer_size);
127                let pull_amount = channel
128                    .buffer
129                    .len()
130                    .min(target.saturating_sub(self.buffer.len()));
131                if pull_amount > 0 {
132                    let data = channel.buffer.split_to(pull_amount);
133                    self.buffer.extend_from_slice(&data);
134                }
135            }
136
137            // If we have enough, return immediately.
138            if self.buffer.len() >= len {
139                return Ok(IoBufs::from(self.buffer.split_to(len).freeze()));
140            }
141
142            // If the sink is dead, we cannot receive any more messages.
143            if !channel.sink_alive {
144                return Err(Error::Closed);
145            }
146
147            // Set up waiter for remaining amount.
148            let remaining = len - self.buffer.len();
149            assert!(channel.waiter.is_none());
150            let (os_send, os_recv) = oneshot::channel();
151            channel.waiter = Some((remaining, os_send));
152            os_recv
153        };
154
155        // Wait for the waiter to be resolved.
156        let data = os_recv.await.map_err(|_| Error::Closed)?;
157        self.buffer.extend_from_slice(&data);
158
159        assert!(self.buffer.len() >= len);
160        Ok(IoBufs::from(self.buffer.split_to(len).freeze()))
161    }
162
163    fn peek(&self, max_len: usize) -> &[u8] {
164        let len = max_len.min(self.buffer.len());
165        &self.buffer[..len]
166    }
167}
168
169impl Drop for Stream {
170    fn drop(&mut self) {
171        let mut channel = self.channel.lock();
172        channel.stream_alive = false;
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::{deterministic, Clock, Runner, Spawner};
180    use commonware_macros::select;
181    use std::{thread::sleep, time::Duration};
182
183    #[test]
184    fn test_send_recv() {
185        let (mut sink, mut stream) = Channel::init();
186        let data = b"hello world";
187
188        let executor = deterministic::Runner::default();
189        executor.start(|_| async move {
190            sink.send(data.as_slice()).await.unwrap();
191            let received = stream.recv(data.len()).await.unwrap();
192            assert_eq!(received.coalesce(), data);
193        });
194    }
195
196    #[test]
197    fn test_send_recv_partial_multiple() {
198        let (mut sink, mut stream) = Channel::init();
199        let data = b"hello";
200        let data2 = b" world";
201
202        let executor = deterministic::Runner::default();
203        executor.start(|_| async move {
204            sink.send(data.as_slice()).await.unwrap();
205            sink.send(data2.as_slice()).await.unwrap();
206            let received = stream.recv(5).await.unwrap();
207            assert_eq!(received.coalesce(), b"hello");
208            let received = stream.recv(5).await.unwrap();
209            assert_eq!(received.coalesce(), b" worl");
210            let received = stream.recv(1).await.unwrap();
211            assert_eq!(received.coalesce(), b"d");
212        });
213    }
214
215    #[test]
216    fn test_send_recv_async() {
217        let (mut sink, mut stream) = Channel::init();
218        let data = b"hello world";
219
220        let executor = deterministic::Runner::default();
221        executor.start(|_| async move {
222            let (received, _) = futures::try_join!(stream.recv(data.len()), async {
223                sleep(Duration::from_millis(50));
224                sink.send(data.as_slice()).await
225            })
226            .unwrap();
227            assert_eq!(received.coalesce(), data);
228        });
229    }
230
231    #[test]
232    fn test_recv_error_sink_dropped_while_waiting() {
233        let (sink, mut stream) = Channel::init();
234
235        let executor = deterministic::Runner::default();
236        executor.start(|context| async move {
237            futures::join!(
238                async {
239                    let result = stream.recv(5).await;
240                    assert!(matches!(result, Err(Error::Closed)));
241                },
242                async {
243                    // Wait for the stream to start waiting
244                    context.sleep(Duration::from_millis(50)).await;
245                    drop(sink);
246                }
247            );
248        });
249    }
250
251    #[test]
252    fn test_recv_error_sink_dropped_before_recv() {
253        let (sink, mut stream) = Channel::init();
254        drop(sink); // Drop sink immediately
255
256        let executor = deterministic::Runner::default();
257        executor.start(|_| async move {
258            let result = stream.recv(5).await;
259            assert!(matches!(result, Err(Error::Closed)));
260        });
261    }
262
263    #[test]
264    fn test_send_error_stream_dropped() {
265        let (mut sink, mut stream) = Channel::init();
266
267        let executor = deterministic::Runner::default();
268        executor.start(|context| async move {
269            // Send some bytes
270            assert!(sink.send(b"7 bytes".as_slice()).await.is_ok());
271
272            // Spawn a task to initiate recv's where the first one will succeed and then will drop.
273            let handle = context.clone().spawn(|_| async move {
274                let _ = stream.recv(5).await;
275                let _ = stream.recv(5).await;
276            });
277
278            // Give the async task a moment to start
279            context.sleep(Duration::from_millis(50)).await;
280
281            // Drop the stream by aborting the handle
282            handle.abort();
283            assert!(matches!(handle.await, Err(Error::Closed)));
284
285            // Try to send a message. The stream is dropped, so this should fail.
286            let result = sink.send(b"hello world".as_slice()).await;
287            assert!(matches!(result, Err(Error::Closed)));
288        });
289    }
290
291    #[test]
292    fn test_send_error_stream_dropped_before_send() {
293        let (mut sink, stream) = Channel::init();
294        drop(stream); // Drop stream immediately
295
296        let executor = deterministic::Runner::default();
297        executor.start(|_| async move {
298            let result = sink.send(b"hello world".as_slice()).await;
299            assert!(matches!(result, Err(Error::Closed)));
300        });
301    }
302
303    #[test]
304    fn test_recv_timeout() {
305        let (_sink, mut stream) = Channel::init();
306
307        // If there is no data to read, test that the recv function just blocks.
308        // The timeout should return first.
309        let executor = deterministic::Runner::default();
310        executor.start(|context| async move {
311            select! {
312                v = stream.recv(5) => {
313                    panic!("unexpected value: {v:?}");
314                },
315                _ = context.sleep(Duration::from_millis(100)) => "timeout",
316            };
317        });
318    }
319
320    #[test]
321    fn test_peek_empty() {
322        let (_sink, stream) = Channel::init();
323
324        // Peek on a fresh stream should return empty slice
325        assert!(stream.peek(10).is_empty());
326    }
327
328    #[test]
329    fn test_peek_after_partial_recv() {
330        let (mut sink, mut stream) = Channel::init();
331
332        let executor = deterministic::Runner::default();
333        executor.start(|_| async move {
334            // Send more data than we'll consume
335            sink.send(b"hello world".as_slice()).await.unwrap();
336
337            // Recv only part of it
338            let received = stream.recv(5).await.unwrap();
339            assert_eq!(received.coalesce(), b"hello");
340
341            // Peek should show the remaining data
342            assert_eq!(stream.peek(100), b" world");
343
344            // Peek with smaller max_len
345            assert_eq!(stream.peek(3), b" wo");
346
347            // Peek doesn't consume - can peek again
348            assert_eq!(stream.peek(100), b" world");
349
350            // Recv consumes the peeked data
351            let received = stream.recv(6).await.unwrap();
352            assert_eq!(received.coalesce(), b" world");
353
354            // Peek is now empty
355            assert!(stream.peek(100).is_empty());
356        });
357    }
358
359    #[test]
360    fn test_peek_after_recv_wakeup() {
361        let (mut sink, mut stream) = Channel::init_with_read_buffer_size(64);
362
363        let executor = deterministic::Runner::default();
364        executor.start(|context| async move {
365            // Spawn recv that will block waiting
366            let (tx, rx) = oneshot::channel();
367            let recv_handle = context.clone().spawn(|_| async move {
368                let data = stream.recv(3).await.unwrap();
369                tx.send(stream).ok();
370                data
371            });
372
373            // Let recv set up waiter
374            context.sleep(Duration::from_millis(10)).await;
375
376            // Send more than requested
377            sink.send(b"ABCDEFGHIJ".as_slice()).await.unwrap();
378
379            // Recv gets its 3 bytes
380            let received = recv_handle.await.unwrap();
381            assert_eq!(received.coalesce(), b"ABC");
382
383            // Get stream back and verify peek sees remaining data
384            let stream = rx.await.unwrap();
385            assert_eq!(stream.peek(100), b"DEFGHIJ");
386        });
387    }
388
389    #[test]
390    fn test_peek_multiple_sends() {
391        let (mut sink, mut stream) = Channel::init();
392
393        let executor = deterministic::Runner::default();
394        executor.start(|_| async move {
395            // Send multiple chunks
396            sink.send(b"aaa".as_slice()).await.unwrap();
397            sink.send(b"bbb".as_slice()).await.unwrap();
398            sink.send(b"ccc".as_slice()).await.unwrap();
399
400            // Recv less than total
401            let received = stream.recv(4).await.unwrap();
402            assert_eq!(received.coalesce(), b"aaab");
403
404            // Peek should show remaining
405            assert_eq!(stream.peek(100), b"bbccc");
406        });
407    }
408
409    #[test]
410    fn test_read_buffer_size_limit() {
411        // Use a small buffer capacity for testing
412        let (mut sink, mut stream) = Channel::init_with_read_buffer_size(10);
413
414        let executor = deterministic::Runner::default();
415        executor.start(|_| async move {
416            // Send more than buffer capacity
417            sink.send(b"0123456789ABCDEF".as_slice()).await.unwrap();
418
419            // Recv a small amount - should only pull up to capacity (10 bytes)
420            let received = stream.recv(2).await.unwrap();
421            assert_eq!(received.coalesce(), b"01");
422
423            // Peek should show remaining buffered data (8 bytes, not 14)
424            assert_eq!(stream.peek(100), b"23456789");
425
426            // The rest should still be in the channel buffer
427            // Recv more to pull the remaining data
428            let received = stream.recv(8).await.unwrap();
429            assert_eq!(received.coalesce(), b"23456789");
430
431            // Now peek should show next chunk from channel (up to capacity)
432            let received = stream.recv(2).await.unwrap();
433            assert_eq!(received.coalesce(), b"AB");
434
435            assert_eq!(stream.peek(100), b"CDEF");
436        });
437    }
438
439    #[test]
440    fn test_recv_before_send() {
441        // Use a small buffer capacity for testing
442        let (mut sink, mut stream) = Channel::init_with_read_buffer_size(10);
443
444        let executor = deterministic::Runner::default();
445        executor.start(|context| async move {
446            // Start recv before send (will wait)
447            let recv_handle = context
448                .clone()
449                .spawn(|_| async move { stream.recv(3).await.unwrap() });
450
451            // Give recv time to set up waiter
452            context.sleep(Duration::from_millis(10)).await;
453
454            // Send more than capacity
455            sink.send(b"ABCDEFGHIJKLMNOP".as_slice()).await.unwrap();
456
457            // Recv should get its 3 bytes
458            let received = recv_handle.await.unwrap();
459            assert_eq!(received.coalesce(), b"ABC");
460        });
461    }
462}