1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
//! In-memory messaging implementations, meant to imitate distributed messaging services for test
//! purposes.
//!
//! See [`MockPublisher`] for an entry point to the mock system.
use crate::{consumer::AcknowledgeableMessage, EncodableMessage, Topic, ValidatedMessage};
use async_channel as mpmc;
use futures_util::{
sink,
stream::{self, StreamExt},
};
use parking_lot::Mutex;
use pin_project::pin_project;
use std::{
collections::BTreeMap,
error::Error as StdError,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
/// Errors originating from mock publisher and consumer operations
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct Error {
/// The underlying source of the error
pub cause: Box<dyn StdError>,
}
impl Error {
fn from<E>(from: E) -> Self
where
Box<dyn StdError>: From<E>,
{
Self { cause: from.into() }
}
}
type Topics = BTreeMap<Topic, Subscriptions>;
type Subscriptions = BTreeMap<MockSubscription, Channel<ValidatedMessage>>;
/// An in-memory publisher.
///
/// Consumers for the published data can be created using the `new_consumer` method.
///
/// Messages are published to particular [`Topics`](crate::Topic). Each topic may have multiple
/// `Subscriptions`, and every message for a topic will be sent to each of its subscriptions. A
/// subscription, in turn, may have multiple consumers; consumers will take messages from the
/// subscription on a first-polled-first-served basis.
///
/// This publisher can be cloned, allowing multiple publishers to send messages to the same set of
/// topics and subscriptions. Any consumer created with `new_consumer` will receive all on-topic
/// and on-subscription messages from all the associated publishers, regardless of whether the
/// consumer was created from a cloned instance.
#[derive(Debug, Clone)]
pub struct MockPublisher {
topics: Arc<Mutex<Topics>>,
}
impl MockPublisher {
/// Create a new `MockPublisher`
pub fn new() -> Self {
MockPublisher {
topics: Arc::new(Mutex::new(BTreeMap::new())),
}
}
/// Create a new consumer which will listen for messages published to the given topic and
/// subscription by this publisher (or any of its clones)
pub fn new_consumer(
&self,
topic: impl Into<Topic>,
subscription: impl Into<MockSubscription>,
) -> MockConsumer {
let mut topics = self.topics.lock();
let subscriptions = topics.entry(topic.into()).or_insert_with(BTreeMap::new);
let channel = subscriptions
.entry(subscription.into())
.or_insert_with(|| {
let (sender, receiver) = mpmc::unbounded();
Channel { sender, receiver }
})
.clone();
MockConsumer {
subscription_messages: channel.receiver,
subscription_resend: channel.sender,
}
}
}
impl Default for MockPublisher {
fn default() -> Self {
Self::new()
}
}
impl<M, S> crate::Publisher<M, S> for MockPublisher
where
M: crate::EncodableMessage,
M::Error: StdError + 'static,
S: sink::Sink<M>,
S::Error: StdError + 'static,
{
type PublishError = Error;
type PublishSink = MockSink<M, S>;
fn publish_sink_with_responses(
self,
validator: M::Validator,
response_sink: S,
) -> Self::PublishSink {
MockSink {
topics: self.topics,
validator,
response_sink,
}
}
}
/// The sink used by the `MockPublisher`
#[pin_project]
#[derive(Debug)]
pub struct MockSink<M: EncodableMessage, S> {
topics: Arc<Mutex<Topics>>,
validator: M::Validator,
#[pin]
response_sink: S,
}
#[derive(Debug, Clone)]
struct Channel<T> {
sender: mpmc::Sender<T>,
receiver: mpmc::Receiver<T>,
}
impl<M, S> sink::Sink<M> for MockSink<M, S>
where
M: EncodableMessage,
M::Error: StdError + 'static,
S: sink::Sink<M>,
S::Error: StdError + 'static,
{
type Error = Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project()
.response_sink
.poll_ready(cx)
.map_err(Error::from)
}
fn start_send(self: Pin<&mut Self>, message: M) -> Result<(), Self::Error> {
let this = self.project();
let topic = message.topic();
let validated_message = message.encode(this.validator).map_err(Error::from)?;
// lock critical section
{
let mut topics = this.topics.lock();
// send the message to every subscription listening on the given topic
// find the subscriptions for this topic
let subscriptions = topics.entry(topic).or_insert_with(Subscriptions::new);
// Send to every subscription that still has consumers. If a subscription's consumers are
// all dropped, the channel will have been closed and should be removed from the list
subscriptions.retain(|_subscription_name, channel| {
match channel.sender.try_send(validated_message.clone()) {
// if successfully sent, retain the channel
Ok(()) => true,
// if the channel has disconnected due to drops, remove it from the list
Err(mpmc::TrySendError::Closed(_)) => false,
Err(mpmc::TrySendError::Full(_)) => {
unreachable!("unbounded channel should never be full")
}
}
});
}
// notify the caller that the message has been sent successfully
this.response_sink
.start_send(message)
.map_err(Error::from)?;
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project()
.response_sink
.poll_flush(cx)
.map_err(Error::from)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
self.project()
.response_sink
.poll_close(cx)
.map_err(Error::from)
}
}
/// An opaque identifier for individual subscriptions to a [`MockPublisher`]
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct MockSubscription(String);
impl<S> From<S> for MockSubscription
where
S: Into<String>,
{
fn from(string: S) -> Self {
MockSubscription(string.into())
}
}
/// A consumer for messages from a particular subscription to a [`MockPublisher`]
#[derive(Debug, Clone)]
pub struct MockConsumer {
// channel receiver to get messages from the subscription
subscription_messages: mpmc::Receiver<ValidatedMessage>,
// channel sender to resend messages to the subscription on nack
subscription_resend: mpmc::Sender<ValidatedMessage>,
}
impl crate::Consumer for MockConsumer {
type AckToken = MockAckToken;
type Error = Error;
type Stream = Self;
fn stream(self) -> Self::Stream {
self
}
}
impl stream::Stream for MockConsumer {
type Item = Result<AcknowledgeableMessage<MockAckToken, ValidatedMessage>, Error>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
self.subscription_messages
.poll_next_unpin(cx)
.map(|opt_message| {
opt_message.map(|message| {
Ok(AcknowledgeableMessage {
ack_token: MockAckToken {
message: message.clone(),
subscription_resend: self.subscription_resend.clone(),
},
message,
})
})
})
}
}
/// An acknowledge token associated with a particular message from a [`MockConsumer`].
///
/// When `nack` is called for a particular message's token, that message will be re-submitted to
/// consumers of the corresponding subscription. Messages otherwise do not have any timeout
/// behavior, so a message is only re-sent to consumers if it is explicitly nack'ed; `ack` and
/// `modify_deadline` have no effect
#[derive(Debug)]
pub struct MockAckToken {
message: ValidatedMessage,
subscription_resend: mpmc::Sender<ValidatedMessage>,
}
#[async_trait::async_trait]
impl crate::consumer::AcknowledgeToken for MockAckToken {
type AckError = Error;
type NackError = Error;
type ModifyError = Error;
async fn ack(self) -> Result<(), Self::AckError> {
Ok(())
}
async fn nack(self) -> Result<(), Self::NackError> {
self.subscription_resend
.send(self.message)
.await
.map_err(|mpmc::SendError(_message)| Error {
cause: "Could not nack message because all consumers have been dropped".into(),
})
}
async fn modify_deadline(&mut self, _seconds: u32) -> Result<(), Self::ModifyError> {
// currently does nothing
Ok(())
}
}