Skip to main content

dynamo_runtime/transports/
nats.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! NATS transport
5//!
6//! The following environment variables are used to configure the NATS client:
7//!
8//! - `NATS_SERVER`: the NATS server address
9//!
10//! For authentication, the following environment variables are used and prioritized in the following order:
11//!
12//! - `NATS_AUTH_USERNAME`: the username for authentication
13//! - `NATS_AUTH_PASSWORD`: the password for authentication
14//! - `NATS_AUTH_TOKEN`: the token for authentication
15//! - `NATS_AUTH_NKEY`: the nkey for authentication
16//! - `NATS_AUTH_CREDENTIALS_FILE`: the path to the credentials file
17//!
18//! Note: `NATS_AUTH_USERNAME` and `NATS_AUTH_PASSWORD` must be used together.
19use crate::metrics::MetricsHierarchy;
20use crate::protocols::EndpointId;
21use crate::traits::events::EventPublisher;
22
23use anyhow::Result;
24use async_nats::connection::State;
25use async_nats::{Subscriber, client, jetstream};
26use async_trait::async_trait;
27use bytes::Bytes;
28use derive_builder::Builder;
29use futures::{StreamExt, TryStreamExt};
30use prometheus::{Counter, Gauge, Histogram, HistogramOpts, IntCounter, IntGauge, Opts, Registry};
31use serde::de::DeserializeOwned;
32use serde::{Deserialize, Serialize};
33use std::path::{Path, PathBuf};
34use std::sync::atomic::Ordering;
35use tokio::fs::File as TokioFile;
36use tokio::io::AsyncRead;
37use tokio::time;
38use url::Url;
39use validator::{Validate, ValidationError};
40
41use crate::config::environment_names::nats as env_nats;
42pub use crate::slug::Slug;
43use tracing as log;
44
45use super::utils::build_in_runtime;
46
47pub const URL_PREFIX: &str = "nats://";
48
49#[derive(Clone)]
50pub struct Client {
51    client: client::Client,
52    js_ctx: jetstream::Context,
53}
54
55impl Client {
56    /// Create a NATS [`ClientOptionsBuilder`].
57    pub fn builder() -> ClientOptionsBuilder {
58        ClientOptionsBuilder::default()
59    }
60
61    /// Returns a reference to the underlying [`async_nats::client::Client`] instance
62    pub fn client(&self) -> &client::Client {
63        &self.client
64    }
65
66    /// Returns a reference to the underlying [`async_nats::jetstream::Context`] instance
67    pub fn jetstream(&self) -> &jetstream::Context {
68        &self.js_ctx
69    }
70
71    /// host:port of NATS
72    pub fn addr(&self) -> String {
73        let info = self.client.server_info();
74        format!("{}:{}", info.host, info.port)
75    }
76
77    /// fetch the list of streams
78    pub async fn list_streams(&self) -> Result<Vec<String>> {
79        let names = self.js_ctx.stream_names();
80        let stream_names: Vec<String> = names.try_collect().await?;
81        Ok(stream_names)
82    }
83
84    /// fetch the list of consumers for a given stream
85    pub async fn list_consumers(&self, stream_name: &str) -> Result<Vec<String>> {
86        let stream = self.js_ctx.get_stream(stream_name).await?;
87        let consumers: Vec<String> = stream.consumer_names().try_collect().await?;
88        Ok(consumers)
89    }
90
91    pub async fn stream_info(&self, stream_name: &str) -> Result<jetstream::stream::State> {
92        let mut stream = self.js_ctx.get_stream(stream_name).await?;
93        let info = stream.info().await?;
94        Ok(info.state.clone())
95    }
96
97    pub async fn get_stream(&self, name: &str) -> Result<jetstream::stream::Stream> {
98        let stream = self.js_ctx.get_stream(name).await?;
99        Ok(stream)
100    }
101
102    /// Issues a broadcast request for all services with the provided `service_name` to report their
103    /// current stats. Each service will only respond once. The service may have customized the reply
104    /// so the caller should select which endpoint and what concrete data model should be used to
105    /// extract the details.
106    ///
107    /// Note: Because each endpoint will only reply once, the caller must drop the subscription after
108    /// some time or it will await forever.
109    pub async fn scrape_service(&self, service_name: &str) -> Result<Subscriber> {
110        let subject = format!("$SRV.STATS.{}", service_name);
111        let reply_subject = format!("_INBOX.{}", nuid::next());
112        let subscription = self.client.subscribe(reply_subject.clone()).await?;
113
114        // Publish the request with the reply-to subject
115        self.client
116            .publish_with_reply(subject, reply_subject, "".into())
117            .await?;
118
119        Ok(subscription)
120    }
121
122    /// Helper method to get or optionally create an object store bucket
123    ///
124    /// # Arguments
125    /// * `bucket_name` - The name of the bucket to retrieve
126    /// * `create_if_not_found` - If true, creates the bucket when it doesn't exist
127    ///
128    /// # Returns
129    /// The object store bucket or an error
130    async fn get_or_create_bucket(
131        &self,
132        bucket_name: &str,
133        create_if_not_found: bool,
134    ) -> anyhow::Result<jetstream::object_store::ObjectStore> {
135        let context = self.jetstream();
136
137        match context.get_object_store(bucket_name).await {
138            Ok(bucket) => Ok(bucket),
139            Err(err) if err.to_string().contains("stream not found") => {
140                // err.source() is GetStreamError, which has a kind() which
141                // is GetStreamErrorKind::JetStream which wraps a jetstream::Error
142                // which has code 404. Phew. So yeah check the string for now.
143
144                if create_if_not_found {
145                    tracing::debug!("Creating NATS bucket {bucket_name}");
146                    context
147                        .create_object_store(jetstream::object_store::Config {
148                            bucket: bucket_name.to_string(),
149                            ..Default::default()
150                        })
151                        .await
152                        .map_err(|e| anyhow::anyhow!("Failed creating bucket / object store: {e}"))
153                } else {
154                    anyhow::bail!(
155                        "NATS get_object_store bucket does not exist: {bucket_name}. {err}."
156                    );
157                }
158            }
159            Err(err) => {
160                anyhow::bail!("NATS get_object_store error: {err}");
161            }
162        }
163    }
164
165    /// Upload file to NATS at this URL
166    pub async fn object_store_upload(&self, filepath: &Path, nats_url: &Url) -> anyhow::Result<()> {
167        let mut disk_file = TokioFile::open(filepath).await?;
168
169        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
170        let bucket = self.get_or_create_bucket(&bucket_name, true).await?;
171
172        let key_meta = async_nats::jetstream::object_store::ObjectMetadata {
173            name: key.to_string(),
174            ..Default::default()
175        };
176        bucket.put(key_meta, &mut disk_file).await.map_err(|e| {
177            anyhow::anyhow!("Failed uploading to bucket / object store {bucket_name}/{key}: {e}")
178        })?;
179
180        Ok(())
181    }
182
183    /// Download file from NATS at this URL
184    pub async fn object_store_download(
185        &self,
186        nats_url: &Url,
187        filepath: &Path,
188    ) -> anyhow::Result<()> {
189        let mut disk_file = TokioFile::create(filepath).await?;
190
191        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
192        let bucket = self.get_or_create_bucket(&bucket_name, false).await?;
193
194        let mut obj_reader = bucket.get(&key).await.map_err(|e| {
195            anyhow::anyhow!(
196                "Failed downloading from bucket / object store {bucket_name}/{key}: {e}"
197            )
198        })?;
199        let _bytes_copied = tokio::io::copy(&mut obj_reader, &mut disk_file).await?;
200
201        Ok(())
202    }
203
204    /// Delete a bucket and all it's contents from the NATS object store
205    pub async fn object_store_delete_bucket(&self, bucket_name: &str) -> anyhow::Result<()> {
206        let context = self.jetstream();
207        match context.delete_object_store(&bucket_name).await {
208            Ok(_) => Ok(()),
209            Err(err) if err.to_string().contains("stream not found") => {
210                tracing::trace!(bucket_name, "NATS bucket already gone");
211                Ok(())
212            }
213            Err(err) => Err(anyhow::anyhow!("NATS get_object_store error: {err}")),
214        }
215    }
216
217    /// Upload a serializable struct to NATS object store using bincode
218    pub async fn object_store_upload_data<T>(&self, data: &T, nats_url: &Url) -> anyhow::Result<()>
219    where
220        T: Serialize,
221    {
222        // Serialize the data using bincode (more efficient binary format)
223        let binary_data = bincode::serialize(data)
224            .map_err(|e| anyhow::anyhow!("Failed to serialize data with bincode: {e}"))?;
225
226        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
227        let bucket = self.get_or_create_bucket(&bucket_name, true).await?;
228
229        let key_meta = async_nats::jetstream::object_store::ObjectMetadata {
230            name: key.to_string(),
231            ..Default::default()
232        };
233
234        // Upload the serialized bytes
235        let mut cursor = std::io::Cursor::new(binary_data);
236        bucket.put(key_meta, &mut cursor).await.map_err(|e| {
237            anyhow::anyhow!("Failed uploading to bucket / object store {bucket_name}/{key}: {e}")
238        })?;
239
240        Ok(())
241    }
242
243    /// Download and deserialize a struct from NATS object store using bincode
244    pub async fn object_store_download_data<T>(&self, nats_url: &Url) -> anyhow::Result<T>
245    where
246        T: DeserializeOwned,
247    {
248        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
249        let bucket = self.get_or_create_bucket(&bucket_name, false).await?;
250
251        let mut obj_reader = bucket.get(&key).await.map_err(|e| {
252            anyhow::anyhow!(
253                "Failed downloading from bucket / object store {bucket_name}/{key}: {e}"
254            )
255        })?;
256
257        // Read all bytes into memory
258        let mut buffer = Vec::new();
259        tokio::io::copy(&mut obj_reader, &mut buffer)
260            .await
261            .map_err(|e| anyhow::anyhow!("Failed reading object data: {e}"))?;
262        tracing::debug!("Downloaded {} bytes from {bucket_name}/{key}", buffer.len());
263
264        // Deserialize from bincode
265        let data = bincode::deserialize(&buffer)
266            .map_err(|e| anyhow::anyhow!("Failed to deserialize data with bincode: {e}"))?;
267
268        Ok(data)
269    }
270}
271
272/// NATS client options
273///
274/// This object uses the builder pattern with default values that are evaluates
275/// from the environment variables if they are not explicitly set by the builder.
276#[derive(Debug, Clone, Builder, Validate)]
277pub struct ClientOptions {
278    #[builder(setter(into), default = "default_server()")]
279    #[validate(custom(function = "validate_nats_server"))]
280    server: String,
281
282    #[builder(default)]
283    auth: NatsAuth,
284}
285
286fn default_server() -> String {
287    if let Ok(server) = std::env::var(env_nats::NATS_SERVER) {
288        return server;
289    }
290
291    "nats://localhost:4222".to_string()
292}
293
294fn validate_nats_server(server: &str) -> Result<(), ValidationError> {
295    if server.starts_with("nats://") {
296        Ok(())
297    } else {
298        Err(ValidationError::new("server must start with 'nats://'"))
299    }
300}
301
302// TODO(jthomson04): We really shouldn't be hardcoding this.
303const NATS_WORKER_THREADS: usize = 4;
304
305impl ClientOptions {
306    /// Create a new [`ClientOptionsBuilder`]
307    pub fn builder() -> ClientOptionsBuilder {
308        ClientOptionsBuilder::default()
309    }
310
311    /// Validate the config and attempt to connection to the NATS server
312    pub async fn connect(self) -> Result<Client> {
313        self.validate()?;
314
315        let client = match self.auth {
316            NatsAuth::UserPass(username, password) => {
317                async_nats::ConnectOptions::with_user_and_password(username, password)
318            }
319            NatsAuth::Token(token) => async_nats::ConnectOptions::with_token(token),
320            NatsAuth::NKey(nkey) => async_nats::ConnectOptions::with_nkey(nkey),
321            NatsAuth::CredentialsFile(path) => {
322                async_nats::ConnectOptions::with_credentials_file(path).await?
323            }
324        };
325
326        let (client, _) = build_in_runtime(
327            async move {
328                client
329                    .connect(self.server)
330                    .await
331                    .map_err(|e| anyhow::anyhow!("Failed to connect to NATS: {e}. Verify NATS server is running and accessible."))
332            },
333            NATS_WORKER_THREADS,
334        )
335        .await?;
336
337        let js_ctx = jetstream::new(client.clone());
338
339        // Validate JetStream is available
340        js_ctx
341            .query_account()
342            .await
343            .map_err(|e| anyhow::anyhow!("JetStream not available: {e}"))?;
344
345        Ok(Client { client, js_ctx })
346    }
347}
348
349impl Default for ClientOptions {
350    fn default() -> Self {
351        ClientOptions {
352            server: default_server(),
353            auth: NatsAuth::default(),
354        }
355    }
356}
357
358#[derive(Clone, Eq, PartialEq)]
359pub enum NatsAuth {
360    UserPass(String, String),
361    Token(String),
362    NKey(String),
363    CredentialsFile(PathBuf),
364}
365
366impl std::fmt::Debug for NatsAuth {
367    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
368        match self {
369            NatsAuth::UserPass(user, _pass) => {
370                write!(f, "UserPass({}, <redacted>)", user)
371            }
372            NatsAuth::Token(_token) => write!(f, "Token(<redacted>)"),
373            NatsAuth::NKey(_nkey) => write!(f, "NKey(<redacted>)"),
374            NatsAuth::CredentialsFile(path) => write!(f, "CredentialsFile({:?})", path),
375        }
376    }
377}
378
379impl Default for NatsAuth {
380    fn default() -> Self {
381        if let (Ok(username), Ok(password)) = (
382            std::env::var(env_nats::auth::NATS_AUTH_USERNAME),
383            std::env::var(env_nats::auth::NATS_AUTH_PASSWORD),
384        ) {
385            return NatsAuth::UserPass(username, password);
386        }
387
388        if let Ok(token) = std::env::var(env_nats::auth::NATS_AUTH_TOKEN) {
389            return NatsAuth::Token(token);
390        }
391
392        if let Ok(nkey) = std::env::var(env_nats::auth::NATS_AUTH_NKEY) {
393            return NatsAuth::NKey(nkey);
394        }
395
396        if let Ok(path) = std::env::var(env_nats::auth::NATS_AUTH_CREDENTIALS_FILE) {
397            return NatsAuth::CredentialsFile(PathBuf::from(path));
398        }
399
400        NatsAuth::UserPass("user".to_string(), "user".to_string())
401    }
402}
403
404/// Extract NATS bucket and key from a nats URL of the form:
405/// nats://host[:port]/bucket/key
406pub fn url_to_bucket_and_key(url: &Url) -> anyhow::Result<(String, String)> {
407    let Some(mut path_segments) = url.path_segments() else {
408        anyhow::bail!("No path in NATS URL: {url}");
409    };
410    let Some(bucket) = path_segments.next() else {
411        anyhow::bail!("No bucket in NATS URL: {url}");
412    };
413    let Some(key) = path_segments.next() else {
414        anyhow::bail!("No key in NATS URL: {url}");
415    };
416    Ok((bucket.to_string(), key.to_string()))
417}
418
419/// A queue implementation using NATS JetStream
420pub struct NatsQueue {
421    /// The name of the stream to use for the queue
422    stream_name: String,
423    /// The NATS server URL
424    nats_server: String,
425    /// Timeout for dequeue operations in seconds
426    dequeue_timeout: time::Duration,
427    /// The NATS client
428    client: Option<Client>,
429    /// The subject pattern used for this queue
430    subject: String,
431    /// The subscriber for pull-based consumption
432    subscriber: Option<jetstream::consumer::PullConsumer>,
433    /// Optional consumer name for broadcast pattern (if None, uses "worker-group")
434    consumer_name: Option<String>,
435    /// Message stream for efficient message consumption
436    message_stream: Option<jetstream::consumer::pull::Stream>,
437}
438
439impl NatsQueue {
440    /// Create a new NatsQueue with the default "worker-group" consumer
441    pub fn new(stream_name: String, nats_server: String, dequeue_timeout: time::Duration) -> Self {
442        // Sanitize stream name to remove path separators (like in Python version)
443        // rupei: are we sure NATs stream name accepts '_'?
444        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
445        let subject = format!("{sanitized_stream_name}.*");
446
447        Self {
448            stream_name: sanitized_stream_name,
449            nats_server,
450            dequeue_timeout,
451            client: None,
452            subject,
453            subscriber: None,
454            consumer_name: Some("worker-group".to_string()),
455            message_stream: None,
456        }
457    }
458
459    /// Create a new NatsQueue without a consumer (publisher-only mode)
460    pub fn new_without_consumer(
461        stream_name: String,
462        nats_server: String,
463        dequeue_timeout: time::Duration,
464    ) -> Self {
465        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
466        let subject = format!("{sanitized_stream_name}.*");
467
468        Self {
469            stream_name: sanitized_stream_name,
470            nats_server,
471            dequeue_timeout,
472            client: None,
473            subject,
474            subscriber: None,
475            consumer_name: None,
476            message_stream: None,
477        }
478    }
479
480    /// Create a new NatsQueue with a specific consumer name for broadcast pattern
481    /// Each consumer with a unique name will receive all messages independently
482    pub fn new_with_consumer(
483        stream_name: String,
484        nats_server: String,
485        dequeue_timeout: time::Duration,
486        consumer_name: String,
487    ) -> Self {
488        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
489        let subject = format!("{sanitized_stream_name}.*");
490
491        Self {
492            stream_name: sanitized_stream_name,
493            nats_server,
494            dequeue_timeout,
495            client: None,
496            subject,
497            subscriber: None,
498            consumer_name: Some(consumer_name),
499            message_stream: None,
500        }
501    }
502
503    /// Connect to the NATS server and set up the stream and consumer
504    pub async fn connect(&mut self) -> Result<()> {
505        self.connect_with_reset(false).await
506    }
507
508    /// Connect to the NATS server and set up the stream and consumer, optionally resetting the stream
509    pub async fn connect_with_reset(&mut self, reset_stream: bool) -> Result<()> {
510        if self.client.is_none() {
511            // Create a new client
512            let client_options = Client::builder().server(self.nats_server.clone()).build()?;
513
514            let client = client_options.connect().await?;
515
516            // messages older than a hour in the stream will be automatically purged
517            let max_age = std::env::var(env_nats::stream::DYN_NATS_STREAM_MAX_AGE)
518                .ok()
519                .and_then(|s| s.parse::<u64>().ok())
520                .map(time::Duration::from_secs)
521                .unwrap_or_else(|| time::Duration::from_secs(60 * 60));
522
523            let stream_config = jetstream::stream::Config {
524                name: self.stream_name.clone(),
525                subjects: vec![self.subject.clone()],
526                max_age,
527                ..Default::default()
528            };
529
530            // Get or create the stream
531            let stream = client
532                .jetstream()
533                .get_or_create_stream(stream_config)
534                .await?;
535
536            log::debug!("Stream {} is ready", self.stream_name);
537
538            // If reset_stream is true, purge all messages from the stream
539            if reset_stream {
540                match stream.purge().await {
541                    Ok(purge_info) => {
542                        log::info!(
543                            "Successfully purged {} messages from NATS stream {}",
544                            purge_info.purged,
545                            self.stream_name
546                        );
547                    }
548                    Err(e) => {
549                        log::warn!("Failed to purge NATS stream '{}': {e}", self.stream_name);
550                    }
551                }
552            }
553
554            // Create persistent subscriber only if consumer_name is set
555            if let Some(ref consumer_name) = self.consumer_name {
556                let consumer_config = jetstream::consumer::pull::Config {
557                    durable_name: Some(consumer_name.clone()),
558                    inactive_threshold: std::time::Duration::from_secs(3600), // 1 hour
559                    ..Default::default()
560                };
561
562                let subscriber = stream.create_consumer(consumer_config).await?;
563
564                // Create the message stream for efficient consumption
565                let message_stream = subscriber.messages().await?;
566
567                self.subscriber = Some(subscriber);
568                self.message_stream = Some(message_stream);
569            }
570
571            self.client = Some(client);
572        }
573
574        Ok(())
575    }
576
577    /// Ensure we have an active connection
578    pub async fn ensure_connection(&mut self) -> Result<()> {
579        if self.client.is_none() {
580            self.connect().await?;
581        }
582        Ok(())
583    }
584
585    /// Close the connection when done
586    pub async fn close(&mut self) -> Result<()> {
587        self.message_stream = None;
588        self.subscriber = None;
589        self.client = None;
590        Ok(())
591    }
592
593    /// Shutdown the consumer by deleting it from the stream and closing the connection
594    /// This permanently removes the consumer from the server
595    ///
596    /// If `consumer_name` is provided, that specific consumer will be deleted instead of the
597    /// current consumer. This allows deletion of other consumers on the same stream.
598    pub async fn shutdown(&mut self, consumer_name: Option<String>) -> Result<()> {
599        // Determine which consumer to delete
600        let target_consumer = consumer_name.as_ref().or(self.consumer_name.as_ref());
601
602        // Warn if deleting our own consumer via explicit parameter
603        if let Some(ref passed_name) = consumer_name
604            && self.consumer_name.as_ref() == Some(passed_name)
605        {
606            log::warn!(
607                "Deleting our own consumer '{}' via explicit consumer_name parameter. \
608                Consider calling shutdown without arguments instead.",
609                passed_name
610            );
611        }
612
613        if let (Some(client), Some(consumer_to_delete)) = (&self.client, target_consumer) {
614            // Get the stream and delete the consumer
615            let stream = client.jetstream().get_stream(&self.stream_name).await?;
616            stream
617                .delete_consumer(consumer_to_delete)
618                .await
619                .map_err(|e| {
620                    anyhow::anyhow!("Failed to delete consumer {}: {}", consumer_to_delete, e)
621                })?;
622            log::debug!(
623                "Deleted consumer {} from stream {}",
624                consumer_to_delete,
625                self.stream_name
626            );
627        } else {
628            log::debug!(
629                "Cannot shutdown consumer: client or target consumer is None (client: {:?}, target_consumer: {:?})",
630                self.client.is_some(),
631                target_consumer.is_some()
632            );
633        }
634
635        // Only close the connection if we deleted our own consumer
636        if consumer_name.is_none() {
637            self.close().await
638        } else {
639            Ok(())
640        }
641    }
642
643    /// Count the number of consumers for the stream
644    pub async fn count_consumers(&mut self) -> Result<usize> {
645        self.ensure_connection().await?;
646
647        if let Some(client) = &self.client {
648            let mut stream = client.jetstream().get_stream(&self.stream_name).await?;
649            let info = stream.info().await?;
650            Ok(info.state.consumer_count)
651        } else {
652            Err(anyhow::anyhow!("Client not connected"))
653        }
654    }
655
656    /// List all consumer names for the stream
657    pub async fn list_consumers(&mut self) -> Result<Vec<String>> {
658        self.ensure_connection().await?;
659
660        if let Some(client) = &self.client {
661            client.list_consumers(&self.stream_name).await
662        } else {
663            Err(anyhow::anyhow!("Client not connected"))
664        }
665    }
666
667    /// Enqueue a task using the provided data
668    pub async fn enqueue_task(&mut self, task_data: Bytes) -> Result<()> {
669        self.ensure_connection().await?;
670
671        if let Some(client) = &self.client {
672            let subject = format!("{}.queue", self.stream_name);
673            client.jetstream().publish(subject, task_data).await?;
674            Ok(())
675        } else {
676            Err(anyhow::anyhow!("Client not connected"))
677        }
678    }
679
680    /// Dequeue and return a task as raw bytes
681    pub async fn dequeue_task(&mut self, timeout: Option<time::Duration>) -> Result<Option<Bytes>> {
682        self.ensure_connection().await?;
683
684        let Some(ref mut stream) = self.message_stream else {
685            return Err(anyhow::anyhow!("Message stream not initialized"));
686        };
687
688        let timeout_duration = timeout.unwrap_or(self.dequeue_timeout);
689
690        // Try to get next message from the stream with timeout
691        let message = tokio::time::timeout(timeout_duration, stream.next()).await;
692
693        match message {
694            Ok(Some(Ok(msg))) => {
695                msg.ack()
696                    .await
697                    .map_err(|e| anyhow::anyhow!("Failed to ack message: {}", e))?;
698                Ok(Some(msg.payload.clone()))
699            }
700
701            Ok(Some(Err(e))) => Err(anyhow::anyhow!("Failed to get message from stream: {}", e)),
702
703            Ok(None) => Err(anyhow::anyhow!("Message stream ended unexpectedly")),
704
705            // Timeout - no messages available
706            Err(_) => Ok(None),
707        }
708    }
709
710    /// Get the number of messages currently in the queue
711    pub async fn get_queue_size(&mut self) -> Result<u64> {
712        self.ensure_connection().await?;
713
714        if let Some(client) = &self.client {
715            // Get consumer info to get pending messages count
716            let stream = client.jetstream().get_stream(&self.stream_name).await?;
717            let consumer_name = self
718                .consumer_name
719                .clone()
720                .unwrap_or_else(|| "worker-group".to_string());
721            let mut consumer: jetstream::consumer::PullConsumer = stream
722                .get_consumer(&consumer_name)
723                .await
724                .map_err(|e| anyhow::anyhow!("Failed to get consumer: {}", e))?;
725            let info = consumer.info().await?;
726
727            Ok(info.num_pending)
728        } else {
729            Err(anyhow::anyhow!("Client not connected"))
730        }
731    }
732
733    /// Get the total number of messages currently in the stream
734    pub async fn get_stream_messages(&mut self) -> Result<u64> {
735        self.ensure_connection().await?;
736
737        if let Some(client) = &self.client {
738            let mut stream = client.jetstream().get_stream(&self.stream_name).await?;
739            let info = stream.info().await?;
740            Ok(info.state.messages)
741        } else {
742            Err(anyhow::anyhow!("Client not connected"))
743        }
744    }
745
746    /// Purge messages from the stream up to (but not including) the specified sequence number
747    /// This permanently removes messages and affects all consumers of the stream
748    pub async fn purge_up_to_sequence(&self, sequence: u64) -> Result<()> {
749        if let Some(client) = &self.client {
750            let stream = client.jetstream().get_stream(&self.stream_name).await?;
751
752            // NOTE: this purge excludes the sequence itself
753            // https://docs.rs/nats/latest/nats/jetstream/struct.PurgeRequest.html
754            stream.purge().sequence(sequence).await.map_err(|e| {
755                anyhow::anyhow!("Failed to purge stream up to sequence {}: {}", sequence, e)
756            })?;
757
758            log::debug!(
759                "Purged stream {} up to sequence {}",
760                self.stream_name,
761                sequence
762            );
763            Ok(())
764        } else {
765            Err(anyhow::anyhow!("Client not connected"))
766        }
767    }
768
769    /// Purge messages from the stream up to the minimum acknowledged sequence across all consumers
770    /// This finds the lowest acknowledged sequence number across all consumers and purges up to that point
771    pub async fn purge_acknowledged(&mut self) -> Result<()> {
772        self.ensure_connection().await?;
773
774        let Some(client) = &self.client else {
775            return Err(anyhow::anyhow!("Client not connected"));
776        };
777
778        let stream = client.jetstream().get_stream(&self.stream_name).await?;
779
780        // Get all consumer names for the stream
781        let consumer_names: Vec<String> = stream
782            .consumer_names()
783            .try_collect()
784            .await
785            .map_err(|e| anyhow::anyhow!("Failed to list consumers: {}", e))?;
786
787        if consumer_names.is_empty() {
788            log::debug!("No consumers found for stream {}", self.stream_name);
789            return Ok(());
790        }
791
792        // Find the minimum acknowledged sequence across all consumers
793        let mut min_ack_sequence = u64::MAX;
794
795        for consumer_name in &consumer_names {
796            let mut consumer: jetstream::consumer::PullConsumer = stream
797                .get_consumer(consumer_name)
798                .await
799                .map_err(|e| anyhow::anyhow!("Failed to get consumer {}: {}", consumer_name, e))?;
800
801            let info = consumer.info().await.map_err(|e| {
802                anyhow::anyhow!("Failed to get consumer info for {}: {}", consumer_name, e)
803            })?;
804
805            // The ack_floor contains the stream sequence of the highest contiguously acknowledged message
806            // If stream_sequence is 0, it means no messages have been acknowledged yet
807            if info.ack_floor.stream_sequence > 0 {
808                min_ack_sequence = min_ack_sequence.min(info.ack_floor.stream_sequence);
809                log::debug!(
810                    "Consumer {} has ack_floor at sequence {}",
811                    consumer_name,
812                    info.ack_floor.stream_sequence
813                );
814            }
815        }
816
817        // Only purge if we found a valid minimum acknowledged sequence
818        if min_ack_sequence < u64::MAX && min_ack_sequence > 0 {
819            // Purge up to (but not including) the minimum acknowledged sequence + 1
820            // We add 1 because we want to include the minimum acknowledged message in the purge
821            let purge_sequence = min_ack_sequence + 1;
822
823            self.purge_up_to_sequence(purge_sequence).await?;
824
825            log::debug!(
826                "Purged stream {} up to acknowledged sequence {} (purged up to sequence {})",
827                self.stream_name,
828                min_ack_sequence,
829                purge_sequence
830            );
831        } else {
832            log::debug!(
833                "No messages to purge for stream {} (min_ack_sequence: {})",
834                self.stream_name,
835                min_ack_sequence
836            );
837        }
838
839        Ok(())
840    }
841}
842
843#[async_trait]
844impl EventPublisher for NatsQueue {
845    fn subject(&self) -> String {
846        self.stream_name.clone()
847    }
848
849    async fn publish(
850        &self,
851        event_name: impl AsRef<str> + Send + Sync,
852        event: &(impl Serialize + Send + Sync),
853    ) -> Result<()> {
854        let bytes = serde_json::to_vec(event)?;
855        self.publish_bytes(event_name, bytes).await
856    }
857
858    async fn publish_bytes(
859        &self,
860        event_name: impl AsRef<str> + Send + Sync,
861        bytes: Vec<u8>,
862    ) -> Result<()> {
863        let subject = format!("{}.{}", self.subject(), event_name.as_ref());
864
865        // Note: enqueue_task requires &mut self, but EventPublisher requires &self
866        // We need to ensure the client is connected and use it directly
867        if let Some(client) = &self.client {
868            client.jetstream().publish(subject, bytes.into()).await?;
869            Ok(())
870        } else {
871            Err(anyhow::anyhow!("Client not connected"))
872        }
873    }
874}
875
876/// The NATS subject / inbox to talk to an instance on.
877/// TODO: Do we need to sanitize the names?
878pub fn instance_subject(endpoint_id: &EndpointId, instance_id: u64) -> String {
879    format!(
880        "{}_{}.{}-{:x}",
881        endpoint_id.namespace, endpoint_id.component, endpoint_id.name, instance_id,
882    )
883}
884
885#[cfg(test)]
886mod tests {
887
888    use super::*;
889    use figment::Jail;
890    use serde::{Deserialize, Serialize};
891
892    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
893    struct TestData {
894        id: u32,
895        name: String,
896        values: Vec<f64>,
897    }
898
899    #[test]
900    fn test_client_options_builder() {
901        Jail::expect_with(|_jail| {
902            let opts = ClientOptions::builder().build();
903            assert!(opts.is_ok());
904            Ok(())
905        });
906
907        Jail::expect_with(|jail| {
908            jail.set_env(env_nats::NATS_SERVER, "nats://localhost:5222");
909            jail.set_env(env_nats::auth::NATS_AUTH_USERNAME, "user");
910            jail.set_env(env_nats::auth::NATS_AUTH_PASSWORD, "pass");
911
912            let opts = ClientOptions::builder().build();
913            assert!(opts.is_ok());
914            let opts = opts.unwrap();
915
916            assert_eq!(opts.server, "nats://localhost:5222");
917            assert_eq!(
918                opts.auth,
919                NatsAuth::UserPass("user".to_string(), "pass".to_string())
920            );
921
922            Ok(())
923        });
924
925        Jail::expect_with(|jail| {
926            jail.set_env(env_nats::NATS_SERVER, "nats://localhost:5222");
927            jail.set_env(env_nats::auth::NATS_AUTH_USERNAME, "user");
928            jail.set_env(env_nats::auth::NATS_AUTH_PASSWORD, "pass");
929
930            let opts = ClientOptions::builder()
931                .server("nats://localhost:6222")
932                .auth(NatsAuth::Token("token".to_string()))
933                .build();
934            assert!(opts.is_ok());
935            let opts = opts.unwrap();
936
937            assert_eq!(opts.server, "nats://localhost:6222");
938            assert_eq!(opts.auth, NatsAuth::Token("token".to_string()));
939
940            Ok(())
941        });
942    }
943
944    // Integration test for object store data operations using bincode
945    #[tokio::test]
946    #[ignore] // Requires NATS server to be running
947    async fn test_object_store_data_operations() {
948        // Create test data
949        let test_data = TestData {
950            id: 42,
951            name: "test_item".to_string(),
952            values: vec![1.0, 2.5, 3.7, 4.2],
953        };
954
955        // Set up client
956        let client_options = ClientOptions::builder()
957            .server("nats://localhost:4222")
958            .build()
959            .expect("Failed to build client options");
960
961        let client = client_options
962            .connect()
963            .await
964            .expect("Failed to connect to NATS");
965
966        // Test URL (using .bin extension to indicate binary format)
967        let url =
968            Url::parse("nats://localhost/test-bucket/test-data.bin").expect("Failed to parse URL");
969
970        // Upload the data
971        client
972            .object_store_upload_data(&test_data, &url)
973            .await
974            .expect("Failed to upload data");
975
976        // Download the data
977        let downloaded_data: TestData = client
978            .object_store_download_data(&url)
979            .await
980            .expect("Failed to download data");
981
982        // Verify the data matches
983        assert_eq!(test_data, downloaded_data);
984
985        // Clean up
986        client
987            .object_store_delete_bucket("test-bucket")
988            .await
989            .expect("Failed to delete bucket");
990    }
991
992    // Integration test for broadcast pattern with purging
993    #[tokio::test]
994    #[ignore]
995    async fn test_nats_queue_broadcast_with_purge() {
996        use uuid::Uuid;
997
998        // Create unique stream name for this test
999        let stream_name = format!("test-broadcast-{}", Uuid::new_v4());
1000        let nats_server = "nats://localhost:4222".to_string();
1001        let timeout = time::Duration::from_secs(0);
1002
1003        // Connect to NATS client first to delete stream if it exists
1004        let client_options = Client::builder()
1005            .server(nats_server.clone())
1006            .build()
1007            .expect("Failed to build client options");
1008
1009        let client = client_options
1010            .connect()
1011            .await
1012            .expect("Failed to connect to NATS");
1013
1014        // Delete the stream if it exists (to ensure clean start)
1015        let _ = client.jetstream().delete_stream(&stream_name).await;
1016
1017        // Create two consumers with different names for the same stream
1018        let consumer1_name = format!("consumer-{}", Uuid::new_v4());
1019        let consumer2_name = format!("consumer-{}", Uuid::new_v4());
1020
1021        let mut queue1 = NatsQueue::new_with_consumer(
1022            stream_name.clone(),
1023            nats_server.clone(),
1024            timeout,
1025            consumer1_name,
1026        );
1027
1028        // Connect queue1 first (it will create the stream)
1029        queue1.connect().await.expect("Failed to connect queue1");
1030
1031        // Send 4 messages using the EventPublisher trait
1032        let message_strings = [
1033            "message1".to_string(),
1034            "message2".to_string(),
1035            "message3".to_string(),
1036            "message4".to_string(),
1037        ];
1038
1039        // Using the EventPublisher trait to publish messages
1040        for (idx, msg) in message_strings.iter().enumerate() {
1041            queue1
1042                .publish("queue", msg)
1043                .await
1044                .unwrap_or_else(|_| panic!("Failed to publish message {}", idx + 1));
1045        }
1046
1047        // Convert messages to JSON-serialized Bytes for comparison
1048        let messages: Vec<Bytes> = message_strings
1049            .iter()
1050            .map(|s| Bytes::from(serde_json::to_vec(s).unwrap()))
1051            .collect();
1052
1053        // Give JetStream a moment to persist the messages
1054        tokio::time::sleep(time::Duration::from_millis(100)).await;
1055
1056        // Now create and connect queue2 and queue3 AFTER messages are published (to test persistence)
1057        let mut queue2 = NatsQueue::new_with_consumer(
1058            stream_name.clone(),
1059            nats_server.clone(),
1060            timeout,
1061            consumer2_name,
1062        );
1063
1064        // Create a third queue without consumer (publisher-only)
1065        let mut queue3 =
1066            NatsQueue::new_without_consumer(stream_name.clone(), nats_server.clone(), timeout);
1067
1068        // Connect queue2 and queue3 after messages are already published
1069        queue2.connect().await.expect("Failed to connect queue2");
1070        queue3.connect().await.expect("Failed to connect queue3");
1071
1072        // Purge the first two messages (sequence 1 and 2)
1073        // Note: JetStream sequences start at 1, and purge is exclusive of the sequence number
1074        queue1
1075            .purge_up_to_sequence(3)
1076            .await
1077            .expect("Failed to purge messages");
1078
1079        // Give JetStream a moment to process the purge
1080        tokio::time::sleep(time::Duration::from_millis(100)).await;
1081
1082        // Consumer 1 dequeues one message (message3)
1083        let msg3_consumer1 = queue1
1084            .dequeue_task(Some(time::Duration::from_millis(500)))
1085            .await
1086            .expect("Failed to dequeue from queue1");
1087        assert_eq!(
1088            msg3_consumer1,
1089            Some(messages[2].clone()),
1090            "Consumer 1 should get message3"
1091        );
1092
1093        // Give JetStream a moment to process acknowledgments
1094        tokio::time::sleep(time::Duration::from_millis(100)).await;
1095
1096        // Now run purge_acknowledged
1097        // At this point:
1098        // - Consumer 1 has ack'd message 3 (ack_floor = 3)
1099        // - Consumer 2 hasn't consumed anything yet (ack_floor = 0)
1100        // - Min ack_floor = 0, so nothing will be purged
1101        queue1
1102            .purge_acknowledged()
1103            .await
1104            .expect("Failed to purge acknowledged messages");
1105
1106        // Give JetStream a moment to process the purge
1107        tokio::time::sleep(time::Duration::from_millis(100)).await;
1108
1109        // Now collect remaining messages from both consumers
1110        let mut consumer1_remaining = Vec::new();
1111        let mut consumer2_remaining = Vec::new();
1112
1113        // Collect remaining messages from consumer 1
1114        while let Some(msg) = queue1
1115            .dequeue_task(None)
1116            .await
1117            .expect("Failed to dequeue from queue1")
1118        {
1119            consumer1_remaining.push(msg);
1120        }
1121
1122        // Collect remaining messages from consumer 2
1123        while let Some(msg) = queue2
1124            .dequeue_task(None)
1125            .await
1126            .expect("Failed to dequeue from queue2")
1127        {
1128            consumer2_remaining.push(msg);
1129        }
1130
1131        // Verify consumer 1 gets 1 remaining message (message4)
1132        assert_eq!(
1133            consumer1_remaining.len(),
1134            1,
1135            "Consumer 1 should have 1 remaining message"
1136        );
1137        assert_eq!(
1138            consumer1_remaining[0], messages[3],
1139            "Consumer 1 should get message4"
1140        );
1141
1142        // Verify consumer 2 gets 2 messages (message3 and message4)
1143        assert_eq!(
1144            consumer2_remaining.len(),
1145            2,
1146            "Consumer 2 should have 2 messages"
1147        );
1148        assert_eq!(
1149            consumer2_remaining[0], messages[2],
1150            "Consumer 2 should get message3"
1151        );
1152        assert_eq!(
1153            consumer2_remaining[1], messages[3],
1154            "Consumer 2 should get message4"
1155        );
1156
1157        // Test consumer count and shutdown behavior
1158        // First verify via consumer 1 that there are two consumers
1159        let consumer_count = queue1
1160            .count_consumers()
1161            .await
1162            .expect("Failed to count consumers");
1163        assert_eq!(consumer_count, 2, "Should have 2 consumers initially");
1164
1165        // Close consumer 1 and verify via consumer 2 that there are still two consumers
1166        queue1.close().await.expect("Failed to close queue1");
1167
1168        let consumer_count = queue2
1169            .count_consumers()
1170            .await
1171            .expect("Failed to count consumers");
1172        assert_eq!(
1173            consumer_count, 2,
1174            "Should still have 2 consumers after closing queue1"
1175        );
1176
1177        // Reconnect queue1 to be able to shutdown
1178        queue1.connect().await.expect("Failed to reconnect queue1");
1179
1180        // Shutdown consumer 1 and verify via consumer 2 that there is only one consumer left
1181        queue1
1182            .shutdown(None)
1183            .await
1184            .expect("Failed to shutdown queue1");
1185
1186        let consumer_count = queue2
1187            .count_consumers()
1188            .await
1189            .expect("Failed to count consumers");
1190        assert_eq!(
1191            consumer_count, 1,
1192            "Should have only 1 consumer after shutting down queue1"
1193        );
1194
1195        // Clean up by deleting the stream
1196        client
1197            .jetstream()
1198            .delete_stream(&stream_name)
1199            .await
1200            .expect("Failed to delete test stream");
1201    }
1202}