dynamo_runtime/transports/
nats.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! NATS transport
5//!
6//! The following environment variables are used to configure the NATS client:
7//!
8//! - `NATS_SERVER`: the NATS server address
9//!
10//! For authentication, the following environment variables are used and prioritized in the following order:
11//!
12//! - `NATS_AUTH_USERNAME`: the username for authentication
13//! - `NATS_AUTH_PASSWORD`: the password for authentication
14//! - `NATS_AUTH_TOKEN`: the token for authentication
15//! - `NATS_AUTH_NKEY`: the nkey for authentication
16//! - `NATS_AUTH_CREDENTIALS_FILE`: the path to the credentials file
17//!
18//! Note: `NATS_AUTH_USERNAME` and `NATS_AUTH_PASSWORD` must be used together.
19use crate::traits::events::EventPublisher;
20use crate::{Result, metrics::MetricsHierarchy};
21
22use async_nats::connection::State;
23use async_nats::{Subscriber, client, jetstream};
24use async_trait::async_trait;
25use bytes::Bytes;
26use derive_builder::Builder;
27use futures::{StreamExt, TryStreamExt};
28use prometheus::{Counter, Gauge, Histogram, HistogramOpts, IntCounter, IntGauge, Opts, Registry};
29use serde::de::DeserializeOwned;
30use serde::{Deserialize, Serialize};
31use std::path::{Path, PathBuf};
32use std::sync::atomic::Ordering;
33use tokio::fs::File as TokioFile;
34use tokio::io::AsyncRead;
35use tokio::time;
36use url::Url;
37use validator::{Validate, ValidationError};
38
39use crate::metrics::prometheus_names::nats_client as nats_metrics;
40pub use crate::slug::Slug;
41use tracing as log;
42
43use super::utils::build_in_runtime;
44
45pub const URL_PREFIX: &str = "nats://";
46
47#[derive(Clone)]
48pub struct Client {
49    client: client::Client,
50    js_ctx: jetstream::Context,
51}
52
53impl Client {
54    /// Create a NATS [`ClientOptionsBuilder`].
55    pub fn builder() -> ClientOptionsBuilder {
56        ClientOptionsBuilder::default()
57    }
58
59    /// Returns a reference to the underlying [`async_nats::client::Client`] instance
60    pub fn client(&self) -> &client::Client {
61        &self.client
62    }
63
64    /// Returns a reference to the underlying [`async_nats::jetstream::Context`] instance
65    pub fn jetstream(&self) -> &jetstream::Context {
66        &self.js_ctx
67    }
68
69    /// host:port of NATS
70    pub fn addr(&self) -> String {
71        let info = self.client.server_info();
72        format!("{}:{}", info.host, info.port)
73    }
74
75    /// fetch the list of streams
76    pub async fn list_streams(&self) -> Result<Vec<String>> {
77        let names = self.js_ctx.stream_names();
78        let stream_names: Vec<String> = names.try_collect().await?;
79        Ok(stream_names)
80    }
81
82    /// fetch the list of consumers for a given stream
83    pub async fn list_consumers(&self, stream_name: &str) -> Result<Vec<String>> {
84        let stream = self.js_ctx.get_stream(stream_name).await?;
85        let consumers: Vec<String> = stream.consumer_names().try_collect().await?;
86        Ok(consumers)
87    }
88
89    pub async fn stream_info(&self, stream_name: &str) -> Result<jetstream::stream::State> {
90        let mut stream = self.js_ctx.get_stream(stream_name).await?;
91        let info = stream.info().await?;
92        Ok(info.state.clone())
93    }
94
95    pub async fn get_stream(&self, name: &str) -> Result<jetstream::stream::Stream> {
96        let stream = self.js_ctx.get_stream(name).await?;
97        Ok(stream)
98    }
99
100    /// Issues a broadcast request for all services with the provided `service_name` to report their
101    /// current stats. Each service will only respond once. The service may have customized the reply
102    /// so the caller should select which endpoint and what concrete data model should be used to
103    /// extract the details.
104    ///
105    /// Note: Because each endpoint will only reply once, the caller must drop the subscription after
106    /// some time or it will await forever.
107    pub async fn scrape_service(&self, service_name: &str) -> Result<Subscriber> {
108        let subject = format!("$SRV.STATS.{}", service_name);
109        let reply_subject = format!("_INBOX.{}", nuid::next());
110        let subscription = self.client.subscribe(reply_subject.clone()).await?;
111
112        // Publish the request with the reply-to subject
113        self.client
114            .publish_with_reply(subject, reply_subject, "".into())
115            .await?;
116
117        Ok(subscription)
118    }
119
120    /// Helper method to get or optionally create an object store bucket
121    ///
122    /// # Arguments
123    /// * `bucket_name` - The name of the bucket to retrieve
124    /// * `create_if_not_found` - If true, creates the bucket when it doesn't exist
125    ///
126    /// # Returns
127    /// The object store bucket or an error
128    async fn get_or_create_bucket(
129        &self,
130        bucket_name: &str,
131        create_if_not_found: bool,
132    ) -> anyhow::Result<jetstream::object_store::ObjectStore> {
133        let context = self.jetstream();
134
135        match context.get_object_store(bucket_name).await {
136            Ok(bucket) => Ok(bucket),
137            Err(err) if err.to_string().contains("stream not found") => {
138                // err.source() is GetStreamError, which has a kind() which
139                // is GetStreamErrorKind::JetStream which wraps a jetstream::Error
140                // which has code 404. Phew. So yeah check the string for now.
141
142                if create_if_not_found {
143                    tracing::debug!("Creating NATS bucket {bucket_name}");
144                    context
145                        .create_object_store(jetstream::object_store::Config {
146                            bucket: bucket_name.to_string(),
147                            ..Default::default()
148                        })
149                        .await
150                        .map_err(|e| anyhow::anyhow!("Failed creating bucket / object store: {e}"))
151                } else {
152                    anyhow::bail!(
153                        "NATS get_object_store bucket does not exist: {bucket_name}. {err}."
154                    );
155                }
156            }
157            Err(err) => {
158                anyhow::bail!("NATS get_object_store error: {err}");
159            }
160        }
161    }
162
163    /// Upload file to NATS at this URL
164    pub async fn object_store_upload(&self, filepath: &Path, nats_url: &Url) -> anyhow::Result<()> {
165        let mut disk_file = TokioFile::open(filepath).await?;
166
167        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
168        let bucket = self.get_or_create_bucket(&bucket_name, true).await?;
169
170        let key_meta = async_nats::jetstream::object_store::ObjectMetadata {
171            name: key.to_string(),
172            ..Default::default()
173        };
174        bucket.put(key_meta, &mut disk_file).await.map_err(|e| {
175            anyhow::anyhow!("Failed uploading to bucket / object store {bucket_name}/{key}: {e}")
176        })?;
177
178        Ok(())
179    }
180
181    /// Download file from NATS at this URL
182    pub async fn object_store_download(
183        &self,
184        nats_url: &Url,
185        filepath: &Path,
186    ) -> anyhow::Result<()> {
187        let mut disk_file = TokioFile::create(filepath).await?;
188
189        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
190        let bucket = self.get_or_create_bucket(&bucket_name, false).await?;
191
192        let mut obj_reader = bucket.get(&key).await.map_err(|e| {
193            anyhow::anyhow!(
194                "Failed downloading from bucket / object store {bucket_name}/{key}: {e}"
195            )
196        })?;
197        let _bytes_copied = tokio::io::copy(&mut obj_reader, &mut disk_file).await?;
198
199        Ok(())
200    }
201
202    /// Delete a bucket and all it's contents from the NATS object store
203    pub async fn object_store_delete_bucket(&self, bucket_name: &str) -> anyhow::Result<()> {
204        let context = self.jetstream();
205        match context.delete_object_store(&bucket_name).await {
206            Ok(_) => Ok(()),
207            Err(err) if err.to_string().contains("stream not found") => {
208                tracing::trace!(bucket_name, "NATS bucket already gone");
209                Ok(())
210            }
211            Err(err) => Err(anyhow::anyhow!("NATS get_object_store error: {err}")),
212        }
213    }
214
215    /// Upload a serializable struct to NATS object store using bincode
216    pub async fn object_store_upload_data<T>(&self, data: &T, nats_url: &Url) -> anyhow::Result<()>
217    where
218        T: Serialize,
219    {
220        // Serialize the data using bincode (more efficient binary format)
221        let binary_data = bincode::serialize(data)
222            .map_err(|e| anyhow::anyhow!("Failed to serialize data with bincode: {e}"))?;
223
224        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
225        let bucket = self.get_or_create_bucket(&bucket_name, true).await?;
226
227        let key_meta = async_nats::jetstream::object_store::ObjectMetadata {
228            name: key.to_string(),
229            ..Default::default()
230        };
231
232        // Upload the serialized bytes
233        let mut cursor = std::io::Cursor::new(binary_data);
234        bucket.put(key_meta, &mut cursor).await.map_err(|e| {
235            anyhow::anyhow!("Failed uploading to bucket / object store {bucket_name}/{key}: {e}")
236        })?;
237
238        Ok(())
239    }
240
241    /// Download and deserialize a struct from NATS object store using bincode
242    pub async fn object_store_download_data<T>(&self, nats_url: &Url) -> anyhow::Result<T>
243    where
244        T: DeserializeOwned,
245    {
246        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
247        let bucket = self.get_or_create_bucket(&bucket_name, false).await?;
248
249        let mut obj_reader = bucket.get(&key).await.map_err(|e| {
250            anyhow::anyhow!(
251                "Failed downloading from bucket / object store {bucket_name}/{key}: {e}"
252            )
253        })?;
254
255        // Read all bytes into memory
256        let mut buffer = Vec::new();
257        tokio::io::copy(&mut obj_reader, &mut buffer)
258            .await
259            .map_err(|e| anyhow::anyhow!("Failed reading object data: {e}"))?;
260        tracing::debug!("Downloaded {} bytes from {bucket_name}/{key}", buffer.len());
261
262        // Deserialize from bincode
263        let data = bincode::deserialize(&buffer)
264            .map_err(|e| anyhow::anyhow!("Failed to deserialize data with bincode: {e}"))?;
265
266        Ok(data)
267    }
268}
269
270/// NATS client options
271///
272/// This object uses the builder pattern with default values that are evaluates
273/// from the environment variables if they are not explicitly set by the builder.
274#[derive(Debug, Clone, Builder, Validate)]
275pub struct ClientOptions {
276    #[builder(setter(into), default = "default_server()")]
277    #[validate(custom(function = "validate_nats_server"))]
278    server: String,
279
280    #[builder(default)]
281    auth: NatsAuth,
282}
283
284fn default_server() -> String {
285    if let Ok(server) = std::env::var("NATS_SERVER") {
286        return server;
287    }
288
289    "nats://localhost:4222".to_string()
290}
291
292fn validate_nats_server(server: &str) -> Result<(), ValidationError> {
293    if server.starts_with("nats://") {
294        Ok(())
295    } else {
296        Err(ValidationError::new("server must start with 'nats://'"))
297    }
298}
299
300// TODO(jthomson04): We really shouldn't be hardcoding this.
301const NATS_WORKER_THREADS: usize = 4;
302
303impl ClientOptions {
304    /// Create a new [`ClientOptionsBuilder`]
305    pub fn builder() -> ClientOptionsBuilder {
306        ClientOptionsBuilder::default()
307    }
308
309    /// Validate the config and attempt to connection to the NATS server
310    pub async fn connect(self) -> Result<Client> {
311        self.validate()?;
312
313        let client = match self.auth {
314            NatsAuth::UserPass(username, password) => {
315                async_nats::ConnectOptions::with_user_and_password(username, password)
316            }
317            NatsAuth::Token(token) => async_nats::ConnectOptions::with_token(token),
318            NatsAuth::NKey(nkey) => async_nats::ConnectOptions::with_nkey(nkey),
319            NatsAuth::CredentialsFile(path) => {
320                async_nats::ConnectOptions::with_credentials_file(path).await?
321            }
322        };
323
324        let (client, _) = build_in_runtime(
325            async move {
326                client
327                    .connect(self.server)
328                    .await
329                    .map_err(|e| anyhow::anyhow!("Failed to connect to NATS: {e}. Verify NATS server is running and accessible."))
330            },
331            NATS_WORKER_THREADS,
332        )
333        .await?;
334
335        let js_ctx = jetstream::new(client.clone());
336
337        // Validate JetStream is available
338        js_ctx
339            .query_account()
340            .await
341            .map_err(|e| anyhow::anyhow!("JetStream not available: {e}"))?;
342
343        Ok(Client { client, js_ctx })
344    }
345}
346
347impl Default for ClientOptions {
348    fn default() -> Self {
349        ClientOptions {
350            server: default_server(),
351            auth: NatsAuth::default(),
352        }
353    }
354}
355
356#[derive(Clone, Eq, PartialEq)]
357pub enum NatsAuth {
358    UserPass(String, String),
359    Token(String),
360    NKey(String),
361    CredentialsFile(PathBuf),
362}
363
364impl std::fmt::Debug for NatsAuth {
365    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
366        match self {
367            NatsAuth::UserPass(user, _pass) => {
368                write!(f, "UserPass({}, <redacted>)", user)
369            }
370            NatsAuth::Token(_token) => write!(f, "Token(<redacted>)"),
371            NatsAuth::NKey(_nkey) => write!(f, "NKey(<redacted>)"),
372            NatsAuth::CredentialsFile(path) => write!(f, "CredentialsFile({:?})", path),
373        }
374    }
375}
376
377impl Default for NatsAuth {
378    fn default() -> Self {
379        if let (Ok(username), Ok(password)) = (
380            std::env::var("NATS_AUTH_USERNAME"),
381            std::env::var("NATS_AUTH_PASSWORD"),
382        ) {
383            return NatsAuth::UserPass(username, password);
384        }
385
386        if let Ok(token) = std::env::var("NATS_AUTH_TOKEN") {
387            return NatsAuth::Token(token);
388        }
389
390        if let Ok(nkey) = std::env::var("NATS_AUTH_NKEY") {
391            return NatsAuth::NKey(nkey);
392        }
393
394        if let Ok(path) = std::env::var("NATS_AUTH_CREDENTIALS_FILE") {
395            return NatsAuth::CredentialsFile(PathBuf::from(path));
396        }
397
398        NatsAuth::UserPass("user".to_string(), "user".to_string())
399    }
400}
401
402/// Is this file name / url in the NATS object store?
403/// Checks the name only, does not go to the store.
404pub fn is_nats_url(s: &str) -> bool {
405    s.starts_with(URL_PREFIX)
406}
407
408/// Extract NATS bucket and key from a nats URL of the form:
409/// nats://host[:port]/bucket/key
410pub fn url_to_bucket_and_key(url: &Url) -> anyhow::Result<(String, String)> {
411    let Some(mut path_segments) = url.path_segments() else {
412        anyhow::bail!("No path in NATS URL: {url}");
413    };
414    let Some(bucket) = path_segments.next() else {
415        anyhow::bail!("No bucket in NATS URL: {url}");
416    };
417    let Some(key) = path_segments.next() else {
418        anyhow::bail!("No key in NATS URL: {url}");
419    };
420    Ok((bucket.to_string(), key.to_string()))
421}
422
423/// Default queue name for publishing events
424pub const QUEUE_NAME: &str = "queue";
425
426/// A queue implementation using NATS JetStream
427pub struct NatsQueue {
428    /// The name of the stream to use for the queue
429    stream_name: String,
430    /// The NATS server URL
431    nats_server: String,
432    /// Timeout for dequeue operations in seconds
433    dequeue_timeout: time::Duration,
434    /// The NATS client
435    client: Option<Client>,
436    /// The subject pattern used for this queue
437    subject: String,
438    /// The subscriber for pull-based consumption
439    subscriber: Option<jetstream::consumer::PullConsumer>,
440    /// Optional consumer name for broadcast pattern (if None, uses "worker-group")
441    consumer_name: Option<String>,
442    /// Message stream for efficient message consumption
443    message_stream: Option<jetstream::consumer::pull::Stream>,
444}
445
446impl NatsQueue {
447    /// Create a new NatsQueue with the default "worker-group" consumer
448    pub fn new(stream_name: String, nats_server: String, dequeue_timeout: time::Duration) -> Self {
449        // Sanitize stream name to remove path separators (like in Python version)
450        // rupei: are we sure NATs stream name accepts '_'?
451        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
452        let subject = format!("{sanitized_stream_name}.*");
453
454        Self {
455            stream_name: sanitized_stream_name,
456            nats_server,
457            dequeue_timeout,
458            client: None,
459            subject,
460            subscriber: None,
461            consumer_name: Some("worker-group".to_string()),
462            message_stream: None,
463        }
464    }
465
466    /// Create a new NatsQueue without a consumer (publisher-only mode)
467    pub fn new_without_consumer(
468        stream_name: String,
469        nats_server: String,
470        dequeue_timeout: time::Duration,
471    ) -> Self {
472        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
473        let subject = format!("{sanitized_stream_name}.*");
474
475        Self {
476            stream_name: sanitized_stream_name,
477            nats_server,
478            dequeue_timeout,
479            client: None,
480            subject,
481            subscriber: None,
482            consumer_name: None,
483            message_stream: None,
484        }
485    }
486
487    /// Create a new NatsQueue with a specific consumer name for broadcast pattern
488    /// Each consumer with a unique name will receive all messages independently
489    pub fn new_with_consumer(
490        stream_name: String,
491        nats_server: String,
492        dequeue_timeout: time::Duration,
493        consumer_name: String,
494    ) -> Self {
495        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
496        let subject = format!("{sanitized_stream_name}.*");
497
498        Self {
499            stream_name: sanitized_stream_name,
500            nats_server,
501            dequeue_timeout,
502            client: None,
503            subject,
504            subscriber: None,
505            consumer_name: Some(consumer_name),
506            message_stream: None,
507        }
508    }
509
510    /// Connect to the NATS server and set up the stream and consumer
511    pub async fn connect(&mut self) -> Result<()> {
512        self.connect_with_reset(false).await
513    }
514
515    /// Connect to the NATS server and set up the stream and consumer, optionally resetting the stream
516    pub async fn connect_with_reset(&mut self, reset_stream: bool) -> Result<()> {
517        if self.client.is_none() {
518            // Create a new client
519            let client_options = Client::builder().server(self.nats_server.clone()).build()?;
520
521            let client = client_options.connect().await?;
522
523            // messages older than a hour in the stream will be automatically purged
524            let max_age = std::env::var("DYN_NATS_STREAM_MAX_AGE")
525                .ok()
526                .and_then(|s| s.parse::<u64>().ok())
527                .map(time::Duration::from_secs)
528                .unwrap_or_else(|| time::Duration::from_secs(60 * 60));
529
530            let stream_config = jetstream::stream::Config {
531                name: self.stream_name.clone(),
532                subjects: vec![self.subject.clone()],
533                max_age,
534                ..Default::default()
535            };
536
537            // Get or create the stream
538            let stream = client
539                .jetstream()
540                .get_or_create_stream(stream_config)
541                .await?;
542
543            log::debug!("Stream {} is ready", self.stream_name);
544
545            // If reset_stream is true, purge all messages from the stream
546            if reset_stream {
547                match stream.purge().await {
548                    Ok(purge_info) => {
549                        log::info!(
550                            "Successfully purged {} messages from NATS stream {}",
551                            purge_info.purged,
552                            self.stream_name
553                        );
554                    }
555                    Err(e) => {
556                        log::warn!("Failed to purge NATS stream '{}': {e}", self.stream_name);
557                    }
558                }
559            }
560
561            // Create persistent subscriber only if consumer_name is set
562            if let Some(ref consumer_name) = self.consumer_name {
563                let consumer_config = jetstream::consumer::pull::Config {
564                    durable_name: Some(consumer_name.clone()),
565                    inactive_threshold: std::time::Duration::from_secs(3600), // 1 hour
566                    ..Default::default()
567                };
568
569                let subscriber = stream.create_consumer(consumer_config).await?;
570
571                // Create the message stream for efficient consumption
572                let message_stream = subscriber.messages().await?;
573
574                self.subscriber = Some(subscriber);
575                self.message_stream = Some(message_stream);
576            }
577
578            self.client = Some(client);
579        }
580
581        Ok(())
582    }
583
584    /// Ensure we have an active connection
585    pub async fn ensure_connection(&mut self) -> Result<()> {
586        if self.client.is_none() {
587            self.connect().await?;
588        }
589        Ok(())
590    }
591
592    /// Close the connection when done
593    pub async fn close(&mut self) -> Result<()> {
594        self.message_stream = None;
595        self.subscriber = None;
596        self.client = None;
597        Ok(())
598    }
599
600    /// Shutdown the consumer by deleting it from the stream and closing the connection
601    /// This permanently removes the consumer from the server
602    ///
603    /// If `consumer_name` is provided, that specific consumer will be deleted instead of the
604    /// current consumer. This allows deletion of other consumers on the same stream.
605    pub async fn shutdown(&mut self, consumer_name: Option<String>) -> Result<()> {
606        // Determine which consumer to delete
607        let target_consumer = consumer_name.as_ref().or(self.consumer_name.as_ref());
608
609        // Warn if deleting our own consumer via explicit parameter
610        if let Some(ref passed_name) = consumer_name
611            && self.consumer_name.as_ref() == Some(passed_name)
612        {
613            log::warn!(
614                "Deleting our own consumer '{}' via explicit consumer_name parameter. \
615                Consider calling shutdown without arguments instead.",
616                passed_name
617            );
618        }
619
620        if let (Some(client), Some(consumer_to_delete)) = (&self.client, target_consumer) {
621            // Get the stream and delete the consumer
622            let stream = client.jetstream().get_stream(&self.stream_name).await?;
623            stream
624                .delete_consumer(consumer_to_delete)
625                .await
626                .map_err(|e| {
627                    anyhow::anyhow!("Failed to delete consumer {}: {}", consumer_to_delete, e)
628                })?;
629            log::debug!(
630                "Deleted consumer {} from stream {}",
631                consumer_to_delete,
632                self.stream_name
633            );
634        } else {
635            log::debug!(
636                "Cannot shutdown consumer: client or target consumer is None (client: {:?}, target_consumer: {:?})",
637                self.client.is_some(),
638                target_consumer.is_some()
639            );
640        }
641
642        // Only close the connection if we deleted our own consumer
643        if consumer_name.is_none() {
644            self.close().await
645        } else {
646            Ok(())
647        }
648    }
649
650    /// Count the number of consumers for the stream
651    pub async fn count_consumers(&mut self) -> Result<usize> {
652        self.ensure_connection().await?;
653
654        if let Some(client) = &self.client {
655            let mut stream = client.jetstream().get_stream(&self.stream_name).await?;
656            let info = stream.info().await?;
657            Ok(info.state.consumer_count)
658        } else {
659            Err(anyhow::anyhow!("Client not connected"))
660        }
661    }
662
663    /// List all consumer names for the stream
664    pub async fn list_consumers(&mut self) -> Result<Vec<String>> {
665        self.ensure_connection().await?;
666
667        if let Some(client) = &self.client {
668            client.list_consumers(&self.stream_name).await
669        } else {
670            Err(anyhow::anyhow!("Client not connected"))
671        }
672    }
673
674    /// Enqueue a task using the provided data
675    pub async fn enqueue_task(&mut self, task_data: Bytes) -> Result<()> {
676        self.ensure_connection().await?;
677
678        if let Some(client) = &self.client {
679            let subject = format!("{}.queue", self.stream_name);
680            client.jetstream().publish(subject, task_data).await?;
681            Ok(())
682        } else {
683            Err(anyhow::anyhow!("Client not connected"))
684        }
685    }
686
687    /// Dequeue and return a task as raw bytes
688    pub async fn dequeue_task(&mut self, timeout: Option<time::Duration>) -> Result<Option<Bytes>> {
689        self.ensure_connection().await?;
690
691        let Some(ref mut stream) = self.message_stream else {
692            return Err(anyhow::anyhow!("Message stream not initialized"));
693        };
694
695        let timeout_duration = timeout.unwrap_or(self.dequeue_timeout);
696
697        // Try to get next message from the stream with timeout
698        let message = tokio::time::timeout(timeout_duration, stream.next()).await;
699
700        match message {
701            Ok(Some(Ok(msg))) => {
702                msg.ack()
703                    .await
704                    .map_err(|e| anyhow::anyhow!("Failed to ack message: {}", e))?;
705                Ok(Some(msg.payload.clone()))
706            }
707
708            Ok(Some(Err(e))) => Err(anyhow::anyhow!("Failed to get message from stream: {}", e)),
709
710            Ok(None) => Err(anyhow::anyhow!("Message stream ended unexpectedly")),
711
712            // Timeout - no messages available
713            Err(_) => Ok(None),
714        }
715    }
716
717    /// Get the number of messages currently in the queue
718    pub async fn get_queue_size(&mut self) -> Result<u64> {
719        self.ensure_connection().await?;
720
721        if let Some(client) = &self.client {
722            // Get consumer info to get pending messages count
723            let stream = client.jetstream().get_stream(&self.stream_name).await?;
724            let consumer_name = self
725                .consumer_name
726                .clone()
727                .unwrap_or_else(|| "worker-group".to_string());
728            let mut consumer: jetstream::consumer::PullConsumer = stream
729                .get_consumer(&consumer_name)
730                .await
731                .map_err(|e| anyhow::anyhow!("Failed to get consumer: {}", e))?;
732            let info = consumer.info().await?;
733
734            Ok(info.num_pending)
735        } else {
736            Err(anyhow::anyhow!("Client not connected"))
737        }
738    }
739
740    /// Get the total number of messages currently in the stream
741    pub async fn get_stream_messages(&mut self) -> Result<u64> {
742        self.ensure_connection().await?;
743
744        if let Some(client) = &self.client {
745            let mut stream = client.jetstream().get_stream(&self.stream_name).await?;
746            let info = stream.info().await?;
747            Ok(info.state.messages)
748        } else {
749            Err(anyhow::anyhow!("Client not connected"))
750        }
751    }
752
753    /// Purge messages from the stream up to (but not including) the specified sequence number
754    /// This permanently removes messages and affects all consumers of the stream
755    pub async fn purge_up_to_sequence(&self, sequence: u64) -> Result<()> {
756        if let Some(client) = &self.client {
757            let stream = client.jetstream().get_stream(&self.stream_name).await?;
758
759            // NOTE: this purge excludes the sequence itself
760            // https://docs.rs/nats/latest/nats/jetstream/struct.PurgeRequest.html
761            stream.purge().sequence(sequence).await.map_err(|e| {
762                anyhow::anyhow!("Failed to purge stream up to sequence {}: {}", sequence, e)
763            })?;
764
765            log::debug!(
766                "Purged stream {} up to sequence {}",
767                self.stream_name,
768                sequence
769            );
770            Ok(())
771        } else {
772            Err(anyhow::anyhow!("Client not connected"))
773        }
774    }
775
776    /// Purge messages from the stream up to the minimum acknowledged sequence across all consumers
777    /// This finds the lowest acknowledged sequence number across all consumers and purges up to that point
778    pub async fn purge_acknowledged(&mut self) -> Result<()> {
779        self.ensure_connection().await?;
780
781        let Some(client) = &self.client else {
782            return Err(anyhow::anyhow!("Client not connected"));
783        };
784
785        let stream = client.jetstream().get_stream(&self.stream_name).await?;
786
787        // Get all consumer names for the stream
788        let consumer_names: Vec<String> = stream
789            .consumer_names()
790            .try_collect()
791            .await
792            .map_err(|e| anyhow::anyhow!("Failed to list consumers: {}", e))?;
793
794        if consumer_names.is_empty() {
795            log::debug!("No consumers found for stream {}", self.stream_name);
796            return Ok(());
797        }
798
799        // Find the minimum acknowledged sequence across all consumers
800        let mut min_ack_sequence = u64::MAX;
801
802        for consumer_name in &consumer_names {
803            let mut consumer: jetstream::consumer::PullConsumer = stream
804                .get_consumer(consumer_name)
805                .await
806                .map_err(|e| anyhow::anyhow!("Failed to get consumer {}: {}", consumer_name, e))?;
807
808            let info = consumer.info().await.map_err(|e| {
809                anyhow::anyhow!("Failed to get consumer info for {}: {}", consumer_name, e)
810            })?;
811
812            // The ack_floor contains the stream sequence of the highest contiguously acknowledged message
813            // If stream_sequence is 0, it means no messages have been acknowledged yet
814            if info.ack_floor.stream_sequence > 0 {
815                min_ack_sequence = min_ack_sequence.min(info.ack_floor.stream_sequence);
816                log::debug!(
817                    "Consumer {} has ack_floor at sequence {}",
818                    consumer_name,
819                    info.ack_floor.stream_sequence
820                );
821            }
822        }
823
824        // Only purge if we found a valid minimum acknowledged sequence
825        if min_ack_sequence < u64::MAX && min_ack_sequence > 0 {
826            // Purge up to (but not including) the minimum acknowledged sequence + 1
827            // We add 1 because we want to include the minimum acknowledged message in the purge
828            let purge_sequence = min_ack_sequence + 1;
829
830            self.purge_up_to_sequence(purge_sequence).await?;
831
832            log::debug!(
833                "Purged stream {} up to acknowledged sequence {} (purged up to sequence {})",
834                self.stream_name,
835                min_ack_sequence,
836                purge_sequence
837            );
838        } else {
839            log::debug!(
840                "No messages to purge for stream {} (min_ack_sequence: {})",
841                self.stream_name,
842                min_ack_sequence
843            );
844        }
845
846        Ok(())
847    }
848}
849
850#[async_trait]
851impl EventPublisher for NatsQueue {
852    fn subject(&self) -> String {
853        self.stream_name.clone()
854    }
855
856    async fn publish(
857        &self,
858        event_name: impl AsRef<str> + Send + Sync,
859        event: &(impl Serialize + Send + Sync),
860    ) -> Result<()> {
861        let bytes = serde_json::to_vec(event)?;
862        self.publish_bytes(event_name, bytes).await
863    }
864
865    async fn publish_bytes(
866        &self,
867        event_name: impl AsRef<str> + Send + Sync,
868        bytes: Vec<u8>,
869    ) -> Result<()> {
870        // We expect the stream to be always suffixed with "queue"
871        // This suffix itself is nothing special, just a repo standard
872        if event_name.as_ref() != QUEUE_NAME {
873            tracing::warn!(
874                "Expected event_name to be '{}', but got '{}'",
875                QUEUE_NAME,
876                event_name.as_ref()
877            );
878        }
879
880        let subject = format!("{}.{}", self.subject(), event_name.as_ref());
881
882        // Note: enqueue_task requires &mut self, but EventPublisher requires &self
883        // We need to ensure the client is connected and use it directly
884        if let Some(client) = &self.client {
885            client.jetstream().publish(subject, bytes.into()).await?;
886            Ok(())
887        } else {
888            Err(anyhow::anyhow!("Client not connected"))
889        }
890    }
891}
892
893/// Prometheus metrics that mirror the NATS client statistics (in primitive types)
894/// to be used for the System Status Server.
895///
896/// ⚠️  IMPORTANT: These Prometheus Gauges are COPIES of NATS client data, not live references!
897///
898/// How it works:
899/// 1. NATS client provides source data via client.statistics() and connection_state()
900/// 2. set_from_client_stats() reads current NATS values and updates these Prometheus Gauges
901/// 3. Prometheus scrapes these Gauge values (snapshots, not live data)
902///
903/// Flow: NATS Client → Client Statistics → set_from_client_stats() → Prometheus Gauge
904/// Note: These are snapshots updated when set_from_client_stats() is called.
905#[derive(Debug, Clone)]
906pub struct DRTNatsClientPrometheusMetrics {
907    nats_client: client::Client,
908    /// Number of bytes received (excluding protocol overhead)
909    pub in_bytes: IntGauge,
910    /// Number of bytes sent (excluding protocol overhead)
911    pub out_bytes: IntGauge,
912    /// Number of messages received
913    pub in_messages: IntGauge,
914    /// Number of messages sent
915    pub out_messages: IntGauge,
916    /// Number of times connection was established
917    pub connects: IntGauge,
918    /// Current connection state (0 = disconnected, 1 = connected, 2 = reconnecting)
919    pub connection_state: IntGauge,
920}
921
922impl DRTNatsClientPrometheusMetrics {
923    /// Create a new instance of NATS client metrics using a DistributedRuntime's Prometheus constructors
924    pub fn new(drt: &crate::DistributedRuntime, nats_client: client::Client) -> Result<Self> {
925        let metrics = drt.metrics();
926        let in_bytes = metrics.create_intgauge(
927            nats_metrics::IN_TOTAL_BYTES,
928            "Total number of bytes received by NATS client",
929            &[],
930        )?;
931        let out_bytes = metrics.create_intgauge(
932            nats_metrics::OUT_OVERHEAD_BYTES,
933            "Total number of bytes sent by NATS client",
934            &[],
935        )?;
936        let in_messages = metrics.create_intgauge(
937            nats_metrics::IN_MESSAGES,
938            "Total number of messages received by NATS client",
939            &[],
940        )?;
941        let out_messages = metrics.create_intgauge(
942            nats_metrics::OUT_MESSAGES,
943            "Total number of messages sent by NATS client",
944            &[],
945        )?;
946        let connects = metrics.create_intgauge(
947            nats_metrics::CURRENT_CONNECTIONS,
948            "Current number of active connections for NATS client",
949            &[],
950        )?;
951        let connection_state = metrics.create_intgauge(
952            nats_metrics::CONNECTION_STATE,
953            "Current connection state of NATS client (0=disconnected, 1=connected, 2=reconnecting)",
954            &[],
955        )?;
956
957        Ok(Self {
958            nats_client,
959            in_bytes,
960            out_bytes,
961            in_messages,
962            out_messages,
963            connects,
964            connection_state,
965        })
966    }
967
968    /// Copy statistics from the stored NATS client to these Prometheus metrics
969    pub fn set_from_client_stats(&self) {
970        let stats = self.nats_client.statistics();
971
972        // Get current values from the client statistics
973        let in_bytes = stats.in_bytes.load(Ordering::Relaxed);
974        let out_bytes = stats.out_bytes.load(Ordering::Relaxed);
975        let in_messages = stats.in_messages.load(Ordering::Relaxed);
976        let out_messages = stats.out_messages.load(Ordering::Relaxed);
977        let connects = stats.connects.load(Ordering::Relaxed);
978
979        // Get connection state
980        let connection_state = match self.nats_client.connection_state() {
981            State::Connected => 1,
982            // treat Disconnected and Pending as "down"
983            State::Disconnected | State::Pending => 0,
984        };
985
986        // Update Prometheus metrics
987        // Using gauges allows us to set absolute values directly
988        self.in_bytes.set(in_bytes as i64);
989        self.out_bytes.set(out_bytes as i64);
990        self.in_messages.set(in_messages as i64);
991        self.out_messages.set(out_messages as i64);
992        self.connects.set(connects as i64);
993        self.connection_state.set(connection_state);
994    }
995}
996
997#[cfg(test)]
998mod tests {
999
1000    use super::*;
1001    use figment::Jail;
1002    use serde::{Deserialize, Serialize};
1003
1004    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1005    struct TestData {
1006        id: u32,
1007        name: String,
1008        values: Vec<f64>,
1009    }
1010
1011    #[test]
1012    fn test_client_options_builder() {
1013        Jail::expect_with(|_jail| {
1014            let opts = ClientOptions::builder().build();
1015            assert!(opts.is_ok());
1016            Ok(())
1017        });
1018
1019        Jail::expect_with(|jail| {
1020            jail.set_env("NATS_SERVER", "nats://localhost:5222");
1021            jail.set_env("NATS_AUTH_USERNAME", "user");
1022            jail.set_env("NATS_AUTH_PASSWORD", "pass");
1023
1024            let opts = ClientOptions::builder().build();
1025            assert!(opts.is_ok());
1026            let opts = opts.unwrap();
1027
1028            assert_eq!(opts.server, "nats://localhost:5222");
1029            assert_eq!(
1030                opts.auth,
1031                NatsAuth::UserPass("user".to_string(), "pass".to_string())
1032            );
1033
1034            Ok(())
1035        });
1036
1037        Jail::expect_with(|jail| {
1038            jail.set_env("NATS_SERVER", "nats://localhost:5222");
1039            jail.set_env("NATS_AUTH_USERNAME", "user");
1040            jail.set_env("NATS_AUTH_PASSWORD", "pass");
1041
1042            let opts = ClientOptions::builder()
1043                .server("nats://localhost:6222")
1044                .auth(NatsAuth::Token("token".to_string()))
1045                .build();
1046            assert!(opts.is_ok());
1047            let opts = opts.unwrap();
1048
1049            assert_eq!(opts.server, "nats://localhost:6222");
1050            assert_eq!(opts.auth, NatsAuth::Token("token".to_string()));
1051
1052            Ok(())
1053        });
1054    }
1055
1056    // Integration test for object store data operations using bincode
1057    #[tokio::test]
1058    #[ignore] // Requires NATS server to be running
1059    async fn test_object_store_data_operations() {
1060        // Create test data
1061        let test_data = TestData {
1062            id: 42,
1063            name: "test_item".to_string(),
1064            values: vec![1.0, 2.5, 3.7, 4.2],
1065        };
1066
1067        // Set up client
1068        let client_options = ClientOptions::builder()
1069            .server("nats://localhost:4222")
1070            .build()
1071            .expect("Failed to build client options");
1072
1073        let client = client_options
1074            .connect()
1075            .await
1076            .expect("Failed to connect to NATS");
1077
1078        // Test URL (using .bin extension to indicate binary format)
1079        let url =
1080            Url::parse("nats://localhost/test-bucket/test-data.bin").expect("Failed to parse URL");
1081
1082        // Upload the data
1083        client
1084            .object_store_upload_data(&test_data, &url)
1085            .await
1086            .expect("Failed to upload data");
1087
1088        // Download the data
1089        let downloaded_data: TestData = client
1090            .object_store_download_data(&url)
1091            .await
1092            .expect("Failed to download data");
1093
1094        // Verify the data matches
1095        assert_eq!(test_data, downloaded_data);
1096
1097        // Clean up
1098        client
1099            .object_store_delete_bucket("test-bucket")
1100            .await
1101            .expect("Failed to delete bucket");
1102    }
1103
1104    // Integration test for broadcast pattern with purging
1105    #[tokio::test]
1106    #[ignore]
1107    async fn test_nats_queue_broadcast_with_purge() {
1108        use uuid::Uuid;
1109
1110        // Create unique stream name for this test
1111        let stream_name = format!("test-broadcast-{}", Uuid::new_v4());
1112        let nats_server = "nats://localhost:4222".to_string();
1113        let timeout = time::Duration::from_secs(0);
1114
1115        // Connect to NATS client first to delete stream if it exists
1116        let client_options = Client::builder()
1117            .server(nats_server.clone())
1118            .build()
1119            .expect("Failed to build client options");
1120
1121        let client = client_options
1122            .connect()
1123            .await
1124            .expect("Failed to connect to NATS");
1125
1126        // Delete the stream if it exists (to ensure clean start)
1127        let _ = client.jetstream().delete_stream(&stream_name).await;
1128
1129        // Create two consumers with different names for the same stream
1130        let consumer1_name = format!("consumer-{}", Uuid::new_v4());
1131        let consumer2_name = format!("consumer-{}", Uuid::new_v4());
1132
1133        let mut queue1 = NatsQueue::new_with_consumer(
1134            stream_name.clone(),
1135            nats_server.clone(),
1136            timeout,
1137            consumer1_name,
1138        );
1139
1140        // Connect queue1 first (it will create the stream)
1141        queue1.connect().await.expect("Failed to connect queue1");
1142
1143        // Send 4 messages using the EventPublisher trait
1144        let message_strings = [
1145            "message1".to_string(),
1146            "message2".to_string(),
1147            "message3".to_string(),
1148            "message4".to_string(),
1149        ];
1150
1151        // Using the EventPublisher trait to publish messages
1152        for (idx, msg) in message_strings.iter().enumerate() {
1153            queue1
1154                .publish("queue", msg)
1155                .await
1156                .unwrap_or_else(|_| panic!("Failed to publish message {}", idx + 1));
1157        }
1158
1159        // Convert messages to JSON-serialized Bytes for comparison
1160        let messages: Vec<Bytes> = message_strings
1161            .iter()
1162            .map(|s| Bytes::from(serde_json::to_vec(s).unwrap()))
1163            .collect();
1164
1165        // Give JetStream a moment to persist the messages
1166        tokio::time::sleep(time::Duration::from_millis(100)).await;
1167
1168        // Now create and connect queue2 and queue3 AFTER messages are published (to test persistence)
1169        let mut queue2 = NatsQueue::new_with_consumer(
1170            stream_name.clone(),
1171            nats_server.clone(),
1172            timeout,
1173            consumer2_name,
1174        );
1175
1176        // Create a third queue without consumer (publisher-only)
1177        let mut queue3 =
1178            NatsQueue::new_without_consumer(stream_name.clone(), nats_server.clone(), timeout);
1179
1180        // Connect queue2 and queue3 after messages are already published
1181        queue2.connect().await.expect("Failed to connect queue2");
1182        queue3.connect().await.expect("Failed to connect queue3");
1183
1184        // Purge the first two messages (sequence 1 and 2)
1185        // Note: JetStream sequences start at 1, and purge is exclusive of the sequence number
1186        queue1
1187            .purge_up_to_sequence(3)
1188            .await
1189            .expect("Failed to purge messages");
1190
1191        // Give JetStream a moment to process the purge
1192        tokio::time::sleep(time::Duration::from_millis(100)).await;
1193
1194        // Consumer 1 dequeues one message (message3)
1195        let msg3_consumer1 = queue1
1196            .dequeue_task(Some(time::Duration::from_millis(500)))
1197            .await
1198            .expect("Failed to dequeue from queue1");
1199        assert_eq!(
1200            msg3_consumer1,
1201            Some(messages[2].clone()),
1202            "Consumer 1 should get message3"
1203        );
1204
1205        // Give JetStream a moment to process acknowledgments
1206        tokio::time::sleep(time::Duration::from_millis(100)).await;
1207
1208        // Now run purge_acknowledged
1209        // At this point:
1210        // - Consumer 1 has ack'd message 3 (ack_floor = 3)
1211        // - Consumer 2 hasn't consumed anything yet (ack_floor = 0)
1212        // - Min ack_floor = 0, so nothing will be purged
1213        queue1
1214            .purge_acknowledged()
1215            .await
1216            .expect("Failed to purge acknowledged messages");
1217
1218        // Give JetStream a moment to process the purge
1219        tokio::time::sleep(time::Duration::from_millis(100)).await;
1220
1221        // Now collect remaining messages from both consumers
1222        let mut consumer1_remaining = Vec::new();
1223        let mut consumer2_remaining = Vec::new();
1224
1225        // Collect remaining messages from consumer 1
1226        while let Some(msg) = queue1
1227            .dequeue_task(None)
1228            .await
1229            .expect("Failed to dequeue from queue1")
1230        {
1231            consumer1_remaining.push(msg);
1232        }
1233
1234        // Collect remaining messages from consumer 2
1235        while let Some(msg) = queue2
1236            .dequeue_task(None)
1237            .await
1238            .expect("Failed to dequeue from queue2")
1239        {
1240            consumer2_remaining.push(msg);
1241        }
1242
1243        // Verify consumer 1 gets 1 remaining message (message4)
1244        assert_eq!(
1245            consumer1_remaining.len(),
1246            1,
1247            "Consumer 1 should have 1 remaining message"
1248        );
1249        assert_eq!(
1250            consumer1_remaining[0], messages[3],
1251            "Consumer 1 should get message4"
1252        );
1253
1254        // Verify consumer 2 gets 2 messages (message3 and message4)
1255        assert_eq!(
1256            consumer2_remaining.len(),
1257            2,
1258            "Consumer 2 should have 2 messages"
1259        );
1260        assert_eq!(
1261            consumer2_remaining[0], messages[2],
1262            "Consumer 2 should get message3"
1263        );
1264        assert_eq!(
1265            consumer2_remaining[1], messages[3],
1266            "Consumer 2 should get message4"
1267        );
1268
1269        // Test consumer count and shutdown behavior
1270        // First verify via consumer 1 that there are two consumers
1271        let consumer_count = queue1
1272            .count_consumers()
1273            .await
1274            .expect("Failed to count consumers");
1275        assert_eq!(consumer_count, 2, "Should have 2 consumers initially");
1276
1277        // Close consumer 1 and verify via consumer 2 that there are still two consumers
1278        queue1.close().await.expect("Failed to close queue1");
1279
1280        let consumer_count = queue2
1281            .count_consumers()
1282            .await
1283            .expect("Failed to count consumers");
1284        assert_eq!(
1285            consumer_count, 2,
1286            "Should still have 2 consumers after closing queue1"
1287        );
1288
1289        // Reconnect queue1 to be able to shutdown
1290        queue1.connect().await.expect("Failed to reconnect queue1");
1291
1292        // Shutdown consumer 1 and verify via consumer 2 that there is only one consumer left
1293        queue1
1294            .shutdown(None)
1295            .await
1296            .expect("Failed to shutdown queue1");
1297
1298        let consumer_count = queue2
1299            .count_consumers()
1300            .await
1301            .expect("Failed to count consumers");
1302        assert_eq!(
1303            consumer_count, 1,
1304            "Should have only 1 consumer after shutting down queue1"
1305        );
1306
1307        // Clean up by deleting the stream
1308        client
1309            .jetstream()
1310            .delete_stream(&stream_name)
1311            .await
1312            .expect("Failed to delete test stream");
1313    }
1314}