commonware_utils/channels/
ring.rs1use core::num::NonZeroUsize;
31use futures::{stream::FusedStream, Sink, Stream};
32use std::{
33 collections::VecDeque,
34 pin::Pin,
35 sync::{Arc, Mutex},
36 task::{Context, Poll, Waker},
37};
38use thiserror::Error;
39
40#[derive(Debug, Error)]
42#[error("channel closed")]
43pub struct ChannelClosed;
44
45struct Shared<T: Send + Sync> {
46 buffer: VecDeque<T>,
47 capacity: usize,
48 receiver_waker: Option<Waker>,
49 sender_count: usize,
50 receiver_dropped: bool,
51}
52
53pub struct Sender<T: Send + Sync> {
61 shared: Arc<Mutex<Shared<T>>>,
62}
63
64impl<T: Send + Sync> Sender<T> {
65 pub fn is_closed(&self) -> bool {
69 let shared = self.shared.lock().unwrap();
70 shared.receiver_dropped
71 }
72}
73
74impl<T: Send + Sync> Clone for Sender<T> {
75 fn clone(&self) -> Self {
76 let mut shared = self.shared.lock().unwrap();
77 shared.sender_count += 1;
78 drop(shared);
79
80 Self {
81 shared: self.shared.clone(),
82 }
83 }
84}
85
86impl<T: Send + Sync> Drop for Sender<T> {
87 fn drop(&mut self) {
88 let Ok(mut shared) = self.shared.lock() else {
89 return;
90 };
91 shared.sender_count -= 1;
92 let waker = if shared.sender_count == 0 {
93 shared.receiver_waker.take()
94 } else {
95 None
96 };
97 drop(shared);
98
99 if let Some(w) = waker {
100 w.wake();
101 }
102 }
103}
104
105impl<T: Send + Sync> Sink<T> for Sender<T> {
106 type Error = ChannelClosed;
107
108 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109 let shared = self.shared.lock().unwrap();
110 if shared.receiver_dropped {
111 return Poll::Ready(Err(ChannelClosed));
112 }
113
114 Poll::Ready(Ok(()))
115 }
116
117 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
118 let mut shared = self.shared.lock().unwrap();
119
120 if shared.receiver_dropped {
121 return Err(ChannelClosed);
122 }
123
124 let old_item = if shared.buffer.len() >= shared.capacity {
125 shared.buffer.pop_front()
126 } else {
127 None
128 };
129
130 shared.buffer.push_back(item);
131 let waker = shared.receiver_waker.take();
132 drop(shared);
133
134 drop(old_item);
136
137 if let Some(w) = waker {
138 w.wake();
139 }
140
141 Ok(())
142 }
143
144 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
145 Poll::Ready(Ok(()))
147 }
148
149 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150 Poll::Ready(Ok(()))
152 }
153}
154
155pub struct Receiver<T: Send + Sync> {
163 shared: Arc<Mutex<Shared<T>>>,
164}
165
166impl<T: Send + Sync> Stream for Receiver<T> {
167 type Item = T;
168
169 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
170 let mut shared = self.shared.lock().unwrap();
171
172 if let Some(item) = shared.buffer.pop_front() {
173 return Poll::Ready(Some(item));
174 }
175
176 if shared.sender_count == 0 {
177 return Poll::Ready(None);
178 }
179
180 if !shared
181 .receiver_waker
182 .as_ref()
183 .is_some_and(|w| w.will_wake(cx.waker()))
184 {
185 shared.receiver_waker = Some(cx.waker().clone());
186 }
187 Poll::Pending
188 }
189}
190
191impl<T: Send + Sync> FusedStream for Receiver<T> {
192 fn is_terminated(&self) -> bool {
193 let shared = self.shared.lock().unwrap();
194 shared.sender_count == 0 && shared.buffer.is_empty()
195 }
196}
197
198impl<T: Send + Sync> Drop for Receiver<T> {
199 fn drop(&mut self) {
200 let Ok(mut shared) = self.shared.lock() else {
201 return;
202 };
203 shared.receiver_dropped = true;
204 }
205}
206
207pub fn channel<T: Send + Sync>(capacity: NonZeroUsize) -> (Sender<T>, Receiver<T>) {
212 let shared = Arc::new(Mutex::new(Shared {
213 buffer: VecDeque::with_capacity(capacity.get()),
214 capacity: capacity.get(),
215 receiver_waker: None,
216 sender_count: 1,
217 receiver_dropped: false,
218 }));
219
220 let sender = Sender {
221 shared: shared.clone(),
222 };
223 let receiver = Receiver { shared };
224
225 (sender, receiver)
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::NZUsize;
232 use futures::{executor::block_on, SinkExt, StreamExt};
233
234 #[test]
235 fn test_basic_send_recv() {
236 block_on(async {
237 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
238
239 sender.send(1).await.unwrap();
240 sender.send(2).await.unwrap();
241 sender.send(3).await.unwrap();
242
243 assert_eq!(receiver.next().await, Some(1));
244 assert_eq!(receiver.next().await, Some(2));
245 assert_eq!(receiver.next().await, Some(3));
246 });
247 }
248
249 #[test]
250 fn test_overflow_drops_oldest() {
251 block_on(async {
252 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(2));
253
254 sender.send(1).await.unwrap();
255 sender.send(2).await.unwrap();
256 sender.send(3).await.unwrap(); sender.send(4).await.unwrap(); assert_eq!(receiver.next().await, Some(3));
260 assert_eq!(receiver.next().await, Some(4));
261 });
262 }
263
264 #[test]
265 fn test_send_after_receiver_dropped() {
266 block_on(async {
267 let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
268 drop(receiver);
269
270 let err = sender.send(1).await.unwrap_err();
271 assert!(matches!(err, ChannelClosed));
272 });
273 }
274
275 #[test]
276 fn test_recv_after_sender_dropped() {
277 block_on(async {
278 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
279
280 sender.send(1).await.unwrap();
281 sender.send(2).await.unwrap();
282 drop(sender);
283
284 assert_eq!(receiver.next().await, Some(1));
285 assert_eq!(receiver.next().await, Some(2));
286 assert_eq!(receiver.next().await, None);
287 });
288 }
289
290 #[test]
291 fn test_stream_collect() {
292 block_on(async {
293 let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
294
295 sender.send(1).await.unwrap();
296 sender.send(2).await.unwrap();
297 sender.send(3).await.unwrap();
298 drop(sender);
299
300 let items: Vec<_> = receiver.collect().await;
301 assert_eq!(items, vec![1, 2, 3]);
302 });
303 }
304
305 #[test]
306 fn test_clone_sender() {
307 block_on(async {
308 let (mut sender1, mut receiver) = channel::<i32>(NZUsize!(10));
309 let mut sender2 = sender1.clone();
310
311 sender1.send(1).await.unwrap();
312 sender2.send(2).await.unwrap();
313
314 assert_eq!(receiver.next().await, Some(1));
315 assert_eq!(receiver.next().await, Some(2));
316 });
317 }
318
319 #[test]
320 fn test_sender_drop_with_clones() {
321 block_on(async {
322 let (sender1, mut receiver) = channel::<i32>(NZUsize!(10));
323 let mut sender2 = sender1.clone();
324
325 drop(sender1);
326
327 sender2.send(1).await.unwrap();
329 assert_eq!(receiver.next().await, Some(1));
330
331 drop(sender2);
332 assert_eq!(receiver.next().await, None);
334 });
335 }
336
337 #[test]
338 fn test_capacity_one() {
339 block_on(async {
340 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(1));
341
342 sender.send(1).await.unwrap();
343 sender.send(2).await.unwrap(); assert_eq!(receiver.next().await, Some(2));
346
347 sender.send(1).await.unwrap();
348 sender.send(2).await.unwrap(); sender.send(3).await.unwrap(); assert_eq!(receiver.next().await, Some(3));
352 });
353 }
354
355 #[test]
356 fn test_send_all() {
357 block_on(async {
358 let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
359
360 let items = futures::stream::iter(vec![1, 2, 3]);
361 sender.send_all(&mut items.map(Ok)).await.unwrap();
362 drop(sender);
363
364 let received: Vec<_> = receiver.collect().await;
365 assert_eq!(received, vec![1, 2, 3]);
366 });
367 }
368
369 #[test]
370 fn test_fused_stream() {
371 use futures::stream::FusedStream;
372
373 block_on(async {
374 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
375
376 assert!(!receiver.is_terminated());
377
378 sender.send(1).await.unwrap();
379 assert!(!receiver.is_terminated());
380
381 drop(sender);
382 assert!(!receiver.is_terminated()); assert_eq!(receiver.next().await, Some(1));
385 assert!(receiver.is_terminated()); assert_eq!(receiver.next().await, None);
389 assert!(receiver.is_terminated());
390 });
391 }
392
393 #[test]
394 fn test_is_closed() {
395 block_on(async {
396 let (sender, receiver) = channel::<i32>(NZUsize!(10));
397
398 assert!(!sender.is_closed());
399
400 drop(receiver);
401 assert!(sender.is_closed());
402 });
403 }
404}