commonware_utils/channel/
ring.rs1use crate::sync::Mutex;
31use core::num::NonZeroUsize;
32use futures::{stream::FusedStream, Sink, Stream};
33use std::{
34 collections::VecDeque,
35 pin::Pin,
36 sync::Arc,
37 task::{Context, Poll, Waker},
38};
39use thiserror::Error;
40
41#[derive(Debug, Error)]
43#[error("channel closed")]
44pub struct ChannelClosed;
45
46#[derive(Debug, Error, PartialEq, Eq)]
48pub enum TryRecvError {
49 #[error("channel empty")]
51 Empty,
52 #[error("channel closed")]
54 Disconnected,
55}
56
57#[derive(Debug)]
58struct Shared<T: Send + Sync> {
59 buffer: VecDeque<T>,
60 capacity: usize,
61 receiver_waker: Option<Waker>,
62 sender_count: usize,
63 receiver_dropped: bool,
64}
65
66pub struct Sender<T: Send + Sync> {
74 shared: Arc<Mutex<Shared<T>>>,
75}
76
77impl<T: Send + Sync> Sender<T> {
78 pub fn is_closed(&self) -> bool {
82 let shared = self.shared.lock();
83 shared.receiver_dropped
84 }
85}
86
87impl<T: Send + Sync> Clone for Sender<T> {
88 fn clone(&self) -> Self {
89 let mut shared = self.shared.lock();
90 shared.sender_count += 1;
91 drop(shared);
92
93 Self {
94 shared: self.shared.clone(),
95 }
96 }
97}
98
99impl<T: Send + Sync> Drop for Sender<T> {
100 fn drop(&mut self) {
101 let mut shared = self.shared.lock();
102 shared.sender_count -= 1;
103 let waker = if shared.sender_count == 0 {
104 shared.receiver_waker.take()
105 } else {
106 None
107 };
108 drop(shared);
109
110 if let Some(w) = waker {
111 w.wake();
112 }
113 }
114}
115
116impl<T: Send + Sync> Sink<T> for Sender<T> {
117 type Error = ChannelClosed;
118
119 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
120 let shared = self.shared.lock();
121 if shared.receiver_dropped {
122 return Poll::Ready(Err(ChannelClosed));
123 }
124
125 Poll::Ready(Ok(()))
126 }
127
128 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
129 let mut shared = self.shared.lock();
130
131 if shared.receiver_dropped {
132 return Err(ChannelClosed);
133 }
134
135 let old_item = if shared.buffer.len() >= shared.capacity {
136 shared.buffer.pop_front()
137 } else {
138 None
139 };
140
141 shared.buffer.push_back(item);
142 let waker = shared.receiver_waker.take();
143 drop(shared);
144
145 drop(old_item);
147
148 if let Some(w) = waker {
149 w.wake();
150 }
151
152 Ok(())
153 }
154
155 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
156 Poll::Ready(Ok(()))
158 }
159
160 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
161 Poll::Ready(Ok(()))
163 }
164}
165
166#[derive(Debug)]
174pub struct Receiver<T: Send + Sync> {
175 shared: Arc<Mutex<Shared<T>>>,
176}
177
178impl<T: Send + Sync> Receiver<T> {
179 pub async fn recv(&mut self) -> Option<T> {
181 futures::future::poll_fn(|cx| Pin::new(&mut *self).poll_next(cx)).await
182 }
183
184 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
186 let mut shared = self.shared.lock();
187 if let Some(item) = shared.buffer.pop_front() {
188 return Ok(item);
189 }
190 if shared.sender_count == 0 {
191 return Err(TryRecvError::Disconnected);
192 }
193 Err(TryRecvError::Empty)
194 }
195}
196
197impl<T: Send + Sync> Stream for Receiver<T> {
198 type Item = T;
199
200 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
201 let mut shared = self.shared.lock();
202
203 if let Some(item) = shared.buffer.pop_front() {
204 return Poll::Ready(Some(item));
205 }
206
207 if shared.sender_count == 0 {
208 return Poll::Ready(None);
209 }
210
211 if !shared
212 .receiver_waker
213 .as_ref()
214 .is_some_and(|w| w.will_wake(cx.waker()))
215 {
216 shared.receiver_waker = Some(cx.waker().clone());
217 }
218 Poll::Pending
219 }
220}
221
222impl<T: Send + Sync> FusedStream for Receiver<T> {
223 fn is_terminated(&self) -> bool {
224 let shared = self.shared.lock();
225 shared.sender_count == 0 && shared.buffer.is_empty()
226 }
227}
228
229impl<T: Send + Sync> Drop for Receiver<T> {
230 fn drop(&mut self) {
231 let mut shared = self.shared.lock();
232 shared.receiver_dropped = true;
233 }
234}
235
236pub fn channel<T: Send + Sync>(capacity: NonZeroUsize) -> (Sender<T>, Receiver<T>) {
241 let shared = Arc::new(Mutex::new(Shared {
242 buffer: VecDeque::with_capacity(capacity.get()),
243 capacity: capacity.get(),
244 receiver_waker: None,
245 sender_count: 1,
246 receiver_dropped: false,
247 }));
248
249 let sender = Sender {
250 shared: shared.clone(),
251 };
252 let receiver = Receiver { shared };
253
254 (sender, receiver)
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use crate::NZUsize;
261 use futures::{executor::block_on, SinkExt, StreamExt};
262
263 #[test]
264 fn test_basic_send_recv() {
265 block_on(async {
266 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
267
268 sender.send(1).await.unwrap();
269 sender.send(2).await.unwrap();
270 sender.send(3).await.unwrap();
271
272 assert_eq!(receiver.next().await, Some(1));
273 assert_eq!(receiver.next().await, Some(2));
274 assert_eq!(receiver.next().await, Some(3));
275 });
276 }
277
278 #[test]
279 fn test_overflow_drops_oldest() {
280 block_on(async {
281 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(2));
282
283 sender.send(1).await.unwrap();
284 sender.send(2).await.unwrap();
285 sender.send(3).await.unwrap(); sender.send(4).await.unwrap(); assert_eq!(receiver.next().await, Some(3));
289 assert_eq!(receiver.next().await, Some(4));
290 });
291 }
292
293 #[test]
294 fn test_send_after_receiver_dropped() {
295 block_on(async {
296 let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
297 drop(receiver);
298
299 let err = sender.send(1).await.unwrap_err();
300 assert!(matches!(err, ChannelClosed));
301 });
302 }
303
304 #[test]
305 fn test_recv_after_sender_dropped() {
306 block_on(async {
307 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
308
309 sender.send(1).await.unwrap();
310 sender.send(2).await.unwrap();
311 drop(sender);
312
313 assert_eq!(receiver.next().await, Some(1));
314 assert_eq!(receiver.next().await, Some(2));
315 assert_eq!(receiver.next().await, None);
316 });
317 }
318
319 #[test]
320 fn test_stream_collect() {
321 block_on(async {
322 let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
323
324 sender.send(1).await.unwrap();
325 sender.send(2).await.unwrap();
326 sender.send(3).await.unwrap();
327 drop(sender);
328
329 let items: Vec<_> = receiver.collect().await;
330 assert_eq!(items, vec![1, 2, 3]);
331 });
332 }
333
334 #[test]
335 fn test_clone_sender() {
336 block_on(async {
337 let (mut sender1, mut receiver) = channel::<i32>(NZUsize!(10));
338 let mut sender2 = sender1.clone();
339
340 sender1.send(1).await.unwrap();
341 sender2.send(2).await.unwrap();
342
343 assert_eq!(receiver.next().await, Some(1));
344 assert_eq!(receiver.next().await, Some(2));
345 });
346 }
347
348 #[test]
349 fn test_sender_drop_with_clones() {
350 block_on(async {
351 let (sender1, mut receiver) = channel::<i32>(NZUsize!(10));
352 let mut sender2 = sender1.clone();
353
354 drop(sender1);
355
356 sender2.send(1).await.unwrap();
358 assert_eq!(receiver.next().await, Some(1));
359
360 drop(sender2);
361 assert_eq!(receiver.next().await, None);
363 });
364 }
365
366 #[test]
367 fn test_capacity_one() {
368 block_on(async {
369 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(1));
370
371 sender.send(1).await.unwrap();
372 sender.send(2).await.unwrap(); assert_eq!(receiver.next().await, Some(2));
375
376 sender.send(1).await.unwrap();
377 sender.send(2).await.unwrap(); sender.send(3).await.unwrap(); assert_eq!(receiver.next().await, Some(3));
381 });
382 }
383
384 #[test]
385 fn test_send_all() {
386 block_on(async {
387 let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
388
389 let items = futures::stream::iter(vec![1, 2, 3]);
390 sender.send_all(&mut items.map(Ok)).await.unwrap();
391 drop(sender);
392
393 let received: Vec<_> = receiver.collect().await;
394 assert_eq!(received, vec![1, 2, 3]);
395 });
396 }
397
398 #[test]
399 fn test_fused_stream() {
400 use futures::stream::FusedStream;
401
402 block_on(async {
403 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
404
405 assert!(!receiver.is_terminated());
406
407 sender.send(1).await.unwrap();
408 assert!(!receiver.is_terminated());
409
410 drop(sender);
411 assert!(!receiver.is_terminated()); assert_eq!(receiver.next().await, Some(1));
414 assert!(receiver.is_terminated()); assert_eq!(receiver.next().await, None);
418 assert!(receiver.is_terminated());
419 });
420 }
421
422 #[test]
423 fn test_is_closed() {
424 block_on(async {
425 let (sender, receiver) = channel::<i32>(NZUsize!(10));
426
427 assert!(!sender.is_closed());
428
429 drop(receiver);
430 assert!(sender.is_closed());
431 });
432 }
433}