1use crate::client::{QueueProvider, SessionProvider};
58use crate::error::QueueError;
59use crate::message::{
60 Message, MessageId, QueueName, ReceiptHandle, ReceivedMessage, SessionId, Timestamp,
61};
62use crate::provider::{ProviderType, SessionSupport};
63use async_trait::async_trait;
64use bytes::Bytes;
65use chrono::Duration;
66use futures::StreamExt;
67use lapin::{
68 options::{
69 BasicAckOptions, BasicConsumeOptions, BasicGetOptions, BasicNackOptions,
70 BasicPublishOptions, BasicQosOptions, QueueDeclareOptions,
71 },
72 types::{AMQPValue, FieldTable, LongString, ShortString},
73 BasicProperties, Channel, Connection, ConnectionProperties,
74};
75use serde::{Deserialize, Serialize};
76use std::collections::HashMap;
77use std::sync::Arc;
78use tokio::sync::{mpsc, Mutex};
79use tracing::{debug, instrument, warn};
80
81#[cfg(test)]
82#[path = "rabbitmq_tests.rs"]
83mod tests;
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct RabbitMqConfig {
109 pub url: String,
111 pub virtual_host: String,
113 pub prefetch_count: u16,
115 pub session_lock_duration: Duration,
117 pub message_ttl: Option<Duration>,
119 pub enable_dead_letter: bool,
121 pub dead_letter_exchange: Option<String>,
123}
124
125impl Default for RabbitMqConfig {
126 fn default() -> Self {
127 Self {
128 url: "amqp://guest:guest@localhost:5672".to_string(),
129 virtual_host: "/".to_string(),
130 prefetch_count: 10,
131 session_lock_duration: Duration::minutes(5),
132 message_ttl: None,
133 enable_dead_letter: true,
134 dead_letter_exchange: Some("dlx".to_string()),
135 }
136 }
137}
138
139#[derive(Debug)]
145pub struct RabbitMqError {
146 message: String,
147}
148
149impl RabbitMqError {
150 fn new(message: impl Into<String>) -> Self {
151 Self {
152 message: message.into(),
153 }
154 }
155
156 pub fn to_queue_error(&self) -> QueueError {
158 QueueError::ProviderError {
159 provider: "rabbitmq".to_string(),
160 code: "AMQP_ERROR".to_string(),
161 message: self.message.clone(),
162 }
163 }
164}
165
166impl std::fmt::Display for RabbitMqError {
167 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168 write!(f, "RabbitMQ error: {}", self.message)
169 }
170}
171
172impl std::error::Error for RabbitMqError {}
173
174struct InFlightEntry {
180 channel: Channel,
182 delivery_tag: u64,
184 lock_expires_at: Timestamp,
186}
187
188fn redact_url(url: &str) -> String {
198 match url::Url::parse(url) {
199 Ok(mut parsed) => {
200 let has_credentials = !parsed.username().is_empty() || parsed.password().is_some();
201 if has_credentials {
202 let _ = parsed.set_username("***");
204 let _ = parsed.set_password(Some("***"));
205 }
206 parsed.to_string()
207 }
208 Err(_) => "<invalid-url>".to_string(),
209 }
210}
211
212fn session_queue_name(queue: &QueueName, session_id: &SessionId) -> String {
214 let safe = session_id.as_str().replace(['/', ' ', '\\'], "_");
216 format!("{}.session.{}", queue.as_str(), safe)
217}
218
219pub struct RabbitMqProvider {
241 connection: Arc<Connection>,
242 config: RabbitMqConfig,
243 in_flight: Arc<Mutex<HashMap<String, InFlightEntry>>>,
245 publish_channel: Arc<Mutex<Option<Channel>>>,
247 receive_channel: Arc<Mutex<Option<Channel>>>,
249}
250
251impl RabbitMqProvider {
252 pub async fn new(config: RabbitMqConfig) -> Result<Self, RabbitMqError> {
271 let conn = Connection::connect(&config.url, ConnectionProperties::default())
272 .await
273 .map_err(|e| {
274 RabbitMqError::new(format!(
275 "failed to connect to RabbitMQ at '{}': {}",
276 redact_url(&config.url),
277 e
278 ))
279 })?;
280
281 debug!(url = %redact_url(&config.url), "Connected to RabbitMQ");
282
283 Ok(Self {
284 connection: Arc::new(conn),
285 config,
286 in_flight: Arc::new(Mutex::new(HashMap::new())),
287 publish_channel: Arc::new(Mutex::new(None)),
288 receive_channel: Arc::new(Mutex::new(None)),
289 })
290 }
291
292 async fn open_channel(&self) -> Result<Channel, QueueError> {
294 let channel =
295 self.connection
296 .create_channel()
297 .await
298 .map_err(|e| QueueError::ConnectionFailed {
299 message: format!("failed to create AMQP channel: {}", e),
300 })?;
301
302 if self.config.prefetch_count > 0 {
303 channel
304 .basic_qos(self.config.prefetch_count, BasicQosOptions::default())
305 .await
306 .map_err(|e| QueueError::ProviderError {
307 provider: "rabbitmq".to_string(),
308 code: "QOS_FAILED".to_string(),
309 message: format!("failed to set QoS prefetch: {}", e),
310 })?;
311 }
312
313 Ok(channel)
314 }
315
316 async fn get_publish_channel(&self) -> Result<Channel, QueueError> {
321 let mut guard = self.publish_channel.lock().await;
322 if let Some(ref ch) = *guard {
323 if ch.status().connected() {
324 return Ok(ch.clone());
325 }
326 }
327 let ch = self.open_channel().await?;
328 *guard = Some(ch.clone());
329 Ok(ch)
330 }
331
332 async fn get_receive_channel(&self) -> Result<Channel, QueueError> {
334 let mut guard = self.receive_channel.lock().await;
335 if let Some(ref ch) = *guard {
336 if ch.status().connected() {
337 return Ok(ch.clone());
338 }
339 }
340 let ch = self.open_channel().await?;
341 *guard = Some(ch.clone());
342 Ok(ch)
343 }
344
345 async fn declare_queue(&self, channel: &Channel, queue: &QueueName) -> Result<(), QueueError> {
348 let mut args = FieldTable::default();
349
350 if self.config.enable_dead_letter {
351 if let Some(ref dlx) = self.config.dead_letter_exchange {
352 args.insert(
353 ShortString::from("x-dead-letter-exchange"),
354 AMQPValue::LongString(LongString::from(dlx.as_bytes())),
355 );
356 }
357 }
358
359 if let Some(ttl) = self.config.message_ttl {
360 let ttl_ms = ttl.num_milliseconds();
361 if ttl_ms > 0 {
362 args.insert(
363 ShortString::from("x-message-ttl"),
364 AMQPValue::LongLongInt(ttl_ms),
365 );
366 }
367 }
368
369 let opts = QueueDeclareOptions {
370 durable: true,
371 ..Default::default()
372 };
373
374 channel
375 .queue_declare(queue.as_str().into(), opts, args)
376 .await
377 .map_err(|e| QueueError::ProviderError {
378 provider: "rabbitmq".to_string(),
379 code: "QUEUE_DECLARE_FAILED".to_string(),
380 message: format!("failed to declare queue '{}': {}", queue.as_str(), e),
381 })?;
382
383 Ok(())
384 }
385
386 async fn declare_session_queue(
391 &self,
392 channel: &Channel,
393 queue: &QueueName,
394 session_id: &SessionId,
395 ) -> Result<String, QueueError> {
396 let name = session_queue_name(queue, session_id);
397
398 let opts = QueueDeclareOptions {
399 durable: true,
400 ..Default::default()
401 };
402
403 channel
404 .queue_declare(name.as_str().into(), opts, FieldTable::default())
405 .await
406 .map_err(|e| QueueError::ProviderError {
407 provider: "rabbitmq".to_string(),
408 code: "SESSION_QUEUE_DECLARE_FAILED".to_string(),
409 message: format!("failed to declare session queue '{}': {}", name, e),
410 })?;
411
412 Ok(name)
413 }
414
415 fn build_properties(message: &Message) -> BasicProperties {
417 let mut props = BasicProperties::default().with_delivery_mode(2); if let Some(ref corr_id) = message.correlation_id {
420 props = props.with_correlation_id(ShortString::from(corr_id.as_str()));
421 }
422
423 if let Some(ttl) = message.time_to_live {
424 let ttl_ms = ttl.num_milliseconds();
425 if ttl_ms > 0 {
426 props = props.with_expiration(ShortString::from(ttl_ms.to_string().as_str()));
427 }
428 }
429
430 let mut headers = FieldTable::default();
436 for (k, v) in &message.attributes {
437 let header_key = format!("x-attr-{}", k);
438 headers.insert(
439 ShortString::from(header_key.as_str()),
440 AMQPValue::LongString(LongString::from(v.as_bytes())),
441 );
442 }
443 if let Some(ref sid) = message.session_id {
444 headers.insert(
445 ShortString::from("x-session-id"),
446 AMQPValue::LongString(LongString::from(sid.as_str().as_bytes())),
447 );
448 }
449
450 props.with_headers(headers)
451 }
452
453 fn extract_attributes(headers: &Option<FieldTable>) -> HashMap<String, String> {
461 let mut attrs = HashMap::new();
462 if let Some(ht) = headers {
463 for (k, v) in ht.inner() {
464 let key = k.as_str();
465 if let Some(attr_key) = key.strip_prefix("x-attr-") {
467 if let AMQPValue::LongString(s) = v {
468 attrs.insert(
469 attr_key.to_string(),
470 String::from_utf8_lossy(s.as_bytes()).to_string(),
471 );
472 }
473 }
474 }
475 }
476 attrs
477 }
478
479 fn extract_session_id(headers: &Option<FieldTable>) -> Option<SessionId> {
481 if let Some(ht) = headers {
482 if let Some(AMQPValue::LongString(s)) = ht.inner().get("x-session-id") {
483 let id = String::from_utf8_lossy(s.as_bytes()).to_string();
484 return SessionId::new(id).ok();
485 }
486 }
487 None
488 }
489
490 fn extract_delivery_count(headers: &Option<FieldTable>, redelivered: bool) -> u32 {
497 if let Some(ht) = headers {
498 if let Some(AMQPValue::LongLongInt(n)) = ht.inner().get("x-delivery-count") {
499 return (*n as u32).saturating_add(1);
500 }
501 }
502 if redelivered {
503 2
504 } else {
505 1
506 }
507 }
508
509 async fn register_delivery(
511 &self,
512 channel: &Channel,
513 delivery_tag: u64,
514 data: &[u8],
515 headers: Option<FieldTable>,
516 correlation_id: Option<String>,
517 redelivered: bool,
518 ) -> ReceivedMessage {
519 let session_id = Self::extract_session_id(&headers);
520 let attributes = Self::extract_attributes(&headers);
521 let delivery_count = Self::extract_delivery_count(&headers, redelivered);
522
523 let now = Timestamp::now();
524 let lock_expires_at =
525 Timestamp::from_datetime(now.as_datetime() + self.config.session_lock_duration);
526
527 let receipt_id = uuid::Uuid::new_v4().to_string();
528 let message_id = MessageId::new();
529 let body = Bytes::copy_from_slice(data);
530
531 self.in_flight.lock().await.insert(
532 receipt_id.clone(),
533 InFlightEntry {
534 channel: channel.clone(),
535 delivery_tag,
536 lock_expires_at,
537 },
538 );
539
540 ReceivedMessage {
541 message_id,
542 body,
543 attributes,
544 session_id,
545 correlation_id,
546 receipt_handle: ReceiptHandle::new(receipt_id, lock_expires_at, ProviderType::RabbitMq),
547 delivery_count,
548 first_delivered_at: now,
549 delivered_at: now,
550 }
551 }
552
553 async fn settle_message(
555 &self,
556 receipt: &ReceiptHandle,
557 requeue: Option<bool>,
558 ) -> Result<(), QueueError> {
559 let mut in_flight = self.in_flight.lock().await;
560
561 match in_flight.get(receipt.handle()) {
566 None => {
567 return Err(QueueError::MessageNotFound {
568 receipt: receipt.handle().to_string(),
569 });
570 }
571 Some(entry) if Timestamp::now() > entry.lock_expires_at => {
572 in_flight.remove(receipt.handle());
574 return Err(QueueError::MessageNotFound {
575 receipt: format!("{}(expired)", receipt.handle()),
576 });
577 }
578 Some(_) => {}
579 }
580
581 let entry = in_flight
582 .remove(receipt.handle())
583 .expect("entry present after pre-check");
584
585 match requeue {
586 None => {
587 entry
589 .channel
590 .basic_ack(entry.delivery_tag, BasicAckOptions::default())
591 .await
592 .map_err(|e| QueueError::ProviderError {
593 provider: "rabbitmq".to_string(),
594 code: "BASIC_ACK_FAILED".to_string(),
595 message: format!("basic_ack failed: {}", e),
596 })?;
597 }
598 Some(requeue_flag) => {
599 entry
601 .channel
602 .basic_nack(
603 entry.delivery_tag,
604 BasicNackOptions {
605 requeue: requeue_flag,
606 ..Default::default()
607 },
608 )
609 .await
610 .map_err(|e| QueueError::ProviderError {
611 provider: "rabbitmq".to_string(),
612 code: "BASIC_NACK_FAILED".to_string(),
613 message: format!("basic_nack failed: {}", e),
614 })?;
615 }
616 }
617
618 Ok(())
619 }
620}
621
622#[async_trait]
627impl QueueProvider for RabbitMqProvider {
628 #[instrument(skip(self, message), fields(queue = %queue))]
629 async fn send_message(
630 &self,
631 queue: &QueueName,
632 message: &Message,
633 ) -> Result<MessageId, QueueError> {
634 let size = message.body.len();
635 let max_size = self.provider_type().max_message_size();
636 if size > max_size {
637 return Err(QueueError::MessageTooLarge { size, max_size });
638 }
639
640 let channel = self.get_publish_channel().await?;
641
642 let routing_key = if let Some(ref sid) = message.session_id {
645 self.declare_session_queue(&channel, queue, sid).await?
646 } else {
647 self.declare_queue(&channel, queue).await?;
648 queue.as_str().to_string()
649 };
650
651 let props = Self::build_properties(message);
652
653 channel
654 .basic_publish(
655 "".into(),
656 routing_key.as_str().into(),
657 BasicPublishOptions::default(),
658 &message.body,
659 props,
660 )
661 .await
662 .map_err(|e| QueueError::ProviderError {
663 provider: "rabbitmq".to_string(),
664 code: "PUBLISH_FAILED".to_string(),
665 message: format!("failed to publish message to '{}': {}", routing_key, e),
666 })?
667 .await
668 .map_err(|e| QueueError::ProviderError {
669 provider: "rabbitmq".to_string(),
670 code: "PUBLISH_CONFIRM_FAILED".to_string(),
671 message: format!("publish confirmation failed: {}", e),
672 })?;
673
674 let message_id = MessageId::new();
675 debug!(%message_id, %queue, "Published message to RabbitMQ");
676 Ok(message_id)
677 }
678
679 #[instrument(skip(self, messages), fields(queue = %queue, count = messages.len()))]
680 async fn send_messages(
681 &self,
682 queue: &QueueName,
683 messages: &[Message],
684 ) -> Result<Vec<MessageId>, QueueError> {
685 if messages.len() > self.max_batch_size() as usize {
686 return Err(QueueError::BatchTooLarge {
687 size: messages.len(),
688 max_size: self.max_batch_size() as usize,
689 });
690 }
691
692 let mut ids = Vec::with_capacity(messages.len());
699 for message in messages {
700 ids.push(self.send_message(queue, message).await?);
701 }
702 Ok(ids)
703 }
704
705 #[instrument(skip(self), fields(queue = %queue))]
706 async fn receive_message(
707 &self,
708 queue: &QueueName,
709 timeout: Duration,
710 ) -> Result<Option<ReceivedMessage>, QueueError> {
711 let channel = self.get_receive_channel().await?;
712 self.declare_queue(&channel, queue).await?;
713
714 let start = std::time::Instant::now();
715 let timeout_std = timeout
716 .to_std()
717 .unwrap_or(std::time::Duration::from_secs(30));
718
719 loop {
720 let get = channel
721 .basic_get(queue.as_str().into(), BasicGetOptions { no_ack: false })
722 .await
723 .map_err(|e| QueueError::ProviderError {
724 provider: "rabbitmq".to_string(),
725 code: "BASIC_GET_FAILED".to_string(),
726 message: format!("basic_get on '{}' failed: {}", queue.as_str(), e),
727 })?;
728
729 if let Some(delivery) = get {
730 let headers = delivery.delivery.properties.headers().clone();
731 let redelivered = delivery.delivery.redelivered;
732 let correlation_id = delivery
733 .delivery
734 .properties
735 .correlation_id()
736 .as_ref()
737 .map(|s| s.to_string());
738 let msg = self
739 .register_delivery(
740 &channel,
741 delivery.delivery.delivery_tag,
742 &delivery.delivery.data,
743 headers,
744 correlation_id,
745 redelivered,
746 )
747 .await;
748 return Ok(Some(msg));
749 }
750
751 if start.elapsed() >= timeout_std {
752 return Ok(None);
753 }
754
755 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
756 }
757 }
758
759 #[instrument(skip(self), fields(queue = %queue, max = max_messages))]
760 async fn receive_messages(
761 &self,
762 queue: &QueueName,
763 max_messages: u32,
764 timeout: Duration,
765 ) -> Result<Vec<ReceivedMessage>, QueueError> {
766 let channel = self.get_receive_channel().await?;
767 self.declare_queue(&channel, queue).await?;
768
769 let mut messages = Vec::new();
770 let start = std::time::Instant::now();
771 let timeout_std = timeout
772 .to_std()
773 .unwrap_or(std::time::Duration::from_secs(30));
774
775 while messages.len() < max_messages as usize {
776 if start.elapsed() >= timeout_std {
777 break;
778 }
779
780 let get = channel
781 .basic_get(queue.as_str().into(), BasicGetOptions { no_ack: false })
782 .await
783 .map_err(|e| QueueError::ProviderError {
784 provider: "rabbitmq".to_string(),
785 code: "BASIC_GET_FAILED".to_string(),
786 message: format!("basic_get on '{}' failed: {}", queue.as_str(), e),
787 })?;
788
789 match get {
790 Some(delivery) => {
791 let headers = delivery.delivery.properties.headers().clone();
792 let redelivered = delivery.delivery.redelivered;
793 let correlation_id = delivery
794 .delivery
795 .properties
796 .correlation_id()
797 .as_ref()
798 .map(|s| s.to_string());
799 let msg = self
800 .register_delivery(
801 &channel,
802 delivery.delivery.delivery_tag,
803 &delivery.delivery.data,
804 headers,
805 correlation_id,
806 redelivered,
807 )
808 .await;
809 messages.push(msg);
810 }
811 None => {
814 if start.elapsed() >= timeout_std {
815 break;
816 }
817 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
818 }
819 }
820 }
821
822 Ok(messages)
823 }
824
825 #[instrument(skip(self, receipt))]
826 async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
827 self.settle_message(receipt, None).await
828 }
829
830 #[instrument(skip(self, receipt))]
831 async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
832 self.settle_message(receipt, Some(true)).await
833 }
834
835 #[instrument(skip(self, receipt), fields(reason = %reason))]
836 async fn dead_letter_message(
837 &self,
838 receipt: &ReceiptHandle,
839 reason: &str,
840 ) -> Result<(), QueueError> {
841 debug!(reason, "Dead-lettering RabbitMQ message");
842 self.settle_message(receipt, Some(false)).await
846 }
847
848 #[instrument(skip(self), fields(queue = %queue))]
849 async fn create_session_client(
850 &self,
851 queue: &QueueName,
852 session_id: Option<SessionId>,
853 ) -> Result<Box<dyn SessionProvider>, QueueError> {
854 let sid = match session_id {
855 Some(id) => id,
856 None => {
857 return Err(QueueError::SessionNotFound {
859 session_id: "<any>".to_string(),
860 });
861 }
862 };
863
864 let channel = self.open_channel().await?;
865 let session_queue = self.declare_session_queue(&channel, queue, &sid).await?;
866
867 let consumer = channel
870 .basic_consume(
871 session_queue.as_str().into(),
872 format!("session-{}", uuid::Uuid::new_v4()).as_str().into(),
873 BasicConsumeOptions {
874 exclusive: true,
875 no_ack: false,
876 ..Default::default()
877 },
878 FieldTable::default(),
879 )
880 .await
881 .map_err(|e| QueueError::ProviderError {
882 provider: "rabbitmq".to_string(),
883 code: "CONSUME_FAILED".to_string(),
884 message: format!(
885 "failed to start exclusive consumer on '{}': {}",
886 session_queue, e
887 ),
888 })?;
889
890 let now = Timestamp::now();
891 let lock_expires_at =
892 Timestamp::from_datetime(now.as_datetime() + self.config.session_lock_duration);
893
894 let (tx, rx) = mpsc::unbounded_channel::<lapin::message::Delivery>();
898 tokio::spawn(async move {
899 let mut consumer = consumer;
900 while let Some(result) = consumer.next().await {
901 match result {
902 Ok(delivery) => {
903 if tx.send(delivery).is_err() {
904 break;
905 }
906 }
907 Err(e) => {
908 warn!(error = %e, "RabbitMQ session consumer error");
909 break;
910 }
911 }
912 }
913 });
914
915 Ok(Box::new(RabbitMqSessionProvider {
916 channel,
917 deliveries: Arc::new(Mutex::new(rx)),
918 session_id: sid,
919 in_flight: self.in_flight.clone(),
920 lock_expires_at: Arc::new(std::sync::Mutex::new(lock_expires_at)),
921 config: self.config.clone(),
922 }))
923 }
924
925 fn provider_type(&self) -> ProviderType {
926 ProviderType::RabbitMq
927 }
928
929 fn supports_sessions(&self) -> SessionSupport {
930 SessionSupport::Emulated
931 }
932
933 fn supports_batching(&self) -> bool {
934 true
935 }
936
937 fn max_batch_size(&self) -> u32 {
938 100
939 }
940}
941
942pub struct RabbitMqSessionProvider {
952 channel: Channel,
953 deliveries: Arc<Mutex<mpsc::UnboundedReceiver<lapin::message::Delivery>>>,
955 session_id: SessionId,
956 in_flight: Arc<Mutex<HashMap<String, InFlightEntry>>>,
957 lock_expires_at: Arc<std::sync::Mutex<Timestamp>>,
959 config: RabbitMqConfig,
960}
961
962#[async_trait]
963impl SessionProvider for RabbitMqSessionProvider {
964 #[instrument(skip(self), fields(session_id = %self.session_id))]
965 async fn receive_message(
966 &self,
967 timeout: Duration,
968 ) -> Result<Option<ReceivedMessage>, QueueError> {
969 self.check_lock()?;
970
971 let timeout_std = timeout
972 .to_std()
973 .unwrap_or(std::time::Duration::from_secs(30));
974
975 let mut rx = self.deliveries.lock().await;
976 match tokio::time::timeout(timeout_std, rx.recv()).await {
977 Ok(Some(delivery)) => {
978 let msg = self.register_session_delivery(delivery).await;
979 Ok(Some(msg))
980 }
981 Ok(None) => Ok(None),
982 Err(_) => Ok(None), }
984 }
985
986 #[instrument(skip(self, receipt))]
987 async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
988 self.check_lock()?;
989 self.settle(receipt, None).await
990 }
991
992 #[instrument(skip(self, receipt))]
993 async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
994 self.check_lock()?;
995 self.settle(receipt, Some(true)).await
996 }
997
998 #[instrument(skip(self, receipt), fields(reason = %reason))]
999 async fn dead_letter_message(
1000 &self,
1001 receipt: &ReceiptHandle,
1002 reason: &str,
1003 ) -> Result<(), QueueError> {
1004 self.check_lock()?;
1005 debug!(reason, "Dead-lettering session message");
1006 self.settle(receipt, Some(false)).await
1007 }
1008
1009 async fn renew_session_lock(&self) -> Result<(), QueueError> {
1010 advance_session_lock(&self.lock_expires_at, self.config.session_lock_duration)?;
1011 debug!(session_id = %self.session_id, "RabbitMQ session lock renewed");
1012 Ok(())
1013 }
1014
1015 async fn close_session(&self) -> Result<(), QueueError> {
1016 if let Err(e) = self.channel.close(200, "session closed".into()).await {
1017 warn!(error = %e, "Failed to cleanly close RabbitMQ session channel");
1018 }
1019 Ok(())
1020 }
1021
1022 fn session_id(&self) -> &SessionId {
1023 &self.session_id
1024 }
1025
1026 fn session_expires_at(&self) -> Timestamp {
1027 *self
1029 .lock_expires_at
1030 .lock()
1031 .unwrap_or_else(|e| e.into_inner())
1032 }
1033}
1034
1035fn check_session_lock(
1044 lock_expires_at: &std::sync::Mutex<Timestamp>,
1045 session_id: &SessionId,
1046) -> Result<(), QueueError> {
1047 let expires = *lock_expires_at
1048 .lock()
1049 .map_err(|_| QueueError::ProviderError {
1050 provider: "rabbitmq".to_string(),
1051 code: "INTERNAL_ERROR".to_string(),
1052 message: "session lock mutex poisoned".to_string(),
1053 })?;
1054 if Timestamp::now() > expires {
1055 return Err(QueueError::SessionLocked {
1056 session_id: session_id.as_str().to_string(),
1057 locked_until: expires,
1058 });
1059 }
1060 Ok(())
1061}
1062
1063fn advance_session_lock(
1067 lock_expires_at: &std::sync::Mutex<Timestamp>,
1068 duration: Duration,
1069) -> Result<Timestamp, QueueError> {
1070 let new_expiry = Timestamp::from_datetime(Timestamp::now().as_datetime() + duration);
1071 *lock_expires_at
1072 .lock()
1073 .map_err(|_| QueueError::ProviderError {
1074 provider: "rabbitmq".to_string(),
1075 code: "INTERNAL_ERROR".to_string(),
1076 message: "session lock mutex poisoned".to_string(),
1077 })? = new_expiry;
1078 Ok(new_expiry)
1079}
1080
1081impl RabbitMqSessionProvider {
1082 fn check_lock(&self) -> Result<(), QueueError> {
1084 check_session_lock(&self.lock_expires_at, &self.session_id)
1085 }
1086
1087 async fn register_session_delivery(
1089 &self,
1090 delivery: lapin::message::Delivery,
1091 ) -> ReceivedMessage {
1092 let delivery_tag = delivery.delivery_tag;
1093 let redelivered = delivery.redelivered;
1094 let headers = delivery.properties.headers().clone();
1095 let attributes = RabbitMqProvider::extract_attributes(&headers);
1096 let delivery_count = RabbitMqProvider::extract_delivery_count(&headers, redelivered);
1097 let correlation_id = delivery
1098 .properties
1099 .correlation_id()
1100 .as_ref()
1101 .map(|s| s.to_string());
1102
1103 let now = Timestamp::now();
1104 let lock_expires_at =
1105 Timestamp::from_datetime(now.as_datetime() + self.config.session_lock_duration);
1106
1107 let receipt_id = uuid::Uuid::new_v4().to_string();
1108 let message_id = MessageId::new();
1109 let body = Bytes::copy_from_slice(&delivery.data);
1110
1111 self.in_flight.lock().await.insert(
1112 receipt_id.clone(),
1113 InFlightEntry {
1114 channel: self.channel.clone(),
1115 delivery_tag,
1116 lock_expires_at,
1117 },
1118 );
1119
1120 ReceivedMessage {
1121 message_id,
1122 body,
1123 attributes,
1124 session_id: Some(self.session_id.clone()),
1125 correlation_id,
1126 receipt_handle: ReceiptHandle::new(receipt_id, lock_expires_at, ProviderType::RabbitMq),
1127 delivery_count,
1128 first_delivered_at: now,
1129 delivered_at: now,
1130 }
1131 }
1132
1133 async fn settle(
1135 &self,
1136 receipt: &ReceiptHandle,
1137 requeue: Option<bool>,
1138 ) -> Result<(), QueueError> {
1139 let mut in_flight = self.in_flight.lock().await;
1140
1141 match in_flight.get(receipt.handle()) {
1143 None => {
1144 return Err(QueueError::MessageNotFound {
1145 receipt: receipt.handle().to_string(),
1146 });
1147 }
1148 Some(entry) if Timestamp::now() > entry.lock_expires_at => {
1149 in_flight.remove(receipt.handle());
1150 return Err(QueueError::MessageNotFound {
1151 receipt: format!("{}(expired)", receipt.handle()),
1152 });
1153 }
1154 Some(_) => {}
1155 }
1156
1157 let entry = in_flight
1158 .remove(receipt.handle())
1159 .expect("entry present after pre-check");
1160
1161 match requeue {
1162 None => {
1163 entry
1164 .channel
1165 .basic_ack(entry.delivery_tag, BasicAckOptions::default())
1166 .await
1167 .map_err(|e| QueueError::ProviderError {
1168 provider: "rabbitmq".to_string(),
1169 code: "BASIC_ACK_FAILED".to_string(),
1170 message: format!("basic_ack failed: {}", e),
1171 })?;
1172 }
1173 Some(requeue_flag) => {
1174 entry
1175 .channel
1176 .basic_nack(
1177 entry.delivery_tag,
1178 BasicNackOptions {
1179 requeue: requeue_flag,
1180 ..Default::default()
1181 },
1182 )
1183 .await
1184 .map_err(|e| QueueError::ProviderError {
1185 provider: "rabbitmq".to_string(),
1186 code: "BASIC_NACK_FAILED".to_string(),
1187 message: format!("basic_nack failed: {}", e),
1188 })?;
1189 }
1190 }
1191
1192 Ok(())
1193 }
1194}