1use std::{
72 fmt::{Debug, Formatter},
73 hash::Hash,
74 pin::pin,
75 sync::{
76 Arc,
77 atomic::{AtomicBool, AtomicUsize, Ordering},
78 },
79};
80
81use crossbeam::queue::SegQueue;
82use scc::hash_map::Entry;
83use tokio::sync::Notify;
84
85#[derive(Debug)]
87pub struct MessageQueueBroker<T: Hash + Eq, M> {
88 inner: Arc<MessageQueueBrokerInner<T, M>>,
89}
90
91impl<T, M> MessageQueueBroker<T, M>
92where
93 T: Hash + Eq + Clone,
94{
95 pub fn unbounded() -> Self {
97 Self {
98 inner: Arc::new(MessageQueueBrokerInner::Unbounded(Unbounded {
99 buckets: Default::default(),
100 is_closed: AtomicBool::new(false),
101 len: AtomicUsize::new(0),
102 })),
103 }
104 }
105
106 pub fn bounded(cap: usize) -> Self {
108 Self {
109 inner: Arc::new(MessageQueueBrokerInner::Bounded(Bounded {
110 buckets: Default::default(),
111 send_notify: Default::default(),
112 is_closed: AtomicBool::new(false),
113 len: AtomicUsize::new(0),
114 cap,
115 })),
116 }
117 }
118
119 pub fn subscribe(&self, tag: T) -> Subscriber<T, M> {
127 MessageQueueBrokerInner::subscribe(&self.inner, tag)
128 }
129}
130
131impl<T, M> MessageQueueBroker<T, M>
132where
133 T: Hash + Eq,
134{
135 pub fn close(&self) {
137 self.inner.close();
138 }
139
140 pub fn is_closed(&self) -> bool {
142 self.inner.is_closed()
143 }
144
145 pub fn len(&self) -> usize {
147 self.inner.len()
148 }
149
150 pub fn is_empty(&self) -> bool {
152 self.inner.is_empty()
153 }
154
155 pub fn try_send<Q>(&self, tag: &Q, msg: M) -> Result<(), TrySendError<M>>
202 where
203 Q: Hash + scc::Equivalent<T> + ?Sized,
204 {
205 self.inner.try_send(tag, msg)
206 }
207
208 pub async fn send<Q>(&self, tag: &Q, msg: M) -> Result<(), SendError<M>>
267 where
268 Q: Hash + scc::Equivalent<T> + ?Sized,
269 {
270 self.inner.send(tag, msg).await
271 }
272}
273
274impl<T, M> Drop for MessageQueueBroker<T, M>
275where
276 T: Hash + Eq,
277{
278 fn drop(&mut self) {
279 self.close();
280 }
281}
282
283#[derive(Debug)]
284enum MessageQueueBrokerInner<T: Hash + Eq, M> {
285 Bounded(Bounded<T, M>),
286 Unbounded(Unbounded<T, M>),
287}
288
289impl<T, M> MessageQueueBrokerInner<T, M>
290where
291 T: Hash + Eq + Clone,
292{
293 fn subscribe(this: &Arc<Self>, tag: T) -> Subscriber<T, M> {
294 let buckets = match &**this {
295 MessageQueueBrokerInner::Bounded(b) => &b.buckets,
296 MessageQueueBrokerInner::Unbounded(b) => &b.buckets,
297 };
298
299 match buckets.entry(tag.clone()) {
300 Entry::Occupied(e) => {
301 let bucket = e.get().clone();
302 bucket.subs.fetch_add(1, Ordering::Release);
303 Subscriber {
304 tag,
305 bucket,
306 broker: Arc::clone(this),
307 }
308 }
309 Entry::Vacant(e) => {
310 let bucket = Arc::new(Bucket {
311 queue: Default::default(),
312 subs: AtomicUsize::new(1),
313 recv_notify: Default::default(),
314 });
315 e.insert_entry(bucket.clone());
316
317 Subscriber {
318 tag,
319 bucket,
320 broker: Arc::clone(this),
321 }
322 }
323 }
324 }
325}
326
327impl<T, M> MessageQueueBrokerInner<T, M>
328where
329 T: Hash + Eq,
330{
331 fn close(&self) {
332 match self {
333 MessageQueueBrokerInner::Bounded(b) => b.close(),
334 MessageQueueBrokerInner::Unbounded(b) => b.close(),
335 }
336 }
337
338 fn is_closed(&self) -> bool {
339 match self {
340 MessageQueueBrokerInner::Bounded(b) => b.is_closed(),
341 MessageQueueBrokerInner::Unbounded(b) => b.is_closed(),
342 }
343 }
344
345 fn len(&self) -> usize {
346 match self {
347 MessageQueueBrokerInner::Bounded(b) => b.len(),
348 MessageQueueBrokerInner::Unbounded(b) => b.len(),
349 }
350 }
351
352 fn is_empty(&self) -> bool {
353 match self {
354 MessageQueueBrokerInner::Bounded(b) => b.is_empty(),
355 MessageQueueBrokerInner::Unbounded(b) => b.is_empty(),
356 }
357 }
358
359 fn try_send<Q>(&self, tag: &Q, msg: M) -> Result<(), TrySendError<M>>
360 where
361 Q: Hash + scc::Equivalent<T> + ?Sized,
362 {
363 match self {
364 MessageQueueBrokerInner::Bounded(b) => b.try_send(tag, msg),
365 MessageQueueBrokerInner::Unbounded(b) => b.try_send(tag, msg),
366 }
367 }
368
369 async fn send<Q>(&self, tag: &Q, msg: M) -> Result<(), SendError<M>>
370 where
371 Q: Hash + scc::Equivalent<T> + ?Sized,
372 {
373 match self {
374 MessageQueueBrokerInner::Bounded(b) => b.send(tag, msg).await,
375 MessageQueueBrokerInner::Unbounded(b) => b.send(tag, msg).await,
376 }
377 }
378
379 fn unsubscribe<Q>(&self, tag: &Q)
380 where
381 Q: Hash + scc::Equivalent<T> + ?Sized,
382 {
383 match self {
384 MessageQueueBrokerInner::Bounded(b) => b.unsubscribe(tag),
385 MessageQueueBrokerInner::Unbounded(b) => b.unsubscribe(tag),
386 }
387 }
388}
389
390#[derive(Debug)]
391struct Bounded<T: Hash + Eq, M> {
392 buckets: scc::HashMap<T, Arc<Bucket<M>>>,
393 send_notify: Notify,
394 is_closed: AtomicBool,
395 len: AtomicUsize,
396 cap: usize,
397}
398
399impl<T, M> Bounded<T, M>
400where
401 T: Hash + Eq,
402{
403 fn close(&self) {
404 self.is_closed.store(true, Ordering::Release);
405 let mut next_entry = self.buckets.first_entry();
406 while let Some(e) = next_entry {
407 e.recv_notify.notify_waiters();
408 next_entry = e.next();
409 }
410 }
411
412 fn is_closed(&self) -> bool {
413 self.is_closed.load(Ordering::Acquire)
414 }
415
416 fn len(&self) -> usize {
417 self.len.load(Ordering::Acquire)
418 }
419
420 fn is_empty(&self) -> bool {
421 self.len() == 0
422 }
423
424 fn try_send<Q>(&self, tag: &Q, msg: M) -> Result<(), TrySendError<M>>
425 where
426 Q: Hash + scc::Equivalent<T> + ?Sized,
427 {
428 if self.is_closed() {
429 return Err(TrySendError::Closed(msg));
430 }
431
432 let Some(bucket) = self.buckets.get(tag) else {
433 return Err(TrySendError::Closed(msg));
434 };
435
436 match self.try_acquire_slot() {
437 Ok(_) => {
438 bucket.queue.push(msg);
439 bucket.recv_notify.notify_one();
440 Ok(())
441 }
442 Err(_) => Err(TrySendError::Full(msg)),
443 }
444 }
445
446 async fn send<Q>(&self, tag: &Q, msg: M) -> Result<(), SendError<M>>
447 where
448 Q: Hash + scc::Equivalent<T> + ?Sized,
449 {
450 let mut notified = pin!(self.send_notify.notified());
451
452 loop {
453 if self.is_closed() {
454 return Err(SendError(msg));
455 }
456
457 {
458 let Some(bucket) = self.buckets.get(tag) else {
459 return Err(SendError(msg));
460 };
461
462 notified.as_mut().enable();
463
464 if self.try_acquire_slot().is_ok() {
465 bucket.queue.push(msg);
466 bucket.recv_notify.notify_one();
467 return Ok(());
468 }
469 }
470
471 notified.as_mut().await;
472 notified.set(self.send_notify.notified());
473 }
474 }
475
476 fn try_acquire_slot(&self) -> Result<(), ()> {
477 self.len
478 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |x| {
479 if x < self.cap { Some(x + 1) } else { None }
480 })
481 .map(|_| ())
482 .map_err(|_| ())
483 }
484
485 fn unsubscribe<Q>(&self, tag: &Q)
486 where
487 Q: Hash + scc::Equivalent<T> + ?Sized,
488 {
489 let Some((_tag, bucket)) = self.buckets.remove(tag) else {
490 return;
491 };
492 self.len.fetch_sub(bucket.queue.len(), Ordering::Release);
493 }
494}
495
496#[derive(Debug)]
497struct Unbounded<T: Hash + Eq, M> {
498 buckets: scc::HashMap<T, Arc<Bucket<M>>>,
499 is_closed: AtomicBool,
500 len: AtomicUsize,
501}
502
503impl<T, M> Unbounded<T, M>
504where
505 T: Hash + Eq,
506{
507 fn close(&self) {
508 self.is_closed.store(true, Ordering::Release);
509 let mut next_entry = self.buckets.first_entry();
510 while let Some(e) = next_entry {
511 e.recv_notify.notify_waiters();
512 next_entry = e.next();
513 }
514 }
515
516 fn is_closed(&self) -> bool {
517 self.is_closed.load(Ordering::Acquire)
518 }
519
520 fn len(&self) -> usize {
521 self.len.load(Ordering::Acquire)
522 }
523
524 fn is_empty(&self) -> bool {
525 self.len() == 0
526 }
527
528 fn try_send<Q>(&self, tag: &Q, msg: M) -> Result<(), TrySendError<M>>
529 where
530 Q: Hash + scc::Equivalent<T> + ?Sized,
531 {
532 if self.is_closed() {
533 return Err(TrySendError::Closed(msg));
534 }
535
536 let Some(bucket) = self.buckets.get(tag) else {
537 return Err(TrySendError::Closed(msg));
538 };
539
540 self.len.fetch_add(1, Ordering::Release);
541 bucket.queue.push(msg);
542 bucket.recv_notify.notify_one();
543 Ok(())
544 }
545
546 async fn send<Q>(&self, tag: &Q, msg: M) -> Result<(), SendError<M>>
547 where
548 Q: Hash + scc::Equivalent<T> + ?Sized,
549 {
550 self.try_send(tag, msg).map_err(|err| match err {
551 TrySendError::Closed(msg) => SendError(msg),
552 TrySendError::Full(_) => unreachable!(),
553 })
554 }
555
556 fn unsubscribe<Q>(&self, tag: &Q)
557 where
558 Q: Hash + scc::Equivalent<T> + ?Sized,
559 {
560 let Some((_tag, bucket)) = self.buckets.remove(tag) else {
561 return;
562 };
563 self.len.fetch_sub(bucket.queue.len(), Ordering::Release);
564 }
565}
566
567#[derive(Debug)]
568struct Bucket<M> {
569 queue: SegQueue<M>,
570 subs: AtomicUsize,
571 recv_notify: Notify,
572}
573
574#[derive(Debug)]
578pub struct Subscriber<T: Hash + Eq, M> {
579 tag: T,
580 bucket: Arc<Bucket<M>>,
581 broker: Arc<MessageQueueBrokerInner<T, M>>,
582}
583
584impl<T, M> Subscriber<T, M>
585where
586 T: Hash + Eq,
587{
588 pub fn subs_count(&self) -> usize {
590 self.bucket.subs.load(Ordering::Acquire)
591 }
592
593 pub fn len(&self) -> usize {
610 self.bucket.queue.len()
611 }
612
613 pub fn is_empty(&self) -> bool {
630 self.len() == 0
631 }
632
633 pub fn is_closed(&self) -> bool {
635 match &*self.broker {
636 MessageQueueBrokerInner::Bounded(b) => b.is_closed(),
637 MessageQueueBrokerInner::Unbounded(b) => b.is_closed(),
638 }
639 }
640
641 pub fn try_recv(&self) -> Result<M, TryRecvError> {
678 Self::try_recv2(&self.broker, &self.bucket.queue)
679 }
680
681 pub async fn recv(&self) -> Result<M, RecvError> {
727 let mut notified = pin!(self.bucket.recv_notify.notified());
728
729 loop {
730 notified.as_mut().enable();
731
732 match Self::try_recv2(&self.broker, &self.bucket.queue) {
733 Ok(msg) => return Ok(msg),
734 Err(TryRecvError::Closed) => return Err(RecvError),
735 Err(TryRecvError::Empty) => {
736 notified.as_mut().await;
737 notified.set(self.bucket.recv_notify.notified());
738 }
739 }
740 }
741 }
742
743 fn try_recv2(
744 broker: &MessageQueueBrokerInner<T, M>,
745 bucket_queue: &SegQueue<M>,
746 ) -> Result<M, TryRecvError> {
747 match broker {
748 MessageQueueBrokerInner::Bounded(b) => {
749 if b.is_closed() {
750 return Err(TryRecvError::Closed);
751 }
752
753 let msg = bucket_queue.pop().ok_or(TryRecvError::Empty)?;
754 b.len.fetch_sub(1, Ordering::Release);
755 b.send_notify.notify_one();
756 Ok(msg)
757 }
758 MessageQueueBrokerInner::Unbounded(b) => {
759 if b.is_closed() {
760 return Err(TryRecvError::Closed);
761 }
762
763 let msg = bucket_queue.pop().ok_or(TryRecvError::Empty)?;
764 b.len.fetch_sub(1, Ordering::Release);
765 Ok(msg)
766 }
767 }
768 }
769}
770
771impl<T, M> Clone for Subscriber<T, M>
772where
773 T: Hash + Eq + Clone,
774{
775 fn clone(&self) -> Self {
776 self.bucket.subs.fetch_add(1, Ordering::Relaxed);
777 Self {
778 tag: self.tag.clone(),
779 bucket: self.bucket.clone(),
780 broker: self.broker.clone(),
781 }
782 }
783}
784
785impl<T, M> Drop for Subscriber<T, M>
786where
787 T: Hash + Eq,
788{
789 fn drop(&mut self) {
790 if !self.is_closed()
791 && self.bucket.subs.fetch_sub(1, Ordering::Relaxed) == 1
792 {
793 self.broker.unsubscribe(&self.tag);
794 }
795 }
796}
797
798#[derive(thiserror::Error, Eq, PartialEq)]
799#[error("sending into a closed channel")]
800pub struct SendError<T>(pub T);
801
802impl<T> SendError<T> {
803 pub fn into_inner(self) -> T {
805 self.0
806 }
807}
808
809impl<T> Debug for SendError<T> {
810 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
811 f.debug_tuple("SendError").finish_non_exhaustive()
812 }
813}
814
815#[derive(thiserror::Error, Eq, PartialEq)]
816pub enum TrySendError<T> {
817 #[error("sending into a full channel")]
818 Full(T),
819 #[error("sending into a closed channel")]
820 Closed(T),
821}
822
823impl<T> TrySendError<T> {
824 pub fn into_inner(self) -> T {
826 match self {
827 TrySendError::Full(t) => t,
828 TrySendError::Closed(t) => t,
829 }
830 }
831
832 pub fn is_full(&self) -> bool {
834 match self {
835 TrySendError::Full(_) => true,
836 TrySendError::Closed(_) => false,
837 }
838 }
839
840 pub fn is_closed(&self) -> bool {
842 match self {
843 TrySendError::Full(_) => false,
844 TrySendError::Closed(_) => true,
845 }
846 }
847}
848
849impl<T> Debug for TrySendError<T> {
850 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
851 match self {
852 TrySendError::Full(_) => {
853 f.debug_tuple("Full").finish_non_exhaustive()
854 }
855 TrySendError::Closed(_) => {
856 f.debug_tuple("Closed").finish_non_exhaustive()
857 }
858 }
859 }
860}
861
862#[derive(Debug, thiserror::Error, Eq, PartialEq)]
863#[error("receiving from an empty and closed channel")]
864pub struct RecvError;
865
866#[derive(Debug, thiserror::Error, Eq, PartialEq)]
867pub enum TryRecvError {
868 #[error("receiving from an empty channel")]
869 Empty,
870 #[error("receiving from an closed channel")]
871 Closed,
872}
873
874impl TryRecvError {
875 pub fn is_empty(&self) -> bool {
877 match self {
878 TryRecvError::Empty => true,
879 TryRecvError::Closed => false,
880 }
881 }
882
883 pub fn is_closed(&self) -> bool {
885 match self {
886 TryRecvError::Empty => false,
887 TryRecvError::Closed => true,
888 }
889 }
890}
891
892#[cfg(test)]
893mod tests {
894
895 use rand::prelude::SliceRandom;
896 use tokio::sync::Semaphore;
897
898 use super::*;
899
900 async fn parallel_check<
901 const WRITER_THREADS: usize,
902 const TAGS: usize,
903 const MESSAGES_PER_TAG: usize,
904 const READERS_PER_TAG: usize,
905 >(
906 mqb: MessageQueueBroker<usize, usize>,
907 ) {
908 let all_threads = WRITER_THREADS + TAGS * READERS_PER_TAG;
909
910 let mut gen = rand::rng();
911
912 let mqb = Arc::new(mqb);
913 let start_notify = Arc::new(Semaphore::new(0));
914
915 let mut tasks = Vec::with_capacity(all_threads);
916 for thread_idx in 0..WRITER_THREADS {
918 let mqb = mqb.clone();
919 let start_notify = start_notify.clone();
920 let messages = {
921 let mut msgs = Vec::with_capacity(
922 MESSAGES_PER_TAG
923 * usize::max(1, TAGS.div_ceil(WRITER_THREADS)),
924 );
925 for tag in
926 (0..TAGS).filter(|tag| tag % WRITER_THREADS == thread_idx)
927 {
928 for msg in 0..MESSAGES_PER_TAG {
929 msgs.push((tag, msg));
930 }
931 }
932 msgs.shuffle(&mut gen);
933 msgs
934 };
935 let fut = async move {
936 let _permit = start_notify.acquire().await.unwrap();
937 for (tag, msg) in messages {
938 mqb.send(&tag, msg).await.unwrap();
939 }
940 };
941
942 tasks.push(tokio::spawn(fut));
943 }
944
945 let messages_per_readers = {
948 let single = MESSAGES_PER_TAG / READERS_PER_TAG;
949 let mut remainder = MESSAGES_PER_TAG % READERS_PER_TAG;
950 std::iter::from_fn(|| Some(single))
951 .take(READERS_PER_TAG)
952 .map(|v| {
953 if remainder > 0 {
954 remainder -= 1;
955 v + 1
956 } else {
957 v
958 }
959 })
960 .collect::<Vec<_>>()
961 };
962
963 for tag in 0..TAGS {
964 for thread_idx in 0..READERS_PER_TAG {
965 let sub = mqb.subscribe(tag);
966 let start_notify = start_notify.clone();
967 let messages_per_reader = messages_per_readers[thread_idx];
968 let fut = async move {
969 let _permit = start_notify.acquire().await.unwrap();
970
971 for _ in 0..messages_per_reader {
972 sub.recv().await.unwrap();
973 }
974 };
975
976 tasks.push(tokio::spawn(fut));
977 }
978 }
979
980 start_notify.add_permits(all_threads);
981 assert!(
982 futures::future::join_all(tasks)
983 .await
984 .iter()
985 .all(Result::is_ok)
986 );
987 }
988
989 #[tokio::test]
990 async fn unbounded_parallel() {
991 parallel_check::<1, 1000, 1, 1>(MessageQueueBroker::unbounded()).await;
992 parallel_check::<20, 1000, 100, 1>(MessageQueueBroker::unbounded())
993 .await;
994 parallel_check::<20, 1000, 100, 2>(MessageQueueBroker::unbounded())
995 .await;
996 }
997
998 #[tokio::test]
999 async fn bounded_parallel() {
1000 parallel_check::<1, 1000, 1, 1>(MessageQueueBroker::bounded(10)).await;
1001 parallel_check::<20, 1000, 100, 1>(MessageQueueBroker::bounded(10))
1002 .await;
1003 parallel_check::<20, 1000, 100, 2>(MessageQueueBroker::bounded(10))
1004 .await;
1005 }
1006
1007 #[tokio::test]
1008 async fn unbounded() {
1009 let mbq = MessageQueueBroker::unbounded();
1010
1011 let sub1 = mbq.subscribe(1);
1012 let sub2 = mbq.subscribe(2);
1013
1014 mbq.send(&1, 1).await.unwrap();
1015 mbq.send(&2, 2).await.unwrap();
1016 assert_eq!(mbq.len(), 2);
1017 assert_eq!(mbq.try_send(&3, 42).unwrap_err(), TrySendError::Closed(42));
1018 assert_eq!(mbq.len(), 2);
1019
1020 assert_eq!(sub1.len(), 1);
1021 assert_eq!(sub1.recv().await, Ok(1));
1022 assert_eq!(sub1.len(), 0);
1023 assert_eq!(mbq.len(), 1);
1024
1025 assert_eq!(sub2.len(), 1);
1026 assert_eq!(sub2.recv().await, Ok(2));
1027 assert_eq!(sub2.len(), 0);
1028 assert_eq!(mbq.len(), 0);
1029
1030 assert!(mbq.is_empty());
1031 }
1032
1033 #[tokio::test]
1034 async fn bounded() {
1035 let mqb = MessageQueueBroker::bounded(2);
1036
1037 let sub1 = mqb.subscribe(1);
1038 let sub2 = mqb.subscribe(2);
1039
1040 mqb.send(&1, 1).await.unwrap();
1041 mqb.send(&2, 2).await.unwrap();
1042 assert_eq!(mqb.len(), 2);
1043 assert_eq!(mqb.try_send(&3, 42).unwrap_err(), TrySendError::Closed(42));
1044 assert_eq!(mqb.try_send(&2, 3).unwrap_err(), TrySendError::Full(3));
1045 assert_eq!(mqb.len(), 2);
1046
1047 assert_eq!(sub1.len(), 1);
1048 assert_eq!(sub1.recv().await, Ok(1));
1049 assert_eq!(sub1.len(), 0);
1050 assert_eq!(mqb.len(), 1);
1051
1052 assert_eq!(sub2.len(), 1);
1053 assert_eq!(sub2.recv().await, Ok(2));
1054 assert_eq!(sub2.len(), 0);
1055 assert_eq!(mqb.len(), 0);
1056
1057 assert!(mqb.is_empty());
1058 }
1059
1060 #[tokio::test]
1061 async fn sub_unsub() {
1062 let mqb = MessageQueueBroker::unbounded();
1063
1064 let sub1 = mqb.subscribe(1);
1065 let sub1_copy1 = mqb.subscribe(1);
1066 let sub1_copy2 = sub1.clone();
1067
1068 assert_eq!(sub1.subs_count(), 3);
1069
1070 drop(sub1_copy1);
1071 assert_eq!(sub1.subs_count(), 2);
1072
1073 drop(sub1_copy2);
1074 assert_eq!(sub1.subs_count(), 1);
1075
1076 drop(sub1);
1077 assert_eq!(mqb.try_send(&1, 1).unwrap_err(), TrySendError::Closed(1));
1078 assert_eq!(mqb.send(&1, 1).await.unwrap_err(), SendError(1));
1079 }
1080
1081 #[tokio::test]
1082 async fn close() {
1083 let mqb = MessageQueueBroker::<i32, i32>::unbounded();
1084 let sub1 = mqb.subscribe(1);
1085
1086 assert!(!sub1.is_closed());
1087 drop(mqb);
1088 assert!(sub1.is_closed());
1089 }
1090}