1use crate::Feedback;
44use std::{
45 collections::VecDeque,
46 fmt,
47 future::poll_fn,
48 num::NonZeroUsize,
49 sync::mpsc::TryRecvError,
50 task::{Context, Poll},
51};
52
53pub trait Policy: Sized {
55 fn handle(overflow: &mut VecDeque<Self>, message: Self) -> bool;
75}
76
77const OVERFLOW_HAS_MESSAGES: usize = 1;
103const OVERFLOW_MUTATION: usize = 2;
104
105cfg_if::cfg_if! {
106 if #[cfg(feature = "loom")] {
107 use loom::future::AtomicWaker;
108 use loom::sync::{
109 atomic::{AtomicBool, AtomicUsize, Ordering},
110 Arc, Mutex, MutexGuard,
111 };
112
113 fn register_waker(waker: &AtomicWaker, task: &std::task::Waker) {
114 waker.register_by_ref(task);
115 }
116
117 fn lock<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
118 mutex.lock().unwrap()
119 }
120
121 struct ReadyState<T> {
122 published: VecDeque<T>,
123 reserved: usize,
124 }
125
126 struct Ready<T> {
127 state: Mutex<ReadyState<T>>,
128 capacity: usize,
129 }
130
131 impl<T> Ready<T> {
132 fn new(capacity: usize) -> Self {
133 Self {
134 state: Mutex::new(ReadyState {
135 published: VecDeque::new(),
136 reserved: 0,
137 }),
138 capacity,
139 }
140 }
141
142 const fn capacity(&self) -> usize {
143 self.capacity
144 }
145
146 fn push(&self, message: T) -> Result<(), T> {
147 {
148 let mut state = lock(&self.state);
149 if state.published.len() + state.reserved >= self.capacity {
150 return Err(message);
151 }
152 state.reserved += 1;
153 }
154
155 loom::thread::yield_now();
156
157 let mut state = lock(&self.state);
158 state.reserved -= 1;
159 state.published.push_back(message);
160 Ok(())
161 }
162
163 fn pop(&self) -> Option<T> {
164 loop {
165 let mut state = lock(&self.state);
166 if let Some(message) = state.published.pop_front() {
167 return Some(message);
168 }
169 if state.reserved == 0 {
170 return None;
171 }
172 drop(state);
173 loom::thread::yield_now();
174 }
175 }
176 }
177 } else {
178 use crossbeam_queue::ArrayQueue;
179 use futures_util::task::AtomicWaker;
180 use parking_lot::{Mutex, MutexGuard};
181 use std::sync::{
182 atomic::{AtomicBool, AtomicUsize, Ordering},
183 Arc,
184 };
185
186 fn register_waker(waker: &AtomicWaker, task: &std::task::Waker) {
187 waker.register(task);
188 }
189
190 fn lock<T>(mutex: &Mutex<T>) -> MutexGuard<'_, T> {
191 mutex.lock()
192 }
193
194 struct Ready<T> {
195 queue: ArrayQueue<T>,
196 }
197
198 impl<T> Ready<T> {
199 fn new(capacity: usize) -> Self {
200 Self {
201 queue: ArrayQueue::new(capacity),
202 }
203 }
204
205 fn capacity(&self) -> usize {
206 self.queue.capacity()
207 }
208
209 fn push(&self, message: T) -> Result<(), T> {
210 self.queue.push(message)
211 }
212
213 fn pop(&self) -> Option<T> {
214 self.queue.pop()
215 }
216 }
217 }
218}
219
220struct Overflow<T> {
221 queue: Mutex<VecDeque<T>>,
222 activity: AtomicUsize,
223}
224
225impl<T> Overflow<T> {
226 #[allow(clippy::missing_const_for_fn)]
227 fn new() -> Self {
228 Self {
229 queue: Mutex::new(VecDeque::new()),
230 activity: AtomicUsize::new(0),
231 }
232 }
233
234 fn try_ready(&self, ready: &Ready<T>, message: T) -> Result<(), T> {
235 if self.activity.load(Ordering::Relaxed) != 0 {
237 return Err(message);
238 }
239 ready.push(message)
240 }
241
242 fn enqueue(&self, ready: &Ready<T>, message: T, is_closed: impl Fn() -> bool) -> Feedback
243 where
244 T: Policy,
245 {
246 let mutation = Mutation::begin(&self.activity);
248 let mut queue = lock(&self.queue);
249 if is_closed() {
250 mutation.publish(&queue);
251 return Feedback::Closed;
252 }
253
254 let message = if queue.is_empty() {
258 match ready.push(message) {
259 Ok(()) => {
260 mutation.publish(&queue);
261 return Feedback::Ok;
262 }
263 Err(message) => message,
264 }
265 } else {
266 message
267 };
268
269 let feedback = if T::handle(&mut queue, message) {
271 Feedback::Backoff
272 } else {
273 Feedback::Dropped
274 };
275 mutation.publish(&queue);
276 feedback
277 }
278
279 fn refill(&self, ready: &Ready<T>) {
280 if self.activity.load(Ordering::Relaxed) & OVERFLOW_HAS_MESSAGES == 0 {
282 return;
283 }
284
285 let mutation = Mutation::begin(&self.activity);
286 let mut queue = lock(&self.queue);
287 while let Some(message) = queue.pop_front() {
288 match ready.push(message) {
289 Ok(()) => {}
290 Err(message) => {
291 queue.push_front(message);
292 break;
293 }
294 }
295 }
296 mutation.publish(&queue);
297 }
298}
299
300struct Mutation<'a> {
301 activity: &'a AtomicUsize,
302}
303
304impl<'a> Mutation<'a> {
305 fn begin(activity: &'a AtomicUsize) -> Self {
306 activity.fetch_add(OVERFLOW_MUTATION, Ordering::Relaxed);
307 Self { activity }
308 }
309
310 fn publish<T>(&self, queue: &VecDeque<T>) {
311 if queue.is_empty() {
312 self.activity
313 .fetch_and(!OVERFLOW_HAS_MESSAGES, Ordering::Relaxed);
314 } else {
315 self.activity
316 .fetch_or(OVERFLOW_HAS_MESSAGES, Ordering::Relaxed);
317 }
318 }
319}
320
321impl Drop for Mutation<'_> {
322 fn drop(&mut self) {
323 let previous = self
324 .activity
325 .fetch_sub(OVERFLOW_MUTATION, Ordering::Relaxed);
326 assert!(previous >= OVERFLOW_MUTATION);
327 }
328}
329
330struct State<T> {
331 ready: Ready<T>,
332 overflow: Overflow<T>,
333 closed: AtomicBool,
334 senders: AtomicUsize,
335 waker: AtomicWaker,
336}
337
338pub struct Sender<T: Policy> {
340 state: Arc<State<T>>,
341}
342
343impl<T: Policy> Clone for Sender<T> {
344 fn clone(&self) -> Self {
345 self.state.senders.fetch_add(1, Ordering::Relaxed);
347 Self {
348 state: self.state.clone(),
349 }
350 }
351}
352
353impl<T: Policy> Drop for Sender<T> {
354 fn drop(&mut self) {
355 let previous = self.state.senders.fetch_sub(1, Ordering::AcqRel);
356 assert!(previous > 0);
357 if previous == 1 {
359 self.state.waker.wake();
360 }
361 }
362}
363
364impl<T: Policy> fmt::Debug for Sender<T> {
365 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
366 f.debug_struct("Sender")
367 .field("capacity", &self.state.ready.capacity())
368 .field("closed", &self.state.closed.load(Ordering::Acquire))
369 .finish()
370 }
371}
372
373impl<T: Policy> Sender<T> {
374 #[must_use = "caller must handle enqueue feedback"]
376 pub fn enqueue(&self, message: T) -> Feedback {
377 if self.state.closed.load(Ordering::Acquire) {
379 return Feedback::Closed;
380 }
381
382 let message = match self.state.overflow.try_ready(&self.state.ready, message) {
384 Ok(()) => {
385 self.state.waker.wake();
386 return Feedback::Ok;
387 }
388 Err(message) => message,
389 };
390
391 let feedback = self.state.overflow.enqueue(&self.state.ready, message, || {
393 self.state.closed.load(Ordering::Acquire)
394 });
395
396 if feedback != Feedback::Closed {
402 self.state.waker.wake();
403 }
404 feedback
405 }
406}
407
408pub struct Receiver<T> {
417 state: Arc<State<T>>,
418}
419
420impl<T> Receiver<T> {
421 fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<T>> {
422 if let Some(message) = self.pop() {
424 return Poll::Ready(Some(message));
425 }
426
427 if self.is_disconnected() {
428 return Poll::Ready(self.pop());
429 }
430
431 register_waker(&self.state.waker, cx.waker());
432
433 if let Some(message) = self.pop() {
436 return Poll::Ready(Some(message));
437 }
438
439 if self.is_disconnected() {
440 Poll::Ready(self.pop())
441 } else {
442 Poll::Pending
443 }
444 }
445
446 fn pop(&mut self) -> Option<T> {
447 if let Some(message) = self.state.ready.pop() {
448 self.state.overflow.refill(&self.state.ready);
450 return Some(message);
451 }
452
453 self.state.overflow.refill(&self.state.ready);
456 self.state.ready.pop()
457 }
458
459 fn is_disconnected(&self) -> bool {
460 self.state.closed.load(Ordering::Acquire) || self.state.senders.load(Ordering::Acquire) == 0
461 }
462
463 pub async fn recv(&mut self) -> Option<T> {
468 poll_fn(|cx| self.poll_recv(cx)).await
469 }
470
471 pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
476 if let Some(message) = self.pop() {
477 return Ok(message);
478 }
479 if self.is_disconnected() {
480 return self.pop().ok_or(TryRecvError::Disconnected);
481 }
482 Err(TryRecvError::Empty)
483 }
484}
485
486impl<T> Drop for Receiver<T> {
487 fn drop(&mut self) {
488 self.state.closed.store(true, Ordering::Release);
490 }
491}
492
493pub fn new<T: Policy>(capacity: NonZeroUsize) -> (Sender<T>, Receiver<T>) {
495 let state = Arc::new(State {
496 ready: Ready::new(capacity.get()),
497 overflow: Overflow::new(),
498 closed: AtomicBool::new(false),
499 senders: AtomicUsize::new(1),
500 waker: AtomicWaker::new(),
501 });
502 (
503 Sender {
504 state: state.clone(),
505 },
506 Receiver { state },
507 )
508}
509
510#[cfg(all(test, not(feature = "loom")))]
511mod tests {
512 use super::*;
513 use commonware_macros::test_async;
514 use commonware_utils::NZUsize;
515 use futures::{
516 pin_mut,
517 task::{waker_ref, ArcWake},
518 FutureExt,
519 };
520 use std::sync::{
521 atomic::{AtomicUsize, Ordering},
522 mpsc::TryRecvError,
523 Arc,
524 };
525
526 #[derive(Debug, PartialEq, Eq)]
527 enum Message {
528 Update(u64),
529 Vote(u64),
530 Required(u64),
531 Buffered(u64),
532 Hint(u64),
533 }
534
535 impl Policy for Message {
536 fn handle(overflow: &mut VecDeque<Self>, message: Self) -> bool {
537 match message {
538 Self::Update(value) => {
539 if let Some(index) = overflow
540 .iter()
541 .rposition(|pending| matches!(pending, Self::Update(_)))
542 {
543 overflow.remove(index);
544 }
545 overflow.push_back(Self::Update(value));
546 true
547 }
548 Self::Required(_) | Self::Buffered(_) => {
549 overflow.push_back(message);
550 true
551 }
552 Self::Hint(value) => {
553 let Some(index) = overflow
554 .iter()
555 .rposition(|pending| matches!(pending, Self::Update(_)))
556 else {
557 return false;
558 };
559 overflow.remove(index);
560 overflow.push_back(Self::Hint(value));
561 true
562 }
563 Self::Vote(_) => false,
564 }
565 }
566 }
567
568 #[derive(Default)]
569 struct WakeCounter {
570 wakes: AtomicUsize,
571 }
572
573 impl WakeCounter {
574 fn count(&self) -> usize {
575 self.wakes.load(Ordering::Acquire)
576 }
577 }
578
579 impl ArcWake for WakeCounter {
580 fn wake_by_ref(arc_self: &Arc<Self>) {
581 arc_self.wakes.fetch_add(1, Ordering::AcqRel);
582 }
583 }
584
585 #[test_async]
586 async fn full_inbox_replaces_stale_overflow_message() {
587 let (sender, mut receiver) = new(NZUsize!(1));
588 assert_eq!(sender.enqueue(Message::Update(1)), Feedback::Ok);
589 assert_eq!(sender.enqueue(Message::Update(2)), Feedback::Backoff);
590 assert_eq!(sender.enqueue(Message::Update(3)), Feedback::Backoff);
591
592 assert_eq!(receiver.recv().await, Some(Message::Update(1)));
593 assert_eq!(receiver.recv().await, Some(Message::Update(3)));
594 }
595
596 #[test_async]
597 async fn policy_can_replace_stale_overflow_at_back() {
598 let (sender, mut receiver) = new(NZUsize!(1));
599 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
600 assert_eq!(sender.enqueue(Message::Update(2)), Feedback::Backoff);
601 assert_eq!(sender.enqueue(Message::Required(3)), Feedback::Backoff);
602 assert_eq!(sender.enqueue(Message::Update(4)), Feedback::Backoff);
603
604 assert_eq!(receiver.recv().await, Some(Message::Vote(1)));
605 assert_eq!(receiver.recv().await, Some(Message::Required(3)));
606 assert_eq!(receiver.recv().await, Some(Message::Update(4)));
607 }
608
609 #[test_async]
610 async fn full_inbox_rejects_non_replaceable_message() {
611 let (sender, mut receiver) = new(NZUsize!(1));
612 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
613 assert_eq!(sender.enqueue(Message::Vote(2)), Feedback::Dropped);
614
615 assert_eq!(receiver.recv().await, Some(Message::Vote(1)));
616 }
617
618 #[test_async]
619 async fn full_inbox_retains_required_message() {
620 let (sender, mut receiver) = new(NZUsize!(1));
621 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
622 assert_eq!(sender.enqueue(Message::Buffered(2)), Feedback::Backoff);
623
624 assert_eq!(receiver.recv().await, Some(Message::Vote(1)));
625 assert_eq!(receiver.recv().await, Some(Message::Buffered(2)));
626 }
627
628 #[test]
629 fn try_recv_refills_from_overflow() {
630 let (sender, mut receiver) = new(NZUsize!(1));
631 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
632 assert_eq!(sender.enqueue(Message::Buffered(2)), Feedback::Backoff);
633
634 assert_eq!(receiver.try_recv(), Ok(Message::Vote(1)));
635 assert_eq!(receiver.try_recv(), Ok(Message::Buffered(2)));
636 }
637
638 #[test]
639 fn try_recv_drains_buffered_messages_after_senders_drop() {
640 let (sender, mut receiver) = new(NZUsize!(1));
641 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
642 assert_eq!(sender.enqueue(Message::Buffered(2)), Feedback::Backoff);
643 drop(sender);
644
645 assert_eq!(receiver.try_recv(), Ok(Message::Vote(1)));
646 assert_eq!(receiver.try_recv(), Ok(Message::Buffered(2)));
647 assert_eq!(receiver.try_recv(), Err(TryRecvError::Disconnected));
648 }
649
650 #[test]
651 fn poll_recv_drains_buffered_messages_after_senders_drop() {
652 let (sender, mut receiver) = new(NZUsize!(1));
653 let wakes = Arc::new(WakeCounter::default());
654 let waker = waker_ref(&wakes);
655 let mut cx = Context::from_waker(&waker);
656
657 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
658 assert_eq!(sender.enqueue(Message::Buffered(2)), Feedback::Backoff);
659 drop(sender);
660
661 assert_eq!(
662 receiver.poll_recv(&mut cx),
663 Poll::Ready(Some(Message::Vote(1)))
664 );
665 assert_eq!(
666 receiver.poll_recv(&mut cx),
667 Poll::Ready(Some(Message::Buffered(2)))
668 );
669 assert_eq!(receiver.poll_recv(&mut cx), Poll::Ready(None));
670 }
671
672 #[test]
673 fn enqueue_uses_ready_capacity_after_partial_drain() {
674 let (sender, mut receiver) = new(NZUsize!(2));
675 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
676 assert_eq!(sender.enqueue(Message::Vote(2)), Feedback::Ok);
677 assert_eq!(sender.enqueue(Message::Required(3)), Feedback::Backoff);
678
679 assert_eq!(receiver.try_recv(), Ok(Message::Vote(1)));
680 assert_eq!(receiver.try_recv(), Ok(Message::Vote(2)));
681
682 assert_eq!(sender.enqueue(Message::Vote(4)), Feedback::Ok);
683 assert_eq!(receiver.try_recv(), Ok(Message::Required(3)));
684 assert_eq!(receiver.try_recv(), Ok(Message::Vote(4)));
685 }
686
687 #[test]
688 fn receiver_refills_overflow_after_partial_drain() {
689 let (sender, mut receiver) = new(NZUsize!(3));
690 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
691 assert_eq!(sender.enqueue(Message::Vote(2)), Feedback::Ok);
692 assert_eq!(sender.enqueue(Message::Vote(3)), Feedback::Ok);
693 assert_eq!(sender.enqueue(Message::Required(4)), Feedback::Backoff);
694
695 assert_eq!(receiver.try_recv(), Ok(Message::Vote(1)));
696 assert_eq!(receiver.try_recv(), Ok(Message::Vote(2)));
697
698 assert_eq!(sender.enqueue(Message::Vote(5)), Feedback::Ok);
699 assert_eq!(receiver.try_recv(), Ok(Message::Vote(3)));
700 assert_eq!(receiver.try_recv(), Ok(Message::Required(4)));
701 assert_eq!(receiver.try_recv(), Ok(Message::Vote(5)));
702 }
703
704 #[test_async]
705 async fn full_inbox_retains_unmatched_replaceable_message() {
706 let (sender, mut receiver) = new(NZUsize!(1));
707 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
708 assert_eq!(sender.enqueue(Message::Required(2)), Feedback::Backoff);
709
710 assert_eq!(receiver.recv().await, Some(Message::Vote(1)));
711 assert_eq!(receiver.recv().await, Some(Message::Required(2)));
712 }
713
714 #[test_async]
715 async fn full_inbox_replaces_stale_overflow_after_ready_fills() {
716 let (sender, mut receiver) = new(NZUsize!(2));
717 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
718 assert_eq!(sender.enqueue(Message::Update(2)), Feedback::Ok);
719 assert_eq!(sender.enqueue(Message::Update(3)), Feedback::Backoff);
720 assert_eq!(sender.enqueue(Message::Update(4)), Feedback::Backoff);
721
722 assert_eq!(receiver.recv().await, Some(Message::Vote(1)));
723 assert_eq!(receiver.recv().await, Some(Message::Update(2)));
724 assert_eq!(receiver.recv().await, Some(Message::Update(4)));
725 }
726
727 #[test_async]
728 async fn mailbox_capacity_is_soft_limit_for_required_messages() {
729 let (sender, mut receiver) = new(NZUsize!(1));
730 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
731 assert_eq!(sender.enqueue(Message::Required(2)), Feedback::Backoff);
732 assert_eq!(sender.enqueue(Message::Required(3)), Feedback::Backoff);
733
734 assert_eq!(receiver.recv().await, Some(Message::Vote(1)));
735 assert_eq!(receiver.recv().await, Some(Message::Required(2)));
736 assert_eq!(receiver.recv().await, Some(Message::Required(3)));
737 }
738
739 #[test_async]
740 async fn full_inbox_rejects_hint() {
741 let (sender, mut receiver) = new(NZUsize!(1));
742 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
743 assert_eq!(sender.enqueue(Message::Hint(2)), Feedback::Dropped);
744
745 assert_eq!(receiver.recv().await, Some(Message::Vote(1)));
746 }
747
748 #[test_async]
749 async fn full_inbox_can_replace_or_drop_by_message() {
750 let (sender, mut receiver) = new(NZUsize!(1));
751 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
752 assert_eq!(sender.enqueue(Message::Update(2)), Feedback::Backoff);
753 assert_eq!(sender.enqueue(Message::Hint(3)), Feedback::Backoff);
754
755 assert_eq!(receiver.recv().await, Some(Message::Vote(1)));
756 assert_eq!(receiver.recv().await, Some(Message::Hint(3)));
757 }
758
759 #[test_async]
760 async fn empty_inbox_wakes_on_enqueue() {
761 let (sender, mut receiver) = new(NZUsize!(1));
762
763 let next = receiver.recv();
764 pin_mut!(next);
765 assert!(next.as_mut().now_or_never().is_none());
766
767 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Ok);
768 assert_eq!(next.await, Some(Message::Vote(1)));
769 }
770
771 #[test]
772 fn pending_recv_wakes_when_senders_drop() {
773 let (sender, mut receiver) = new::<Message>(NZUsize!(1));
774 let wakes = Arc::new(WakeCounter::default());
775 let waker = waker_ref(&wakes);
776 let mut cx = Context::from_waker(&waker);
777
778 assert_eq!(receiver.poll_recv(&mut cx), Poll::Pending);
779 assert_eq!(wakes.count(), 0);
780
781 drop(sender);
782
783 assert_eq!(wakes.count(), 1);
784 assert_eq!(receiver.poll_recv(&mut cx), Poll::Ready(None));
785 }
786
787 #[test]
788 fn pending_recv_wakes_on_handled_overflow_enqueue() {
789 let (sender, mut receiver) = new(NZUsize!(1));
790 let wakes = Arc::new(WakeCounter::default());
791 let waker = waker_ref(&wakes);
792 let mut cx = Context::from_waker(&waker);
793
794 assert_eq!(receiver.poll_recv(&mut cx), Poll::Pending);
795 assert_eq!(wakes.count(), 0);
796
797 assert_eq!(sender.state.ready.push(Message::Vote(1)), Ok(()));
799 assert_eq!(sender.enqueue(Message::Buffered(2)), Feedback::Backoff);
800
801 assert_eq!(wakes.count(), 1);
802 assert_eq!(receiver.try_recv(), Ok(Message::Vote(1)));
803 assert_eq!(receiver.try_recv(), Ok(Message::Buffered(2)));
804 }
805
806 #[test]
807 fn receiver_drop_blocks_ready_fast_path_feedback() {
808 let (sender, mut receiver) = new(NZUsize!(1));
809 let wakes = Arc::new(WakeCounter::default());
810 let waker = waker_ref(&wakes);
811 let mut cx = Context::from_waker(&waker);
812
813 assert_eq!(receiver.poll_recv(&mut cx), Poll::Pending);
814 drop(receiver);
815
816 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Closed);
817 assert_eq!(wakes.count(), 0);
818 }
819
820 #[test_async]
821 async fn empty_inbox_closes_when_senders_drop() {
822 let (sender, mut receiver) = new::<Message>(NZUsize!(1));
823 drop(sender);
824
825 assert_eq!(receiver.try_recv(), Err(TryRecvError::Disconnected));
826 assert_eq!(receiver.recv().await, None);
827 }
828
829 #[test]
830 fn enqueue_after_receiver_drop_returns_closed() {
831 let (sender, receiver) = new(NZUsize!(1));
832 drop(receiver);
833
834 assert_eq!(sender.enqueue(Message::Vote(1)), Feedback::Closed);
835 }
836
837 #[derive(Debug, PartialEq, Eq)]
838 enum ClearingMessage {
839 FillReady,
840 ClearOverflow,
841 }
842
843 impl Policy for ClearingMessage {
844 fn handle(overflow: &mut VecDeque<Self>, message: Self) -> bool {
845 overflow.push_back(message);
846 overflow.clear();
847 true
848 }
849 }
850
851 #[test]
852 fn policy_can_clear_overflow_and_request_backoff() {
853 let (sender, mut receiver) = new(NZUsize!(1));
854 assert_eq!(sender.enqueue(ClearingMessage::FillReady), Feedback::Ok);
855 assert_eq!(
856 sender.enqueue(ClearingMessage::ClearOverflow),
857 Feedback::Backoff
858 );
859
860 assert!(matches!(
861 receiver.try_recv(),
862 Ok(ClearingMessage::FillReady)
863 ));
864 assert_eq!(receiver.try_recv(), Err(TryRecvError::Empty));
865 }
866
867 #[derive(Debug, PartialEq, Eq)]
868 enum SpillAndDropMessage {
869 FillReady,
870 SpillAndDrop,
871 }
872
873 impl Policy for SpillAndDropMessage {
874 fn handle(overflow: &mut VecDeque<Self>, message: Self) -> bool {
875 overflow.push_back(message);
876 false
877 }
878 }
879
880 #[test]
881 fn pending_recv_wakes_when_policy_spills_and_reports_dropped() {
882 let (sender, mut receiver) = new(NZUsize!(1));
883 let wakes = Arc::new(WakeCounter::default());
884 let waker = waker_ref(&wakes);
885 let mut cx = Context::from_waker(&waker);
886
887 assert_eq!(receiver.poll_recv(&mut cx), Poll::Pending);
888 assert_eq!(wakes.count(), 0);
889
890 assert_eq!(
891 sender.state.ready.push(SpillAndDropMessage::FillReady),
892 Ok(())
893 );
894 assert_eq!(
895 sender.enqueue(SpillAndDropMessage::SpillAndDrop),
896 Feedback::Dropped
897 );
898
899 assert_eq!(wakes.count(), 1);
900 assert_eq!(receiver.try_recv(), Ok(SpillAndDropMessage::FillReady));
901 assert_eq!(receiver.try_recv(), Ok(SpillAndDropMessage::SpillAndDrop));
902 }
903}
904
905#[cfg(all(test, feature = "loom"))]
906mod loom_tests {
907 use super::*;
908 use commonware_utils::NZUsize;
909 use futures::pin_mut;
910 use loom::{
911 sync::{
912 atomic::{AtomicUsize, Ordering},
913 Arc,
914 },
915 thread,
916 };
917 use std::{
918 future::Future,
919 task::{RawWaker, RawWakerVTable, Waker},
920 };
921
922 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
923 enum Message {
924 Drop(u8),
925 Spill(u8),
926 }
927
928 #[derive(Clone, Debug)]
929 enum OrderedMessage {
930 Item(u8),
931 Coordinated(u8, Arc<AtomicUsize>),
932 }
933
934 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
935 enum ReplacingMessage {
936 FillReady,
937 Replace(u8),
938 }
939
940 impl Policy for Message {
941 fn handle(overflow: &mut VecDeque<Self>, message: Self) -> bool {
942 match message {
943 Self::Drop(_) => false,
944 Self::Spill(_) => {
945 overflow.push_back(message);
946 true
947 }
948 }
949 }
950 }
951
952 impl Policy for OrderedMessage {
953 fn handle(overflow: &mut VecDeque<Self>, message: Self) -> bool {
954 let gate = match &message {
955 Self::Item(_) => None,
956 Self::Coordinated(_, gate) => Some(gate.clone()),
957 };
958 overflow.push_back(message);
959 if let Some(gate) = gate {
960 gate.store(1, Ordering::Release);
961 while gate.load(Ordering::Acquire) == 1 {
962 thread::yield_now();
963 }
964 }
965 true
966 }
967 }
968
969 impl Policy for ReplacingMessage {
970 fn handle(overflow: &mut VecDeque<Self>, message: Self) -> bool {
971 match message {
972 Self::FillReady => false,
973 Self::Replace(_) => {
974 if let Some(pending) = overflow
975 .iter_mut()
976 .rev()
977 .find(|pending| matches!(pending, Self::Replace(_)))
978 {
979 *pending = message;
980 } else {
981 overflow.push_back(message);
982 }
983 true
984 }
985 }
986 }
987 }
988
989 fn record(seen: &AtomicUsize, message: Message) {
990 let value = match message {
991 Message::Drop(value) | Message::Spill(value) => value,
992 };
993 seen.fetch_or(1usize << usize::from(value), Ordering::AcqRel);
994 }
995
996 fn value(message: OrderedMessage) -> u8 {
997 match message {
998 OrderedMessage::Item(value) | OrderedMessage::Coordinated(value, _) => value,
999 }
1000 }
1001
1002 const fn replacement_value(message: ReplacingMessage) -> Option<u8> {
1003 match message {
1004 ReplacingMessage::FillReady => None,
1005 ReplacingMessage::Replace(value) => Some(value),
1006 }
1007 }
1008
1009 unsafe fn clone_counter(data: *const ()) -> RawWaker {
1010 let wakes = unsafe { Arc::<AtomicUsize>::from_raw(data.cast()) };
1013 let cloned = wakes.clone();
1014 let _ = Arc::into_raw(wakes);
1015 RawWaker::new(Arc::into_raw(cloned).cast(), &COUNTER_WAKER_VTABLE)
1016 }
1017
1018 unsafe fn wake_counter(data: *const ()) {
1019 let wakes = unsafe { Arc::<AtomicUsize>::from_raw(data.cast()) };
1022 wakes.fetch_add(1, Ordering::AcqRel);
1023 }
1024
1025 unsafe fn wake_counter_by_ref(data: *const ()) {
1026 let wakes = unsafe { Arc::<AtomicUsize>::from_raw(data.cast()) };
1029 wakes.fetch_add(1, Ordering::AcqRel);
1030 let _ = Arc::into_raw(wakes);
1031 }
1032
1033 unsafe fn drop_counter(data: *const ()) {
1034 unsafe {
1037 drop(Arc::<AtomicUsize>::from_raw(data.cast()));
1038 }
1039 }
1040
1041 static COUNTER_WAKER_VTABLE: RawWakerVTable = RawWakerVTable::new(
1042 clone_counter,
1043 wake_counter,
1044 wake_counter_by_ref,
1045 drop_counter,
1046 );
1047
1048 fn counting_waker(wakes: Arc<AtomicUsize>) -> Waker {
1049 let raw = RawWaker::new(Arc::into_raw(wakes).cast(), &COUNTER_WAKER_VTABLE);
1050 unsafe { Waker::from_raw(raw) }
1053 }
1054
1055 #[test]
1056 fn sender_drop_racing_waker_registration_wakes_or_disconnects() {
1057 loom::model(|| {
1058 let (sender, mut receiver) = new::<Message>(NZUsize!(1));
1059 let wakes = Arc::new(AtomicUsize::new(0));
1060 let waker = counting_waker(wakes.clone());
1061 let mut cx = Context::from_waker(&waker);
1062
1063 let close = thread::spawn(move || {
1064 drop(sender);
1065 });
1066
1067 let poll = receiver.poll_recv(&mut cx);
1068 close.join().unwrap();
1069
1070 match poll {
1071 Poll::Ready(None) => {}
1072 Poll::Pending => {
1073 assert!(wakes.load(Ordering::Acquire) > 0);
1074 assert_eq!(receiver.poll_recv(&mut cx), Poll::Ready(None));
1075 }
1076 Poll::Ready(Some(_)) => panic!("unexpected message"),
1077 }
1078 });
1079 }
1080
1081 #[test]
1082 fn sender_enqueue_then_drop_racing_poll_recv_drains_message() {
1083 loom::model(|| {
1084 let (sender, mut receiver) = new::<Message>(NZUsize!(1));
1085 let wakes = Arc::new(AtomicUsize::new(0));
1086 let waker = counting_waker(wakes.clone());
1087 let mut cx = Context::from_waker(&waker);
1088
1089 let enqueue = thread::spawn(move || {
1090 assert_eq!(sender.enqueue(Message::Spill(0)), Feedback::Ok);
1091 });
1092
1093 let poll = receiver.poll_recv(&mut cx);
1094 enqueue.join().unwrap();
1095
1096 match poll {
1097 Poll::Ready(Some(Message::Spill(0))) => {}
1098 Poll::Pending => {
1099 assert!(wakes.load(Ordering::Acquire) > 0);
1100 assert_eq!(
1101 receiver.poll_recv(&mut cx),
1102 Poll::Ready(Some(Message::Spill(0)))
1103 );
1104 }
1105 Poll::Ready(None) => panic!("disconnected before draining message"),
1106 Poll::Ready(Some(message)) => panic!("unexpected message: {message:?}"),
1107 }
1108
1109 assert_eq!(receiver.poll_recv(&mut cx), Poll::Ready(None));
1110 });
1111 }
1112
1113 #[test]
1114 fn sender_enqueue_then_drop_racing_try_recv_drains_message() {
1115 loom::model(|| {
1116 let (sender, mut receiver) = new::<Message>(NZUsize!(1));
1117
1118 let enqueue = thread::spawn(move || {
1119 assert_eq!(sender.enqueue(Message::Spill(0)), Feedback::Ok);
1120 });
1121
1122 let result = receiver.try_recv();
1123 enqueue.join().unwrap();
1124
1125 match result {
1126 Ok(Message::Spill(0)) => {}
1127 Err(TryRecvError::Empty) => {
1128 assert_eq!(receiver.try_recv(), Ok(Message::Spill(0)));
1129 }
1130 Err(TryRecvError::Disconnected) => {
1131 panic!("disconnected before draining message");
1132 }
1133 Ok(message) => panic!("unexpected message: {message:?}"),
1134 }
1135
1136 assert_eq!(receiver.try_recv(), Err(TryRecvError::Disconnected));
1137 });
1138 }
1139
1140 #[test]
1141 fn handled_enqueue_wakes_registered_receiver() {
1142 loom::model(|| {
1143 let (sender, mut receiver) = new::<Message>(NZUsize!(1));
1144 let wakes = Arc::new(AtomicUsize::new(0));
1145 let waker = counting_waker(wakes.clone());
1146 let mut cx = Context::from_waker(&waker);
1147
1148 let next = receiver.recv();
1149 pin_mut!(next);
1150 assert!(matches!(next.as_mut().poll(&mut cx), Poll::Pending));
1151 assert_eq!(sender.enqueue(Message::Spill(0)), Feedback::Ok);
1152
1153 assert_eq!(wakes.load(Ordering::Acquire), 1);
1154 assert_eq!(
1155 next.as_mut().poll(&mut cx),
1156 Poll::Ready(Some(Message::Spill(0)))
1157 );
1158 });
1159 }
1160
1161 #[test]
1162 fn receiver_drop_racing_ready_fast_path_feedback_wakes_if_ready() {
1163 loom::model(|| {
1164 let (sender, mut receiver) = new::<Message>(NZUsize!(1));
1165 let wakes = Arc::new(AtomicUsize::new(0));
1166 let waker = counting_waker(wakes.clone());
1167 let mut cx = Context::from_waker(&waker);
1168
1169 assert_eq!(receiver.poll_recv(&mut cx), Poll::Pending);
1170
1171 let close = thread::spawn(move || {
1172 drop(receiver);
1173 });
1174 let feedback = sender.enqueue(Message::Spill(0));
1175 close.join().unwrap();
1176
1177 match feedback {
1178 Feedback::Ok | Feedback::Backoff => assert!(wakes.load(Ordering::Acquire) > 0),
1179 Feedback::Closed => {}
1180 feedback => panic!("unexpected feedback: {feedback:?}"),
1181 }
1182 assert_eq!(sender.enqueue(Message::Spill(1)), Feedback::Closed);
1183 });
1184 }
1185
1186 #[test]
1187 fn concurrent_close_and_ready_enqueue_remains_closed() {
1188 loom::model(|| {
1189 let (sender, receiver) = new::<Message>(NZUsize!(1));
1190
1191 let enqueue_sender = sender.clone();
1192 let enqueue = thread::spawn(move || {
1193 let _ = enqueue_sender.enqueue(Message::Spill(1));
1194 });
1195
1196 let close = thread::spawn(move || {
1197 drop(receiver);
1198 });
1199
1200 enqueue.join().unwrap();
1201 close.join().unwrap();
1202 assert_eq!(sender.enqueue(Message::Spill(2)), Feedback::Closed);
1203 });
1204 }
1205
1206 #[test]
1207 fn concurrent_close_and_overflow_enqueue_remains_closed() {
1208 loom::model(|| {
1209 let (sender, receiver) = new::<Message>(NZUsize!(1));
1210 assert_eq!(sender.enqueue(Message::Drop(0)), Feedback::Ok);
1211
1212 let enqueue_sender = sender.clone();
1213 let enqueue = thread::spawn(move || {
1214 let _ = enqueue_sender.enqueue(Message::Spill(1));
1215 });
1216
1217 let close = thread::spawn(move || {
1218 drop(receiver);
1219 });
1220
1221 enqueue.join().unwrap();
1222 close.join().unwrap();
1223 assert_eq!(sender.enqueue(Message::Spill(2)), Feedback::Closed);
1224 });
1225 }
1226
1227 #[test]
1228 fn concurrent_spill_and_refill_preserves_messages() {
1229 loom::model(|| {
1230 let (sender, mut receiver) = new::<Message>(NZUsize!(1));
1231 let idle_sender = sender.clone();
1232 assert_eq!(sender.enqueue(Message::Spill(0)), Feedback::Ok);
1233
1234 let seen = Arc::new(AtomicUsize::new(0));
1235 let enqueue = thread::spawn(move || {
1236 let feedback = sender.enqueue(Message::Spill(1));
1237 assert!(matches!(feedback, Feedback::Ok | Feedback::Backoff));
1238 });
1239
1240 let seen_by_receiver = seen.clone();
1241 let recv = thread::spawn(move || {
1242 if let Ok(message) = receiver.try_recv() {
1243 record(&seen_by_receiver, message);
1244 }
1245 receiver
1246 });
1247
1248 enqueue.join().unwrap();
1249 let mut receiver = recv.join().unwrap();
1250
1251 while let Ok(message) = receiver.try_recv() {
1252 record(&seen, message);
1253 }
1254 assert_eq!(receiver.try_recv(), Err(TryRecvError::Empty));
1255 drop(idle_sender);
1256 assert_eq!(seen.load(Ordering::Acquire), 0b11);
1257 });
1258 }
1259
1260 #[test]
1261 fn concurrent_spill_senders_preserve_messages() {
1262 loom::model(|| {
1263 let (sender, mut receiver) = new::<Message>(NZUsize!(1));
1264 let idle_sender = sender.clone();
1265 assert_eq!(sender.enqueue(Message::Spill(0)), Feedback::Ok);
1266
1267 let sender_1 = sender.clone();
1268 let enqueue_1 = thread::spawn(move || sender_1.enqueue(Message::Spill(1)));
1269 let enqueue_2 = thread::spawn(move || sender.enqueue(Message::Spill(2)));
1270
1271 let seen = Arc::new(AtomicUsize::new(0));
1272
1273 assert!(matches!(
1274 enqueue_1.join().unwrap(),
1275 Feedback::Ok | Feedback::Backoff
1276 ));
1277 assert!(matches!(
1278 enqueue_2.join().unwrap(),
1279 Feedback::Ok | Feedback::Backoff
1280 ));
1281
1282 while let Ok(message) = receiver.try_recv() {
1283 record(&seen, message);
1284 }
1285 assert_eq!(receiver.try_recv(), Err(TryRecvError::Empty));
1286 drop(idle_sender);
1287 assert_eq!(seen.load(Ordering::Acquire), 0b111);
1288 });
1289 }
1290
1291 #[test]
1292 fn concurrent_replace_keeps_one_overflow_message() {
1293 loom::model(|| {
1294 let (sender, mut receiver) = new::<ReplacingMessage>(NZUsize!(1));
1295 let idle_sender = sender.clone();
1296 assert_eq!(sender.enqueue(ReplacingMessage::FillReady), Feedback::Ok);
1297 assert_eq!(
1298 sender.enqueue(ReplacingMessage::Replace(1)),
1299 Feedback::Backoff
1300 );
1301
1302 let sender_1 = sender.clone();
1303 let replace_1 = thread::spawn(move || sender_1.enqueue(ReplacingMessage::Replace(2)));
1304 let replace_2 = thread::spawn(move || sender.enqueue(ReplacingMessage::Replace(3)));
1305
1306 assert_eq!(replace_1.join().unwrap(), Feedback::Backoff);
1307 assert_eq!(replace_2.join().unwrap(), Feedback::Backoff);
1308 assert_eq!(receiver.try_recv(), Ok(ReplacingMessage::FillReady));
1309
1310 let retained = replacement_value(receiver.try_recv().unwrap()).unwrap();
1311 assert!(retained == 2 || retained == 3);
1312 assert_eq!(receiver.try_recv(), Err(TryRecvError::Empty));
1313 drop(idle_sender);
1314 });
1315 }
1316
1317 #[test]
1318 fn stale_overflow_hint_retries_ready_before_policy() {
1319 loom::model(|| {
1320 let (sender, mut receiver) = new::<Message>(NZUsize!(2));
1321 assert_eq!(sender.enqueue(Message::Drop(0)), Feedback::Ok);
1322 assert_eq!(sender.enqueue(Message::Drop(1)), Feedback::Ok);
1323 assert_eq!(sender.enqueue(Message::Spill(2)), Feedback::Backoff);
1324
1325 assert_eq!(receiver.try_recv(), Ok(Message::Drop(0)));
1326 assert_eq!(receiver.try_recv(), Ok(Message::Drop(1)));
1327
1328 assert_eq!(sender.enqueue(Message::Drop(3)), Feedback::Ok);
1329 assert_eq!(receiver.try_recv(), Ok(Message::Spill(2)));
1330 assert_eq!(receiver.try_recv(), Ok(Message::Drop(3)));
1331 });
1332 }
1333
1334 #[test]
1335 fn concurrent_overflow_cannot_be_bypassed_by_ready_fast_path() {
1336 loom::model(|| {
1337 let (sender, mut receiver) = new::<OrderedMessage>(NZUsize!(2));
1338 assert_eq!(sender.enqueue(OrderedMessage::Item(0)), Feedback::Ok);
1339 assert_eq!(sender.enqueue(OrderedMessage::Item(1)), Feedback::Ok);
1340
1341 let gate = Arc::new(AtomicUsize::new(0));
1342 let overflow_sender = sender.clone();
1343 let overflow_gate = gate.clone();
1344 let overflow = thread::spawn(move || {
1345 assert_eq!(
1346 overflow_sender.enqueue(OrderedMessage::Coordinated(2, overflow_gate)),
1347 Feedback::Backoff
1348 );
1349 });
1350
1351 while gate.load(Ordering::Acquire) == 0 {
1352 thread::yield_now();
1353 }
1354
1355 let mut observed = vec![value(receiver.try_recv().unwrap())];
1358 gate.store(2, Ordering::Release);
1359 let feedback = sender.enqueue(OrderedMessage::Item(3));
1360 assert!(matches!(feedback, Feedback::Ok | Feedback::Backoff));
1361
1362 overflow.join().unwrap();
1363 while let Ok(message) = receiver.try_recv() {
1364 observed.push(value(message));
1365 }
1366
1367 assert_eq!(observed, vec![0, 1, 2, 3]);
1368 });
1369 }
1370
1371 #[test]
1372 fn concurrent_overflow_mutation_does_not_hide_published_overflow() {
1373 loom::model(|| {
1374 let (sender, mut receiver) = new::<OrderedMessage>(NZUsize!(1));
1375 assert_eq!(sender.enqueue(OrderedMessage::Item(0)), Feedback::Ok);
1376 assert_eq!(sender.enqueue(OrderedMessage::Item(1)), Feedback::Backoff);
1377
1378 let gate = Arc::new(AtomicUsize::new(0));
1379 let overflow_sender = sender.clone();
1380 let overflow_gate = gate.clone();
1381 let overflow = thread::spawn(move || {
1382 overflow_sender.enqueue(OrderedMessage::Coordinated(2, overflow_gate))
1383 });
1384
1385 while gate.load(Ordering::Acquire) == 0 {
1386 thread::yield_now();
1387 }
1388
1389 let release_gate = gate.clone();
1390 let release = thread::spawn(move || {
1391 release_gate.store(2, Ordering::Release);
1392 });
1393
1394 let receive = thread::spawn(move || {
1395 assert_eq!(receiver.try_recv().map(value), Ok(0));
1396 assert_eq!(receiver.try_recv().map(value), Ok(1));
1397 receiver
1398 });
1399
1400 release.join().unwrap();
1401 let mut receiver = receive.join().unwrap();
1402 assert_eq!(overflow.join().unwrap(), Feedback::Backoff);
1403 assert_eq!(receiver.try_recv().map(value), Ok(2));
1404 });
1405 }
1406
1407 #[test]
1408 fn published_overflow_wakes_pending_receiver() {
1409 loom::model(|| {
1410 let (sender, mut receiver) = new::<OrderedMessage>(NZUsize!(1));
1411 let wakes = Arc::new(AtomicUsize::new(0));
1412 let waker = counting_waker(wakes.clone());
1413 let mut cx = Context::from_waker(&waker);
1414
1415 let gate = Arc::new(AtomicUsize::new(0));
1416 let overflow = {
1417 let next = receiver.recv();
1418 pin_mut!(next);
1419 assert!(matches!(next.as_mut().poll(&mut cx), Poll::Pending));
1420
1421 assert_eq!(sender.enqueue(OrderedMessage::Item(0)), Feedback::Ok);
1422 while wakes.load(Ordering::Acquire) == 0 {
1423 thread::yield_now();
1424 }
1425
1426 let overflow_sender = sender.clone();
1427 let overflow_gate = gate.clone();
1428 let overflow = thread::spawn(move || {
1429 overflow_sender.enqueue(OrderedMessage::Coordinated(1, overflow_gate))
1430 });
1431
1432 while gate.load(Ordering::Acquire) == 0 {
1433 thread::yield_now();
1434 }
1435
1436 assert_eq!(
1437 next.as_mut()
1438 .poll(&mut cx)
1439 .map(|message| message.map(value)),
1440 Poll::Ready(Some(0))
1441 );
1442 overflow
1443 };
1444
1445 {
1446 let next = receiver.recv();
1447 pin_mut!(next);
1448 assert!(matches!(next.as_mut().poll(&mut cx), Poll::Pending));
1449 assert_eq!(wakes.load(Ordering::Acquire), 1);
1450
1451 gate.store(2, Ordering::Release);
1452 while wakes.load(Ordering::Acquire) < 2 {
1453 thread::yield_now();
1454 }
1455
1456 assert_eq!(
1457 next.as_mut()
1458 .poll(&mut cx)
1459 .map(|message| message.map(value)),
1460 Poll::Ready(Some(1))
1461 );
1462 }
1463 assert_eq!(overflow.join().unwrap(), Feedback::Backoff);
1464 });
1465 }
1466
1467 #[test]
1468 fn concurrent_refill_and_enqueue_preserves_overflow_order() {
1469 loom::model(|| {
1470 let (sender, mut receiver) = new::<OrderedMessage>(NZUsize!(1));
1471 assert_eq!(sender.enqueue(OrderedMessage::Item(0)), Feedback::Ok);
1472 assert_eq!(sender.enqueue(OrderedMessage::Item(1)), Feedback::Backoff);
1473
1474 let enqueue = thread::spawn(move || sender.enqueue(OrderedMessage::Item(2)));
1475 let receive = thread::spawn(move || {
1476 assert_eq!(receiver.try_recv().map(value), Ok(0));
1477 receiver
1478 });
1479
1480 let mut receiver = receive.join().unwrap();
1481 assert_eq!(enqueue.join().unwrap(), Feedback::Backoff);
1482 assert_eq!(receiver.try_recv().map(value), Ok(1));
1483 assert_eq!(receiver.try_recv().map(value), Ok(2));
1484 });
1485 }
1486}