commonware_utils/channel/
ring.rs1use crate::sync::Mutex;
31use core::num::NonZeroUsize;
32use futures::{stream::FusedStream, Sink, Stream};
33use std::{
34 collections::VecDeque,
35 pin::Pin,
36 sync::Arc,
37 task::{Context, Poll, Waker},
38};
39use thiserror::Error;
40
41#[derive(Debug, Error)]
43#[error("channel closed")]
44pub struct ChannelClosed;
45
46#[derive(Debug)]
47struct Shared<T: Send + Sync> {
48 buffer: VecDeque<T>,
49 capacity: usize,
50 receiver_waker: Option<Waker>,
51 sender_count: usize,
52 receiver_dropped: bool,
53}
54
55pub struct Sender<T: Send + Sync> {
63 shared: Arc<Mutex<Shared<T>>>,
64}
65
66impl<T: Send + Sync> Sender<T> {
67 pub fn is_closed(&self) -> bool {
71 let shared = self.shared.lock();
72 shared.receiver_dropped
73 }
74}
75
76impl<T: Send + Sync> Clone for Sender<T> {
77 fn clone(&self) -> Self {
78 let mut shared = self.shared.lock();
79 shared.sender_count += 1;
80 drop(shared);
81
82 Self {
83 shared: self.shared.clone(),
84 }
85 }
86}
87
88impl<T: Send + Sync> Drop for Sender<T> {
89 fn drop(&mut self) {
90 let mut shared = self.shared.lock();
91 shared.sender_count -= 1;
92 let waker = if shared.sender_count == 0 {
93 shared.receiver_waker.take()
94 } else {
95 None
96 };
97 drop(shared);
98
99 if let Some(w) = waker {
100 w.wake();
101 }
102 }
103}
104
105impl<T: Send + Sync> Sink<T> for Sender<T> {
106 type Error = ChannelClosed;
107
108 fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
109 let shared = self.shared.lock();
110 if shared.receiver_dropped {
111 return Poll::Ready(Err(ChannelClosed));
112 }
113
114 Poll::Ready(Ok(()))
115 }
116
117 fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
118 let mut shared = self.shared.lock();
119
120 if shared.receiver_dropped {
121 return Err(ChannelClosed);
122 }
123
124 let old_item = if shared.buffer.len() >= shared.capacity {
125 shared.buffer.pop_front()
126 } else {
127 None
128 };
129
130 shared.buffer.push_back(item);
131 let waker = shared.receiver_waker.take();
132 drop(shared);
133
134 drop(old_item);
136
137 if let Some(w) = waker {
138 w.wake();
139 }
140
141 Ok(())
142 }
143
144 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
145 Poll::Ready(Ok(()))
147 }
148
149 fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
150 Poll::Ready(Ok(()))
152 }
153}
154
155#[derive(Debug)]
163pub struct Receiver<T: Send + Sync> {
164 shared: Arc<Mutex<Shared<T>>>,
165}
166
167impl<T: Send + Sync> Stream for Receiver<T> {
168 type Item = T;
169
170 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
171 let mut shared = self.shared.lock();
172
173 if let Some(item) = shared.buffer.pop_front() {
174 return Poll::Ready(Some(item));
175 }
176
177 if shared.sender_count == 0 {
178 return Poll::Ready(None);
179 }
180
181 if !shared
182 .receiver_waker
183 .as_ref()
184 .is_some_and(|w| w.will_wake(cx.waker()))
185 {
186 shared.receiver_waker = Some(cx.waker().clone());
187 }
188 Poll::Pending
189 }
190}
191
192impl<T: Send + Sync> FusedStream for Receiver<T> {
193 fn is_terminated(&self) -> bool {
194 let shared = self.shared.lock();
195 shared.sender_count == 0 && shared.buffer.is_empty()
196 }
197}
198
199impl<T: Send + Sync> Drop for Receiver<T> {
200 fn drop(&mut self) {
201 let mut shared = self.shared.lock();
202 shared.receiver_dropped = true;
203 }
204}
205
206pub fn channel<T: Send + Sync>(capacity: NonZeroUsize) -> (Sender<T>, Receiver<T>) {
211 let shared = Arc::new(Mutex::new(Shared {
212 buffer: VecDeque::with_capacity(capacity.get()),
213 capacity: capacity.get(),
214 receiver_waker: None,
215 sender_count: 1,
216 receiver_dropped: false,
217 }));
218
219 let sender = Sender {
220 shared: shared.clone(),
221 };
222 let receiver = Receiver { shared };
223
224 (sender, receiver)
225}
226
227#[cfg(test)]
228mod tests {
229 use super::*;
230 use crate::NZUsize;
231 use futures::{executor::block_on, SinkExt, StreamExt};
232
233 #[test]
234 fn test_basic_send_recv() {
235 block_on(async {
236 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
237
238 sender.send(1).await.unwrap();
239 sender.send(2).await.unwrap();
240 sender.send(3).await.unwrap();
241
242 assert_eq!(receiver.next().await, Some(1));
243 assert_eq!(receiver.next().await, Some(2));
244 assert_eq!(receiver.next().await, Some(3));
245 });
246 }
247
248 #[test]
249 fn test_overflow_drops_oldest() {
250 block_on(async {
251 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(2));
252
253 sender.send(1).await.unwrap();
254 sender.send(2).await.unwrap();
255 sender.send(3).await.unwrap(); sender.send(4).await.unwrap(); assert_eq!(receiver.next().await, Some(3));
259 assert_eq!(receiver.next().await, Some(4));
260 });
261 }
262
263 #[test]
264 fn test_send_after_receiver_dropped() {
265 block_on(async {
266 let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
267 drop(receiver);
268
269 let err = sender.send(1).await.unwrap_err();
270 assert!(matches!(err, ChannelClosed));
271 });
272 }
273
274 #[test]
275 fn test_recv_after_sender_dropped() {
276 block_on(async {
277 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
278
279 sender.send(1).await.unwrap();
280 sender.send(2).await.unwrap();
281 drop(sender);
282
283 assert_eq!(receiver.next().await, Some(1));
284 assert_eq!(receiver.next().await, Some(2));
285 assert_eq!(receiver.next().await, None);
286 });
287 }
288
289 #[test]
290 fn test_stream_collect() {
291 block_on(async {
292 let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
293
294 sender.send(1).await.unwrap();
295 sender.send(2).await.unwrap();
296 sender.send(3).await.unwrap();
297 drop(sender);
298
299 let items: Vec<_> = receiver.collect().await;
300 assert_eq!(items, vec![1, 2, 3]);
301 });
302 }
303
304 #[test]
305 fn test_clone_sender() {
306 block_on(async {
307 let (mut sender1, mut receiver) = channel::<i32>(NZUsize!(10));
308 let mut sender2 = sender1.clone();
309
310 sender1.send(1).await.unwrap();
311 sender2.send(2).await.unwrap();
312
313 assert_eq!(receiver.next().await, Some(1));
314 assert_eq!(receiver.next().await, Some(2));
315 });
316 }
317
318 #[test]
319 fn test_sender_drop_with_clones() {
320 block_on(async {
321 let (sender1, mut receiver) = channel::<i32>(NZUsize!(10));
322 let mut sender2 = sender1.clone();
323
324 drop(sender1);
325
326 sender2.send(1).await.unwrap();
328 assert_eq!(receiver.next().await, Some(1));
329
330 drop(sender2);
331 assert_eq!(receiver.next().await, None);
333 });
334 }
335
336 #[test]
337 fn test_capacity_one() {
338 block_on(async {
339 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(1));
340
341 sender.send(1).await.unwrap();
342 sender.send(2).await.unwrap(); assert_eq!(receiver.next().await, Some(2));
345
346 sender.send(1).await.unwrap();
347 sender.send(2).await.unwrap(); sender.send(3).await.unwrap(); assert_eq!(receiver.next().await, Some(3));
351 });
352 }
353
354 #[test]
355 fn test_send_all() {
356 block_on(async {
357 let (mut sender, receiver) = channel::<i32>(NZUsize!(10));
358
359 let items = futures::stream::iter(vec![1, 2, 3]);
360 sender.send_all(&mut items.map(Ok)).await.unwrap();
361 drop(sender);
362
363 let received: Vec<_> = receiver.collect().await;
364 assert_eq!(received, vec![1, 2, 3]);
365 });
366 }
367
368 #[test]
369 fn test_fused_stream() {
370 use futures::stream::FusedStream;
371
372 block_on(async {
373 let (mut sender, mut receiver) = channel::<i32>(NZUsize!(10));
374
375 assert!(!receiver.is_terminated());
376
377 sender.send(1).await.unwrap();
378 assert!(!receiver.is_terminated());
379
380 drop(sender);
381 assert!(!receiver.is_terminated()); assert_eq!(receiver.next().await, Some(1));
384 assert!(receiver.is_terminated()); assert_eq!(receiver.next().await, None);
388 assert!(receiver.is_terminated());
389 });
390 }
391
392 #[test]
393 fn test_is_closed() {
394 block_on(async {
395 let (sender, receiver) = channel::<i32>(NZUsize!(10));
396
397 assert!(!sender.is_closed());
398
399 drop(receiver);
400 assert!(sender.is_closed());
401 });
402 }
403}