Skip to main content

commonware_runtime/
mocks.rs

1//! A mock implementation of a channel that implements the Sink and Stream traits.
2
3use crate::{BufMut, Error, IoBufs};
4use bytes::{Bytes, BytesMut};
5use commonware_utils::{
6    channel::{fallible::OneshotExt, oneshot},
7    sync::Mutex,
8};
9use std::sync::Arc;
10
11/// Default buffer size (64 KB). Controls both how much data the stream
12/// pulls per recv and the backpressure threshold for send.
13const DEFAULT_BUFFER_SIZE: usize = 64 * 1024;
14
15/// A mock channel struct that is used internally by Sink and Stream.
16pub struct Channel {
17    /// Stores the bytes sent by the sink that are not yet read by the stream.
18    buffer: BytesMut,
19
20    /// If the stream is waiting to read bytes, the waiter stores the number of
21    /// bytes that the stream is waiting for, as well as the oneshot sender that
22    /// the sink uses to send the bytes to the stream directly.
23    waiter: Option<(usize, oneshot::Sender<Bytes>)>,
24
25    /// Target buffer size, used to bound both the stream's local buffer
26    /// and the shared buffer (backpressure threshold).
27    buffer_size: usize,
28
29    /// If the sink is blocked waiting for the buffer to drain, this holds
30    /// the oneshot sender that the stream uses to wake the sink.
31    drain_waiter: Option<oneshot::Sender<()>>,
32
33    /// Tracks whether the sink is still alive and able to send messages.
34    sink_alive: bool,
35
36    /// Tracks whether the stream is still alive and able to receive messages.
37    stream_alive: bool,
38}
39
40impl Channel {
41    /// Returns an async-safe Sink/Stream pair with default buffer size.
42    pub fn init() -> (Sink, Stream) {
43        Self::init_with_buffer_size(DEFAULT_BUFFER_SIZE)
44    }
45
46    /// Returns an async-safe Sink/Stream pair with the specified buffer size.
47    pub fn init_with_buffer_size(buffer_size: usize) -> (Sink, Stream) {
48        let channel = Arc::new(Mutex::new(Self {
49            buffer: BytesMut::new(),
50            waiter: None,
51            buffer_size,
52            drain_waiter: None,
53            sink_alive: true,
54            stream_alive: true,
55        }));
56        (
57            Sink {
58                channel: channel.clone(),
59                state: SinkState::Open,
60            },
61            Stream {
62                channel,
63                buffer: BytesMut::new(),
64                poisoned: false,
65            },
66        )
67    }
68
69    /// Restores bytes that were detached from the front of the shared buffer.
70    fn restore_front(&mut self, data: Bytes) {
71        if data.is_empty() {
72            return;
73        }
74
75        let mut restored = BytesMut::with_capacity(data.len() + self.buffer.len());
76        restored.extend_from_slice(&data);
77        restored.extend_from_slice(&self.buffer);
78        self.buffer = restored;
79    }
80
81    /// Marks the sink as closed and wakes any waiter.
82    fn close_sink(&mut self) {
83        self.sink_alive = false;
84
85        // If there is a waiter, resolve it by dropping the oneshot sender.
86        self.waiter.take();
87    }
88}
89
90struct RecvWaiterGuard {
91    channel: Arc<Mutex<Channel>>,
92    active: bool,
93}
94
95impl RecvWaiterGuard {
96    const fn new(channel: Arc<Mutex<Channel>>) -> Self {
97        Self {
98            channel,
99            active: true,
100        }
101    }
102
103    const fn disarm(&mut self) {
104        self.active = false;
105    }
106}
107
108impl Drop for RecvWaiterGuard {
109    fn drop(&mut self) {
110        if !self.active {
111            return;
112        }
113
114        self.channel.lock().waiter.take();
115    }
116}
117
118/// A mock sink that implements the Sink trait.
119pub struct Sink {
120    channel: Arc<Mutex<Channel>>,
121    state: SinkState,
122}
123
124/// Lifecycle state for the mock sink half.
125enum SinkState {
126    /// Sends may be attempted.
127    Open,
128    /// A send is currently in progress.
129    Sending,
130    /// The sink has been closed.
131    Closed,
132}
133
134impl Sink {
135    fn close(&mut self) {
136        if matches!(self.state, SinkState::Closed) {
137            return;
138        }
139        self.channel.lock().close_sink();
140        self.state = SinkState::Closed;
141    }
142}
143
144impl crate::Sink for Sink {
145    async fn send(&mut self, bufs: impl Into<IoBufs> + Send) -> Result<(), Error> {
146        match self.state {
147            SinkState::Open => {}
148            SinkState::Sending => {
149                self.close();
150                return Err(Error::Closed);
151            }
152            SinkState::Closed => return Err(Error::Closed),
153        }
154
155        let drain_recv = {
156            let mut channel = self.channel.lock();
157
158            // If the receiver is dead, we cannot send any more messages.
159            if !channel.stream_alive {
160                channel.close_sink();
161                self.state = SinkState::Closed;
162                return Err(Error::SendFailed);
163            }
164
165            channel.buffer.put(bufs.into());
166
167            // If there is a waiter and the buffer is large enough,
168            // resolve the waiter (while clearing the waiter field).
169            if channel
170                .waiter
171                .as_ref()
172                .is_some_and(|(requested, _)| *requested <= channel.buffer.len())
173            {
174                // Send up to buffer_size bytes (but at least requested amount)
175                let (requested, os_send) = channel.waiter.take().unwrap();
176                let send_amount = channel.buffer.len().min(requested.max(channel.buffer_size));
177                let data = channel.buffer.split_to(send_amount).freeze();
178
179                // A canceled recv should behave like a buffered transport:
180                // preserve the bytes and allow a subsequent recv to consume them.
181                if let Err(data) = os_send.send(data) {
182                    channel.restore_front(data);
183                    if !channel.stream_alive {
184                        channel.close_sink();
185                        self.state = SinkState::Closed;
186                        return Err(Error::SendFailed);
187                    }
188                }
189            }
190
191            // If the buffer exceeds the write limit, block until the
192            // receiver drains enough data.
193            if channel.buffer.len() > channel.buffer_size {
194                assert!(channel.drain_waiter.is_none());
195                let (os_send, os_recv) = oneshot::channel();
196                channel.drain_waiter = Some(os_send);
197                os_recv
198            } else {
199                return Ok(());
200            }
201        };
202
203        // Mark the sink as sending before awaiting so cancellation can be
204        // detected by the next send.
205        self.state = SinkState::Sending;
206
207        // Wait for the receiver to drain the buffer.
208        match drain_recv.await {
209            Ok(()) => {
210                self.state = SinkState::Open;
211                Ok(())
212            }
213            Err(_) => {
214                self.close();
215                Err(Error::SendFailed)
216            }
217        }
218    }
219}
220
221impl Drop for Sink {
222    fn drop(&mut self) {
223        self.close();
224    }
225}
226
227/// A mock stream that implements the Stream trait.
228pub struct Stream {
229    channel: Arc<Mutex<Channel>>,
230    /// Local buffer for data that has been received but not yet consumed.
231    buffer: BytesMut,
232    poisoned: bool,
233}
234
235impl crate::Stream for Stream {
236    async fn recv(&mut self, len: usize) -> Result<IoBufs, Error> {
237        if self.poisoned {
238            return Err(Error::Closed);
239        }
240
241        let os_recv = {
242            let mut channel = self.channel.lock();
243
244            // Pull data from channel buffer into local buffer.
245            let target = len.max(channel.buffer_size);
246            let pull_amount = channel
247                .buffer
248                .len()
249                .min(target.saturating_sub(self.buffer.len()));
250            if pull_amount > 0 {
251                let data = channel.buffer.split_to(pull_amount);
252                self.buffer.extend_from_slice(&data);
253
254                // Wake a blocked sender if the buffer drained below the limit.
255                if channel.buffer.len() <= channel.buffer_size {
256                    if let Some(sender) = channel.drain_waiter.take() {
257                        sender.send_lossy(());
258                    }
259                }
260            }
261
262            // If we have enough, return immediately.
263            if self.buffer.len() >= len {
264                return Ok(IoBufs::from(self.buffer.split_to(len).freeze()));
265            }
266
267            // If the sink is dead, we cannot receive any more messages.
268            if !channel.sink_alive {
269                self.poisoned = true;
270                return Err(Error::RecvFailed);
271            }
272
273            // Set up waiter for remaining amount.
274            let remaining = len - self.buffer.len();
275            assert!(channel.waiter.is_none());
276            let (os_send, os_recv) = oneshot::channel();
277            channel.waiter = Some((remaining, os_send));
278            os_recv
279        };
280
281        let mut waiter_guard = RecvWaiterGuard::new(self.channel.clone());
282
283        // Pre-poison so that cancellation  leaves the stream permanently closed.
284        self.poisoned = true;
285
286        // Wait for the waiter to be resolved.
287        let data = match os_recv.await {
288            Ok(data) => {
289                waiter_guard.disarm();
290                self.poisoned = false;
291                data
292            }
293            Err(_) => {
294                waiter_guard.disarm();
295                return Err(Error::RecvFailed);
296            }
297        };
298        self.buffer.extend_from_slice(&data);
299
300        assert!(self.buffer.len() >= len);
301        Ok(IoBufs::from(self.buffer.split_to(len).freeze()))
302    }
303
304    fn peek(&self, max_len: usize) -> &[u8] {
305        let len = max_len.min(self.buffer.len());
306        &self.buffer[..len]
307    }
308}
309
310impl Drop for Stream {
311    fn drop(&mut self) {
312        let mut channel = self.channel.lock();
313        channel.stream_alive = false;
314
315        // Wake a blocked sender so it can observe the closed stream.
316        channel.drain_waiter.take();
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use crate::{deterministic, Clock, Runner, Sink, Spawner, Stream, Supervisor as _};
324    use commonware_macros::select;
325    use std::{thread::sleep, time::Duration};
326
327    #[test]
328    fn test_send_recv() {
329        let (mut sink, mut stream) = Channel::init();
330        let data = b"hello world";
331
332        let executor = deterministic::Runner::default();
333        executor.start(|_| async move {
334            sink.send(data.as_slice()).await.unwrap();
335            let received = stream.recv(data.len()).await.unwrap();
336            assert_eq!(received.coalesce(), data);
337        });
338    }
339
340    #[test]
341    fn test_send_recv_partial_multiple() {
342        let (mut sink, mut stream) = Channel::init();
343        let data = b"hello";
344        let data2 = b" world";
345
346        let executor = deterministic::Runner::default();
347        executor.start(|_| async move {
348            sink.send(data.as_slice()).await.unwrap();
349            sink.send(data2.as_slice()).await.unwrap();
350            let received = stream.recv(5).await.unwrap();
351            assert_eq!(received.coalesce(), b"hello");
352            let received = stream.recv(5).await.unwrap();
353            assert_eq!(received.coalesce(), b" worl");
354            let received = stream.recv(1).await.unwrap();
355            assert_eq!(received.coalesce(), b"d");
356        });
357    }
358
359    #[test]
360    fn test_send_recv_async() {
361        let (mut sink, mut stream) = Channel::init();
362        let data = b"hello world";
363
364        let executor = deterministic::Runner::default();
365        executor.start(|_| async move {
366            let (received, _) = futures::try_join!(stream.recv(data.len()), async {
367                sleep(Duration::from_millis(50));
368                sink.send(data.as_slice()).await
369            })
370            .unwrap();
371            assert_eq!(received.coalesce(), data);
372        });
373    }
374
375    #[test]
376    fn test_recv_error_sink_dropped_while_waiting() {
377        let (sink, mut stream) = Channel::init();
378
379        let executor = deterministic::Runner::default();
380        executor.start(|context| async move {
381            futures::join!(
382                async {
383                    let result = stream.recv(5).await;
384                    assert!(matches!(result, Err(Error::RecvFailed)));
385                    let result = stream.recv(5).await;
386                    assert!(matches!(result, Err(Error::Closed)));
387                },
388                async {
389                    // Wait for the stream to start waiting
390                    context.sleep(Duration::from_millis(50)).await;
391                    drop(sink);
392                }
393            );
394        });
395    }
396
397    #[test]
398    fn test_recv_error_sink_dropped_before_recv() {
399        let (sink, mut stream) = Channel::init();
400        drop(sink); // Drop sink immediately
401
402        let executor = deterministic::Runner::default();
403        executor.start(|_| async move {
404            let result = stream.recv(5).await;
405            assert!(matches!(result, Err(Error::RecvFailed)));
406            let result = stream.recv(5).await;
407            assert!(matches!(result, Err(Error::Closed)));
408        });
409    }
410
411    #[test]
412    fn test_send_error_stream_dropped() {
413        let (mut sink, mut stream) = Channel::init();
414
415        let executor = deterministic::Runner::default();
416        executor.start(|context| async move {
417            // Send some bytes
418            assert!(sink.send(b"7 bytes".as_slice()).await.is_ok());
419
420            // Spawn a task to initiate recv's where the first one will succeed and then will drop.
421            let handle = context.child("recv").spawn(|_| async move {
422                let _ = stream.recv(5).await;
423                let _ = stream.recv(5).await;
424            });
425
426            // Give the async task a moment to start
427            context.sleep(Duration::from_millis(50)).await;
428
429            // Drop the stream by aborting the handle
430            handle.abort();
431            assert!(matches!(handle.await, Err(Error::Closed)));
432
433            // Try to send a message. The stream is dropped, so this should fail.
434            let result = sink.send(b"hello world".as_slice()).await;
435            assert!(matches!(result, Err(Error::SendFailed)));
436            let result = sink.send(b"hello world".as_slice()).await;
437            assert!(matches!(result, Err(Error::Closed)));
438        });
439    }
440
441    #[test]
442    fn test_send_error_stream_dropped_before_send() {
443        let (mut sink, stream) = Channel::init();
444        drop(stream); // Drop stream immediately
445
446        let executor = deterministic::Runner::default();
447        executor.start(|_| async move {
448            let result = sink.send(b"hello world".as_slice()).await;
449            assert!(matches!(result, Err(Error::SendFailed)));
450            let result = sink.send(b"hello world".as_slice()).await;
451            assert!(matches!(result, Err(Error::Closed)));
452        });
453    }
454
455    #[test]
456    fn test_recv_timeout() {
457        let (_sink, mut stream) = Channel::init();
458
459        // If there is no data to read, test that the recv function just blocks.
460        // The timeout should return first.
461        let executor = deterministic::Runner::default();
462        executor.start(|context| async move {
463            select! {
464                v = stream.recv(5) => {
465                    panic!("unexpected value: {v:?}");
466                },
467                _ = context.sleep(Duration::from_millis(100)) => "timeout",
468            };
469        });
470    }
471
472    #[test]
473    fn test_peek_empty() {
474        let (_sink, stream) = Channel::init();
475
476        // Peek on a fresh stream should return empty slice
477        assert!(stream.peek(10).is_empty());
478    }
479
480    #[test]
481    fn test_peek_after_partial_recv() {
482        let (mut sink, mut stream) = Channel::init();
483
484        let executor = deterministic::Runner::default();
485        executor.start(|_| async move {
486            // Send more data than we'll consume
487            sink.send(b"hello world".as_slice()).await.unwrap();
488
489            // Recv only part of it
490            let received = stream.recv(5).await.unwrap();
491            assert_eq!(received.coalesce(), b"hello");
492
493            // Peek should show the remaining data
494            assert_eq!(stream.peek(100), b" world");
495
496            // Peek with smaller max_len
497            assert_eq!(stream.peek(3), b" wo");
498
499            // Peek doesn't consume - can peek again
500            assert_eq!(stream.peek(100), b" world");
501
502            // Recv consumes the peeked data
503            let received = stream.recv(6).await.unwrap();
504            assert_eq!(received.coalesce(), b" world");
505
506            // Peek is now empty
507            assert!(stream.peek(100).is_empty());
508        });
509    }
510
511    #[test]
512    fn test_peek_after_recv_wakeup() {
513        let (mut sink, mut stream) = Channel::init_with_buffer_size(64);
514
515        let executor = deterministic::Runner::default();
516        executor.start(|context| async move {
517            // Spawn recv that will block waiting
518            let (tx, rx) = oneshot::channel();
519            let recv_handle = context.child("recv").spawn(|_| async move {
520                let data = stream.recv(3).await.unwrap();
521                tx.send(stream).ok();
522                data
523            });
524
525            // Let recv set up waiter
526            context.sleep(Duration::from_millis(10)).await;
527
528            // Send more than requested
529            sink.send(b"ABCDEFGHIJ".as_slice()).await.unwrap();
530
531            // Recv gets its 3 bytes
532            let received = recv_handle.await.unwrap();
533            assert_eq!(received.coalesce(), b"ABC");
534
535            // Get stream back and verify peek sees remaining data
536            let stream = rx.await.unwrap();
537            assert_eq!(stream.peek(100), b"DEFGHIJ");
538        });
539    }
540
541    #[test]
542    fn test_peek_multiple_sends() {
543        let (mut sink, mut stream) = Channel::init();
544
545        let executor = deterministic::Runner::default();
546        executor.start(|_| async move {
547            // Send multiple chunks
548            sink.send(b"aaa".as_slice()).await.unwrap();
549            sink.send(b"bbb".as_slice()).await.unwrap();
550            sink.send(b"ccc".as_slice()).await.unwrap();
551
552            // Recv less than total
553            let received = stream.recv(4).await.unwrap();
554            assert_eq!(received.coalesce(), b"aaab");
555
556            // Peek should show remaining
557            assert_eq!(stream.peek(100), b"bbccc");
558        });
559    }
560
561    #[test]
562    fn test_buffer_size_limit() {
563        // Use a small buffer capacity for testing
564        let (mut sink, mut stream) = Channel::init_with_buffer_size(10);
565
566        let executor = deterministic::Runner::default();
567        executor.start(|context| async move {
568            // Send more than buffer capacity concurrently with recv
569            // so the sender can drain via backpressure.
570            let send_handle = context.child("sender").spawn(|_| async move {
571                sink.send(b"0123456789ABCDEF".as_slice()).await.unwrap();
572                sink
573            });
574
575            // Recv a small amount - should only pull up to capacity (10 bytes)
576            let received = stream.recv(2).await.unwrap();
577            assert_eq!(received.coalesce(), b"01");
578
579            // Peek should show remaining buffered data (8 bytes, not 14)
580            assert_eq!(stream.peek(100), b"23456789");
581
582            // The rest should still be in the channel buffer
583            // Recv more to pull the remaining data
584            let received = stream.recv(8).await.unwrap();
585            assert_eq!(received.coalesce(), b"23456789");
586
587            // Now peek should show next chunk from channel (up to capacity)
588            let received = stream.recv(2).await.unwrap();
589            assert_eq!(received.coalesce(), b"AB");
590
591            assert_eq!(stream.peek(100), b"CDEF");
592
593            // Ensure the sender completes
594            send_handle.await.unwrap();
595        });
596    }
597
598    #[test]
599    fn test_recv_before_send() {
600        // Use a small buffer capacity for testing
601        let (mut sink, mut stream) = Channel::init_with_buffer_size(10);
602
603        let executor = deterministic::Runner::default();
604        executor.start(|context| async move {
605            // Start recv before send (will wait)
606            let recv_handle = context
607                .child("recv")
608                .spawn(|_| async move { stream.recv(3).await.unwrap() });
609
610            // Give recv time to set up waiter
611            context.sleep(Duration::from_millis(10)).await;
612
613            // Send more than capacity
614            sink.send(b"ABCDEFGHIJKLMNOP".as_slice()).await.unwrap();
615
616            // Recv should get its 3 bytes
617            let received = recv_handle.await.unwrap();
618            assert_eq!(received.coalesce(), b"ABC");
619        });
620    }
621}