Skip to main content

mq_bridge/
command_handler.rs

1//  mq-bridge
2//  © Copyright 2025, by Marco Mengelkoch
3//  Licensed under MIT License, see License file for more details
4//  git clone https://github.com/marcomq/mq-bridge
5
6use crate::traits::{send_batch_helper, BoxFuture, Handler, MessagePublisher};
7use crate::traits::{Handled, HandlerError};
8use crate::CanonicalMessage;
9use async_trait::async_trait;
10use std::any::Any;
11use std::future::Future;
12use std::sync::Arc;
13
14use crate::traits::{PublisherError, Sent, SentBatch};
15#[async_trait]
16impl<F, Fut> Handler for F
17where
18    F: Fn(CanonicalMessage) -> Fut + Send + Sync + 'static,
19    Fut: Future<Output = Result<Handled, HandlerError>> + Send,
20{
21    async fn handle(&self, msg: CanonicalMessage) -> Result<Handled, HandlerError> {
22        self(msg).await
23    }
24}
25
26/// A publisher middleware that intercepts messages and passes them to a `Handler`.
27/// If the handler returns a new message, it is passed to the inner publisher.
28pub struct CommandPublisher {
29    inner: Box<dyn MessagePublisher>,
30    handler: Arc<dyn Handler>,
31}
32
33impl CommandPublisher {
34    pub fn new(inner: impl MessagePublisher, handler: impl Handler + 'static) -> Self {
35        Self {
36            inner: Box::new(inner),
37            handler: Arc::new(handler),
38        }
39    }
40}
41
42#[async_trait]
43impl MessagePublisher for CommandPublisher {
44    fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
45        self.inner.on_connect_hook()
46    }
47
48    fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
49        self.inner.on_disconnect_hook()
50    }
51
52    async fn send(&self, message: CanonicalMessage) -> Result<Sent, PublisherError> {
53        let inbound_correlation_id = message.metadata.get("correlation_id").cloned();
54        let original_id = message.message_id;
55        match self.handler.handle(message).await {
56            Ok(Handled::Publish(mut response_msg)) => {
57                // For internal correlation, set the response message's ID to the original.
58                response_msg.message_id = original_id;
59                // For end-to-end tracing, propagate or create a correlation_id.
60                let fallback_correlation_id =
61                    inbound_correlation_id.unwrap_or_else(|| format!("{:032x}", original_id));
62                response_msg
63                    .metadata
64                    .entry("correlation_id".to_string())
65                    .or_insert(fallback_correlation_id);
66                self.inner.send(response_msg).await
67            }
68            Ok(Handled::Ack) => Ok(Sent::Ack),
69            Err(e) => Err(e), // Converts HandlerError to PublisherError
70        }
71    }
72
73    async fn send_batch(
74        &self,
75        messages: Vec<CanonicalMessage>,
76    ) -> Result<SentBatch, PublisherError> {
77        send_batch_helper(self, messages, |publisher, message| {
78            Box::pin(publisher.send(message))
79        })
80        .await
81    }
82
83    async fn flush(&self) -> anyhow::Result<()> {
84        self.inner.flush().await
85    }
86
87    fn as_any(&self) -> &dyn Any {
88        self
89    }
90}
91
92#[cfg(test)]
93mod tests {
94    use std::sync::atomic::{AtomicBool, Ordering};
95
96    use super::*;
97    use crate::endpoints::memory::MemoryPublisher;
98
99    #[tokio::test]
100    async fn test_command_handler_produces_response() {
101        let memory_publisher = MemoryPublisher::new_local("test_command_out_resp", 10);
102        let channel = memory_publisher.channel();
103
104        let handler = |msg: CanonicalMessage| async move {
105            let response_payload = format!("response_to_{}", String::from_utf8_lossy(&msg.payload));
106            Ok(Handled::Publish(response_payload.into()))
107        };
108
109        let publisher = CommandPublisher::new(memory_publisher, handler);
110
111        publisher.send("command1".into()).await.unwrap();
112
113        let received = channel.drain_messages();
114        assert_eq!(received.len(), 1);
115        assert_eq!(received[0].payload, "response_to_command1".as_bytes());
116    }
117
118    #[tokio::test]
119    async fn test_command_handler_acks() {
120        let memory_publisher = MemoryPublisher::new_local("test_command_out_ack", 10);
121        let channel = memory_publisher.channel();
122
123        let handler = |_msg: CanonicalMessage| async move { Ok(Handled::Ack) };
124
125        let publisher = CommandPublisher::new(memory_publisher, handler);
126
127        let result = publisher.send("command1".into()).await.unwrap();
128
129        assert!(matches!(result, Sent::Ack));
130        let received = channel.drain_messages();
131        assert_eq!(received.len(), 0);
132    }
133
134    #[tokio::test]
135    async fn test_command_handler_retryable_error() {
136        let memory_publisher = MemoryPublisher::new_local("test_command_out_err", 10);
137
138        let handler = |_msg: CanonicalMessage| async move {
139            Err(HandlerError::Retryable(anyhow::anyhow!("db is down")))
140        };
141
142        let publisher = CommandPublisher::new(memory_publisher, handler);
143        let result = publisher.send("command1".into()).await;
144
145        assert!(result.is_err());
146        let err = result.unwrap_err();
147        // The HandlerError is converted into a PublisherError
148        assert!(matches!(err, PublisherError::Retryable(_)));
149    }
150
151    #[tokio::test]
152    async fn test_command_handler_integration_with_memory_consumer() {
153        use crate::endpoints::memory::MemoryConsumer;
154        use crate::traits::MessageConsumer;
155
156        // 1. Setup Input (MemoryConsumer)
157        let mut consumer = MemoryConsumer::new_local("cmd_input", 10);
158        let input_channel = consumer.channel();
159
160        // 2. Setup Output (MemoryPublisher wrapped by CommandPublisher)
161        let memory_publisher = MemoryPublisher::new_local("cmd_output", 10);
162        let output_channel = memory_publisher.channel();
163
164        // 3. Create Publisher Middleware with inline handler
165        let publisher =
166            CommandPublisher::new(memory_publisher, |msg: CanonicalMessage| async move {
167                let payload = String::from_utf8_lossy(&msg.payload);
168                let response = format!("processed_{}", payload);
169                Ok(Handled::Publish(response.into()))
170            });
171
172        // 4. Inject message into input
173        input_channel
174            .send_message("test_data".into())
175            .await
176            .unwrap();
177
178        // 5. Simulate Bridge Loop (Consume -> Publish)
179        let received = consumer.receive().await.unwrap();
180        let result = publisher.send(received.message).await.unwrap();
181
182        // 6. Verify
183        assert!(matches!(result, Sent::Ack));
184
185        let output_msgs = output_channel.drain_messages();
186        assert_eq!(output_msgs.len(), 1);
187        assert_eq!(output_msgs[0].payload.to_vec(), b"processed_test_data");
188
189        let _ = (received.commit)(crate::traits::MessageDisposition::Ack).await;
190    }
191
192    #[tokio::test(flavor = "multi_thread")]
193    async fn test_command_handler_with_route_config() {
194        use crate::models::{Endpoint, Route};
195
196        let success = Arc::new(AtomicBool::new(false));
197        let success_clone = success.clone();
198
199        // 1. Define Handler
200        let handler = move |mut msg: CanonicalMessage| {
201            success_clone.store(true, Ordering::SeqCst);
202            msg.set_payload_str(format!("modified {}", msg.get_payload_str()));
203            async move { Ok(Handled::Publish(msg)) }
204        };
205        // 2. Define Route
206        let route = Route::new(
207            Endpoint::new_memory("route_in", 100),
208            Endpoint::new_memory("route_out", 100),
209        )
210        .with_handler(handler);
211
212        // 3. Deploy Route
213        route.deploy("command_handler_test_route").await.unwrap();
214
215        // 4. Inject Data
216        let input_channel = route.input.channel().unwrap();
217        input_channel.send_message("hello".into()).await.unwrap();
218
219        // 5. Verify
220        let mut verifier = route.connect_to_output("verifier").await.unwrap();
221        let received = verifier.receive().await.unwrap();
222        assert_eq!(received.message.get_payload_str(), "modified hello");
223        assert!(success.load(Ordering::SeqCst));
224        Route::stop("command_handler_test_route").await;
225    }
226
227    #[tokio::test]
228    async fn test_command_handler_inner_publisher_failure() {
229        use crate::traits::MessagePublisher;
230
231        struct FailPublisher;
232        #[async_trait]
233        impl MessagePublisher for FailPublisher {
234            async fn send(&self, _msg: CanonicalMessage) -> Result<Sent, PublisherError> {
235                Err(PublisherError::NonRetryable(anyhow::anyhow!("inner fail")))
236            }
237            async fn send_batch(
238                &self,
239                _msgs: Vec<CanonicalMessage>,
240            ) -> Result<SentBatch, PublisherError> {
241                Ok(SentBatch::Ack)
242            }
243            fn as_any(&self) -> &dyn std::any::Any {
244                self
245            }
246        }
247
248        let handler = |msg: CanonicalMessage| async move { Ok(Handled::Publish(msg)) };
249        let publisher = CommandPublisher::new(FailPublisher, handler);
250        let result = publisher.send("test".into()).await;
251        assert!(result.is_err());
252        assert!(result.unwrap_err().to_string().contains("inner fail"));
253    }
254
255    #[tokio::test]
256    async fn test_command_handler_preserves_message_id() {
257        let memory_publisher = MemoryPublisher::new_local("test_cmd_id_preservation", 10);
258        let channel = memory_publisher.channel();
259
260        let handler = |_msg: CanonicalMessage| async move {
261            let new_msg = CanonicalMessage::new(b"response".to_vec(), None);
262            Ok(Handled::Publish(new_msg))
263        };
264
265        let publisher = CommandPublisher::new(memory_publisher, handler);
266        let original_id = 987654321u128;
267        publisher
268            .send(CanonicalMessage::new(b"req".to_vec(), Some(original_id)))
269            .await
270            .unwrap();
271
272        let received = channel.drain_messages();
273        assert_eq!(received[0].message_id, original_id);
274    }
275}