Skip to main content

mq_bridge/
traits.rs

1//  mq-bridge
2//  © Copyright 2025, by Marco Mengelkoch
3//  Licensed under MIT License, see License file for more details
4//  git clone https://github.com/marcomq/mq-bridge
5
6pub use crate::errors::{ConsumerError, HandlerError, PublisherError};
7pub use crate::outcomes::{Handled, Received, ReceivedBatch, Sent, SentBatch};
8use crate::CanonicalMessage;
9use async_trait::async_trait;
10pub use futures::future::BoxFuture;
11use std::any::Any;
12use std::sync::Arc;
13use tracing::warn;
14
15/// The disposition of a processed message.
16///
17/// Implements `From<Option<CanonicalMessage>>` for compatibility:
18/// `None` maps to `Ack`, `Some(msg)` maps to `Reply(msg)`.
19#[derive(Default, Debug, Clone)]
20#[allow(clippy::large_enum_variant)]
21pub enum MessageDisposition {
22    /// Acknowledge processing (success).
23    #[default]
24    Ack,
25    /// Acknowledge processing and send a reply.
26    Reply(CanonicalMessage),
27    /// Negative acknowledgement (failure).
28    Nack,
29}
30
31impl From<Option<CanonicalMessage>> for MessageDisposition {
32    fn from(opt: Option<CanonicalMessage>) -> Self {
33        match opt {
34            Some(msg) => MessageDisposition::Reply(msg),
35            None => MessageDisposition::Ack,
36        }
37    }
38}
39
40impl From<Handled> for MessageDisposition {
41    fn from(handled: Handled) -> Self {
42        match handled {
43            Handled::Ack => MessageDisposition::Ack,
44            Handled::Publish(msg) => MessageDisposition::Reply(msg),
45        }
46    }
47}
48
49/// A generic trait for handling messages (commands or events).
50///
51/// Handlers process an incoming message and can optionally return a new
52/// message (e.g. a reply) via `Handled::Publish`, or acknowledge processing via `Handled::Ack`.
53#[async_trait]
54pub trait Handler: Send + Sync + 'static {
55    async fn handle(&self, msg: CanonicalMessage) -> Result<Handled, HandlerError>;
56
57    /// Tries to register a handler for a specific type.
58    /// Returns `None` if this handler does not support registration (e.g. it's not a TypeHandler).
59    fn register_handler(
60        &self,
61        _type_name: &str,
62        _handler: Arc<dyn Handler>,
63    ) -> Option<Arc<dyn Handler>> {
64        None
65    }
66}
67
68#[async_trait]
69impl<T: Handler + ?Sized> Handler for Arc<T> {
70    async fn handle(&self, msg: CanonicalMessage) -> Result<Handled, HandlerError> {
71        (**self).handle(msg).await
72    }
73    fn register_handler(
74        &self,
75        type_name: &str,
76        handler: Arc<dyn Handler>,
77    ) -> Option<Arc<dyn Handler>> {
78        (**self).register_handler(type_name, handler)
79    }
80}
81
82/// A helper trait that allows implementing handlers using native `async fn` syntax
83/// without the `#[async_trait]` macro.
84///
85/// Implementations of this trait can be adapted to `Handler` using `SimpleHandler`.
86pub trait AsyncHandler: Send + Sync + 'static {
87    fn handle<'a>(&'a self, msg: CanonicalMessage) -> BoxFuture<'a, Result<Handled, HandlerError>>;
88}
89
90/// A wrapper struct that adapts an `AsyncHandler` to the `Handler` trait.
91pub struct SimpleHandler<T>(pub T);
92
93#[async_trait]
94impl<T: AsyncHandler> Handler for SimpleHandler<T> {
95    async fn handle(&self, msg: CanonicalMessage) -> Result<Handled, HandlerError> {
96        self.0.handle(msg).await
97    }
98}
99
100/// A closure that can be called to commit the message.
101/// It returns a `BoxFuture` to allow for async commit operations.
102pub type CommitFunc =
103    Box<dyn FnOnce(MessageDisposition) -> BoxFuture<'static, anyhow::Result<()>> + Send + 'static>;
104
105/// A closure for committing a batch of messages.
106pub type BatchCommitFunc = Box<
107    dyn FnOnce(Vec<MessageDisposition>) -> BoxFuture<'static, anyhow::Result<()>> + Send + 'static,
108>;
109
110/// Status information about an endpoint (Consumer or Publisher).
111#[derive(Debug, Clone, serde::Serialize)]
112pub struct EndpointStatus {
113    pub healthy: bool,
114    pub target: String,
115    #[serde(skip_serializing_if = "Option::is_none")]
116    pub pending: Option<usize>,
117    #[serde(skip_serializing_if = "Option::is_none")]
118    pub capacity: Option<usize>,
119    #[serde(skip_serializing_if = "Option::is_none")]
120    pub error: Option<String>,
121    pub details: serde_json::Value,
122}
123impl Default for EndpointStatus {
124    fn default() -> Self {
125        Self {
126            healthy: true,
127            target: String::new(),
128            pending: None,
129            capacity: None,
130            error: None,
131            details: serde_json::Value::Null,
132        }
133    }
134}
135
136#[async_trait]
137pub trait MessageConsumer: Send + Sync {
138    /// Returns an optional lifecycle hook that runs once after the consumer connection is created.
139    ///
140    /// The route awaits this hook before it reports itself as ready. Returning an error fails
141    /// route startup and lets the outer route runner reconnect or surface the startup failure.
142    ///
143    /// Use this for per-connection setup that should be shared by all messages read through this
144    /// consumer, such as warming a connection pool, creating SQLite tables or indexes, setting up
145    /// a Kafka consumer group, or authenticating a RabbitMQ channel.
146    ///
147    /// ```ignore
148    /// fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
149    ///     Some(Box::pin(async move {
150    ///         self.pool.get().await?;
151    ///         self.db.execute("CREATE TABLE IF NOT EXISTS embeddings (...)").await?;
152    ///         Ok(())
153    ///     }))
154    /// }
155    /// ```
156    fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
157        None
158    }
159
160    /// Returns an optional lifecycle hook that runs before the consumer is dropped.
161    ///
162    /// The route awaits this hook during shutdown or reconnect cleanup. Errors are logged as
163    /// warnings and do not replace the route's original result.
164    fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
165        None
166    }
167
168    /// Receives a batch of messages.
169    ///
170    /// This method must be implemented by all consumers.
171    /// If in doubt, implement `receive_batch` to return a single message as a vector.
172    async fn receive_batch(&mut self, _max_messages: usize)
173        -> Result<ReceivedBatch, ConsumerError>;
174
175    /// Receives a single message.
176    async fn receive(&mut self) -> Result<Received, ConsumerError> {
177        // This default implementation ensures we get exactly one message,
178        // looping if the underlying batch consumer returns an empty batch.
179        loop {
180            let mut batch = self.receive_batch(1).await?;
181            if let Some(msg) = batch.messages.pop() {
182                debug_assert!(batch.messages.is_empty());
183                if !batch.messages.is_empty() {
184                    tracing::error!(
185                        "receive_batch(1) returned {} extra messages; dropping them (implementation bug)",
186                        batch.messages.len()
187                    );
188                }
189                return Ok(Received {
190                    message: msg,
191                    commit: into_commit_func(batch.commit),
192                });
193            }
194            // Batch was success but empty, which is unexpected for receive(1). Loop.
195            tokio::task::yield_now().await;
196        }
197    }
198
199    async fn receive_batch_helper(
200        &mut self,
201        _max_messages: usize,
202    ) -> Result<ReceivedBatch, ConsumerError> {
203        let received = self.receive().await?; // The `?` now correctly handles ConsumerError
204        let batch_commit = Box::new(move |dispositions: Vec<MessageDisposition>| {
205            // The default implementation only handles one message, so we take the first disposition.
206            let single_disposition = dispositions
207                .into_iter()
208                .next()
209                .unwrap_or(MessageDisposition::Ack);
210            (received.commit)(single_disposition)
211        }) as BatchCommitFunc;
212        Ok(ReceivedBatch {
213            messages: vec![received.message],
214            commit: batch_commit,
215        })
216    }
217
218    async fn status(&self) -> EndpointStatus {
219        EndpointStatus {
220            healthy: true,
221            ..Default::default()
222        }
223    }
224    fn as_any(&self) -> &dyn Any;
225}
226
227#[async_trait]
228pub trait MessagePublisher: Send + Sync + 'static {
229    /// Returns an optional lifecycle hook that runs once after the publisher connection is created.
230    ///
231    /// The route awaits this hook before it reports itself as ready. Returning an error fails
232    /// route startup and lets the outer route runner reconnect or surface the startup failure.
233    ///
234    /// Use this for per-connection setup that should be shared by all messages published through
235    /// this publisher, such as loading an embedding model, warming a connection pool, creating
236    /// SQLite tables or indexes, setting up a Kafka producer transaction context, or
237    /// authenticating a RabbitMQ channel.
238    ///
239    /// ```ignore
240    /// struct SqliteEmbeddingPublisher {
241    ///     model: Arc<tokio::sync::Mutex<Option<EmbeddingModel>>>,
242    ///     db: sqlx::SqlitePool,
243    /// }
244    ///
245    /// impl MessagePublisher for SqliteEmbeddingPublisher {
246    ///     fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
247    ///         Some(Box::pin(async move {
248    ///             let mut model = self.model.lock().await;
249    ///             if model.is_none() {
250    ///                 *model = Some(EmbeddingModel::load("all-MiniLM-L6-v2").await?);
251    ///             }
252    ///             sqlx::query("CREATE INDEX IF NOT EXISTS idx_embeddings_id ON embeddings(id)")
253    ///                 .execute(&self.db)
254    ///                 .await?;
255    ///             Ok(())
256    ///         }))
257    ///     }
258    /// }
259    /// ```
260    fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
261        None
262    }
263
264    /// Returns an optional lifecycle hook that runs before the publisher is dropped.
265    ///
266    /// The route awaits this hook during shutdown or reconnect cleanup. Errors are logged as
267    /// warnings and do not replace the route's original result.
268    fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
269        None
270    }
271
272    /// Sends a batch of messages.
273    ///
274    /// This method must be implemented by all publishers.
275    /// If in doubt, implement `send_batch` to send messages one at a time.
276    async fn send_batch(
277        &self,
278        messages: Vec<CanonicalMessage>,
279    ) -> Result<SentBatch, PublisherError>;
280
281    async fn send(&self, message: CanonicalMessage) -> Result<Sent, PublisherError> {
282        let message_id = message.message_id;
283        let expects_reply = message.metadata.contains_key("reply_to");
284        match self.send_batch(vec![message]).await {
285            Ok(SentBatch::Ack) => {
286                if expects_reply {
287                    warn!("Message {:032x} expected a reply (reply_to set), but publisher returned Ack. Response loop might be broken.", message_id);
288                }
289                Ok(Sent::Ack)
290            }
291            Ok(SentBatch::Partial {
292                mut responses,
293                mut failed,
294            }) => {
295                if let Some((_, err)) = failed.pop() {
296                    Err(err)
297                } else if let Some(res) = responses.as_mut().and_then(|r| r.pop()) {
298                    Ok(Sent::Response(res))
299                } else {
300                    if expects_reply {
301                        warn!("Message {:032x} expected a reply (reply_to set), but publisher returned Ack. Response loop might be broken.", message_id);
302                    }
303                    Ok(Sent::Ack)
304                }
305            }
306            Err(e) => Err(e),
307        }
308    }
309
310    async fn flush(&self) -> anyhow::Result<()> {
311        Ok(())
312    }
313
314    async fn status(&self) -> EndpointStatus {
315        EndpointStatus {
316            healthy: true,
317            ..Default::default()
318        }
319    }
320    fn as_any(&self) -> &dyn Any;
321}
322
323#[async_trait]
324impl<T: MessagePublisher + ?Sized> MessagePublisher for Arc<T> {
325    fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
326        (**self).on_connect_hook()
327    }
328
329    fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
330        (**self).on_disconnect_hook()
331    }
332
333    async fn send(&self, message: CanonicalMessage) -> Result<Sent, PublisherError> {
334        (**self).send(message).await
335    }
336
337    async fn send_batch(
338        &self,
339        messages: Vec<CanonicalMessage>,
340    ) -> Result<SentBatch, PublisherError> {
341        (**self).send_batch(messages).await
342    }
343
344    async fn flush(&self) -> anyhow::Result<()> {
345        (**self).flush().await
346    }
347
348    async fn status(&self) -> EndpointStatus {
349        (**self).status().await
350    }
351
352    fn as_any(&self) -> &dyn Any {
353        (**self).as_any()
354    }
355}
356
357#[async_trait]
358impl<T: MessagePublisher + ?Sized> MessagePublisher for Box<T> {
359    fn on_connect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
360        (**self).on_connect_hook()
361    }
362
363    fn on_disconnect_hook(&self) -> Option<BoxFuture<'_, anyhow::Result<()>>> {
364        (**self).on_disconnect_hook()
365    }
366
367    async fn send(&self, message: CanonicalMessage) -> Result<Sent, PublisherError> {
368        (**self).send(message).await
369    }
370
371    async fn send_batch(
372        &self,
373        messages: Vec<CanonicalMessage>,
374    ) -> Result<SentBatch, PublisherError> {
375        (**self).send_batch(messages).await
376    }
377
378    async fn flush(&self) -> anyhow::Result<()> {
379        (**self).flush().await
380    }
381
382    async fn status(&self) -> EndpointStatus {
383        (**self).status().await
384    }
385
386    fn as_any(&self) -> &dyn Any {
387        (**self).as_any()
388    }
389}
390
391/// Factory for creating custom endpoints (consumers and publishers).
392#[async_trait]
393pub trait CustomEndpointFactory: Send + Sync + std::fmt::Debug {
394    async fn create_consumer(
395        &self,
396        _route_name: &str,
397        _config: &serde_json::Value,
398    ) -> anyhow::Result<Box<dyn MessageConsumer>> {
399        Err(anyhow::anyhow!(
400            "This custom endpoint does not support creating consumers"
401        ))
402    }
403    async fn create_publisher(
404        &self,
405        _route_name: &str,
406        _config: &serde_json::Value,
407    ) -> anyhow::Result<Box<dyn MessagePublisher>> {
408        Err(anyhow::anyhow!(
409            "This custom endpoint does not support creating publishers"
410        ))
411    }
412}
413
414/// Factory for creating custom middleware.
415#[async_trait]
416pub trait CustomMiddlewareFactory: Send + Sync + std::fmt::Debug {
417    async fn apply_consumer(
418        &self,
419        consumer: Box<dyn MessageConsumer>,
420        _route_name: &str,
421        _config: &serde_json::Value,
422    ) -> anyhow::Result<Box<dyn MessageConsumer>> {
423        Ok(consumer)
424    }
425
426    async fn apply_publisher(
427        &self,
428        publisher: Box<dyn MessagePublisher>,
429        _route_name: &str,
430        _config: &serde_json::Value,
431    ) -> anyhow::Result<Box<dyn MessagePublisher>> {
432        Ok(publisher)
433    }
434}
435
436/// A helper function to send messages in bulk by calling `send` for each one.
437/// This is useful for `MessagePublisher` implementations that don't have a native bulk sending mechanism.
438/// Requires that "send" is implemented for the publisher. Otherwise causes an infinite loop,
439/// as send is calling "send_batch" by default.
440pub async fn send_batch_helper<P: MessagePublisher + ?Sized>(
441    publisher: &P,
442    messages: Vec<CanonicalMessage>,
443    callback: impl for<'a> Fn(&'a P, CanonicalMessage) -> BoxFuture<'a, Result<Sent, PublisherError>>
444        + Send
445        + Sync,
446) -> Result<SentBatch, PublisherError> {
447    let mut responses = Vec::new();
448    let mut failed_messages = Vec::new();
449
450    let mut iter = messages.into_iter();
451    while let Some(msg) = iter.next() {
452        match callback(publisher, msg.clone()).await {
453            Ok(Sent::Response(resp)) => responses.push(resp),
454            Ok(Sent::Ack) => {}
455            Err(PublisherError::Retryable(e)) => {
456                // A retryable error likely affects the whole connection.
457                // We must return what succeeded so far (responses) and mark the rest as failed.
458                failed_messages.push((msg, PublisherError::Retryable(e)));
459                for m in iter {
460                    failed_messages.push((
461                        m,
462                        PublisherError::Retryable(anyhow::anyhow!(
463                            "Batch aborted due to previous error"
464                        )),
465                    ));
466                }
467                break;
468            }
469            Err(PublisherError::Connection(e)) => {
470                // Treat connection errors as affecting the whole batch, propagate immediately.
471                failed_messages.push((msg, PublisherError::Connection(e)));
472                for m in iter {
473                    failed_messages.push((
474                        m,
475                        PublisherError::Connection(anyhow::anyhow!(
476                            "Batch aborted due to previous connection error"
477                        )),
478                    ));
479                }
480                break;
481            }
482            Err(PublisherError::NonRetryable(e)) => {
483                // A non-retryable error is specific to this message.
484                // Collect it and continue with the rest of the batch.
485                failed_messages.push((msg, PublisherError::NonRetryable(e)));
486            }
487        }
488    }
489
490    if failed_messages.is_empty() && responses.is_empty() {
491        Ok(SentBatch::Ack)
492    } else {
493        Ok(SentBatch::Partial {
494            responses: if responses.is_empty() {
495                None
496            } else {
497                Some(responses)
498            },
499            failed: failed_messages,
500        })
501    }
502}
503
504/// Converts a `BatchCommitFunc` into a `CommitFunc` by wrapping it.
505/// This allows a function that commits a batch of messages to be used where a
506/// function that commits a single message is expected.
507pub fn into_commit_func(batch_commit: BatchCommitFunc) -> CommitFunc {
508    Box::new(move |disposition: MessageDisposition| {
509        let batch_disposition = vec![disposition];
510        batch_commit(batch_disposition)
511    })
512}
513
514/// Converts a `CommitFunc` into a `BatchCommitFunc` by wrapping it.
515/// This allows a function that commits a single message to be used where a
516/// function that commits a batch of messages is expected. It does so by
517/// extracting the first message from the response vector (if any) and passing
518/// it to the underlying single-message commit function.
519pub fn into_batch_commit_func(commit: CommitFunc) -> BatchCommitFunc {
520    Box::new(move |mut dispositions: Vec<MessageDisposition>| {
521        let single_disposition = if dispositions.len() > 1 {
522            warn!(
523                "into_batch_commit_func called with batch of {} messages; dropping all responses to avoid partial commit (incorrect usage)",
524                dispositions.len()
525            );
526            // Default to Ack to avoid hanging if we can't process the batch correctly
527            MessageDisposition::Ack
528        } else {
529            dispositions.pop().unwrap_or(MessageDisposition::Ack)
530        };
531        commit(single_disposition)
532    })
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538    use crate::CanonicalMessage;
539    use anyhow::anyhow;
540
541    struct MockPublisher;
542    #[async_trait]
543    impl MessagePublisher for MockPublisher {
544        async fn send_batch(
545            &self,
546            _msgs: Vec<CanonicalMessage>,
547        ) -> Result<SentBatch, PublisherError> {
548            Ok(SentBatch::Ack)
549        }
550        fn as_any(&self) -> &dyn Any {
551            self
552        }
553    }
554
555    #[tokio::test]
556    async fn test_send_batch_helper_partial_failure() {
557        let publisher = MockPublisher;
558        let msgs = vec![
559            CanonicalMessage::from("1"),
560            CanonicalMessage::from("2"),
561            CanonicalMessage::from("3"),
562        ];
563
564        let result = send_batch_helper(&publisher, msgs.clone(), |_pub, msg| {
565            Box::pin(async move {
566                let payload = msg.get_payload_str();
567                if payload == "1" {
568                    Ok(Sent::Response(CanonicalMessage::from("resp1")))
569                } else if payload == "2" {
570                    Err(PublisherError::Retryable(anyhow!("fail")))
571                } else {
572                    Ok(Sent::Ack)
573                }
574            })
575        })
576        .await;
577
578        match result {
579            Ok(SentBatch::Partial { responses, failed }) => {
580                // 1. Verify response from first message
581                assert!(responses.is_some());
582                let resps = responses.unwrap();
583                assert_eq!(resps.len(), 1);
584                assert_eq!(resps[0].get_payload_str(), "resp1");
585
586                // 2. Verify failures
587                // Message 2 failed explicitly
588                // Message 3 failed implicitly because batch was aborted
589                assert_eq!(failed.len(), 2);
590                assert_eq!(failed[0].0.get_payload_str(), "2");
591                assert!(matches!(failed[0].1, PublisherError::Retryable(_)));
592
593                assert_eq!(failed[1].0.get_payload_str(), "3");
594                assert!(matches!(failed[1].1, PublisherError::Retryable(_)));
595            }
596            _ => panic!("Expected Partial result"),
597        }
598    }
599
600    #[tokio::test]
601    async fn test_send_propagates_single_error() {
602        struct FailPublisher;
603        #[async_trait]
604        impl MessagePublisher for FailPublisher {
605            async fn send_batch(
606                &self,
607                msgs: Vec<CanonicalMessage>,
608            ) -> Result<SentBatch, PublisherError> {
609                // Simulate what send_batch_helper does on single failure
610                Ok(SentBatch::Partial {
611                    responses: None,
612                    failed: vec![(
613                        msgs[0].clone(),
614                        PublisherError::NonRetryable(anyhow!("inner")),
615                    )],
616                })
617            }
618            fn as_any(&self) -> &dyn Any {
619                self
620            }
621        }
622
623        let publ = FailPublisher;
624        let res = publ.send(CanonicalMessage::from("test")).await;
625
626        assert!(res.is_err());
627        match res.unwrap_err() {
628            PublisherError::NonRetryable(e) => assert_eq!(e.to_string(), "inner"),
629            _ => panic!("Expected NonRetryable error"),
630        }
631    }
632
633    #[tokio::test]
634    async fn test_simple_handler_wrapper() {
635        struct MyLogic;
636        impl AsyncHandler for MyLogic {
637            fn handle<'a>(
638                &'a self,
639                _msg: CanonicalMessage,
640            ) -> BoxFuture<'a, Result<Handled, HandlerError>> {
641                Box::pin(async { Ok(Handled::Ack) })
642            }
643        }
644
645        let handler = SimpleHandler(MyLogic);
646        let res = handler.handle(CanonicalMessage::from("test")).await;
647        assert!(matches!(res, Ok(Handled::Ack)));
648    }
649}