Skip to main content

commonware_utils/channel/
ring.rs

1//! A bounded mpsc channel that drops the oldest item when full instead of applying backpressure.
2//!
3//! This is useful for scenarios where you want to keep the most recent items and can
4//! tolerate losing older ones, such as real-time data streams or status updates where
5//! only the latest values matter.
6//!
7//! # Example
8//!
9//! ```
10//! use futures::executor::block_on;
11//! use futures::{SinkExt, StreamExt};
12//! use commonware_utils::{NZUsize, channel::ring};
13//!
14//! block_on(async {
15//!     let (mut sender, mut receiver) = ring::channel::<u32>(NZUsize!(2));
16//!
17//!     // Fill the channel
18//!     sender.send(1).await.unwrap();
19//!     sender.send(2).await.unwrap();
20//!
21//!     // This will drop the oldest item (1) and insert 3
22//!     sender.send(3).await.unwrap();
23//!
24//!     // Receive the remaining items
25//!     assert_eq!(receiver.next().await, Some(2));
26//!     assert_eq!(receiver.next().await, Some(3));
27//! });
28//! ```
29
30use crate::sync::Mutex;
31use core::num::NonZeroUsize;
32use futures::{stream::FusedStream, Sink, Stream};
33use std::{
34    collections::VecDeque,
35    pin::Pin,
36    sync::Arc,
37    task::{Context, Poll, Waker},
38};
39use thiserror::Error;
40
41/// Error returned when sending to a channel whose receiver has been dropped.
42#[derive(Debug, Error)]
43#[error("channel closed")]
44pub struct ChannelClosed;
45
46/// Error returned by [`Receiver::try_recv`].
47#[derive(Debug, Error, PartialEq, Eq)]
48pub enum TryRecvError {
49    /// The channel currently has no buffered items, but senders still exist.
50    #[error("channel empty")]
51    Empty,
52    /// The channel is empty and all senders have been dropped.
53    #[error("channel closed")]
54    Disconnected,
55}
56
57#[derive(Debug)]
58struct Shared<T: Send + Sync> {
59    buffer: VecDeque<T>,
60    capacity: usize,
61    receiver_waker: Option<Waker>,
62    sender_count: usize,
63    receiver_dropped: bool,
64}
65
66/// The sending half of a ring channel.
67///
68/// Implements [`Sink`] for sending items. Use [`SinkExt::send`](futures::SinkExt::send)
69/// to send items asynchronously.
70///
71/// This type can be cloned to create multiple producers for the same channel.
72/// The channel remains open until all senders are dropped.
73pub struct Sender<T: Send + Sync> {
74    shared: Arc<Mutex<Shared<T>>>,
75}
76
77impl<T: Send + Sync> Sender<T> {
78    /// Returns whether the receiver has been dropped.
79    ///
80    /// If this returns `true`, subsequent sends will fail with [`ChannelClosed`].
81    pub fn is_closed(&self) -> bool {
82        let shared = self.shared.lock();
83        shared.receiver_dropped
84    }
85}
86
87impl<T: Send + Sync> Clone for Sender<T> {
88    fn clone(&self) -> Self {
89        let mut shared = self.shared.lock();
90        shared.sender_count += 1;
91        drop(shared);
92
93        Self {
94            shared: self.shared.clone(),
95        }
96    }
97}
98
99impl<T: Send + Sync> Drop for Sender<T> {
100    fn drop(&mut self) {
101        let mut shared = self.shared.lock();
102        shared.sender_count -= 1;
103        let waker = if shared.sender_count == 0 {
104            shared.receiver_waker.take()
105        } else {
106            None
107        };
108        drop(shared);
109
110        if let Some(w) = waker {
111            w.wake();
112        }
113    }
114}
115
116impl<T: Send + Sync> Sink<T> for Sender<T> {
117    type Error = ChannelClosed;
118
119    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120        let shared = self.shared.lock();
121        if shared.receiver_dropped {
122            return Poll::Ready(Err(ChannelClosed));
123        }
124
125        Poll::Ready(Ok(()))
126    }
127
128    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
129        let mut shared = self.shared.lock();
130
131        if shared.receiver_dropped {
132            return Err(ChannelClosed);
133        }
134
135        let old_item = if shared.buffer.len() >= shared.capacity {
136            shared.buffer.pop_front()
137        } else {
138            None
139        };
140
141        shared.buffer.push_back(item);
142        let waker = shared.receiver_waker.take();
143        drop(shared);
144
145        // Drop the old item after the lock is released to avoid potential mutex poisoning
146        drop(old_item);
147
148        if let Some(w) = waker {
149            w.wake();
150        }
151
152        Ok(())
153    }
154
155    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156        // No buffering in the sender - items are sent immediately to the shared buffer
157        Poll::Ready(Ok(()))
158    }
159
160    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
161        // Closing is handled by Drop
162        Poll::Ready(Ok(()))
163    }
164}
165
166/// The receiving half of a ring channel.
167///
168/// Implements [`Stream`] and [`FusedStream`] for receiving items. Use
169/// [`StreamExt::next`](futures::StreamExt::next) to receive items asynchronously.
170///
171/// The stream terminates (returns `None`) when all senders have been dropped
172/// and all buffered items have been consumed.
173#[derive(Debug)]
174pub struct Receiver<T: Send + Sync> {
175    shared: Arc<Mutex<Shared<T>>>,
176}
177
178impl<T: Send + Sync> Receiver<T> {
179    /// Receives the next item from the channel.
180    pub async fn recv(&mut self) -> Option<T> {
181        futures::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
182    }
183
184    /// Attempts to receive an item without waiting.
185    pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
186        let mut shared = self.shared.lock();
187        if let Some(item) = shared.buffer.pop_front() {
188            return Ok(item);
189        }
190        if shared.sender_count == 0 {
191            return Err(TryRecvError::Disconnected);
192        }
193        Err(TryRecvError::Empty)
194    }
195}
196
197impl<T: Send + Sync> Stream for Receiver<T> {
198    type Item = T;
199
200    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
201        let mut shared = self.shared.lock();
202
203        if let Some(item) = shared.buffer.pop_front() {
204            return Poll::Ready(Some(item));
205        }
206
207        if shared.sender_count == 0 {
208            return Poll::Ready(None);
209        }
210
211        if !shared
212            .receiver_waker
213            .as_ref()
214            .is_some_and(|w| w.will_wake(cx.waker()))
215        {
216            shared.receiver_waker = Some(cx.waker().clone());
217        }
218        Poll::Pending
219    }
220}
221
222impl<T: Send + Sync> FusedStream for Receiver<T> {
223    fn is_terminated(&self) -> bool {
224        let shared = self.shared.lock();
225        shared.sender_count == 0 && shared.buffer.is_empty()
226    }
227}
228
229impl<T: Send + Sync> Drop for Receiver<T> {
230    fn drop(&mut self) {
231        let mut shared = self.shared.lock();
232        shared.receiver_dropped = true;
233    }
234}
235
236/// Creates a new ring channel with the specified capacity.
237///
238/// Returns a ([`Sender`], [`Receiver`]) pair. The sender can be cloned to create
239/// multiple producers.
240pub fn channel<T: Send + Sync>(capacity: NonZeroUsize) -> (Sender<T>, Receiver<T>) {
241    let shared = Arc::new(Mutex::new(Shared {
242        buffer: VecDeque::with_capacity(capacity.get()),
243        capacity: capacity.get(),
244        receiver_waker: None,
245        sender_count: 1,
246        receiver_dropped: false,
247    }));
248
249    let sender = Sender {
250        shared: shared.clone(),
251    };
252    let receiver = Receiver { shared };
253
254    (sender, receiver)
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use crate::NZUsize;
261    use futures::{executor::block_on, SinkExt, StreamExt};
262
263    #[test]
264    fn test_basic_send_recv() {
265        block_on(async {
266            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
267
268            sender.send(1).await.unwrap();
269            sender.send(2).await.unwrap();
270            sender.send(3).await.unwrap();
271
272            assert_eq!(receiver.next().await, Some(1));
273            assert_eq!(receiver.next().await, Some(2));
274            assert_eq!(receiver.next().await, Some(3));
275        });
276    }
277
278    #[test]
279    fn test_overflow_drops_oldest() {
280        block_on(async {
281            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(2));
282
283            sender.send(1).await.unwrap();
284            sender.send(2).await.unwrap();
285            sender.send(3).await.unwrap(); // Should drop 1
286            sender.send(4).await.unwrap(); // Should drop 2
287
288            assert_eq!(receiver.next().await, Some(3));
289            assert_eq!(receiver.next().await, Some(4));
290        });
291    }
292
293    #[test]
294    fn test_send_after_receiver_dropped() {
295        block_on(async {
296            let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
297            drop(receiver);
298
299            let err = sender.send(1).await.unwrap_err();
300            assert!(matches!(err, ChannelClosed));
301        });
302    }
303
304    #[test]
305    fn test_recv_after_sender_dropped() {
306        block_on(async {
307            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
308
309            sender.send(1).await.unwrap();
310            sender.send(2).await.unwrap();
311            drop(sender);
312
313            assert_eq!(receiver.next().await, Some(1));
314            assert_eq!(receiver.next().await, Some(2));
315            assert_eq!(receiver.next().await, None);
316        });
317    }
318
319    #[test]
320    fn test_stream_collect() {
321        block_on(async {
322            let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
323
324            sender.send(1).await.unwrap();
325            sender.send(2).await.unwrap();
326            sender.send(3).await.unwrap();
327            drop(sender);
328
329            let items: Vec<_> = receiver.collect().await;
330            assert_eq!(items, vec![1, 2, 3]);
331        });
332    }
333
334    #[test]
335    fn test_clone_sender() {
336        block_on(async {
337            let (mut sender1, mut receiver) = channel::<i32>(NZUsize!(10));
338            let mut sender2 = sender1.clone();
339
340            sender1.send(1).await.unwrap();
341            sender2.send(2).await.unwrap();
342
343            assert_eq!(receiver.next().await, Some(1));
344            assert_eq!(receiver.next().await, Some(2));
345        });
346    }
347
348    #[test]
349    fn test_sender_drop_with_clones() {
350        block_on(async {
351            let (sender1, mut receiver) = channel::<i32>(NZUsize!(10));
352            let mut sender2 = sender1.clone();
353
354            drop(sender1);
355
356            // Channel should still be open because sender2 exists
357            sender2.send(1).await.unwrap();
358            assert_eq!(receiver.next().await, Some(1));
359
360            drop(sender2);
361            // Now channel should be closed
362            assert_eq!(receiver.next().await, None);
363        });
364    }
365
366    #[test]
367    fn test_capacity_one() {
368        block_on(async {
369            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(1));
370
371            sender.send(1).await.unwrap();
372            sender.send(2).await.unwrap(); // Drops 1
373
374            assert_eq!(receiver.next().await, Some(2));
375
376            sender.send(1).await.unwrap();
377            sender.send(2).await.unwrap(); // Drops 1
378            sender.send(3).await.unwrap(); // Drops 2
379
380            assert_eq!(receiver.next().await, Some(3));
381        });
382    }
383
384    #[test]
385    fn test_send_all() {
386        block_on(async {
387            let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
388
389            let items = futures::stream::iter(vec![1, 2, 3]);
390            sender.send_all(&mut items.map(Ok)).await.unwrap();
391            drop(sender);
392
393            let received: Vec<_> = receiver.collect().await;
394            assert_eq!(received, vec![1, 2, 3]);
395        });
396    }
397
398    #[test]
399    fn test_fused_stream() {
400        use futures::stream::FusedStream;
401
402        block_on(async {
403            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
404
405            assert!(!receiver.is_terminated());
406
407            sender.send(1).await.unwrap();
408            assert!(!receiver.is_terminated());
409
410            drop(sender);
411            assert!(!receiver.is_terminated()); // Still has item in buffer
412
413            assert_eq!(receiver.next().await, Some(1));
414            assert!(receiver.is_terminated()); // Now terminated
415
416            // Calling next after termination returns None
417            assert_eq!(receiver.next().await, None);
418            assert!(receiver.is_terminated());
419        });
420    }
421
422    #[test]
423    fn test_is_closed() {
424        block_on(async {
425            let (sender, receiver) = channel::<i32>(NZUsize!(10));
426
427            assert!(!sender.is_closed());
428
429            drop(receiver);
430            assert!(sender.is_closed());
431        });
432    }
433}