pub use crate::errors::{ConsumerError, HandlerError, PublisherError};
pub use crate::outcomes::{Handled, Received, ReceivedBatch, Sent, SentBatch};
use crate::CanonicalMessage;
use async_trait::async_trait;
pub use futures::future::BoxFuture;
use std::any::Any;
use std::sync::Arc;
use tracing::warn;
#[derive(Default, Debug, Clone)]
#[allow(clippy::large_enum_variant)]
pub enum MessageDisposition {
#[default]
Ack,
Reply(CanonicalMessage),
Nack,
}
impl From<Option<CanonicalMessage>> for MessageDisposition {
fn from(opt: Option<CanonicalMessage>) -> Self {
match opt {
Some(msg) => MessageDisposition::Reply(msg),
None => MessageDisposition::Ack,
}
}
}
impl From<Handled> for MessageDisposition {
fn from(handled: Handled) -> Self {
match handled {
Handled::Ack => MessageDisposition::Ack,
Handled::Publish(msg) => MessageDisposition::Reply(msg),
}
}
}
#[async_trait]
pub trait Handler: Send + Sync + 'static {
async fn handle(&self, msg: CanonicalMessage) -> Result<Handled, HandlerError>;
fn register_handler(
&self,
_type_name: &str,
_handler: Arc<dyn Handler>,
) -> Option<Arc<dyn Handler>> {
None
}
}
#[async_trait]
impl<T: Handler + ?Sized> Handler for Arc<T> {
async fn handle(&self, msg: CanonicalMessage) -> Result<Handled, HandlerError> {
(**self).handle(msg).await
}
fn register_handler(
&self,
type_name: &str,
handler: Arc<dyn Handler>,
) -> Option<Arc<dyn Handler>> {
(**self).register_handler(type_name, handler)
}
}
pub trait AsyncHandler: Send + Sync + 'static {
fn handle<'a>(&'a self, msg: CanonicalMessage) -> BoxFuture<'a, Result<Handled, HandlerError>>;
}
pub struct SimpleHandler<T>(pub T);
#[async_trait]
impl<T: AsyncHandler> Handler for SimpleHandler<T> {
async fn handle(&self, msg: CanonicalMessage) -> Result<Handled, HandlerError> {
self.0.handle(msg).await
}
}
pub type CommitFunc =
Box<dyn FnOnce(MessageDisposition) -> BoxFuture<'static, anyhow::Result<()>> + Send + 'static>;
pub type BatchCommitFunc = Box<
dyn FnOnce(Vec<MessageDisposition>) -> BoxFuture<'static, anyhow::Result<()>> + Send + 'static,
>;
#[derive(Debug, Clone, serde::Serialize)]
pub struct EndpointStatus {
pub healthy: bool,
pub target: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub pending: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub capacity: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
pub details: serde_json::Value,
}
impl Default for EndpointStatus {
fn default() -> Self {
Self {
healthy: true,
target: String::new(),
pending: None,
capacity: None,
error: None,
details: serde_json::Value::Null,
}
}
}
#[async_trait]
pub trait MessageConsumer: Send + Sync {
async fn receive_batch(&mut self, _max_messages: usize)
-> Result<ReceivedBatch, ConsumerError>;
async fn receive(&mut self) -> Result<Received, ConsumerError> {
loop {
let mut batch = self.receive_batch(1).await?;
if let Some(msg) = batch.messages.pop() {
debug_assert!(batch.messages.is_empty());
if !batch.messages.is_empty() {
tracing::error!(
"receive_batch(1) returned {} extra messages; dropping them (implementation bug)",
batch.messages.len()
);
}
return Ok(Received {
message: msg,
commit: into_commit_func(batch.commit),
});
}
tokio::task::yield_now().await;
}
}
async fn receive_batch_helper(
&mut self,
_max_messages: usize,
) -> Result<ReceivedBatch, ConsumerError> {
let received = self.receive().await?; let batch_commit = Box::new(move |dispositions: Vec<MessageDisposition>| {
let single_disposition = dispositions
.into_iter()
.next()
.unwrap_or(MessageDisposition::Ack);
(received.commit)(single_disposition)
}) as BatchCommitFunc;
Ok(ReceivedBatch {
messages: vec![received.message],
commit: batch_commit,
})
}
async fn status(&self) -> EndpointStatus {
EndpointStatus {
healthy: true,
..Default::default()
}
}
fn as_any(&self) -> &dyn Any;
}
#[async_trait]
pub trait MessagePublisher: Send + Sync + 'static {
async fn send_batch(
&self,
messages: Vec<CanonicalMessage>,
) -> Result<SentBatch, PublisherError>;
async fn send(&self, message: CanonicalMessage) -> Result<Sent, PublisherError> {
match self.send_batch(vec![message]).await {
Ok(SentBatch::Ack) => Ok(Sent::Ack),
Ok(SentBatch::Partial {
mut responses,
mut failed,
}) => {
if let Some((_, err)) = failed.pop() {
Err(err)
} else if let Some(res) = responses.as_mut().and_then(|r| r.pop()) {
Ok(Sent::Response(res))
} else {
Ok(Sent::Ack)
}
}
Err(e) => Err(e),
}
}
async fn flush(&self) -> anyhow::Result<()> {
Ok(())
}
async fn status(&self) -> EndpointStatus {
EndpointStatus {
healthy: true,
..Default::default()
}
}
fn as_any(&self) -> &dyn Any;
}
#[async_trait]
impl<T: MessagePublisher + ?Sized> MessagePublisher for Arc<T> {
async fn send(&self, message: CanonicalMessage) -> Result<Sent, PublisherError> {
(**self).send(message).await
}
async fn send_batch(
&self,
messages: Vec<CanonicalMessage>,
) -> Result<SentBatch, PublisherError> {
(**self).send_batch(messages).await
}
async fn flush(&self) -> anyhow::Result<()> {
(**self).flush().await
}
async fn status(&self) -> EndpointStatus {
(**self).status().await
}
fn as_any(&self) -> &dyn Any {
(**self).as_any()
}
}
#[async_trait]
impl<T: MessagePublisher + ?Sized> MessagePublisher for Box<T> {
async fn send(&self, message: CanonicalMessage) -> Result<Sent, PublisherError> {
(**self).send(message).await
}
async fn send_batch(
&self,
messages: Vec<CanonicalMessage>,
) -> Result<SentBatch, PublisherError> {
(**self).send_batch(messages).await
}
async fn flush(&self) -> anyhow::Result<()> {
(**self).flush().await
}
async fn status(&self) -> EndpointStatus {
(**self).status().await
}
fn as_any(&self) -> &dyn Any {
(**self).as_any()
}
}
#[async_trait]
pub trait CustomEndpointFactory: Send + Sync + std::fmt::Debug {
async fn create_consumer(
&self,
_route_name: &str,
_config: &serde_json::Value,
) -> anyhow::Result<Box<dyn MessageConsumer>> {
Err(anyhow::anyhow!(
"This custom endpoint does not support creating consumers"
))
}
async fn create_publisher(
&self,
_route_name: &str,
_config: &serde_json::Value,
) -> anyhow::Result<Box<dyn MessagePublisher>> {
Err(anyhow::anyhow!(
"This custom endpoint does not support creating publishers"
))
}
}
#[async_trait]
pub trait CustomMiddlewareFactory: Send + Sync + std::fmt::Debug {
async fn apply_consumer(
&self,
consumer: Box<dyn MessageConsumer>,
_route_name: &str,
_config: &serde_json::Value,
) -> anyhow::Result<Box<dyn MessageConsumer>> {
Ok(consumer)
}
async fn apply_publisher(
&self,
publisher: Box<dyn MessagePublisher>,
_route_name: &str,
_config: &serde_json::Value,
) -> anyhow::Result<Box<dyn MessagePublisher>> {
Ok(publisher)
}
}
pub async fn send_batch_helper<P: MessagePublisher + ?Sized>(
publisher: &P,
messages: Vec<CanonicalMessage>,
callback: impl for<'a> Fn(&'a P, CanonicalMessage) -> BoxFuture<'a, Result<Sent, PublisherError>>
+ Send
+ Sync,
) -> Result<SentBatch, PublisherError> {
let mut responses = Vec::new();
let mut failed_messages = Vec::new();
let mut iter = messages.into_iter();
while let Some(msg) = iter.next() {
match callback(publisher, msg.clone()).await {
Ok(Sent::Response(resp)) => responses.push(resp),
Ok(Sent::Ack) => {}
Err(PublisherError::Retryable(e)) => {
failed_messages.push((msg, PublisherError::Retryable(e)));
for m in iter {
failed_messages.push((
m,
PublisherError::Retryable(anyhow::anyhow!(
"Batch aborted due to previous error"
)),
));
}
break;
}
Err(PublisherError::NonRetryable(e)) => {
failed_messages.push((msg, PublisherError::NonRetryable(e)));
}
}
}
if failed_messages.is_empty() && responses.is_empty() {
Ok(SentBatch::Ack)
} else {
Ok(SentBatch::Partial {
responses: if responses.is_empty() {
None
} else {
Some(responses)
},
failed: failed_messages,
})
}
}
pub fn into_commit_func(batch_commit: BatchCommitFunc) -> CommitFunc {
Box::new(move |disposition: MessageDisposition| {
let batch_disposition = vec![disposition];
batch_commit(batch_disposition)
})
}
pub fn into_batch_commit_func(commit: CommitFunc) -> BatchCommitFunc {
Box::new(move |mut dispositions: Vec<MessageDisposition>| {
let single_disposition = if dispositions.len() > 1 {
warn!(
"into_batch_commit_func called with batch of {} messages; dropping all responses to avoid partial commit (incorrect usage)",
dispositions.len()
);
MessageDisposition::Ack
} else {
dispositions.pop().unwrap_or(MessageDisposition::Ack)
};
commit(single_disposition)
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::CanonicalMessage;
use anyhow::anyhow;
struct MockPublisher;
#[async_trait]
impl MessagePublisher for MockPublisher {
async fn send_batch(
&self,
_msgs: Vec<CanonicalMessage>,
) -> Result<SentBatch, PublisherError> {
Ok(SentBatch::Ack)
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[tokio::test]
async fn test_send_batch_helper_partial_failure() {
let publisher = MockPublisher;
let msgs = vec![
CanonicalMessage::from("1"),
CanonicalMessage::from("2"),
CanonicalMessage::from("3"),
];
let result = send_batch_helper(&publisher, msgs.clone(), |_pub, msg| {
Box::pin(async move {
let payload = msg.get_payload_str();
if payload == "1" {
Ok(Sent::Response(CanonicalMessage::from("resp1")))
} else if payload == "2" {
Err(PublisherError::Retryable(anyhow!("fail")))
} else {
Ok(Sent::Ack)
}
})
})
.await;
match result {
Ok(SentBatch::Partial { responses, failed }) => {
assert!(responses.is_some());
let resps = responses.unwrap();
assert_eq!(resps.len(), 1);
assert_eq!(resps[0].get_payload_str(), "resp1");
assert_eq!(failed.len(), 2);
assert_eq!(failed[0].0.get_payload_str(), "2");
assert!(matches!(failed[0].1, PublisherError::Retryable(_)));
assert_eq!(failed[1].0.get_payload_str(), "3");
assert!(matches!(failed[1].1, PublisherError::Retryable(_)));
}
_ => panic!("Expected Partial result"),
}
}
#[tokio::test]
async fn test_send_propagates_single_error() {
struct FailPublisher;
#[async_trait]
impl MessagePublisher for FailPublisher {
async fn send_batch(
&self,
msgs: Vec<CanonicalMessage>,
) -> Result<SentBatch, PublisherError> {
Ok(SentBatch::Partial {
responses: None,
failed: vec![(
msgs[0].clone(),
PublisherError::NonRetryable(anyhow!("inner")),
)],
})
}
fn as_any(&self) -> &dyn Any {
self
}
}
let publ = FailPublisher;
let res = publ.send(CanonicalMessage::from("test")).await;
assert!(res.is_err());
match res.unwrap_err() {
PublisherError::NonRetryable(e) => assert_eq!(e.to_string(), "inner"),
_ => panic!("Expected NonRetryable error"),
}
}
#[tokio::test]
async fn test_simple_handler_wrapper() {
struct MyLogic;
impl AsyncHandler for MyLogic {
fn handle<'a>(
&'a self,
_msg: CanonicalMessage,
) -> BoxFuture<'a, Result<Handled, HandlerError>> {
Box::pin(async { Ok(Handled::Ack) })
}
}
let handler = SimpleHandler(MyLogic);
let res = handler.handle(CanonicalMessage::from("test")).await;
assert!(matches!(res, Ok(Handled::Ack)));
}
}