1use std::{
13 collections::HashMap,
14 pin::Pin,
15 sync::{
16 atomic::{AtomicU64, Ordering},
17 Arc,
18 },
19 task::{Context, Poll},
20};
21
22use anyhow::Result;
23use async_trait::async_trait;
24use bytes::Bytes;
25use crossbeam::queue::SegQueue;
26use dashmap::{mapref::entry::Entry, DashMap};
27use futures::{ready, Stream};
28use tokio::sync::{OwnedSemaphorePermit, Semaphore};
29use tokio_util::sync::PollSemaphore;
30
31use super::{Connection, DeliveryMode, Publisher, QueueHandle, QueueOptions, SyndicationMode};
32use crate::{
33 acker::NoopAcker,
34 serializer::{Serializable, Serializer},
35};
36
37#[derive(Clone)]
62pub struct InMemoryConnection {
63 queues: Arc<DashMap<String, InMemoryQueueHandle>>,
69 serializer: Serializer,
71}
72
73impl InMemoryConnection {
74 pub fn new(serializer: Serializer) -> Self {
75 Self {
76 queues: Default::default(),
77 serializer,
78 }
79 }
80}
81
82#[async_trait]
83impl Connection for InMemoryConnection {
84 type QueueHandle = InMemoryQueueHandle;
85
86 async fn close(&self) -> Result<()> {
87 Ok(())
88 }
89
90 async fn declare_queue(&self, name: &str, options: QueueOptions) -> Result<Self::QueueHandle> {
91 match self.queues.entry(name.to_string()) {
92 Entry::Occupied(entry) => Ok(entry.get().clone()),
93 Entry::Vacant(entry) => {
94 let queue = InMemoryQueueHandle::new(self.serializer, options);
95 entry.insert(queue.clone());
96 Ok(queue)
97 }
98 }
99 }
100
101 async fn delete_queue(&self, name: &str) -> Result<()> {
102 self.queues.remove(name);
103
104 Ok(())
105 }
106}
107
108#[derive(Clone)]
114struct ExactlyOnceQueue {
115 messages: Arc<SegQueue<Bytes>>,
117 num_messages: PollSemaphore,
119 _options: QueueOptions,
121 serializer: Serializer,
123}
124
125impl Default for ExactlyOnceQueue {
126 fn default() -> Self {
127 Self {
128 messages: Default::default(),
129 num_messages: PollSemaphore::new(Arc::new(Semaphore::new(0))),
130 _options: Default::default(),
131 serializer: Default::default(),
132 }
133 }
134}
135
136impl ExactlyOnceQueue {
137 fn new(options: QueueOptions, serializer: Serializer) -> Self {
138 Self {
139 messages: Default::default(),
140 num_messages: PollSemaphore::new(Arc::new(Semaphore::new(0))),
141 _options: options,
142 serializer,
143 }
144 }
145
146 fn publish<PayloadTarget: Serializable>(&self, payload: &PayloadTarget) -> Result<()> {
147 let bytes = self.serializer.to_bytes(payload)?;
148 self.messages.push(bytes.clone());
149 self.num_messages.add_permits(1);
150 Ok(())
151 }
152
153 fn declare_consumer<PayloadTarget: Serializable>(
154 &self,
155 _consumer_name: &str,
156 ) -> Result<InMemoryConsumer<PayloadTarget>> {
157 Ok(InMemoryConsumer {
158 messages: self.messages.clone(),
159 num_messages: self.num_messages.clone(),
160 serializer: self.serializer,
161 permit: None,
162 _marker: std::marker::PhantomData,
163 })
164 }
165}
166
167#[derive(Clone)]
176struct BroadcastConsumer {
177 messages: Arc<SegQueue<Bytes>>,
179 num_messages: PollSemaphore,
181 seen: Arc<DashMap<u64, ()>>,
183}
184
185impl Default for BroadcastConsumer {
186 fn default() -> Self {
187 Self {
188 messages: Default::default(),
189 num_messages: PollSemaphore::new(Arc::new(Semaphore::new(0))),
190 seen: Default::default(),
191 }
192 }
193}
194
195#[derive(Clone, Default)]
216struct BroadcastQueue {
217 consumers: Arc<DashMap<String, BroadcastConsumer>>,
219 history: Arc<SegQueue<(u64, Bytes)>>,
224 message_counter: Arc<AtomicU64>,
226 options: QueueOptions,
228 serializer: Serializer,
230}
231
232impl BroadcastQueue {
233 fn new(options: QueueOptions, serializer: Serializer) -> Self {
234 Self {
235 options,
236 serializer,
237 ..Default::default()
238 }
239 }
240
241 fn publish<PayloadTarget: Serializable>(&self, payload: &PayloadTarget) -> Result<()> {
242 let bytes = self.serializer.to_bytes(payload)?;
243 let message_id = self.message_counter.fetch_add(1, Ordering::Relaxed);
246
247 if DeliveryMode::Persistent == self.options.delivery_mode && self.consumers.is_empty() {
251 self.history.push((message_id, bytes.clone()));
252 }
253
254 for consumer in self.consumers.iter() {
255 if DeliveryMode::Persistent == self.options.delivery_mode {
260 match consumer.seen.entry(message_id) {
261 Entry::Occupied(_) => continue,
262 Entry::Vacant(entry) => {
263 entry.insert(());
264 }
265 }
266 }
267
268 consumer.messages.push(bytes.clone());
269 consumer.num_messages.add_permits(1);
270 }
271
272 Ok(())
273 }
274
275 fn declare_consumer<PayloadTarget: Serializable>(
276 &self,
277 consumer_name: &str,
278 ) -> Result<InMemoryConsumer<PayloadTarget>> {
279 match self.consumers.entry(consumer_name.to_string()) {
280 Entry::Occupied(entry) => {
281 let consumer = entry.get().clone();
282 Ok(InMemoryConsumer {
283 messages: consumer.messages.clone(),
284 num_messages: consumer.num_messages.clone(),
285 serializer: self.serializer,
286 permit: None,
287 _marker: std::marker::PhantomData,
288 })
289 }
290 Entry::Vacant(entry) => {
291 let (messages, seen) = if DeliveryMode::Persistent == self.options.delivery_mode
295 && !self.history.is_empty()
296 {
297 let messages = SegQueue::new();
298 let mut seen = HashMap::new();
299
300 while let Some((message_id, message)) = self.history.pop() {
303 match seen.entry(message_id) {
305 std::collections::hash_map::Entry::Occupied(_) => continue,
306 std::collections::hash_map::Entry::Vacant(entry) => {
307 entry.insert(());
308 }
309 }
310
311 messages.push(message);
312 }
313 (messages, seen)
314 } else {
315 (Default::default(), Default::default())
316 };
317
318 let consumer = BroadcastConsumer {
319 num_messages: PollSemaphore::new(Arc::new(Semaphore::new(seen.len()))),
320 messages: Arc::new(messages),
321 seen: Arc::new(seen.into_iter().collect()),
322 };
323
324 entry.insert(consumer.clone());
325 Ok(InMemoryConsumer {
326 messages: consumer.messages.clone(),
327 num_messages: consumer.num_messages.clone(),
328 serializer: self.serializer,
329 permit: None,
330 _marker: std::marker::PhantomData,
331 })
332 }
333 }
334 }
335}
336
337#[derive(Clone, Default)]
450pub struct InMemoryQueueHandle {
451 broadcast_queue: BroadcastQueue,
453 exactly_once_queue: ExactlyOnceQueue,
455 options: QueueOptions,
457}
458
459impl InMemoryQueueHandle {
460 pub fn new(serializer: Serializer, options: QueueOptions) -> Self {
461 Self {
462 options,
463 broadcast_queue: BroadcastQueue::new(options, serializer),
464 exactly_once_queue: ExactlyOnceQueue::new(options, serializer),
465 }
466 }
467}
468
469pub struct InMemoryPublisher<T> {
470 queue_handle: InMemoryQueueHandle,
471 _marker: std::marker::PhantomData<T>,
472}
473
474impl<T> InMemoryPublisher<T> {
475 pub fn new(queue_handle: InMemoryQueueHandle) -> Self {
476 Self {
477 queue_handle,
478 _marker: std::marker::PhantomData,
479 }
480 }
481}
482
483#[async_trait]
484impl<T: Serializable> Publisher<T> for InMemoryPublisher<T> {
485 async fn publish(&self, payload: &T) -> Result<()> {
486 match self.queue_handle.options.syndication_mode {
487 SyndicationMode::ExactlyOnce => self.queue_handle.exactly_once_queue.publish(payload),
488 SyndicationMode::Broadcast => self.queue_handle.broadcast_queue.publish(payload),
489 }
490 }
491
492 async fn close(&self) -> Result<()> {
493 Ok(())
494 }
495}
496
497#[async_trait]
498impl QueueHandle for InMemoryQueueHandle {
499 type Acker = NoopAcker;
500 type Consumer<PayloadTarget: Serializable> = InMemoryConsumer<PayloadTarget>;
501 type Publisher<PayloadTarget: Serializable> = InMemoryPublisher<PayloadTarget>;
502
503 fn publisher<PayloadTarget: Serializable>(&self) -> Self::Publisher<PayloadTarget> {
504 InMemoryPublisher::new(self.clone())
505 }
506
507 async fn declare_consumer<PayloadTarget: Serializable>(
508 &self,
509 consumer_name: &str,
510 ) -> Result<Self::Consumer<PayloadTarget>> {
511 match self.options.syndication_mode {
512 SyndicationMode::ExactlyOnce => self.exactly_once_queue.declare_consumer(consumer_name),
513 SyndicationMode::Broadcast => self.broadcast_queue.declare_consumer(consumer_name),
514 }
515 }
516}
517
518pub struct InMemoryConsumer<T> {
527 _marker: std::marker::PhantomData<T>,
528 messages: Arc<SegQueue<Bytes>>,
529 permit: Option<OwnedSemaphorePermit>,
530 serializer: Serializer,
531 num_messages: PollSemaphore,
532}
533
534impl<T: Serializable> Stream for InMemoryConsumer<T> {
535 type Item = (T, NoopAcker);
536
537 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
538 let mut this = self.as_mut();
539
540 match this.permit.take() {
542 Some(permit) => {
543 let item = this.messages.pop();
545 permit.forget();
547 this.permit = None;
549
550 match item {
551 Some(item) => {
552 let item = this
553 .serializer
554 .from_bytes(&item)
555 .expect("failed to deserialize");
556
557 Poll::Ready(Some((item, NoopAcker::new())))
558 }
559 None => {
560 unreachable!("permit was acquired, but no message was available")
563 }
564 }
565 }
566
567 None => {
569 let permit = ready!(this.num_messages.poll_acquire(cx));
570 match permit {
571 Some(permit) => {
573 this.permit = Some(permit);
575 self.poll_next(cx)
576 }
577 None => Poll::Pending,
578 }
579 }
580 }
581 }
582}
583
584#[cfg(test)]
585mod helpers {
586 use std::time::Duration;
587
588 use futures::{Future, StreamExt};
589 use serde::{Deserialize, Serialize};
590 use tokio::{
591 task::{JoinError, JoinHandle},
592 try_join,
593 };
594
595 use super::*;
596 use crate::acker::Acker;
597
598 #[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
599 pub(super) struct Payload {
600 pub(super) field: i64,
601 }
602
603 pub(super) fn new_payload(field: i64) -> Payload {
604 Payload { field }
605 }
606
607 pub(super) async fn with_timeout<O, F: Future<Output = Result<O, JoinError>>>(
608 fut: F,
609 ) -> Option<O> {
610 let timeout = tokio::time::sleep(Duration::from_millis(10));
611
612 tokio::select! {
613 result = fut => {
614 Some(result.unwrap())
615 }
616 _ = timeout => {
617 None
618 }
619 }
620 }
621
622 pub(super) fn consume_next(mut consumer: InMemoryConsumer<Payload>) -> JoinHandle<Payload> {
623 tokio::spawn(async move {
624 let (payload, acker) = consumer.next().await.unwrap();
625 acker.ack().await.unwrap();
626 payload
627 })
628 }
629
630 pub(super) fn consume_n(
631 consumer: InMemoryConsumer<Payload>,
632 n: usize,
633 ) -> JoinHandle<Vec<Payload>> {
634 tokio::spawn(async move {
635 consumer
636 .then(|(payload, acker)| async move {
637 acker.ack().await.unwrap();
638 payload
639 })
640 .take(n)
641 .collect::<Vec<_>>()
642 .await
643 })
644 }
645
646 pub(super) fn consume_n_select(
647 c1: InMemoryConsumer<Payload>,
648 c2: InMemoryConsumer<Payload>,
649 n: usize,
650 ) -> JoinHandle<Vec<Payload>> {
651 tokio::spawn(async move {
652 futures::stream::select(c1, c2)
653 .then(|(payload, acker)| async move {
654 acker.ack().await.unwrap();
655 payload
656 })
657 .take(n)
658 .collect::<Vec<_>>()
659 .await
660 })
661 }
662
663 pub(super) async fn consumers<P: Serializable, H: QueueHandle>(
664 queue: &H,
665 ) -> (H::Consumer<P>, H::Consumer<P>) {
666 try_join!(queue.declare_consumer("1"), queue.declare_consumer("2")).unwrap()
667 }
668
669 pub(super) fn publish<H: QueueHandle + Send + Sync + 'static>(
670 queue: &H,
671 payload: &Payload,
672 ) -> JoinHandle<Result<()>>
673 where
674 <H as QueueHandle>::Publisher<Payload>: Send,
675 {
676 let payload = payload.clone();
677 let queue = queue.clone();
678 let publisher = queue.publisher();
679 tokio::spawn(async move { publisher.publish(&payload).await })
680 }
681
682 pub(super) fn publish_multi<H: QueueHandle + Send + Sync + 'static>(
683 queue: &H,
684 payload: &[Payload],
685 ) -> Vec<JoinHandle<Result<()>>>
686 where
687 <H as QueueHandle>::Publisher<Payload>: Send,
688 {
689 payload.iter().map(|p| publish(queue, p)).collect()
690 }
691}
692
693#[cfg(test)]
694mod exactly_once {
695
696 use tokio::{join, try_join};
697
698 use super::helpers::*;
699 use super::*;
700 use crate::queue::*;
701
702 async fn queue_handle() -> InMemoryQueueHandle {
703 let connection = InMemoryConnection::new(Serializer::default());
704 connection
705 .declare_queue(
706 "my_queue",
707 QueueOptions {
708 delivery_mode: DeliveryMode::Persistent,
709 syndication_mode: SyndicationMode::ExactlyOnce,
710 durability: QueueDurability::NonDurable,
711 },
712 )
713 .await
714 .unwrap()
715 }
716
717 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
718 async fn single_message_delivers_once_publish_first() {
719 let queue = queue_handle().await;
720 let clone = queue.clone();
721 publish(&clone, &new_payload(1));
722
723 let (c1, c2) = consumers(&queue).await;
724 let (r1, r2) = (consume_next(c1), consume_next(c2));
725 let (r1, r2) = join!(with_timeout(r1), with_timeout(r2));
726
727 assert!([r1.clone(), r2.clone()].iter().any(|r| r.is_none()));
728 assert!([r1.clone(), r2.clone()]
729 .into_iter()
730 .any(|r| r == Some(new_payload(1))));
731 }
732
733 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
734 async fn single_message_delivers_once_publish_last() {
735 let queue = queue_handle().await;
736
737 let (c1, c2) = consumers(&queue).await;
738 let (r1, r2) = (consume_next(c1), consume_next(c2));
739
740 publish(&queue, &new_payload(1));
741
742 let (r1, r2) = join!(with_timeout(r1), with_timeout(r2));
743
744 assert!([r1.clone(), r2.clone()].iter().any(|p| p.is_none()));
745 assert!([r1.clone(), r2.clone()]
746 .into_iter()
747 .any(|p| p == Some(new_payload(1))));
748 }
749
750 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
751 async fn double_message_delivers_once_publish_first() {
752 let queue = queue_handle().await;
753 publish(&queue, &new_payload(1));
754 publish(&queue, &new_payload(2));
755 let (c1, c2) = consumers(&queue).await;
756 let (r1, r2) = (consume_next(c1), consume_next(c2));
757 let (r1, r2) = try_join!(r1, r2).unwrap();
758
759 assert_ne!(r1, r2)
760 }
761
762 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
763 async fn double_message_delivers_once_publish_last() {
764 let queue = queue_handle().await;
765
766 let (c1, c2) = consumers(&queue).await;
767 let (r1, r2) = (consume_next(c1), consume_next(c2));
768 publish(&queue, &new_payload(1));
769 publish(&queue, &new_payload(2));
770 let (r1, r2) = try_join!(r1, r2).unwrap();
771
772 assert_ne!(r1, r2)
773 }
774
775 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
776 async fn many_messages_single_consumer() {
777 let queue = queue_handle().await;
778 let payloads = (0..100).map(new_payload).collect::<Vec<_>>();
779 publish_multi(&queue, &payloads);
780
781 let c = queue.declare_consumer("1").await.unwrap();
782 let mut results = consume_n(c, payloads.len()).await.unwrap();
783 results.sort_by_key(|a| a.field);
784 assert_eq!(payloads, results)
785 }
786
787 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
788 async fn many_messages_two_consumers() {
789 let queue = queue_handle().await;
790 let payloads = (0..100).map(new_payload).collect::<Vec<_>>();
791 publish_multi(&queue, &payloads);
792
793 let (c1, c2) = consumers(&queue).await;
794 let mut results = consume_n_select(c1, c2, payloads.len()).await.unwrap();
795 results.sort_by_key(|a| a.field);
796 assert_eq!(payloads, results)
797 }
798}
799
800#[cfg(test)]
801mod broadcast {
802 use tokio::{join, try_join};
803
804 use super::helpers::*;
805 use super::*;
806 use crate::queue::*;
807
808 async fn broadcast_handle() -> InMemoryQueueHandle {
809 let connection = InMemoryConnection::new(Default::default());
810 connection
811 .declare_queue(
812 "my_broadcast_queue",
813 QueueOptions {
814 delivery_mode: DeliveryMode::Persistent,
815 syndication_mode: SyndicationMode::Broadcast,
816 durability: QueueDurability::NonDurable,
817 },
818 )
819 .await
820 .unwrap()
821 }
822
823 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
824 async fn single_message_delivers_to_all_publish_last() {
825 let queue = broadcast_handle().await;
826 let expected = new_payload(1);
827
828 let (c1, c2) = consumers(&queue).await;
829 publish(&queue, &expected);
830 let (r1, r2) = try_join!(consume_next(c1), consume_next(c2)).unwrap();
831
832 assert_eq!(expected, r1);
833 assert_eq!(r1, r2)
834 }
835
836 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
837 async fn single_message_delivers_to_at_least_one_publish_first() {
838 let queue = broadcast_handle().await;
839
840 publish(&queue, &new_payload(1));
841
842 let (c1, c2) = consumers(&queue).await;
843 let (r1, r2) = (consume_next(c1), consume_next(c2));
844 let (r1, r2) = join!(with_timeout(r1), with_timeout(r2));
845
846 assert!([r1, r2].into_iter().any(|r| r == Some(new_payload(1))));
847 }
848
849 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
850 async fn many_messages_single_consumer_publish_first() {
851 let queue = broadcast_handle().await;
852 let payloads = (0..5).map(new_payload).collect::<Vec<_>>();
853 publish_multi(&queue, &payloads);
854 let c = queue.declare_consumer("1").await.unwrap();
855
856 let mut results = consume_n(c, payloads.len()).await.unwrap();
857 results.sort_by_key(|a| a.field);
858
859 assert_eq!(payloads, results)
860 }
861
862 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
863 async fn many_messages_single_consumer_publish_last() {
864 let queue = broadcast_handle().await;
865 let payloads = (0..5).map(new_payload).collect::<Vec<_>>();
866 let c = queue.declare_consumer("1").await.unwrap();
867 publish_multi(&queue, &payloads);
868
869 let mut results = consume_n(c, payloads.len()).await.unwrap();
870 results.sort_by_key(|a| a.field);
871
872 assert_eq!(payloads, results)
873 }
874
875 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
876 async fn many_messages_multi_consumer_publish_first() {
877 let queue = broadcast_handle().await;
878 let payloads = (0..5).map(new_payload).collect::<Vec<_>>();
879 publish_multi(&queue, &payloads);
880 let (c1, c2) = consumers(&queue).await;
881
882 let (mut r1, mut r2) = join!(
883 with_timeout(consume_n(c1, payloads.len())),
884 with_timeout(consume_n(c2, payloads.len()))
885 );
886 if let Some(v) = r1.as_mut() {
887 v.sort_by_key(|a| a.field);
888 }
889 if let Some(v) = r2.as_mut() {
890 v.sort_by_key(|a| a.field);
891 }
892 let expected = Some(payloads);
893
894 assert!([r1, r2].into_iter().any(|r| r == expected));
895 }
896
897 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
898 async fn many_messages_multi_consumer_publish_last() {
899 let queue = broadcast_handle().await;
900 let payloads = (0..5).map(new_payload).collect::<Vec<_>>();
901 let (c1, c2) = consumers(&queue).await;
902 publish_multi(&queue, &payloads);
903
904 let (mut r1, mut r2) = join!(
905 with_timeout(consume_n(c1, payloads.len())),
906 with_timeout(consume_n(c2, payloads.len()))
907 );
908 if let Some(v) = r1.as_mut() {
909 v.sort_by_key(|a| a.field);
910 }
911 if let Some(v) = r2.as_mut() {
912 v.sort_by_key(|a| a.field);
913 }
914 let expected = Some(payloads);
915
916 assert!([r1, r2].into_iter().any(|r| r == expected));
917 }
918}