Skip to main content

amqp_lapin_helper/
lib.rs

1#![forbid(unsafe_code)]
2#[macro_use]
3extern crate tracing;
4
5pub use lapin::{
6    message::Delivery, options::*, types::*, BasicProperties, Channel, Connection,
7    ConnectionProperties, Consumer as LapinConsumer, ExchangeKind, Queue,
8};
9
10pub mod message {
11    pub use lapin::message::Delivery;
12}
13
14pub mod options {
15    pub use lapin::options::*;
16}
17
18pub mod types {
19    pub use lapin::types::*;
20}
21
22use async_trait::async_trait;
23use futures_lite::StreamExt;
24pub use leaky_bucket::RateLimiter;
25use prometheus::{
26    opts, register_gauge_vec, register_histogram_vec, register_int_counter, register_int_gauge_vec,
27    GaugeVec, HistogramVec, IntCounter, IntGaugeVec,
28};
29use serde::Serialize;
30use std::sync::{Arc, LazyLock};
31use std::time::Duration;
32use tokio::sync::mpsc::error::SendError;
33use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
34use tokio::sync::oneshot;
35use tokio::sync::oneshot::{Receiver, Sender};
36use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore};
37use tokio::task;
38use tokio::time::{sleep, timeout};
39
40pub type Requeue = bool;
41
42pub type Result<E> = std::result::Result<E, Error>;
43pub type ConsumeResult<E> = std::result::Result<E, Requeue>;
44
45static STAT_CONCURRENT_TASK: LazyLock<IntGaugeVec> = LazyLock::new(|| {
46    register_int_gauge_vec!(
47        opts!(
48            "amqp_consumer_concurrent_tasks",
49            "Current/Max concurrent check",
50        ),
51        &["exchange_name", "kind"],
52    )
53    .unwrap()
54});
55
56static STAT_RATE_LIMIT: LazyLock<IntGaugeVec> = LazyLock::new(|| {
57    register_int_gauge_vec!(
58        opts!(
59            "amqp_consumer_rate_limit_tokens",
60            "Current/Max rate limiting tokens count",
61        ),
62        &["exchange_name", "kind"],
63    )
64    .unwrap()
65});
66
67static STAT_TIMED_OUT_TASK: LazyLock<IntCounter> = LazyLock::new(|| {
68    register_int_counter!(
69        "amqp_consumer_timed_out_tasks",
70        "Count of tasks that hit the emergency timeout",
71    )
72    .unwrap()
73});
74
75const EXPONENTIAL_SECONDS: &[f64] = &[
76    0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 40.0,
77];
78
79static STAT_CONSUMER_DURATION: LazyLock<HistogramVec> = LazyLock::new(|| {
80    register_histogram_vec!(
81        "amqp_consumer_duration",
82        "The duration of the consumer",
83        &["exchange_name"],
84        EXPONENTIAL_SECONDS.to_vec(),
85    )
86    .unwrap()
87});
88
89static STAT_PUBLISHER_DURATION: LazyLock<HistogramVec> = LazyLock::new(|| {
90    register_histogram_vec!(
91        "amqp_publisher_duration",
92        "The duration of the publisher",
93        &["exchange_name", "routing_key"],
94        EXPONENTIAL_SECONDS.to_vec(),
95    )
96    .unwrap()
97});
98
99static STAT_PUBLISHER_MSG_QUEUE: LazyLock<GaugeVec> = LazyLock::new(|| {
100    register_gauge_vec!(
101        "amqp_publisher_msg_queue",
102        "The number of messages pending in the queue",
103        &["exchange_name", "routing_key"],
104    )
105    .unwrap()
106});
107
108/// The reconnection delay when it detects an amqp disconnection
109const RECONNECTION_DELAY: Duration = Duration::from_millis(1000);
110const AMQP_READINESS_TIMEOUT: Duration = Duration::from_secs(30);
111
112#[derive(thiserror::Error, Debug)]
113pub enum Error {
114    #[error("AMQP Client Readiness error: {0}")]
115    ReadinessSignal(String),
116    #[error("MpscSendError: {0}")]
117    SendError(#[from] SendError<QueueMessage>),
118    #[error("acquire-semaphore: {0}")]
119    AcquireSemaphore(#[from] AcquireError),
120    #[error("AMQP: {0}")]
121    Amqp(#[from] lapin::Error),
122    #[error("Missing server ID")]
123    MissingServerId,
124    #[error("String UTF-8 error: {0}")]
125    StringUtf8Error(#[from] std::string::FromUtf8Error),
126    #[error("Bincode: {0}")]
127    Bincode(#[from] bincode::Error),
128    #[error("Consumer: {0}")]
129    ConsumerError(#[from] Box<dyn std::error::Error + Send + Sync>),
130}
131
132/// Tag an object as Publishable
133#[async_trait]
134pub trait BrokerPublish {
135    fn exchange_name(&self) -> &'static str;
136}
137
138/// Plug listeners to the broker.
139#[async_trait]
140pub trait BrokerListener: Send + Sync {
141    /// Bind the queue & struct to this exchange name
142    fn exchange_name(&self) -> &'static str;
143
144    /// The concurrency limit for the listener (will not run more than this number of running tasks).
145    /// (Note: if your tasks are doing I/O with timeouts, waiting on timeouts counts as running,
146    ///  and so a few of timeouts happening at the same time can really impact the consumption rate,
147    ///  so prefer rate limiting for those scenarios instead of relying on concurrency limiting,
148    ///  setting this limit way high.)
149    fn max_concurrent_tasks(&self) -> usize {
150        1
151    }
152
153    /// The rate limit for the listener, based on the leaky bucket algorithm
154    /// (unlike the concurrency limit, new permits are also refilled each time interval,
155    ///  not when a task completes, so e.g. time-sensitive work would not get delayed by timeouts).
156    fn task_rate_limit(&self) -> Option<RateLimiter> {
157        None
158    }
159
160    /// The emergency termination timeout for the `consume` method
161    /// (Prefer having shorter fine-grained timeouts on inner operations, but this is the last resort against
162    ///  tasks hanging forever and starving the queue. 5 minutes by default)
163    fn task_timeout(&self) -> Duration {
164        Duration::from_millis(1000 * 60 * 5)
165    }
166
167    /// Whether to requeue upon an emergency termination timeout
168    fn requeue_on_timeout(&self) -> bool {
169        false
170    }
171
172    /// The method that will be called in the struct impl on every messages received
173    /// Err(false): reject.requeue = false
174    /// Err(true): reject.requeue = true
175    async fn consume(&self, delivery: &Delivery) -> std::result::Result<(), bool>;
176}
177
178#[async_trait]
179pub trait BrokerManager {
180    async fn declare_publisher(&self, channel: &Channel) -> Result<()>;
181
182    /// Consumer's declaration is not required, some apps only need a publisher, vice & versa
183    async fn declare_consumer(&self, channel: &Channel) -> Result<Option<lapin::Consumer>>;
184}
185
186/// AMQP Client
187pub struct Broker<M: BrokerManager> {
188    conn: Option<Connection>,
189    /// The publisher will be copied & cloned across the app.
190    /// Messages will be pushed into a queue, and then process to amqp asynchronously.
191    publisher: Publisher,
192    /// A daemon is spawned to process messages queue
193    publisher_queue: PublisherQueue,
194    consumer: Consumer,
195    manager: M,
196    uri: String,
197    /// This will synchronize the parent app with amqp being ready to consume.
198    ready_tx: Option<Sender<()>>,
199    ready_rx: Option<Receiver<()>>,
200}
201
202impl<M: BrokerManager> Broker<M> {
203    pub fn new(uri: &str, manager: M) -> Self {
204        let (tx, recv) = unbounded_channel();
205        let (ready_tx, ready_rx) = oneshot::channel();
206
207        Self {
208            conn: None,
209            publisher: Publisher::new(tx),
210            publisher_queue: PublisherQueue::new(recv),
211            consumer: Consumer::new(),
212            uri: uri.to_owned(),
213            manager,
214            ready_tx: Some(ready_tx),
215            ready_rx: Some(ready_rx),
216        }
217    }
218
219    /// Connect `Broker` to the AMQP endpoint, then declare Proxy's queue.
220    pub async fn init(&mut self) -> Result<()> {
221        let options = ConnectionProperties::default()
222            // Use tokio executor and reactor.
223            // At the moment the reactor is only available for unix.
224            .with_executor(tokio_executor_trait::Tokio::current())
225            .with_reactor(tokio_reactor_trait::Tokio);
226
227        let conn = Connection::connect(&self.uri, options).await?;
228
229        // Create Publisher
230        let channel = conn.create_channel().await?;
231        self.manager.declare_publisher(&channel).await?;
232        self.publisher_queue.channel = Some(channel); // not sure whether that is required or not
233
234        // Create Consumer
235        let channel = conn.create_channel().await?;
236        let amqp_consumer = self.manager.declare_consumer(&channel).await?;
237        self.consumer.consumer = amqp_consumer;
238
239        info!("Broker connected, channels created.");
240
241        self.conn = Some(conn);
242
243        // send a signal to tell the software that AMQP is ready and listening for incoming messages
244        if let Some(ready_tx) = self.ready_tx.take() {
245            ready_tx.send(()).expect("Can't send the amqp ready signal");
246            debug!("amqp has been initialized and is ready for the first time");
247        }
248
249        Ok(())
250    }
251
252    /// Add and store listeners
253    /// When a listener is added, it will bind the queue to the specified exchange name.
254    pub fn add_listener(&mut self, listener: Arc<dyn BrokerListener>) {
255        self.consumer.listeners.push(Listener::new(listener));
256    }
257
258    pub fn publish<P>(&self, entity: &P, routing_key: &str) -> Result<()>
259    where
260        P: BrokerPublish + Serialize,
261    {
262        self.publisher.publish(entity, routing_key)
263    }
264
265    pub fn publish_raw(&self, exchange: &str, routing_key: &str, msg: &[u8]) -> Result<()> {
266        self.publisher.publish_raw(exchange, routing_key, msg)
267    }
268
269    /// This will send a copy of the receiver to receive a signal that says the consumer & publisher are ready
270    /// so we could start the software and making sure we've haven't missed any messages.
271    pub fn ready_signal(&mut self) -> Ready {
272        Ready {
273            ready_rx: Some(
274                self.ready_rx
275                    .take()
276                    .expect("amqp::ready_signal() has already been consumed"),
277            ),
278        }
279    }
280
281    /// Spawn the consumer and retry on connection interruption
282    pub async fn spawn(&mut self) {
283        loop {
284            if let Err(err) = self.init().await {
285                error!(%err, "amqp connection failed");
286                sleep(RECONNECTION_DELAY).await;
287                continue; // retry connection before hitting the spawn
288            } else {
289                info!("connected to amqp");
290            }
291
292            tokio::select! {
293                err = self.consumer.consume() => {
294                    if let Err(err) = &err {
295                        error!(%err, "amqp consumer failed, trying to reconnect..");
296                    }
297                }
298                err = self.publisher_queue.publish() => {
299                    if let Err(err) = &err {
300                        error!(%err, "amqp publisher failed, trying to reconnect..");
301                    }
302                }
303            }
304        }
305    }
306
307    pub fn publisher(&self) -> &Publisher {
308        &self.publisher
309    }
310}
311
312/// This will send a copy of the receiver to receive a signal that says the consumer & publisher are ready
313/// so we could start the software and making sure we've haven't missed any messages.
314pub struct Ready {
315    ready_rx: Option<Receiver<()>>,
316}
317
318impl Ready {
319    /// This will send a copy of the receiver to receive a signal that says the consumer & publisher are ready
320    /// so we could start the software and making sure we've haven't missed any messages.
321    ///
322    /// A timeout has been added to make sure k8s will restart the containers to avoid
323    /// the container to wait indefinitely the readiness of amqp due to some potential configuration issue.
324    pub async fn wait_to_be_ready(&mut self) -> Result<()> {
325        let secs = AMQP_READINESS_TIMEOUT.as_secs();
326        let fut = self
327            .ready_rx
328            .take()
329            .expect("amqp has already been initialized");
330
331        timeout(AMQP_READINESS_TIMEOUT, fut)
332            .await
333            .map_err(|_| Error::ReadinessSignal(format!("timed out after {secs}s")))?
334            .map_err(|e| Error::ReadinessSignal(e.to_string()))
335    }
336}
337
338/// needs to be a MPSC queue, in order there's a disconnection, messages still could be added to the queue
339/// but needs to wait to be connected to AMQP in order to process the queue and send the msg on the broker
340/// maybe have a "PublisherQueue"
341#[derive(Clone)]
342pub struct Publisher {
343    /// Here we choose Unbounded channel, because if a disconnection happens, a large number of
344    /// messages can be produced while rabbitma has been disconnected.
345    /// Is it a good thing to keep all messages in memory? Unsure yet, the binary memory can grow
346    /// indefinitely if rabbitmq isn't available again. But at least it won't block the async runtime.
347    tx: UnboundedSender<QueueMessage>,
348}
349
350impl Publisher {
351    fn new(tx: UnboundedSender<QueueMessage>) -> Self {
352        Self { tx }
353    }
354
355    /// Push item into memory queue before pushing it to amqp
356    pub fn publish<P>(&self, entity: &P, routing_key: &str) -> Result<()>
357    where
358        P: BrokerPublish + Serialize,
359    {
360        let serialized = bincode::serialize(entity)?;
361
362        self.tx.send(QueueMessage::new(
363            entity.exchange_name(),
364            routing_key,
365            serialized,
366        ))?;
367
368        STAT_PUBLISHER_MSG_QUEUE
369            .with_label_values(&[entity.exchange_name(), routing_key])
370            .inc();
371
372        Ok(())
373    }
374
375    /// Push without serializing, serializing has been made before calling this function
376    pub fn publish_raw(&self, exchange: &str, routing_key: &str, msg: &[u8]) -> Result<()> {
377        self.tx
378            .send(QueueMessage::new(exchange, routing_key, msg.to_owned()))?;
379
380        STAT_PUBLISHER_MSG_QUEUE
381            .with_label_values(&[exchange, routing_key])
382            .inc();
383
384        Ok(())
385    }
386}
387
388pub struct PublisherQueue {
389    recv: UnboundedReceiver<QueueMessage>,
390    channel: Option<Channel>,
391}
392
393impl PublisherQueue {
394    fn new(recv: UnboundedReceiver<QueueMessage>) -> Self {
395        Self {
396            recv,
397            channel: None,
398        }
399    }
400
401    /// Process the messages from the queue to AMQP spawn a separate thread to publish
402    /// the message and avoid network delays for new incoming messages
403    ///
404    /// TODO: make sure this doesn't cause any problem such as messaging order issue
405    pub async fn publish(&mut self) -> Result<()> {
406        // process the mpsc queue for new messages
407        while let Some(msg) = self.recv.recv().await {
408            // Clone the Channel to move it to an async thread
409            let channel = self.channel.as_ref().unwrap().clone();
410
411            // send the message on amqp separately
412            tokio::spawn(async move {
413                // start prometheus duration timer
414                let histogram_timer = STAT_PUBLISHER_DURATION
415                    .with_label_values(&[&msg.exchange, &msg.routing_key])
416                    .start_timer();
417
418                let res = channel
419                    .basic_publish(
420                        &msg.exchange,
421                        &msg.routing_key,
422                        BasicPublishOptions::default(),
423                        &msg.content,
424                        BasicProperties::default(),
425                    )
426                    .await;
427
428                // finish and compute the duration to prometheus
429                histogram_timer.observe_duration();
430
431                // decrement the gauge about the number of pending msg in the queue
432                STAT_PUBLISHER_MSG_QUEUE
433                    .with_label_values(&[&msg.exchange, &msg.routing_key])
434                    .dec();
435
436                if let Err(err) = res {
437                    error!(%err, "failed to publish an amqp message");
438                }
439            });
440        }
441
442        Ok(())
443    }
444}
445
446pub struct QueueMessage {
447    exchange: String,
448    routing_key: String,
449    content: Vec<u8>,
450}
451
452impl QueueMessage {
453    fn new(exchange: &str, routing_key: &str, content: Vec<u8>) -> QueueMessage {
454        Self {
455            exchange: exchange.to_owned(),
456            routing_key: routing_key.to_owned(),
457            content,
458        }
459    }
460}
461
462#[derive(Clone)]
463pub struct Listener {
464    inner: Arc<dyn BrokerListener>,
465    semaphore: Option<Arc<Semaphore>>,
466    rate_limit: Option<Arc<RateLimiter>>,
467}
468
469impl Listener {
470    pub fn new(listener: Arc<dyn BrokerListener>) -> Self {
471        let max = listener.max_concurrent_tasks();
472        Self {
473            semaphore: if max > 1 {
474                Some(Arc::new(Semaphore::new(max)))
475            } else {
476                None
477            },
478            rate_limit: listener.task_rate_limit().map(Arc::new),
479            inner: listener,
480        }
481    }
482
483    fn listener(&self) -> &Arc<dyn BrokerListener> {
484        &self.inner
485    }
486
487    fn max_concurrent_tasks(&self) -> usize {
488        self.inner.max_concurrent_tasks()
489    }
490
491    async fn acquire_rate_limit(&self, exchange: &str) {
492        if let Some(rate_limit) = self.rate_limit.as_ref() {
493            STAT_RATE_LIMIT
494                .with_label_values(&[exchange, "max"])
495                .set(rate_limit.max() as i64);
496            debug!(
497                balance = rate_limit.balance(),
498                max = rate_limit.max(),
499                "waiting for a rate limiter permit",
500            );
501
502            rate_limit.acquire_one().await;
503            debug!("got a rate limiter permit");
504
505            STAT_RATE_LIMIT
506                .with_label_values(&[exchange, "balance"])
507                .set(rate_limit.balance() as i64);
508        }
509    }
510}
511
512#[derive(Default, Clone)]
513pub struct Consumer {
514    consumer: Option<lapin::Consumer>,
515    listeners: Vec<Listener>,
516}
517
518impl Consumer {
519    pub fn new() -> Self {
520        Default::default()
521    }
522
523    /// Add and store listeners
524    /// When a listener is added, it will bind the queue to the specified exchange name.
525    pub fn add_listener(&mut self, listener: Arc<dyn BrokerListener>) {
526        self.listeners.push(Listener::new(listener));
527    }
528
529    /// Consume messages by finding the appropriated listener.
530    pub async fn consume(&mut self) -> Result<()> {
531        if self.listeners.is_empty() || self.consumer.is_none() {
532            warn!("No listeners have been found, nothing will be consumed from amqp.");
533
534            loop {
535                // in case there's no listeners, in order to avoid breaking
536                // the job spawn for the publisher, we make an infinite loop here.
537                sleep(Duration::from_secs(60)).await;
538            }
539        } else {
540            info!(listeners = %self.listeners.len(), "Broker consuming...");
541        }
542
543        while let Some(message) = self.consumer.as_mut().unwrap().next().await {
544            match message {
545                Ok(delivery) => {
546                    // info!("received message: {:?}", delivery);
547                    let listener = self.listeners.iter().find(|listener| {
548                        listener.listener().exchange_name() == delivery.exchange.as_str()
549                    });
550
551                    if let Some(listener) = listener {
552                        let listener = listener.clone();
553
554                        if let Some(ref semaphore) = listener.semaphore {
555                            // Concurrent path: semaphore first, then rate_limit
556                            debug!(
557                                permits_available = semaphore.available_permits(),
558                                permits_max = listener.max_concurrent_tasks(),
559                                "waiting for a semaphore permit"
560                            );
561                            STAT_CONCURRENT_TASK
562                                .with_label_values(&[delivery.exchange.as_str(), "max"])
563                                .set(listener.max_concurrent_tasks() as i64);
564
565                            let permit = semaphore.clone().acquire_owned().await?;
566                            debug!("got a semaphore permit");
567
568                            STAT_CONCURRENT_TASK
569                                .with_label_values(&[delivery.exchange.as_str(), "permits_used"])
570                                .inc();
571
572                            listener.acquire_rate_limit(delivery.exchange.as_str()).await;
573                            task::spawn(consume_async(delivery, listener, permit));
574                        } else {
575                            // Inline path: rate_limit only (no semaphore)
576                            listener.acquire_rate_limit(delivery.exchange.as_str()).await;
577                            consume_inline(delivery, listener).await;
578                        }
579                    } else {
580                        // No listener found for that exchange
581                        if let Err(err) = delivery.nack(BasicNackOptions::default()).await {
582                            panic!("Can't find any registered listeners for `{}` exchange: {:?} + Failed to send nack: {}", &delivery.exchange, &delivery, err);
583                        } else {
584                            panic!(
585                                "Can't find any registered listeners for `{}` exchange: {:?}",
586                                &delivery.exchange, &delivery
587                            );
588                        }
589                    }
590                }
591                Err(err) => {
592                    error!(%err, "Error when receiving a delivery");
593                    Err(err)? // force the binary to shutdown on any AMQP error received
594                }
595            }
596        }
597        Ok(())
598    }
599}
600
601// async fn consume_async<L: BrokerListener + ?Sized>(
602//     delivery: Delivery,
603//     listener: Arc<L>,
604//     channel: Channel,
605// ) {
606/// Consume the delivery async
607async fn consume_async(delivery: Delivery, listener: Listener, permit: OwnedSemaphorePermit) {
608    // start prometheus duration timer
609    let histogram_timer = STAT_CONSUMER_DURATION
610        .with_label_values(&[listener.inner.exchange_name()])
611        .start_timer();
612
613    // launch the consumer
614    let res = timeout(
615        listener.listener().task_timeout(),
616        listener.listener().consume(&delivery),
617    )
618    .await;
619    drop(permit); // release the permit immediately
620
621    STAT_CONCURRENT_TASK
622        .with_label_values(&[delivery.exchange.as_str(), "permits_used"])
623        .dec();
624
625    // finish and compute the duration to prometheus
626    histogram_timer.observe_duration();
627
628    let res = match res {
629        Ok(inner) => inner,
630        Err(_) => {
631            error!("Consume task timed out");
632            STAT_TIMED_OUT_TASK.inc();
633            Err(listener.listener().requeue_on_timeout())
634        }
635    };
636
637    if let Err(requeue) = res {
638        #[allow(clippy::needless_update)]
639        let options = BasicRejectOptions {
640            requeue,
641            ..Default::default()
642        };
643
644        if let Err(err_reject) = delivery.reject(options).await {
645            error!(requeue, %err_reject, "Broker failed to send REJECT");
646        } else {
647            let exchange_name = listener.inner.exchange_name();
648            let routing_key = delivery.routing_key;
649            let redelivered = delivery.redelivered;
650
651            warn!(requeue, %exchange_name, %routing_key, %redelivered, "Error during consumption of a delivery, `REJECT` sent");
652        }
653    } else {
654        // Consumption went fine, we send ACK
655        if let Err(err) = delivery.ack(BasicAckOptions::default()).await {
656            error!(
657                %err, "Delivery consumed, but failed to send ACK back to the broker",
658            );
659        }
660    }
661}
662
663/// Inline variant of consume_async for single-task listeners (no semaphore).
664async fn consume_inline(delivery: Delivery, listener: Listener) {
665    let histogram_timer = STAT_CONSUMER_DURATION
666        .with_label_values(&[listener.inner.exchange_name()])
667        .start_timer();
668
669    let res = timeout(
670        listener.listener().task_timeout(),
671        listener.listener().consume(&delivery),
672    )
673    .await;
674
675    histogram_timer.observe_duration();
676
677    let res = match res {
678        Ok(inner) => inner,
679        Err(_) => {
680            error!("Consume task timed out");
681            STAT_TIMED_OUT_TASK.inc();
682            Err(listener.listener().requeue_on_timeout())
683        }
684    };
685
686    if let Err(requeue) = res {
687        #[allow(clippy::needless_update)]
688        let options = BasicRejectOptions {
689            requeue,
690            ..Default::default()
691        };
692
693        if let Err(err_reject) = delivery.reject(options).await {
694            error!(requeue, %err_reject, "Broker failed to send REJECT");
695        } else {
696            let exchange_name = listener.inner.exchange_name();
697            let routing_key = delivery.routing_key;
698            let redelivered = delivery.redelivered;
699
700            warn!(requeue, %exchange_name, %routing_key, %redelivered, "Error during consumption of a delivery, `REJECT` sent");
701        }
702    } else {
703        if let Err(err) = delivery.ack(BasicAckOptions::default()).await {
704            error!(
705                %err, "Delivery consumed, but failed to send ACK back to the broker",
706            );
707        }
708    }
709}