use crate::traits::{
BoxFuture, ConsumerError, MessageConsumer, MessageDisposition, MessagePublisher,
PublisherError, Sent, SentBatch,
};
use crate::CanonicalMessage;
use async_trait::async_trait;
use std::any::Any;
use std::sync::Arc;
use tokio::sync::Mutex;
pub struct ReaderPublisher {
consumer: Arc<Mutex<Box<dyn MessageConsumer>>>,
}
impl ReaderPublisher {
pub fn new(consumer: Box<dyn MessageConsumer>) -> Self {
Self {
consumer: Arc::new(Mutex::new(consumer)),
}
}
}
#[async_trait]
impl MessagePublisher for ReaderPublisher {
fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
Some(Box::pin(async move {
let consumer = self.consumer.lock().await;
if let Some(hook) = consumer.on_connect_hook() {
hook.await?;
}
Ok(())
}))
}
fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
Some(Box::pin(async move {
let consumer = self.consumer.lock().await;
if let Some(hook) = consumer.on_disconnect_hook() {
hook.await?;
}
Ok(())
}))
}
async fn send(&self, _message: CanonicalMessage) -> Result<Sent, PublisherError> {
let mut consumer = self.consumer.lock().await;
match consumer.receive().await {
Ok(received) => {
if let Err(e) = (received.commit)(MessageDisposition::Ack).await {
return Err(PublisherError::Retryable(anyhow::anyhow!(
"Failed to commit message in ReaderPublisher: {}",
e
)));
}
Ok(Sent::Response(received.message))
}
Err(e) => match e {
ConsumerError::EndOfStream => Err(PublisherError::NonRetryable(anyhow::anyhow!(e))),
_ => Err(PublisherError::Retryable(anyhow::anyhow!(e))),
},
}
}
async fn send_batch(
&self,
messages: Vec<CanonicalMessage>,
) -> Result<SentBatch, PublisherError> {
let count = messages.len();
if count == 0 {
return Ok(SentBatch::Ack);
}
let mut consumer = self.consumer.lock().await;
match consumer.receive_batch(count).await {
Ok(batch) => {
let received_count = batch.messages.len();
if received_count > 0 {
if let Err(e) =
(batch.commit)(vec![MessageDisposition::Ack; received_count]).await
{
return Err(PublisherError::Retryable(anyhow::anyhow!(
"Failed to commit batch in ReaderPublisher: {}",
e
)));
}
}
Ok(SentBatch::Ack)
}
Err(e) => match e {
ConsumerError::EndOfStream => Err(PublisherError::NonRetryable(anyhow::anyhow!(e))),
_ => Err(PublisherError::Retryable(anyhow::anyhow!(e))),
},
}
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::outcomes::{Received, ReceivedBatch};
use crate::traits::{BatchCommitFunc, CommitFunc, EndpointStatus};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc as StdArc, Mutex as StdMutex};
struct MockConsumer {
single_result: Option<Result<CanonicalMessage, ConsumerError>>,
batch_result: Option<Result<Vec<CanonicalMessage>, ConsumerError>>,
commit_log: StdArc<StdMutex<Vec<Vec<MessageDisposition>>>>,
commit_error: Option<String>,
connect_calls: StdArc<AtomicUsize>,
disconnect_calls: StdArc<AtomicUsize>,
}
impl MockConsumer {
fn new_single(
result: Result<CanonicalMessage, ConsumerError>,
commit_log: StdArc<StdMutex<Vec<Vec<MessageDisposition>>>>,
) -> Self {
Self {
single_result: Some(result),
batch_result: None,
commit_log,
commit_error: None,
connect_calls: StdArc::new(AtomicUsize::new(0)),
disconnect_calls: StdArc::new(AtomicUsize::new(0)),
}
}
fn new_batch(
result: Result<Vec<CanonicalMessage>, ConsumerError>,
commit_log: StdArc<StdMutex<Vec<Vec<MessageDisposition>>>>,
) -> Self {
Self {
single_result: None,
batch_result: Some(result),
commit_log,
commit_error: None,
connect_calls: StdArc::new(AtomicUsize::new(0)),
disconnect_calls: StdArc::new(AtomicUsize::new(0)),
}
}
fn with_commit_error(mut self, message: &str) -> Self {
self.commit_error = Some(message.to_string());
self
}
fn commit_func(&self) -> CommitFunc {
let log = self.commit_log.clone();
let error = self.commit_error.clone();
Box::new(move |disposition| {
let log = log.clone();
let error = error.clone();
Box::pin(async move {
log.lock().unwrap().push(vec![disposition]);
if let Some(message) = error {
Err(anyhow::anyhow!(message))
} else {
Ok(())
}
})
})
}
fn batch_commit_func(&self) -> BatchCommitFunc {
let log = self.commit_log.clone();
let error = self.commit_error.clone();
Box::new(move |dispositions| {
let log = log.clone();
let error = error.clone();
Box::pin(async move {
log.lock().unwrap().push(dispositions);
if let Some(message) = error {
Err(anyhow::anyhow!(message))
} else {
Ok(())
}
})
})
}
}
#[async_trait]
impl MessageConsumer for MockConsumer {
fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
let calls = self.connect_calls.clone();
Some(Box::pin(async move {
calls.fetch_add(1, Ordering::SeqCst);
Ok(())
}))
}
fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
let calls = self.disconnect_calls.clone();
Some(Box::pin(async move {
calls.fetch_add(1, Ordering::SeqCst);
Ok(())
}))
}
async fn receive(&mut self) -> Result<Received, ConsumerError> {
match self
.single_result
.take()
.expect("single_result should be configured for this test")
{
Ok(message) => Ok(Received {
message,
commit: self.commit_func(),
}),
Err(err) => Err(err),
}
}
async fn receive_batch(
&mut self,
_max_messages: usize,
) -> Result<ReceivedBatch, ConsumerError> {
match self
.batch_result
.take()
.expect("batch_result should be configured for this test")
{
Ok(messages) => Ok(ReceivedBatch {
messages,
commit: self.batch_commit_func(),
}),
Err(err) => Err(err),
}
}
async fn status(&self) -> EndpointStatus {
EndpointStatus::default()
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[tokio::test]
async fn test_reader_publisher_send_returns_response_and_commits_ack() {
let commit_log = StdArc::new(StdMutex::new(Vec::new()));
let publisher = ReaderPublisher::new(Box::new(MockConsumer::new_single(
Ok(CanonicalMessage::from("from-reader")),
commit_log.clone(),
)));
let sent = publisher
.send(CanonicalMessage::from("trigger"))
.await
.unwrap();
match sent {
Sent::Response(message) => assert_eq!(message.get_payload_str(), "from-reader"),
Sent::Ack => panic!("expected response"),
}
assert_eq!(commit_log.lock().unwrap().len(), 1);
assert!(matches!(
commit_log.lock().unwrap()[0].as_slice(),
[MessageDisposition::Ack]
));
}
#[tokio::test]
async fn test_reader_publisher_send_maps_end_of_stream_to_non_retryable_error() {
let publisher = ReaderPublisher::new(Box::new(MockConsumer::new_single(
Err(ConsumerError::EndOfStream),
StdArc::new(StdMutex::new(Vec::new())),
)));
let err = publisher
.send(CanonicalMessage::from("trigger"))
.await
.unwrap_err();
assert!(matches!(err, PublisherError::NonRetryable(_)));
}
#[tokio::test]
async fn test_reader_publisher_send_batch_commits_all_received_messages() {
let commit_log = StdArc::new(StdMutex::new(Vec::new()));
let publisher = ReaderPublisher::new(Box::new(MockConsumer::new_batch(
Ok(vec![
CanonicalMessage::from("one"),
CanonicalMessage::from("two"),
]),
commit_log.clone(),
)));
let sent = publisher
.send_batch(vec![
CanonicalMessage::from("trigger-1"),
CanonicalMessage::from("trigger-2"),
])
.await
.unwrap();
assert!(matches!(sent, SentBatch::Ack));
assert_eq!(commit_log.lock().unwrap().len(), 1);
assert!(commit_log.lock().unwrap()[0]
.iter()
.all(|disposition| matches!(disposition, MessageDisposition::Ack)));
}
#[tokio::test]
async fn test_reader_publisher_send_batch_commit_failure_is_retryable() {
let publisher = ReaderPublisher::new(Box::new(
MockConsumer::new_batch(
Ok(vec![CanonicalMessage::from("one")]),
StdArc::new(StdMutex::new(Vec::new())),
)
.with_commit_error("commit failed"),
));
let err = publisher
.send_batch(vec![CanonicalMessage::from("trigger")])
.await
.unwrap_err();
assert!(matches!(err, PublisherError::Retryable(_)));
}
#[tokio::test]
async fn test_reader_publisher_runs_consumer_hooks() {
let consumer =
MockConsumer::new_batch(Ok(Vec::new()), StdArc::new(StdMutex::new(Vec::new())));
let connect_calls = consumer.connect_calls.clone();
let disconnect_calls = consumer.disconnect_calls.clone();
let publisher = ReaderPublisher::new(Box::new(consumer));
publisher
.on_connect_hook()
.unwrap()
.await
.expect("connect hook should succeed");
publisher
.on_disconnect_hook()
.unwrap()
.await
.expect("disconnect hook should succeed");
assert_eq!(connect_calls.load(Ordering::SeqCst), 1);
assert_eq!(disconnect_calls.load(Ordering::SeqCst), 1);
assert!(publisher.as_any().is::<ReaderPublisher>());
}
}