dynamo_runtime/transports/
nats.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! NATS transport
5//!
6//! The following environment variables are used to configure the NATS client:
7//!
8//! - `NATS_SERVER`: the NATS server address
9//!
10//! For authentication, the following environment variables are used and prioritized in the following order:
11//!
12//! - `NATS_AUTH_USERNAME`: the username for authentication
13//! - `NATS_AUTH_PASSWORD`: the password for authentication
14//! - `NATS_AUTH_TOKEN`: the token for authentication
15//! - `NATS_AUTH_NKEY`: the nkey for authentication
16//! - `NATS_AUTH_CREDENTIALS_FILE`: the path to the credentials file
17//!
18//! Note: `NATS_AUTH_USERNAME` and `NATS_AUTH_PASSWORD` must be used together.
19use crate::metrics::MetricsHierarchy;
20use crate::traits::events::EventPublisher;
21
22use anyhow::Result;
23use async_nats::connection::State;
24use async_nats::{Subscriber, client, jetstream};
25use async_trait::async_trait;
26use bytes::Bytes;
27use derive_builder::Builder;
28use futures::{StreamExt, TryStreamExt};
29use prometheus::{Counter, Gauge, Histogram, HistogramOpts, IntCounter, IntGauge, Opts, Registry};
30use serde::de::DeserializeOwned;
31use serde::{Deserialize, Serialize};
32use std::path::{Path, PathBuf};
33use std::sync::atomic::Ordering;
34use tokio::fs::File as TokioFile;
35use tokio::io::AsyncRead;
36use tokio::time;
37use url::Url;
38use validator::{Validate, ValidationError};
39
40use crate::metrics::prometheus_names::nats_client as nats_metrics;
41pub use crate::slug::Slug;
42use tracing as log;
43
44use super::utils::build_in_runtime;
45
46pub const URL_PREFIX: &str = "nats://";
47
48#[derive(Clone)]
49pub struct Client {
50    client: client::Client,
51    js_ctx: jetstream::Context,
52}
53
54impl Client {
55    /// Create a NATS [`ClientOptionsBuilder`].
56    pub fn builder() -> ClientOptionsBuilder {
57        ClientOptionsBuilder::default()
58    }
59
60    /// Returns a reference to the underlying [`async_nats::client::Client`] instance
61    pub fn client(&self) -> &client::Client {
62        &self.client
63    }
64
65    /// Returns a reference to the underlying [`async_nats::jetstream::Context`] instance
66    pub fn jetstream(&self) -> &jetstream::Context {
67        &self.js_ctx
68    }
69
70    /// host:port of NATS
71    pub fn addr(&self) -> String {
72        let info = self.client.server_info();
73        format!("{}:{}", info.host, info.port)
74    }
75
76    /// fetch the list of streams
77    pub async fn list_streams(&self) -> Result<Vec<String>> {
78        let names = self.js_ctx.stream_names();
79        let stream_names: Vec<String> = names.try_collect().await?;
80        Ok(stream_names)
81    }
82
83    /// fetch the list of consumers for a given stream
84    pub async fn list_consumers(&self, stream_name: &str) -> Result<Vec<String>> {
85        let stream = self.js_ctx.get_stream(stream_name).await?;
86        let consumers: Vec<String> = stream.consumer_names().try_collect().await?;
87        Ok(consumers)
88    }
89
90    pub async fn stream_info(&self, stream_name: &str) -> Result<jetstream::stream::State> {
91        let mut stream = self.js_ctx.get_stream(stream_name).await?;
92        let info = stream.info().await?;
93        Ok(info.state.clone())
94    }
95
96    pub async fn get_stream(&self, name: &str) -> Result<jetstream::stream::Stream> {
97        let stream = self.js_ctx.get_stream(name).await?;
98        Ok(stream)
99    }
100
101    /// Issues a broadcast request for all services with the provided `service_name` to report their
102    /// current stats. Each service will only respond once. The service may have customized the reply
103    /// so the caller should select which endpoint and what concrete data model should be used to
104    /// extract the details.
105    ///
106    /// Note: Because each endpoint will only reply once, the caller must drop the subscription after
107    /// some time or it will await forever.
108    pub async fn scrape_service(&self, service_name: &str) -> Result<Subscriber> {
109        let subject = format!("$SRV.STATS.{}", service_name);
110        let reply_subject = format!("_INBOX.{}", nuid::next());
111        let subscription = self.client.subscribe(reply_subject.clone()).await?;
112
113        // Publish the request with the reply-to subject
114        self.client
115            .publish_with_reply(subject, reply_subject, "".into())
116            .await?;
117
118        Ok(subscription)
119    }
120
121    /// Helper method to get or optionally create an object store bucket
122    ///
123    /// # Arguments
124    /// * `bucket_name` - The name of the bucket to retrieve
125    /// * `create_if_not_found` - If true, creates the bucket when it doesn't exist
126    ///
127    /// # Returns
128    /// The object store bucket or an error
129    async fn get_or_create_bucket(
130        &self,
131        bucket_name: &str,
132        create_if_not_found: bool,
133    ) -> anyhow::Result<jetstream::object_store::ObjectStore> {
134        let context = self.jetstream();
135
136        match context.get_object_store(bucket_name).await {
137            Ok(bucket) => Ok(bucket),
138            Err(err) if err.to_string().contains("stream not found") => {
139                // err.source() is GetStreamError, which has a kind() which
140                // is GetStreamErrorKind::JetStream which wraps a jetstream::Error
141                // which has code 404. Phew. So yeah check the string for now.
142
143                if create_if_not_found {
144                    tracing::debug!("Creating NATS bucket {bucket_name}");
145                    context
146                        .create_object_store(jetstream::object_store::Config {
147                            bucket: bucket_name.to_string(),
148                            ..Default::default()
149                        })
150                        .await
151                        .map_err(|e| anyhow::anyhow!("Failed creating bucket / object store: {e}"))
152                } else {
153                    anyhow::bail!(
154                        "NATS get_object_store bucket does not exist: {bucket_name}. {err}."
155                    );
156                }
157            }
158            Err(err) => {
159                anyhow::bail!("NATS get_object_store error: {err}");
160            }
161        }
162    }
163
164    /// Upload file to NATS at this URL
165    pub async fn object_store_upload(&self, filepath: &Path, nats_url: &Url) -> anyhow::Result<()> {
166        let mut disk_file = TokioFile::open(filepath).await?;
167
168        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
169        let bucket = self.get_or_create_bucket(&bucket_name, true).await?;
170
171        let key_meta = async_nats::jetstream::object_store::ObjectMetadata {
172            name: key.to_string(),
173            ..Default::default()
174        };
175        bucket.put(key_meta, &mut disk_file).await.map_err(|e| {
176            anyhow::anyhow!("Failed uploading to bucket / object store {bucket_name}/{key}: {e}")
177        })?;
178
179        Ok(())
180    }
181
182    /// Download file from NATS at this URL
183    pub async fn object_store_download(
184        &self,
185        nats_url: &Url,
186        filepath: &Path,
187    ) -> anyhow::Result<()> {
188        let mut disk_file = TokioFile::create(filepath).await?;
189
190        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
191        let bucket = self.get_or_create_bucket(&bucket_name, false).await?;
192
193        let mut obj_reader = bucket.get(&key).await.map_err(|e| {
194            anyhow::anyhow!(
195                "Failed downloading from bucket / object store {bucket_name}/{key}: {e}"
196            )
197        })?;
198        let _bytes_copied = tokio::io::copy(&mut obj_reader, &mut disk_file).await?;
199
200        Ok(())
201    }
202
203    /// Delete a bucket and all it's contents from the NATS object store
204    pub async fn object_store_delete_bucket(&self, bucket_name: &str) -> anyhow::Result<()> {
205        let context = self.jetstream();
206        match context.delete_object_store(&bucket_name).await {
207            Ok(_) => Ok(()),
208            Err(err) if err.to_string().contains("stream not found") => {
209                tracing::trace!(bucket_name, "NATS bucket already gone");
210                Ok(())
211            }
212            Err(err) => Err(anyhow::anyhow!("NATS get_object_store error: {err}")),
213        }
214    }
215
216    /// Upload a serializable struct to NATS object store using bincode
217    pub async fn object_store_upload_data<T>(&self, data: &T, nats_url: &Url) -> anyhow::Result<()>
218    where
219        T: Serialize,
220    {
221        // Serialize the data using bincode (more efficient binary format)
222        let binary_data = bincode::serialize(data)
223            .map_err(|e| anyhow::anyhow!("Failed to serialize data with bincode: {e}"))?;
224
225        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
226        let bucket = self.get_or_create_bucket(&bucket_name, true).await?;
227
228        let key_meta = async_nats::jetstream::object_store::ObjectMetadata {
229            name: key.to_string(),
230            ..Default::default()
231        };
232
233        // Upload the serialized bytes
234        let mut cursor = std::io::Cursor::new(binary_data);
235        bucket.put(key_meta, &mut cursor).await.map_err(|e| {
236            anyhow::anyhow!("Failed uploading to bucket / object store {bucket_name}/{key}: {e}")
237        })?;
238
239        Ok(())
240    }
241
242    /// Download and deserialize a struct from NATS object store using bincode
243    pub async fn object_store_download_data<T>(&self, nats_url: &Url) -> anyhow::Result<T>
244    where
245        T: DeserializeOwned,
246    {
247        let (bucket_name, key) = url_to_bucket_and_key(nats_url)?;
248        let bucket = self.get_or_create_bucket(&bucket_name, false).await?;
249
250        let mut obj_reader = bucket.get(&key).await.map_err(|e| {
251            anyhow::anyhow!(
252                "Failed downloading from bucket / object store {bucket_name}/{key}: {e}"
253            )
254        })?;
255
256        // Read all bytes into memory
257        let mut buffer = Vec::new();
258        tokio::io::copy(&mut obj_reader, &mut buffer)
259            .await
260            .map_err(|e| anyhow::anyhow!("Failed reading object data: {e}"))?;
261        tracing::debug!("Downloaded {} bytes from {bucket_name}/{key}", buffer.len());
262
263        // Deserialize from bincode
264        let data = bincode::deserialize(&buffer)
265            .map_err(|e| anyhow::anyhow!("Failed to deserialize data with bincode: {e}"))?;
266
267        Ok(data)
268    }
269}
270
271/// NATS client options
272///
273/// This object uses the builder pattern with default values that are evaluates
274/// from the environment variables if they are not explicitly set by the builder.
275#[derive(Debug, Clone, Builder, Validate)]
276pub struct ClientOptions {
277    #[builder(setter(into), default = "default_server()")]
278    #[validate(custom(function = "validate_nats_server"))]
279    server: String,
280
281    #[builder(default)]
282    auth: NatsAuth,
283}
284
285fn default_server() -> String {
286    if let Ok(server) = std::env::var("NATS_SERVER") {
287        return server;
288    }
289
290    "nats://localhost:4222".to_string()
291}
292
293fn validate_nats_server(server: &str) -> Result<(), ValidationError> {
294    if server.starts_with("nats://") {
295        Ok(())
296    } else {
297        Err(ValidationError::new("server must start with 'nats://'"))
298    }
299}
300
301// TODO(jthomson04): We really shouldn't be hardcoding this.
302const NATS_WORKER_THREADS: usize = 4;
303
304impl ClientOptions {
305    /// Create a new [`ClientOptionsBuilder`]
306    pub fn builder() -> ClientOptionsBuilder {
307        ClientOptionsBuilder::default()
308    }
309
310    /// Validate the config and attempt to connection to the NATS server
311    pub async fn connect(self) -> Result<Client> {
312        self.validate()?;
313
314        let client = match self.auth {
315            NatsAuth::UserPass(username, password) => {
316                async_nats::ConnectOptions::with_user_and_password(username, password)
317            }
318            NatsAuth::Token(token) => async_nats::ConnectOptions::with_token(token),
319            NatsAuth::NKey(nkey) => async_nats::ConnectOptions::with_nkey(nkey),
320            NatsAuth::CredentialsFile(path) => {
321                async_nats::ConnectOptions::with_credentials_file(path).await?
322            }
323        };
324
325        let (client, _) = build_in_runtime(
326            async move {
327                client
328                    .connect(self.server)
329                    .await
330                    .map_err(|e| anyhow::anyhow!("Failed to connect to NATS: {e}. Verify NATS server is running and accessible."))
331            },
332            NATS_WORKER_THREADS,
333        )
334        .await?;
335
336        let js_ctx = jetstream::new(client.clone());
337
338        // Validate JetStream is available
339        js_ctx
340            .query_account()
341            .await
342            .map_err(|e| anyhow::anyhow!("JetStream not available: {e}"))?;
343
344        Ok(Client { client, js_ctx })
345    }
346}
347
348impl Default for ClientOptions {
349    fn default() -> Self {
350        ClientOptions {
351            server: default_server(),
352            auth: NatsAuth::default(),
353        }
354    }
355}
356
357#[derive(Clone, Eq, PartialEq)]
358pub enum NatsAuth {
359    UserPass(String, String),
360    Token(String),
361    NKey(String),
362    CredentialsFile(PathBuf),
363}
364
365impl std::fmt::Debug for NatsAuth {
366    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        match self {
368            NatsAuth::UserPass(user, _pass) => {
369                write!(f, "UserPass({}, <redacted>)", user)
370            }
371            NatsAuth::Token(_token) => write!(f, "Token(<redacted>)"),
372            NatsAuth::NKey(_nkey) => write!(f, "NKey(<redacted>)"),
373            NatsAuth::CredentialsFile(path) => write!(f, "CredentialsFile({:?})", path),
374        }
375    }
376}
377
378impl Default for NatsAuth {
379    fn default() -> Self {
380        if let (Ok(username), Ok(password)) = (
381            std::env::var("NATS_AUTH_USERNAME"),
382            std::env::var("NATS_AUTH_PASSWORD"),
383        ) {
384            return NatsAuth::UserPass(username, password);
385        }
386
387        if let Ok(token) = std::env::var("NATS_AUTH_TOKEN") {
388            return NatsAuth::Token(token);
389        }
390
391        if let Ok(nkey) = std::env::var("NATS_AUTH_NKEY") {
392            return NatsAuth::NKey(nkey);
393        }
394
395        if let Ok(path) = std::env::var("NATS_AUTH_CREDENTIALS_FILE") {
396            return NatsAuth::CredentialsFile(PathBuf::from(path));
397        }
398
399        NatsAuth::UserPass("user".to_string(), "user".to_string())
400    }
401}
402
403/// Extract NATS bucket and key from a nats URL of the form:
404/// nats://host[:port]/bucket/key
405pub fn url_to_bucket_and_key(url: &Url) -> anyhow::Result<(String, String)> {
406    let Some(mut path_segments) = url.path_segments() else {
407        anyhow::bail!("No path in NATS URL: {url}");
408    };
409    let Some(bucket) = path_segments.next() else {
410        anyhow::bail!("No bucket in NATS URL: {url}");
411    };
412    let Some(key) = path_segments.next() else {
413        anyhow::bail!("No key in NATS URL: {url}");
414    };
415    Ok((bucket.to_string(), key.to_string()))
416}
417
418/// Default queue name for publishing events
419pub const QUEUE_NAME: &str = "queue";
420
421/// A queue implementation using NATS JetStream
422pub struct NatsQueue {
423    /// The name of the stream to use for the queue
424    stream_name: String,
425    /// The NATS server URL
426    nats_server: String,
427    /// Timeout for dequeue operations in seconds
428    dequeue_timeout: time::Duration,
429    /// The NATS client
430    client: Option<Client>,
431    /// The subject pattern used for this queue
432    subject: String,
433    /// The subscriber for pull-based consumption
434    subscriber: Option<jetstream::consumer::PullConsumer>,
435    /// Optional consumer name for broadcast pattern (if None, uses "worker-group")
436    consumer_name: Option<String>,
437    /// Message stream for efficient message consumption
438    message_stream: Option<jetstream::consumer::pull::Stream>,
439}
440
441impl NatsQueue {
442    /// Create a new NatsQueue with the default "worker-group" consumer
443    pub fn new(stream_name: String, nats_server: String, dequeue_timeout: time::Duration) -> Self {
444        // Sanitize stream name to remove path separators (like in Python version)
445        // rupei: are we sure NATs stream name accepts '_'?
446        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
447        let subject = format!("{sanitized_stream_name}.*");
448
449        Self {
450            stream_name: sanitized_stream_name,
451            nats_server,
452            dequeue_timeout,
453            client: None,
454            subject,
455            subscriber: None,
456            consumer_name: Some("worker-group".to_string()),
457            message_stream: None,
458        }
459    }
460
461    /// Create a new NatsQueue without a consumer (publisher-only mode)
462    pub fn new_without_consumer(
463        stream_name: String,
464        nats_server: String,
465        dequeue_timeout: time::Duration,
466    ) -> Self {
467        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
468        let subject = format!("{sanitized_stream_name}.*");
469
470        Self {
471            stream_name: sanitized_stream_name,
472            nats_server,
473            dequeue_timeout,
474            client: None,
475            subject,
476            subscriber: None,
477            consumer_name: None,
478            message_stream: None,
479        }
480    }
481
482    /// Create a new NatsQueue with a specific consumer name for broadcast pattern
483    /// Each consumer with a unique name will receive all messages independently
484    pub fn new_with_consumer(
485        stream_name: String,
486        nats_server: String,
487        dequeue_timeout: time::Duration,
488        consumer_name: String,
489    ) -> Self {
490        let sanitized_stream_name = Slug::slugify(&stream_name).to_string();
491        let subject = format!("{sanitized_stream_name}.*");
492
493        Self {
494            stream_name: sanitized_stream_name,
495            nats_server,
496            dequeue_timeout,
497            client: None,
498            subject,
499            subscriber: None,
500            consumer_name: Some(consumer_name),
501            message_stream: None,
502        }
503    }
504
505    /// Connect to the NATS server and set up the stream and consumer
506    pub async fn connect(&mut self) -> Result<()> {
507        self.connect_with_reset(false).await
508    }
509
510    /// Connect to the NATS server and set up the stream and consumer, optionally resetting the stream
511    pub async fn connect_with_reset(&mut self, reset_stream: bool) -> Result<()> {
512        if self.client.is_none() {
513            // Create a new client
514            let client_options = Client::builder().server(self.nats_server.clone()).build()?;
515
516            let client = client_options.connect().await?;
517
518            // messages older than a hour in the stream will be automatically purged
519            let max_age = std::env::var("DYN_NATS_STREAM_MAX_AGE")
520                .ok()
521                .and_then(|s| s.parse::<u64>().ok())
522                .map(time::Duration::from_secs)
523                .unwrap_or_else(|| time::Duration::from_secs(60 * 60));
524
525            let stream_config = jetstream::stream::Config {
526                name: self.stream_name.clone(),
527                subjects: vec![self.subject.clone()],
528                max_age,
529                ..Default::default()
530            };
531
532            // Get or create the stream
533            let stream = client
534                .jetstream()
535                .get_or_create_stream(stream_config)
536                .await?;
537
538            log::debug!("Stream {} is ready", self.stream_name);
539
540            // If reset_stream is true, purge all messages from the stream
541            if reset_stream {
542                match stream.purge().await {
543                    Ok(purge_info) => {
544                        log::info!(
545                            "Successfully purged {} messages from NATS stream {}",
546                            purge_info.purged,
547                            self.stream_name
548                        );
549                    }
550                    Err(e) => {
551                        log::warn!("Failed to purge NATS stream '{}': {e}", self.stream_name);
552                    }
553                }
554            }
555
556            // Create persistent subscriber only if consumer_name is set
557            if let Some(ref consumer_name) = self.consumer_name {
558                let consumer_config = jetstream::consumer::pull::Config {
559                    durable_name: Some(consumer_name.clone()),
560                    inactive_threshold: std::time::Duration::from_secs(3600), // 1 hour
561                    ..Default::default()
562                };
563
564                let subscriber = stream.create_consumer(consumer_config).await?;
565
566                // Create the message stream for efficient consumption
567                let message_stream = subscriber.messages().await?;
568
569                self.subscriber = Some(subscriber);
570                self.message_stream = Some(message_stream);
571            }
572
573            self.client = Some(client);
574        }
575
576        Ok(())
577    }
578
579    /// Ensure we have an active connection
580    pub async fn ensure_connection(&mut self) -> Result<()> {
581        if self.client.is_none() {
582            self.connect().await?;
583        }
584        Ok(())
585    }
586
587    /// Close the connection when done
588    pub async fn close(&mut self) -> Result<()> {
589        self.message_stream = None;
590        self.subscriber = None;
591        self.client = None;
592        Ok(())
593    }
594
595    /// Shutdown the consumer by deleting it from the stream and closing the connection
596    /// This permanently removes the consumer from the server
597    ///
598    /// If `consumer_name` is provided, that specific consumer will be deleted instead of the
599    /// current consumer. This allows deletion of other consumers on the same stream.
600    pub async fn shutdown(&mut self, consumer_name: Option<String>) -> Result<()> {
601        // Determine which consumer to delete
602        let target_consumer = consumer_name.as_ref().or(self.consumer_name.as_ref());
603
604        // Warn if deleting our own consumer via explicit parameter
605        if let Some(ref passed_name) = consumer_name
606            && self.consumer_name.as_ref() == Some(passed_name)
607        {
608            log::warn!(
609                "Deleting our own consumer '{}' via explicit consumer_name parameter. \
610                Consider calling shutdown without arguments instead.",
611                passed_name
612            );
613        }
614
615        if let (Some(client), Some(consumer_to_delete)) = (&self.client, target_consumer) {
616            // Get the stream and delete the consumer
617            let stream = client.jetstream().get_stream(&self.stream_name).await?;
618            stream
619                .delete_consumer(consumer_to_delete)
620                .await
621                .map_err(|e| {
622                    anyhow::anyhow!("Failed to delete consumer {}: {}", consumer_to_delete, e)
623                })?;
624            log::debug!(
625                "Deleted consumer {} from stream {}",
626                consumer_to_delete,
627                self.stream_name
628            );
629        } else {
630            log::debug!(
631                "Cannot shutdown consumer: client or target consumer is None (client: {:?}, target_consumer: {:?})",
632                self.client.is_some(),
633                target_consumer.is_some()
634            );
635        }
636
637        // Only close the connection if we deleted our own consumer
638        if consumer_name.is_none() {
639            self.close().await
640        } else {
641            Ok(())
642        }
643    }
644
645    /// Count the number of consumers for the stream
646    pub async fn count_consumers(&mut self) -> Result<usize> {
647        self.ensure_connection().await?;
648
649        if let Some(client) = &self.client {
650            let mut stream = client.jetstream().get_stream(&self.stream_name).await?;
651            let info = stream.info().await?;
652            Ok(info.state.consumer_count)
653        } else {
654            Err(anyhow::anyhow!("Client not connected"))
655        }
656    }
657
658    /// List all consumer names for the stream
659    pub async fn list_consumers(&mut self) -> Result<Vec<String>> {
660        self.ensure_connection().await?;
661
662        if let Some(client) = &self.client {
663            client.list_consumers(&self.stream_name).await
664        } else {
665            Err(anyhow::anyhow!("Client not connected"))
666        }
667    }
668
669    /// Enqueue a task using the provided data
670    pub async fn enqueue_task(&mut self, task_data: Bytes) -> Result<()> {
671        self.ensure_connection().await?;
672
673        if let Some(client) = &self.client {
674            let subject = format!("{}.queue", self.stream_name);
675            client.jetstream().publish(subject, task_data).await?;
676            Ok(())
677        } else {
678            Err(anyhow::anyhow!("Client not connected"))
679        }
680    }
681
682    /// Dequeue and return a task as raw bytes
683    pub async fn dequeue_task(&mut self, timeout: Option<time::Duration>) -> Result<Option<Bytes>> {
684        self.ensure_connection().await?;
685
686        let Some(ref mut stream) = self.message_stream else {
687            return Err(anyhow::anyhow!("Message stream not initialized"));
688        };
689
690        let timeout_duration = timeout.unwrap_or(self.dequeue_timeout);
691
692        // Try to get next message from the stream with timeout
693        let message = tokio::time::timeout(timeout_duration, stream.next()).await;
694
695        match message {
696            Ok(Some(Ok(msg))) => {
697                msg.ack()
698                    .await
699                    .map_err(|e| anyhow::anyhow!("Failed to ack message: {}", e))?;
700                Ok(Some(msg.payload.clone()))
701            }
702
703            Ok(Some(Err(e))) => Err(anyhow::anyhow!("Failed to get message from stream: {}", e)),
704
705            Ok(None) => Err(anyhow::anyhow!("Message stream ended unexpectedly")),
706
707            // Timeout - no messages available
708            Err(_) => Ok(None),
709        }
710    }
711
712    /// Get the number of messages currently in the queue
713    pub async fn get_queue_size(&mut self) -> Result<u64> {
714        self.ensure_connection().await?;
715
716        if let Some(client) = &self.client {
717            // Get consumer info to get pending messages count
718            let stream = client.jetstream().get_stream(&self.stream_name).await?;
719            let consumer_name = self
720                .consumer_name
721                .clone()
722                .unwrap_or_else(|| "worker-group".to_string());
723            let mut consumer: jetstream::consumer::PullConsumer = stream
724                .get_consumer(&consumer_name)
725                .await
726                .map_err(|e| anyhow::anyhow!("Failed to get consumer: {}", e))?;
727            let info = consumer.info().await?;
728
729            Ok(info.num_pending)
730        } else {
731            Err(anyhow::anyhow!("Client not connected"))
732        }
733    }
734
735    /// Get the total number of messages currently in the stream
736    pub async fn get_stream_messages(&mut self) -> Result<u64> {
737        self.ensure_connection().await?;
738
739        if let Some(client) = &self.client {
740            let mut stream = client.jetstream().get_stream(&self.stream_name).await?;
741            let info = stream.info().await?;
742            Ok(info.state.messages)
743        } else {
744            Err(anyhow::anyhow!("Client not connected"))
745        }
746    }
747
748    /// Purge messages from the stream up to (but not including) the specified sequence number
749    /// This permanently removes messages and affects all consumers of the stream
750    pub async fn purge_up_to_sequence(&self, sequence: u64) -> Result<()> {
751        if let Some(client) = &self.client {
752            let stream = client.jetstream().get_stream(&self.stream_name).await?;
753
754            // NOTE: this purge excludes the sequence itself
755            // https://docs.rs/nats/latest/nats/jetstream/struct.PurgeRequest.html
756            stream.purge().sequence(sequence).await.map_err(|e| {
757                anyhow::anyhow!("Failed to purge stream up to sequence {}: {}", sequence, e)
758            })?;
759
760            log::debug!(
761                "Purged stream {} up to sequence {}",
762                self.stream_name,
763                sequence
764            );
765            Ok(())
766        } else {
767            Err(anyhow::anyhow!("Client not connected"))
768        }
769    }
770
771    /// Purge messages from the stream up to the minimum acknowledged sequence across all consumers
772    /// This finds the lowest acknowledged sequence number across all consumers and purges up to that point
773    pub async fn purge_acknowledged(&mut self) -> Result<()> {
774        self.ensure_connection().await?;
775
776        let Some(client) = &self.client else {
777            return Err(anyhow::anyhow!("Client not connected"));
778        };
779
780        let stream = client.jetstream().get_stream(&self.stream_name).await?;
781
782        // Get all consumer names for the stream
783        let consumer_names: Vec<String> = stream
784            .consumer_names()
785            .try_collect()
786            .await
787            .map_err(|e| anyhow::anyhow!("Failed to list consumers: {}", e))?;
788
789        if consumer_names.is_empty() {
790            log::debug!("No consumers found for stream {}", self.stream_name);
791            return Ok(());
792        }
793
794        // Find the minimum acknowledged sequence across all consumers
795        let mut min_ack_sequence = u64::MAX;
796
797        for consumer_name in &consumer_names {
798            let mut consumer: jetstream::consumer::PullConsumer = stream
799                .get_consumer(consumer_name)
800                .await
801                .map_err(|e| anyhow::anyhow!("Failed to get consumer {}: {}", consumer_name, e))?;
802
803            let info = consumer.info().await.map_err(|e| {
804                anyhow::anyhow!("Failed to get consumer info for {}: {}", consumer_name, e)
805            })?;
806
807            // The ack_floor contains the stream sequence of the highest contiguously acknowledged message
808            // If stream_sequence is 0, it means no messages have been acknowledged yet
809            if info.ack_floor.stream_sequence > 0 {
810                min_ack_sequence = min_ack_sequence.min(info.ack_floor.stream_sequence);
811                log::debug!(
812                    "Consumer {} has ack_floor at sequence {}",
813                    consumer_name,
814                    info.ack_floor.stream_sequence
815                );
816            }
817        }
818
819        // Only purge if we found a valid minimum acknowledged sequence
820        if min_ack_sequence < u64::MAX && min_ack_sequence > 0 {
821            // Purge up to (but not including) the minimum acknowledged sequence + 1
822            // We add 1 because we want to include the minimum acknowledged message in the purge
823            let purge_sequence = min_ack_sequence + 1;
824
825            self.purge_up_to_sequence(purge_sequence).await?;
826
827            log::debug!(
828                "Purged stream {} up to acknowledged sequence {} (purged up to sequence {})",
829                self.stream_name,
830                min_ack_sequence,
831                purge_sequence
832            );
833        } else {
834            log::debug!(
835                "No messages to purge for stream {} (min_ack_sequence: {})",
836                self.stream_name,
837                min_ack_sequence
838            );
839        }
840
841        Ok(())
842    }
843}
844
845#[async_trait]
846impl EventPublisher for NatsQueue {
847    fn subject(&self) -> String {
848        self.stream_name.clone()
849    }
850
851    async fn publish(
852        &self,
853        event_name: impl AsRef<str> + Send + Sync,
854        event: &(impl Serialize + Send + Sync),
855    ) -> Result<()> {
856        let bytes = serde_json::to_vec(event)?;
857        self.publish_bytes(event_name, bytes).await
858    }
859
860    async fn publish_bytes(
861        &self,
862        event_name: impl AsRef<str> + Send + Sync,
863        bytes: Vec<u8>,
864    ) -> Result<()> {
865        // We expect the stream to be always suffixed with "queue"
866        // This suffix itself is nothing special, just a repo standard
867        if event_name.as_ref() != QUEUE_NAME {
868            tracing::warn!(
869                "Expected event_name to be '{}', but got '{}'",
870                QUEUE_NAME,
871                event_name.as_ref()
872            );
873        }
874
875        let subject = format!("{}.{}", self.subject(), event_name.as_ref());
876
877        // Note: enqueue_task requires &mut self, but EventPublisher requires &self
878        // We need to ensure the client is connected and use it directly
879        if let Some(client) = &self.client {
880            client.jetstream().publish(subject, bytes.into()).await?;
881            Ok(())
882        } else {
883            Err(anyhow::anyhow!("Client not connected"))
884        }
885    }
886}
887
888/// Prometheus metrics that mirror the NATS client statistics (in primitive types)
889/// to be used for the System Status Server.
890///
891/// ⚠️  IMPORTANT: These Prometheus Gauges are COPIES of NATS client data, not live references!
892///
893/// How it works:
894/// 1. NATS client provides source data via client.statistics() and connection_state()
895/// 2. set_from_client_stats() reads current NATS values and updates these Prometheus Gauges
896/// 3. Prometheus scrapes these Gauge values (snapshots, not live data)
897///
898/// Flow: NATS Client → Client Statistics → set_from_client_stats() → Prometheus Gauge
899/// Note: These are snapshots updated when set_from_client_stats() is called.
900#[derive(Debug, Clone)]
901pub struct DRTNatsClientPrometheusMetrics {
902    nats_client: client::Client,
903    /// Number of bytes received (excluding protocol overhead)
904    pub in_bytes: IntGauge,
905    /// Number of bytes sent (excluding protocol overhead)
906    pub out_bytes: IntGauge,
907    /// Number of messages received
908    pub in_messages: IntGauge,
909    /// Number of messages sent
910    pub out_messages: IntGauge,
911    /// Number of times connection was established
912    pub connects: IntGauge,
913    /// Current connection state (0 = disconnected, 1 = connected, 2 = reconnecting)
914    pub connection_state: IntGauge,
915}
916
917impl DRTNatsClientPrometheusMetrics {
918    /// Create a new instance of NATS client metrics using a DistributedRuntime's Prometheus constructors
919    pub fn new(drt: &crate::DistributedRuntime, nats_client: client::Client) -> Result<Self> {
920        let metrics = drt.metrics();
921        let in_bytes = metrics.create_intgauge(
922            nats_metrics::IN_TOTAL_BYTES,
923            "Total number of bytes received by NATS client",
924            &[],
925        )?;
926        let out_bytes = metrics.create_intgauge(
927            nats_metrics::OUT_OVERHEAD_BYTES,
928            "Total number of bytes sent by NATS client",
929            &[],
930        )?;
931        let in_messages = metrics.create_intgauge(
932            nats_metrics::IN_MESSAGES,
933            "Total number of messages received by NATS client",
934            &[],
935        )?;
936        let out_messages = metrics.create_intgauge(
937            nats_metrics::OUT_MESSAGES,
938            "Total number of messages sent by NATS client",
939            &[],
940        )?;
941        let connects = metrics.create_intgauge(
942            nats_metrics::CURRENT_CONNECTIONS,
943            "Current number of active connections for NATS client",
944            &[],
945        )?;
946        let connection_state = metrics.create_intgauge(
947            nats_metrics::CONNECTION_STATE,
948            "Current connection state of NATS client (0=disconnected, 1=connected, 2=reconnecting)",
949            &[],
950        )?;
951
952        Ok(Self {
953            nats_client,
954            in_bytes,
955            out_bytes,
956            in_messages,
957            out_messages,
958            connects,
959            connection_state,
960        })
961    }
962
963    /// Copy statistics from the stored NATS client to these Prometheus metrics
964    pub fn set_from_client_stats(&self) {
965        let stats = self.nats_client.statistics();
966
967        // Get current values from the client statistics
968        let in_bytes = stats.in_bytes.load(Ordering::Relaxed);
969        let out_bytes = stats.out_bytes.load(Ordering::Relaxed);
970        let in_messages = stats.in_messages.load(Ordering::Relaxed);
971        let out_messages = stats.out_messages.load(Ordering::Relaxed);
972        let connects = stats.connects.load(Ordering::Relaxed);
973
974        // Get connection state
975        let connection_state = match self.nats_client.connection_state() {
976            State::Connected => 1,
977            // treat Disconnected and Pending as "down"
978            State::Disconnected | State::Pending => 0,
979        };
980
981        // Update Prometheus metrics
982        // Using gauges allows us to set absolute values directly
983        self.in_bytes.set(in_bytes as i64);
984        self.out_bytes.set(out_bytes as i64);
985        self.in_messages.set(in_messages as i64);
986        self.out_messages.set(out_messages as i64);
987        self.connects.set(connects as i64);
988        self.connection_state.set(connection_state);
989    }
990}
991
992#[cfg(test)]
993mod tests {
994
995    use super::*;
996    use figment::Jail;
997    use serde::{Deserialize, Serialize};
998
999    #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
1000    struct TestData {
1001        id: u32,
1002        name: String,
1003        values: Vec<f64>,
1004    }
1005
1006    #[test]
1007    fn test_client_options_builder() {
1008        Jail::expect_with(|_jail| {
1009            let opts = ClientOptions::builder().build();
1010            assert!(opts.is_ok());
1011            Ok(())
1012        });
1013
1014        Jail::expect_with(|jail| {
1015            jail.set_env("NATS_SERVER", "nats://localhost:5222");
1016            jail.set_env("NATS_AUTH_USERNAME", "user");
1017            jail.set_env("NATS_AUTH_PASSWORD", "pass");
1018
1019            let opts = ClientOptions::builder().build();
1020            assert!(opts.is_ok());
1021            let opts = opts.unwrap();
1022
1023            assert_eq!(opts.server, "nats://localhost:5222");
1024            assert_eq!(
1025                opts.auth,
1026                NatsAuth::UserPass("user".to_string(), "pass".to_string())
1027            );
1028
1029            Ok(())
1030        });
1031
1032        Jail::expect_with(|jail| {
1033            jail.set_env("NATS_SERVER", "nats://localhost:5222");
1034            jail.set_env("NATS_AUTH_USERNAME", "user");
1035            jail.set_env("NATS_AUTH_PASSWORD", "pass");
1036
1037            let opts = ClientOptions::builder()
1038                .server("nats://localhost:6222")
1039                .auth(NatsAuth::Token("token".to_string()))
1040                .build();
1041            assert!(opts.is_ok());
1042            let opts = opts.unwrap();
1043
1044            assert_eq!(opts.server, "nats://localhost:6222");
1045            assert_eq!(opts.auth, NatsAuth::Token("token".to_string()));
1046
1047            Ok(())
1048        });
1049    }
1050
1051    // Integration test for object store data operations using bincode
1052    #[tokio::test]
1053    #[ignore] // Requires NATS server to be running
1054    async fn test_object_store_data_operations() {
1055        // Create test data
1056        let test_data = TestData {
1057            id: 42,
1058            name: "test_item".to_string(),
1059            values: vec![1.0, 2.5, 3.7, 4.2],
1060        };
1061
1062        // Set up client
1063        let client_options = ClientOptions::builder()
1064            .server("nats://localhost:4222")
1065            .build()
1066            .expect("Failed to build client options");
1067
1068        let client = client_options
1069            .connect()
1070            .await
1071            .expect("Failed to connect to NATS");
1072
1073        // Test URL (using .bin extension to indicate binary format)
1074        let url =
1075            Url::parse("nats://localhost/test-bucket/test-data.bin").expect("Failed to parse URL");
1076
1077        // Upload the data
1078        client
1079            .object_store_upload_data(&test_data, &url)
1080            .await
1081            .expect("Failed to upload data");
1082
1083        // Download the data
1084        let downloaded_data: TestData = client
1085            .object_store_download_data(&url)
1086            .await
1087            .expect("Failed to download data");
1088
1089        // Verify the data matches
1090        assert_eq!(test_data, downloaded_data);
1091
1092        // Clean up
1093        client
1094            .object_store_delete_bucket("test-bucket")
1095            .await
1096            .expect("Failed to delete bucket");
1097    }
1098
1099    // Integration test for broadcast pattern with purging
1100    #[tokio::test]
1101    #[ignore]
1102    async fn test_nats_queue_broadcast_with_purge() {
1103        use uuid::Uuid;
1104
1105        // Create unique stream name for this test
1106        let stream_name = format!("test-broadcast-{}", Uuid::new_v4());
1107        let nats_server = "nats://localhost:4222".to_string();
1108        let timeout = time::Duration::from_secs(0);
1109
1110        // Connect to NATS client first to delete stream if it exists
1111        let client_options = Client::builder()
1112            .server(nats_server.clone())
1113            .build()
1114            .expect("Failed to build client options");
1115
1116        let client = client_options
1117            .connect()
1118            .await
1119            .expect("Failed to connect to NATS");
1120
1121        // Delete the stream if it exists (to ensure clean start)
1122        let _ = client.jetstream().delete_stream(&stream_name).await;
1123
1124        // Create two consumers with different names for the same stream
1125        let consumer1_name = format!("consumer-{}", Uuid::new_v4());
1126        let consumer2_name = format!("consumer-{}", Uuid::new_v4());
1127
1128        let mut queue1 = NatsQueue::new_with_consumer(
1129            stream_name.clone(),
1130            nats_server.clone(),
1131            timeout,
1132            consumer1_name,
1133        );
1134
1135        // Connect queue1 first (it will create the stream)
1136        queue1.connect().await.expect("Failed to connect queue1");
1137
1138        // Send 4 messages using the EventPublisher trait
1139        let message_strings = [
1140            "message1".to_string(),
1141            "message2".to_string(),
1142            "message3".to_string(),
1143            "message4".to_string(),
1144        ];
1145
1146        // Using the EventPublisher trait to publish messages
1147        for (idx, msg) in message_strings.iter().enumerate() {
1148            queue1
1149                .publish("queue", msg)
1150                .await
1151                .unwrap_or_else(|_| panic!("Failed to publish message {}", idx + 1));
1152        }
1153
1154        // Convert messages to JSON-serialized Bytes for comparison
1155        let messages: Vec<Bytes> = message_strings
1156            .iter()
1157            .map(|s| Bytes::from(serde_json::to_vec(s).unwrap()))
1158            .collect();
1159
1160        // Give JetStream a moment to persist the messages
1161        tokio::time::sleep(time::Duration::from_millis(100)).await;
1162
1163        // Now create and connect queue2 and queue3 AFTER messages are published (to test persistence)
1164        let mut queue2 = NatsQueue::new_with_consumer(
1165            stream_name.clone(),
1166            nats_server.clone(),
1167            timeout,
1168            consumer2_name,
1169        );
1170
1171        // Create a third queue without consumer (publisher-only)
1172        let mut queue3 =
1173            NatsQueue::new_without_consumer(stream_name.clone(), nats_server.clone(), timeout);
1174
1175        // Connect queue2 and queue3 after messages are already published
1176        queue2.connect().await.expect("Failed to connect queue2");
1177        queue3.connect().await.expect("Failed to connect queue3");
1178
1179        // Purge the first two messages (sequence 1 and 2)
1180        // Note: JetStream sequences start at 1, and purge is exclusive of the sequence number
1181        queue1
1182            .purge_up_to_sequence(3)
1183            .await
1184            .expect("Failed to purge messages");
1185
1186        // Give JetStream a moment to process the purge
1187        tokio::time::sleep(time::Duration::from_millis(100)).await;
1188
1189        // Consumer 1 dequeues one message (message3)
1190        let msg3_consumer1 = queue1
1191            .dequeue_task(Some(time::Duration::from_millis(500)))
1192            .await
1193            .expect("Failed to dequeue from queue1");
1194        assert_eq!(
1195            msg3_consumer1,
1196            Some(messages[2].clone()),
1197            "Consumer 1 should get message3"
1198        );
1199
1200        // Give JetStream a moment to process acknowledgments
1201        tokio::time::sleep(time::Duration::from_millis(100)).await;
1202
1203        // Now run purge_acknowledged
1204        // At this point:
1205        // - Consumer 1 has ack'd message 3 (ack_floor = 3)
1206        // - Consumer 2 hasn't consumed anything yet (ack_floor = 0)
1207        // - Min ack_floor = 0, so nothing will be purged
1208        queue1
1209            .purge_acknowledged()
1210            .await
1211            .expect("Failed to purge acknowledged messages");
1212
1213        // Give JetStream a moment to process the purge
1214        tokio::time::sleep(time::Duration::from_millis(100)).await;
1215
1216        // Now collect remaining messages from both consumers
1217        let mut consumer1_remaining = Vec::new();
1218        let mut consumer2_remaining = Vec::new();
1219
1220        // Collect remaining messages from consumer 1
1221        while let Some(msg) = queue1
1222            .dequeue_task(None)
1223            .await
1224            .expect("Failed to dequeue from queue1")
1225        {
1226            consumer1_remaining.push(msg);
1227        }
1228
1229        // Collect remaining messages from consumer 2
1230        while let Some(msg) = queue2
1231            .dequeue_task(None)
1232            .await
1233            .expect("Failed to dequeue from queue2")
1234        {
1235            consumer2_remaining.push(msg);
1236        }
1237
1238        // Verify consumer 1 gets 1 remaining message (message4)
1239        assert_eq!(
1240            consumer1_remaining.len(),
1241            1,
1242            "Consumer 1 should have 1 remaining message"
1243        );
1244        assert_eq!(
1245            consumer1_remaining[0], messages[3],
1246            "Consumer 1 should get message4"
1247        );
1248
1249        // Verify consumer 2 gets 2 messages (message3 and message4)
1250        assert_eq!(
1251            consumer2_remaining.len(),
1252            2,
1253            "Consumer 2 should have 2 messages"
1254        );
1255        assert_eq!(
1256            consumer2_remaining[0], messages[2],
1257            "Consumer 2 should get message3"
1258        );
1259        assert_eq!(
1260            consumer2_remaining[1], messages[3],
1261            "Consumer 2 should get message4"
1262        );
1263
1264        // Test consumer count and shutdown behavior
1265        // First verify via consumer 1 that there are two consumers
1266        let consumer_count = queue1
1267            .count_consumers()
1268            .await
1269            .expect("Failed to count consumers");
1270        assert_eq!(consumer_count, 2, "Should have 2 consumers initially");
1271
1272        // Close consumer 1 and verify via consumer 2 that there are still two consumers
1273        queue1.close().await.expect("Failed to close queue1");
1274
1275        let consumer_count = queue2
1276            .count_consumers()
1277            .await
1278            .expect("Failed to count consumers");
1279        assert_eq!(
1280            consumer_count, 2,
1281            "Should still have 2 consumers after closing queue1"
1282        );
1283
1284        // Reconnect queue1 to be able to shutdown
1285        queue1.connect().await.expect("Failed to reconnect queue1");
1286
1287        // Shutdown consumer 1 and verify via consumer 2 that there is only one consumer left
1288        queue1
1289            .shutdown(None)
1290            .await
1291            .expect("Failed to shutdown queue1");
1292
1293        let consumer_count = queue2
1294            .count_consumers()
1295            .await
1296            .expect("Failed to count consumers");
1297        assert_eq!(
1298            consumer_count, 1,
1299            "Should have only 1 consumer after shutting down queue1"
1300        );
1301
1302        // Clean up by deleting the stream
1303        client
1304            .jetstream()
1305            .delete_stream(&stream_name)
1306            .await
1307            .expect("Failed to delete test stream");
1308    }
1309}