use std::{
convert::Infallible,
sync::{Arc, atomic::Ordering},
task::Poll,
time::Duration,
};
use bytes::Bytes;
use futures::Stream;
use thiserror::Error;
use tokio::{sync::mpsc, time::timeout};
use super::{MemoryDelivery, MemoryMessage, MemoryPublisher, MemoryState, MemorySubscriber};
use crate::{
BatchSubscriber, IncomingMessage, OutgoingMessage, Partitioned, Publisher, RequestReply,
Subscriber, TransactionalPublisher,
};
pub const PARTITION_KEY_HEADER: &str = "partition-key";
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum RequestError {
#[error("no reply to \"{subject}\" within {timeout:?}")]
Timeout {
subject: String,
timeout: Duration,
},
}
#[derive(Clone)]
pub struct MemoryRequester {
state: Arc<MemoryState>,
}
impl MemoryRequester {
pub(super) fn new(state: Arc<MemoryState>) -> Self {
Self { state }
}
}
impl std::fmt::Debug for MemoryRequester {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MemoryRequester").finish_non_exhaustive()
}
}
impl Publisher for MemoryRequester {
type Error = RequestError;
async fn publish(&self, msg: OutgoingMessage<'_>) -> Result<(), Self::Error> {
let delivery = MemoryDelivery {
name: msg.name().to_owned(),
payload: Bytes::copy_from_slice(msg.payload()),
headers: msg.headers().clone(),
};
self.state.fanout(&delivery);
Ok(())
}
}
impl RequestReply for MemoryRequester {
type Reply = MemoryMessage;
async fn request(
&self,
msg: OutgoingMessage<'_>,
wait: Duration,
) -> Result<Self::Reply, Self::Error> {
let id = self.state.inbox_seq.fetch_add(1, Ordering::Relaxed);
let inbox = format!("_inbox.{id}");
let (tx, mut rx) = mpsc::unbounded_channel();
self.state.register(inbox.clone(), tx.clone());
let mut headers = msg.headers().clone();
headers.insert("reply-to", inbox.clone());
let delivery = MemoryDelivery {
name: msg.name().to_owned(),
payload: Bytes::copy_from_slice(msg.payload()),
headers,
};
self.state.fanout(&delivery);
let outcome = timeout(wait, rx.recv()).await;
self.state.unregister(&inbox);
match outcome {
Ok(Some(reply)) => Ok(MemoryMessage {
delivery: Some(reply),
requeue: tx,
}),
Ok(None) => unreachable!("request inbox closed while its sender is held"),
Err(_) => Err(RequestError::Timeout {
subject: msg.name().to_owned(),
timeout: wait,
}),
}
}
}
impl BatchSubscriber for MemorySubscriber {
type Batch = Vec<MemoryMessage>;
fn batches(
&mut self,
) -> impl Stream<Item = Result<Self::Batch, <Self as Subscriber>::Error>> + Send + '_ {
let limit = self.batch_limit.max(1);
let requeue = self.requeue.clone();
futures::stream::poll_fn(move |cx| {
let Some(first) = std::task::ready!(self.rx.poll_recv(cx)) else {
return Poll::Ready(None);
};
let mut batch = vec![MemoryMessage {
delivery: Some(first),
requeue: requeue.clone(),
}];
while batch.len() < limit {
match self.rx.poll_recv(cx) {
Poll::Ready(Some(delivery)) => batch.push(MemoryMessage {
delivery: Some(delivery),
requeue: requeue.clone(),
}),
Poll::Ready(None) | Poll::Pending => break,
}
}
Poll::Ready(Some(Ok(batch)))
})
}
}
impl TransactionalPublisher for MemoryPublisher {
async fn begin_transaction(&self) -> Result<(), Infallible> {
let mut txn = self.txn.lock().expect("memory broker mutex poisoned");
if txn.is_none() {
*txn = Some(Vec::new());
}
drop(txn);
Ok(())
}
async fn commit(&self) -> Result<(), Infallible> {
let buffered = self
.txn
.lock()
.expect("memory broker mutex poisoned")
.take();
for delivery in buffered.into_iter().flatten() {
self.state.fanout(&delivery);
}
Ok(())
}
async fn abort(&self) -> Result<(), Infallible> {
self.txn
.lock()
.expect("memory broker mutex poisoned")
.take();
Ok(())
}
}
impl Partitioned for MemoryMessage {
fn partition_key(&self) -> Option<&[u8]> {
self.headers().get(PARTITION_KEY_HEADER)
}
}
#[cfg(test)]
mod tests {
use futures::StreamExt;
use super::super::MemoryBroker;
use super::*;
use crate::Headers;
#[tokio::test]
async fn batches_drain_buffered_deliveries() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("batch");
let publisher = broker.publisher();
for i in 0..5u8 {
publisher
.publish(OutgoingMessage::new("batch", &[i]))
.await
.unwrap();
}
let mut stream = std::pin::pin!(sub.batches());
let batch = stream.next().await.unwrap().unwrap();
let payloads: Vec<u8> = batch.iter().map(|m| m.payload()[0]).collect();
assert_eq!(payloads, [0, 1, 2, 3, 4]);
for msg in batch {
msg.ack().await.unwrap();
}
}
#[tokio::test]
async fn batch_limit_caps_each_batch() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("batch.capped");
sub.set_batch_limit(2);
let publisher = broker.publisher();
for i in 0..3u8 {
publisher
.publish(OutgoingMessage::new("batch.capped", &[i]))
.await
.unwrap();
}
let mut stream = std::pin::pin!(sub.batches());
let first = stream.next().await.unwrap().unwrap();
assert_eq!(first.len(), 2);
let second = stream.next().await.unwrap().unwrap();
assert_eq!(second.len(), 1);
for msg in first.into_iter().chain(second) {
msg.ack().await.unwrap();
}
}
#[tokio::test]
async fn transaction_buffers_until_commit() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("txn");
let publisher = broker.publisher();
publisher.begin_transaction().await.unwrap();
publisher
.publish(OutgoingMessage::new("txn", b"a".as_slice()))
.await
.unwrap();
publisher
.publish(OutgoingMessage::new("txn", b"b".as_slice()))
.await
.unwrap();
let mut stream = std::pin::pin!(sub.stream());
assert!(futures::poll!(stream.next()).is_pending());
publisher.commit().await.unwrap();
let first = stream.next().await.unwrap().unwrap();
assert_eq!(first.payload(), b"a");
first.ack().await.unwrap();
let second = stream.next().await.unwrap().unwrap();
assert_eq!(second.payload(), b"b");
second.ack().await.unwrap();
}
#[tokio::test]
async fn abort_discards_buffered_publishes() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("txn.abort");
let publisher = broker.publisher();
publisher.begin_transaction().await.unwrap();
publisher
.publish(OutgoingMessage::new("txn.abort", b"gone".as_slice()))
.await
.unwrap();
publisher.abort().await.unwrap();
let mut stream = std::pin::pin!(sub.stream());
assert!(futures::poll!(stream.next()).is_pending());
publisher
.publish(OutgoingMessage::new("txn.abort", b"kept".as_slice()))
.await
.unwrap();
let msg = stream.next().await.unwrap().unwrap();
assert_eq!(msg.payload(), b"kept");
msg.ack().await.unwrap();
}
#[tokio::test]
async fn clone_does_not_join_transaction() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("txn.clone");
let transactional = broker.publisher();
transactional.begin_transaction().await.unwrap();
transactional
.publish(OutgoingMessage::new("txn.clone", b"buffered".as_slice()))
.await
.unwrap();
let independent = transactional.clone();
independent
.publish(OutgoingMessage::new("txn.clone", b"direct".as_slice()))
.await
.unwrap();
let mut stream = std::pin::pin!(sub.stream());
let first = stream.next().await.unwrap().unwrap();
assert_eq!(first.payload(), b"direct");
first.ack().await.unwrap();
transactional.commit().await.unwrap();
let second = stream.next().await.unwrap().unwrap();
assert_eq!(second.payload(), b"buffered");
second.ack().await.unwrap();
}
#[tokio::test]
async fn request_resolves_on_reply() {
let broker = MemoryBroker::new();
let mut service = broker.subscribe("svc.echo");
let publisher = broker.publisher();
let requester = broker.requester();
let respond = async {
let mut stream = std::pin::pin!(service.stream());
let msg = stream.next().await.unwrap().unwrap();
assert_eq!(msg.payload(), b"ping");
let reply_to = msg.headers().reply_to().unwrap().to_owned();
publisher
.publish(OutgoingMessage::new(&reply_to, b"pong".as_slice()))
.await
.unwrap();
msg.ack().await.unwrap();
};
let request = requester.request(
OutgoingMessage::new("svc.echo", b"ping".as_slice()),
Duration::from_secs(1),
);
let (reply, ()) = futures::join!(request, respond);
assert_eq!(reply.unwrap().payload(), b"pong");
let inbox_leaked = broker
.state
.subscribers
.lock()
.unwrap()
.keys()
.any(|name| name.starts_with("_inbox."));
assert!(!inbox_leaked);
}
#[tokio::test(start_paused = true)]
async fn request_times_out_without_responder() {
let broker = MemoryBroker::new();
let requester = broker.requester();
let outcome = requester
.request(
OutgoingMessage::new("svc.void", b"ping".as_slice()),
Duration::from_millis(5),
)
.await;
assert!(matches!(outcome, Err(RequestError::Timeout { .. })));
let inbox_leaked = broker
.state
.subscribers
.lock()
.unwrap()
.keys()
.any(|name| name.starts_with("_inbox."));
assert!(!inbox_leaked);
}
#[tokio::test]
async fn partition_key_reads_well_known_header() {
let broker = MemoryBroker::new();
let mut sub = broker.subscribe("keyed");
let publisher = broker.publisher();
let mut headers = Headers::new();
headers.insert(PARTITION_KEY_HEADER, b"user-42".as_slice());
publisher
.publish(OutgoingMessage::new("keyed", b"a".as_slice()).with_headers(headers))
.await
.unwrap();
publisher
.publish(OutgoingMessage::new("keyed", b"b".as_slice()))
.await
.unwrap();
let mut stream = std::pin::pin!(sub.stream());
let keyed = stream.next().await.unwrap().unwrap();
assert_eq!(
Partitioned::partition_key(&keyed),
Some(b"user-42".as_slice())
);
assert_eq!(
IncomingMessage::partition_key(&keyed),
Some(b"user-42".as_slice())
);
keyed.ack().await.unwrap();
let unkeyed = stream.next().await.unwrap().unwrap();
assert_eq!(Partitioned::partition_key(&unkeyed), None);
unkeyed.ack().await.unwrap();
}
}