use std::{
collections::HashMap,
convert::Infallible,
sync::{Arc, Mutex, OnceLock},
time::Duration,
};
use crate::{
AckError, Broker, Headers, IncomingMessage, OutgoingMessage, Publisher, RawMessage, Subscribe,
Subscriber, testing::TestClient,
};
use bytes::Bytes;
use futures::Stream;
use tokio::{
sync::{Notify, mpsc},
time::timeout,
};
use tokio_stream::{StreamExt, wrappers::UnboundedReceiverStream};
type Sender = mpsc::UnboundedSender<MemoryDelivery>;
#[derive(Clone)]
struct MemoryDelivery {
name: String,
payload: Bytes,
headers: Headers,
}
#[derive(Default)]
struct MemoryState {
subscribers: Mutex<HashMap<String, Vec<Sender>>>,
published: Mutex<HashMap<String, Vec<RawMessage>>>,
notify: Notify,
}
impl MemoryState {
fn register(&self, name: String, tx: Sender) {
let mut subs = self
.subscribers
.lock()
.expect("memory broker mutex poisoned");
subs.entry(name).or_default().push(tx);
}
fn fanout(&self, delivery: &MemoryDelivery) {
let snapshot = RawMessage::new(delivery.name.clone(), delivery.payload.clone())
.with_headers(delivery.headers.clone());
{
let mut log = self.published.lock().expect("memory broker mutex poisoned");
log.entry(delivery.name.clone()).or_default().push(snapshot);
}
self.notify.notify_waiters();
let subs = self
.subscribers
.lock()
.expect("memory broker mutex poisoned");
if let Some(senders) = subs.get(&delivery.name) {
for tx in senders {
let _ = tx.send(delivery.clone());
}
}
}
}
#[derive(Clone, Default)]
pub struct MemoryBroker {
state: Arc<MemoryState>,
}
impl MemoryBroker {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn subscribe(&self, name: impl Into<String>) -> MemorySubscriber {
let (tx, rx) = mpsc::unbounded_channel();
let name = name.into();
self.state.register(name.clone(), tx.clone());
MemorySubscriber {
name,
rx: Some(rx),
requeue: tx,
}
}
#[must_use]
pub fn publisher(&self) -> MemoryPublisher {
MemoryPublisher {
state: Arc::clone(&self.state),
}
}
}
impl std::fmt::Debug for MemoryBroker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryBroker").finish_non_exhaustive()
}
}
impl Broker for MemoryBroker {
type Error = Infallible;
async fn connect(&self) -> Result<(), Self::Error> {
Ok(())
}
async fn shutdown(&self) -> Result<(), Self::Error> {
self.state
.subscribers
.lock()
.expect("memory broker mutex poisoned")
.clear();
Ok(())
}
}
#[allow(clippy::use_self)]
impl Subscribe for MemoryBroker {
type Subscriber = MemorySubscriber;
async fn subscribe(&self, name: &str) -> Result<Self::Subscriber, Self::Error> {
Ok(MemoryBroker::subscribe(self, name))
}
}
pub struct MemorySubscriber {
name: String,
rx: Option<mpsc::UnboundedReceiver<MemoryDelivery>>,
requeue: Sender,
}
impl std::fmt::Debug for MemorySubscriber {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemorySubscriber")
.field("name", &self.name)
.finish_non_exhaustive()
}
}
impl Subscriber for MemorySubscriber {
type Message = MemoryMessage;
type Error = Infallible;
fn stream(&mut self) -> impl Stream<Item = Result<Self::Message, Self::Error>> + Send + '_ {
let rx = self
.rx
.take()
.expect("MemorySubscriber::stream called more than once");
let requeue = self.requeue.clone();
UnboundedReceiverStream::new(rx).map(move |delivery| {
Ok(MemoryMessage {
delivery: Some(delivery),
requeue: requeue.clone(),
})
})
}
}
#[derive(Clone)]
pub struct MemoryPublisher {
state: Arc<MemoryState>,
}
impl std::fmt::Debug for MemoryPublisher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryPublisher").finish_non_exhaustive()
}
}
impl Publisher for MemoryPublisher {
type Error = Infallible;
async fn publish(&self, msg: OutgoingMessage<'_>) -> Result<(), Self::Error> {
let delivery = MemoryDelivery {
name: msg.name().to_owned(),
payload: Bytes::copy_from_slice(msg.payload()),
headers: msg.headers().clone(),
};
self.state.fanout(&delivery);
Ok(())
}
}
pub struct MemoryMessage {
delivery: Option<MemoryDelivery>,
requeue: Sender,
}
impl std::fmt::Debug for MemoryMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryMessage")
.field("name", &self.delivery.as_ref().map(|d| d.name.as_str()))
.finish_non_exhaustive()
}
}
impl MemoryMessage {
#[must_use]
pub fn name(&self) -> &str {
self.delivery
.as_ref()
.map(|d| d.name.as_str())
.unwrap_or_default()
}
#[must_use]
pub fn into_raw(mut self) -> RawMessage {
let delivery = self.delivery.take().expect("delivery already consumed");
RawMessage::new(delivery.name, delivery.payload).with_headers(delivery.headers)
}
}
impl IncomingMessage for MemoryMessage {
fn payload(&self) -> &[u8] {
self.delivery
.as_ref()
.map(|d| d.payload.as_ref())
.unwrap_or_default()
}
fn headers(&self) -> &Headers {
static EMPTY: OnceLock<Headers> = OnceLock::new();
self.delivery
.as_ref()
.map_or_else(|| EMPTY.get_or_init(Headers::new), |d| &d.headers)
}
async fn ack(mut self) -> Result<(), AckError> {
self.delivery.take();
Ok(())
}
async fn nack(mut self, requeue: bool) -> Result<(), AckError> {
let delivery = self.delivery.take().expect("delivery already consumed");
if requeue {
let _ = self.requeue.send(delivery);
}
Ok(())
}
}
impl TestClient for MemoryBroker {
type Broker = Self;
type Subscriber = MemorySubscriber;
type Publisher = MemoryPublisher;
type Error = Infallible;
async fn start() -> Result<Self, Self::Error> {
Ok(Self::new())
}
fn broker(&self) -> &Self::Broker {
self
}
async fn publish(&self, name: &str, payload: &[u8]) -> Result<(), Self::Error> {
let publisher = Self::publisher(self);
publisher.publish(OutgoingMessage::new(name, payload)).await
}
async fn subscribe(&self, name: &str) -> Result<Self::Subscriber, Self::Error> {
Ok(Self::subscribe(self, name))
}
async fn publisher(&self) -> Result<Self::Publisher, Self::Error> {
Ok(Self::publisher(self))
}
async fn expect_published(
&self,
name: &str,
count: usize,
timeout_duration: Duration,
) -> Result<Vec<RawMessage>, Self::Error> {
let name_for_wait = name.to_owned();
let name_for_fallback = name_for_wait.clone();
let state = Arc::clone(&self.state);
let wait = async move {
loop {
{
let log = state
.published
.lock()
.expect("memory broker mutex poisoned");
if let Some(messages) = log.get(&name_for_wait) {
if messages.len() >= count {
return messages.iter().take(count).cloned().collect::<Vec<_>>();
}
}
}
state.notify.notified().await;
}
};
let result = timeout(timeout_duration, wait).await;
let messages = result.unwrap_or_else(|_| {
self.state
.published
.lock()
.expect("memory broker mutex poisoned")
.get(&name_for_fallback)
.map(|m| m.iter().take(count).cloned().collect())
.unwrap_or_default()
});
Ok(messages)
}
async fn shutdown(self) -> Result<(), Self::Error> {
<Self as Broker>::shutdown(&self).await
}
}