Skip to main content

queue_runtime/providers/
rabbitmq.rs

1//! RabbitMQ provider implementation using AMQP 0-9-1.
2//!
3//! This module provides production-ready RabbitMQ integration via the `lapin` AMQP client.
4//! It implements the [`QueueProvider`] and [`SessionProvider`] traits, enabling
5//! RabbitMQ to be used as a drop-in backend within the queue-runtime abstraction layer.
6//!
7//! ## Key Features
8//!
9//! - **AMQP 0-9-1**: Full native protocol support via the `lapin` crate
10//! - **Persistent queues**: Messages survive broker restarts (durable queues, persistent
11//!   delivery mode)
12//! - **Manual acknowledgement**: Messages remain unacked until explicitly completed,
13//!   abandoned, or dead-lettered, giving visibility-timeout-like semantics
14//! - **Dead letter exchange (DLX)**: Nacked messages are routed to a configured DLX,
15//!   mirroring the DLQ behaviour of other providers
16//! - **Session emulation**: Sessions are emulated via per-session sub-queues
17//!   (`{queue}.session.{session_id}`) with exclusive consumers
18//! - **Batch operations**: Up to 100 messages per batch
19//!
20//! ## Queue Naming
21//!
22//! Queues are declared as durable on first use.  The naming conventions are:
23//!
24//! - Main queue: `{queue_name}`
25//! - Session queue: `{queue_name}.session.{session_id}`
26//!
27//! ## Dead Letter Routing
28//!
29//! When `enable_dead_letter` is `true` the provider automatically binds each
30//! queue to the configured dead letter exchange.  Messages dead-lettered via
31//! [`dead_letter_message`] are nacked without requeue, causing RabbitMQ to
32//! forward them through the DLX to the dead letter queue.
33//!
34//! ## Session Support
35//!
36//! Sessions are emulated using per-session sub-queues.  Sending a message with a
37//! `session_id` routes it to `{queue}.session.{session_id}`.  A
38//! [`RabbitMqSessionProvider`] holds an exclusive consumer on that sub-queue for
39//! the duration of the session.
40//!
41//! ## Testing
42//!
43//! ```rust,no_run
44//! use queue_runtime::providers::RabbitMqProvider;
45//! use queue_runtime::RabbitMqConfig;
46//!
47//! # async fn test_example() {
48//! let config = RabbitMqConfig {
49//!     url: "amqp://guest:guest@localhost:5672".to_string(),
50//!     ..RabbitMqConfig::default()
51//! };
52//!
53//! let provider = RabbitMqProvider::new(config).await.unwrap();
54//! # }
55//! ```
56
57use crate::client::{QueueProvider, SessionProvider};
58use crate::error::QueueError;
59use crate::message::{
60    Message, MessageId, QueueName, ReceiptHandle, ReceivedMessage, SessionId, Timestamp,
61};
62use crate::provider::{ProviderType, SessionSupport};
63use async_trait::async_trait;
64use bytes::Bytes;
65use chrono::Duration;
66use futures::StreamExt;
67use lapin::{
68    options::{
69        BasicAckOptions, BasicConsumeOptions, BasicGetOptions, BasicNackOptions,
70        BasicPublishOptions, BasicQosOptions, QueueDeclareOptions,
71    },
72    types::{AMQPValue, FieldTable, LongString, ShortString},
73    BasicProperties, Channel, Connection, ConnectionProperties,
74};
75use serde::{Deserialize, Serialize};
76use std::collections::HashMap;
77use std::sync::Arc;
78use tokio::sync::{mpsc, Mutex};
79use tracing::{debug, instrument, warn};
80
81#[cfg(test)]
82#[path = "rabbitmq_tests.rs"]
83mod tests;
84
85// ============================================================================
86// Configuration
87// ============================================================================
88
89/// RabbitMQ provider configuration using AMQP 0-9-1
90///
91/// # Examples
92///
93/// ```rust
94/// use queue_runtime::RabbitMqConfig;
95/// use chrono::Duration;
96///
97/// let config = RabbitMqConfig {
98///     url: "amqp://guest:guest@localhost:5672".to_string(),
99///     virtual_host: "/".to_string(),
100///     prefetch_count: 10,
101///     session_lock_duration: Duration::minutes(5),
102///     message_ttl: None,
103///     enable_dead_letter: true,
104///     dead_letter_exchange: Some("dlx".to_string()),
105/// };
106/// ```
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct RabbitMqConfig {
109    /// AMQP connection URL (e.g. `amqp://user:pass@host:port/vhost`)
110    pub url: String,
111    /// RabbitMQ virtual host (defaults to `/`)
112    pub virtual_host: String,
113    /// Number of messages to prefetch per channel (0 = unlimited)
114    pub prefetch_count: u16,
115    /// Duration to hold a session lock before expiry
116    pub session_lock_duration: Duration,
117    /// Default message time-to-live
118    pub message_ttl: Option<Duration>,
119    /// Whether to enable dead letter queue routing
120    pub enable_dead_letter: bool,
121    /// Name of the dead letter exchange (required when `enable_dead_letter` is true)
122    pub dead_letter_exchange: Option<String>,
123}
124
125impl Default for RabbitMqConfig {
126    fn default() -> Self {
127        Self {
128            url: "amqp://guest:guest@localhost:5672".to_string(),
129            virtual_host: "/".to_string(),
130            prefetch_count: 10,
131            session_lock_duration: Duration::minutes(5),
132            message_ttl: None,
133            enable_dead_letter: true,
134            dead_letter_exchange: Some("dlx".to_string()),
135        }
136    }
137}
138
139// ============================================================================
140// Error types
141// ============================================================================
142
143/// RabbitMQ-specific error type that wraps AMQP errors.
144#[derive(Debug)]
145pub struct RabbitMqError {
146    message: String,
147}
148
149impl RabbitMqError {
150    fn new(message: impl Into<String>) -> Self {
151        Self {
152            message: message.into(),
153        }
154    }
155
156    /// Convert to a [`QueueError`].
157    pub fn to_queue_error(&self) -> QueueError {
158        QueueError::ProviderError {
159            provider: "rabbitmq".to_string(),
160            code: "AMQP_ERROR".to_string(),
161            message: self.message.clone(),
162        }
163    }
164}
165
166impl std::fmt::Display for RabbitMqError {
167    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
168        write!(f, "RabbitMQ error: {}", self.message)
169    }
170}
171
172impl std::error::Error for RabbitMqError {}
173
174// ============================================================================
175// Internal helpers
176// ============================================================================
177
178/// An in-flight message awaiting acknowledgement.
179struct InFlightEntry {
180    /// The AMQP channel the message was received on
181    channel: Channel,
182    /// AMQP delivery tag used for ack/nack
183    delivery_tag: u64,
184    /// Lock expiry for best-effort visibility timeout
185    lock_expires_at: Timestamp,
186}
187
188// ============================================================================
189// Helper functions
190// ============================================================================
191
192/// Redact any userinfo (username and password) from a URL, keeping host and path.
193///
194/// Used to prevent credential leakage in log fields and error messages when
195/// connection URLs contain embedded credentials
196/// (e.g. `amqp://user:pass@host:5672` → `amqp://***:***@host:5672`).
197fn redact_url(url: &str) -> String {
198    match url::Url::parse(url) {
199        Ok(mut parsed) => {
200            let has_credentials = !parsed.username().is_empty() || parsed.password().is_some();
201            if has_credentials {
202                // Errors here are non-fatal; fall through to return the redacted form.
203                let _ = parsed.set_username("***");
204                let _ = parsed.set_password(Some("***"));
205            }
206            parsed.to_string()
207        }
208        Err(_) => "<invalid-url>".to_string(),
209    }
210}
211
212/// Build the session sub-queue name from a queue name and session ID.
213fn session_queue_name(queue: &QueueName, session_id: &SessionId) -> String {
214    // Replace characters not valid in AMQP queue names.
215    let safe = session_id.as_str().replace(['/', ' ', '\\'], "_");
216    format!("{}.session.{}", queue.as_str(), safe)
217}
218
219// ============================================================================
220// RabbitMqProvider
221// ============================================================================
222
223/// RabbitMQ queue provider using AMQP 0-9-1 via the `lapin` crate.
224///
225/// All queues are declared as durable on first use.  Messages are published with
226/// `delivery_mode = 2` (persistent) so they survive broker restarts.
227///
228/// The provider emulates visibility timeouts by keeping in-flight messages in an
229/// unacked state on the AMQP broker.  If the application crashes before calling
230/// [`complete_message`] or [`abandon_message`], the AMQP broker will redeliver
231/// the message once the connection is dropped.
232///
233/// # Channel Strategy
234///
235/// A single persistent AMQP channel is reused for all publish operations and a
236/// second for all receive operations instead of opening a new channel per call.
237/// Both channels are recreated transparently if the broker closes them.  This
238/// keeps the channel count bounded regardless of message throughput and avoids
239/// exhausting the broker's per-connection channel limit.
240pub struct RabbitMqProvider {
241    connection: Arc<Connection>,
242    config: RabbitMqConfig,
243    /// In-flight messages indexed by their receipt handle UUID.
244    in_flight: Arc<Mutex<HashMap<String, InFlightEntry>>>,
245    /// Shared channel used for all publish operations.
246    publish_channel: Arc<Mutex<Option<Channel>>>,
247    /// Shared channel used for all receive operations.
248    receive_channel: Arc<Mutex<Option<Channel>>>,
249}
250
251impl RabbitMqProvider {
252    /// Create a new [`RabbitMqProvider`] and establish an AMQP connection.
253    ///
254    /// # Errors
255    ///
256    /// Returns [`RabbitMqError`] if the AMQP connection cannot be established or
257    /// if the URL is malformed.
258    ///
259    /// # Examples
260    ///
261    /// ```rust,no_run
262    /// use queue_runtime::providers::RabbitMqProvider;
263    /// use queue_runtime::RabbitMqConfig;
264    ///
265    /// # async fn example() {
266    /// let config = RabbitMqConfig::default();
267    /// let provider = RabbitMqProvider::new(config).await.unwrap();
268    /// # }
269    /// ```
270    pub async fn new(config: RabbitMqConfig) -> Result<Self, RabbitMqError> {
271        let conn = Connection::connect(&config.url, ConnectionProperties::default())
272            .await
273            .map_err(|e| {
274                RabbitMqError::new(format!(
275                    "failed to connect to RabbitMQ at '{}': {}",
276                    redact_url(&config.url),
277                    e
278                ))
279            })?;
280
281        debug!(url = %redact_url(&config.url), "Connected to RabbitMQ");
282
283        Ok(Self {
284            connection: Arc::new(conn),
285            config,
286            in_flight: Arc::new(Mutex::new(HashMap::new())),
287            publish_channel: Arc::new(Mutex::new(None)),
288            receive_channel: Arc::new(Mutex::new(None)),
289        })
290    }
291
292    /// Open a new AMQP channel, optionally configuring QoS prefetch.
293    async fn open_channel(&self) -> Result<Channel, QueueError> {
294        let channel =
295            self.connection
296                .create_channel()
297                .await
298                .map_err(|e| QueueError::ConnectionFailed {
299                    message: format!("failed to create AMQP channel: {}", e),
300                })?;
301
302        if self.config.prefetch_count > 0 {
303            channel
304                .basic_qos(self.config.prefetch_count, BasicQosOptions::default())
305                .await
306                .map_err(|e| QueueError::ProviderError {
307                    provider: "rabbitmq".to_string(),
308                    code: "QOS_FAILED".to_string(),
309                    message: format!("failed to set QoS prefetch: {}", e),
310                })?;
311        }
312
313        Ok(channel)
314    }
315
316    /// Return the shared publish channel, creating or recreating it if needed.
317    ///
318    /// If the stored channel has been closed by the broker (e.g. after a
319    /// channel error), a new channel is transparently opened and stored.
320    async fn get_publish_channel(&self) -> Result<Channel, QueueError> {
321        let mut guard = self.publish_channel.lock().await;
322        if let Some(ref ch) = *guard {
323            if ch.status().connected() {
324                return Ok(ch.clone());
325            }
326        }
327        let ch = self.open_channel().await?;
328        *guard = Some(ch.clone());
329        Ok(ch)
330    }
331
332    /// Return the shared receive channel, creating or recreating it if needed.
333    async fn get_receive_channel(&self) -> Result<Channel, QueueError> {
334        let mut guard = self.receive_channel.lock().await;
335        if let Some(ref ch) = *guard {
336            if ch.status().connected() {
337                return Ok(ch.clone());
338            }
339        }
340        let ch = self.open_channel().await?;
341        *guard = Some(ch.clone());
342        Ok(ch)
343    }
344
345    /// Declare a durable queue on the broker, optionally wiring up dead letter
346    /// exchange and message TTL via queue arguments.
347    async fn declare_queue(&self, channel: &Channel, queue: &QueueName) -> Result<(), QueueError> {
348        let mut args = FieldTable::default();
349
350        if self.config.enable_dead_letter {
351            if let Some(ref dlx) = self.config.dead_letter_exchange {
352                args.insert(
353                    ShortString::from("x-dead-letter-exchange"),
354                    AMQPValue::LongString(LongString::from(dlx.as_bytes())),
355                );
356            }
357        }
358
359        if let Some(ttl) = self.config.message_ttl {
360            let ttl_ms = ttl.num_milliseconds();
361            if ttl_ms > 0 {
362                args.insert(
363                    ShortString::from("x-message-ttl"),
364                    AMQPValue::LongLongInt(ttl_ms),
365                );
366            }
367        }
368
369        let opts = QueueDeclareOptions {
370            durable: true,
371            ..Default::default()
372        };
373
374        channel
375            .queue_declare(queue.as_str().into(), opts, args)
376            .await
377            .map_err(|e| QueueError::ProviderError {
378                provider: "rabbitmq".to_string(),
379                code: "QUEUE_DECLARE_FAILED".to_string(),
380                message: format!("failed to declare queue '{}': {}", queue.as_str(), e),
381            })?;
382
383        Ok(())
384    }
385
386    /// Declare a durable session sub-queue.
387    ///
388    /// Session queues are named `{queue}.session.{session_id}` and are bound to
389    /// the default exchange automatically.
390    async fn declare_session_queue(
391        &self,
392        channel: &Channel,
393        queue: &QueueName,
394        session_id: &SessionId,
395    ) -> Result<String, QueueError> {
396        let name = session_queue_name(queue, session_id);
397
398        let opts = QueueDeclareOptions {
399            durable: true,
400            ..Default::default()
401        };
402
403        channel
404            .queue_declare(name.as_str().into(), opts, FieldTable::default())
405            .await
406            .map_err(|e| QueueError::ProviderError {
407                provider: "rabbitmq".to_string(),
408                code: "SESSION_QUEUE_DECLARE_FAILED".to_string(),
409                message: format!("failed to declare session queue '{}': {}", name, e),
410            })?;
411
412        Ok(name)
413    }
414
415    /// Build AMQP [`BasicProperties`] from a [`Message`].
416    fn build_properties(message: &Message) -> BasicProperties {
417        let mut props = BasicProperties::default().with_delivery_mode(2); // persistent
418
419        if let Some(ref corr_id) = message.correlation_id {
420            props = props.with_correlation_id(ShortString::from(corr_id.as_str()));
421        }
422
423        if let Some(ttl) = message.time_to_live {
424            let ttl_ms = ttl.num_milliseconds();
425            if ttl_ms > 0 {
426                props = props.with_expiration(ShortString::from(ttl_ms.to_string().as_str()));
427            }
428        }
429
430        // Encode message attributes and session ID into AMQP headers.
431        // User attributes are prefixed with "x-attr-" so they survive the
432        // extract_attributes filter that drops broker/extension x- headers.
433        // This mirrors the NATS provider's "x-attr-" scheme and ensures that
434        // attribute names starting with "x-" roundtrip correctly.
435        let mut headers = FieldTable::default();
436        for (k, v) in &message.attributes {
437            let header_key = format!("x-attr-{}", k);
438            headers.insert(
439                ShortString::from(header_key.as_str()),
440                AMQPValue::LongString(LongString::from(v.as_bytes())),
441            );
442        }
443        if let Some(ref sid) = message.session_id {
444            headers.insert(
445                ShortString::from("x-session-id"),
446                AMQPValue::LongString(LongString::from(sid.as_str().as_bytes())),
447            );
448        }
449
450        props.with_headers(headers)
451    }
452
453    /// Extract message attributes from AMQP headers.
454    ///
455    /// Only headers with the `x-attr-` prefix are considered user attributes;
456    /// the prefix is stripped before returning. This scheme lets attribute keys
457    /// that themselves start with `x-` (e.g. `x-custom-meta`) round-trip
458    /// correctly while keeping AMQP broker/extension headers (`x-session-id`,
459    /// `x-delivery-count`, `x-death`, …) out of the attribute map.
460    fn extract_attributes(headers: &Option<FieldTable>) -> HashMap<String, String> {
461        let mut attrs = HashMap::new();
462        if let Some(ht) = headers {
463            for (k, v) in ht.inner() {
464                let key = k.as_str();
465                // Only extract user-attribute headers (written with "x-attr-" prefix).
466                if let Some(attr_key) = key.strip_prefix("x-attr-") {
467                    if let AMQPValue::LongString(s) = v {
468                        attrs.insert(
469                            attr_key.to_string(),
470                            String::from_utf8_lossy(s.as_bytes()).to_string(),
471                        );
472                    }
473                }
474            }
475        }
476        attrs
477    }
478
479    /// Extract the session ID from AMQP headers.
480    fn extract_session_id(headers: &Option<FieldTable>) -> Option<SessionId> {
481        if let Some(ht) = headers {
482            if let Some(AMQPValue::LongString(s)) = ht.inner().get("x-session-id") {
483                let id = String::from_utf8_lossy(s.as_bytes()).to_string();
484                return SessionId::new(id).ok();
485            }
486        }
487        None
488    }
489
490    /// Extract the delivery count from AMQP `x-delivery-count` headers.
491    ///
492    /// RabbitMQ Classic Queues do not set `x-delivery-count` on nack+requeue;
493    /// they only set the `redelivered` flag in the delivery frame.  When the
494    /// header is absent we fall back to `redelivered` to distinguish a first
495    /// delivery (1) from at least one redeliver (2).
496    fn extract_delivery_count(headers: &Option<FieldTable>, redelivered: bool) -> u32 {
497        if let Some(ht) = headers {
498            if let Some(AMQPValue::LongLongInt(n)) = ht.inner().get("x-delivery-count") {
499                return (*n as u32).saturating_add(1);
500            }
501        }
502        if redelivered {
503            2
504        } else {
505            1
506        }
507    }
508
509    /// Register a delivery in the in-flight map and return its [`ReceivedMessage`].
510    async fn register_delivery(
511        &self,
512        channel: &Channel,
513        delivery_tag: u64,
514        data: &[u8],
515        headers: Option<FieldTable>,
516        correlation_id: Option<String>,
517        redelivered: bool,
518    ) -> ReceivedMessage {
519        let session_id = Self::extract_session_id(&headers);
520        let attributes = Self::extract_attributes(&headers);
521        let delivery_count = Self::extract_delivery_count(&headers, redelivered);
522
523        let now = Timestamp::now();
524        let lock_expires_at =
525            Timestamp::from_datetime(now.as_datetime() + self.config.session_lock_duration);
526
527        let receipt_id = uuid::Uuid::new_v4().to_string();
528        let message_id = MessageId::new();
529        let body = Bytes::copy_from_slice(data);
530
531        self.in_flight.lock().await.insert(
532            receipt_id.clone(),
533            InFlightEntry {
534                channel: channel.clone(),
535                delivery_tag,
536                lock_expires_at,
537            },
538        );
539
540        ReceivedMessage {
541            message_id,
542            body,
543            attributes,
544            session_id,
545            correlation_id,
546            receipt_handle: ReceiptHandle::new(receipt_id, lock_expires_at, ProviderType::RabbitMq),
547            delivery_count,
548            first_delivered_at: now,
549            delivered_at: now,
550        }
551    }
552
553    /// Ack, nack-requeue, or nack-discard a message identified by its receipt handle.
554    async fn settle_message(
555        &self,
556        receipt: &ReceiptHandle,
557        requeue: Option<bool>,
558    ) -> Result<(), QueueError> {
559        let mut in_flight = self.in_flight.lock().await;
560
561        // Check existence and lock expiry before removal so callers receive a
562        // meaningful error: MessageNotFound for an unknown receipt, and a
563        // separate expiry error for a receipt whose lock has lapsed.  Expired
564        // entries are then pruned to prevent unbounded map growth.
565        match in_flight.get(receipt.handle()) {
566            None => {
567                return Err(QueueError::MessageNotFound {
568                    receipt: receipt.handle().to_string(),
569                });
570            }
571            Some(entry) if Timestamp::now() > entry.lock_expires_at => {
572                // Remove the stale entry and surface a clear expiry error.
573                in_flight.remove(receipt.handle());
574                return Err(QueueError::MessageNotFound {
575                    receipt: format!("{}(expired)", receipt.handle()),
576                });
577            }
578            Some(_) => {}
579        }
580
581        let entry = in_flight
582            .remove(receipt.handle())
583            .expect("entry present after pre-check");
584
585        match requeue {
586            None => {
587                // Complete (ack)
588                entry
589                    .channel
590                    .basic_ack(entry.delivery_tag, BasicAckOptions::default())
591                    .await
592                    .map_err(|e| QueueError::ProviderError {
593                        provider: "rabbitmq".to_string(),
594                        code: "BASIC_ACK_FAILED".to_string(),
595                        message: format!("basic_ack failed: {}", e),
596                    })?;
597            }
598            Some(requeue_flag) => {
599                // Abandon (requeue=true) or dead-letter (requeue=false)
600                entry
601                    .channel
602                    .basic_nack(
603                        entry.delivery_tag,
604                        BasicNackOptions {
605                            requeue: requeue_flag,
606                            ..Default::default()
607                        },
608                    )
609                    .await
610                    .map_err(|e| QueueError::ProviderError {
611                        provider: "rabbitmq".to_string(),
612                        code: "BASIC_NACK_FAILED".to_string(),
613                        message: format!("basic_nack failed: {}", e),
614                    })?;
615            }
616        }
617
618        Ok(())
619    }
620}
621
622// ============================================================================
623// QueueProvider implementation
624// ============================================================================
625
626#[async_trait]
627impl QueueProvider for RabbitMqProvider {
628    #[instrument(skip(self, message), fields(queue = %queue))]
629    async fn send_message(
630        &self,
631        queue: &QueueName,
632        message: &Message,
633    ) -> Result<MessageId, QueueError> {
634        let size = message.body.len();
635        let max_size = self.provider_type().max_message_size();
636        if size > max_size {
637            return Err(QueueError::MessageTooLarge { size, max_size });
638        }
639
640        let channel = self.get_publish_channel().await?;
641
642        // Route session messages to their dedicated sub-queue; all others go to
643        // the main queue via the default exchange.
644        let routing_key = if let Some(ref sid) = message.session_id {
645            self.declare_session_queue(&channel, queue, sid).await?
646        } else {
647            self.declare_queue(&channel, queue).await?;
648            queue.as_str().to_string()
649        };
650
651        let props = Self::build_properties(message);
652
653        channel
654            .basic_publish(
655                "".into(),
656                routing_key.as_str().into(),
657                BasicPublishOptions::default(),
658                &message.body,
659                props,
660            )
661            .await
662            .map_err(|e| QueueError::ProviderError {
663                provider: "rabbitmq".to_string(),
664                code: "PUBLISH_FAILED".to_string(),
665                message: format!("failed to publish message to '{}': {}", routing_key, e),
666            })?
667            .await
668            .map_err(|e| QueueError::ProviderError {
669                provider: "rabbitmq".to_string(),
670                code: "PUBLISH_CONFIRM_FAILED".to_string(),
671                message: format!("publish confirmation failed: {}", e),
672            })?;
673
674        let message_id = MessageId::new();
675        debug!(%message_id, %queue, "Published message to RabbitMQ");
676        Ok(message_id)
677    }
678
679    #[instrument(skip(self, messages), fields(queue = %queue, count = messages.len()))]
680    async fn send_messages(
681        &self,
682        queue: &QueueName,
683        messages: &[Message],
684    ) -> Result<Vec<MessageId>, QueueError> {
685        if messages.len() > self.max_batch_size() as usize {
686            return Err(QueueError::BatchTooLarge {
687                size: messages.len(),
688                max_size: self.max_batch_size() as usize,
689            });
690        }
691
692        // AMQP 0-9-1 has no native batch-publish command; messages are sent sequentially.
693        // This satisfies the batch API contract (multiple IDs returned in one call)
694        // but does not reduce the number of network round-trips compared with
695        // individual send_message calls.  The AMQP protocol does not expose a
696        // PublishBatch primitive, so sequential sending is the best achievable
697        // implementation.  See docs/spec/assertions.md Assertion 20 for context.
698        let mut ids = Vec::with_capacity(messages.len());
699        for message in messages {
700            ids.push(self.send_message(queue, message).await?);
701        }
702        Ok(ids)
703    }
704
705    #[instrument(skip(self), fields(queue = %queue))]
706    async fn receive_message(
707        &self,
708        queue: &QueueName,
709        timeout: Duration,
710    ) -> Result<Option<ReceivedMessage>, QueueError> {
711        let channel = self.get_receive_channel().await?;
712        self.declare_queue(&channel, queue).await?;
713
714        let start = std::time::Instant::now();
715        let timeout_std = timeout
716            .to_std()
717            .unwrap_or(std::time::Duration::from_secs(30));
718
719        loop {
720            let get = channel
721                .basic_get(queue.as_str().into(), BasicGetOptions { no_ack: false })
722                .await
723                .map_err(|e| QueueError::ProviderError {
724                    provider: "rabbitmq".to_string(),
725                    code: "BASIC_GET_FAILED".to_string(),
726                    message: format!("basic_get on '{}' failed: {}", queue.as_str(), e),
727                })?;
728
729            if let Some(delivery) = get {
730                let headers = delivery.delivery.properties.headers().clone();
731                let redelivered = delivery.delivery.redelivered;
732                let correlation_id = delivery
733                    .delivery
734                    .properties
735                    .correlation_id()
736                    .as_ref()
737                    .map(|s| s.to_string());
738                let msg = self
739                    .register_delivery(
740                        &channel,
741                        delivery.delivery.delivery_tag,
742                        &delivery.delivery.data,
743                        headers,
744                        correlation_id,
745                        redelivered,
746                    )
747                    .await;
748                return Ok(Some(msg));
749            }
750
751            if start.elapsed() >= timeout_std {
752                return Ok(None);
753            }
754
755            tokio::time::sleep(std::time::Duration::from_millis(50)).await;
756        }
757    }
758
759    #[instrument(skip(self), fields(queue = %queue, max = max_messages))]
760    async fn receive_messages(
761        &self,
762        queue: &QueueName,
763        max_messages: u32,
764        timeout: Duration,
765    ) -> Result<Vec<ReceivedMessage>, QueueError> {
766        let channel = self.get_receive_channel().await?;
767        self.declare_queue(&channel, queue).await?;
768
769        let mut messages = Vec::new();
770        let start = std::time::Instant::now();
771        let timeout_std = timeout
772            .to_std()
773            .unwrap_or(std::time::Duration::from_secs(30));
774
775        while messages.len() < max_messages as usize {
776            if start.elapsed() >= timeout_std {
777                break;
778            }
779
780            let get = channel
781                .basic_get(queue.as_str().into(), BasicGetOptions { no_ack: false })
782                .await
783                .map_err(|e| QueueError::ProviderError {
784                    provider: "rabbitmq".to_string(),
785                    code: "BASIC_GET_FAILED".to_string(),
786                    message: format!("basic_get on '{}' failed: {}", queue.as_str(), e),
787                })?;
788
789            match get {
790                Some(delivery) => {
791                    let headers = delivery.delivery.properties.headers().clone();
792                    let redelivered = delivery.delivery.redelivered;
793                    let correlation_id = delivery
794                        .delivery
795                        .properties
796                        .correlation_id()
797                        .as_ref()
798                        .map(|s| s.to_string());
799                    let msg = self
800                        .register_delivery(
801                            &channel,
802                            delivery.delivery.delivery_tag,
803                            &delivery.delivery.data,
804                            headers,
805                            correlation_id,
806                            redelivered,
807                        )
808                        .await;
809                    messages.push(msg);
810                }
811                // Queue is empty; poll with backoff until timeout, mirroring the
812                // behaviour of `receive_message` rather than returning immediately.
813                None => {
814                    if start.elapsed() >= timeout_std {
815                        break;
816                    }
817                    tokio::time::sleep(std::time::Duration::from_millis(50)).await;
818                }
819            }
820        }
821
822        Ok(messages)
823    }
824
825    #[instrument(skip(self, receipt))]
826    async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
827        self.settle_message(receipt, None).await
828    }
829
830    #[instrument(skip(self, receipt))]
831    async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
832        self.settle_message(receipt, Some(true)).await
833    }
834
835    #[instrument(skip(self, receipt), fields(reason = %reason))]
836    async fn dead_letter_message(
837        &self,
838        receipt: &ReceiptHandle,
839        reason: &str,
840    ) -> Result<(), QueueError> {
841        debug!(reason, "Dead-lettering RabbitMQ message");
842        // The `reason` argument is logged but cannot be forwarded to the DLX because the
843        // AMQP `basic_nack` command carries no metadata payload.  Use the NATS provider
844        // if dead-letter reason propagation is required.
845        self.settle_message(receipt, Some(false)).await
846    }
847
848    #[instrument(skip(self), fields(queue = %queue))]
849    async fn create_session_client(
850        &self,
851        queue: &QueueName,
852        session_id: Option<SessionId>,
853    ) -> Result<Box<dyn SessionProvider>, QueueError> {
854        let sid = match session_id {
855            Some(id) => id,
856            None => {
857                // RabbitMQ does not enumerate sessions; an explicit ID is required.
858                return Err(QueueError::SessionNotFound {
859                    session_id: "<any>".to_string(),
860                });
861            }
862        };
863
864        let channel = self.open_channel().await?;
865        let session_queue = self.declare_session_queue(&channel, queue, &sid).await?;
866
867        // Create an exclusive consumer on the session sub-queue so only one
868        // active session client can process messages at a time.
869        let consumer = channel
870            .basic_consume(
871                session_queue.as_str().into(),
872                format!("session-{}", uuid::Uuid::new_v4()).as_str().into(),
873                BasicConsumeOptions {
874                    exclusive: true,
875                    no_ack: false,
876                    ..Default::default()
877                },
878                FieldTable::default(),
879            )
880            .await
881            .map_err(|e| QueueError::ProviderError {
882                provider: "rabbitmq".to_string(),
883                code: "CONSUME_FAILED".to_string(),
884                message: format!(
885                    "failed to start exclusive consumer on '{}': {}",
886                    session_queue, e
887                ),
888            })?;
889
890        let now = Timestamp::now();
891        let lock_expires_at =
892            Timestamp::from_datetime(now.as_datetime() + self.config.session_lock_duration);
893
894        // Spawn a background task that forwards deliveries over a channel so
895        // we can use standard tokio async patterns without depending on
896        // futures_lite directly in application code.
897        let (tx, rx) = mpsc::unbounded_channel::<lapin::message::Delivery>();
898        tokio::spawn(async move {
899            let mut consumer = consumer;
900            while let Some(result) = consumer.next().await {
901                match result {
902                    Ok(delivery) => {
903                        if tx.send(delivery).is_err() {
904                            break;
905                        }
906                    }
907                    Err(e) => {
908                        warn!(error = %e, "RabbitMQ session consumer error");
909                        break;
910                    }
911                }
912            }
913        });
914
915        Ok(Box::new(RabbitMqSessionProvider {
916            channel,
917            deliveries: Arc::new(Mutex::new(rx)),
918            session_id: sid,
919            in_flight: self.in_flight.clone(),
920            lock_expires_at: Arc::new(std::sync::Mutex::new(lock_expires_at)),
921            config: self.config.clone(),
922        }))
923    }
924
925    fn provider_type(&self) -> ProviderType {
926        ProviderType::RabbitMq
927    }
928
929    fn supports_sessions(&self) -> SessionSupport {
930        SessionSupport::Emulated
931    }
932
933    fn supports_batching(&self) -> bool {
934        true
935    }
936
937    fn max_batch_size(&self) -> u32 {
938        100
939    }
940}
941
942// ============================================================================
943// RabbitMqSessionProvider
944// ============================================================================
945
946/// Session provider for RabbitMQ backed by an exclusive per-session consumer.
947///
948/// An exclusive consumer on the session sub-queue ensures that only one active
949/// [`RabbitMqSessionProvider`] processes messages for a given session at a time,
950/// providing the ordering and exclusive-access semantics expected of sessions.
951pub struct RabbitMqSessionProvider {
952    channel: Channel,
953    /// Tokio channel receiving deliveries forwarded from the AMQP consumer.
954    deliveries: Arc<Mutex<mpsc::UnboundedReceiver<lapin::message::Delivery>>>,
955    session_id: SessionId,
956    in_flight: Arc<Mutex<HashMap<String, InFlightEntry>>>,
957    /// Current session-lock expiry; shared so `renew_session_lock` can update it.
958    lock_expires_at: Arc<std::sync::Mutex<Timestamp>>,
959    config: RabbitMqConfig,
960}
961
962#[async_trait]
963impl SessionProvider for RabbitMqSessionProvider {
964    #[instrument(skip(self), fields(session_id = %self.session_id))]
965    async fn receive_message(
966        &self,
967        timeout: Duration,
968    ) -> Result<Option<ReceivedMessage>, QueueError> {
969        self.check_lock()?;
970
971        let timeout_std = timeout
972            .to_std()
973            .unwrap_or(std::time::Duration::from_secs(30));
974
975        let mut rx = self.deliveries.lock().await;
976        match tokio::time::timeout(timeout_std, rx.recv()).await {
977            Ok(Some(delivery)) => {
978                let msg = self.register_session_delivery(delivery).await;
979                Ok(Some(msg))
980            }
981            Ok(None) => Ok(None),
982            Err(_) => Ok(None), // timeout
983        }
984    }
985
986    #[instrument(skip(self, receipt))]
987    async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
988        self.check_lock()?;
989        self.settle(receipt, None).await
990    }
991
992    #[instrument(skip(self, receipt))]
993    async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
994        self.check_lock()?;
995        self.settle(receipt, Some(true)).await
996    }
997
998    #[instrument(skip(self, receipt), fields(reason = %reason))]
999    async fn dead_letter_message(
1000        &self,
1001        receipt: &ReceiptHandle,
1002        reason: &str,
1003    ) -> Result<(), QueueError> {
1004        self.check_lock()?;
1005        debug!(reason, "Dead-lettering session message");
1006        self.settle(receipt, Some(false)).await
1007    }
1008
1009    async fn renew_session_lock(&self) -> Result<(), QueueError> {
1010        advance_session_lock(&self.lock_expires_at, self.config.session_lock_duration)?;
1011        debug!(session_id = %self.session_id, "RabbitMQ session lock renewed");
1012        Ok(())
1013    }
1014
1015    async fn close_session(&self) -> Result<(), QueueError> {
1016        if let Err(e) = self.channel.close(200, "session closed".into()).await {
1017            warn!(error = %e, "Failed to cleanly close RabbitMQ session channel");
1018        }
1019        Ok(())
1020    }
1021
1022    fn session_id(&self) -> &SessionId {
1023        &self.session_id
1024    }
1025
1026    fn session_expires_at(&self) -> Timestamp {
1027        // Recover from a poisoned lock by using the last known value.
1028        *self
1029            .lock_expires_at
1030            .lock()
1031            .unwrap_or_else(|e| e.into_inner())
1032    }
1033}
1034
1035// ============================================================================
1036// Session lock helpers — module-level so they are testable without a live broker
1037// ============================================================================
1038
1039/// Return an error when the session lock timestamp has expired.
1040///
1041/// Extracted from [`RabbitMqSessionProvider::check_lock`] so the expiry logic can
1042/// be verified in unit tests without constructing a full provider.
1043fn check_session_lock(
1044    lock_expires_at: &std::sync::Mutex<Timestamp>,
1045    session_id: &SessionId,
1046) -> Result<(), QueueError> {
1047    let expires = *lock_expires_at
1048        .lock()
1049        .map_err(|_| QueueError::ProviderError {
1050            provider: "rabbitmq".to_string(),
1051            code: "INTERNAL_ERROR".to_string(),
1052            message: "session lock mutex poisoned".to_string(),
1053        })?;
1054    if Timestamp::now() > expires {
1055        return Err(QueueError::SessionLocked {
1056            session_id: session_id.as_str().to_string(),
1057            locked_until: expires,
1058        });
1059    }
1060    Ok(())
1061}
1062
1063/// Advance the session lock by `duration` from now and return the new expiry.
1064///
1065/// Extracted from [`RabbitMqSessionProvider::renew_session_lock`] for the same reason.
1066fn advance_session_lock(
1067    lock_expires_at: &std::sync::Mutex<Timestamp>,
1068    duration: Duration,
1069) -> Result<Timestamp, QueueError> {
1070    let new_expiry = Timestamp::from_datetime(Timestamp::now().as_datetime() + duration);
1071    *lock_expires_at
1072        .lock()
1073        .map_err(|_| QueueError::ProviderError {
1074            provider: "rabbitmq".to_string(),
1075            code: "INTERNAL_ERROR".to_string(),
1076            message: "session lock mutex poisoned".to_string(),
1077        })? = new_expiry;
1078    Ok(new_expiry)
1079}
1080
1081impl RabbitMqSessionProvider {
1082    /// Return an error if the session lock has expired.
1083    fn check_lock(&self) -> Result<(), QueueError> {
1084        check_session_lock(&self.lock_expires_at, &self.session_id)
1085    }
1086
1087    /// Register a delivery in the in-flight map and build a [`ReceivedMessage`].
1088    async fn register_session_delivery(
1089        &self,
1090        delivery: lapin::message::Delivery,
1091    ) -> ReceivedMessage {
1092        let delivery_tag = delivery.delivery_tag;
1093        let redelivered = delivery.redelivered;
1094        let headers = delivery.properties.headers().clone();
1095        let attributes = RabbitMqProvider::extract_attributes(&headers);
1096        let delivery_count = RabbitMqProvider::extract_delivery_count(&headers, redelivered);
1097        let correlation_id = delivery
1098            .properties
1099            .correlation_id()
1100            .as_ref()
1101            .map(|s| s.to_string());
1102
1103        let now = Timestamp::now();
1104        let lock_expires_at =
1105            Timestamp::from_datetime(now.as_datetime() + self.config.session_lock_duration);
1106
1107        let receipt_id = uuid::Uuid::new_v4().to_string();
1108        let message_id = MessageId::new();
1109        let body = Bytes::copy_from_slice(&delivery.data);
1110
1111        self.in_flight.lock().await.insert(
1112            receipt_id.clone(),
1113            InFlightEntry {
1114                channel: self.channel.clone(),
1115                delivery_tag,
1116                lock_expires_at,
1117            },
1118        );
1119
1120        ReceivedMessage {
1121            message_id,
1122            body,
1123            attributes,
1124            session_id: Some(self.session_id.clone()),
1125            correlation_id,
1126            receipt_handle: ReceiptHandle::new(receipt_id, lock_expires_at, ProviderType::RabbitMq),
1127            delivery_count,
1128            first_delivered_at: now,
1129            delivered_at: now,
1130        }
1131    }
1132
1133    /// Ack, nack-requeue, or nack-discard a message identified by its receipt handle.
1134    async fn settle(
1135        &self,
1136        receipt: &ReceiptHandle,
1137        requeue: Option<bool>,
1138    ) -> Result<(), QueueError> {
1139        let mut in_flight = self.in_flight.lock().await;
1140
1141        // Check existence and lock expiry before removal (mirrors settle_message).
1142        match in_flight.get(receipt.handle()) {
1143            None => {
1144                return Err(QueueError::MessageNotFound {
1145                    receipt: receipt.handle().to_string(),
1146                });
1147            }
1148            Some(entry) if Timestamp::now() > entry.lock_expires_at => {
1149                in_flight.remove(receipt.handle());
1150                return Err(QueueError::MessageNotFound {
1151                    receipt: format!("{}(expired)", receipt.handle()),
1152                });
1153            }
1154            Some(_) => {}
1155        }
1156
1157        let entry = in_flight
1158            .remove(receipt.handle())
1159            .expect("entry present after pre-check");
1160
1161        match requeue {
1162            None => {
1163                entry
1164                    .channel
1165                    .basic_ack(entry.delivery_tag, BasicAckOptions::default())
1166                    .await
1167                    .map_err(|e| QueueError::ProviderError {
1168                        provider: "rabbitmq".to_string(),
1169                        code: "BASIC_ACK_FAILED".to_string(),
1170                        message: format!("basic_ack failed: {}", e),
1171                    })?;
1172            }
1173            Some(requeue_flag) => {
1174                entry
1175                    .channel
1176                    .basic_nack(
1177                        entry.delivery_tag,
1178                        BasicNackOptions {
1179                            requeue: requeue_flag,
1180                            ..Default::default()
1181                        },
1182                    )
1183                    .await
1184                    .map_err(|e| QueueError::ProviderError {
1185                        provider: "rabbitmq".to_string(),
1186                        code: "BASIC_NACK_FAILED".to_string(),
1187                        message: format!("basic_nack failed: {}", e),
1188                    })?;
1189            }
1190        }
1191
1192        Ok(())
1193    }
1194}