Skip to main content

queue_runtime/providers/
nats.rs

1//! NATS provider implementation using JetStream.
2//!
3//! This module provides production-ready NATS integration via the `async-nats` client
4//! with JetStream.  It implements the [`QueueProvider`] and [`SessionProvider`] traits,
5//! enabling NATS to be used as a drop-in backend in the queue-runtime abstraction layer.
6//!
7//! ## Key Features
8//!
9//! - **JetStream**: Persistent, at-least-once message delivery with ack/nak semantics
10//! - **Pull consumers**: Explicit message fetch with configurable ack-wait (visibility
11//!   timeout analog)
12//! - **Dead letter support**: Failed messages forwarded to a configurable DLQ stream
13//! - **Session emulation**: Sessions emulated via per-session filter subjects within a
14//!   shared stream
15//! - **Batch operations**: Up to 100 messages per batch
16//!
17//! ## Stream and Subject Naming
18//!
19//! The provider creates one JetStream stream per queue.  Naming conventions:
20//!
21//! - Stream name: `{stream_prefix}-{queue_name}` (hyphens replace underscores for
22//!   NATS compatibility)
23//! - Subject: `{stream_prefix}.{queue_name}`
24//! - Session subject: `{stream_prefix}.{queue_name}.session.{session_id}`
25//!
26//! ## Dead Letter Support
27//!
28//! When `enable_dead_letter` is `true` and `dead_letter_subject_prefix` is set,
29//! messages dead-lettered via [`dead_letter_message`] are published to
30//! `{prefix}.{queue_name}` using a separate JetStream stream for DLQ messages.
31//!
32//! ## Session Support
33//!
34//! Sessions are emulated via JetStream subject filtering.  Each session gets its own
35//! subject (`{prefix}.{queue}.session.{session_id}`), and a [`NatsSessionProvider`]
36//! creates a per-session pull consumer filtered to that subject.
37//!
38//! ## Connection
39//!
40//! The provider uses the `async-nats` client which reconnects automatically on
41//! connection loss.  Optional NATS credentials can be loaded from a `.creds` file
42//! via the `credentials_path` configuration field.
43//!
44//! ## Testing
45//!
46//! ```rust,no_run
47//! use queue_runtime::providers::NatsProvider;
48//! use queue_runtime::NatsConfig;
49//!
50//! # async fn test_example() {
51//! let config = NatsConfig {
52//!     url: "nats://localhost:4222".to_string(),
53//!     ..NatsConfig::default()
54//! };
55//!
56//! let provider = NatsProvider::new(config).await.unwrap();
57//! # }
58//! ```
59
60use crate::client::{QueueProvider, SessionProvider};
61use crate::error::QueueError;
62use crate::message::{
63    Message, MessageId, QueueName, ReceiptHandle, ReceivedMessage, SessionId, Timestamp,
64};
65use crate::provider::{ProviderType, SessionSupport};
66use async_nats::jetstream::{
67    self, consumer::pull::Config as ConsumerConfig, stream::Config as StreamConfig, AckKind,
68    Context as JetStreamContext,
69};
70use async_trait::async_trait;
71use bytes::Bytes;
72use chrono::Duration;
73use futures::StreamExt;
74use serde::{Deserialize, Serialize};
75use std::collections::HashMap;
76use std::sync::Arc;
77use tokio::sync::Mutex;
78use tracing::{debug, instrument, warn};
79
80#[cfg(test)]
81#[path = "nats_tests.rs"]
82mod tests;
83
84// ============================================================================
85// Configuration
86// ============================================================================
87
88/// NATS provider configuration using JetStream
89///
90/// # Examples
91///
92/// ```rust
93/// use queue_runtime::NatsConfig;
94/// use chrono::Duration;
95///
96/// let config = NatsConfig {
97///     url: "nats://localhost:4222".to_string(),
98///     stream_prefix: "queue-runtime".to_string(),
99///     max_deliver: Some(3),
100///     ack_wait: Duration::seconds(30),
101///     session_lock_duration: Duration::minutes(5),
102///     enable_dead_letter: true,
103///     dead_letter_subject_prefix: Some("dlq".to_string()),
104///     credentials_path: None,
105/// };
106/// ```
107#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct NatsConfig {
109    /// NATS server URL (e.g. `nats://localhost:4222` or `nats://user:pass@host:port`)
110    pub url: String,
111    /// Prefix for JetStream stream names (stream name = `{prefix}-{queue_name}`)
112    pub stream_prefix: String,
113    /// Maximum number of delivery attempts before giving up (None = unlimited)
114    pub max_deliver: Option<i64>,
115    /// Duration to wait for ack before re-delivering (visibility timeout analog)
116    pub ack_wait: Duration,
117    /// Duration to hold a session lock before expiry
118    pub session_lock_duration: Duration,
119    /// Whether to enable dead letter queue routing via a separate stream
120    pub enable_dead_letter: bool,
121    /// Subject prefix for dead letter messages (`{prefix}.{queue}`)
122    pub dead_letter_subject_prefix: Option<String>,
123    /// Path to NATS credentials file (`.creds` format)
124    pub credentials_path: Option<String>,
125}
126
127impl Default for NatsConfig {
128    fn default() -> Self {
129        Self {
130            url: "nats://localhost:4222".to_string(),
131            stream_prefix: "queue-runtime".to_string(),
132            max_deliver: Some(3),
133            ack_wait: Duration::seconds(30),
134            session_lock_duration: Duration::minutes(5),
135            enable_dead_letter: true,
136            dead_letter_subject_prefix: Some("dlq".to_string()),
137            credentials_path: None,
138        }
139    }
140}
141
142// ============================================================================
143// Error types
144// ============================================================================
145
146/// NATS-specific error type.
147#[derive(Debug)]
148pub struct NatsError {
149    message: String,
150}
151
152impl NatsError {
153    fn new(message: impl Into<String>) -> Self {
154        Self {
155            message: message.into(),
156        }
157    }
158
159    /// Convert to a [`QueueError`].
160    pub fn to_queue_error(&self) -> QueueError {
161        QueueError::ProviderError {
162            provider: "nats".to_string(),
163            code: "NATS_ERROR".to_string(),
164            message: self.message.clone(),
165        }
166    }
167}
168
169impl std::fmt::Display for NatsError {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        write!(f, "NATS error: {}", self.message)
172    }
173}
174
175impl std::error::Error for NatsError {}
176
177// ============================================================================
178// Internal helpers
179// ============================================================================
180
181/// An in-flight JetStream message pending acknowledgement.
182struct InFlightEntry {
183    /// The JetStream message (contains ack/nak methods)
184    js_message: async_nats::jetstream::Message,
185    /// Lock expiry (maps to the JetStream ack-wait on the consumer)
186    lock_expires_at: Timestamp,
187    /// Dead letter subject for this message's queue
188    dead_letter_subject: Option<String>,
189}
190
191// ============================================================================
192// Helper functions
193// ============================================================================
194
195/// Redact any userinfo (username and password) from a URL, keeping host and path.
196///
197/// Used to prevent credential leakage in log fields and error messages when
198/// connection URLs contain embedded credentials
199/// (e.g. `nats://user:pass@host:4222` → `nats://***:***@host:4222`).
200fn redact_url(url: &str) -> String {
201    match url::Url::parse(url) {
202        Ok(mut parsed) => {
203            let has_credentials = !parsed.username().is_empty() || parsed.password().is_some();
204            if has_credentials {
205                let _ = parsed.set_username("***");
206                let _ = parsed.set_password(Some("***"));
207            }
208            parsed.to_string()
209        }
210        Err(_) => "<invalid-url>".to_string(),
211    }
212}
213
214/// Sanitise a queue name for use in NATS subject/stream identifiers.
215///
216/// NATS subjects use `.` as a separator and do not allow spaces.
217fn nats_safe(s: &str) -> String {
218    s.replace(['-', ' '], "_")
219}
220
221/// Build the NATS subject for a queue.
222fn queue_subject(config: &NatsConfig, queue: &QueueName) -> String {
223    format!(
224        "{}.{}",
225        nats_safe(&config.stream_prefix),
226        nats_safe(queue.as_str())
227    )
228}
229
230/// Sanitise a session ID for use in NATS subject and consumer name identifiers.
231///
232/// NATS subjects use `.` as a wildcard-separator and consumer/stream names must
233/// contain only alphanumeric characters, underscores, or hyphens.  This helper
234/// replaces every character that is not an ASCII alphanumeric or `_` with `_`,
235/// covering all NATS-invalid characters (`.`, `/`, ` `, `*`, `>`, etc.).
236fn nats_safe_session_id(id: &str) -> String {
237    id.chars()
238        .map(|c| {
239            if c.is_ascii_alphanumeric() || c == '_' {
240                c
241            } else {
242                '_'
243            }
244        })
245        .collect()
246}
247
248/// Build the NATS subject for a session within a queue.
249fn session_subject(config: &NatsConfig, queue: &QueueName, session_id: &SessionId) -> String {
250    let safe_session = nats_safe_session_id(session_id.as_str());
251    format!(
252        "{}.{}.session.{}",
253        nats_safe(&config.stream_prefix),
254        nats_safe(queue.as_str()),
255        safe_session
256    )
257}
258
259/// Build the JetStream stream name for a queue.
260fn stream_name(config: &NatsConfig, queue: &QueueName) -> String {
261    // JetStream stream names may not contain dots.
262    format!(
263        "{}-{}",
264        nats_safe(&config.stream_prefix),
265        nats_safe(queue.as_str())
266    )
267}
268
269/// Build a stable durable consumer name for a queue.
270///
271/// Using a named durable consumer (rather than an ephemeral one) means that
272/// `Info::delivered` increments across successive `receive_message` calls,
273/// even when each call creates a fresh handle to the server-side consumer.
274/// With ephemeral consumers the server issues a new consumer-sequence on each
275/// call, and `num_delivered` is always 1 from the new consumer's perspective.
276fn consumer_name(config: &NatsConfig, queue: &QueueName) -> String {
277    format!(
278        "{}-{}-consumer",
279        nats_safe(&config.stream_prefix),
280        nats_safe(queue.as_str())
281    )
282}
283
284/// Build a stable durable consumer name for a session consumer.
285///
286/// Session consumers must have a name that is distinct from the queue-level
287/// consumer returned by [`consumer_name`].  NATS enforces config consistency
288/// on named durable consumers: if the same name is submitted with a different
289/// `filter_subject` the server returns an error (or silently reuses the old
290/// consumer, routing messages through the wrong filter).  Incorporating the
291/// session ID into the name keeps each session consumer independent.
292fn session_consumer_name(config: &NatsConfig, queue: &QueueName, session_id: &SessionId) -> String {
293    let safe_sid = nats_safe_session_id(session_id.as_str());
294    format!(
295        "{}-{}-session-{}-consumer",
296        nats_safe(&config.stream_prefix),
297        nats_safe(queue.as_str()),
298        safe_sid
299    )
300}
301
302/// Build the dead-letter subject for a queue if DLQ is enabled.
303fn dead_letter_subject(config: &NatsConfig, queue: &QueueName) -> Option<String> {
304    if !config.enable_dead_letter {
305        return None;
306    }
307    config
308        .dead_letter_subject_prefix
309        .as_ref()
310        .map(|prefix| format!("{}.{}", nats_safe(prefix), nats_safe(queue.as_str())))
311}
312
313// ============================================================================
314// NatsProvider
315// ============================================================================
316
317/// NATS queue provider using JetStream for persistent, at-least-once delivery.
318///
319/// Each queue maps to a JetStream stream created on demand.  Messages are published
320/// to per-queue subjects and consumed via pull consumers.
321pub struct NatsProvider {
322    client: async_nats::Client,
323    jetstream: JetStreamContext,
324    config: NatsConfig,
325    /// In-flight messages indexed by their receipt handle UUID.
326    in_flight: Arc<Mutex<HashMap<String, InFlightEntry>>>,
327}
328
329impl NatsProvider {
330    /// Create a new [`NatsProvider`] and connect to the NATS server.
331    ///
332    /// # Errors
333    ///
334    /// Returns [`NatsError`] if the connection cannot be established or the
335    /// credentials file cannot be read.
336    ///
337    /// # Examples
338    ///
339    /// ```rust,no_run
340    /// use queue_runtime::providers::NatsProvider;
341    /// use queue_runtime::NatsConfig;
342    ///
343    /// # async fn example() {
344    /// let config = NatsConfig::default();
345    /// let provider = NatsProvider::new(config).await.unwrap();
346    /// # }
347    /// ```
348    pub async fn new(config: NatsConfig) -> Result<Self, NatsError> {
349        let connect_options = if let Some(ref creds_path) = config.credentials_path {
350            async_nats::ConnectOptions::with_credentials_file(creds_path.as_str())
351                .await
352                .map_err(|e| NatsError::new(format!("failed to load NATS credentials: {}", e)))?
353        } else {
354            async_nats::ConnectOptions::new()
355        };
356
357        let client = connect_options.connect(&config.url).await.map_err(|e| {
358            NatsError::new(format!(
359                "failed to connect to NATS at '{}': {}",
360                redact_url(&config.url),
361                e
362            ))
363        })?;
364
365        let jetstream = jetstream::new(client.clone());
366
367        debug!(url = %redact_url(&config.url), "Connected to NATS");
368
369        Ok(Self {
370            client,
371            jetstream,
372            config,
373            in_flight: Arc::new(Mutex::new(HashMap::new())),
374        })
375    }
376
377    /// Ensure a JetStream stream exists for the given queue, creating it if needed.
378    ///
379    /// Streams are created with the `WorkQueue` retention policy so each message
380    /// is delivered to exactly one consumer.
381    async fn ensure_stream(&self, queue: &QueueName) -> Result<(), QueueError> {
382        let name = stream_name(&self.config, queue);
383        let subject = queue_subject(&self.config, queue);
384
385        // Subject wildcard: capture both the main subject and all session subjects.
386        let subjects = vec![subject.clone(), format!("{}.session.>", subject)];
387
388        let stream_config = StreamConfig {
389            name: name.clone(),
390            subjects,
391            retention: async_nats::jetstream::stream::RetentionPolicy::WorkQueue,
392            storage: async_nats::jetstream::stream::StorageType::File,
393            ..Default::default()
394        };
395
396        self.jetstream
397            .get_or_create_stream(stream_config)
398            .await
399            .map_err(|e| QueueError::ProviderError {
400                provider: "nats".to_string(),
401                code: "STREAM_CREATE_FAILED".to_string(),
402                message: format!("failed to ensure JetStream stream '{}': {}", name, e),
403            })?;
404
405        // Also ensure the DLQ stream exists if dead letter is enabled.
406        self.ensure_dlq_stream(queue).await?;
407
408        Ok(())
409    }
410
411    /// Ensure a JetStream stream exists for the dead letter queue.
412    async fn ensure_dlq_stream(&self, queue: &QueueName) -> Result<(), QueueError> {
413        let dlq_subject = match dead_letter_subject(&self.config, queue) {
414            Some(s) => s,
415            None => return Ok(()),
416        };
417
418        let dlq_stream_name = format!(
419            "dlq-{}-{}",
420            nats_safe(&self.config.stream_prefix),
421            nats_safe(queue.as_str())
422        );
423
424        let stream_config = StreamConfig {
425            name: dlq_stream_name.clone(),
426            subjects: vec![dlq_subject],
427            storage: async_nats::jetstream::stream::StorageType::File,
428            ..Default::default()
429        };
430
431        self.jetstream
432            .get_or_create_stream(stream_config)
433            .await
434            .map_err(|e| QueueError::ProviderError {
435                provider: "nats".to_string(),
436                code: "DLQ_STREAM_CREATE_FAILED".to_string(),
437                message: format!("failed to ensure DLQ stream '{}': {}", dlq_stream_name, e),
438            })?;
439
440        Ok(())
441    }
442
443    /// Create or retrieve a named durable pull consumer for the given subject filter.
444    ///
445    /// Using named durable consumers means the server tracks delivery count across
446    /// successive `receive_message` calls (each call reuses the same server-side
447    /// consumer, so `Info::delivered` increments correctly on redelivery).
448    ///
449    /// The `name` parameter must be unique per `filter_subject`.  Queue-level
450    /// consumers and per-session consumers must use distinct names (see
451    /// [`consumer_name`] and [`session_consumer_name`]) to avoid the NATS server
452    /// rejecting a name reuse with a different filter.
453    async fn create_consumer(
454        &self,
455        queue: &QueueName,
456        name: &str,
457        filter_subject: &str,
458    ) -> Result<async_nats::jetstream::consumer::Consumer<ConsumerConfig>, QueueError> {
459        let stream_name = stream_name(&self.config, queue);
460        let ack_wait_std = self
461            .config
462            .ack_wait
463            .to_std()
464            .unwrap_or(std::time::Duration::from_secs(30));
465
466        let consumer_config = ConsumerConfig {
467            name: Some(name.to_string()),
468            durable_name: Some(name.to_string()),
469            filter_subject: filter_subject.to_string(),
470            ack_policy: async_nats::jetstream::consumer::AckPolicy::Explicit,
471            ack_wait: ack_wait_std,
472            max_deliver: self.config.max_deliver.unwrap_or(-1),
473            // Expire session consumers after twice the session lock duration so
474            // they are cleaned up automatically by the server once a session is
475            // no longer in use.  Without this, named durable consumers accumulate
476            // indefinitely on the server (one per unique session ID).  Queue-level
477            // consumers also benefit: if the process dies, the server reclaims the
478            // consumer slot after the inactive window instead of keeping it forever.
479            inactive_threshold: self
480                .config
481                .session_lock_duration
482                .to_std()
483                .unwrap_or(std::time::Duration::from_secs(300))
484                .saturating_mul(2),
485            ..Default::default()
486        };
487
488        let stream = self.jetstream.get_stream(&stream_name).await.map_err(|e| {
489            QueueError::ProviderError {
490                provider: "nats".to_string(),
491                code: "STREAM_GET_FAILED".to_string(),
492                message: format!("failed to get stream '{}': {}", stream_name, e),
493            }
494        })?;
495
496        let consumer = stream
497            .get_or_create_consumer(name, consumer_config)
498            .await
499            .map_err(|e| QueueError::ProviderError {
500                provider: "nats".to_string(),
501                code: "CONSUMER_CREATE_FAILED".to_string(),
502                message: format!(
503                    "failed to get or create pull consumer on '{}': {}",
504                    stream_name, e
505                ),
506            })?;
507
508        Ok(consumer)
509    }
510
511    /// Encode message metadata into NATS message headers.
512    fn build_headers(message: &Message) -> async_nats::header::HeaderMap {
513        let mut headers = async_nats::header::HeaderMap::new();
514
515        if let Some(ref sid) = message.session_id {
516            headers.insert("x-session-id", sid.as_str());
517        }
518        if let Some(ref corr_id) = message.correlation_id {
519            headers.insert("x-correlation-id", corr_id.as_str());
520        }
521        for (k, v) in &message.attributes {
522            // Prefix user attributes to distinguish from provider headers.
523            headers.insert(format!("x-attr-{}", k).as_str(), v.as_str());
524        }
525
526        headers
527    }
528
529    /// Extract message attributes from NATS headers.
530    fn extract_attributes(
531        headers: &Option<async_nats::header::HeaderMap>,
532    ) -> HashMap<String, String> {
533        let mut attrs = HashMap::new();
534        if let Some(hm) = headers {
535            for (name, values) in hm.iter() {
536                // HeaderName implements AsRef<str> for &str access
537                let key: &str = name.as_ref();
538                if let Some(attr_key) = key.strip_prefix("x-attr-") {
539                    if let Some(val) = values.first() {
540                        attrs.insert(attr_key.to_string(), val.as_str().to_string());
541                    }
542                }
543            }
544        }
545        attrs
546    }
547
548    /// Extract the session ID from NATS headers.
549    fn extract_session_id(headers: &Option<async_nats::header::HeaderMap>) -> Option<SessionId> {
550        if let Some(hm) = headers {
551            if let Some(val) = hm.get("x-session-id") {
552                let id = val.as_str().to_string();
553                return SessionId::new(id).ok();
554            }
555        }
556        None
557    }
558
559    /// Extract the correlation ID from NATS headers.
560    fn extract_correlation_id(headers: &Option<async_nats::header::HeaderMap>) -> Option<String> {
561        if let Some(hm) = headers {
562            if let Some(val) = hm.get("x-correlation-id") {
563                return Some(val.as_str().to_string());
564            }
565        }
566        None
567    }
568
569    /// Register a JetStream message in the in-flight map and return a [`ReceivedMessage`].
570    async fn register_js_message(
571        &self,
572        js_message: async_nats::jetstream::Message,
573        queue: &QueueName,
574    ) -> ReceivedMessage {
575        let headers = js_message.message.headers.clone();
576        let session_id = Self::extract_session_id(&headers);
577        let attributes = Self::extract_attributes(&headers);
578        let correlation_id = Self::extract_correlation_id(&headers);
579        let delivery_count = js_message.info().map(|i| i.delivered as u32).unwrap_or(1);
580        let body = Bytes::copy_from_slice(&js_message.message.payload);
581
582        let now = Timestamp::now();
583        let lock_expires_at =
584            Timestamp::from_datetime(now.as_datetime() + self.config.session_lock_duration);
585
586        let receipt_id = uuid::Uuid::new_v4().to_string();
587        let message_id = MessageId::new();
588
589        let dlq_subject = dead_letter_subject(&self.config, queue);
590
591        self.in_flight.lock().await.insert(
592            receipt_id.clone(),
593            InFlightEntry {
594                js_message,
595                lock_expires_at,
596                dead_letter_subject: dlq_subject,
597            },
598        );
599
600        ReceivedMessage {
601            message_id,
602            body,
603            attributes,
604            session_id,
605            correlation_id,
606            receipt_handle: ReceiptHandle::new(receipt_id, lock_expires_at, ProviderType::Nats),
607            delivery_count,
608            first_delivered_at: now,
609            delivered_at: now,
610        }
611    }
612}
613
614// ============================================================================
615// QueueProvider implementation
616// ============================================================================
617
618#[async_trait]
619impl QueueProvider for NatsProvider {
620    #[instrument(skip(self, message), fields(queue = %queue))]
621    async fn send_message(
622        &self,
623        queue: &QueueName,
624        message: &Message,
625    ) -> Result<MessageId, QueueError> {
626        let size = message.body.len();
627        let max_size = self.provider_type().max_message_size();
628        if size > max_size {
629            return Err(QueueError::MessageTooLarge { size, max_size });
630        }
631
632        self.ensure_stream(queue).await?;
633
634        // Route to session subject if session_id is set, otherwise main subject.
635        let subject = if let Some(ref sid) = message.session_id {
636            session_subject(&self.config, queue, sid)
637        } else {
638            queue_subject(&self.config, queue)
639        };
640
641        let headers = Self::build_headers(message);
642        let payload = Bytes::copy_from_slice(&message.body);
643
644        self.jetstream
645            .publish_with_headers(subject.clone(), headers, payload)
646            .await
647            .map_err(|e| QueueError::ProviderError {
648                provider: "nats".to_string(),
649                code: "PUBLISH_FAILED".to_string(),
650                message: format!("failed to publish to subject '{}': {}", subject, e),
651            })?
652            .await
653            .map_err(|e| QueueError::ProviderError {
654                provider: "nats".to_string(),
655                code: "PUBLISH_ACK_FAILED".to_string(),
656                message: format!("JetStream publish ack failed: {}", e),
657            })?;
658
659        let message_id = MessageId::new();
660        debug!(%message_id, %queue, "Published message to NATS JetStream");
661        Ok(message_id)
662    }
663
664    #[instrument(skip(self, messages), fields(queue = %queue, count = messages.len()))]
665    async fn send_messages(
666        &self,
667        queue: &QueueName,
668        messages: &[Message],
669    ) -> Result<Vec<MessageId>, QueueError> {
670        if messages.len() > self.max_batch_size() as usize {
671            return Err(QueueError::BatchTooLarge {
672                size: messages.len(),
673                max_size: self.max_batch_size() as usize,
674            });
675        }
676
677        // JetStream has no multi-message publish; messages are sent sequentially.
678        // This satisfies the batch API contract (multiple IDs returned in one call)
679        // but does not reduce the number of network round-trips compared with
680        // individual send_message calls.  The NATS protocol does not expose a
681        // PublishBatch primitive, so sequential sending is the best achievable
682        // implementation.  See docs/spec/assertions.md Assertion 20 for context.
683        let mut ids = Vec::with_capacity(messages.len());
684        for message in messages {
685            ids.push(self.send_message(queue, message).await?);
686        }
687        Ok(ids)
688    }
689
690    #[instrument(skip(self), fields(queue = %queue))]
691    async fn receive_message(
692        &self,
693        queue: &QueueName,
694        timeout: Duration,
695    ) -> Result<Option<ReceivedMessage>, QueueError> {
696        self.ensure_stream(queue).await?;
697
698        let subject = queue_subject(&self.config, queue);
699        let name = consumer_name(&self.config, queue);
700        let consumer = self.create_consumer(queue, &name, &subject).await?;
701
702        let timeout_std = timeout
703            .to_std()
704            .unwrap_or(std::time::Duration::from_secs(30));
705
706        let mut messages = consumer
707            .fetch()
708            .max_messages(1)
709            .expires(timeout_std)
710            .messages()
711            .await
712            .map_err(|e| QueueError::ProviderError {
713                provider: "nats".to_string(),
714                code: "FETCH_FAILED".to_string(),
715                message: format!("failed to fetch from JetStream: {}", e),
716            })?;
717
718        match tokio::time::timeout(timeout_std, messages.next()).await {
719            Ok(Some(Ok(js_msg))) => {
720                let msg = self.register_js_message(js_msg, queue).await;
721                Ok(Some(msg))
722            }
723            Ok(Some(Err(e))) => Err(QueueError::ProviderError {
724                provider: "nats".to_string(),
725                code: "MESSAGE_ERROR".to_string(),
726                message: format!("error reading JetStream message: {}", e),
727            }),
728            Ok(None) | Err(_) => Ok(None),
729        }
730    }
731
732    #[instrument(skip(self), fields(queue = %queue, max = max_messages))]
733    async fn receive_messages(
734        &self,
735        queue: &QueueName,
736        max_messages: u32,
737        timeout: Duration,
738    ) -> Result<Vec<ReceivedMessage>, QueueError> {
739        self.ensure_stream(queue).await?;
740
741        let subject = queue_subject(&self.config, queue);
742        let name = consumer_name(&self.config, queue);
743        let consumer = self.create_consumer(queue, &name, &subject).await?;
744
745        let timeout_std = timeout
746            .to_std()
747            .unwrap_or(std::time::Duration::from_secs(30));
748
749        let mut js_messages = consumer
750            .fetch()
751            .max_messages(max_messages as usize)
752            .expires(timeout_std)
753            .messages()
754            .await
755            .map_err(|e| QueueError::ProviderError {
756                provider: "nats".to_string(),
757                code: "FETCH_FAILED".to_string(),
758                message: format!("failed to fetch from JetStream: {}", e),
759            })?;
760
761        let mut result = Vec::new();
762        let deadline = tokio::time::Instant::now() + timeout_std;
763
764        loop {
765            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
766            if remaining.is_zero() || result.len() >= max_messages as usize {
767                break;
768            }
769
770            match tokio::time::timeout(remaining, js_messages.next()).await {
771                Ok(Some(Ok(js_msg))) => {
772                    let msg = self.register_js_message(js_msg, queue).await;
773                    result.push(msg);
774                }
775                Ok(Some(Err(e))) => {
776                    return Err(QueueError::ProviderError {
777                        provider: "nats".to_string(),
778                        code: "MESSAGE_ERROR".to_string(),
779                        message: format!("error reading JetStream message: {}", e),
780                    });
781                }
782                Ok(None) | Err(_) => break,
783            }
784        }
785
786        Ok(result)
787    }
788
789    #[instrument(skip(self, receipt))]
790    async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
791        let mut in_flight = self.in_flight.lock().await;
792
793        // Check existence and expiry before removal so callers receive a
794        // meaningful error and the JetStream message is not abandoned silently.
795        match in_flight.get(receipt.handle()) {
796            None => {
797                return Err(QueueError::MessageNotFound {
798                    receipt: receipt.handle().to_string(),
799                });
800            }
801            Some(entry) if Timestamp::now() > entry.lock_expires_at => {
802                in_flight.remove(receipt.handle());
803                return Err(QueueError::MessageNotFound {
804                    receipt: format!("{}(expired)", receipt.handle()),
805                });
806            }
807            Some(_) => {}
808        }
809
810        let entry = in_flight
811            .remove(receipt.handle())
812            .expect("entry present after pre-check");
813
814        entry
815            .js_message
816            .ack()
817            .await
818            .map_err(|e| QueueError::ProviderError {
819                provider: "nats".to_string(),
820                code: "ACK_FAILED".to_string(),
821                message: format!("JetStream ack failed: {}", e),
822            })?;
823
824        Ok(())
825    }
826
827    #[instrument(skip(self, receipt))]
828    async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
829        let mut in_flight = self.in_flight.lock().await;
830
831        match in_flight.get(receipt.handle()) {
832            None => {
833                return Err(QueueError::MessageNotFound {
834                    receipt: receipt.handle().to_string(),
835                });
836            }
837            Some(entry) if Timestamp::now() > entry.lock_expires_at => {
838                in_flight.remove(receipt.handle());
839                return Err(QueueError::MessageNotFound {
840                    receipt: format!("{}(expired)", receipt.handle()),
841                });
842            }
843            Some(_) => {}
844        }
845
846        let entry = in_flight
847            .remove(receipt.handle())
848            .expect("entry present after pre-check");
849
850        entry
851            .js_message
852            .ack_with(AckKind::Nak(None))
853            .await
854            .map_err(|e| QueueError::ProviderError {
855                provider: "nats".to_string(),
856                code: "NAK_FAILED".to_string(),
857                message: format!("JetStream nak failed: {}", e),
858            })?;
859
860        Ok(())
861    }
862
863    #[instrument(skip(self, receipt), fields(reason = %reason))]
864    async fn dead_letter_message(
865        &self,
866        receipt: &ReceiptHandle,
867        reason: &str,
868    ) -> Result<(), QueueError> {
869        let mut in_flight = self.in_flight.lock().await;
870
871        match in_flight.get(receipt.handle()) {
872            None => {
873                return Err(QueueError::MessageNotFound {
874                    receipt: receipt.handle().to_string(),
875                });
876            }
877            Some(entry) if Timestamp::now() > entry.lock_expires_at => {
878                in_flight.remove(receipt.handle());
879                return Err(QueueError::MessageNotFound {
880                    receipt: format!("{}(expired)", receipt.handle()),
881                });
882            }
883            Some(_) => {}
884        }
885
886        let entry = in_flight
887            .remove(receipt.handle())
888            .expect("entry present after pre-check");
889
890        // Terminate delivery so JetStream stops redelivering this message.
891        // This must happen before the DLQ publish so that even if DLQ write
892        // fails the message is not stuck in re-delivery.
893        entry
894            .js_message
895            .ack_with(async_nats::jetstream::AckKind::Term)
896            .await
897            .map_err(|e| QueueError::ProviderError {
898                provider: "nats".to_string(),
899                code: "TERM_FAILED".to_string(),
900                message: format!("JetStream term ack failed: {}", e),
901            })?;
902
903        // Publish to DLQ stream if configured.  The Term above is the
904        // authoritative action; DLQ publishing is best-effort.  If the
905        // publish fails the message has still been terminated from JetStream
906        // (it will not be redelivered), so we log the error and return Ok
907        // rather than signalling failure to the caller.
908        if let Some(ref dlq_subject) = entry.dead_letter_subject {
909            let mut headers = async_nats::header::HeaderMap::new();
910            headers.insert("x-dead-letter-reason", reason);
911            let payload = entry.js_message.message.payload.clone();
912            if let Some(msg_headers) = &entry.js_message.message.headers {
913                for (name, values) in msg_headers.iter() {
914                    // HeaderName implements AsRef<str>
915                    let key: &str = name.as_ref();
916                    for val in values.iter() {
917                        headers.insert(key, val.as_str());
918                    }
919                }
920            }
921
922            if let Err(e) = self
923                .client
924                .publish_with_headers(dlq_subject.clone(), headers, payload)
925                .await
926            {
927                // Log the failure but do not surface it — the message has
928                // already been terminated from JetStream.
929                warn!(
930                    reason,
931                    dlq_subject,
932                    error = %e,
933                    "Failed to publish dead-lettered message to DLQ (message already terminated)"
934                );
935            } else {
936                debug!(
937                    reason,
938                    dlq_subject, "Message dead-lettered and published to DLQ"
939                );
940            }
941        } else {
942            debug!(reason, "Message terminated (no DLQ configured)");
943        }
944
945        Ok(())
946    }
947
948    #[instrument(skip(self), fields(queue = %queue))]
949    async fn create_session_client(
950        &self,
951        queue: &QueueName,
952        session_id: Option<SessionId>,
953    ) -> Result<Box<dyn SessionProvider>, QueueError> {
954        let sid = match session_id {
955            Some(id) => id,
956            None => {
957                // NATS does not enumerate active sessions; an explicit ID is required.
958                return Err(QueueError::SessionNotFound {
959                    session_id: "<any>".to_string(),
960                });
961            }
962        };
963
964        self.ensure_stream(queue).await?;
965
966        let subject = session_subject(&self.config, queue, &sid);
967        let name = session_consumer_name(&self.config, queue, &sid);
968        let consumer = self.create_consumer(queue, &name, &subject).await?;
969
970        let now = Timestamp::now();
971        let lock_expires_at =
972            Timestamp::from_datetime(now.as_datetime() + self.config.session_lock_duration);
973
974        Ok(Box::new(NatsSessionProvider {
975            consumer: Arc::new(Mutex::new(consumer)),
976            client: self.client.clone(),
977            session_id: sid,
978            queue_name: queue.clone(),
979            in_flight: self.in_flight.clone(),
980            lock_expires_at: Arc::new(std::sync::Mutex::new(lock_expires_at)),
981            config: self.config.clone(),
982        }))
983    }
984
985    fn provider_type(&self) -> ProviderType {
986        ProviderType::Nats
987    }
988
989    fn supports_sessions(&self) -> SessionSupport {
990        SessionSupport::Emulated
991    }
992
993    fn supports_batching(&self) -> bool {
994        true
995    }
996
997    fn max_batch_size(&self) -> u32 {
998        100
999    }
1000}
1001
1002// ============================================================================
1003// NatsSessionProvider
1004// ============================================================================
1005
1006/// Session provider for NATS using a per-session JetStream pull consumer.
1007///
1008/// Messages for the session are filtered by a dedicated subject
1009/// (`{prefix}.{queue}.session.{session_id}`), ensuring ordered, exclusive delivery
1010/// within the session.
1011pub struct NatsSessionProvider {
1012    consumer: Arc<Mutex<async_nats::jetstream::consumer::Consumer<ConsumerConfig>>>,
1013    client: async_nats::Client,
1014    session_id: SessionId,
1015    queue_name: QueueName,
1016    in_flight: Arc<Mutex<HashMap<String, InFlightEntry>>>,
1017    /// Current session-lock expiry; shared so `renew_session_lock` can update it.
1018    lock_expires_at: Arc<std::sync::Mutex<Timestamp>>,
1019    config: NatsConfig,
1020}
1021
1022#[async_trait]
1023impl SessionProvider for NatsSessionProvider {
1024    #[instrument(skip(self), fields(session_id = %self.session_id))]
1025    async fn receive_message(
1026        &self,
1027        timeout: Duration,
1028    ) -> Result<Option<ReceivedMessage>, QueueError> {
1029        self.check_lock()?;
1030
1031        let timeout_std = timeout
1032            .to_std()
1033            .unwrap_or(std::time::Duration::from_secs(30));
1034
1035        let consumer = self.consumer.lock().await;
1036
1037        let mut messages = consumer
1038            .fetch()
1039            .max_messages(1)
1040            .expires(timeout_std)
1041            .messages()
1042            .await
1043            .map_err(|e| QueueError::ProviderError {
1044                provider: "nats".to_string(),
1045                code: "FETCH_FAILED".to_string(),
1046                message: format!("session fetch failed: {}", e),
1047            })?;
1048
1049        match tokio::time::timeout(timeout_std, messages.next()).await {
1050            Ok(Some(Ok(js_msg))) => {
1051                let msg = self.register_session_message(js_msg).await;
1052                Ok(Some(msg))
1053            }
1054            Ok(Some(Err(e))) => Err(QueueError::ProviderError {
1055                provider: "nats".to_string(),
1056                code: "MESSAGE_ERROR".to_string(),
1057                message: format!("session message error: {}", e),
1058            }),
1059            Ok(None) | Err(_) => Ok(None),
1060        }
1061    }
1062
1063    #[instrument(skip(self, receipt))]
1064    async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
1065        self.check_lock()?;
1066        self.ack_message(receipt, SettlementKind::Ack).await
1067    }
1068
1069    #[instrument(skip(self, receipt))]
1070    async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
1071        self.check_lock()?;
1072        self.ack_message(receipt, SettlementKind::Nak).await
1073    }
1074
1075    #[instrument(skip(self, receipt), fields(reason = %reason))]
1076    async fn dead_letter_message(
1077        &self,
1078        receipt: &ReceiptHandle,
1079        reason: &str,
1080    ) -> Result<(), QueueError> {
1081        self.check_lock()?;
1082
1083        let mut in_flight = self.in_flight.lock().await;
1084
1085        // Check existence and expiry before removal.
1086        match in_flight.get(receipt.handle()) {
1087            None => {
1088                return Err(QueueError::MessageNotFound {
1089                    receipt: receipt.handle().to_string(),
1090                });
1091            }
1092            Some(entry) if Timestamp::now() > entry.lock_expires_at => {
1093                in_flight.remove(receipt.handle());
1094                return Err(QueueError::MessageNotFound {
1095                    receipt: format!("{}(expired)", receipt.handle()),
1096                });
1097            }
1098            Some(_) => {}
1099        }
1100
1101        let entry = in_flight
1102            .remove(receipt.handle())
1103            .expect("entry present after pre-check");
1104
1105        entry
1106            .js_message
1107            .ack_with(async_nats::jetstream::AckKind::Term)
1108            .await
1109            .map_err(|e| QueueError::ProviderError {
1110                provider: "nats".to_string(),
1111                code: "TERM_FAILED".to_string(),
1112                message: format!("session term ack failed: {}", e),
1113            })?;
1114
1115        // Forward to DLQ if configured.  Best-effort: Term above is canonical;
1116        // if the DLQ publish fails we log and do not propagate the error.
1117        if let Some(ref dlq_subject) = entry.dead_letter_subject {
1118            let mut headers = async_nats::header::HeaderMap::new();
1119            headers.insert("x-dead-letter-reason", reason);
1120            let payload = entry.js_message.message.payload.clone();
1121
1122            if let Err(e) = self
1123                .client
1124                .publish_with_headers(dlq_subject.clone(), headers, payload)
1125                .await
1126            {
1127                warn!(
1128                    reason,
1129                    dlq_subject,
1130                    error = %e,
1131                    "Session: failed to publish dead-lettered message to DLQ (message already terminated)"
1132                );
1133            } else {
1134                debug!(reason, dlq_subject, "Session message dead-lettered");
1135            }
1136        }
1137
1138        Ok(())
1139    }
1140
1141    async fn renew_session_lock(&self) -> Result<(), QueueError> {
1142        advance_session_lock(&self.lock_expires_at, self.config.session_lock_duration)?;
1143        debug!(session_id = %self.session_id, "NATS session lock renewed");
1144        Ok(())
1145    }
1146
1147    async fn close_session(&self) -> Result<(), QueueError> {
1148        // Pull consumers are ephemeral and cleaned up by the server; nothing to do.
1149        Ok(())
1150    }
1151
1152    fn session_id(&self) -> &SessionId {
1153        &self.session_id
1154    }
1155
1156    fn session_expires_at(&self) -> Timestamp {
1157        // Recover from a poisoned lock by using the last known value.
1158        *self
1159            .lock_expires_at
1160            .lock()
1161            .unwrap_or_else(|e| e.into_inner())
1162    }
1163}
1164
1165// ============================================================================
1166// Session lock helpers — module-level so they are testable without a live server
1167// ============================================================================
1168
1169/// Return an error when the session lock timestamp has expired.
1170///
1171/// Extracted from [`NatsSessionProvider::check_lock`] so the expiry logic can be
1172/// verified in unit tests without constructing a full provider.
1173fn check_session_lock(
1174    lock_expires_at: &std::sync::Mutex<Timestamp>,
1175    session_id: &SessionId,
1176) -> Result<(), QueueError> {
1177    let expires = *lock_expires_at
1178        .lock()
1179        .map_err(|_| QueueError::ProviderError {
1180            provider: "nats".to_string(),
1181            code: "INTERNAL_ERROR".to_string(),
1182            message: "session lock mutex poisoned".to_string(),
1183        })?;
1184    if Timestamp::now() > expires {
1185        return Err(QueueError::SessionLocked {
1186            session_id: session_id.as_str().to_string(),
1187            locked_until: expires,
1188        });
1189    }
1190    Ok(())
1191}
1192
1193/// Advance the session lock by `duration` from now and return the new expiry.
1194///
1195/// Extracted from [`NatsSessionProvider::renew_session_lock`] for the same reason.
1196fn advance_session_lock(
1197    lock_expires_at: &std::sync::Mutex<Timestamp>,
1198    duration: Duration,
1199) -> Result<Timestamp, QueueError> {
1200    let new_expiry = Timestamp::from_datetime(Timestamp::now().as_datetime() + duration);
1201    *lock_expires_at
1202        .lock()
1203        .map_err(|_| QueueError::ProviderError {
1204            provider: "nats".to_string(),
1205            code: "INTERNAL_ERROR".to_string(),
1206            message: "session lock mutex poisoned".to_string(),
1207        })? = new_expiry;
1208    Ok(new_expiry)
1209}
1210
1211/// Internal settlement kind for session operations.
1212enum SettlementKind {
1213    Ack,
1214    Nak,
1215}
1216
1217impl NatsSessionProvider {
1218    /// Return an error if the session lock has expired.
1219    fn check_lock(&self) -> Result<(), QueueError> {
1220        check_session_lock(&self.lock_expires_at, &self.session_id)
1221    }
1222
1223    /// Ack or nak a message identified by its receipt handle.
1224    async fn ack_message(
1225        &self,
1226        receipt: &ReceiptHandle,
1227        kind: SettlementKind,
1228    ) -> Result<(), QueueError> {
1229        let mut in_flight = self.in_flight.lock().await;
1230
1231        // Check existence and expiry before removal (same pattern as NatsProvider).
1232        match in_flight.get(receipt.handle()) {
1233            None => {
1234                return Err(QueueError::MessageNotFound {
1235                    receipt: receipt.handle().to_string(),
1236                });
1237            }
1238            Some(entry) if Timestamp::now() > entry.lock_expires_at => {
1239                in_flight.remove(receipt.handle());
1240                return Err(QueueError::MessageNotFound {
1241                    receipt: format!("{}(expired)", receipt.handle()),
1242                });
1243            }
1244            Some(_) => {}
1245        }
1246
1247        let entry = in_flight
1248            .remove(receipt.handle())
1249            .expect("entry present after pre-check");
1250
1251        match kind {
1252            SettlementKind::Ack => {
1253                entry
1254                    .js_message
1255                    .ack()
1256                    .await
1257                    .map_err(|e| QueueError::ProviderError {
1258                        provider: "nats".to_string(),
1259                        code: "ACK_FAILED".to_string(),
1260                        message: format!("session ack failed: {}", e),
1261                    })
1262            }
1263            SettlementKind::Nak => {
1264                entry
1265                    .js_message
1266                    .ack_with(AckKind::Nak(None))
1267                    .await
1268                    .map_err(|e| QueueError::ProviderError {
1269                        provider: "nats".to_string(),
1270                        code: "NAK_FAILED".to_string(),
1271                        message: format!("session nak failed: {}", e),
1272                    })
1273            }
1274        }
1275    }
1276
1277    /// Register a JetStream message in the in-flight map and build a [`ReceivedMessage`].
1278    async fn register_session_message(
1279        &self,
1280        js_message: async_nats::jetstream::Message,
1281    ) -> ReceivedMessage {
1282        let headers = js_message.message.headers.clone();
1283        let attributes = NatsProvider::extract_attributes(&headers);
1284        let correlation_id = NatsProvider::extract_correlation_id(&headers);
1285        let delivery_count = js_message.info().map(|i| i.delivered as u32).unwrap_or(1);
1286        let body = Bytes::copy_from_slice(&js_message.message.payload);
1287
1288        let now = Timestamp::now();
1289        let lock_expires_at =
1290            Timestamp::from_datetime(now.as_datetime() + self.config.session_lock_duration);
1291
1292        let receipt_id = uuid::Uuid::new_v4().to_string();
1293        let message_id = MessageId::new();
1294
1295        let dlq_subject = dead_letter_subject(&self.config, &self.queue_name);
1296
1297        self.in_flight.lock().await.insert(
1298            receipt_id.clone(),
1299            InFlightEntry {
1300                js_message,
1301                lock_expires_at,
1302                dead_letter_subject: dlq_subject,
1303            },
1304        );
1305
1306        ReceivedMessage {
1307            message_id,
1308            body,
1309            attributes,
1310            session_id: Some(self.session_id.clone()),
1311            correlation_id,
1312            receipt_handle: ReceiptHandle::new(receipt_id, lock_expires_at, ProviderType::Nats),
1313            delivery_count,
1314            first_delivered_at: now,
1315            delivered_at: now,
1316        }
1317    }
1318}