1use crate::{
2 connection::ConnectionManager,
3 error::{RabbitError, Result},
4 metrics::RustRabbitMetrics,
5 publisher::{CustomExchangeDeclareOptions, CustomQueueDeclareOptions, Publisher},
6 retry::{DelayedMessageExchange, RetryPolicy},
7};
8use async_trait::async_trait;
9use futures::StreamExt;
10use lapin::{
11 message::Delivery,
12 options::{
13 BasicAckOptions, BasicConsumeOptions, BasicNackOptions, BasicPublishOptions,
14 ExchangeDeclareOptions, QueueBindOptions, QueueDeclareOptions as LapinQueueDeclareOptions,
15 },
16 types::FieldTable,
17 BasicProperties, Channel, ExchangeKind,
18};
19use serde::de::DeserializeOwned;
20use std::sync::Arc;
21use tokio::sync::Semaphore;
22use tracing::{debug, error, info, warn};
23
24#[async_trait]
26pub trait MessageHandler<T>: Send + Sync + 'static
27where
28 T: DeserializeOwned + Send + Sync,
29{
30 async fn handle(&self, message: T, context: MessageContext) -> MessageResult;
32}
33
34#[derive(Debug, Clone)]
36pub struct MessageContext {
37 pub message_id: Option<String>,
38 pub correlation_id: Option<String>,
39 pub reply_to: Option<String>,
40 pub delivery_tag: u64,
41 pub redelivered: bool,
42 pub exchange: String,
43 pub routing_key: String,
44 pub headers: FieldTable,
45 pub timestamp: Option<u64>,
46 pub retry_count: u32,
47}
48
49#[derive(Debug)]
51pub enum MessageResult {
52 Ack,
54 Retry,
56 Reject,
58 Requeue,
60}
61
62#[derive(Debug, Clone)]
64pub struct ConsumerOptions {
65 pub queue_name: String,
67
68 pub consumer_tag: Option<String>,
70
71 pub concurrency: usize,
73
74 pub prefetch_count: Option<u16>,
76
77 pub auto_declare_queue: bool,
79
80 pub queue_options: CustomQueueDeclareOptions,
82
83 pub auto_declare_exchange: bool,
85
86 pub exchange_name: Option<String>,
88
89 pub exchange_options: CustomExchangeDeclareOptions,
91
92 pub routing_key: Option<String>,
94
95 pub retry_policy: Option<RetryPolicy>,
97
98 pub dead_letter_exchange: Option<String>,
100
101 pub auto_ack: bool,
103
104 pub exclusive: bool,
106
107 pub arguments: FieldTable,
109}
110
111impl ConsumerOptions {
112 pub fn builder<S: Into<String>>(queue_name: S) -> ConsumerOptionsBuilder {
114 ConsumerOptionsBuilder::new(queue_name.into())
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct ConsumerOptionsBuilder {
121 queue_name: String,
122 consumer_tag: Option<String>,
123 concurrency: usize,
124 prefetch_count: Option<u16>,
125 auto_declare_queue: bool,
126 queue_options: CustomQueueDeclareOptions,
127 auto_declare_exchange: bool,
128 exchange_name: Option<String>,
129 exchange_options: CustomExchangeDeclareOptions,
130 routing_key: Option<String>,
131 retry_policy: Option<RetryPolicy>,
132 dead_letter_exchange: Option<String>,
133 auto_ack: bool,
134 exclusive: bool,
135 arguments: FieldTable,
136}
137
138impl ConsumerOptionsBuilder {
139 pub fn new(queue_name: String) -> Self {
141 Self {
142 queue_name,
143 consumer_tag: None,
144 concurrency: 1,
145 prefetch_count: Some(10),
146 auto_declare_queue: false,
147 queue_options: CustomQueueDeclareOptions::default(),
148 auto_declare_exchange: false,
149 exchange_name: None,
150 exchange_options: CustomExchangeDeclareOptions::default(),
151 routing_key: None,
152 retry_policy: None,
153 dead_letter_exchange: None,
154 auto_ack: false,
155 exclusive: false,
156 arguments: FieldTable::default(),
157 }
158 }
159
160 pub fn consumer_tag<S: Into<String>>(mut self, tag: S) -> Self {
162 self.consumer_tag = Some(tag.into());
163 self
164 }
165
166 pub fn concurrency(mut self, concurrency: usize) -> Self {
168 self.concurrency = concurrency;
169 self
170 }
171
172 pub fn prefetch_count(mut self, count: u16) -> Self {
174 self.prefetch_count = Some(count);
175 self
176 }
177
178 pub fn no_prefetch_limit(mut self) -> Self {
180 self.prefetch_count = None;
181 self
182 }
183
184 pub fn auto_declare_queue(mut self) -> Self {
186 self.auto_declare_queue = true;
187 self
188 }
189
190 pub fn auto_declare_exchange(mut self) -> Self {
192 self.auto_declare_exchange = true;
193 self
194 }
195
196 pub fn exchange_name<S: Into<String>>(mut self, name: S) -> Self {
198 self.exchange_name = Some(name.into());
199 self
200 }
201
202 pub fn exchange_options(mut self, options: CustomExchangeDeclareOptions) -> Self {
204 self.exchange_options = options;
205 self
206 }
207
208 pub fn routing_key<S: Into<String>>(mut self, key: S) -> Self {
210 self.routing_key = Some(key.into());
211 self
212 }
213
214 pub fn queue_options(mut self, options: CustomQueueDeclareOptions) -> Self {
216 self.queue_options = options;
217 self
218 }
219
220 pub fn retry_policy(mut self, policy: RetryPolicy) -> Self {
222 self.retry_policy = Some(policy);
223 self
224 }
225
226 pub fn dead_letter_exchange<S: Into<String>>(mut self, exchange: S) -> Self {
228 self.dead_letter_exchange = Some(exchange.into());
229 self
230 }
231
232 pub fn auto_ack(mut self) -> Self {
234 self.auto_ack = true;
235 self
236 }
237
238 pub fn manual_ack(mut self) -> Self {
240 self.auto_ack = false;
241 self
242 }
243
244 pub fn exclusive(mut self) -> Self {
246 self.exclusive = true;
247 self
248 }
249
250 pub fn high_throughput(mut self) -> Self {
252 self.concurrency = 20;
253 self.prefetch_count = Some(50);
254 self.auto_ack = false;
255 self
256 }
257
258 pub fn reliable(mut self) -> Self {
260 self.concurrency = 1;
261 self.prefetch_count = Some(1);
262 self.auto_ack = false;
263 self
264 }
265
266 pub fn development(mut self) -> Self {
268 self.concurrency = 1;
269 self.prefetch_count = Some(1);
270 self.auto_ack = true;
271 self.auto_declare_queue = true;
272 self.auto_declare_exchange = true; self
274 }
275
276 pub fn minutes_retry(mut self) -> Self {
283 let queue_name = self.queue_name.clone();
284
285 self.auto_declare_queue = true;
286 self.auto_declare_exchange = true;
287 self.retry_policy = Some(RetryPolicy::minutes_exponential_for_queue(&queue_name));
288 self.concurrency = 1; self.prefetch_count = Some(1); self.auto_ack = false; self
292 }
293
294 pub fn build(self) -> ConsumerOptions {
296 ConsumerOptions {
297 queue_name: self.queue_name,
298 consumer_tag: self.consumer_tag,
299 concurrency: self.concurrency,
300 prefetch_count: self.prefetch_count,
301 auto_declare_queue: self.auto_declare_queue,
302 queue_options: self.queue_options,
303 auto_declare_exchange: self.auto_declare_exchange,
304 exchange_name: self.exchange_name,
305 exchange_options: self.exchange_options,
306 routing_key: self.routing_key,
307 retry_policy: self.retry_policy,
308 dead_letter_exchange: self.dead_letter_exchange,
309 auto_ack: self.auto_ack,
310 exclusive: self.exclusive,
311 arguments: self.arguments,
312 }
313 }
314}
315
316impl Default for ConsumerOptions {
317 fn default() -> Self {
318 Self {
319 queue_name: String::new(),
320 consumer_tag: None,
321 concurrency: 1,
322 prefetch_count: Some(10),
323 auto_declare_queue: false,
324 queue_options: CustomQueueDeclareOptions::default(),
325 auto_declare_exchange: false,
326 exchange_name: None,
327 exchange_options: CustomExchangeDeclareOptions::default(),
328 routing_key: None,
329 retry_policy: None,
330 dead_letter_exchange: None,
331 auto_ack: false,
332 exclusive: false,
333 arguments: FieldTable::default(),
334 }
335 }
336}
337
338pub struct Consumer {
340 #[allow(dead_code)] connection_manager: ConnectionManager,
342 options: ConsumerOptions,
343 channel: Channel,
344 semaphore: Arc<Semaphore>,
345 metrics: Option<RustRabbitMetrics>,
346 publisher: Publisher,
347}
348
349impl Consumer {
350 pub async fn new(
352 connection_manager: ConnectionManager,
353 options: ConsumerOptions,
354 ) -> Result<Self> {
355 let connection = connection_manager.get_connection().await?;
356 let channel = connection.create_channel().await?;
357
358 if let Some(prefetch_count) = options.prefetch_count {
360 debug!("Setting prefetch_count: {}", prefetch_count);
361 channel
362 .basic_qos(
363 prefetch_count,
364 lapin::options::BasicQosOptions { global: false },
365 )
366 .await
367 .map_err(|e| {
368 error!("Failed to set QoS prefetch_count={}: {}", prefetch_count, e);
369 RabbitError::Connection(e)
370 })?;
371 debug!("Successfully set prefetch_count: {}", prefetch_count);
372 }
373
374 if options.auto_declare_queue {
376 Self::declare_queue_and_exchange(&channel, &options).await?;
377 }
378
379 let semaphore = Arc::new(Semaphore::new(options.concurrency));
380
381 if options.retry_policy.is_some() {
383 Self::setup_retry_infrastructure(&connection_manager, &options).await?;
384 }
385
386 let publisher = Publisher::new(connection_manager.clone());
387
388 Ok(Self {
389 connection_manager,
390 options,
391 channel,
392 semaphore,
393 metrics: None,
394 publisher,
395 })
396 }
397
398 pub fn set_metrics(&mut self, metrics: RustRabbitMetrics) {
400 self.metrics = Some(metrics);
401 }
402
403 pub async fn consume<T, H>(&self, handler: Arc<H>) -> Result<()>
405 where
406 T: DeserializeOwned + Send + Sync + 'static,
407 H: MessageHandler<T>,
408 {
409 let consumer_tag = self
410 .options
411 .consumer_tag
412 .clone()
413 .unwrap_or_else(|| format!("rust-rabbit-{}", uuid::Uuid::new_v4()));
414
415 let consume_options = BasicConsumeOptions {
416 no_local: false,
417 no_ack: self.options.auto_ack,
418 exclusive: self.options.exclusive,
419 nowait: false,
420 };
421
422 let mut consumer = self
423 .channel
424 .basic_consume(
425 &self.options.queue_name,
426 &consumer_tag,
427 consume_options,
428 self.options.arguments.clone(),
429 )
430 .await?;
431
432 info!(
433 "Started consuming from queue: {} with tag: {}",
434 self.options.queue_name, consumer_tag
435 );
436
437 while let Some(delivery) = consumer.next().await {
438 let delivery = delivery?;
439 let permit = self
440 .semaphore
441 .clone()
442 .acquire_owned()
443 .await
444 .map_err(|e| RabbitError::Generic(e.into()))?;
445
446 let handler = handler.clone();
447 let retry_policy = self.options.retry_policy.clone();
448 let dead_letter_exchange = self.options.dead_letter_exchange.clone();
449 let channel = self.channel.clone();
450 let publisher = self.publisher.clone();
451 let exchange_name = self
452 .options
453 .exchange_name
454 .clone()
455 .unwrap_or_else(|| self.options.queue_name.clone());
456
457 tokio::spawn(async move {
459 let _permit = permit; if let Err(e) = Self::process_message::<T, H>(
462 delivery,
463 handler,
464 retry_policy,
465 dead_letter_exchange,
466 channel,
467 publisher,
468 exchange_name,
469 )
470 .await
471 {
472 error!("Error processing message: {}", e);
473 }
474 });
475 }
476
477 warn!(
478 "Consumer stream ended for queue: {}",
479 self.options.queue_name
480 );
481 Ok(())
482 }
483
484 async fn process_message<T, H>(
486 delivery: Delivery,
487 handler: Arc<H>,
488 retry_policy: Option<RetryPolicy>,
489 dead_letter_exchange: Option<String>,
490 channel: Channel,
491 publisher: Publisher,
492 exchange_name: String,
493 ) -> Result<()>
494 where
495 T: DeserializeOwned + Send + Sync,
496 H: MessageHandler<T>,
497 {
498 let context = Self::build_message_context(&delivery);
499
500 let message: T = match serde_json::from_slice(&delivery.data) {
502 Ok(msg) => msg,
503 Err(e) => {
504 error!("Failed to deserialize message: {}", e);
505 Self::reject_message(&delivery, &channel, false).await?;
506 return Ok(());
507 }
508 };
509
510 let result = handler.handle(message, context.clone()).await;
512
513 match result {
514 MessageResult::Ack => {
515 Self::ack_message(&delivery, &channel).await?;
516 debug!("Message acknowledged: {}", delivery.delivery_tag);
517 }
518 MessageResult::Retry => {
519 if let Some(ref policy) = retry_policy {
520 Self::handle_retry(
521 &delivery,
522 &channel,
523 &context,
524 policy,
525 &publisher,
526 &exchange_name,
527 )
528 .await?;
529 } else {
530 Self::reject_message(&delivery, &channel, true).await?;
531 }
532 }
533 MessageResult::Reject => {
534 if let Some(ref dle) = dead_letter_exchange {
535 Self::send_to_dead_letter(&delivery, dle, &context, &publisher).await?;
536 } else {
537 Self::reject_message(&delivery, &channel, false).await?;
538 }
539 }
540 MessageResult::Requeue => {
541 Self::reject_message(&delivery, &channel, true).await?;
542 }
543 }
544
545 Ok(())
546 }
547
548 fn build_message_context(delivery: &Delivery) -> MessageContext {
550 let properties = &delivery.properties;
551
552 MessageContext {
553 message_id: properties.message_id().as_ref().map(|s| s.to_string()),
554 correlation_id: properties.correlation_id().as_ref().map(|s| s.to_string()),
555 reply_to: properties.reply_to().as_ref().map(|s| s.to_string()),
556 delivery_tag: delivery.delivery_tag,
557 redelivered: delivery.redelivered,
558 exchange: delivery.exchange.to_string(),
559 routing_key: delivery.routing_key.to_string(),
560 headers: properties.headers().clone().unwrap_or_default(),
561 timestamp: *properties.timestamp(),
562 retry_count: Self::get_retry_count_from_headers(
563 properties
564 .headers()
565 .as_ref()
566 .unwrap_or(&FieldTable::default()),
567 ),
568 }
569 }
570
571 fn get_retry_count_from_headers(headers: &FieldTable) -> u32 {
573 headers
574 .inner()
575 .get("x-retry-count")
576 .and_then(|v| match v {
577 lapin::types::AMQPValue::LongInt(count) => Some(*count as u32),
578 lapin::types::AMQPValue::LongLongInt(count) => Some(*count as u32),
579 _ => None,
580 })
581 .unwrap_or(0)
582 }
583
584 async fn ack_message(delivery: &Delivery, channel: &Channel) -> Result<()> {
586 channel
587 .basic_ack(delivery.delivery_tag, BasicAckOptions::default())
588 .await?;
589 Ok(())
590 }
591
592 async fn reject_message(delivery: &Delivery, channel: &Channel, requeue: bool) -> Result<()> {
594 channel
595 .basic_nack(
596 delivery.delivery_tag,
597 BasicNackOptions {
598 multiple: false,
599 requeue,
600 },
601 )
602 .await?;
603 Ok(())
604 }
605
606 async fn handle_retry(
608 delivery: &Delivery,
609 channel: &Channel,
610 context: &MessageContext,
611 retry_policy: &RetryPolicy,
612 publisher: &Publisher,
613 exchange_name: &str,
614 ) -> Result<()> {
615 if context.retry_count >= retry_policy.max_retries {
616 warn!(
617 "Max retries exceeded for message: {}",
618 delivery.delivery_tag
619 );
620
621 if let Some(ref dle) = retry_policy.dead_letter_exchange {
623 Self::send_to_dead_letter(delivery, dle, context, publisher).await?;
624 } else {
625 Self::reject_message(delivery, channel, false).await?;
626 }
627 return Ok(());
628 }
629
630 let delay = retry_policy.calculate_delay(context.retry_count);
632 let delayed_exchange_name = format!("{}.retry", exchange_name);
633
634 let mut headers = delivery.properties.headers().clone().unwrap_or_default();
636 headers.insert(
637 "x-retry-count".into(),
638 lapin::types::AMQPValue::LongInt((context.retry_count + 1) as i32),
639 );
640 headers.insert(
641 "x-original-queue".into(),
642 lapin::types::AMQPValue::LongString(context.routing_key.clone().into()),
643 );
644
645 let mut properties = BasicProperties::default()
647 .with_content_type("application/json".into())
648 .with_delivery_mode(2)
649 .with_headers(headers);
650
651 let mut delay_headers = properties.headers().clone().unwrap_or_default();
653 delay_headers.insert(
654 "x-delay".into(),
655 lapin::types::AMQPValue::LongLongInt(delay.as_millis() as i64),
656 );
657 properties = properties.with_headers(delay_headers);
658
659 channel
661 .basic_publish(
662 &delayed_exchange_name,
663 &context.routing_key,
664 BasicPublishOptions::default(),
665 &delivery.data,
666 properties,
667 )
668 .await?;
669
670 info!(
671 "Retrying message after {:?} (attempt {})",
672 delay,
673 context.retry_count + 1
674 );
675
676 Self::ack_message(delivery, channel).await?;
678
679 Ok(())
680 }
681
682 async fn send_to_dead_letter(
684 delivery: &Delivery,
685 dead_letter_exchange: &str,
686 _context: &MessageContext,
687 publisher: &Publisher,
688 ) -> Result<()> {
689 let mut headers = delivery.properties.headers().clone().unwrap_or_default();
691 headers.insert(
692 "x-death-reason".into(),
693 lapin::types::AMQPValue::LongString("max-retries-exceeded".into()),
694 );
695 headers.insert(
696 "x-death-time".into(),
697 lapin::types::AMQPValue::LongLongInt(chrono::Utc::now().timestamp_millis()),
698 );
699
700 let properties = BasicProperties::default()
702 .with_content_type("application/json".into())
703 .with_delivery_mode(2)
704 .with_headers(headers);
705
706 let connection = publisher.get_connection().await?;
708 let dlx_channel = connection.create_channel().await?;
709
710 dlx_channel
711 .basic_publish(
712 dead_letter_exchange,
713 "dead-letter", BasicPublishOptions::default(),
715 &delivery.data,
716 properties,
717 )
718 .await?;
719
720 warn!(
721 "Sent message to dead letter exchange: {}",
722 dead_letter_exchange
723 );
724
725 Ok(())
726 }
727
728 pub async fn stop(&self) -> Result<()> {
730 info!("Stopping consumer for queue: {}", self.options.queue_name);
733 Ok(())
734 }
735
736 async fn declare_queue_and_exchange(
738 channel: &Channel,
739 options: &ConsumerOptions,
740 ) -> Result<()> {
741 let queue_options = LapinQueueDeclareOptions {
743 passive: options.queue_options.passive,
744 durable: options.queue_options.durable,
745 exclusive: options.queue_options.exclusive,
746 auto_delete: options.queue_options.auto_delete,
747 nowait: false,
748 };
749
750 channel
751 .queue_declare(
752 &options.queue_name,
753 queue_options,
754 options.queue_options.arguments.clone(),
755 )
756 .await?;
757
758 debug!("Declared queue: {}", options.queue_name);
759
760 if options.auto_declare_exchange {
762 let exchange_name = options
763 .exchange_name
764 .as_ref()
765 .unwrap_or(&options.queue_name);
766
767 let exchange_options = ExchangeDeclareOptions {
769 passive: options.exchange_options.passive,
770 durable: options.exchange_options.durable,
771 auto_delete: options.exchange_options.auto_delete,
772 internal: options.exchange_options.internal,
773 nowait: false,
774 };
775
776 let mut arguments = options.exchange_options.arguments.clone();
778 if matches!(options.exchange_options.exchange_type, ExchangeKind::Custom(ref kind) if kind == "x-delayed-message")
779 {
780 arguments.insert(
781 "x-delayed-type".into(),
782 lapin::types::AMQPValue::LongString(
783 match options.exchange_options.original_type {
784 ExchangeKind::Direct => "direct".into(),
785 ExchangeKind::Fanout => "fanout".into(),
786 ExchangeKind::Topic => "topic".into(),
787 ExchangeKind::Headers => "headers".into(),
788 ExchangeKind::Custom(ref s) => s.clone().into(),
789 },
790 ),
791 );
792 }
793
794 channel
795 .exchange_declare(
796 exchange_name,
797 options.exchange_options.exchange_type.clone(),
798 exchange_options,
799 arguments,
800 )
801 .await?;
802
803 debug!("Declared exchange: {}", exchange_name);
804
805 let routing_key = options.routing_key.as_ref().unwrap_or(&options.queue_name);
807
808 channel
809 .queue_bind(
810 &options.queue_name,
811 exchange_name,
812 routing_key,
813 QueueBindOptions::default(),
814 FieldTable::default(),
815 )
816 .await?;
817
818 debug!(
819 "Bound queue '{}' to exchange '{}' with routing key '{}'",
820 options.queue_name, exchange_name, routing_key
821 );
822 }
823
824 Ok(())
825 }
826
827 async fn setup_retry_infrastructure(
829 connection_manager: &ConnectionManager,
830 options: &ConsumerOptions,
831 ) -> Result<()> {
832 if let Some(ref retry_policy) = options.retry_policy {
833 let delayed_exchange_name = format!(
835 "{}.retry",
836 options
837 .exchange_name
838 .as_ref()
839 .unwrap_or(&options.queue_name)
840 );
841
842 let delayed_exchange = DelayedMessageExchange::new(
844 connection_manager.clone(),
845 delayed_exchange_name.clone(),
846 retry_policy.clone(),
847 );
848
849 delayed_exchange.setup().await?;
851
852 delayed_exchange
854 .setup_queue_retry(&options.queue_name)
855 .await?;
856
857 debug!(
858 "Setup retry infrastructure for queue: {} with delayed exchange: {}",
859 options.queue_name, delayed_exchange_name
860 );
861 }
862
863 Ok(())
864 }
865}
866
867pub struct SimpleMessageHandler<F, T>
869where
870 F: Fn(T, MessageContext) -> MessageResult + Send + Sync,
871 T: DeserializeOwned + Send + Sync,
872{
873 handler_fn: F,
874 _phantom: std::marker::PhantomData<T>,
875}
876
877impl<F, T> SimpleMessageHandler<F, T>
878where
879 F: Fn(T, MessageContext) -> MessageResult + Send + Sync + 'static,
880 T: DeserializeOwned + Send + Sync + 'static,
881{
882 pub fn new(handler_fn: F) -> Self {
883 Self {
884 handler_fn,
885 _phantom: std::marker::PhantomData,
886 }
887 }
888}
889
890#[async_trait]
891impl<F, T> MessageHandler<T> for SimpleMessageHandler<F, T>
892where
893 F: Fn(T, MessageContext) -> MessageResult + Send + Sync + 'static,
894 T: DeserializeOwned + Send + Sync + 'static,
895{
896 async fn handle(&self, message: T, context: MessageContext) -> MessageResult {
897 (self.handler_fn)(message, context)
898 }
899}