use crate::{consumer::AcknowledgeableMessage, EncodableMessage, Topic, ValidatedMessage};
use async_channel as mpmc;
use futures_util::{
sink,
stream::{self, StreamExt},
};
use parking_lot::Mutex;
use pin_project::pin_project;
use std::{
collections::BTreeMap,
error::Error as StdError,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct Error {
pub cause: Box<dyn StdError>,
}
impl Error {
fn from<E>(from: E) -> Self
where
Box<dyn StdError>: From<E>,
{
Self { cause: from.into() }
}
}
type Topics = BTreeMap<Topic, Subscriptions>;
type Subscriptions = BTreeMap<MockSubscription, Channel<ValidatedMessage>>;
#[derive(Debug, Clone)]
pub struct MockPublisher {
topics: Arc<Mutex<Topics>>,
}
impl MockPublisher {
pub fn new() -> Self {
MockPublisher {
topics: Arc::new(Mutex::new(BTreeMap::new())),
}
}
pub fn new_consumer(
&self,
topic: impl Into<Topic>,
subscription: impl Into<MockSubscription>,
) -> MockConsumer {
let mut topics = self.topics.lock();
let subscriptions = topics.entry(topic.into()).or_default();
let channel = subscriptions
.entry(subscription.into())
.or_insert_with(|| {
let (sender, receiver) = mpmc::unbounded();
Channel { sender, receiver }
})
.clone();
MockConsumer {
subscription_messages: channel.receiver,
subscription_resend: channel.sender,
}
}
}
impl Default for MockPublisher {
fn default() -> Self {
Self::new()
}
}
impl<M, S> crate::Publisher<M, S> for MockPublisher
where
M: crate::EncodableMessage,
M::Error: StdError + 'static,
S: sink::Sink<M>,
S::Error: StdError + 'static,
{
type PublishError = Error;
type PublishSink = MockSink<M, S>;
fn publish_sink_with_responses(
self,
validator: M::Validator,
response_sink: S,
) -> Self::PublishSink {
MockSink {
topics: self.topics,
validator,
response_sink,
}
}
}
#[pin_project]
#[derive(Debug)]
pub struct MockSink<M: EncodableMessage, S> {
topics: Arc<Mutex<Topics>>,
validator: M::Validator,
#[pin]
response_sink: S,
}
#[derive(Debug, Clone)]
struct Channel<T> {
sender: mpmc::Sender<T>,
receiver: mpmc::Receiver<T>,
}
impl<M, S> sink::Sink<M> for MockSink<M, S>
where
M: EncodableMessage,
M::Error: StdError + 'static,
S: sink::Sink<M>,
S::Error: StdError + 'static,
{
type Error = Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project()
.response_sink
.poll_ready(cx)
.map_err(Error::from)
}
fn start_send(self: Pin<&mut Self>, message: M) -> Result<(), Self::Error> {
let this = self.project();
let topic = message.topic();
let validated_message = message.encode(this.validator).map_err(Error::from)?;
{
let mut topics = this.topics.lock();
let subscriptions = topics.entry(topic).or_default();
subscriptions.retain(|_subscription_name, channel| {
match channel.sender.try_send(validated_message.clone()) {
Ok(()) => true,
Err(mpmc::TrySendError::Closed(_)) => false,
Err(mpmc::TrySendError::Full(_)) => {
unreachable!("unbounded channel should never be full")
}
}
});
}
this.response_sink
.start_send(message)
.map_err(Error::from)?;
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project()
.response_sink
.poll_flush(cx)
.map_err(Error::from)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project()
.response_sink
.poll_close(cx)
.map_err(Error::from)
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct MockSubscription(String);
impl<S> From<S> for MockSubscription
where
S: Into<String>,
{
fn from(string: S) -> Self {
MockSubscription(string.into())
}
}
#[derive(Debug, Clone)]
pub struct MockConsumer {
subscription_messages: mpmc::Receiver<ValidatedMessage>,
subscription_resend: mpmc::Sender<ValidatedMessage>,
}
impl crate::Consumer for MockConsumer {
type AckToken = MockAckToken;
type Error = Error;
type Stream = Self;
fn stream(self) -> Self::Stream {
self
}
}
impl stream::Stream for MockConsumer {
type Item = Result<AcknowledgeableMessage<MockAckToken, ValidatedMessage>, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.subscription_messages
.poll_next_unpin(cx)
.map(|opt_message| {
opt_message.map(|message| {
Ok(AcknowledgeableMessage {
ack_token: MockAckToken {
message: message.clone(),
subscription_resend: self.subscription_resend.clone(),
},
message,
})
})
})
}
}
#[derive(Debug)]
pub struct MockAckToken {
message: ValidatedMessage,
subscription_resend: mpmc::Sender<ValidatedMessage>,
}
#[async_trait::async_trait]
impl crate::consumer::AcknowledgeToken for MockAckToken {
type AckError = Error;
type NackError = Error;
type ModifyError = Error;
async fn ack(self) -> Result<(), Self::AckError> {
Ok(())
}
async fn nack(self) -> Result<(), Self::NackError> {
self.subscription_resend
.send(self.message)
.await
.map_err(|mpmc::SendError(_message)| Error {
cause: "Could not nack message because all consumers have been dropped".into(),
})
}
async fn modify_deadline(&mut self, _seconds: u32) -> Result<(), Self::ModifyError> {
Ok(())
}
}