mq_bridge/
command_handler.rs1use crate::traits::{send_batch_helper, BoxFuture, Handler, MessagePublisher};
7use crate::traits::{Handled, HandlerError};
8use crate::CanonicalMessage;
9use async_trait::async_trait;
10use std::any::Any;
11use std::future::Future;
12use std::sync::Arc;
13
14use crate::traits::{PublisherError, Sent, SentBatch};
15#[async_trait]
16impl<F, Fut> Handler for F
17where
18 F: Fn(CanonicalMessage) -> Fut + Send + Sync + 'static,
19 Fut: Future<Output = Result<Handled, HandlerError>> + Send,
20{
21 async fn handle(&self, msg: CanonicalMessage) -> Result<Handled, HandlerError> {
22 self(msg).await
23 }
24}
25
26pub struct CommandPublisher {
29 inner: Box<dyn MessagePublisher>,
30 handler: Arc<dyn Handler>,
31}
32
33impl CommandPublisher {
34 pub fn new(inner: impl MessagePublisher, handler: impl Handler + 'static) -> Self {
35 Self {
36 inner: Box::new(inner),
37 handler: Arc::new(handler),
38 }
39 }
40}
41
42#[async_trait]
43impl MessagePublisher for CommandPublisher {
44 fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
45 self.inner.on_connect_hook()
46 }
47
48 fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
49 self.inner.on_disconnect_hook()
50 }
51
52 async fn send(&self, message: CanonicalMessage) -> Result<Sent, PublisherError> {
53 let inbound_correlation_id = message.metadata.get("correlation_id").cloned();
54 let original_id = message.message_id;
55 match self.handler.handle(message).await {
56 Ok(Handled::Publish(mut response_msg)) => {
57 response_msg.message_id = original_id;
59 let fallback_correlation_id =
61 inbound_correlation_id.unwrap_or_else(|| format!("{:032x}", original_id));
62 response_msg
63 .metadata
64 .entry("correlation_id".to_string())
65 .or_insert(fallback_correlation_id);
66 self.inner.send(response_msg).await
67 }
68 Ok(Handled::Ack) => Ok(Sent::Ack),
69 Err(e) => Err(e), }
71 }
72
73 async fn send_batch(
74 &self,
75 messages: Vec<CanonicalMessage>,
76 ) -> Result<SentBatch, PublisherError> {
77 send_batch_helper(self, messages, |publisher, message| {
78 Box::pin(publisher.send(message))
79 })
80 .await
81 }
82
83 async fn flush(&self) -> anyhow::Result<()> {
84 self.inner.flush().await
85 }
86
87 fn as_any(&self) -> &dyn Any {
88 self
89 }
90}
91
92#[cfg(test)]
93mod tests {
94 use std::sync::atomic::{AtomicBool, Ordering};
95
96 use super::*;
97 use crate::endpoints::memory::MemoryPublisher;
98
99 #[tokio::test]
100 async fn test_command_handler_produces_response() {
101 let memory_publisher = MemoryPublisher::new_local("test_command_out_resp", 10);
102 let channel = memory_publisher.channel();
103
104 let handler = |msg: CanonicalMessage| async move {
105 let response_payload = format!("response_to_{}", String::from_utf8_lossy(&msg.payload));
106 Ok(Handled::Publish(response_payload.into()))
107 };
108
109 let publisher = CommandPublisher::new(memory_publisher, handler);
110
111 publisher.send("command1".into()).await.unwrap();
112
113 let received = channel.drain_messages();
114 assert_eq!(received.len(), 1);
115 assert_eq!(received[0].payload, "response_to_command1".as_bytes());
116 }
117
118 #[tokio::test]
119 async fn test_command_handler_acks() {
120 let memory_publisher = MemoryPublisher::new_local("test_command_out_ack", 10);
121 let channel = memory_publisher.channel();
122
123 let handler = |_msg: CanonicalMessage| async move { Ok(Handled::Ack) };
124
125 let publisher = CommandPublisher::new(memory_publisher, handler);
126
127 let result = publisher.send("command1".into()).await.unwrap();
128
129 assert!(matches!(result, Sent::Ack));
130 let received = channel.drain_messages();
131 assert_eq!(received.len(), 0);
132 }
133
134 #[tokio::test]
135 async fn test_command_handler_retryable_error() {
136 let memory_publisher = MemoryPublisher::new_local("test_command_out_err", 10);
137
138 let handler = |_msg: CanonicalMessage| async move {
139 Err(HandlerError::Retryable(anyhow::anyhow!("db is down")))
140 };
141
142 let publisher = CommandPublisher::new(memory_publisher, handler);
143 let result = publisher.send("command1".into()).await;
144
145 assert!(result.is_err());
146 let err = result.unwrap_err();
147 assert!(matches!(err, PublisherError::Retryable(_)));
149 }
150
151 #[tokio::test]
152 async fn test_command_handler_integration_with_memory_consumer() {
153 use crate::endpoints::memory::MemoryConsumer;
154 use crate::traits::MessageConsumer;
155
156 let mut consumer = MemoryConsumer::new_local("cmd_input", 10);
158 let input_channel = consumer.channel();
159
160 let memory_publisher = MemoryPublisher::new_local("cmd_output", 10);
162 let output_channel = memory_publisher.channel();
163
164 let publisher =
166 CommandPublisher::new(memory_publisher, |msg: CanonicalMessage| async move {
167 let payload = String::from_utf8_lossy(&msg.payload);
168 let response = format!("processed_{}", payload);
169 Ok(Handled::Publish(response.into()))
170 });
171
172 input_channel
174 .send_message("test_data".into())
175 .await
176 .unwrap();
177
178 let received = consumer.receive().await.unwrap();
180 let result = publisher.send(received.message).await.unwrap();
181
182 assert!(matches!(result, Sent::Ack));
184
185 let output_msgs = output_channel.drain_messages();
186 assert_eq!(output_msgs.len(), 1);
187 assert_eq!(output_msgs[0].payload.to_vec(), b"processed_test_data");
188
189 let _ = (received.commit)(crate::traits::MessageDisposition::Ack).await;
190 }
191
192 #[tokio::test(flavor = "multi_thread")]
193 async fn test_command_handler_with_route_config() {
194 use crate::models::{Endpoint, Route};
195
196 let success = Arc::new(AtomicBool::new(false));
197 let success_clone = success.clone();
198
199 let handler = move |mut msg: CanonicalMessage| {
201 success_clone.store(true, Ordering::SeqCst);
202 msg.set_payload_str(format!("modified {}", msg.get_payload_str()));
203 async move { Ok(Handled::Publish(msg)) }
204 };
205 let route = Route::new(
207 Endpoint::new_memory("route_in", 100),
208 Endpoint::new_memory("route_out", 100),
209 )
210 .with_handler(handler);
211
212 route.deploy("command_handler_test_route").await.unwrap();
214
215 let input_channel = route.input.channel().unwrap();
217 input_channel.send_message("hello".into()).await.unwrap();
218
219 let mut verifier = route.connect_to_output("verifier").await.unwrap();
221 let received = verifier.receive().await.unwrap();
222 assert_eq!(received.message.get_payload_str(), "modified hello");
223 assert!(success.load(Ordering::SeqCst));
224 Route::stop("command_handler_test_route").await;
225 }
226
227 #[tokio::test]
228 async fn test_command_handler_inner_publisher_failure() {
229 use crate::traits::MessagePublisher;
230
231 struct FailPublisher;
232 #[async_trait]
233 impl MessagePublisher for FailPublisher {
234 async fn send(&self, _msg: CanonicalMessage) -> Result<Sent, PublisherError> {
235 Err(PublisherError::NonRetryable(anyhow::anyhow!("inner fail")))
236 }
237 async fn send_batch(
238 &self,
239 _msgs: Vec<CanonicalMessage>,
240 ) -> Result<SentBatch, PublisherError> {
241 Ok(SentBatch::Ack)
242 }
243 fn as_any(&self) -> &dyn std::any::Any {
244 self
245 }
246 }
247
248 let handler = |msg: CanonicalMessage| async move { Ok(Handled::Publish(msg)) };
249 let publisher = CommandPublisher::new(FailPublisher, handler);
250 let result = publisher.send("test".into()).await;
251 assert!(result.is_err());
252 assert!(result.unwrap_err().to_string().contains("inner fail"));
253 }
254
255 #[tokio::test]
256 async fn test_command_handler_preserves_message_id() {
257 let memory_publisher = MemoryPublisher::new_local("test_cmd_id_preservation", 10);
258 let channel = memory_publisher.channel();
259
260 let handler = |_msg: CanonicalMessage| async move {
261 let new_msg = CanonicalMessage::new(b"response".to_vec(), None);
262 Ok(Handled::Publish(new_msg))
263 };
264
265 let publisher = CommandPublisher::new(memory_publisher, handler);
266 let original_id = 987654321u128;
267 publisher
268 .send(CanonicalMessage::new(b"req".to_vec(), Some(original_id)))
269 .await
270 .unwrap();
271
272 let received = channel.drain_messages();
273 assert_eq!(received[0].message_id, original_id);
274 }
275}