commonware_utils/channels/
ring.rs

1//! A bounded mpsc channel that drops the oldest item when full instead of applying backpressure.
2//!
3//! This is useful for scenarios where you want to keep the most recent items and can
4//! tolerate losing older ones, such as real-time data streams or status updates where
5//! only the latest values matter.
6//!
7//! # Example
8//!
9//! ```
10//! use futures::executor::block_on;
11//! use futures::{SinkExt, StreamExt};
12//! use commonware_utils::{NZUsize, channels::ring};
13//!
14//! block_on(async {
15//!     let (mut sender, mut receiver) = ring::channel::<u32>(NZUsize!(2));
16//!
17//!     // Fill the channel
18//!     sender.send(1).await.unwrap();
19//!     sender.send(2).await.unwrap();
20//!
21//!     // This will drop the oldest item (1) and insert 3
22//!     sender.send(3).await.unwrap();
23//!
24//!     // Receive the remaining items
25//!     assert_eq!(receiver.next().await, Some(2));
26//!     assert_eq!(receiver.next().await, Some(3));
27//! });
28//! ```
29
30use core::num::NonZeroUsize;
31use futures::{stream::FusedStream, Sink, Stream};
32use std::{
33    collections::VecDeque,
34    pin::Pin,
35    sync::{Arc, Mutex},
36    task::{Context, Poll, Waker},
37};
38use thiserror::Error;
39
40/// Error returned when sending to a channel whose receiver has been dropped.
41#[derive(Debug, Error)]
42#[error("channel closed")]
43pub struct ChannelClosed;
44
45struct Shared<T: Send + Sync> {
46    buffer: VecDeque<T>,
47    capacity: usize,
48    receiver_waker: Option<Waker>,
49    sender_count: usize,
50    receiver_dropped: bool,
51}
52
53/// The sending half of a ring channel.
54///
55/// Implements [`Sink`] for sending items. Use [`SinkExt::send`](futures::SinkExt::send)
56/// to send items asynchronously.
57///
58/// This type can be cloned to create multiple producers for the same channel.
59/// The channel remains open until all senders are dropped.
60pub struct Sender<T: Send + Sync> {
61    shared: Arc<Mutex<Shared<T>>>,
62}
63
64impl<T: Send + Sync> Sender<T> {
65    /// Returns whether the receiver has been dropped.
66    ///
67    /// If this returns `true`, subsequent sends will fail with [`ChannelClosed`].
68    pub fn is_closed(&self) -> bool {
69        let shared = self.shared.lock().unwrap();
70        shared.receiver_dropped
71    }
72}
73
74impl<T: Send + Sync> Clone for Sender<T> {
75    fn clone(&self) -> Self {
76        let mut shared = self.shared.lock().unwrap();
77        shared.sender_count += 1;
78        drop(shared);
79
80        Self {
81            shared: self.shared.clone(),
82        }
83    }
84}
85
86impl<T: Send + Sync> Drop for Sender<T> {
87    fn drop(&mut self) {
88        let Ok(mut shared) = self.shared.lock() else {
89            return;
90        };
91        shared.sender_count -= 1;
92        let waker = if shared.sender_count == 0 {
93            shared.receiver_waker.take()
94        } else {
95            None
96        };
97        drop(shared);
98
99        if let Some(w) = waker {
100            w.wake();
101        }
102    }
103}
104
105impl<T: Send + Sync> Sink<T> for Sender<T> {
106    type Error = ChannelClosed;
107
108    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109        let shared = self.shared.lock().unwrap();
110        if shared.receiver_dropped {
111            return Poll::Ready(Err(ChannelClosed));
112        }
113
114        Poll::Ready(Ok(()))
115    }
116
117    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
118        let mut shared = self.shared.lock().unwrap();
119
120        if shared.receiver_dropped {
121            return Err(ChannelClosed);
122        }
123
124        let old_item = if shared.buffer.len() >= shared.capacity {
125            shared.buffer.pop_front()
126        } else {
127            None
128        };
129
130        shared.buffer.push_back(item);
131        let waker = shared.receiver_waker.take();
132        drop(shared);
133
134        // Drop the old item after the lock is released to avoid potential mutex poisoning
135        drop(old_item);
136
137        if let Some(w) = waker {
138            w.wake();
139        }
140
141        Ok(())
142    }
143
144    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
145        // No buffering in the sender - items are sent immediately to the shared buffer
146        Poll::Ready(Ok(()))
147    }
148
149    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150        // Closing is handled by Drop
151        Poll::Ready(Ok(()))
152    }
153}
154
155/// The receiving half of a ring channel.
156///
157/// Implements [`Stream`] and [`FusedStream`] for receiving items. Use
158/// [`StreamExt::next`](futures::StreamExt::next) to receive items asynchronously.
159///
160/// The stream terminates (returns `None`) when all senders have been dropped
161/// and all buffered items have been consumed.
162pub struct Receiver<T: Send + Sync> {
163    shared: Arc<Mutex<Shared<T>>>,
164}
165
166impl<T: Send + Sync> Stream for Receiver<T> {
167    type Item = T;
168
169    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
170        let mut shared = self.shared.lock().unwrap();
171
172        if let Some(item) = shared.buffer.pop_front() {
173            return Poll::Ready(Some(item));
174        }
175
176        if shared.sender_count == 0 {
177            return Poll::Ready(None);
178        }
179
180        if !shared
181            .receiver_waker
182            .as_ref()
183            .is_some_and(|w| w.will_wake(cx.waker()))
184        {
185            shared.receiver_waker = Some(cx.waker().clone());
186        }
187        Poll::Pending
188    }
189}
190
191impl<T: Send + Sync> FusedStream for Receiver<T> {
192    fn is_terminated(&self) -> bool {
193        let shared = self.shared.lock().unwrap();
194        shared.sender_count == 0 && shared.buffer.is_empty()
195    }
196}
197
198impl<T: Send + Sync> Drop for Receiver<T> {
199    fn drop(&mut self) {
200        let Ok(mut shared) = self.shared.lock() else {
201            return;
202        };
203        shared.receiver_dropped = true;
204    }
205}
206
207/// Creates a new ring channel with the specified capacity.
208///
209/// Returns a ([`Sender`], [`Receiver`]) pair. The sender can be cloned to create
210/// multiple producers.
211pub fn channel<T: Send + Sync>(capacity: NonZeroUsize) -> (Sender<T>, Receiver<T>) {
212    let shared = Arc::new(Mutex::new(Shared {
213        buffer: VecDeque::with_capacity(capacity.get()),
214        capacity: capacity.get(),
215        receiver_waker: None,
216        sender_count: 1,
217        receiver_dropped: false,
218    }));
219
220    let sender = Sender {
221        shared: shared.clone(),
222    };
223    let receiver = Receiver { shared };
224
225    (sender, receiver)
226}
227
228#[cfg(test)]
229mod tests {
230    use super::*;
231    use crate::NZUsize;
232    use futures::{executor::block_on, SinkExt, StreamExt};
233
234    #[test]
235    fn test_basic_send_recv() {
236        block_on(async {
237            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
238
239            sender.send(1).await.unwrap();
240            sender.send(2).await.unwrap();
241            sender.send(3).await.unwrap();
242
243            assert_eq!(receiver.next().await, Some(1));
244            assert_eq!(receiver.next().await, Some(2));
245            assert_eq!(receiver.next().await, Some(3));
246        });
247    }
248
249    #[test]
250    fn test_overflow_drops_oldest() {
251        block_on(async {
252            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(2));
253
254            sender.send(1).await.unwrap();
255            sender.send(2).await.unwrap();
256            sender.send(3).await.unwrap(); // Should drop 1
257            sender.send(4).await.unwrap(); // Should drop 2
258
259            assert_eq!(receiver.next().await, Some(3));
260            assert_eq!(receiver.next().await, Some(4));
261        });
262    }
263
264    #[test]
265    fn test_send_after_receiver_dropped() {
266        block_on(async {
267            let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
268            drop(receiver);
269
270            let err = sender.send(1).await.unwrap_err();
271            assert!(matches!(err, ChannelClosed));
272        });
273    }
274
275    #[test]
276    fn test_recv_after_sender_dropped() {
277        block_on(async {
278            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
279
280            sender.send(1).await.unwrap();
281            sender.send(2).await.unwrap();
282            drop(sender);
283
284            assert_eq!(receiver.next().await, Some(1));
285            assert_eq!(receiver.next().await, Some(2));
286            assert_eq!(receiver.next().await, None);
287        });
288    }
289
290    #[test]
291    fn test_stream_collect() {
292        block_on(async {
293            let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
294
295            sender.send(1).await.unwrap();
296            sender.send(2).await.unwrap();
297            sender.send(3).await.unwrap();
298            drop(sender);
299
300            let items: Vec<_> = receiver.collect().await;
301            assert_eq!(items, vec![1, 2, 3]);
302        });
303    }
304
305    #[test]
306    fn test_clone_sender() {
307        block_on(async {
308            let (mut sender1, mut receiver) = channel::<i32>(NZUsize!(10));
309            let mut sender2 = sender1.clone();
310
311            sender1.send(1).await.unwrap();
312            sender2.send(2).await.unwrap();
313
314            assert_eq!(receiver.next().await, Some(1));
315            assert_eq!(receiver.next().await, Some(2));
316        });
317    }
318
319    #[test]
320    fn test_sender_drop_with_clones() {
321        block_on(async {
322            let (sender1, mut receiver) = channel::<i32>(NZUsize!(10));
323            let mut sender2 = sender1.clone();
324
325            drop(sender1);
326
327            // Channel should still be open because sender2 exists
328            sender2.send(1).await.unwrap();
329            assert_eq!(receiver.next().await, Some(1));
330
331            drop(sender2);
332            // Now channel should be closed
333            assert_eq!(receiver.next().await, None);
334        });
335    }
336
337    #[test]
338    fn test_capacity_one() {
339        block_on(async {
340            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(1));
341
342            sender.send(1).await.unwrap();
343            sender.send(2).await.unwrap(); // Drops 1
344
345            assert_eq!(receiver.next().await, Some(2));
346
347            sender.send(1).await.unwrap();
348            sender.send(2).await.unwrap(); // Drops 1
349            sender.send(3).await.unwrap(); // Drops 2
350
351            assert_eq!(receiver.next().await, Some(3));
352        });
353    }
354
355    #[test]
356    fn test_send_all() {
357        block_on(async {
358            let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
359
360            let items = futures::stream::iter(vec![1, 2, 3]);
361            sender.send_all(&mut items.map(Ok)).await.unwrap();
362            drop(sender);
363
364            let received: Vec<_> = receiver.collect().await;
365            assert_eq!(received, vec![1, 2, 3]);
366        });
367    }
368
369    #[test]
370    fn test_fused_stream() {
371        use futures::stream::FusedStream;
372
373        block_on(async {
374            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
375
376            assert!(!receiver.is_terminated());
377
378            sender.send(1).await.unwrap();
379            assert!(!receiver.is_terminated());
380
381            drop(sender);
382            assert!(!receiver.is_terminated()); // Still has item in buffer
383
384            assert_eq!(receiver.next().await, Some(1));
385            assert!(receiver.is_terminated()); // Now terminated
386
387            // Calling next after termination returns None
388            assert_eq!(receiver.next().await, None);
389            assert!(receiver.is_terminated());
390        });
391    }
392
393    #[test]
394    fn test_is_closed() {
395        block_on(async {
396            let (sender, receiver) = channel::<i32>(NZUsize!(10));
397
398            assert!(!sender.is_closed());
399
400            drop(receiver);
401            assert!(sender.is_closed());
402        });
403    }
404}