dynamo_runtime/transports/
nats.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! NATS transport
5//!
6//! The following environment variables are used to configure the NATS client:
7//!
8//! - `NATS_SERVER`: the NATS server address
9//!
10//! For authentication, the following environment variables are used and prioritized in the following order:
11//!
12//! - `NATS_AUTH_USERNAME`: the username for authentication
13//! - `NATS_AUTH_PASSWORD`: the password for authentication
14//! - `NATS_AUTH_TOKEN`: the token for authentication
15//! - `NATS_AUTH_NKEY`: the nkey for authentication
16//! - `NATS_AUTH_CREDENTIALS_FILE`: the path to the credentials file
17//!
18//! Note: `NATS_AUTH_USERNAME` and `NATS_AUTH_PASSWORD` must be used together.
19use crate::traits::events::EventPublisher;
20use crate::{Result, metrics::MetricsRegistry};
21
22use async_nats::connection::State;
23use async_nats::{Subscriber, client, jetstream};
24use async_trait::async_trait;
25use bytes::Bytes;
26use derive_builder::Builder;
27use futures::{StreamExt, TryStreamExt};
28use prometheus::{Counter, Gauge, Histogram, HistogramOpts, IntCounter, IntGauge, Opts, Registry};
29use serde::de::DeserializeOwned;
30use serde::{Deserialize, Serialize};
31use std::path::{Path, PathBuf};
32use std::sync::atomic::Ordering;
33use tokio::fs::File as TokioFile;
34use tokio::io::AsyncRead;
35use tokio::time;
36use url::Url;
37use validator::{Validate, ValidationError};
38
39use crate::metrics::prometheus_names::nats_client as nats_metrics;
40pub use crate::slug::Slug;
41use tracing as log;
42
43use super::utils::build_in_runtime;
44
45pub const URL_PREFIX: &str = "nats://";
46
47#[derive(Clone)]
48pub struct Client {
49    client: client::Client,
50    js_ctx: jetstream::Context,
51}
52
53impl Client {
54    /// Create a NATS [`ClientOptionsBuilder`].
55    pub fn builder() -> ClientOptionsBuilder {
56        ClientOptionsBuilder::default()
57    }
58
59    /// Returns a reference to the underlying [`async_nats::client::Client`] instance
60    pub fn client(&self) -> &client::Client {
61        &self.client
62    }
63
64    /// Returns a reference to the underlying [`async_nats::jetstream::Context`] instance
65    pub fn jetstream(&self) -> &jetstream::Context {
66        &self.js_ctx
67    }
68
69    /// host:port of NATS
70    pub fn addr(&self) -> String {
71        let info = self.client.server_info();
72        format!("{}:{}", info.host, info.port)
73    }
74
75    /// fetch the list of streams
76    pub async fn list_streams(&self) -> Result<Vec<String>> {
77        let names = self.js_ctx.stream_names();
78        let stream_names: Vec<String> = names.try_collect().await?;
79        Ok(stream_names)
80    }
81
82    /// fetch the list of consumers for a given stream
83    pub async fn list_consumers(&self, stream_name: &str) -> Result<Vec<String>> {
84        let stream = self.js_ctx.get_stream(stream_name).await?;
85        let consumers: Vec<String> = stream.consumer_names().try_collect().await?;
86        Ok(consumers)
87    }
88
89    pub async fn stream_info(&self, stream_name: &str) -> Result<jetstream::stream::State> {
90        let mut stream = self.js_ctx.get_stream(stream_name).await?;
91        let info = stream.info().await?;
92        Ok(info.state.clone())
93    }
94
95    pub async fn get_stream(&self, name: &str) -> Result<jetstream::stream::Stream> {
96        let stream = self.js_ctx.get_stream(name).await?;
97        Ok(stream)
98    }
99
100    /// Issues a broadcast request for all services with the provided `service_name` to report their
101    /// current stats. Each service will only respond once. The service may have customized the reply
102    /// so the caller should select which endpoint and what concrete data model should be used to
103    /// extract the details.
104    ///
105    /// Note: Because each endpoint will only reply once, the caller must drop the subscription after
106    /// some time or it will await forever.
107    pub async fn scrape_service(&self, service_name: &str) -> Result<Subscriber> {
108        let subject = format!("$SRV.STATS.{}", service_name);
109        let reply_subject = format!("_INBOX.{}", nuid::next());
110        let subscription = self.client.subscribe(reply_subject.clone()).await?;
111
112        // Publish the request with the reply-to subject
113        self.client
114            .publish_with_reply(subject, reply_subject, "".into())
115            .await?;
116
117        Ok(subscription)
118    }
119
120    /// Helper method to get or optionally create an object store bucket
121    ///
122    /// # Arguments
123    /// * `bucket_name` - The name of the bucket to retrieve
124    /// * `create_if_not_found` - If true, creates the bucket when it doesn't exist
125    ///
126    /// # Returns
127    /// The object store bucket or an error
128    async fn get_or_create_bucket(
129        &self,
130        bucket_name: &str,
131        create_if_not_found: bool,
132    ) -> anyhow::Result<jetstream::object_store::ObjectStore> {
133        let context = self.jetstream();
134
135        match context.get_object_store(bucket_name).await {
136            Ok(bucket) => Ok(bucket),
137            Err(err) if err.to_string().contains("stream not found") => {
138                // err.source() is GetStreamError, which has a kind() which
139                // is GetStreamErrorKind::JetStream which wraps a jetstream::Error
140                // which has code 404. Phew. So yeah check the string for now.
141
142                if create_if_not_found {
143                    tracing::debug!("Creating NATS bucket {bucket_name}");
144                    context
145                        .create_object_store(jetstream::object_store::Config {
146                            bucket: bucket_name.to_string(),
147                            ..Default::default()
148                        })
149                        .await
150                        .map_err(|e| anyhow::anyhow!("Failed creating bucket / object store: {e}"))
151                } else {
152                    anyhow::bail!(
153                        "NATS get_object_store bucket does not exist: {bucket_name}. {err}."
154                    );
155                }
156            }
157            Err(err) => {
158                anyhow::bail!("NATS get_object_store error: {err}");
159            }
160        }
161    }
162
163    /// Upload file to NATS at this URL
164    pub async fn object_store_upload(&self, filepath: &Path, nats_url: &Url) -> anyhow::Result<()> {
165        let mut disk_file = TokioFile::open(filepath).await?;
166
167        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
168        let bucket = self.get_or_create_bucket(&bucket_name, true).await?;
169
170        let key_meta = async_nats::jetstream::object_store::ObjectMetadata {
171            name: key.to_string(),
172            ..Default::default()
173        };
174        bucket.put(key_meta, &mut disk_file).await.map_err(|e| {
175            anyhow::anyhow!("Failed uploading to bucket / object store {bucket_name}/{key}: {e}")
176        })?;
177
178        Ok(())
179    }
180
181    /// Download file from NATS at this URL
182    pub async fn object_store_download(
183        &self,
184        nats_url: &Url,
185        filepath: &Path,
186    ) -> anyhow::Result<()> {
187        let mut disk_file = TokioFile::create(filepath).await?;
188
189        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
190        let bucket = self.get_or_create_bucket(&bucket_name, false).await?;
191
192        let mut obj_reader = bucket.get(&key).await.map_err(|e| {
193            anyhow::anyhow!(
194                "Failed downloading from bucket / object store {bucket_name}/{key}: {e}"
195            )
196        })?;
197        let _bytes_copied = tokio::io::copy(&mut obj_reader, &mut disk_file).await?;
198
199        Ok(())
200    }
201
202    /// Delete a bucket and all it's contents from the NATS object store
203    pub async fn object_store_delete_bucket(&self, bucket_name: &str) -> anyhow::Result<()> {
204        let context = self.jetstream();
205        match context.delete_object_store(&bucket_name).await {
206            Ok(_) => Ok(()),
207            Err(err) if err.to_string().contains("stream not found") => {
208                tracing::trace!(bucket_name, "NATS bucket already gone");
209                Ok(())
210            }
211            Err(err) => Err(anyhow::anyhow!("NATS get_object_store error: {err}")),
212        }
213    }
214
215    /// Upload a serializable struct to NATS object store using bincode
216    pub async fn object_store_upload_data<T>(&self, data: &T, nats_url: &Url) -> anyhow::Result<()>
217    where
218        T: Serialize,
219    {
220        // Serialize the data using bincode (more efficient binary format)
221        let binary_data = bincode::serialize(data)
222            .map_err(|e| anyhow::anyhow!("Failed to serialize data with bincode: {e}"))?;
223
224        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
225        let bucket = self.get_or_create_bucket(&bucket_name, true).await?;
226
227        let key_meta = async_nats::jetstream::object_store::ObjectMetadata {
228            name: key.to_string(),
229            ..Default::default()
230        };
231
232        // Upload the serialized bytes
233        let mut cursor = std::io::Cursor::new(binary_data);
234        bucket.put(key_meta, &mut cursor).await.map_err(|e| {
235            anyhow::anyhow!("Failed uploading to bucket / object store {bucket_name}/{key}: {e}")
236        })?;
237
238        Ok(())
239    }
240
241    /// Download and deserialize a struct from NATS object store using bincode
242    pub async fn object_store_download_data<T>(&self, nats_url: &Url) -> anyhow::Result<T>
243    where
244        T: DeserializeOwned,
245    {
246        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
247        let bucket = self.get_or_create_bucket(&bucket_name, false).await?;
248
249        let mut obj_reader = bucket.get(&key).await.map_err(|e| {
250            anyhow::anyhow!(
251                "Failed downloading from bucket / object store {bucket_name}/{key}: {e}"
252            )
253        })?;
254
255        // Read all bytes into memory
256        let mut buffer = Vec::new();
257        tokio::io::copy(&mut obj_reader, &mut buffer)
258            .await
259            .map_err(|e| anyhow::anyhow!("Failed reading object data: {e}"))?;
260        tracing::debug!("Downloaded {} bytes from {bucket_name}/{key}", buffer.len());
261
262        // Deserialize from bincode
263        let data = bincode::deserialize(&buffer)
264            .map_err(|e| anyhow::anyhow!("Failed to deserialize data with bincode: {e}"))?;
265
266        Ok(data)
267    }
268}
269
270/// NATS client options
271///
272/// This object uses the builder pattern with default values that are evaluates
273/// from the environment variables if they are not explicitly set by the builder.
274#[derive(Debug, Clone, Builder, Validate)]
275pub struct ClientOptions {
276    #[builder(setter(into), default = "default_server()")]
277    #[validate(custom(function = "validate_nats_server"))]
278    server: String,
279
280    #[builder(default)]
281    auth: NatsAuth,
282}
283
284fn default_server() -> String {
285    if let Ok(server) = std::env::var("NATS_SERVER") {
286        return server;
287    }
288
289    "nats://localhost:4222".to_string()
290}
291
292fn validate_nats_server(server: &str) -> Result<(), ValidationError> {
293    if server.starts_with("nats://") {
294        Ok(())
295    } else {
296        Err(ValidationError::new("server must start with 'nats://'"))
297    }
298}
299
300// TODO(jthomson04): We really shouldn't be hardcoding this.
301const NATS_WORKER_THREADS: usize = 4;
302
303impl ClientOptions {
304    /// Create a new [`ClientOptionsBuilder`]
305    pub fn builder() -> ClientOptionsBuilder {
306        ClientOptionsBuilder::default()
307    }
308
309    /// Validate the config and attempt to connection to the NATS server
310    pub async fn connect(self) -> Result<Client> {
311        self.validate()?;
312
313        let client = match self.auth {
314            NatsAuth::UserPass(username, password) => {
315                async_nats::ConnectOptions::with_user_and_password(username, password)
316            }
317            NatsAuth::Token(token) => async_nats::ConnectOptions::with_token(token),
318            NatsAuth::NKey(nkey) => async_nats::ConnectOptions::with_nkey(nkey),
319            NatsAuth::CredentialsFile(path) => {
320                async_nats::ConnectOptions::with_credentials_file(path).await?
321            }
322        };
323
324        let (client, _) = build_in_runtime(
325            async move {
326                client
327                    .connect(self.server)
328                    .await
329                    .map_err(|e| anyhow::anyhow!("Failed to connect to NATS: {e}"))
330            },
331            NATS_WORKER_THREADS,
332        )
333        .await?;
334
335        let js_ctx = jetstream::new(client.clone());
336
337        // Validate JetStream is available
338        js_ctx
339            .query_account()
340            .await
341            .map_err(|e| anyhow::anyhow!("JetStream not available: {e}"))?;
342
343        Ok(Client { client, js_ctx })
344    }
345}
346
347impl Default for ClientOptions {
348    fn default() -> Self {
349        ClientOptions {
350            server: default_server(),
351            auth: NatsAuth::default(),
352        }
353    }
354}
355
356#[derive(Clone, Eq, PartialEq)]
357pub enum NatsAuth {
358    UserPass(String, String),
359    Token(String),
360    NKey(String),
361    CredentialsFile(PathBuf),
362}
363
364impl std::fmt::Debug for NatsAuth {
365    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366        match self {
367            NatsAuth::UserPass(user, _pass) => {
368                write!(f, "UserPass({}, <redacted>)", user)
369            }
370            NatsAuth::Token(_token) => write!(f, "Token(<redacted>)"),
371            NatsAuth::NKey(_nkey) => write!(f, "NKey(<redacted>)"),
372            NatsAuth::CredentialsFile(path) => write!(f, "CredentialsFile({:?})", path),
373        }
374    }
375}
376
377impl Default for NatsAuth {
378    fn default() -> Self {
379        if let (Ok(username), Ok(password)) = (
380            std::env::var("NATS_AUTH_USERNAME"),
381            std::env::var("NATS_AUTH_PASSWORD"),
382        ) {
383            return NatsAuth::UserPass(username, password);
384        }
385
386        if let Ok(token) = std::env::var("NATS_AUTH_TOKEN") {
387            return NatsAuth::Token(token);
388        }
389
390        if let Ok(nkey) = std::env::var("NATS_AUTH_NKEY") {
391            return NatsAuth::NKey(nkey);
392        }
393
394        if let Ok(path) = std::env::var("NATS_AUTH_CREDENTIALS_FILE") {
395            return NatsAuth::CredentialsFile(PathBuf::from(path));
396        }
397
398        NatsAuth::UserPass("user".to_string(), "user".to_string())
399    }
400}
401
402/// Is this file name / url in the NATS object store?
403/// Checks the name only, does not go to the store.
404pub fn is_nats_url(s: &str) -> bool {
405    s.starts_with(URL_PREFIX)
406}
407
408/// Extract NATS bucket and key from a nats URL of the form:
409/// nats://host[:port]/bucket/key
410pub fn url_to_bucket_and_key(url: &Url) -> anyhow::Result<(String, String)> {
411    let Some(mut path_segments) = url.path_segments() else {
412        anyhow::bail!("No path in NATS URL: {url}");
413    };
414    let Some(bucket) = path_segments.next() else {
415        anyhow::bail!("No bucket in NATS URL: {url}");
416    };
417    let Some(key) = path_segments.next() else {
418        anyhow::bail!("No key in NATS URL: {url}");
419    };
420    Ok((bucket.to_string(), key.to_string()))
421}
422
423/// Default queue name for publishing events
424pub const QUEUE_NAME: &str = "queue";
425
426/// A queue implementation using NATS JetStream
427pub struct NatsQueue {
428    /// The name of the stream to use for the queue
429    stream_name: String,
430    /// The NATS server URL
431    nats_server: String,
432    /// Timeout for dequeue operations in seconds
433    dequeue_timeout: time::Duration,
434    /// The NATS client
435    client: Option<Client>,
436    /// The subject pattern used for this queue
437    subject: String,
438    /// The subscriber for pull-based consumption
439    subscriber: Option<jetstream::consumer::PullConsumer>,
440    /// Optional consumer name for broadcast pattern (if None, uses "worker-group")
441    consumer_name: Option<String>,
442}
443
444impl NatsQueue {
445    /// Create a new NatsQueue with the default "worker-group" consumer
446    pub fn new(stream_name: String, nats_server: String, dequeue_timeout: time::Duration) -> Self {
447        // Sanitize stream name to remove path separators (like in Python version)
448        // rupei: are we sure NATs stream name accepts '_'?
449        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
450        let subject = format!("{sanitized_stream_name}.*");
451
452        Self {
453            stream_name: sanitized_stream_name,
454            nats_server,
455            dequeue_timeout,
456            client: None,
457            subject,
458            subscriber: None,
459            consumer_name: Some("worker-group".to_string()),
460        }
461    }
462
463    /// Create a new NatsQueue without a consumer (publisher-only mode)
464    pub fn new_without_consumer(
465        stream_name: String,
466        nats_server: String,
467        dequeue_timeout: time::Duration,
468    ) -> Self {
469        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
470        let subject = format!("{sanitized_stream_name}.*");
471
472        Self {
473            stream_name: sanitized_stream_name,
474            nats_server,
475            dequeue_timeout,
476            client: None,
477            subject,
478            subscriber: None,
479            consumer_name: None,
480        }
481    }
482
483    /// Create a new NatsQueue with a specific consumer name for broadcast pattern
484    /// Each consumer with a unique name will receive all messages independently
485    pub fn new_with_consumer(
486        stream_name: String,
487        nats_server: String,
488        dequeue_timeout: time::Duration,
489        consumer_name: String,
490    ) -> Self {
491        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
492        let subject = format!("{sanitized_stream_name}.*");
493
494        Self {
495            stream_name: sanitized_stream_name,
496            nats_server,
497            dequeue_timeout,
498            client: None,
499            subject,
500            subscriber: None,
501            consumer_name: Some(consumer_name),
502        }
503    }
504
505    /// Connect to the NATS server and set up the stream and consumer
506    pub async fn connect(&mut self) -> Result<()> {
507        self.connect_with_reset(false).await
508    }
509
510    /// Connect to the NATS server and set up the stream and consumer, optionally resetting the stream
511    pub async fn connect_with_reset(&mut self, reset_stream: bool) -> Result<()> {
512        if self.client.is_none() {
513            // Create a new client
514            let client_options = Client::builder().server(self.nats_server.clone()).build()?;
515
516            let client = client_options.connect().await?;
517
518            // messages older than a hour in the stream will be automatically purged
519            let max_age = std::env::var("DYN_NATS_STREAM_MAX_AGE")
520                .ok()
521                .and_then(|s| s.parse::<u64>().ok())
522                .map(time::Duration::from_secs)
523                .unwrap_or_else(|| time::Duration::from_secs(60 * 60));
524
525            let stream_config = jetstream::stream::Config {
526                name: self.stream_name.clone(),
527                subjects: vec![self.subject.clone()],
528                max_age,
529                ..Default::default()
530            };
531
532            // Get or create the stream
533            let stream = client
534                .jetstream()
535                .get_or_create_stream(stream_config)
536                .await?;
537
538            log::debug!("Stream {} is ready", self.stream_name);
539
540            // If reset_stream is true, purge all messages from the stream
541            if reset_stream {
542                match stream.purge().await {
543                    Ok(purge_info) => {
544                        log::info!(
545                            "Successfully purged {} messages from NATS stream {}",
546                            purge_info.purged,
547                            self.stream_name
548                        );
549                    }
550                    Err(e) => {
551                        log::warn!("Failed to purge NATS stream '{}': {e}", self.stream_name);
552                    }
553                }
554            }
555
556            // Create persistent subscriber only if consumer_name is set
557            if let Some(ref consumer_name) = self.consumer_name {
558                let consumer_config = jetstream::consumer::pull::Config {
559                    durable_name: Some(consumer_name.clone()),
560                    inactive_threshold: std::time::Duration::from_secs(3600), // 1 hour
561                    ..Default::default()
562                };
563
564                let subscriber = stream.create_consumer(consumer_config).await?;
565                self.subscriber = Some(subscriber);
566            }
567
568            self.client = Some(client);
569        }
570
571        Ok(())
572    }
573
574    /// Ensure we have an active connection
575    pub async fn ensure_connection(&mut self) -> Result<()> {
576        if self.client.is_none() {
577            self.connect().await?;
578        }
579        Ok(())
580    }
581
582    /// Close the connection when done
583    pub async fn close(&mut self) -> Result<()> {
584        self.subscriber = None;
585        self.client = None;
586        Ok(())
587    }
588
589    /// Shutdown the consumer by deleting it from the stream and closing the connection
590    /// This permanently removes the consumer from the server
591    ///
592    /// If `consumer_name` is provided, that specific consumer will be deleted instead of the
593    /// current consumer. This allows deletion of other consumers on the same stream.
594    pub async fn shutdown(&mut self, consumer_name: Option<String>) -> Result<()> {
595        // Determine which consumer to delete
596        let target_consumer = consumer_name.as_ref().or(self.consumer_name.as_ref());
597
598        // Warn if deleting our own consumer via explicit parameter
599        if let Some(ref passed_name) = consumer_name
600            && self.consumer_name.as_ref() == Some(passed_name)
601        {
602            log::warn!(
603                "Deleting our own consumer '{}' via explicit consumer_name parameter. \
604                Consider calling shutdown without arguments instead.",
605                passed_name
606            );
607        }
608
609        if let (Some(client), Some(consumer_to_delete)) = (&self.client, target_consumer) {
610            // Get the stream and delete the consumer
611            let stream = client.jetstream().get_stream(&self.stream_name).await?;
612            stream
613                .delete_consumer(consumer_to_delete)
614                .await
615                .map_err(|e| {
616                    anyhow::anyhow!("Failed to delete consumer {}: {}", consumer_to_delete, e)
617                })?;
618            log::debug!(
619                "Deleted consumer {} from stream {}",
620                consumer_to_delete,
621                self.stream_name
622            );
623        } else {
624            log::debug!(
625                "Cannot shutdown consumer: client or target consumer is None (client: {:?}, target_consumer: {:?})",
626                self.client.is_some(),
627                target_consumer.is_some()
628            );
629        }
630
631        // Only close the connection if we deleted our own consumer
632        if consumer_name.is_none() {
633            self.close().await
634        } else {
635            Ok(())
636        }
637    }
638
639    /// Count the number of consumers for the stream
640    pub async fn count_consumers(&mut self) -> Result<usize> {
641        self.ensure_connection().await?;
642
643        if let Some(client) = &self.client {
644            let mut stream = client.jetstream().get_stream(&self.stream_name).await?;
645            let info = stream.info().await?;
646            Ok(info.state.consumer_count)
647        } else {
648            Err(anyhow::anyhow!("Client not connected"))
649        }
650    }
651
652    /// List all consumer names for the stream
653    pub async fn list_consumers(&mut self) -> Result<Vec<String>> {
654        self.ensure_connection().await?;
655
656        if let Some(client) = &self.client {
657            client.list_consumers(&self.stream_name).await
658        } else {
659            Err(anyhow::anyhow!("Client not connected"))
660        }
661    }
662
663    /// Enqueue a task using the provided data
664    pub async fn enqueue_task(&mut self, task_data: Bytes) -> Result<()> {
665        self.ensure_connection().await?;
666
667        if let Some(client) = &self.client {
668            let subject = format!("{}.queue", self.stream_name);
669            client.jetstream().publish(subject, task_data).await?;
670            Ok(())
671        } else {
672            Err(anyhow::anyhow!("Client not connected"))
673        }
674    }
675
676    /// Dequeue and return a task as raw bytes
677    pub async fn dequeue_task(&mut self, timeout: Option<time::Duration>) -> Result<Option<Bytes>> {
678        self.ensure_connection().await?;
679
680        if let Some(subscriber) = &self.subscriber {
681            let timeout_duration = timeout.unwrap_or(self.dequeue_timeout);
682            let mut batch = subscriber
683                .fetch()
684                .expires(timeout_duration)
685                .max_messages(1)
686                .messages()
687                .await?;
688
689            if let Some(message) = batch.next().await {
690                let message =
691                    message.map_err(|e| anyhow::anyhow!("Failed to get message: {}", e))?;
692                message
693                    .ack()
694                    .await
695                    .map_err(|e| anyhow::anyhow!("Failed to ack message: {}", e))?;
696                Ok(Some(message.payload.clone()))
697            } else {
698                Ok(None)
699            }
700        } else {
701            Err(anyhow::anyhow!("Subscriber not initialized"))
702        }
703    }
704
705    /// Get the number of messages currently in the queue
706    pub async fn get_queue_size(&mut self) -> Result<u64> {
707        self.ensure_connection().await?;
708
709        if let Some(client) = &self.client {
710            // Get consumer info to get pending messages count
711            let stream = client.jetstream().get_stream(&self.stream_name).await?;
712            let consumer_name = self
713                .consumer_name
714                .clone()
715                .unwrap_or_else(|| "worker-group".to_string());
716            let mut consumer: jetstream::consumer::PullConsumer = stream
717                .get_consumer(&consumer_name)
718                .await
719                .map_err(|e| anyhow::anyhow!("Failed to get consumer: {}", e))?;
720            let info = consumer.info().await?;
721
722            Ok(info.num_pending)
723        } else {
724            Err(anyhow::anyhow!("Client not connected"))
725        }
726    }
727
728    /// Get the total number of messages currently in the stream
729    pub async fn get_stream_messages(&mut self) -> Result<u64> {
730        self.ensure_connection().await?;
731
732        if let Some(client) = &self.client {
733            let mut stream = client.jetstream().get_stream(&self.stream_name).await?;
734            let info = stream.info().await?;
735            Ok(info.state.messages)
736        } else {
737            Err(anyhow::anyhow!("Client not connected"))
738        }
739    }
740
741    /// Purge messages from the stream up to (but not including) the specified sequence number
742    /// This permanently removes messages and affects all consumers of the stream
743    pub async fn purge_up_to_sequence(&self, sequence: u64) -> Result<()> {
744        if let Some(client) = &self.client {
745            let stream = client.jetstream().get_stream(&self.stream_name).await?;
746
747            // NOTE: this purge excludes the sequence itself
748            // https://docs.rs/nats/latest/nats/jetstream/struct.PurgeRequest.html
749            stream.purge().sequence(sequence).await.map_err(|e| {
750                anyhow::anyhow!("Failed to purge stream up to sequence {}: {}", sequence, e)
751            })?;
752
753            log::debug!(
754                "Purged stream {} up to sequence {}",
755                self.stream_name,
756                sequence
757            );
758            Ok(())
759        } else {
760            Err(anyhow::anyhow!("Client not connected"))
761        }
762    }
763
764    /// Purge messages from the stream up to the minimum acknowledged sequence across all consumers
765    /// This finds the lowest acknowledged sequence number across all consumers and purges up to that point
766    pub async fn purge_acknowledged(&mut self) -> Result<()> {
767        self.ensure_connection().await?;
768
769        let Some(client) = &self.client else {
770            return Err(anyhow::anyhow!("Client not connected"));
771        };
772
773        let stream = client.jetstream().get_stream(&self.stream_name).await?;
774
775        // Get all consumer names for the stream
776        let consumer_names: Vec<String> = stream
777            .consumer_names()
778            .try_collect()
779            .await
780            .map_err(|e| anyhow::anyhow!("Failed to list consumers: {}", e))?;
781
782        if consumer_names.is_empty() {
783            log::debug!("No consumers found for stream {}", self.stream_name);
784            return Ok(());
785        }
786
787        // Find the minimum acknowledged sequence across all consumers
788        let mut min_ack_sequence = u64::MAX;
789
790        for consumer_name in &consumer_names {
791            let mut consumer: jetstream::consumer::PullConsumer = stream
792                .get_consumer(consumer_name)
793                .await
794                .map_err(|e| anyhow::anyhow!("Failed to get consumer {}: {}", consumer_name, e))?;
795
796            let info = consumer.info().await.map_err(|e| {
797                anyhow::anyhow!("Failed to get consumer info for {}: {}", consumer_name, e)
798            })?;
799
800            // The ack_floor contains the stream sequence of the highest contiguously acknowledged message
801            // If stream_sequence is 0, it means no messages have been acknowledged yet
802            if info.ack_floor.stream_sequence > 0 {
803                min_ack_sequence = min_ack_sequence.min(info.ack_floor.stream_sequence);
804                log::debug!(
805                    "Consumer {} has ack_floor at sequence {}",
806                    consumer_name,
807                    info.ack_floor.stream_sequence
808                );
809            }
810        }
811
812        // Only purge if we found a valid minimum acknowledged sequence
813        if min_ack_sequence < u64::MAX && min_ack_sequence > 0 {
814            // Purge up to (but not including) the minimum acknowledged sequence + 1
815            // We add 1 because we want to include the minimum acknowledged message in the purge
816            let purge_sequence = min_ack_sequence + 1;
817
818            self.purge_up_to_sequence(purge_sequence).await?;
819
820            log::debug!(
821                "Purged stream {} up to acknowledged sequence {} (purged up to sequence {})",
822                self.stream_name,
823                min_ack_sequence,
824                purge_sequence
825            );
826        } else {
827            log::debug!(
828                "No messages to purge for stream {} (min_ack_sequence: {})",
829                self.stream_name,
830                min_ack_sequence
831            );
832        }
833
834        Ok(())
835    }
836}
837
838#[async_trait]
839impl EventPublisher for NatsQueue {
840    fn subject(&self) -> String {
841        self.stream_name.clone()
842    }
843
844    async fn publish(
845        &self,
846        event_name: impl AsRef<str> + Send + Sync,
847        event: &(impl Serialize + Send + Sync),
848    ) -> Result<()> {
849        let bytes = serde_json::to_vec(event)?;
850        self.publish_bytes(event_name, bytes).await
851    }
852
853    async fn publish_bytes(
854        &self,
855        event_name: impl AsRef<str> + Send + Sync,
856        bytes: Vec<u8>,
857    ) -> Result<()> {
858        // We expect the stream to be always suffixed with "queue"
859        // This suffix itself is nothing special, just a repo standard
860        if event_name.as_ref() != QUEUE_NAME {
861            tracing::warn!(
862                "Expected event_name to be '{}', but got '{}'",
863                QUEUE_NAME,
864                event_name.as_ref()
865            );
866        }
867
868        let subject = format!("{}.{}", self.subject(), event_name.as_ref());
869
870        // Note: enqueue_task requires &mut self, but EventPublisher requires &self
871        // We need to ensure the client is connected and use it directly
872        if let Some(client) = &self.client {
873            client.jetstream().publish(subject, bytes.into()).await?;
874            Ok(())
875        } else {
876            Err(anyhow::anyhow!("Client not connected"))
877        }
878    }
879}
880
881/// Prometheus metrics that mirror the NATS client statistics (in primitive types)
882/// to be used for the System Status Server.
883///
884/// ⚠️  IMPORTANT: These Prometheus Gauges are COPIES of NATS client data, not live references!
885///
886/// How it works:
887/// 1. NATS client provides source data via client.statistics() and connection_state()
888/// 2. set_from_client_stats() reads current NATS values and updates these Prometheus Gauges
889/// 3. Prometheus scrapes these Gauge values (snapshots, not live data)
890///
891/// Flow: NATS Client → Client Statistics → set_from_client_stats() → Prometheus Gauge
892/// Note: These are snapshots updated when set_from_client_stats() is called.
893#[derive(Debug, Clone)]
894pub struct DRTNatsClientPrometheusMetrics {
895    nats_client: client::Client,
896    /// Number of bytes received (excluding protocol overhead)
897    pub in_bytes: IntGauge,
898    /// Number of bytes sent (excluding protocol overhead)
899    pub out_bytes: IntGauge,
900    /// Number of messages received
901    pub in_messages: IntGauge,
902    /// Number of messages sent
903    pub out_messages: IntGauge,
904    /// Number of times connection was established
905    pub connects: IntGauge,
906    /// Current connection state (0 = disconnected, 1 = connected, 2 = reconnecting)
907    pub connection_state: IntGauge,
908}
909
910impl DRTNatsClientPrometheusMetrics {
911    /// Create a new instance of NATS client metrics using a DistributedRuntime's Prometheus constructors
912    pub fn new(drt: &crate::DistributedRuntime, nats_client: client::Client) -> Result<Self> {
913        let in_bytes = drt.create_intgauge(
914            nats_metrics::IN_TOTAL_BYTES,
915            "Total number of bytes received by NATS client",
916            &[],
917        )?;
918        let out_bytes = drt.create_intgauge(
919            nats_metrics::OUT_OVERHEAD_BYTES,
920            "Total number of bytes sent by NATS client",
921            &[],
922        )?;
923        let in_messages = drt.create_intgauge(
924            nats_metrics::IN_MESSAGES,
925            "Total number of messages received by NATS client",
926            &[],
927        )?;
928        let out_messages = drt.create_intgauge(
929            nats_metrics::OUT_MESSAGES,
930            "Total number of messages sent by NATS client",
931            &[],
932        )?;
933        let connects = drt.create_intgauge(
934            nats_metrics::CURRENT_CONNECTIONS,
935            "Current number of active connections for NATS client",
936            &[],
937        )?;
938        let connection_state = drt.create_intgauge(
939            nats_metrics::CONNECTION_STATE,
940            "Current connection state of NATS client (0=disconnected, 1=connected, 2=reconnecting)",
941            &[],
942        )?;
943
944        Ok(Self {
945            nats_client,
946            in_bytes,
947            out_bytes,
948            in_messages,
949            out_messages,
950            connects,
951            connection_state,
952        })
953    }
954
955    /// Copy statistics from the stored NATS client to these Prometheus metrics
956    pub fn set_from_client_stats(&self) {
957        let stats = self.nats_client.statistics();
958
959        // Get current values from the client statistics
960        let in_bytes = stats.in_bytes.load(Ordering::Relaxed);
961        let out_bytes = stats.out_bytes.load(Ordering::Relaxed);
962        let in_messages = stats.in_messages.load(Ordering::Relaxed);
963        let out_messages = stats.out_messages.load(Ordering::Relaxed);
964        let connects = stats.connects.load(Ordering::Relaxed);
965
966        // Get connection state
967        let connection_state = match self.nats_client.connection_state() {
968            State::Connected => 1,
969            // treat Disconnected and Pending as "down"
970            State::Disconnected | State::Pending => 0,
971        };
972
973        // Update Prometheus metrics
974        // Using gauges allows us to set absolute values directly
975        self.in_bytes.set(in_bytes as i64);
976        self.out_bytes.set(out_bytes as i64);
977        self.in_messages.set(in_messages as i64);
978        self.out_messages.set(out_messages as i64);
979        self.connects.set(connects as i64);
980        self.connection_state.set(connection_state);
981    }
982}
983
984#[cfg(test)]
985mod tests {
986
987    use super::*;
988    use figment::Jail;
989    use serde::{Deserialize, Serialize};
990
991    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
992    struct TestData {
993        id: u32,
994        name: String,
995        values: Vec<f64>,
996    }
997
998    #[test]
999    fn test_client_options_builder() {
1000        Jail::expect_with(|_jail| {
1001            let opts = ClientOptions::builder().build();
1002            assert!(opts.is_ok());
1003            Ok(())
1004        });
1005
1006        Jail::expect_with(|jail| {
1007            jail.set_env("NATS_SERVER", "nats://localhost:5222");
1008            jail.set_env("NATS_AUTH_USERNAME", "user");
1009            jail.set_env("NATS_AUTH_PASSWORD", "pass");
1010
1011            let opts = ClientOptions::builder().build();
1012            assert!(opts.is_ok());
1013            let opts = opts.unwrap();
1014
1015            assert_eq!(opts.server, "nats://localhost:5222");
1016            assert_eq!(
1017                opts.auth,
1018                NatsAuth::UserPass("user".to_string(), "pass".to_string())
1019            );
1020
1021            Ok(())
1022        });
1023
1024        Jail::expect_with(|jail| {
1025            jail.set_env("NATS_SERVER", "nats://localhost:5222");
1026            jail.set_env("NATS_AUTH_USERNAME", "user");
1027            jail.set_env("NATS_AUTH_PASSWORD", "pass");
1028
1029            let opts = ClientOptions::builder()
1030                .server("nats://localhost:6222")
1031                .auth(NatsAuth::Token("token".to_string()))
1032                .build();
1033            assert!(opts.is_ok());
1034            let opts = opts.unwrap();
1035
1036            assert_eq!(opts.server, "nats://localhost:6222");
1037            assert_eq!(opts.auth, NatsAuth::Token("token".to_string()));
1038
1039            Ok(())
1040        });
1041    }
1042
1043    // Integration test for object store data operations using bincode
1044    #[tokio::test]
1045    #[ignore] // Requires NATS server to be running
1046    async fn test_object_store_data_operations() {
1047        // Create test data
1048        let test_data = TestData {
1049            id: 42,
1050            name: "test_item".to_string(),
1051            values: vec![1.0, 2.5, 3.7, 4.2],
1052        };
1053
1054        // Set up client
1055        let client_options = ClientOptions::builder()
1056            .server("nats://localhost:4222")
1057            .build()
1058            .expect("Failed to build client options");
1059
1060        let client = client_options
1061            .connect()
1062            .await
1063            .expect("Failed to connect to NATS");
1064
1065        // Test URL (using .bin extension to indicate binary format)
1066        let url =
1067            Url::parse("nats://localhost/test-bucket/test-data.bin").expect("Failed to parse URL");
1068
1069        // Upload the data
1070        client
1071            .object_store_upload_data(&test_data, &url)
1072            .await
1073            .expect("Failed to upload data");
1074
1075        // Download the data
1076        let downloaded_data: TestData = client
1077            .object_store_download_data(&url)
1078            .await
1079            .expect("Failed to download data");
1080
1081        // Verify the data matches
1082        assert_eq!(test_data, downloaded_data);
1083
1084        // Clean up
1085        client
1086            .object_store_delete_bucket("test-bucket")
1087            .await
1088            .expect("Failed to delete bucket");
1089    }
1090
1091    // Integration test for broadcast pattern with purging
1092    #[tokio::test]
1093    #[ignore]
1094    async fn test_nats_queue_broadcast_with_purge() {
1095        use uuid::Uuid;
1096
1097        // Create unique stream name for this test
1098        let stream_name = format!("test-broadcast-{}", Uuid::new_v4());
1099        let nats_server = "nats://localhost:4222".to_string();
1100        let timeout = time::Duration::from_secs(0);
1101
1102        // Connect to NATS client first to delete stream if it exists
1103        let client_options = Client::builder()
1104            .server(nats_server.clone())
1105            .build()
1106            .expect("Failed to build client options");
1107
1108        let client = client_options
1109            .connect()
1110            .await
1111            .expect("Failed to connect to NATS");
1112
1113        // Delete the stream if it exists (to ensure clean start)
1114        let _ = client.jetstream().delete_stream(&stream_name).await;
1115
1116        // Create two consumers with different names for the same stream
1117        let consumer1_name = format!("consumer-{}", Uuid::new_v4());
1118        let consumer2_name = format!("consumer-{}", Uuid::new_v4());
1119
1120        let mut queue1 = NatsQueue::new_with_consumer(
1121            stream_name.clone(),
1122            nats_server.clone(),
1123            timeout,
1124            consumer1_name,
1125        );
1126
1127        // Connect queue1 first (it will create the stream)
1128        queue1.connect().await.expect("Failed to connect queue1");
1129
1130        // Send 4 messages using the EventPublisher trait
1131        let message_strings = [
1132            "message1".to_string(),
1133            "message2".to_string(),
1134            "message3".to_string(),
1135            "message4".to_string(),
1136        ];
1137
1138        // Using the EventPublisher trait to publish messages
1139        for (idx, msg) in message_strings.iter().enumerate() {
1140            queue1
1141                .publish("queue", msg)
1142                .await
1143                .unwrap_or_else(|_| panic!("Failed to publish message {}", idx + 1));
1144        }
1145
1146        // Convert messages to JSON-serialized Bytes for comparison
1147        let messages: Vec<Bytes> = message_strings
1148            .iter()
1149            .map(|s| Bytes::from(serde_json::to_vec(s).unwrap()))
1150            .collect();
1151
1152        // Give JetStream a moment to persist the messages
1153        tokio::time::sleep(time::Duration::from_millis(100)).await;
1154
1155        // Now create and connect queue2 and queue3 AFTER messages are published (to test persistence)
1156        let mut queue2 = NatsQueue::new_with_consumer(
1157            stream_name.clone(),
1158            nats_server.clone(),
1159            timeout,
1160            consumer2_name,
1161        );
1162
1163        // Create a third queue without consumer (publisher-only)
1164        let mut queue3 =
1165            NatsQueue::new_without_consumer(stream_name.clone(), nats_server.clone(), timeout);
1166
1167        // Connect queue2 and queue3 after messages are already published
1168        queue2.connect().await.expect("Failed to connect queue2");
1169        queue3.connect().await.expect("Failed to connect queue3");
1170
1171        // Purge the first two messages (sequence 1 and 2)
1172        // Note: JetStream sequences start at 1, and purge is exclusive of the sequence number
1173        queue1
1174            .purge_up_to_sequence(3)
1175            .await
1176            .expect("Failed to purge messages");
1177
1178        // Give JetStream a moment to process the purge
1179        tokio::time::sleep(time::Duration::from_millis(100)).await;
1180
1181        // Consumer 1 dequeues one message (message3)
1182        let msg3_consumer1 = queue1
1183            .dequeue_task(Some(time::Duration::from_millis(500)))
1184            .await
1185            .expect("Failed to dequeue from queue1");
1186        assert_eq!(
1187            msg3_consumer1,
1188            Some(messages[2].clone()),
1189            "Consumer 1 should get message3"
1190        );
1191
1192        // Give JetStream a moment to process acknowledgments
1193        tokio::time::sleep(time::Duration::from_millis(100)).await;
1194
1195        // Now run purge_acknowledged
1196        // At this point:
1197        // - Consumer 1 has ack'd message 3 (ack_floor = 3)
1198        // - Consumer 2 hasn't consumed anything yet (ack_floor = 0)
1199        // - Min ack_floor = 0, so nothing will be purged
1200        queue1
1201            .purge_acknowledged()
1202            .await
1203            .expect("Failed to purge acknowledged messages");
1204
1205        // Give JetStream a moment to process the purge
1206        tokio::time::sleep(time::Duration::from_millis(100)).await;
1207
1208        // Now collect remaining messages from both consumers
1209        let mut consumer1_remaining = Vec::new();
1210        let mut consumer2_remaining = Vec::new();
1211
1212        // Collect remaining messages from consumer 1
1213        while let Some(msg) = queue1
1214            .dequeue_task(None)
1215            .await
1216            .expect("Failed to dequeue from queue1")
1217        {
1218            consumer1_remaining.push(msg);
1219        }
1220
1221        // Collect remaining messages from consumer 2
1222        while let Some(msg) = queue2
1223            .dequeue_task(None)
1224            .await
1225            .expect("Failed to dequeue from queue2")
1226        {
1227            consumer2_remaining.push(msg);
1228        }
1229
1230        // Verify consumer 1 gets 1 remaining message (message4)
1231        assert_eq!(
1232            consumer1_remaining.len(),
1233            1,
1234            "Consumer 1 should have 1 remaining message"
1235        );
1236        assert_eq!(
1237            consumer1_remaining[0], messages[3],
1238            "Consumer 1 should get message4"
1239        );
1240
1241        // Verify consumer 2 gets 2 messages (message3 and message4)
1242        assert_eq!(
1243            consumer2_remaining.len(),
1244            2,
1245            "Consumer 2 should have 2 messages"
1246        );
1247        assert_eq!(
1248            consumer2_remaining[0], messages[2],
1249            "Consumer 2 should get message3"
1250        );
1251        assert_eq!(
1252            consumer2_remaining[1], messages[3],
1253            "Consumer 2 should get message4"
1254        );
1255
1256        // Test consumer count and shutdown behavior
1257        // First verify via consumer 1 that there are two consumers
1258        let consumer_count = queue1
1259            .count_consumers()
1260            .await
1261            .expect("Failed to count consumers");
1262        assert_eq!(consumer_count, 2, "Should have 2 consumers initially");
1263
1264        // Close consumer 1 and verify via consumer 2 that there are still two consumers
1265        queue1.close().await.expect("Failed to close queue1");
1266
1267        let consumer_count = queue2
1268            .count_consumers()
1269            .await
1270            .expect("Failed to count consumers");
1271        assert_eq!(
1272            consumer_count, 2,
1273            "Should still have 2 consumers after closing queue1"
1274        );
1275
1276        // Reconnect queue1 to be able to shutdown
1277        queue1.connect().await.expect("Failed to reconnect queue1");
1278
1279        // Shutdown consumer 1 and verify via consumer 2 that there is only one consumer left
1280        queue1
1281            .shutdown(None)
1282            .await
1283            .expect("Failed to shutdown queue1");
1284
1285        let consumer_count = queue2
1286            .count_consumers()
1287            .await
1288            .expect("Failed to count consumers");
1289        assert_eq!(
1290            consumer_count, 1,
1291            "Should have only 1 consumer after shutting down queue1"
1292        );
1293
1294        // Clean up by deleting the stream
1295        client
1296            .jetstream()
1297            .delete_stream(&stream_name)
1298            .await
1299            .expect("Failed to delete test stream");
1300    }
1301}