Skip to main content

commonware_utils/channel/
ring.rs

1//! A bounded mpsc channel that drops the oldest item when full instead of applying backpressure.
2//!
3//! This is useful for scenarios where you want to keep the most recent items and can
4//! tolerate losing older ones, such as real-time data streams or status updates where
5//! only the latest values matter.
6//!
7//! # Example
8//!
9//! ```
10//! use futures::executor::block_on;
11//! use futures::{SinkExt, StreamExt};
12//! use commonware_utils::{NZUsize, channel::ring};
13//!
14//! block_on(async {
15//!     let (mut sender, mut receiver) = ring::channel::<u32>(NZUsize!(2));
16//!
17//!     // Fill the channel
18//!     sender.send(1).await.unwrap();
19//!     sender.send(2).await.unwrap();
20//!
21//!     // This will drop the oldest item (1) and insert 3
22//!     sender.send(3).await.unwrap();
23//!
24//!     // Receive the remaining items
25//!     assert_eq!(receiver.next().await, Some(2));
26//!     assert_eq!(receiver.next().await, Some(3));
27//! });
28//! ```
29
30use core::num::NonZeroUsize;
31use futures::{stream::FusedStream, Sink, Stream};
32use std::{
33    collections::VecDeque,
34    pin::Pin,
35    sync::{Arc, Mutex},
36    task::{Context, Poll, Waker},
37};
38use thiserror::Error;
39
40/// Error returned when sending to a channel whose receiver has been dropped.
41#[derive(Debug, Error)]
42#[error("channel closed")]
43pub struct ChannelClosed;
44
45#[derive(Debug)]
46struct Shared<T: Send + Sync> {
47    buffer: VecDeque<T>,
48    capacity: usize,
49    receiver_waker: Option<Waker>,
50    sender_count: usize,
51    receiver_dropped: bool,
52}
53
54/// The sending half of a ring channel.
55///
56/// Implements [`Sink`] for sending items. Use [`SinkExt::send`](futures::SinkExt::send)
57/// to send items asynchronously.
58///
59/// This type can be cloned to create multiple producers for the same channel.
60/// The channel remains open until all senders are dropped.
61pub struct Sender<T: Send + Sync> {
62    shared: Arc<Mutex<Shared<T>>>,
63}
64
65impl<T: Send + Sync> Sender<T> {
66    /// Returns whether the receiver has been dropped.
67    ///
68    /// If this returns `true`, subsequent sends will fail with [`ChannelClosed`].
69    pub fn is_closed(&self) -> bool {
70        let shared = self.shared.lock().unwrap();
71        shared.receiver_dropped
72    }
73}
74
75impl<T: Send + Sync> Clone for Sender<T> {
76    fn clone(&self) -> Self {
77        let mut shared = self.shared.lock().unwrap();
78        shared.sender_count += 1;
79        drop(shared);
80
81        Self {
82            shared: self.shared.clone(),
83        }
84    }
85}
86
87impl<T: Send + Sync> Drop for Sender<T> {
88    fn drop(&mut self) {
89        let Ok(mut shared) = self.shared.lock() else {
90            return;
91        };
92        shared.sender_count -= 1;
93        let waker = if shared.sender_count == 0 {
94            shared.receiver_waker.take()
95        } else {
96            None
97        };
98        drop(shared);
99
100        if let Some(w) = waker {
101            w.wake();
102        }
103    }
104}
105
106impl<T: Send + Sync> Sink<T> for Sender<T> {
107    type Error = ChannelClosed;
108
109    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110        let shared = self.shared.lock().unwrap();
111        if shared.receiver_dropped {
112            return Poll::Ready(Err(ChannelClosed));
113        }
114
115        Poll::Ready(Ok(()))
116    }
117
118    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
119        let mut shared = self.shared.lock().unwrap();
120
121        if shared.receiver_dropped {
122            return Err(ChannelClosed);
123        }
124
125        let old_item = if shared.buffer.len() >= shared.capacity {
126            shared.buffer.pop_front()
127        } else {
128            None
129        };
130
131        shared.buffer.push_back(item);
132        let waker = shared.receiver_waker.take();
133        drop(shared);
134
135        // Drop the old item after the lock is released to avoid potential mutex poisoning
136        drop(old_item);
137
138        if let Some(w) = waker {
139            w.wake();
140        }
141
142        Ok(())
143    }
144
145    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
146        // No buffering in the sender - items are sent immediately to the shared buffer
147        Poll::Ready(Ok(()))
148    }
149
150    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
151        // Closing is handled by Drop
152        Poll::Ready(Ok(()))
153    }
154}
155
156/// The receiving half of a ring channel.
157///
158/// Implements [`Stream`] and [`FusedStream`] for receiving items. Use
159/// [`StreamExt::next`](futures::StreamExt::next) to receive items asynchronously.
160///
161/// The stream terminates (returns `None`) when all senders have been dropped
162/// and all buffered items have been consumed.
163#[derive(Debug)]
164pub struct Receiver<T: Send + Sync> {
165    shared: Arc<Mutex<Shared<T>>>,
166}
167
168impl<T: Send + Sync> Stream for Receiver<T> {
169    type Item = T;
170
171    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
172        let mut shared = self.shared.lock().unwrap();
173
174        if let Some(item) = shared.buffer.pop_front() {
175            return Poll::Ready(Some(item));
176        }
177
178        if shared.sender_count == 0 {
179            return Poll::Ready(None);
180        }
181
182        if !shared
183            .receiver_waker
184            .as_ref()
185            .is_some_and(|w| w.will_wake(cx.waker()))
186        {
187            shared.receiver_waker = Some(cx.waker().clone());
188        }
189        Poll::Pending
190    }
191}
192
193impl<T: Send + Sync> FusedStream for Receiver<T> {
194    fn is_terminated(&self) -> bool {
195        let shared = self.shared.lock().unwrap();
196        shared.sender_count == 0 && shared.buffer.is_empty()
197    }
198}
199
200impl<T: Send + Sync> Drop for Receiver<T> {
201    fn drop(&mut self) {
202        let Ok(mut shared) = self.shared.lock() else {
203            return;
204        };
205        shared.receiver_dropped = true;
206    }
207}
208
209/// Creates a new ring channel with the specified capacity.
210///
211/// Returns a ([`Sender`], [`Receiver`]) pair. The sender can be cloned to create
212/// multiple producers.
213pub fn channel<T: Send + Sync>(capacity: NonZeroUsize) -> (Sender<T>, Receiver<T>) {
214    let shared = Arc::new(Mutex::new(Shared {
215        buffer: VecDeque::with_capacity(capacity.get()),
216        capacity: capacity.get(),
217        receiver_waker: None,
218        sender_count: 1,
219        receiver_dropped: false,
220    }));
221
222    let sender = Sender {
223        shared: shared.clone(),
224    };
225    let receiver = Receiver { shared };
226
227    (sender, receiver)
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use crate::NZUsize;
234    use futures::{executor::block_on, SinkExt, StreamExt};
235
236    #[test]
237    fn test_basic_send_recv() {
238        block_on(async {
239            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
240
241            sender.send(1).await.unwrap();
242            sender.send(2).await.unwrap();
243            sender.send(3).await.unwrap();
244
245            assert_eq!(receiver.next().await, Some(1));
246            assert_eq!(receiver.next().await, Some(2));
247            assert_eq!(receiver.next().await, Some(3));
248        });
249    }
250
251    #[test]
252    fn test_overflow_drops_oldest() {
253        block_on(async {
254            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(2));
255
256            sender.send(1).await.unwrap();
257            sender.send(2).await.unwrap();
258            sender.send(3).await.unwrap(); // Should drop 1
259            sender.send(4).await.unwrap(); // Should drop 2
260
261            assert_eq!(receiver.next().await, Some(3));
262            assert_eq!(receiver.next().await, Some(4));
263        });
264    }
265
266    #[test]
267    fn test_send_after_receiver_dropped() {
268        block_on(async {
269            let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
270            drop(receiver);
271
272            let err = sender.send(1).await.unwrap_err();
273            assert!(matches!(err, ChannelClosed));
274        });
275    }
276
277    #[test]
278    fn test_recv_after_sender_dropped() {
279        block_on(async {
280            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
281
282            sender.send(1).await.unwrap();
283            sender.send(2).await.unwrap();
284            drop(sender);
285
286            assert_eq!(receiver.next().await, Some(1));
287            assert_eq!(receiver.next().await, Some(2));
288            assert_eq!(receiver.next().await, None);
289        });
290    }
291
292    #[test]
293    fn test_stream_collect() {
294        block_on(async {
295            let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
296
297            sender.send(1).await.unwrap();
298            sender.send(2).await.unwrap();
299            sender.send(3).await.unwrap();
300            drop(sender);
301
302            let items: Vec<_> = receiver.collect().await;
303            assert_eq!(items, vec![1, 2, 3]);
304        });
305    }
306
307    #[test]
308    fn test_clone_sender() {
309        block_on(async {
310            let (mut sender1, mut receiver) = channel::<i32>(NZUsize!(10));
311            let mut sender2 = sender1.clone();
312
313            sender1.send(1).await.unwrap();
314            sender2.send(2).await.unwrap();
315
316            assert_eq!(receiver.next().await, Some(1));
317            assert_eq!(receiver.next().await, Some(2));
318        });
319    }
320
321    #[test]
322    fn test_sender_drop_with_clones() {
323        block_on(async {
324            let (sender1, mut receiver) = channel::<i32>(NZUsize!(10));
325            let mut sender2 = sender1.clone();
326
327            drop(sender1);
328
329            // Channel should still be open because sender2 exists
330            sender2.send(1).await.unwrap();
331            assert_eq!(receiver.next().await, Some(1));
332
333            drop(sender2);
334            // Now channel should be closed
335            assert_eq!(receiver.next().await, None);
336        });
337    }
338
339    #[test]
340    fn test_capacity_one() {
341        block_on(async {
342            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(1));
343
344            sender.send(1).await.unwrap();
345            sender.send(2).await.unwrap(); // Drops 1
346
347            assert_eq!(receiver.next().await, Some(2));
348
349            sender.send(1).await.unwrap();
350            sender.send(2).await.unwrap(); // Drops 1
351            sender.send(3).await.unwrap(); // Drops 2
352
353            assert_eq!(receiver.next().await, Some(3));
354        });
355    }
356
357    #[test]
358    fn test_send_all() {
359        block_on(async {
360            let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
361
362            let items = futures::stream::iter(vec![1, 2, 3]);
363            sender.send_all(&mut items.map(Ok)).await.unwrap();
364            drop(sender);
365
366            let received: Vec<_> = receiver.collect().await;
367            assert_eq!(received, vec![1, 2, 3]);
368        });
369    }
370
371    #[test]
372    fn test_fused_stream() {
373        use futures::stream::FusedStream;
374
375        block_on(async {
376            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
377
378            assert!(!receiver.is_terminated());
379
380            sender.send(1).await.unwrap();
381            assert!(!receiver.is_terminated());
382
383            drop(sender);
384            assert!(!receiver.is_terminated()); // Still has item in buffer
385
386            assert_eq!(receiver.next().await, Some(1));
387            assert!(receiver.is_terminated()); // Now terminated
388
389            // Calling next after termination returns None
390            assert_eq!(receiver.next().await, None);
391            assert!(receiver.is_terminated());
392        });
393    }
394
395    #[test]
396    fn test_is_closed() {
397        block_on(async {
398            let (sender, receiver) = channel::<i32>(NZUsize!(10));
399
400            assert!(!sender.is_closed());
401
402            drop(receiver);
403            assert!(sender.is_closed());
404        });
405    }
406}