mod capability;
mod test_client;
pub use capability::{MemoryRequester, PARTITION_KEY_HEADER, RequestError};
use std::{
collections::HashMap,
convert::Infallible,
sync::{Arc, Mutex, OnceLock, atomic::AtomicU64},
time::Duration,
};
use crate::{
AckError, Broker, Headers, IncomingMessage, OutgoingMessage, Publisher, RawMessage, Subscribe,
Subscriber, SubscriptionSource,
};
use bytes::Bytes;
use futures::Stream;
use tokio::sync::{Notify, mpsc};
type Sender = mpsc::UnboundedSender<MemoryDelivery>;
#[derive(Clone)]
struct MemoryDelivery {
name: String,
payload: Bytes,
headers: Headers,
}
#[derive(Default)]
struct MemoryState {
subscribers: Mutex<HashMap<String, Vec<Sender>>>,
published: Mutex<HashMap<String, Vec<RawMessage>>>,
notify: Notify,
inbox_seq: AtomicU64,
}
impl MemoryState {
fn register(&self, name: String, tx: Sender) {
let mut subs = self
.subscribers
.lock()
.expect("memory broker mutex poisoned");
subs.entry(name).or_default().push(tx);
}
fn unregister(&self, name: &str) {
let mut subs = self
.subscribers
.lock()
.expect("memory broker mutex poisoned");
subs.remove(name);
}
fn fanout(&self, delivery: &MemoryDelivery) {
let snapshot = RawMessage::new(delivery.name.clone(), delivery.payload.clone())
.with_headers(delivery.headers.clone());
{
let mut log = self.published.lock().expect("memory broker mutex poisoned");
log.entry(delivery.name.clone()).or_default().push(snapshot);
}
self.notify.notify_waiters();
let subs = self
.subscribers
.lock()
.expect("memory broker mutex poisoned");
if let Some(senders) = subs.get(&delivery.name) {
for tx in senders {
let _ = tx.send(delivery.clone());
}
}
}
}
#[derive(Clone, Default)]
pub struct MemoryBroker {
state: Arc<MemoryState>,
}
impl MemoryBroker {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn subscribe(&self, name: impl Into<String>) -> MemorySubscriber {
let (tx, rx) = mpsc::unbounded_channel();
let name = name.into();
self.state.register(name.clone(), tx.clone());
MemorySubscriber {
name,
rx,
requeue: tx,
batch_limit: DEFAULT_BATCH_LIMIT,
}
}
#[must_use]
pub fn publisher(&self) -> MemoryPublisher {
MemoryPublisher {
state: Arc::clone(&self.state),
txn: Mutex::new(None),
}
}
#[must_use]
pub fn requester(&self) -> MemoryRequester {
MemoryRequester::new(Arc::clone(&self.state))
}
}
impl std::fmt::Debug for MemoryBroker {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryBroker").finish_non_exhaustive()
}
}
impl Broker for MemoryBroker {
type Error = Infallible;
async fn connect(&self) -> Result<(), Self::Error> {
Ok(())
}
async fn shutdown(&self) -> Result<(), Self::Error> {
self.state
.subscribers
.lock()
.expect("memory broker mutex poisoned")
.clear();
Ok(())
}
}
#[allow(clippy::use_self)]
impl Subscribe for MemoryBroker {
type Subscriber = MemorySubscriber;
async fn subscribe(&self, name: &str) -> Result<Self::Subscriber, Self::Error> {
Ok(MemoryBroker::subscribe(self, name))
}
}
#[derive(Debug, Clone)]
pub struct MemorySource {
name: String,
}
impl MemorySource {
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
Self { name: name.into() }
}
}
impl SubscriptionSource<MemoryBroker> for MemorySource {
type Subscriber = MemorySubscriber;
fn name(&self) -> &str {
&self.name
}
async fn subscribe(self, broker: &MemoryBroker) -> Result<Self::Subscriber, Infallible> {
Ok(broker.subscribe(self.name))
}
}
const DEFAULT_BATCH_LIMIT: usize = 64;
pub struct MemorySubscriber {
name: String,
rx: mpsc::UnboundedReceiver<MemoryDelivery>,
requeue: Sender,
batch_limit: usize,
}
impl MemorySubscriber {
pub fn set_batch_limit(&mut self, limit: usize) {
self.batch_limit = limit;
}
}
impl std::fmt::Debug for MemorySubscriber {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemorySubscriber")
.field("name", &self.name)
.finish_non_exhaustive()
}
}
impl Subscriber for MemorySubscriber {
type Message = MemoryMessage;
type Error = Infallible;
fn stream(&mut self) -> impl Stream<Item = Result<Self::Message, Self::Error>> + Send + '_ {
let requeue = self.requeue.clone();
futures::stream::poll_fn(move |cx| {
self.rx.poll_recv(cx).map(|next| {
next.map(|delivery| {
Ok(MemoryMessage {
delivery: Some(delivery),
requeue: requeue.clone(),
})
})
})
})
}
}
pub struct MemoryPublisher {
state: Arc<MemoryState>,
txn: Mutex<Option<Vec<MemoryDelivery>>>,
}
impl Clone for MemoryPublisher {
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
txn: Mutex::new(None),
}
}
}
impl std::fmt::Debug for MemoryPublisher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryPublisher").finish_non_exhaustive()
}
}
impl Publisher for MemoryPublisher {
type Error = Infallible;
async fn publish(&self, msg: OutgoingMessage<'_>) -> Result<(), Self::Error> {
let delivery = MemoryDelivery {
name: msg.name().to_owned(),
payload: Bytes::copy_from_slice(msg.payload()),
headers: msg.headers().clone(),
};
{
let mut txn = self.txn.lock().expect("memory broker mutex poisoned");
if let Some(buffered) = txn.as_mut() {
buffered.push(delivery);
return Ok(());
}
}
self.state.fanout(&delivery);
Ok(())
}
}
pub struct MemoryMessage {
delivery: Option<MemoryDelivery>,
requeue: Sender,
}
impl std::fmt::Debug for MemoryMessage {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryMessage")
.field("name", &self.delivery.as_ref().map(|d| d.name.as_str()))
.finish_non_exhaustive()
}
}
impl MemoryMessage {
#[must_use]
pub fn name(&self) -> &str {
self.delivery
.as_ref()
.map(|d| d.name.as_str())
.unwrap_or_default()
}
#[must_use]
pub fn into_raw(mut self) -> RawMessage {
let delivery = self.delivery.take().expect("delivery already consumed");
RawMessage::new(delivery.name, delivery.payload).with_headers(delivery.headers)
}
}
impl IncomingMessage for MemoryMessage {
fn payload(&self) -> &[u8] {
self.delivery
.as_ref()
.map(|d| d.payload.as_ref())
.unwrap_or_default()
}
fn partition_key(&self) -> Option<&[u8]> {
crate::Partitioned::partition_key(self)
}
fn headers(&self) -> &Headers {
static EMPTY: OnceLock<Headers> = OnceLock::new();
self.delivery
.as_ref()
.map_or_else(|| EMPTY.get_or_init(Headers::new), |d| &d.headers)
}
async fn ack(mut self) -> Result<(), AckError> {
self.delivery.take();
Ok(())
}
async fn nack(mut self, requeue: bool) -> Result<(), AckError> {
let delivery = self.delivery.take().expect("delivery already consumed");
if requeue {
let _ = self.requeue.send(delivery);
}
Ok(())
}
fn supports_nack_after(&self) -> bool {
true
}
async fn nack_after(mut self, delay: Duration) -> Result<(), AckError> {
let delivery = self.delivery.take().expect("delivery already consumed");
let requeue = self.requeue.clone();
tokio::spawn(async move {
tokio::time::sleep(delay).await;
let _ = requeue.send(delivery);
});
Ok(())
}
}
#[cfg(test)]
mod tests {
use futures::StreamExt;
use super::*;
#[tokio::test]
async fn debug_formats_and_message_accessors() {
let broker = MemoryBroker::new();
assert!(format!("{broker:?}").contains("MemoryBroker"));
let source = MemorySource::new("orders");
assert_eq!(source.name(), "orders");
let publisher = broker.publisher();
assert!(format!("{publisher:?}").contains("MemoryPublisher"));
let mut sub = broker.subscribe("dbg");
assert!(format!("{sub:?}").contains("MemorySubscriber"));
publisher
.publish(OutgoingMessage::new("dbg", b"payload".as_slice()))
.await
.unwrap();
let mut stream = std::pin::pin!(sub.stream());
let msg = stream.next().await.unwrap().unwrap();
assert!(format!("{msg:?}").contains("MemoryMessage"));
assert_eq!(msg.name(), "dbg");
let raw = msg.into_raw();
assert_eq!(raw.name(), "dbg");
assert_eq!(raw.payload(), b"payload");
}
#[tokio::test(start_paused = true)]
async fn nack_after_redelivers_after_the_delay() {
let broker = MemoryBroker::new();
let mut sub = MemoryBroker::subscribe(&broker, "delayed");
let publisher = broker.publisher();
publisher
.publish(OutgoingMessage::new("delayed", b"later".as_slice()))
.await
.unwrap();
let mut stream = std::pin::pin!(sub.stream());
let msg = stream.next().await.unwrap().unwrap();
msg.nack_after(Duration::from_secs(5)).await.unwrap();
assert!(futures::poll!(stream.next()).is_pending());
tokio::time::advance(Duration::from_secs(5)).await;
tokio::task::yield_now().await;
let redelivered = stream.next().await.unwrap().unwrap();
assert_eq!(redelivered.payload(), b"later");
redelivered.ack().await.unwrap();
}
#[tokio::test]
async fn stream_can_be_reentered() {
let broker = MemoryBroker::new();
let mut sub = MemoryBroker::subscribe(&broker, "test");
let publisher = broker.publisher();
publisher
.publish(OutgoingMessage::new("test", b"one".as_slice()))
.await
.unwrap();
{
let mut stream = std::pin::pin!(sub.stream());
let msg = stream.next().await.unwrap().unwrap();
assert_eq!(msg.payload(), b"one");
msg.ack().await.unwrap();
}
publisher
.publish(OutgoingMessage::new("test", b"two".as_slice()))
.await
.unwrap();
let mut stream = std::pin::pin!(sub.stream());
let msg = stream.next().await.unwrap().unwrap();
assert_eq!(msg.payload(), b"two");
msg.ack().await.unwrap();
}
}