commonware_utils/channel/
ring.rs1use core::num::NonZeroUsize;
31use futures::{stream::FusedStream, Sink, Stream};
32use std::{
33 collections::VecDeque,
34 pin::Pin,
35 sync::{Arc, Mutex},
36 task::{Context, Poll, Waker},
37};
38use thiserror::Error;
39
40#[derive(Debug, Error)]
42#[error("channel closed")]
43pub struct ChannelClosed;
44
45#[derive(Debug)]
46struct Shared<T: Send + Sync> {
47 buffer: VecDeque<T>,
48 capacity: usize,
49 receiver_waker: Option<Waker>,
50 sender_count: usize,
51 receiver_dropped: bool,
52}
53
54pub struct Sender<T: Send + Sync> {
62 shared: Arc<Mutex<Shared<T>>>,
63}
64
65impl<T: Send + Sync> Sender<T> {
66 pub fn is_closed(&self) -> bool {
70 let shared = self.shared.lock().unwrap();
71 shared.receiver_dropped
72 }
73}
74
75impl<T: Send + Sync> Clone for Sender<T> {
76 fn clone(&self) -> Self {
77 let mut shared = self.shared.lock().unwrap();
78 shared.sender_count += 1;
79 drop(shared);
80
81 Self {
82 shared: self.shared.clone(),
83 }
84 }
85}
86
87impl<T: Send + Sync> Drop for Sender<T> {
88 fn drop(&mut self) {
89 let Ok(mut shared) = self.shared.lock() else {
90 return;
91 };
92 shared.sender_count -= 1;
93 let waker = if shared.sender_count == 0 {
94 shared.receiver_waker.take()
95 } else {
96 None
97 };
98 drop(shared);
99
100 if let Some(w) = waker {
101 w.wake();
102 }
103 }
104}
105
106impl<T: Send + Sync> Sink<T> for Sender<T> {
107 type Error = ChannelClosed;
108
109 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
110 let shared = self.shared.lock().unwrap();
111 if shared.receiver_dropped {
112 return Poll::Ready(Err(ChannelClosed));
113 }
114
115 Poll::Ready(Ok(()))
116 }
117
118 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
119 let mut shared = self.shared.lock().unwrap();
120
121 if shared.receiver_dropped {
122 return Err(ChannelClosed);
123 }
124
125 let old_item = if shared.buffer.len() >= shared.capacity {
126 shared.buffer.pop_front()
127 } else {
128 None
129 };
130
131 shared.buffer.push_back(item);
132 let waker = shared.receiver_waker.take();
133 drop(shared);
134
135 drop(old_item);
137
138 if let Some(w) = waker {
139 w.wake();
140 }
141
142 Ok(())
143 }
144
145 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
146 Poll::Ready(Ok(()))
148 }
149
150 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
151 Poll::Ready(Ok(()))
153 }
154}
155
156#[derive(Debug)]
164pub struct Receiver<T: Send + Sync> {
165 shared: Arc<Mutex<Shared<T>>>,
166}
167
168impl<T: Send + Sync> Stream for Receiver<T> {
169 type Item = T;
170
171 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
172 let mut shared = self.shared.lock().unwrap();
173
174 if let Some(item) = shared.buffer.pop_front() {
175 return Poll::Ready(Some(item));
176 }
177
178 if shared.sender_count == 0 {
179 return Poll::Ready(None);
180 }
181
182 if !shared
183 .receiver_waker
184 .as_ref()
185 .is_some_and(|w| w.will_wake(cx.waker()))
186 {
187 shared.receiver_waker = Some(cx.waker().clone());
188 }
189 Poll::Pending
190 }
191}
192
193impl<T: Send + Sync> FusedStream for Receiver<T> {
194 fn is_terminated(&self) -> bool {
195 let shared = self.shared.lock().unwrap();
196 shared.sender_count == 0 && shared.buffer.is_empty()
197 }
198}
199
200impl<T: Send + Sync> Drop for Receiver<T> {
201 fn drop(&mut self) {
202 let Ok(mut shared) = self.shared.lock() else {
203 return;
204 };
205 shared.receiver_dropped = true;
206 }
207}
208
209pub fn channel<T: Send + Sync>(capacity: NonZeroUsize) -> (Sender<T>, Receiver<T>) {
214 let shared = Arc::new(Mutex::new(Shared {
215 buffer: VecDeque::with_capacity(capacity.get()),
216 capacity: capacity.get(),
217 receiver_waker: None,
218 sender_count: 1,
219 receiver_dropped: false,
220 }));
221
222 let sender = Sender {
223 shared: shared.clone(),
224 };
225 let receiver = Receiver { shared };
226
227 (sender, receiver)
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use crate::NZUsize;
234 use futures::{executor::block_on, SinkExt, StreamExt};
235
236 #[test]
237 fn test_basic_send_recv() {
238 block_on(async {
239 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
240
241 sender.send(1).await.unwrap();
242 sender.send(2).await.unwrap();
243 sender.send(3).await.unwrap();
244
245 assert_eq!(receiver.next().await, Some(1));
246 assert_eq!(receiver.next().await, Some(2));
247 assert_eq!(receiver.next().await, Some(3));
248 });
249 }
250
251 #[test]
252 fn test_overflow_drops_oldest() {
253 block_on(async {
254 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(2));
255
256 sender.send(1).await.unwrap();
257 sender.send(2).await.unwrap();
258 sender.send(3).await.unwrap(); sender.send(4).await.unwrap(); assert_eq!(receiver.next().await, Some(3));
262 assert_eq!(receiver.next().await, Some(4));
263 });
264 }
265
266 #[test]
267 fn test_send_after_receiver_dropped() {
268 block_on(async {
269 let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
270 drop(receiver);
271
272 let err = sender.send(1).await.unwrap_err();
273 assert!(matches!(err, ChannelClosed));
274 });
275 }
276
277 #[test]
278 fn test_recv_after_sender_dropped() {
279 block_on(async {
280 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
281
282 sender.send(1).await.unwrap();
283 sender.send(2).await.unwrap();
284 drop(sender);
285
286 assert_eq!(receiver.next().await, Some(1));
287 assert_eq!(receiver.next().await, Some(2));
288 assert_eq!(receiver.next().await, None);
289 });
290 }
291
292 #[test]
293 fn test_stream_collect() {
294 block_on(async {
295 let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
296
297 sender.send(1).await.unwrap();
298 sender.send(2).await.unwrap();
299 sender.send(3).await.unwrap();
300 drop(sender);
301
302 let items: Vec<_> = receiver.collect().await;
303 assert_eq!(items, vec![1, 2, 3]);
304 });
305 }
306
307 #[test]
308 fn test_clone_sender() {
309 block_on(async {
310 let (mut sender1, mut receiver) = channel::<i32>(NZUsize!(10));
311 let mut sender2 = sender1.clone();
312
313 sender1.send(1).await.unwrap();
314 sender2.send(2).await.unwrap();
315
316 assert_eq!(receiver.next().await, Some(1));
317 assert_eq!(receiver.next().await, Some(2));
318 });
319 }
320
321 #[test]
322 fn test_sender_drop_with_clones() {
323 block_on(async {
324 let (sender1, mut receiver) = channel::<i32>(NZUsize!(10));
325 let mut sender2 = sender1.clone();
326
327 drop(sender1);
328
329 sender2.send(1).await.unwrap();
331 assert_eq!(receiver.next().await, Some(1));
332
333 drop(sender2);
334 assert_eq!(receiver.next().await, None);
336 });
337 }
338
339 #[test]
340 fn test_capacity_one() {
341 block_on(async {
342 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(1));
343
344 sender.send(1).await.unwrap();
345 sender.send(2).await.unwrap(); assert_eq!(receiver.next().await, Some(2));
348
349 sender.send(1).await.unwrap();
350 sender.send(2).await.unwrap(); sender.send(3).await.unwrap(); assert_eq!(receiver.next().await, Some(3));
354 });
355 }
356
357 #[test]
358 fn test_send_all() {
359 block_on(async {
360 let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
361
362 let items = futures::stream::iter(vec![1, 2, 3]);
363 sender.send_all(&mut items.map(Ok)).await.unwrap();
364 drop(sender);
365
366 let received: Vec<_> = receiver.collect().await;
367 assert_eq!(received, vec![1, 2, 3]);
368 });
369 }
370
371 #[test]
372 fn test_fused_stream() {
373 use futures::stream::FusedStream;
374
375 block_on(async {
376 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
377
378 assert!(!receiver.is_terminated());
379
380 sender.send(1).await.unwrap();
381 assert!(!receiver.is_terminated());
382
383 drop(sender);
384 assert!(!receiver.is_terminated()); assert_eq!(receiver.next().await, Some(1));
387 assert!(receiver.is_terminated()); assert_eq!(receiver.next().await, None);
391 assert!(receiver.is_terminated());
392 });
393 }
394
395 #[test]
396 fn test_is_closed() {
397 block_on(async {
398 let (sender, receiver) = channel::<i32>(NZUsize!(10));
399
400 assert!(!sender.is_closed());
401
402 drop(receiver);
403 assert!(sender.is_closed());
404 });
405 }
406}