Skip to main content

dynamo_runtime/transports/
nats.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! NATS transport
5//!
6//! The following environment variables are used to configure the NATS client:
7//!
8//! - `NATS_SERVER`: the NATS server address
9//!
10//! For authentication, the following environment variables are used and prioritized in the following order:
11//!
12//! - `NATS_AUTH_USERNAME`: the username for authentication
13//! - `NATS_AUTH_PASSWORD`: the password for authentication
14//! - `NATS_AUTH_TOKEN`: the token for authentication
15//! - `NATS_AUTH_NKEY`: the nkey for authentication
16//! - `NATS_AUTH_CREDENTIALS_FILE`: the path to the credentials file
17//!
18//! Note: `NATS_AUTH_USERNAME` and `NATS_AUTH_PASSWORD` must be used together.
19use crate::metrics::MetricsHierarchy;
20use crate::protocols::EndpointId;
21
22use anyhow::Result;
23use async_nats::connection::State;
24use async_nats::{Subscriber, client, jetstream};
25use async_trait::async_trait;
26use bytes::Bytes;
27use derive_builder::Builder;
28use futures::{StreamExt, TryStreamExt};
29use prometheus::{Counter, Gauge, Histogram, HistogramOpts, IntCounter, IntGauge, Opts, Registry};
30use serde::de::DeserializeOwned;
31use serde::{Deserialize, Serialize};
32use std::path::{Path, PathBuf};
33use std::sync::atomic::Ordering;
34use tokio::fs::File as TokioFile;
35use tokio::io::AsyncRead;
36use tokio::time;
37use url::Url;
38use validator::{Validate, ValidationError};
39
40use crate::config::environment_names::nats as env_nats;
41pub use crate::slug::Slug;
42use tracing as log;
43
44use super::utils::build_in_runtime;
45
46pub const URL_PREFIX: &str = "nats://";
47
48#[derive(Clone)]
49pub struct Client {
50    client: client::Client,
51    js_ctx: jetstream::Context,
52}
53
54impl Client {
55    /// Create a NATS [`ClientOptionsBuilder`].
56    pub fn builder() -> ClientOptionsBuilder {
57        ClientOptionsBuilder::default()
58    }
59
60    /// Returns a reference to the underlying [`async_nats::client::Client`] instance
61    pub fn client(&self) -> &client::Client {
62        &self.client
63    }
64
65    /// Returns a reference to the underlying [`async_nats::jetstream::Context`] instance
66    pub fn jetstream(&self) -> &jetstream::Context {
67        &self.js_ctx
68    }
69
70    /// host:port of NATS
71    pub fn addr(&self) -> String {
72        let info = self.client.server_info();
73        format!("{}:{}", info.host, info.port)
74    }
75
76    /// fetch the list of streams
77    pub async fn list_streams(&self) -> Result<Vec<String>> {
78        let names = self.js_ctx.stream_names();
79        let stream_names: Vec<String> = names.try_collect().await?;
80        Ok(stream_names)
81    }
82
83    /// fetch the list of consumers for a given stream
84    pub async fn list_consumers(&self, stream_name: &str) -> Result<Vec<String>> {
85        let stream = self.js_ctx.get_stream(stream_name).await?;
86        let consumers: Vec<String> = stream.consumer_names().try_collect().await?;
87        Ok(consumers)
88    }
89
90    pub async fn stream_info(&self, stream_name: &str) -> Result<jetstream::stream::State> {
91        let mut stream = self.js_ctx.get_stream(stream_name).await?;
92        let info = stream.info().await?;
93        Ok(info.state.clone())
94    }
95
96    pub async fn get_stream(&self, name: &str) -> Result<jetstream::stream::Stream> {
97        let stream = self.js_ctx.get_stream(name).await?;
98        Ok(stream)
99    }
100
101    /// Issues a broadcast request for all services with the provided `service_name` to report their
102    /// current stats. Each service will only respond once. The service may have customized the reply
103    /// so the caller should select which endpoint and what concrete data model should be used to
104    /// extract the details.
105    ///
106    /// Note: Because each endpoint will only reply once, the caller must drop the subscription after
107    /// some time or it will await forever.
108    pub async fn scrape_service(&self, service_name: &str) -> Result<Subscriber> {
109        let subject = format!("$SRV.STATS.{}", service_name);
110        let reply_subject = format!("_INBOX.{}", nuid::next());
111        let subscription = self.client.subscribe(reply_subject.clone()).await?;
112
113        // Publish the request with the reply-to subject
114        self.client
115            .publish_with_reply(subject, reply_subject, "".into())
116            .await?;
117
118        Ok(subscription)
119    }
120
121    /// Helper method to get or optionally create an object store bucket
122    ///
123    /// # Arguments
124    /// * `bucket_name` - The name of the bucket to retrieve
125    /// * `create_if_not_found` - If true, creates the bucket when it doesn't exist
126    ///
127    /// # Returns
128    /// The object store bucket or an error
129    async fn get_or_create_bucket(
130        &self,
131        bucket_name: &str,
132        create_if_not_found: bool,
133    ) -> anyhow::Result<jetstream::object_store::ObjectStore> {
134        let context = self.jetstream();
135
136        match context.get_object_store(bucket_name).await {
137            Ok(bucket) => Ok(bucket),
138            Err(err) if err.to_string().contains("stream not found") => {
139                // err.source() is GetStreamError, which has a kind() which
140                // is GetStreamErrorKind::JetStream which wraps a jetstream::Error
141                // which has code 404. Phew. So yeah check the string for now.
142
143                if create_if_not_found {
144                    tracing::debug!("Creating NATS bucket {bucket_name}");
145                    context
146                        .create_object_store(jetstream::object_store::Config {
147                            bucket: bucket_name.to_string(),
148                            ..Default::default()
149                        })
150                        .await
151                        .map_err(|e| anyhow::anyhow!("Failed creating bucket / object store: {e}"))
152                } else {
153                    anyhow::bail!(
154                        "NATS get_object_store bucket does not exist: {bucket_name}. {err}."
155                    );
156                }
157            }
158            Err(err) => {
159                anyhow::bail!("NATS get_object_store error: {err}");
160            }
161        }
162    }
163
164    /// Upload file to NATS at this URL
165    pub async fn object_store_upload(&self, filepath: &Path, nats_url: &Url) -> anyhow::Result<()> {
166        let mut disk_file = TokioFile::open(filepath).await?;
167
168        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
169        let bucket = self.get_or_create_bucket(&bucket_name, true).await?;
170
171        let key_meta = async_nats::jetstream::object_store::ObjectMetadata {
172            name: key.to_string(),
173            ..Default::default()
174        };
175        bucket.put(key_meta, &mut disk_file).await.map_err(|e| {
176            anyhow::anyhow!("Failed uploading to bucket / object store {bucket_name}/{key}: {e}")
177        })?;
178
179        Ok(())
180    }
181
182    /// Download file from NATS at this URL
183    pub async fn object_store_download(
184        &self,
185        nats_url: &Url,
186        filepath: &Path,
187    ) -> anyhow::Result<()> {
188        let mut disk_file = TokioFile::create(filepath).await?;
189
190        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
191        let bucket = self.get_or_create_bucket(&bucket_name, false).await?;
192
193        let mut obj_reader = bucket.get(&key).await.map_err(|e| {
194            anyhow::anyhow!(
195                "Failed downloading from bucket / object store {bucket_name}/{key}: {e}"
196            )
197        })?;
198        let _bytes_copied = tokio::io::copy(&mut obj_reader, &mut disk_file).await?;
199
200        Ok(())
201    }
202
203    /// Delete a bucket and all it's contents from the NATS object store
204    pub async fn object_store_delete_bucket(&self, bucket_name: &str) -> anyhow::Result<()> {
205        let context = self.jetstream();
206        match context.delete_object_store(&bucket_name).await {
207            Ok(_) => Ok(()),
208            Err(err) if err.to_string().contains("stream not found") => {
209                tracing::trace!(bucket_name, "NATS bucket already gone");
210                Ok(())
211            }
212            Err(err) => Err(anyhow::anyhow!("NATS get_object_store error: {err}")),
213        }
214    }
215
216    /// Upload a serializable struct to NATS object store using bincode
217    pub async fn object_store_upload_data<T>(&self, data: &T, nats_url: &Url) -> anyhow::Result<()>
218    where
219        T: Serialize,
220    {
221        // Serialize the data using bincode (more efficient binary format)
222        let binary_data = bincode::serialize(data)
223            .map_err(|e| anyhow::anyhow!("Failed to serialize data with bincode: {e}"))?;
224
225        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
226        let bucket = self.get_or_create_bucket(&bucket_name, true).await?;
227
228        let key_meta = async_nats::jetstream::object_store::ObjectMetadata {
229            name: key.to_string(),
230            ..Default::default()
231        };
232
233        // Upload the serialized bytes
234        let mut cursor = std::io::Cursor::new(binary_data);
235        bucket.put(key_meta, &mut cursor).await.map_err(|e| {
236            anyhow::anyhow!("Failed uploading to bucket / object store {bucket_name}/{key}: {e}")
237        })?;
238
239        Ok(())
240    }
241
242    /// Download and deserialize a struct from NATS object store using bincode
243    pub async fn object_store_download_data<T>(&self, nats_url: &Url) -> anyhow::Result<T>
244    where
245        T: DeserializeOwned,
246    {
247        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
248        let bucket = self.get_or_create_bucket(&bucket_name, false).await?;
249
250        let mut obj_reader = bucket.get(&key).await.map_err(|e| {
251            anyhow::anyhow!(
252                "Failed downloading from bucket / object store {bucket_name}/{key}: {e}"
253            )
254        })?;
255
256        // Read all bytes into memory
257        let mut buffer = Vec::new();
258        tokio::io::copy(&mut obj_reader, &mut buffer)
259            .await
260            .map_err(|e| anyhow::anyhow!("Failed reading object data: {e}"))?;
261        tracing::debug!("Downloaded {} bytes from {bucket_name}/{key}", buffer.len());
262
263        // Deserialize from bincode
264        let data = bincode::deserialize(&buffer)
265            .map_err(|e| anyhow::anyhow!("Failed to deserialize data with bincode: {e}"))?;
266
267        Ok(data)
268    }
269}
270
271/// NATS client options
272///
273/// This object uses the builder pattern with default values that are evaluates
274/// from the environment variables if they are not explicitly set by the builder.
275#[derive(Debug, Clone, Builder, Validate)]
276pub struct ClientOptions {
277    #[builder(setter(into), default = "default_server()")]
278    #[validate(custom(function = "validate_nats_server"))]
279    server: String,
280
281    #[builder(default)]
282    auth: NatsAuth,
283}
284
285fn default_server() -> String {
286    if let Ok(server) = std::env::var(env_nats::NATS_SERVER) {
287        return server;
288    }
289
290    "nats://localhost:4222".to_string()
291}
292
293fn validate_nats_server(server: &str) -> Result<(), ValidationError> {
294    if server.starts_with("nats://") {
295        Ok(())
296    } else {
297        Err(ValidationError::new("server must start with 'nats://'"))
298    }
299}
300
301// TODO(jthomson04): We really shouldn't be hardcoding this.
302const NATS_WORKER_THREADS: usize = 4;
303
304impl ClientOptions {
305    /// Create a new [`ClientOptionsBuilder`]
306    pub fn builder() -> ClientOptionsBuilder {
307        ClientOptionsBuilder::default()
308    }
309
310    /// Validate the config and attempt to connection to the NATS server
311    pub async fn connect(self) -> Result<Client> {
312        self.validate()?;
313
314        let client = match self.auth {
315            NatsAuth::UserPass(username, password) => {
316                async_nats::ConnectOptions::with_user_and_password(username, password)
317            }
318            NatsAuth::Token(token) => async_nats::ConnectOptions::with_token(token),
319            NatsAuth::NKey(nkey) => async_nats::ConnectOptions::with_nkey(nkey),
320            NatsAuth::CredentialsFile(path) => {
321                async_nats::ConnectOptions::with_credentials_file(path).await?
322            }
323        };
324
325        let (client, _) = build_in_runtime(
326            async move {
327                client
328                    .connect(self.server)
329                    .await
330                    .map_err(|e| anyhow::anyhow!("Failed to connect to NATS: {e}. Verify NATS server is running and accessible."))
331            },
332            NATS_WORKER_THREADS,
333        )
334        .await?;
335
336        let js_ctx = jetstream::new(client.clone());
337
338        Ok(Client { client, js_ctx })
339    }
340}
341
342impl Default for ClientOptions {
343    fn default() -> Self {
344        ClientOptions {
345            server: default_server(),
346            auth: NatsAuth::default(),
347        }
348    }
349}
350
351#[derive(Clone, Eq, PartialEq)]
352pub enum NatsAuth {
353    UserPass(String, String),
354    Token(String),
355    NKey(String),
356    CredentialsFile(PathBuf),
357}
358
359impl std::fmt::Debug for NatsAuth {
360    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361        match self {
362            NatsAuth::UserPass(user, _pass) => {
363                write!(f, "UserPass({}, <redacted>)", user)
364            }
365            NatsAuth::Token(_token) => write!(f, "Token(<redacted>)"),
366            NatsAuth::NKey(_nkey) => write!(f, "NKey(<redacted>)"),
367            NatsAuth::CredentialsFile(path) => write!(f, "CredentialsFile({:?})", path),
368        }
369    }
370}
371
372impl Default for NatsAuth {
373    fn default() -> Self {
374        if let (Ok(username), Ok(password)) = (
375            std::env::var(env_nats::auth::NATS_AUTH_USERNAME),
376            std::env::var(env_nats::auth::NATS_AUTH_PASSWORD),
377        ) {
378            return NatsAuth::UserPass(username, password);
379        }
380
381        if let Ok(token) = std::env::var(env_nats::auth::NATS_AUTH_TOKEN) {
382            return NatsAuth::Token(token);
383        }
384
385        if let Ok(nkey) = std::env::var(env_nats::auth::NATS_AUTH_NKEY) {
386            return NatsAuth::NKey(nkey);
387        }
388
389        if let Ok(path) = std::env::var(env_nats::auth::NATS_AUTH_CREDENTIALS_FILE) {
390            return NatsAuth::CredentialsFile(PathBuf::from(path));
391        }
392
393        NatsAuth::UserPass("user".to_string(), "user".to_string())
394    }
395}
396
397/// Extract NATS bucket and key from a nats URL of the form:
398/// nats://host[:port]/bucket/key
399pub fn url_to_bucket_and_key(url: &Url) -> anyhow::Result<(String, String)> {
400    let Some(mut path_segments) = url.path_segments() else {
401        anyhow::bail!("No path in NATS URL: {url}");
402    };
403    let Some(bucket) = path_segments.next() else {
404        anyhow::bail!("No bucket in NATS URL: {url}");
405    };
406    let Some(key) = path_segments.next() else {
407        anyhow::bail!("No key in NATS URL: {url}");
408    };
409    Ok((bucket.to_string(), key.to_string()))
410}
411
412/// A queue implementation using NATS JetStream
413pub struct NatsQueue {
414    /// The name of the stream to use for the queue
415    stream_name: String,
416    /// The NATS server URL
417    nats_server: String,
418    /// Timeout for dequeue operations in seconds
419    dequeue_timeout: time::Duration,
420    /// The NATS client
421    client: Option<Client>,
422    /// The subject pattern used for this queue
423    subject: String,
424    /// The subscriber for pull-based consumption
425    subscriber: Option<jetstream::consumer::PullConsumer>,
426    /// Optional consumer name for broadcast pattern (if None, uses "worker-group")
427    consumer_name: Option<String>,
428    /// Message stream for efficient message consumption
429    message_stream: Option<jetstream::consumer::pull::Stream>,
430}
431
432impl NatsQueue {
433    /// Create a new NatsQueue with the default "worker-group" consumer
434    pub fn new(stream_name: String, nats_server: String, dequeue_timeout: time::Duration) -> Self {
435        // Sanitize stream name to remove path separators (like in Python version)
436        // rupei: are we sure NATs stream name accepts '_'?
437        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
438        let subject = format!("{sanitized_stream_name}.*");
439
440        Self {
441            stream_name: sanitized_stream_name,
442            nats_server,
443            dequeue_timeout,
444            client: None,
445            subject,
446            subscriber: None,
447            consumer_name: Some("worker-group".to_string()),
448            message_stream: None,
449        }
450    }
451
452    /// Create a new NatsQueue without a consumer (publisher-only mode)
453    pub fn new_without_consumer(
454        stream_name: String,
455        nats_server: String,
456        dequeue_timeout: time::Duration,
457    ) -> Self {
458        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
459        let subject = format!("{sanitized_stream_name}.*");
460
461        Self {
462            stream_name: sanitized_stream_name,
463            nats_server,
464            dequeue_timeout,
465            client: None,
466            subject,
467            subscriber: None,
468            consumer_name: None,
469            message_stream: None,
470        }
471    }
472
473    /// Create a new NatsQueue with a specific consumer name for broadcast pattern
474    /// Each consumer with a unique name will receive all messages independently
475    pub fn new_with_consumer(
476        stream_name: String,
477        nats_server: String,
478        dequeue_timeout: time::Duration,
479        consumer_name: String,
480    ) -> Self {
481        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
482        let subject = format!("{sanitized_stream_name}.*");
483
484        Self {
485            stream_name: sanitized_stream_name,
486            nats_server,
487            dequeue_timeout,
488            client: None,
489            subject,
490            subscriber: None,
491            consumer_name: Some(consumer_name),
492            message_stream: None,
493        }
494    }
495
496    /// Connect to the NATS server and set up the stream and consumer
497    pub async fn connect(&mut self) -> Result<()> {
498        self.connect_with_reset(false).await
499    }
500
501    /// Connect to the NATS server and set up the stream and consumer, optionally resetting the stream
502    pub async fn connect_with_reset(&mut self, reset_stream: bool) -> Result<()> {
503        if self.client.is_none() {
504            // Create a new client
505            let client_options = Client::builder().server(self.nats_server.clone()).build()?;
506
507            let client = client_options.connect().await?;
508
509            // messages older than a hour in the stream will be automatically purged
510            let max_age = std::env::var(env_nats::stream::DYN_NATS_STREAM_MAX_AGE)
511                .ok()
512                .and_then(|s| s.parse::<u64>().ok())
513                .map(time::Duration::from_secs)
514                .unwrap_or_else(|| time::Duration::from_secs(60 * 60));
515
516            let stream_config = jetstream::stream::Config {
517                name: self.stream_name.clone(),
518                subjects: vec![self.subject.clone()],
519                max_age,
520                ..Default::default()
521            };
522
523            // Get or create the stream
524            let stream = client
525                .jetstream()
526                .get_or_create_stream(stream_config)
527                .await?;
528
529            log::debug!("Stream {} is ready", self.stream_name);
530
531            // If reset_stream is true, purge all messages from the stream
532            if reset_stream {
533                match stream.purge().await {
534                    Ok(purge_info) => {
535                        log::info!(
536                            "Successfully purged {} messages from NATS stream {}",
537                            purge_info.purged,
538                            self.stream_name
539                        );
540                    }
541                    Err(e) => {
542                        log::warn!("Failed to purge NATS stream '{}': {e}", self.stream_name);
543                    }
544                }
545            }
546
547            // Create persistent subscriber only if consumer_name is set
548            if let Some(ref consumer_name) = self.consumer_name {
549                let consumer_config = jetstream::consumer::pull::Config {
550                    durable_name: Some(consumer_name.clone()),
551                    inactive_threshold: std::time::Duration::from_secs(300), // 5 minutes
552                    ..Default::default()
553                };
554
555                let subscriber = stream.create_consumer(consumer_config).await?;
556
557                // Create the message stream for efficient consumption
558                let message_stream = subscriber.messages().await?;
559
560                self.subscriber = Some(subscriber);
561                self.message_stream = Some(message_stream);
562            }
563
564            self.client = Some(client);
565        }
566
567        Ok(())
568    }
569
570    /// Ensure we have an active connection
571    pub async fn ensure_connection(&mut self) -> Result<()> {
572        if self.client.is_none() {
573            self.connect().await?;
574        }
575        Ok(())
576    }
577
578    /// Close the connection when done
579    pub async fn close(&mut self) -> Result<()> {
580        self.message_stream = None;
581        self.subscriber = None;
582        self.client = None;
583        Ok(())
584    }
585
586    /// Shutdown the consumer by deleting it from the stream and closing the connection
587    /// This permanently removes the consumer from the server
588    ///
589    /// If `consumer_name` is provided, that specific consumer will be deleted instead of the
590    /// current consumer. This allows deletion of other consumers on the same stream.
591    pub async fn shutdown(&mut self, consumer_name: Option<String>) -> Result<()> {
592        // Determine which consumer to delete
593        let target_consumer = consumer_name.as_ref().or(self.consumer_name.as_ref());
594
595        // Warn if deleting our own consumer via explicit parameter
596        if let Some(ref passed_name) = consumer_name
597            && self.consumer_name.as_ref() == Some(passed_name)
598        {
599            log::warn!(
600                "Deleting our own consumer '{}' via explicit consumer_name parameter. \
601                Consider calling shutdown without arguments instead.",
602                passed_name
603            );
604        }
605
606        if let (Some(client), Some(consumer_to_delete)) = (&self.client, target_consumer) {
607            // Get the stream and delete the consumer
608            let stream = client.jetstream().get_stream(&self.stream_name).await?;
609            stream
610                .delete_consumer(consumer_to_delete)
611                .await
612                .map_err(|e| {
613                    anyhow::anyhow!("Failed to delete consumer {}: {}", consumer_to_delete, e)
614                })?;
615            log::debug!(
616                "Deleted consumer {} from stream {}",
617                consumer_to_delete,
618                self.stream_name
619            );
620        } else {
621            log::debug!(
622                "Cannot shutdown consumer: client or target consumer is None (client: {:?}, target_consumer: {:?})",
623                self.client.is_some(),
624                target_consumer.is_some()
625            );
626        }
627
628        // Only close the connection if we deleted our own consumer
629        if consumer_name.is_none() {
630            self.close().await
631        } else {
632            Ok(())
633        }
634    }
635
636    /// Count the number of consumers for the stream
637    pub async fn count_consumers(&mut self) -> Result<usize> {
638        self.ensure_connection().await?;
639
640        if let Some(client) = &self.client {
641            let mut stream = client.jetstream().get_stream(&self.stream_name).await?;
642            let info = stream.info().await?;
643            Ok(info.state.consumer_count)
644        } else {
645            Err(anyhow::anyhow!("Client not connected"))
646        }
647    }
648
649    /// List all consumer names for the stream
650    pub async fn list_consumers(&mut self) -> Result<Vec<String>> {
651        self.ensure_connection().await?;
652
653        if let Some(client) = &self.client {
654            client.list_consumers(&self.stream_name).await
655        } else {
656            Err(anyhow::anyhow!("Client not connected"))
657        }
658    }
659
660    /// Enqueue a task using the provided data
661    pub async fn enqueue_task(&mut self, task_data: Bytes) -> Result<()> {
662        self.ensure_connection().await?;
663
664        if let Some(client) = &self.client {
665            let subject = format!("{}.queue", self.stream_name);
666            client.jetstream().publish(subject, task_data).await?;
667            Ok(())
668        } else {
669            Err(anyhow::anyhow!("Client not connected"))
670        }
671    }
672
673    /// Dequeue and return a task as raw bytes
674    pub async fn dequeue_task(&mut self, timeout: Option<time::Duration>) -> Result<Option<Bytes>> {
675        self.ensure_connection().await?;
676
677        let Some(ref mut stream) = self.message_stream else {
678            return Err(anyhow::anyhow!("Message stream not initialized"));
679        };
680
681        let timeout_duration = timeout.unwrap_or(self.dequeue_timeout);
682
683        // Try to get next message from the stream with timeout
684        let message = tokio::time::timeout(timeout_duration, stream.next()).await;
685
686        match message {
687            Ok(Some(Ok(msg))) => {
688                msg.ack()
689                    .await
690                    .map_err(|e| anyhow::anyhow!("Failed to ack message: {}", e))?;
691                Ok(Some(msg.payload.clone()))
692            }
693
694            Ok(Some(Err(e))) => Err(anyhow::anyhow!("Failed to get message from stream: {}", e)),
695
696            Ok(None) => Err(anyhow::anyhow!("Message stream ended unexpectedly")),
697
698            // Timeout - no messages available
699            Err(_) => Ok(None),
700        }
701    }
702
703    /// Get the number of messages currently in the queue
704    pub async fn get_queue_size(&mut self) -> Result<u64> {
705        self.ensure_connection().await?;
706
707        if let Some(client) = &self.client {
708            // Get consumer info to get pending messages count
709            let stream = client.jetstream().get_stream(&self.stream_name).await?;
710            let consumer_name = self
711                .consumer_name
712                .clone()
713                .unwrap_or_else(|| "worker-group".to_string());
714            let mut consumer: jetstream::consumer::PullConsumer = stream
715                .get_consumer(&consumer_name)
716                .await
717                .map_err(|e| anyhow::anyhow!("Failed to get consumer: {}", e))?;
718            let info = consumer.info().await?;
719
720            Ok(info.num_pending)
721        } else {
722            Err(anyhow::anyhow!("Client not connected"))
723        }
724    }
725
726    /// Get the total number of messages currently in the stream
727    pub async fn get_stream_messages(&mut self) -> Result<u64> {
728        self.ensure_connection().await?;
729
730        if let Some(client) = &self.client {
731            let mut stream = client.jetstream().get_stream(&self.stream_name).await?;
732            let info = stream.info().await?;
733            Ok(info.state.messages)
734        } else {
735            Err(anyhow::anyhow!("Client not connected"))
736        }
737    }
738
739    /// Purge messages from the stream up to (but not including) the specified sequence number
740    /// This permanently removes messages and affects all consumers of the stream
741    pub async fn purge_up_to_sequence(&self, sequence: u64) -> Result<()> {
742        if let Some(client) = &self.client {
743            let stream = client.jetstream().get_stream(&self.stream_name).await?;
744
745            // NOTE: this purge excludes the sequence itself
746            // https://docs.rs/nats/latest/nats/jetstream/struct.PurgeRequest.html
747            stream.purge().sequence(sequence).await.map_err(|e| {
748                anyhow::anyhow!("Failed to purge stream up to sequence {}: {}", sequence, e)
749            })?;
750
751            log::debug!(
752                "Purged stream {} up to sequence {}",
753                self.stream_name,
754                sequence
755            );
756            Ok(())
757        } else {
758            Err(anyhow::anyhow!("Client not connected"))
759        }
760    }
761
762    /// Purge messages from the stream up to the minimum acknowledged sequence across all consumers
763    /// This finds the lowest acknowledged sequence number across all consumers and purges up to that point
764    pub async fn purge_acknowledged(&mut self) -> Result<()> {
765        self.ensure_connection().await?;
766
767        let Some(client) = &self.client else {
768            return Err(anyhow::anyhow!("Client not connected"));
769        };
770
771        let stream = client.jetstream().get_stream(&self.stream_name).await?;
772
773        // Get all consumer names for the stream
774        let consumer_names: Vec<String> = stream
775            .consumer_names()
776            .try_collect()
777            .await
778            .map_err(|e| anyhow::anyhow!("Failed to list consumers: {}", e))?;
779
780        if consumer_names.is_empty() {
781            log::debug!("No consumers found for stream {}", self.stream_name);
782            return Ok(());
783        }
784
785        // Find the minimum acknowledged sequence across all consumers
786        let mut min_ack_sequence = u64::MAX;
787
788        for consumer_name in &consumer_names {
789            let mut consumer: jetstream::consumer::PullConsumer = stream
790                .get_consumer(consumer_name)
791                .await
792                .map_err(|e| anyhow::anyhow!("Failed to get consumer {}: {}", consumer_name, e))?;
793
794            let info = consumer.info().await.map_err(|e| {
795                anyhow::anyhow!("Failed to get consumer info for {}: {}", consumer_name, e)
796            })?;
797
798            // The ack_floor contains the stream sequence of the highest contiguously acknowledged message
799            // If stream_sequence is 0, it means no messages have been acknowledged yet
800            if info.ack_floor.stream_sequence > 0 {
801                min_ack_sequence = min_ack_sequence.min(info.ack_floor.stream_sequence);
802                log::debug!(
803                    "Consumer {} has ack_floor at sequence {}",
804                    consumer_name,
805                    info.ack_floor.stream_sequence
806                );
807            }
808        }
809
810        // Only purge if we found a valid minimum acknowledged sequence
811        if min_ack_sequence < u64::MAX && min_ack_sequence > 0 {
812            // Purge up to (but not including) the minimum acknowledged sequence + 1
813            // We add 1 because we want to include the minimum acknowledged message in the purge
814            let purge_sequence = min_ack_sequence + 1;
815
816            self.purge_up_to_sequence(purge_sequence).await?;
817
818            log::debug!(
819                "Purged stream {} up to acknowledged sequence {} (purged up to sequence {})",
820                self.stream_name,
821                min_ack_sequence,
822                purge_sequence
823            );
824        } else {
825            log::debug!(
826                "No messages to purge for stream {} (min_ack_sequence: {})",
827                self.stream_name,
828                min_ack_sequence
829            );
830        }
831
832        Ok(())
833    }
834}
835
836impl NatsQueue {
837    pub fn event_subject(&self) -> String {
838        self.stream_name.clone()
839    }
840
841    pub async fn publish_event(
842        &self,
843        event_name: impl AsRef<str> + Send + Sync,
844        event: &(impl Serialize + Send + Sync),
845    ) -> Result<()> {
846        let bytes = serde_json::to_vec(event)?;
847        self.publish_event_bytes(event_name, bytes).await
848    }
849
850    pub async fn publish_event_bytes(
851        &self,
852        event_name: impl AsRef<str> + Send + Sync,
853        bytes: Vec<u8>,
854    ) -> Result<()> {
855        let subject = format!("{}.{}", self.event_subject(), event_name.as_ref());
856
857        // Note: enqueue_task requires &mut self, but EventPublisher requires &self
858        // We need to ensure the client is connected and use it directly
859        if let Some(client) = &self.client {
860            client.jetstream().publish(subject, bytes.into()).await?;
861            Ok(())
862        } else {
863            Err(anyhow::anyhow!("Client not connected"))
864        }
865    }
866}
867
868/// The NATS subject / inbox to talk to an instance on.
869/// TODO: Do we need to sanitize the names?
870pub fn instance_subject(endpoint_id: &EndpointId, instance_id: u64) -> String {
871    format!(
872        "{}_{}.{}-{:x}",
873        endpoint_id.namespace, endpoint_id.component, endpoint_id.name, instance_id,
874    )
875}
876
877#[cfg(test)]
878mod tests {
879
880    use super::*;
881    use figment::Jail;
882    use serde::{Deserialize, Serialize};
883
884    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
885    struct TestData {
886        id: u32,
887        name: String,
888        values: Vec<f64>,
889    }
890
891    #[test]
892    fn test_client_options_builder() {
893        Jail::expect_with(|_jail| {
894            let opts = ClientOptions::builder().build();
895            assert!(opts.is_ok());
896            Ok(())
897        });
898
899        Jail::expect_with(|jail| {
900            jail.set_env(env_nats::NATS_SERVER, "nats://localhost:5222");
901            jail.set_env(env_nats::auth::NATS_AUTH_USERNAME, "user");
902            jail.set_env(env_nats::auth::NATS_AUTH_PASSWORD, "pass");
903
904            let opts = ClientOptions::builder().build();
905            assert!(opts.is_ok());
906            let opts = opts.unwrap();
907
908            assert_eq!(opts.server, "nats://localhost:5222");
909            assert_eq!(
910                opts.auth,
911                NatsAuth::UserPass("user".to_string(), "pass".to_string())
912            );
913
914            Ok(())
915        });
916
917        Jail::expect_with(|jail| {
918            jail.set_env(env_nats::NATS_SERVER, "nats://localhost:5222");
919            jail.set_env(env_nats::auth::NATS_AUTH_USERNAME, "user");
920            jail.set_env(env_nats::auth::NATS_AUTH_PASSWORD, "pass");
921
922            let opts = ClientOptions::builder()
923                .server("nats://localhost:6222")
924                .auth(NatsAuth::Token("token".to_string()))
925                .build();
926            assert!(opts.is_ok());
927            let opts = opts.unwrap();
928
929            assert_eq!(opts.server, "nats://localhost:6222");
930            assert_eq!(opts.auth, NatsAuth::Token("token".to_string()));
931
932            Ok(())
933        });
934    }
935
936    // Integration test for object store data operations using bincode
937    #[tokio::test]
938    #[ignore] // Requires NATS server to be running
939    async fn test_object_store_data_operations() {
940        // Create test data
941        let test_data = TestData {
942            id: 42,
943            name: "test_item".to_string(),
944            values: vec![1.0, 2.5, 3.7, 4.2],
945        };
946
947        // Set up client
948        let client_options = ClientOptions::builder()
949            .server("nats://localhost:4222")
950            .build()
951            .expect("Failed to build client options");
952
953        let client = client_options
954            .connect()
955            .await
956            .expect("Failed to connect to NATS");
957
958        // Test URL (using .bin extension to indicate binary format)
959        let url =
960            Url::parse("nats://localhost/test-bucket/test-data.bin").expect("Failed to parse URL");
961
962        // Upload the data
963        client
964            .object_store_upload_data(&test_data, &url)
965            .await
966            .expect("Failed to upload data");
967
968        // Download the data
969        let downloaded_data: TestData = client
970            .object_store_download_data(&url)
971            .await
972            .expect("Failed to download data");
973
974        // Verify the data matches
975        assert_eq!(test_data, downloaded_data);
976
977        // Clean up
978        client
979            .object_store_delete_bucket("test-bucket")
980            .await
981            .expect("Failed to delete bucket");
982    }
983
984    // Integration test for broadcast pattern with purging
985    #[tokio::test]
986    #[ignore]
987    async fn test_nats_queue_broadcast_with_purge() {
988        use uuid::Uuid;
989
990        // Create unique stream name for this test
991        let stream_name = format!("test-broadcast-{}", Uuid::new_v4());
992        let nats_server = "nats://localhost:4222".to_string();
993        let timeout = time::Duration::from_secs(0);
994
995        // Connect to NATS client first to delete stream if it exists
996        let client_options = Client::builder()
997            .server(nats_server.clone())
998            .build()
999            .expect("Failed to build client options");
1000
1001        let client = client_options
1002            .connect()
1003            .await
1004            .expect("Failed to connect to NATS");
1005
1006        // Delete the stream if it exists (to ensure clean start)
1007        let _ = client.jetstream().delete_stream(&stream_name).await;
1008
1009        // Create two consumers with different names for the same stream
1010        let consumer1_name = format!("consumer-{}", Uuid::new_v4());
1011        let consumer2_name = format!("consumer-{}", Uuid::new_v4());
1012
1013        let mut queue1 = NatsQueue::new_with_consumer(
1014            stream_name.clone(),
1015            nats_server.clone(),
1016            timeout,
1017            consumer1_name,
1018        );
1019
1020        // Connect queue1 first (it will create the stream)
1021        queue1.connect().await.expect("Failed to connect queue1");
1022
1023        // Send 4 messages using the EventPublisher trait
1024        let message_strings = [
1025            "message1".to_string(),
1026            "message2".to_string(),
1027            "message3".to_string(),
1028            "message4".to_string(),
1029        ];
1030
1031        // Publish messages using NatsQueue
1032        for (idx, msg) in message_strings.iter().enumerate() {
1033            queue1
1034                .publish_event("queue", msg)
1035                .await
1036                .unwrap_or_else(|_| panic!("Failed to publish message {}", idx + 1));
1037        }
1038
1039        // Convert messages to JSON-serialized Bytes for comparison
1040        let messages: Vec<Bytes> = message_strings
1041            .iter()
1042            .map(|s| Bytes::from(serde_json::to_vec(s).unwrap()))
1043            .collect();
1044
1045        // Give JetStream a moment to persist the messages
1046        tokio::time::sleep(time::Duration::from_millis(100)).await;
1047
1048        // Now create and connect queue2 and queue3 AFTER messages are published (to test persistence)
1049        let mut queue2 = NatsQueue::new_with_consumer(
1050            stream_name.clone(),
1051            nats_server.clone(),
1052            timeout,
1053            consumer2_name,
1054        );
1055
1056        // Create a third queue without consumer (publisher-only)
1057        let mut queue3 =
1058            NatsQueue::new_without_consumer(stream_name.clone(), nats_server.clone(), timeout);
1059
1060        // Connect queue2 and queue3 after messages are already published
1061        queue2.connect().await.expect("Failed to connect queue2");
1062        queue3.connect().await.expect("Failed to connect queue3");
1063
1064        // Purge the first two messages (sequence 1 and 2)
1065        // Note: JetStream sequences start at 1, and purge is exclusive of the sequence number
1066        queue1
1067            .purge_up_to_sequence(3)
1068            .await
1069            .expect("Failed to purge messages");
1070
1071        // Give JetStream a moment to process the purge
1072        tokio::time::sleep(time::Duration::from_millis(100)).await;
1073
1074        // Consumer 1 dequeues one message (message3)
1075        let msg3_consumer1 = queue1
1076            .dequeue_task(Some(time::Duration::from_millis(500)))
1077            .await
1078            .expect("Failed to dequeue from queue1");
1079        assert_eq!(
1080            msg3_consumer1,
1081            Some(messages[2].clone()),
1082            "Consumer 1 should get message3"
1083        );
1084
1085        // Give JetStream a moment to process acknowledgments
1086        tokio::time::sleep(time::Duration::from_millis(100)).await;
1087
1088        // Now run purge_acknowledged
1089        // At this point:
1090        // - Consumer 1 has ack'd message 3 (ack_floor = 3)
1091        // - Consumer 2 hasn't consumed anything yet (ack_floor = 0)
1092        // - Min ack_floor = 0, so nothing will be purged
1093        queue1
1094            .purge_acknowledged()
1095            .await
1096            .expect("Failed to purge acknowledged messages");
1097
1098        // Give JetStream a moment to process the purge
1099        tokio::time::sleep(time::Duration::from_millis(100)).await;
1100
1101        // Now collect remaining messages from both consumers
1102        let mut consumer1_remaining = Vec::new();
1103        let mut consumer2_remaining = Vec::new();
1104
1105        // Collect remaining messages from consumer 1
1106        while let Some(msg) = queue1
1107            .dequeue_task(None)
1108            .await
1109            .expect("Failed to dequeue from queue1")
1110        {
1111            consumer1_remaining.push(msg);
1112        }
1113
1114        // Collect remaining messages from consumer 2
1115        while let Some(msg) = queue2
1116            .dequeue_task(None)
1117            .await
1118            .expect("Failed to dequeue from queue2")
1119        {
1120            consumer2_remaining.push(msg);
1121        }
1122
1123        // Verify consumer 1 gets 1 remaining message (message4)
1124        assert_eq!(
1125            consumer1_remaining.len(),
1126            1,
1127            "Consumer 1 should have 1 remaining message"
1128        );
1129        assert_eq!(
1130            consumer1_remaining[0], messages[3],
1131            "Consumer 1 should get message4"
1132        );
1133
1134        // Verify consumer 2 gets 2 messages (message3 and message4)
1135        assert_eq!(
1136            consumer2_remaining.len(),
1137            2,
1138            "Consumer 2 should have 2 messages"
1139        );
1140        assert_eq!(
1141            consumer2_remaining[0], messages[2],
1142            "Consumer 2 should get message3"
1143        );
1144        assert_eq!(
1145            consumer2_remaining[1], messages[3],
1146            "Consumer 2 should get message4"
1147        );
1148
1149        // Test consumer count and shutdown behavior
1150        // First verify via consumer 1 that there are two consumers
1151        let consumer_count = queue1
1152            .count_consumers()
1153            .await
1154            .expect("Failed to count consumers");
1155        assert_eq!(consumer_count, 2, "Should have 2 consumers initially");
1156
1157        // Close consumer 1 and verify via consumer 2 that there are still two consumers
1158        queue1.close().await.expect("Failed to close queue1");
1159
1160        let consumer_count = queue2
1161            .count_consumers()
1162            .await
1163            .expect("Failed to count consumers");
1164        assert_eq!(
1165            consumer_count, 2,
1166            "Should still have 2 consumers after closing queue1"
1167        );
1168
1169        // Reconnect queue1 to be able to shutdown
1170        queue1.connect().await.expect("Failed to reconnect queue1");
1171
1172        // Shutdown consumer 1 and verify via consumer 2 that there is only one consumer left
1173        queue1
1174            .shutdown(None)
1175            .await
1176            .expect("Failed to shutdown queue1");
1177
1178        let consumer_count = queue2
1179            .count_consumers()
1180            .await
1181            .expect("Failed to count consumers");
1182        assert_eq!(
1183            consumer_count, 1,
1184            "Should have only 1 consumer after shutting down queue1"
1185        );
1186
1187        // Clean up by deleting the stream
1188        client
1189            .jetstream()
1190            .delete_stream(&stream_name)
1191            .await
1192            .expect("Failed to delete test stream");
1193    }
1194}