use std::{
collections::HashMap,
pin::Pin,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
task::{Context, Poll},
};
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use crossbeam::queue::SegQueue;
use dashmap::{mapref::entry::Entry, DashMap};
use futures::{ready, Stream};
use tokio::sync::{OwnedSemaphorePermit, Semaphore};
use tokio_util::sync::PollSemaphore;
use super::{Connection, DeliveryMode, Publisher, QueueHandle, QueueOptions, SyndicationMode};
use crate::{
acker::NoopAcker,
serializer::{Serializable, Serializer},
};
#[derive(Clone)]
pub struct InMemoryConnection {
queues: Arc<DashMap<String, InMemoryQueueHandle>>,
serializer: Serializer,
}
impl InMemoryConnection {
pub fn new(serializer: Serializer) -> Self {
Self {
queues: Default::default(),
serializer,
}
}
}
#[async_trait]
impl Connection for InMemoryConnection {
type QueueHandle = InMemoryQueueHandle;
async fn close(&self) -> Result<()> {
Ok(())
}
async fn declare_queue(&self, name: &str, options: QueueOptions) -> Result<Self::QueueHandle> {
match self.queues.entry(name.to_string()) {
Entry::Occupied(entry) => Ok(entry.get().clone()),
Entry::Vacant(entry) => {
let queue = InMemoryQueueHandle::new(self.serializer, options);
entry.insert(queue.clone());
Ok(queue)
}
}
}
async fn delete_queue(&self, name: &str) -> Result<()> {
self.queues.remove(name);
Ok(())
}
}
#[derive(Clone)]
struct ExactlyOnceQueue {
messages: Arc<SegQueue<Bytes>>,
num_messages: PollSemaphore,
_options: QueueOptions,
serializer: Serializer,
}
impl Default for ExactlyOnceQueue {
fn default() -> Self {
Self {
messages: Default::default(),
num_messages: PollSemaphore::new(Arc::new(Semaphore::new(0))),
_options: Default::default(),
serializer: Default::default(),
}
}
}
impl ExactlyOnceQueue {
fn new(options: QueueOptions, serializer: Serializer) -> Self {
Self {
messages: Default::default(),
num_messages: PollSemaphore::new(Arc::new(Semaphore::new(0))),
_options: options,
serializer,
}
}
fn publish<PayloadTarget: Serializable>(&self, payload: &PayloadTarget) -> Result<()> {
let bytes = self.serializer.to_bytes(payload)?;
self.messages.push(bytes.clone());
self.num_messages.add_permits(1);
Ok(())
}
fn declare_consumer<PayloadTarget: Serializable>(
&self,
_consumer_name: &str,
) -> Result<InMemoryConsumer<PayloadTarget>> {
Ok(InMemoryConsumer {
messages: self.messages.clone(),
num_messages: self.num_messages.clone(),
serializer: self.serializer,
permit: None,
_marker: std::marker::PhantomData,
})
}
}
#[derive(Clone)]
struct BroadcastConsumer {
messages: Arc<SegQueue<Bytes>>,
num_messages: PollSemaphore,
seen: Arc<DashMap<u64, ()>>,
}
impl Default for BroadcastConsumer {
fn default() -> Self {
Self {
messages: Default::default(),
num_messages: PollSemaphore::new(Arc::new(Semaphore::new(0))),
seen: Default::default(),
}
}
}
#[derive(Clone, Default)]
struct BroadcastQueue {
consumers: Arc<DashMap<String, BroadcastConsumer>>,
history: Arc<SegQueue<(u64, Bytes)>>,
message_counter: Arc<AtomicU64>,
options: QueueOptions,
serializer: Serializer,
}
impl BroadcastQueue {
fn new(options: QueueOptions, serializer: Serializer) -> Self {
Self {
options,
serializer,
..Default::default()
}
}
fn publish<PayloadTarget: Serializable>(&self, payload: &PayloadTarget) -> Result<()> {
let bytes = self.serializer.to_bytes(payload)?;
let message_id = self.message_counter.fetch_add(1, Ordering::Relaxed);
if DeliveryMode::Persistent == self.options.delivery_mode && self.consumers.is_empty() {
self.history.push((message_id, bytes.clone()));
}
for consumer in self.consumers.iter() {
if DeliveryMode::Persistent == self.options.delivery_mode {
match consumer.seen.entry(message_id) {
Entry::Occupied(_) => continue,
Entry::Vacant(entry) => {
entry.insert(());
}
}
}
consumer.messages.push(bytes.clone());
consumer.num_messages.add_permits(1);
}
Ok(())
}
fn declare_consumer<PayloadTarget: Serializable>(
&self,
consumer_name: &str,
) -> Result<InMemoryConsumer<PayloadTarget>> {
match self.consumers.entry(consumer_name.to_string()) {
Entry::Occupied(entry) => {
let consumer = entry.get().clone();
Ok(InMemoryConsumer {
messages: consumer.messages.clone(),
num_messages: consumer.num_messages.clone(),
serializer: self.serializer,
permit: None,
_marker: std::marker::PhantomData,
})
}
Entry::Vacant(entry) => {
let (messages, seen) = if DeliveryMode::Persistent == self.options.delivery_mode
&& !self.history.is_empty()
{
let messages = SegQueue::new();
let mut seen = HashMap::new();
while let Some((message_id, message)) = self.history.pop() {
match seen.entry(message_id) {
std::collections::hash_map::Entry::Occupied(_) => continue,
std::collections::hash_map::Entry::Vacant(entry) => {
entry.insert(());
}
}
messages.push(message);
}
(messages, seen)
} else {
(Default::default(), Default::default())
};
let consumer = BroadcastConsumer {
num_messages: PollSemaphore::new(Arc::new(Semaphore::new(seen.len()))),
messages: Arc::new(messages),
seen: Arc::new(seen.into_iter().collect()),
};
entry.insert(consumer.clone());
Ok(InMemoryConsumer {
messages: consumer.messages.clone(),
num_messages: consumer.num_messages.clone(),
serializer: self.serializer,
permit: None,
_marker: std::marker::PhantomData,
})
}
}
}
}
#[derive(Clone, Default)]
pub struct InMemoryQueueHandle {
broadcast_queue: BroadcastQueue,
exactly_once_queue: ExactlyOnceQueue,
options: QueueOptions,
}
impl InMemoryQueueHandle {
pub fn new(serializer: Serializer, options: QueueOptions) -> Self {
Self {
options,
broadcast_queue: BroadcastQueue::new(options, serializer),
exactly_once_queue: ExactlyOnceQueue::new(options, serializer),
}
}
}
pub struct InMemoryPublisher<T> {
queue_handle: InMemoryQueueHandle,
_marker: std::marker::PhantomData<T>,
}
impl<T> InMemoryPublisher<T> {
pub fn new(queue_handle: InMemoryQueueHandle) -> Self {
Self {
queue_handle,
_marker: std::marker::PhantomData,
}
}
}
#[async_trait]
impl<T: Serializable> Publisher<T> for InMemoryPublisher<T> {
async fn publish(&self, payload: &T) -> Result<()> {
match self.queue_handle.options.syndication_mode {
SyndicationMode::ExactlyOnce => self.queue_handle.exactly_once_queue.publish(payload),
SyndicationMode::Broadcast => self.queue_handle.broadcast_queue.publish(payload),
}
}
async fn close(&self) -> Result<()> {
Ok(())
}
}
#[async_trait]
impl QueueHandle for InMemoryQueueHandle {
type Acker = NoopAcker;
type Consumer<PayloadTarget: Serializable> = InMemoryConsumer<PayloadTarget>;
type Publisher<PayloadTarget: Serializable> = InMemoryPublisher<PayloadTarget>;
fn publisher<PayloadTarget: Serializable>(&self) -> Self::Publisher<PayloadTarget> {
InMemoryPublisher::new(self.clone())
}
async fn declare_consumer<PayloadTarget: Serializable>(
&self,
consumer_name: &str,
) -> Result<Self::Consumer<PayloadTarget>> {
match self.options.syndication_mode {
SyndicationMode::ExactlyOnce => self.exactly_once_queue.declare_consumer(consumer_name),
SyndicationMode::Broadcast => self.broadcast_queue.declare_consumer(consumer_name),
}
}
}
pub struct InMemoryConsumer<T> {
_marker: std::marker::PhantomData<T>,
messages: Arc<SegQueue<Bytes>>,
permit: Option<OwnedSemaphorePermit>,
serializer: Serializer,
num_messages: PollSemaphore,
}
impl<T: Serializable> Stream for InMemoryConsumer<T> {
type Item = (T, NoopAcker);
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let mut this = self.as_mut();
match this.permit.take() {
Some(permit) => {
let item = this.messages.pop();
permit.forget();
this.permit = None;
match item {
Some(item) => {
let item = this
.serializer
.from_bytes(&item)
.expect("failed to deserialize");
Poll::Ready(Some((item, NoopAcker::new())))
}
None => {
unreachable!("permit was acquired, but no message was available")
}
}
}
None => {
let permit = ready!(this.num_messages.poll_acquire(cx));
match permit {
Some(permit) => {
this.permit = Some(permit);
self.poll_next(cx)
}
None => Poll::Pending,
}
}
}
}
}
#[cfg(test)]
mod helpers {
use std::time::Duration;
use futures::{Future, StreamExt};
use serde::{Deserialize, Serialize};
use tokio::{
task::{JoinError, JoinHandle},
try_join,
};
use super::*;
use crate::acker::Acker;
#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Clone)]
pub(super) struct Payload {
pub(super) field: i64,
}
pub(super) fn new_payload(field: i64) -> Payload {
Payload { field }
}
pub(super) async fn with_timeout<O, F: Future<Output = Result<O, JoinError>>>(
fut: F,
) -> Option<O> {
let timeout = tokio::time::sleep(Duration::from_millis(10));
tokio::select! {
result = fut => {
Some(result.unwrap())
}
_ = timeout => {
None
}
}
}
pub(super) fn consume_next(mut consumer: InMemoryConsumer<Payload>) -> JoinHandle<Payload> {
tokio::spawn(async move {
let (payload, acker) = consumer.next().await.unwrap();
acker.ack().await.unwrap();
payload
})
}
pub(super) fn consume_n(
consumer: InMemoryConsumer<Payload>,
n: usize,
) -> JoinHandle<Vec<Payload>> {
tokio::spawn(async move {
consumer
.then(|(payload, acker)| async move {
acker.ack().await.unwrap();
payload
})
.take(n)
.collect::<Vec<_>>()
.await
})
}
pub(super) fn consume_n_select(
c1: InMemoryConsumer<Payload>,
c2: InMemoryConsumer<Payload>,
n: usize,
) -> JoinHandle<Vec<Payload>> {
tokio::spawn(async move {
futures::stream::select(c1, c2)
.then(|(payload, acker)| async move {
acker.ack().await.unwrap();
payload
})
.take(n)
.collect::<Vec<_>>()
.await
})
}
pub(super) async fn consumers<P: Serializable, H: QueueHandle>(
queue: &H,
) -> (H::Consumer<P>, H::Consumer<P>) {
try_join!(queue.declare_consumer("1"), queue.declare_consumer("2")).unwrap()
}
pub(super) fn publish<H: QueueHandle + Send + Sync + 'static>(
queue: &H,
payload: &Payload,
) -> JoinHandle<Result<()>>
where
<H as QueueHandle>::Publisher<Payload>: Send,
{
let payload = payload.clone();
let queue = queue.clone();
let publisher = queue.publisher();
tokio::spawn(async move { publisher.publish(&payload).await })
}
pub(super) fn publish_multi<H: QueueHandle + Send + Sync + 'static>(
queue: &H,
payload: &[Payload],
) -> Vec<JoinHandle<Result<()>>>
where
<H as QueueHandle>::Publisher<Payload>: Send,
{
payload.iter().map(|p| publish(queue, p)).collect()
}
}
#[cfg(test)]
mod exactly_once {
use tokio::{join, try_join};
use super::helpers::*;
use super::*;
use crate::queue::*;
async fn queue_handle() -> InMemoryQueueHandle {
let connection = InMemoryConnection::new(Serializer::default());
connection
.declare_queue(
"my_queue",
QueueOptions {
delivery_mode: DeliveryMode::Persistent,
syndication_mode: SyndicationMode::ExactlyOnce,
durability: QueueDurability::NonDurable,
},
)
.await
.unwrap()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn single_message_delivers_once_publish_first() {
let queue = queue_handle().await;
let clone = queue.clone();
publish(&clone, &new_payload(1));
let (c1, c2) = consumers(&queue).await;
let (r1, r2) = (consume_next(c1), consume_next(c2));
let (r1, r2) = join!(with_timeout(r1), with_timeout(r2));
assert!([r1.clone(), r2.clone()].iter().any(|r| r.is_none()));
assert!([r1.clone(), r2.clone()]
.into_iter()
.any(|r| r == Some(new_payload(1))));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn single_message_delivers_once_publish_last() {
let queue = queue_handle().await;
let (c1, c2) = consumers(&queue).await;
let (r1, r2) = (consume_next(c1), consume_next(c2));
publish(&queue, &new_payload(1));
let (r1, r2) = join!(with_timeout(r1), with_timeout(r2));
assert!([r1.clone(), r2.clone()].iter().any(|p| p.is_none()));
assert!([r1.clone(), r2.clone()]
.into_iter()
.any(|p| p == Some(new_payload(1))));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn double_message_delivers_once_publish_first() {
let queue = queue_handle().await;
publish(&queue, &new_payload(1));
publish(&queue, &new_payload(2));
let (c1, c2) = consumers(&queue).await;
let (r1, r2) = (consume_next(c1), consume_next(c2));
let (r1, r2) = try_join!(r1, r2).unwrap();
assert_ne!(r1, r2)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn double_message_delivers_once_publish_last() {
let queue = queue_handle().await;
let (c1, c2) = consumers(&queue).await;
let (r1, r2) = (consume_next(c1), consume_next(c2));
publish(&queue, &new_payload(1));
publish(&queue, &new_payload(2));
let (r1, r2) = try_join!(r1, r2).unwrap();
assert_ne!(r1, r2)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn many_messages_single_consumer() {
let queue = queue_handle().await;
let payloads = (0..100).map(new_payload).collect::<Vec<_>>();
publish_multi(&queue, &payloads);
let c = queue.declare_consumer("1").await.unwrap();
let mut results = consume_n(c, payloads.len()).await.unwrap();
results.sort_by_key(|a| a.field);
assert_eq!(payloads, results)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn many_messages_two_consumers() {
let queue = queue_handle().await;
let payloads = (0..100).map(new_payload).collect::<Vec<_>>();
publish_multi(&queue, &payloads);
let (c1, c2) = consumers(&queue).await;
let mut results = consume_n_select(c1, c2, payloads.len()).await.unwrap();
results.sort_by_key(|a| a.field);
assert_eq!(payloads, results)
}
}
#[cfg(test)]
mod broadcast {
use tokio::{join, try_join};
use super::helpers::*;
use super::*;
use crate::queue::*;
async fn broadcast_handle() -> InMemoryQueueHandle {
let connection = InMemoryConnection::new(Default::default());
connection
.declare_queue(
"my_broadcast_queue",
QueueOptions {
delivery_mode: DeliveryMode::Persistent,
syndication_mode: SyndicationMode::Broadcast,
durability: QueueDurability::NonDurable,
},
)
.await
.unwrap()
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn single_message_delivers_to_all_publish_last() {
let queue = broadcast_handle().await;
let expected = new_payload(1);
let (c1, c2) = consumers(&queue).await;
publish(&queue, &expected);
let (r1, r2) = try_join!(consume_next(c1), consume_next(c2)).unwrap();
assert_eq!(expected, r1);
assert_eq!(r1, r2)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn single_message_delivers_to_at_least_one_publish_first() {
let queue = broadcast_handle().await;
publish(&queue, &new_payload(1));
let (c1, c2) = consumers(&queue).await;
let (r1, r2) = (consume_next(c1), consume_next(c2));
let (r1, r2) = join!(with_timeout(r1), with_timeout(r2));
assert!([r1, r2].into_iter().any(|r| r == Some(new_payload(1))));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn many_messages_single_consumer_publish_first() {
let queue = broadcast_handle().await;
let payloads = (0..5).map(new_payload).collect::<Vec<_>>();
publish_multi(&queue, &payloads);
let c = queue.declare_consumer("1").await.unwrap();
let mut results = consume_n(c, payloads.len()).await.unwrap();
results.sort_by_key(|a| a.field);
assert_eq!(payloads, results)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn many_messages_single_consumer_publish_last() {
let queue = broadcast_handle().await;
let payloads = (0..5).map(new_payload).collect::<Vec<_>>();
let c = queue.declare_consumer("1").await.unwrap();
publish_multi(&queue, &payloads);
let mut results = consume_n(c, payloads.len()).await.unwrap();
results.sort_by_key(|a| a.field);
assert_eq!(payloads, results)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn many_messages_multi_consumer_publish_first() {
let queue = broadcast_handle().await;
let payloads = (0..5).map(new_payload).collect::<Vec<_>>();
publish_multi(&queue, &payloads);
let (c1, c2) = consumers(&queue).await;
let (mut r1, mut r2) = join!(
with_timeout(consume_n(c1, payloads.len())),
with_timeout(consume_n(c2, payloads.len()))
);
if let Some(v) = r1.as_mut() {
v.sort_by_key(|a| a.field);
}
if let Some(v) = r2.as_mut() {
v.sort_by_key(|a| a.field);
}
let expected = Some(payloads);
assert!([r1, r2].into_iter().any(|r| r == expected));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
async fn many_messages_multi_consumer_publish_last() {
let queue = broadcast_handle().await;
let payloads = (0..5).map(new_payload).collect::<Vec<_>>();
let (c1, c2) = consumers(&queue).await;
publish_multi(&queue, &payloads);
let (mut r1, mut r2) = join!(
with_timeout(consume_n(c1, payloads.len())),
with_timeout(consume_n(c2, payloads.len()))
);
if let Some(v) = r1.as_mut() {
v.sort_by_key(|a| a.field);
}
if let Some(v) = r2.as_mut() {
v.sort_by_key(|a| a.field);
}
let expected = Some(payloads);
assert!([r1, r2].into_iter().any(|r| r == expected));
}
}