Skip to main content

commonware_utils/channel/
ring.rs

1//! A bounded mpsc channel that drops the oldest item when full instead of applying backpressure.
2//!
3//! This is useful for scenarios where you want to keep the most recent items and can
4//! tolerate losing older ones, such as real-time data streams or status updates where
5//! only the latest values matter.
6//!
7//! # Example
8//!
9//! ```
10//! use futures::executor::block_on;
11//! use futures::{SinkExt, StreamExt};
12//! use commonware_utils::{NZUsize, channel::ring};
13//!
14//! block_on(async {
15//!     let (mut sender, mut receiver) = ring::channel::<u32>(NZUsize!(2));
16//!
17//!     // Fill the channel
18//!     sender.send(1).await.unwrap();
19//!     sender.send(2).await.unwrap();
20//!
21//!     // This will drop the oldest item (1) and insert 3
22//!     sender.send(3).await.unwrap();
23//!
24//!     // Receive the remaining items
25//!     assert_eq!(receiver.next().await, Some(2));
26//!     assert_eq!(receiver.next().await, Some(3));
27//! });
28//! ```
29
30use crate::sync::Mutex;
31use core::num::NonZeroUsize;
32use futures::{stream::FusedStream, Sink, Stream};
33use std::{
34    collections::VecDeque,
35    pin::Pin,
36    sync::Arc,
37    task::{Context, Poll, Waker},
38};
39use thiserror::Error;
40
41/// Error returned when sending to a channel whose receiver has been dropped.
42#[derive(Debug, Error)]
43#[error("channel closed")]
44pub struct ChannelClosed;
45
46#[derive(Debug)]
47struct Shared<T: Send + Sync> {
48    buffer: VecDeque<T>,
49    capacity: usize,
50    receiver_waker: Option<Waker>,
51    sender_count: usize,
52    receiver_dropped: bool,
53}
54
55/// The sending half of a ring channel.
56///
57/// Implements [`Sink`] for sending items. Use [`SinkExt::send`](futures::SinkExt::send)
58/// to send items asynchronously.
59///
60/// This type can be cloned to create multiple producers for the same channel.
61/// The channel remains open until all senders are dropped.
62pub struct Sender<T: Send + Sync> {
63    shared: Arc<Mutex<Shared<T>>>,
64}
65
66impl<T: Send + Sync> Sender<T> {
67    /// Returns whether the receiver has been dropped.
68    ///
69    /// If this returns `true`, subsequent sends will fail with [`ChannelClosed`].
70    pub fn is_closed(&self) -> bool {
71        let shared = self.shared.lock();
72        shared.receiver_dropped
73    }
74}
75
76impl<T: Send + Sync> Clone for Sender<T> {
77    fn clone(&self) -> Self {
78        let mut shared = self.shared.lock();
79        shared.sender_count += 1;
80        drop(shared);
81
82        Self {
83            shared: self.shared.clone(),
84        }
85    }
86}
87
88impl<T: Send + Sync> Drop for Sender<T> {
89    fn drop(&mut self) {
90        let mut shared = self.shared.lock();
91        shared.sender_count -= 1;
92        let waker = if shared.sender_count == 0 {
93            shared.receiver_waker.take()
94        } else {
95            None
96        };
97        drop(shared);
98
99        if let Some(w) = waker {
100            w.wake();
101        }
102    }
103}
104
105impl<T: Send + Sync> Sink<T> for Sender<T> {
106    type Error = ChannelClosed;
107
108    fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109        let shared = self.shared.lock();
110        if shared.receiver_dropped {
111            return Poll::Ready(Err(ChannelClosed));
112        }
113
114        Poll::Ready(Ok(()))
115    }
116
117    fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
118        let mut shared = self.shared.lock();
119
120        if shared.receiver_dropped {
121            return Err(ChannelClosed);
122        }
123
124        let old_item = if shared.buffer.len() >= shared.capacity {
125            shared.buffer.pop_front()
126        } else {
127            None
128        };
129
130        shared.buffer.push_back(item);
131        let waker = shared.receiver_waker.take();
132        drop(shared);
133
134        // Drop the old item after the lock is released to avoid potential mutex poisoning
135        drop(old_item);
136
137        if let Some(w) = waker {
138            w.wake();
139        }
140
141        Ok(())
142    }
143
144    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
145        // No buffering in the sender - items are sent immediately to the shared buffer
146        Poll::Ready(Ok(()))
147    }
148
149    fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150        // Closing is handled by Drop
151        Poll::Ready(Ok(()))
152    }
153}
154
155/// The receiving half of a ring channel.
156///
157/// Implements [`Stream`] and [`FusedStream`] for receiving items. Use
158/// [`StreamExt::next`](futures::StreamExt::next) to receive items asynchronously.
159///
160/// The stream terminates (returns `None`) when all senders have been dropped
161/// and all buffered items have been consumed.
162#[derive(Debug)]
163pub struct Receiver<T: Send + Sync> {
164    shared: Arc<Mutex<Shared<T>>>,
165}
166
167impl<T: Send + Sync> Stream for Receiver<T> {
168    type Item = T;
169
170    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
171        let mut shared = self.shared.lock();
172
173        if let Some(item) = shared.buffer.pop_front() {
174            return Poll::Ready(Some(item));
175        }
176
177        if shared.sender_count == 0 {
178            return Poll::Ready(None);
179        }
180
181        if !shared
182            .receiver_waker
183            .as_ref()
184            .is_some_and(|w| w.will_wake(cx.waker()))
185        {
186            shared.receiver_waker = Some(cx.waker().clone());
187        }
188        Poll::Pending
189    }
190}
191
192impl<T: Send + Sync> FusedStream for Receiver<T> {
193    fn is_terminated(&self) -> bool {
194        let shared = self.shared.lock();
195        shared.sender_count == 0 && shared.buffer.is_empty()
196    }
197}
198
199impl<T: Send + Sync> Drop for Receiver<T> {
200    fn drop(&mut self) {
201        let mut shared = self.shared.lock();
202        shared.receiver_dropped = true;
203    }
204}
205
206/// Creates a new ring channel with the specified capacity.
207///
208/// Returns a ([`Sender`], [`Receiver`]) pair. The sender can be cloned to create
209/// multiple producers.
210pub fn channel<T: Send + Sync>(capacity: NonZeroUsize) -> (Sender<T>, Receiver<T>) {
211    let shared = Arc::new(Mutex::new(Shared {
212        buffer: VecDeque::with_capacity(capacity.get()),
213        capacity: capacity.get(),
214        receiver_waker: None,
215        sender_count: 1,
216        receiver_dropped: false,
217    }));
218
219    let sender = Sender {
220        shared: shared.clone(),
221    };
222    let receiver = Receiver { shared };
223
224    (sender, receiver)
225}
226
227#[cfg(test)]
228mod tests {
229    use super::*;
230    use crate::NZUsize;
231    use futures::{executor::block_on, SinkExt, StreamExt};
232
233    #[test]
234    fn test_basic_send_recv() {
235        block_on(async {
236            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
237
238            sender.send(1).await.unwrap();
239            sender.send(2).await.unwrap();
240            sender.send(3).await.unwrap();
241
242            assert_eq!(receiver.next().await, Some(1));
243            assert_eq!(receiver.next().await, Some(2));
244            assert_eq!(receiver.next().await, Some(3));
245        });
246    }
247
248    #[test]
249    fn test_overflow_drops_oldest() {
250        block_on(async {
251            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(2));
252
253            sender.send(1).await.unwrap();
254            sender.send(2).await.unwrap();
255            sender.send(3).await.unwrap(); // Should drop 1
256            sender.send(4).await.unwrap(); // Should drop 2
257
258            assert_eq!(receiver.next().await, Some(3));
259            assert_eq!(receiver.next().await, Some(4));
260        });
261    }
262
263    #[test]
264    fn test_send_after_receiver_dropped() {
265        block_on(async {
266            let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
267            drop(receiver);
268
269            let err = sender.send(1).await.unwrap_err();
270            assert!(matches!(err, ChannelClosed));
271        });
272    }
273
274    #[test]
275    fn test_recv_after_sender_dropped() {
276        block_on(async {
277            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
278
279            sender.send(1).await.unwrap();
280            sender.send(2).await.unwrap();
281            drop(sender);
282
283            assert_eq!(receiver.next().await, Some(1));
284            assert_eq!(receiver.next().await, Some(2));
285            assert_eq!(receiver.next().await, None);
286        });
287    }
288
289    #[test]
290    fn test_stream_collect() {
291        block_on(async {
292            let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
293
294            sender.send(1).await.unwrap();
295            sender.send(2).await.unwrap();
296            sender.send(3).await.unwrap();
297            drop(sender);
298
299            let items: Vec<_> = receiver.collect().await;
300            assert_eq!(items, vec![1, 2, 3]);
301        });
302    }
303
304    #[test]
305    fn test_clone_sender() {
306        block_on(async {
307            let (mut sender1, mut receiver) = channel::<i32>(NZUsize!(10));
308            let mut sender2 = sender1.clone();
309
310            sender1.send(1).await.unwrap();
311            sender2.send(2).await.unwrap();
312
313            assert_eq!(receiver.next().await, Some(1));
314            assert_eq!(receiver.next().await, Some(2));
315        });
316    }
317
318    #[test]
319    fn test_sender_drop_with_clones() {
320        block_on(async {
321            let (sender1, mut receiver) = channel::<i32>(NZUsize!(10));
322            let mut sender2 = sender1.clone();
323
324            drop(sender1);
325
326            // Channel should still be open because sender2 exists
327            sender2.send(1).await.unwrap();
328            assert_eq!(receiver.next().await, Some(1));
329
330            drop(sender2);
331            // Now channel should be closed
332            assert_eq!(receiver.next().await, None);
333        });
334    }
335
336    #[test]
337    fn test_capacity_one() {
338        block_on(async {
339            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(1));
340
341            sender.send(1).await.unwrap();
342            sender.send(2).await.unwrap(); // Drops 1
343
344            assert_eq!(receiver.next().await, Some(2));
345
346            sender.send(1).await.unwrap();
347            sender.send(2).await.unwrap(); // Drops 1
348            sender.send(3).await.unwrap(); // Drops 2
349
350            assert_eq!(receiver.next().await, Some(3));
351        });
352    }
353
354    #[test]
355    fn test_send_all() {
356        block_on(async {
357            let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
358
359            let items = futures::stream::iter(vec![1, 2, 3]);
360            sender.send_all(&mut items.map(Ok)).await.unwrap();
361            drop(sender);
362
363            let received: Vec<_> = receiver.collect().await;
364            assert_eq!(received, vec![1, 2, 3]);
365        });
366    }
367
368    #[test]
369    fn test_fused_stream() {
370        use futures::stream::FusedStream;
371
372        block_on(async {
373            let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
374
375            assert!(!receiver.is_terminated());
376
377            sender.send(1).await.unwrap();
378            assert!(!receiver.is_terminated());
379
380            drop(sender);
381            assert!(!receiver.is_terminated()); // Still has item in buffer
382
383            assert_eq!(receiver.next().await, Some(1));
384            assert!(receiver.is_terminated()); // Now terminated
385
386            // Calling next after termination returns None
387            assert_eq!(receiver.next().await, None);
388            assert!(receiver.is_terminated());
389        });
390    }
391
392    #[test]
393    fn test_is_closed() {
394        block_on(async {
395            let (sender, receiver) = channel::<i32>(NZUsize!(10));
396
397            assert!(!sender.is_closed());
398
399            drop(receiver);
400            assert!(sender.is_closed());
401        });
402    }
403}