use crate::traits::{send_batch_helper, BoxFuture, Handler, MessagePublisher};
use crate::traits::{Handled, HandlerError};
use crate::CanonicalMessage;
use async_trait::async_trait;
use std::any::Any;
use std::future::Future;
use std::sync::Arc;
use crate::traits::{PublisherError, Sent, SentBatch};
#[async_trait]
impl<F, Fut> Handler for F
where
F: Fn(CanonicalMessage) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<Handled, HandlerError>> + Send,
{
async fn handle(&self, msg: CanonicalMessage) -> Result<Handled, HandlerError> {
self(msg).await
}
}
pub struct CommandPublisher {
inner: Box<dyn MessagePublisher>,
handler: Arc<dyn Handler>,
}
impl CommandPublisher {
pub fn new(inner: impl MessagePublisher, handler: impl Handler + 'static) -> Self {
Self {
inner: Box::new(inner),
handler: Arc::new(handler),
}
}
}
#[async_trait]
impl MessagePublisher for CommandPublisher {
fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
self.inner.on_connect_hook()
}
fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
self.inner.on_disconnect_hook()
}
async fn send(&self, message: CanonicalMessage) -> Result<Sent, PublisherError> {
let inbound_correlation_id = message.metadata.get("correlation_id").cloned();
let original_id = message.message_id;
match self.handler.handle(message).await {
Ok(Handled::Publish(mut response_msg)) => {
response_msg.message_id = original_id;
let fallback_correlation_id =
inbound_correlation_id.unwrap_or_else(|| format!("{:032x}", original_id));
response_msg
.metadata
.entry("correlation_id".to_string())
.or_insert(fallback_correlation_id);
self.inner.send(response_msg).await
}
Ok(Handled::Ack) => Ok(Sent::Ack),
Err(e) => Err(e), }
}
async fn send_batch(
&self,
messages: Vec<CanonicalMessage>,
) -> Result<SentBatch, PublisherError> {
send_batch_helper(self, messages, |publisher, message| {
Box::pin(publisher.send(message))
})
.await
}
async fn flush(&self) -> anyhow::Result<()> {
self.inner.flush().await
}
fn as_any(&self) -> &dyn Any {
self
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicBool, Ordering};
use super::*;
use crate::endpoints::memory::MemoryPublisher;
#[tokio::test]
async fn test_command_handler_produces_response() {
let memory_publisher = MemoryPublisher::new_local("test_command_out_resp", 10);
let channel = memory_publisher.channel();
let handler = |msg: CanonicalMessage| async move {
let response_payload = format!("response_to_{}", String::from_utf8_lossy(&msg.payload));
Ok(Handled::Publish(response_payload.into()))
};
let publisher = CommandPublisher::new(memory_publisher, handler);
publisher.send("command1".into()).await.unwrap();
let received = channel.drain_messages();
assert_eq!(received.len(), 1);
assert_eq!(received[0].payload, "response_to_command1".as_bytes());
}
#[tokio::test]
async fn test_command_handler_acks() {
let memory_publisher = MemoryPublisher::new_local("test_command_out_ack", 10);
let channel = memory_publisher.channel();
let handler = |_msg: CanonicalMessage| async move { Ok(Handled::Ack) };
let publisher = CommandPublisher::new(memory_publisher, handler);
let result = publisher.send("command1".into()).await.unwrap();
assert!(matches!(result, Sent::Ack));
let received = channel.drain_messages();
assert_eq!(received.len(), 0);
}
#[tokio::test]
async fn test_command_handler_retryable_error() {
let memory_publisher = MemoryPublisher::new_local("test_command_out_err", 10);
let handler = |_msg: CanonicalMessage| async move {
Err(HandlerError::Retryable(anyhow::anyhow!("db is down")))
};
let publisher = CommandPublisher::new(memory_publisher, handler);
let result = publisher.send("command1".into()).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, PublisherError::Retryable(_)));
}
#[tokio::test]
async fn test_command_handler_integration_with_memory_consumer() {
use crate::endpoints::memory::MemoryConsumer;
use crate::traits::MessageConsumer;
let mut consumer = MemoryConsumer::new_local("cmd_input", 10);
let input_channel = consumer.channel();
let memory_publisher = MemoryPublisher::new_local("cmd_output", 10);
let output_channel = memory_publisher.channel();
let publisher =
CommandPublisher::new(memory_publisher, |msg: CanonicalMessage| async move {
let payload = String::from_utf8_lossy(&msg.payload);
let response = format!("processed_{}", payload);
Ok(Handled::Publish(response.into()))
});
input_channel
.send_message("test_data".into())
.await
.unwrap();
let received = consumer.receive().await.unwrap();
let result = publisher.send(received.message).await.unwrap();
assert!(matches!(result, Sent::Ack));
let output_msgs = output_channel.drain_messages();
assert_eq!(output_msgs.len(), 1);
assert_eq!(output_msgs[0].payload.to_vec(), b"processed_test_data");
let _ = (received.commit)(crate::traits::MessageDisposition::Ack).await;
}
#[tokio::test(flavor = "multi_thread")]
async fn test_command_handler_with_route_config() {
use crate::models::{Endpoint, Route};
let success = Arc::new(AtomicBool::new(false));
let success_clone = success.clone();
let handler = move |mut msg: CanonicalMessage| {
success_clone.store(true, Ordering::SeqCst);
msg.set_payload_str(format!("modified {}", msg.get_payload_str()));
async move { Ok(Handled::Publish(msg)) }
};
let route = Route::new(
Endpoint::new_memory("route_in", 100),
Endpoint::new_memory("route_out", 100),
)
.with_handler(handler);
route.deploy("command_handler_test_route").await.unwrap();
let input_channel = route.input.channel().unwrap();
input_channel.send_message("hello".into()).await.unwrap();
let mut verifier = route.connect_to_output("verifier").await.unwrap();
let received = verifier.receive().await.unwrap();
assert_eq!(received.message.get_payload_str(), "modified hello");
assert!(success.load(Ordering::SeqCst));
Route::stop("command_handler_test_route").await;
}
#[tokio::test]
async fn test_command_handler_inner_publisher_failure() {
use crate::traits::MessagePublisher;
struct FailPublisher;
#[async_trait]
impl MessagePublisher for FailPublisher {
async fn send(&self, _msg: CanonicalMessage) -> Result<Sent, PublisherError> {
Err(PublisherError::NonRetryable(anyhow::anyhow!("inner fail")))
}
async fn send_batch(
&self,
_msgs: Vec<CanonicalMessage>,
) -> Result<SentBatch, PublisherError> {
Ok(SentBatch::Ack)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
let handler = |msg: CanonicalMessage| async move { Ok(Handled::Publish(msg)) };
let publisher = CommandPublisher::new(FailPublisher, handler);
let result = publisher.send("test".into()).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("inner fail"));
}
#[tokio::test]
async fn test_command_handler_preserves_message_id() {
let memory_publisher = MemoryPublisher::new_local("test_cmd_id_preservation", 10);
let channel = memory_publisher.channel();
let handler = |_msg: CanonicalMessage| async move {
let new_msg = CanonicalMessage::new(b"response".to_vec(), None);
Ok(Handled::Publish(new_msg))
};
let publisher = CommandPublisher::new(memory_publisher, handler);
let original_id = 987654321u128;
publisher
.send(CanonicalMessage::new(b"req".to_vec(), Some(original_id)))
.await
.unwrap();
let received = channel.drain_messages();
assert_eq!(received[0].message_id, original_id);
}
}