1#![forbid(unsafe_code)]
2#[macro_use]
3extern crate tracing;
4
5pub use lapin::{
6 message::Delivery, options::*, types::*, BasicProperties, Channel, Connection,
7 ConnectionProperties, Consumer as LapinConsumer, ExchangeKind, Queue,
8};
9
10pub mod message {
11 pub use lapin::message::Delivery;
12}
13
14pub mod options {
15 pub use lapin::options::*;
16}
17
18pub mod types {
19 pub use lapin::types::*;
20}
21
22use async_trait::async_trait;
23use futures_lite::StreamExt;
24pub use leaky_bucket::RateLimiter;
25use prometheus::{
26 opts, register_gauge_vec, register_histogram_vec, register_int_counter, register_int_gauge_vec,
27 GaugeVec, HistogramVec, IntCounter, IntGaugeVec,
28};
29use serde::Serialize;
30use std::sync::{Arc, LazyLock};
31use std::time::Duration;
32use tokio::sync::mpsc::error::SendError;
33use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
34use tokio::sync::oneshot;
35use tokio::sync::oneshot::{Receiver, Sender};
36use tokio::sync::{AcquireError, OwnedSemaphorePermit, Semaphore};
37use tokio::task;
38use tokio::time::{sleep, timeout};
39
40pub type Requeue = bool;
41
42pub type Result<E> = std::result::Result<E, Error>;
43pub type ConsumeResult<E> = std::result::Result<E, Requeue>;
44
45static STAT_CONCURRENT_TASK: LazyLock<IntGaugeVec> = LazyLock::new(|| {
46 register_int_gauge_vec!(
47 opts!(
48 "amqp_consumer_concurrent_tasks",
49 "Current/Max concurrent check",
50 ),
51 &["exchange_name", "kind"],
52 )
53 .unwrap()
54});
55
56static STAT_RATE_LIMIT: LazyLock<IntGaugeVec> = LazyLock::new(|| {
57 register_int_gauge_vec!(
58 opts!(
59 "amqp_consumer_rate_limit_tokens",
60 "Current/Max rate limiting tokens count",
61 ),
62 &["exchange_name", "kind"],
63 )
64 .unwrap()
65});
66
67static STAT_TIMED_OUT_TASK: LazyLock<IntCounter> = LazyLock::new(|| {
68 register_int_counter!(
69 "amqp_consumer_timed_out_tasks",
70 "Count of tasks that hit the emergency timeout",
71 )
72 .unwrap()
73});
74
75const EXPONENTIAL_SECONDS: &[f64] = &[
76 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 40.0,
77];
78
79static STAT_CONSUMER_DURATION: LazyLock<HistogramVec> = LazyLock::new(|| {
80 register_histogram_vec!(
81 "amqp_consumer_duration",
82 "The duration of the consumer",
83 &["exchange_name"],
84 EXPONENTIAL_SECONDS.to_vec(),
85 )
86 .unwrap()
87});
88
89static STAT_PUBLISHER_DURATION: LazyLock<HistogramVec> = LazyLock::new(|| {
90 register_histogram_vec!(
91 "amqp_publisher_duration",
92 "The duration of the publisher",
93 &["exchange_name", "routing_key"],
94 EXPONENTIAL_SECONDS.to_vec(),
95 )
96 .unwrap()
97});
98
99static STAT_PUBLISHER_MSG_QUEUE: LazyLock<GaugeVec> = LazyLock::new(|| {
100 register_gauge_vec!(
101 "amqp_publisher_msg_queue",
102 "The number of messages pending in the queue",
103 &["exchange_name", "routing_key"],
104 )
105 .unwrap()
106});
107
108const RECONNECTION_DELAY: Duration = Duration::from_millis(1000);
110const AMQP_READINESS_TIMEOUT: Duration = Duration::from_secs(30);
111
112#[derive(thiserror::Error, Debug)]
113pub enum Error {
114 #[error("AMQP Client Readiness error: {0}")]
115 ReadinessSignal(String),
116 #[error("MpscSendError: {0}")]
117 SendError(#[from] SendError<QueueMessage>),
118 #[error("acquire-semaphore: {0}")]
119 AcquireSemaphore(#[from] AcquireError),
120 #[error("AMQP: {0}")]
121 Amqp(#[from] lapin::Error),
122 #[error("Missing server ID")]
123 MissingServerId,
124 #[error("String UTF-8 error: {0}")]
125 StringUtf8Error(#[from] std::string::FromUtf8Error),
126 #[error("Bincode: {0}")]
127 Bincode(#[from] bincode::Error),
128 #[error("Consumer: {0}")]
129 ConsumerError(#[from] Box<dyn std::error::Error + Send + Sync>),
130}
131
132#[async_trait]
134pub trait BrokerPublish {
135 fn exchange_name(&self) -> &'static str;
136}
137
138#[async_trait]
140pub trait BrokerListener: Send + Sync {
141 fn exchange_name(&self) -> &'static str;
143
144 fn max_concurrent_tasks(&self) -> usize {
150 1
151 }
152
153 fn task_rate_limit(&self) -> Option<RateLimiter> {
157 None
158 }
159
160 fn task_timeout(&self) -> Duration {
164 Duration::from_millis(1000 * 60 * 5)
165 }
166
167 fn requeue_on_timeout(&self) -> bool {
169 false
170 }
171
172 async fn consume(&self, delivery: &Delivery) -> std::result::Result<(), bool>;
176}
177
178#[async_trait]
179pub trait BrokerManager {
180 async fn declare_publisher(&self, channel: &Channel) -> Result<()>;
181
182 async fn declare_consumer(&self, channel: &Channel) -> Result<Option<lapin::Consumer>>;
184}
185
186pub struct Broker<M: BrokerManager> {
188 conn: Option<Connection>,
189 publisher: Publisher,
192 publisher_queue: PublisherQueue,
194 consumer: Consumer,
195 manager: M,
196 uri: String,
197 ready_tx: Option<Sender<()>>,
199 ready_rx: Option<Receiver<()>>,
200}
201
202impl<M: BrokerManager> Broker<M> {
203 pub fn new(uri: &str, manager: M) -> Self {
204 let (tx, recv) = unbounded_channel();
205 let (ready_tx, ready_rx) = oneshot::channel();
206
207 Self {
208 conn: None,
209 publisher: Publisher::new(tx),
210 publisher_queue: PublisherQueue::new(recv),
211 consumer: Consumer::new(),
212 uri: uri.to_owned(),
213 manager,
214 ready_tx: Some(ready_tx),
215 ready_rx: Some(ready_rx),
216 }
217 }
218
219 pub async fn init(&mut self) -> Result<()> {
221 let options = ConnectionProperties::default()
222 .with_executor(tokio_executor_trait::Tokio::current())
225 .with_reactor(tokio_reactor_trait::Tokio);
226
227 let conn = Connection::connect(&self.uri, options).await?;
228
229 let channel = conn.create_channel().await?;
231 self.manager.declare_publisher(&channel).await?;
232 self.publisher_queue.channel = Some(channel); let channel = conn.create_channel().await?;
236 let amqp_consumer = self.manager.declare_consumer(&channel).await?;
237 self.consumer.consumer = amqp_consumer;
238
239 info!("Broker connected, channels created.");
240
241 self.conn = Some(conn);
242
243 if let Some(ready_tx) = self.ready_tx.take() {
245 ready_tx.send(()).expect("Can't send the amqp ready signal");
246 debug!("amqp has been initialized and is ready for the first time");
247 }
248
249 Ok(())
250 }
251
252 pub fn add_listener(&mut self, listener: Arc<dyn BrokerListener>) {
255 self.consumer.listeners.push(Listener::new(listener));
256 }
257
258 pub fn publish<P>(&self, entity: &P, routing_key: &str) -> Result<()>
259 where
260 P: BrokerPublish + Serialize,
261 {
262 self.publisher.publish(entity, routing_key)
263 }
264
265 pub fn publish_raw(&self, exchange: &str, routing_key: &str, msg: &[u8]) -> Result<()> {
266 self.publisher.publish_raw(exchange, routing_key, msg)
267 }
268
269 pub fn ready_signal(&mut self) -> Ready {
272 Ready {
273 ready_rx: Some(
274 self.ready_rx
275 .take()
276 .expect("amqp::ready_signal() has already been consumed"),
277 ),
278 }
279 }
280
281 pub async fn spawn(&mut self) {
283 loop {
284 if let Err(err) = self.init().await {
285 error!(%err, "amqp connection failed");
286 sleep(RECONNECTION_DELAY).await;
287 continue; } else {
289 info!("connected to amqp");
290 }
291
292 tokio::select! {
293 err = self.consumer.consume() => {
294 if let Err(err) = &err {
295 error!(%err, "amqp consumer failed, trying to reconnect..");
296 }
297 }
298 err = self.publisher_queue.publish() => {
299 if let Err(err) = &err {
300 error!(%err, "amqp publisher failed, trying to reconnect..");
301 }
302 }
303 }
304 }
305 }
306
307 pub fn publisher(&self) -> &Publisher {
308 &self.publisher
309 }
310}
311
312pub struct Ready {
315 ready_rx: Option<Receiver<()>>,
316}
317
318impl Ready {
319 pub async fn wait_to_be_ready(&mut self) -> Result<()> {
325 let secs = AMQP_READINESS_TIMEOUT.as_secs();
326 let fut = self
327 .ready_rx
328 .take()
329 .expect("amqp has already been initialized");
330
331 timeout(AMQP_READINESS_TIMEOUT, fut)
332 .await
333 .map_err(|_| Error::ReadinessSignal(format!("timed out after {secs}s")))?
334 .map_err(|e| Error::ReadinessSignal(e.to_string()))
335 }
336}
337
338#[derive(Clone)]
342pub struct Publisher {
343 tx: UnboundedSender<QueueMessage>,
348}
349
350impl Publisher {
351 fn new(tx: UnboundedSender<QueueMessage>) -> Self {
352 Self { tx }
353 }
354
355 pub fn publish<P>(&self, entity: &P, routing_key: &str) -> Result<()>
357 where
358 P: BrokerPublish + Serialize,
359 {
360 let serialized = bincode::serialize(entity)?;
361
362 self.tx.send(QueueMessage::new(
363 entity.exchange_name(),
364 routing_key,
365 serialized,
366 ))?;
367
368 STAT_PUBLISHER_MSG_QUEUE
369 .with_label_values(&[entity.exchange_name(), routing_key])
370 .inc();
371
372 Ok(())
373 }
374
375 pub fn publish_raw(&self, exchange: &str, routing_key: &str, msg: &[u8]) -> Result<()> {
377 self.tx
378 .send(QueueMessage::new(exchange, routing_key, msg.to_owned()))?;
379
380 STAT_PUBLISHER_MSG_QUEUE
381 .with_label_values(&[exchange, routing_key])
382 .inc();
383
384 Ok(())
385 }
386}
387
388pub struct PublisherQueue {
389 recv: UnboundedReceiver<QueueMessage>,
390 channel: Option<Channel>,
391}
392
393impl PublisherQueue {
394 fn new(recv: UnboundedReceiver<QueueMessage>) -> Self {
395 Self {
396 recv,
397 channel: None,
398 }
399 }
400
401 pub async fn publish(&mut self) -> Result<()> {
406 while let Some(msg) = self.recv.recv().await {
408 let channel = self.channel.as_ref().unwrap().clone();
410
411 tokio::spawn(async move {
413 let histogram_timer = STAT_PUBLISHER_DURATION
415 .with_label_values(&[&msg.exchange, &msg.routing_key])
416 .start_timer();
417
418 let res = channel
419 .basic_publish(
420 &msg.exchange,
421 &msg.routing_key,
422 BasicPublishOptions::default(),
423 &msg.content,
424 BasicProperties::default(),
425 )
426 .await;
427
428 histogram_timer.observe_duration();
430
431 STAT_PUBLISHER_MSG_QUEUE
433 .with_label_values(&[&msg.exchange, &msg.routing_key])
434 .dec();
435
436 if let Err(err) = res {
437 error!(%err, "failed to publish an amqp message");
438 }
439 });
440 }
441
442 Ok(())
443 }
444}
445
446pub struct QueueMessage {
447 exchange: String,
448 routing_key: String,
449 content: Vec<u8>,
450}
451
452impl QueueMessage {
453 fn new(exchange: &str, routing_key: &str, content: Vec<u8>) -> QueueMessage {
454 Self {
455 exchange: exchange.to_owned(),
456 routing_key: routing_key.to_owned(),
457 content,
458 }
459 }
460}
461
462#[derive(Clone)]
463pub struct Listener {
464 inner: Arc<dyn BrokerListener>,
465 semaphore: Option<Arc<Semaphore>>,
466 rate_limit: Option<Arc<RateLimiter>>,
467}
468
469impl Listener {
470 pub fn new(listener: Arc<dyn BrokerListener>) -> Self {
471 let max = listener.max_concurrent_tasks();
472 Self {
473 semaphore: if max > 1 {
474 Some(Arc::new(Semaphore::new(max)))
475 } else {
476 None
477 },
478 rate_limit: listener.task_rate_limit().map(Arc::new),
479 inner: listener,
480 }
481 }
482
483 fn listener(&self) -> &Arc<dyn BrokerListener> {
484 &self.inner
485 }
486
487 fn max_concurrent_tasks(&self) -> usize {
488 self.inner.max_concurrent_tasks()
489 }
490
491 async fn acquire_rate_limit(&self, exchange: &str) {
492 if let Some(rate_limit) = self.rate_limit.as_ref() {
493 STAT_RATE_LIMIT
494 .with_label_values(&[exchange, "max"])
495 .set(rate_limit.max() as i64);
496 debug!(
497 balance = rate_limit.balance(),
498 max = rate_limit.max(),
499 "waiting for a rate limiter permit",
500 );
501
502 rate_limit.acquire_one().await;
503 debug!("got a rate limiter permit");
504
505 STAT_RATE_LIMIT
506 .with_label_values(&[exchange, "balance"])
507 .set(rate_limit.balance() as i64);
508 }
509 }
510}
511
512#[derive(Default, Clone)]
513pub struct Consumer {
514 consumer: Option<lapin::Consumer>,
515 listeners: Vec<Listener>,
516}
517
518impl Consumer {
519 pub fn new() -> Self {
520 Default::default()
521 }
522
523 pub fn add_listener(&mut self, listener: Arc<dyn BrokerListener>) {
526 self.listeners.push(Listener::new(listener));
527 }
528
529 pub async fn consume(&mut self) -> Result<()> {
531 if self.listeners.is_empty() || self.consumer.is_none() {
532 warn!("No listeners have been found, nothing will be consumed from amqp.");
533
534 loop {
535 sleep(Duration::from_secs(60)).await;
538 }
539 } else {
540 info!(listeners = %self.listeners.len(), "Broker consuming...");
541 }
542
543 while let Some(message) = self.consumer.as_mut().unwrap().next().await {
544 match message {
545 Ok(delivery) => {
546 let listener = self.listeners.iter().find(|listener| {
548 listener.listener().exchange_name() == delivery.exchange.as_str()
549 });
550
551 if let Some(listener) = listener {
552 let listener = listener.clone();
553
554 if let Some(ref semaphore) = listener.semaphore {
555 debug!(
557 permits_available = semaphore.available_permits(),
558 permits_max = listener.max_concurrent_tasks(),
559 "waiting for a semaphore permit"
560 );
561 STAT_CONCURRENT_TASK
562 .with_label_values(&[delivery.exchange.as_str(), "max"])
563 .set(listener.max_concurrent_tasks() as i64);
564
565 let permit = semaphore.clone().acquire_owned().await?;
566 debug!("got a semaphore permit");
567
568 STAT_CONCURRENT_TASK
569 .with_label_values(&[delivery.exchange.as_str(), "permits_used"])
570 .inc();
571
572 listener.acquire_rate_limit(delivery.exchange.as_str()).await;
573 task::spawn(consume_async(delivery, listener, permit));
574 } else {
575 listener.acquire_rate_limit(delivery.exchange.as_str()).await;
577 consume_inline(delivery, listener).await;
578 }
579 } else {
580 if let Err(err) = delivery.nack(BasicNackOptions::default()).await {
582 panic!("Can't find any registered listeners for `{}` exchange: {:?} + Failed to send nack: {}", &delivery.exchange, &delivery, err);
583 } else {
584 panic!(
585 "Can't find any registered listeners for `{}` exchange: {:?}",
586 &delivery.exchange, &delivery
587 );
588 }
589 }
590 }
591 Err(err) => {
592 error!(%err, "Error when receiving a delivery");
593 Err(err)? }
595 }
596 }
597 Ok(())
598 }
599}
600
601async fn consume_async(delivery: Delivery, listener: Listener, permit: OwnedSemaphorePermit) {
608 let histogram_timer = STAT_CONSUMER_DURATION
610 .with_label_values(&[listener.inner.exchange_name()])
611 .start_timer();
612
613 let res = timeout(
615 listener.listener().task_timeout(),
616 listener.listener().consume(&delivery),
617 )
618 .await;
619 drop(permit); STAT_CONCURRENT_TASK
622 .with_label_values(&[delivery.exchange.as_str(), "permits_used"])
623 .dec();
624
625 histogram_timer.observe_duration();
627
628 let res = match res {
629 Ok(inner) => inner,
630 Err(_) => {
631 error!("Consume task timed out");
632 STAT_TIMED_OUT_TASK.inc();
633 Err(listener.listener().requeue_on_timeout())
634 }
635 };
636
637 if let Err(requeue) = res {
638 #[allow(clippy::needless_update)]
639 let options = BasicRejectOptions {
640 requeue,
641 ..Default::default()
642 };
643
644 if let Err(err_reject) = delivery.reject(options).await {
645 error!(requeue, %err_reject, "Broker failed to send REJECT");
646 } else {
647 let exchange_name = listener.inner.exchange_name();
648 let routing_key = delivery.routing_key;
649 let redelivered = delivery.redelivered;
650
651 warn!(requeue, %exchange_name, %routing_key, %redelivered, "Error during consumption of a delivery, `REJECT` sent");
652 }
653 } else {
654 if let Err(err) = delivery.ack(BasicAckOptions::default()).await {
656 error!(
657 %err, "Delivery consumed, but failed to send ACK back to the broker",
658 );
659 }
660 }
661}
662
663async fn consume_inline(delivery: Delivery, listener: Listener) {
665 let histogram_timer = STAT_CONSUMER_DURATION
666 .with_label_values(&[listener.inner.exchange_name()])
667 .start_timer();
668
669 let res = timeout(
670 listener.listener().task_timeout(),
671 listener.listener().consume(&delivery),
672 )
673 .await;
674
675 histogram_timer.observe_duration();
676
677 let res = match res {
678 Ok(inner) => inner,
679 Err(_) => {
680 error!("Consume task timed out");
681 STAT_TIMED_OUT_TASK.inc();
682 Err(listener.listener().requeue_on_timeout())
683 }
684 };
685
686 if let Err(requeue) = res {
687 #[allow(clippy::needless_update)]
688 let options = BasicRejectOptions {
689 requeue,
690 ..Default::default()
691 };
692
693 if let Err(err_reject) = delivery.reject(options).await {
694 error!(requeue, %err_reject, "Broker failed to send REJECT");
695 } else {
696 let exchange_name = listener.inner.exchange_name();
697 let routing_key = delivery.routing_key;
698 let redelivered = delivery.redelivered;
699
700 warn!(requeue, %exchange_name, %routing_key, %redelivered, "Error during consumption of a delivery, `REJECT` sent");
701 }
702 } else {
703 if let Err(err) = delivery.ack(BasicAckOptions::default()).await {
704 error!(
705 %err, "Delivery consumed, but failed to send ACK back to the broker",
706 );
707 }
708 }
709}