Skip to main content

commonware_runtime/
mocks.rs

1//! A mock implementation of a channel that implements the Sink and Stream traits.
2
3use crate::{BufMut, Error, IoBufs, Sink as SinkTrait, Stream as StreamTrait};
4use bytes::{Bytes, BytesMut};
5use commonware_utils::channel::oneshot;
6use std::sync::{Arc, Mutex};
7
8/// Default read buffer size for the stream's local buffer (64 KB).
9const DEFAULT_READ_BUFFER_SIZE: usize = 64 * 1024;
10
11/// A mock channel struct that is used internally by Sink and Stream.
12pub struct Channel {
13    /// Stores the bytes sent by the sink that are not yet read by the stream.
14    buffer: BytesMut,
15
16    /// If the stream is waiting to read bytes, the waiter stores the number of
17    /// bytes that the stream is waiting for, as well as the oneshot sender that
18    /// the sink uses to send the bytes to the stream directly.
19    waiter: Option<(usize, oneshot::Sender<Bytes>)>,
20
21    /// Target size for the stream's local buffer, used to bound buffering.
22    read_buffer_size: usize,
23
24    /// Tracks whether the sink is still alive and able to send messages.
25    sink_alive: bool,
26
27    /// Tracks whether the stream is still alive and able to receive messages.
28    stream_alive: bool,
29}
30
31impl Channel {
32    /// Returns an async-safe Sink/Stream pair with default read buffer size.
33    pub fn init() -> (Sink, Stream) {
34        Self::init_with_read_buffer_size(DEFAULT_READ_BUFFER_SIZE)
35    }
36
37    /// Returns an async-safe Sink/Stream pair with the specified buffer capacity.
38    pub fn init_with_read_buffer_size(read_buffer_size: usize) -> (Sink, Stream) {
39        let channel = Arc::new(Mutex::new(Self {
40            buffer: BytesMut::new(),
41            waiter: None,
42            read_buffer_size,
43            sink_alive: true,
44            stream_alive: true,
45        }));
46        (
47            Sink {
48                channel: channel.clone(),
49            },
50            Stream {
51                channel,
52                buffer: BytesMut::new(),
53            },
54        )
55    }
56}
57
58/// A mock sink that implements the Sink trait.
59pub struct Sink {
60    channel: Arc<Mutex<Channel>>,
61}
62
63impl SinkTrait for Sink {
64    async fn send(&mut self, buf: impl Into<IoBufs> + Send) -> Result<(), Error> {
65        let (os_send, data) = {
66            let mut channel = self.channel.lock().unwrap();
67
68            // If the receiver is dead, we cannot send any more messages.
69            if !channel.stream_alive {
70                return Err(Error::Closed);
71            }
72
73            channel.buffer.put(buf.into());
74
75            // If there is a waiter and the buffer is large enough,
76            // return the waiter (while clearing the waiter field).
77            // Otherwise, return early.
78            if channel
79                .waiter
80                .as_ref()
81                .is_some_and(|(requested, _)| *requested <= channel.buffer.len())
82            {
83                // Send up to read_buffer_size bytes (but at least requested amount)
84                let (requested, os_send) = channel.waiter.take().unwrap();
85                let send_amount = channel
86                    .buffer
87                    .len()
88                    .min(requested.max(channel.read_buffer_size));
89                let data = channel.buffer.split_to(send_amount).freeze();
90                (os_send, data)
91            } else {
92                return Ok(());
93            }
94        };
95
96        // Resolve the waiter.
97        os_send.send(data).map_err(|_| Error::SendFailed)?;
98        Ok(())
99    }
100}
101
102impl Drop for Sink {
103    fn drop(&mut self) {
104        let mut channel = self.channel.lock().unwrap();
105        channel.sink_alive = false;
106
107        // If there is a waiter, resolve it by dropping the oneshot sender.
108        channel.waiter.take();
109    }
110}
111
112/// A mock stream that implements the Stream trait.
113pub struct Stream {
114    channel: Arc<Mutex<Channel>>,
115    /// Local buffer for data that has been received but not yet consumed.
116    buffer: BytesMut,
117}
118
119impl StreamTrait for Stream {
120    async fn recv(&mut self, len: u64) -> Result<IoBufs, Error> {
121        let len = len as usize;
122
123        let os_recv = {
124            let mut channel = self.channel.lock().unwrap();
125
126            // Pull data from channel buffer into local buffer.
127            if !channel.buffer.is_empty() {
128                let target = len.max(channel.read_buffer_size);
129                let pull_amount = channel
130                    .buffer
131                    .len()
132                    .min(target.saturating_sub(self.buffer.len()));
133                if pull_amount > 0 {
134                    let data = channel.buffer.split_to(pull_amount);
135                    self.buffer.extend_from_slice(&data);
136                }
137            }
138
139            // If we have enough, return immediately.
140            if self.buffer.len() >= len {
141                return Ok(IoBufs::from(self.buffer.split_to(len).freeze()));
142            }
143
144            // If the sink is dead, we cannot receive any more messages.
145            if !channel.sink_alive {
146                return Err(Error::Closed);
147            }
148
149            // Set up waiter for remaining amount.
150            let remaining = len - self.buffer.len();
151            assert!(channel.waiter.is_none());
152            let (os_send, os_recv) = oneshot::channel();
153            channel.waiter = Some((remaining, os_send));
154            os_recv
155        };
156
157        // Wait for the waiter to be resolved.
158        let data = os_recv.await.map_err(|_| Error::Closed)?;
159        self.buffer.extend_from_slice(&data);
160
161        assert!(self.buffer.len() >= len);
162        Ok(IoBufs::from(self.buffer.split_to(len).freeze()))
163    }
164
165    fn peek(&self, max_len: u64) -> &[u8] {
166        let len = (max_len as usize).min(self.buffer.len());
167        &self.buffer[..len]
168    }
169}
170
171impl Drop for Stream {
172    fn drop(&mut self) {
173        let mut channel = self.channel.lock().unwrap();
174        channel.stream_alive = false;
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::{deterministic, Clock, Runner, Spawner};
182    use commonware_macros::select;
183    use std::{thread::sleep, time::Duration};
184
185    #[test]
186    fn test_send_recv() {
187        let (mut sink, mut stream) = Channel::init();
188        let data = b"hello world";
189
190        let executor = deterministic::Runner::default();
191        executor.start(|_| async move {
192            sink.send(data.as_slice()).await.unwrap();
193            let received = stream.recv(data.len() as u64).await.unwrap();
194            assert_eq!(received.coalesce(), data);
195        });
196    }
197
198    #[test]
199    fn test_send_recv_partial_multiple() {
200        let (mut sink, mut stream) = Channel::init();
201        let data = b"hello";
202        let data2 = b" world";
203
204        let executor = deterministic::Runner::default();
205        executor.start(|_| async move {
206            sink.send(data.as_slice()).await.unwrap();
207            sink.send(data2.as_slice()).await.unwrap();
208            let received = stream.recv(5).await.unwrap();
209            assert_eq!(received.coalesce(), b"hello");
210            let received = stream.recv(5).await.unwrap();
211            assert_eq!(received.coalesce(), b" worl");
212            let received = stream.recv(1).await.unwrap();
213            assert_eq!(received.coalesce(), b"d");
214        });
215    }
216
217    #[test]
218    fn test_send_recv_async() {
219        let (mut sink, mut stream) = Channel::init();
220        let data = b"hello world";
221
222        let executor = deterministic::Runner::default();
223        executor.start(|_| async move {
224            let (received, _) = futures::try_join!(stream.recv(data.len() as u64), async {
225                sleep(Duration::from_millis(50));
226                sink.send(data.as_slice()).await
227            })
228            .unwrap();
229            assert_eq!(received.coalesce(), data);
230        });
231    }
232
233    #[test]
234    fn test_recv_error_sink_dropped_while_waiting() {
235        let (sink, mut stream) = Channel::init();
236
237        let executor = deterministic::Runner::default();
238        executor.start(|context| async move {
239            futures::join!(
240                async {
241                    let result = stream.recv(5).await;
242                    assert!(matches!(result, Err(Error::Closed)));
243                },
244                async {
245                    // Wait for the stream to start waiting
246                    context.sleep(Duration::from_millis(50)).await;
247                    drop(sink);
248                }
249            );
250        });
251    }
252
253    #[test]
254    fn test_recv_error_sink_dropped_before_recv() {
255        let (sink, mut stream) = Channel::init();
256        drop(sink); // Drop sink immediately
257
258        let executor = deterministic::Runner::default();
259        executor.start(|_| async move {
260            let result = stream.recv(5).await;
261            assert!(matches!(result, Err(Error::Closed)));
262        });
263    }
264
265    #[test]
266    fn test_send_error_stream_dropped() {
267        let (mut sink, mut stream) = Channel::init();
268
269        let executor = deterministic::Runner::default();
270        executor.start(|context| async move {
271            // Send some bytes
272            assert!(sink.send(b"7 bytes".as_slice()).await.is_ok());
273
274            // Spawn a task to initiate recv's where the first one will succeed and then will drop.
275            let handle = context.clone().spawn(|_| async move {
276                let _ = stream.recv(5).await;
277                let _ = stream.recv(5).await;
278            });
279
280            // Give the async task a moment to start
281            context.sleep(Duration::from_millis(50)).await;
282
283            // Drop the stream by aborting the handle
284            handle.abort();
285            assert!(matches!(handle.await, Err(Error::Closed)));
286
287            // Try to send a message. The stream is dropped, so this should fail.
288            let result = sink.send(b"hello world".as_slice()).await;
289            assert!(matches!(result, Err(Error::Closed)));
290        });
291    }
292
293    #[test]
294    fn test_send_error_stream_dropped_before_send() {
295        let (mut sink, stream) = Channel::init();
296        drop(stream); // Drop stream immediately
297
298        let executor = deterministic::Runner::default();
299        executor.start(|_| async move {
300            let result = sink.send(b"hello world".as_slice()).await;
301            assert!(matches!(result, Err(Error::Closed)));
302        });
303    }
304
305    #[test]
306    fn test_recv_timeout() {
307        let (_sink, mut stream) = Channel::init();
308
309        // If there is no data to read, test that the recv function just blocks.
310        // The timeout should return first.
311        let executor = deterministic::Runner::default();
312        executor.start(|context| async move {
313            select! {
314                v = stream.recv(5) => {
315                    panic!("unexpected value: {v:?}");
316                },
317                _ = context.sleep(Duration::from_millis(100)) => "timeout",
318            };
319        });
320    }
321
322    #[test]
323    fn test_peek_empty() {
324        let (_sink, stream) = Channel::init();
325
326        // Peek on a fresh stream should return empty slice
327        assert!(stream.peek(10).is_empty());
328    }
329
330    #[test]
331    fn test_peek_after_partial_recv() {
332        let (mut sink, mut stream) = Channel::init();
333
334        let executor = deterministic::Runner::default();
335        executor.start(|_| async move {
336            // Send more data than we'll consume
337            sink.send(b"hello world".as_slice()).await.unwrap();
338
339            // Recv only part of it
340            let received = stream.recv(5).await.unwrap();
341            assert_eq!(received.coalesce(), b"hello");
342
343            // Peek should show the remaining data
344            assert_eq!(stream.peek(100), b" world");
345
346            // Peek with smaller max_len
347            assert_eq!(stream.peek(3), b" wo");
348
349            // Peek doesn't consume - can peek again
350            assert_eq!(stream.peek(100), b" world");
351
352            // Recv consumes the peeked data
353            let received = stream.recv(6).await.unwrap();
354            assert_eq!(received.coalesce(), b" world");
355
356            // Peek is now empty
357            assert!(stream.peek(100).is_empty());
358        });
359    }
360
361    #[test]
362    fn test_peek_after_recv_wakeup() {
363        let (mut sink, mut stream) = Channel::init_with_read_buffer_size(64);
364
365        let executor = deterministic::Runner::default();
366        executor.start(|context| async move {
367            // Spawn recv that will block waiting
368            let (tx, rx) = oneshot::channel();
369            let recv_handle = context.clone().spawn(|_| async move {
370                let data = stream.recv(3).await.unwrap();
371                tx.send(stream).ok();
372                data
373            });
374
375            // Let recv set up waiter
376            context.sleep(Duration::from_millis(10)).await;
377
378            // Send more than requested
379            sink.send(b"ABCDEFGHIJ".as_slice()).await.unwrap();
380
381            // Recv gets its 3 bytes
382            let received = recv_handle.await.unwrap();
383            assert_eq!(received.coalesce(), b"ABC");
384
385            // Get stream back and verify peek sees remaining data
386            let stream = rx.await.unwrap();
387            assert_eq!(stream.peek(100), b"DEFGHIJ");
388        });
389    }
390
391    #[test]
392    fn test_peek_multiple_sends() {
393        let (mut sink, mut stream) = Channel::init();
394
395        let executor = deterministic::Runner::default();
396        executor.start(|_| async move {
397            // Send multiple chunks
398            sink.send(b"aaa".as_slice()).await.unwrap();
399            sink.send(b"bbb".as_slice()).await.unwrap();
400            sink.send(b"ccc".as_slice()).await.unwrap();
401
402            // Recv less than total
403            let received = stream.recv(4).await.unwrap();
404            assert_eq!(received.coalesce(), b"aaab");
405
406            // Peek should show remaining
407            assert_eq!(stream.peek(100), b"bbccc");
408        });
409    }
410
411    #[test]
412    fn test_read_buffer_size_limit() {
413        // Use a small buffer capacity for testing
414        let (mut sink, mut stream) = Channel::init_with_read_buffer_size(10);
415
416        let executor = deterministic::Runner::default();
417        executor.start(|_| async move {
418            // Send more than buffer capacity
419            sink.send(b"0123456789ABCDEF".as_slice()).await.unwrap();
420
421            // Recv a small amount - should only pull up to capacity (10 bytes)
422            let received = stream.recv(2).await.unwrap();
423            assert_eq!(received.coalesce(), b"01");
424
425            // Peek should show remaining buffered data (8 bytes, not 14)
426            assert_eq!(stream.peek(100), b"23456789");
427
428            // The rest should still be in the channel buffer
429            // Recv more to pull the remaining data
430            let received = stream.recv(8).await.unwrap();
431            assert_eq!(received.coalesce(), b"23456789");
432
433            // Now peek should show next chunk from channel (up to capacity)
434            let received = stream.recv(2).await.unwrap();
435            assert_eq!(received.coalesce(), b"AB");
436
437            assert_eq!(stream.peek(100), b"CDEF");
438        });
439    }
440
441    #[test]
442    fn test_recv_before_send() {
443        // Use a small buffer capacity for testing
444        let (mut sink, mut stream) = Channel::init_with_read_buffer_size(10);
445
446        let executor = deterministic::Runner::default();
447        executor.start(|context| async move {
448            // Start recv before send (will wait)
449            let recv_handle = context
450                .clone()
451                .spawn(|_| async move { stream.recv(3).await.unwrap() });
452
453            // Give recv time to set up waiter
454            context.sleep(Duration::from_millis(10)).await;
455
456            // Send more than capacity
457            sink.send(b"ABCDEFGHIJKLMNOP".as_slice()).await.unwrap();
458
459            // Recv should get its 3 bytes
460            let received = recv_handle.await.unwrap();
461            assert_eq!(received.coalesce(), b"ABC");
462        });
463    }
464}