Skip to main content

rivet_async_nats/
lib.rs

1// Copyright 2020-2022 The NATS Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! A Rust asynchronous client for the NATS.io ecosystem.
15//!
16//! To access the repository, you can clone it by running:
17//!
18//! ```bash
19//! git clone https://github.com/nats-io/nats.rs
20//! ````
21//! NATS.io is a simple, secure, and high-performance open-source messaging
22//! system designed for cloud-native applications, IoT messaging, and microservices
23//! architectures.
24//!
25//! **Note**: The synchronous NATS API is deprecated and no longer actively maintained. If you need to use the deprecated synchronous API, you can refer to:
26//! <https://crates.io/crates/nats>
27//!
28//! For more information on NATS.io visit: <https://nats.io>
29//!
30//! ## Examples
31//!
32//! Below, you can find some basic examples on how to use this library.
33//!
34//! For more details, please refer to the specific methods and structures documentation.
35//!
36//! ### Complete example
37//!
38//! Connect to the NATS server, publish messages and subscribe to receive messages.
39//!
40//! ```no_run
41//! use bytes::Bytes;
42//! use futures_util::StreamExt;
43//!
44//! #[tokio::main]
45//! async fn main() -> Result<(), async_nats::Error> {
46//!     // Connect to the NATS server
47//!     let client = async_nats::connect("demo.nats.io").await?;
48//!
49//!     // Subscribe to the "messages" subject
50//!     let mut subscriber = client.subscribe("messages").await?;
51//!
52//!     // Publish messages to the "messages" subject
53//!     for _ in 0..10 {
54//!         client.publish("messages", "data".into()).await?;
55//!     }
56//!
57//!     // Receive and process messages
58//!     while let Some(message) = subscriber.next().await {
59//!         println!("Received message {:?}", message);
60//!     }
61//!
62//!     Ok(())
63//! }
64//! ```
65//!
66//! ### Publish
67//!
68//! Connect to the NATS server and publish messages to a subject.
69//!
70//! ```
71//! # use bytes::Bytes;
72//! # use std::error::Error;
73//! # use std::time::Instant;
74//! # #[tokio::main]
75//! # async fn main() -> Result<(), async_nats::Error> {
76//! // Connect to the NATS server
77//! let client = async_nats::connect("demo.nats.io").await?;
78//!
79//! // Prepare the subject and data
80//! let subject = "foo";
81//! let data = Bytes::from("bar");
82//!
83//! // Publish messages to the NATS server
84//! for _ in 0..10 {
85//!     client.publish(subject, data.clone()).await?;
86//! }
87//!
88//! // Flush internal buffer before exiting to make sure all messages are sent
89//! client.flush().await?;
90//!
91//! #    Ok(())
92//! # }
93//! ```
94//!
95//! ### Subscribe
96//!
97//! Connect to the NATS server, subscribe to a subject and receive messages.
98//!
99//! ```no_run
100//! # use bytes::Bytes;
101//! # use futures_util::StreamExt;
102//! # use std::error::Error;
103//! # use std::time::Instant;
104//! # #[tokio::main]
105//! # async fn main() -> Result<(), async_nats::Error> {
106//! // Connect to the NATS server
107//! let client = async_nats::connect("demo.nats.io").await?;
108//!
109//! // Subscribe to the "foo" subject
110//! let mut subscriber = client.subscribe("foo").await.unwrap();
111//!
112//! // Receive and process messages
113//! while let Some(message) = subscriber.next().await {
114//!     println!("Received message {:?}", message);
115//! }
116//! #     Ok(())
117//! # }
118//! ```
119//!
120//! ### JetStream
121//!
122//! To access JetStream API, create a JetStream [jetstream::Context].
123//!
124//! ```no_run
125//! # #[tokio::main]
126//! # async fn main() -> Result<(), async_nats::Error> {
127//! // Connect to the NATS server
128//! let client = async_nats::connect("demo.nats.io").await?;
129//! // Create a JetStream context.
130//! let jetstream = async_nats::jetstream::new(client);
131//!
132//! // Publish JetStream messages, manage streams, consumers, etc.
133//! jetstream.publish("foo", "bar".into()).await?;
134//! # Ok(())
135//! # }
136//! ```
137//!
138//! ### Key-value Store
139//!
140//! Key-value [Store][jetstream::kv::Store] is accessed through [jetstream::Context].
141//!
142//! ```no_run
143//! # #[tokio::main]
144//! # async fn main() -> Result<(), async_nats::Error> {
145//! // Connect to the NATS server
146//! let client = async_nats::connect("demo.nats.io").await?;
147//! // Create a JetStream context.
148//! let jetstream = async_nats::jetstream::new(client);
149//! // Access an existing key-value.
150//! let kv = jetstream.get_key_value("store").await?;
151//! # Ok(())
152//! # }
153//! ```
154//! ### Object Store
155//!
156//! Object [Store][jetstream::object_store::ObjectStore] is accessed through [jetstream::Context].
157//!
158//! ```no_run
159//! # #[tokio::main]
160//! # async fn main() -> Result<(), async_nats::Error> {
161//! // Connect to the NATS server
162//! let client = async_nats::connect("demo.nats.io").await?;
163//! // Create a JetStream context.
164//! let jetstream = async_nats::jetstream::new(client);
165//! // Access an existing key-value.
166//! let kv = jetstream.get_object_store("store").await?;
167//! # Ok(())
168//! # }
169//! ```
170//! ### Service API
171//!
172//! [Service API][service::Service] is accessible through [Client] after importing its trait.
173//!
174//! ```no_run
175//! # #[tokio::main]
176//! # async fn main() -> Result<(), async_nats::Error> {
177//! use async_nats::service::ServiceExt;
178//! // Connect to the NATS server
179//! let client = async_nats::connect("demo.nats.io").await?;
180//! let mut service = client
181//!     .service_builder()
182//!     .description("some service")
183//!     .stats_handler(|endpoint, stats| serde_json::json!({ "endpoint": endpoint }))
184//!     .start("products", "1.0.0")
185//!     .await?;
186//! # Ok(())
187//! # }
188//! ```
189
190#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_cfg))]
196
197use thiserror::Error;
198
199use futures_util::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use portable_atomic::AtomicU64;
206use std::collections::HashMap;
207use std::collections::VecDeque;
208use std::fmt::Display;
209use std::future::Future;
210use std::iter;
211use std::mem;
212use std::net::SocketAddr;
213use std::option;
214use std::pin::Pin;
215use std::slice;
216use std::str::{self, FromStr};
217use std::sync::atomic::AtomicUsize;
218use std::sync::atomic::Ordering;
219use std::sync::Arc;
220use std::task::{Context, Poll};
221use tokio::io::ErrorKind;
222use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
223use url::{Host, Url};
224
225use bytes::Bytes;
226use serde::{Deserialize, Serialize};
227use serde_repr::{Deserialize_repr, Serialize_repr};
228use tokio::io;
229use tokio::sync::mpsc;
230use tokio::task;
231
232pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
233
234const VERSION: &str = env!("CARGO_PKG_VERSION");
235const LANG: &str = "rust";
236const MAX_PENDING_PINGS: usize = 2;
237const MULTIPLEXER_SID: u64 = 0;
238
239/// A re-export of the `rustls` crate used in this crate,
240/// for use in cases where manual client configurations
241/// must be provided using `Options::tls_client_config`.
242pub use tokio_rustls::rustls;
243
244use connection::{Connection, State};
245use connector::{Connector, ConnectorOptions};
246pub use header::{HeaderMap, HeaderName, HeaderValue};
247pub use subject::Subject;
248
249mod auth;
250pub(crate) mod auth_utils;
251pub mod client;
252pub mod connection;
253mod connector;
254mod options;
255
256pub use auth::Auth;
257pub use client::{
258    Client, PublishError, Request, RequestError, RequestErrorKind, Statistics, SubscribeError,
259};
260pub use options::{AuthError, ConnectOptions};
261
262#[cfg(feature = "crypto")]
263#[cfg_attr(docsrs, doc(cfg(feature = "crypto")))]
264mod crypto;
265pub mod error;
266pub mod header;
267mod id_generator;
268#[cfg(feature = "jetstream")]
269#[cfg_attr(docsrs, doc(cfg(feature = "jetstream")))]
270pub mod jetstream;
271pub mod message;
272#[cfg(feature = "service")]
273#[cfg_attr(docsrs, doc(cfg(feature = "service")))]
274pub mod service;
275pub mod status;
276pub mod subject;
277mod tls;
278
279pub use message::Message;
280pub use status::StatusCode;
281
282/// Information sent by the server back to this client
283/// during initial connection, and possibly again later.
284#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
285pub struct ServerInfo {
286    /// The unique identifier of the NATS server.
287    #[serde(default)]
288    pub server_id: String,
289    /// Generated Server Name.
290    #[serde(default)]
291    pub server_name: String,
292    /// The host specified in the cluster parameter/options.
293    #[serde(default)]
294    pub host: String,
295    /// The port number specified in the cluster parameter/options.
296    #[serde(default)]
297    pub port: u16,
298    /// The version of the NATS server.
299    #[serde(default)]
300    pub version: String,
301    /// If this is set, then the server should try to authenticate upon
302    /// connect.
303    #[serde(default)]
304    pub auth_required: bool,
305    /// If this is set, then the server must authenticate using TLS.
306    #[serde(default)]
307    pub tls_required: bool,
308    /// Maximum payload size that the server will accept.
309    #[serde(default)]
310    pub max_payload: usize,
311    /// The protocol version in use.
312    #[serde(default)]
313    pub proto: i8,
314    /// The server-assigned client ID. This may change during reconnection.
315    #[serde(default)]
316    pub client_id: u64,
317    /// The version of golang the NATS server was built with.
318    #[serde(default)]
319    pub go: String,
320    /// The nonce used for nkeys.
321    #[serde(default)]
322    pub nonce: String,
323    /// A list of server urls that a client can connect to.
324    #[serde(default)]
325    pub connect_urls: Vec<String>,
326    /// The client IP as known by the server.
327    #[serde(default)]
328    pub client_ip: String,
329    /// Whether the server supports headers.
330    #[serde(default)]
331    pub headers: bool,
332    /// Whether server goes into lame duck mode.
333    #[serde(default, rename = "ldm")]
334    pub lame_duck_mode: bool,
335    /// Name of the cluster if the server is in cluster-mode
336    #[serde(default)]
337    pub cluster: Option<String>,
338    /// The configured NATS domain of the server.
339    #[serde(default)]
340    pub domain: Option<String>,
341    /// Whether the server supports JetStream.
342    #[serde(default)]
343    pub jetstream: bool,
344}
345
346#[derive(Clone, Debug, Eq, PartialEq)]
347pub(crate) enum ServerOp {
348    Ok,
349    Info(Box<ServerInfo>),
350    Ping,
351    Pong,
352    Error(ServerError),
353    Message {
354        sid: u64,
355        subject: Subject,
356        reply: Option<Subject>,
357        payload: Bytes,
358        headers: Option<HeaderMap>,
359        status: Option<StatusCode>,
360        description: Option<String>,
361        length: usize,
362    },
363}
364
365/// An alias. This is done to avoid breaking changes
366/// in the public API. However this will get deprecated in the future in favor of
367/// [crate::message::OutboundMessage].
368#[deprecated(
369    since = "0.44.0",
370    note = "use `async_nats::message::OutboundMessage` instead"
371)]
372pub type PublishMessage = crate::message::OutboundMessage;
373
374/// `Command` represents all commands that a [`Client`] can handle
375#[derive(Debug)]
376pub(crate) enum Command {
377    Publish(OutboundMessage),
378    Request {
379        subject: Subject,
380        payload: Bytes,
381        respond: Subject,
382        headers: Option<HeaderMap>,
383        sender: oneshot::Sender<Message>,
384    },
385    Subscribe {
386        sid: u64,
387        subject: Subject,
388        queue_group: Option<String>,
389        sender: mpsc::Sender<Message>,
390        statistics: Arc<SubscriberStatistics>,
391    },
392    Unsubscribe {
393        sid: u64,
394        max: Option<u64>,
395    },
396    Flush {
397        observer: oneshot::Sender<()>,
398    },
399    Drain {
400        sid: Option<u64>,
401    },
402    Reconnect,
403}
404
405/// `ClientOp` represents all actions of `Client`.
406#[derive(Debug)]
407pub(crate) enum ClientOp {
408    Publish {
409        subject: Subject,
410        payload: Bytes,
411        respond: Option<Subject>,
412        headers: Option<HeaderMap>,
413    },
414    Subscribe {
415        sid: u64,
416        subject: Subject,
417        queue_group: Option<String>,
418    },
419    Unsubscribe {
420        sid: u64,
421        max: Option<u64>,
422    },
423    Ping,
424    Pong,
425    Connect(ConnectInfo),
426}
427
428#[derive(Debug)]
429struct Subscription {
430    subject: Subject,
431    sender: mpsc::Sender<Message>,
432    statistics: Arc<SubscriberStatistics>,
433    queue_group: Option<String>,
434    delivered: u64,
435    max: Option<u64>,
436}
437
438#[derive(Debug)]
439struct Multiplexer {
440    subject: Subject,
441    prefix: Subject,
442    senders: HashMap<String, oneshot::Sender<Message>>,
443}
444
445/// A connection handler which facilitates communication from channels to a single shared connection.
446pub(crate) struct ConnectionHandler {
447    connection: Connection,
448    connector: Connector,
449    subscriptions: HashMap<u64, Subscription>,
450    multiplexer: Option<Multiplexer>,
451    pending_pings: usize,
452    info_sender: tokio::sync::watch::Sender<ServerInfo>,
453    ping_interval: Interval,
454    should_reconnect: bool,
455    flush_observers: Vec<oneshot::Sender<()>>,
456    is_draining: bool,
457    drain_pings: VecDeque<u64>,
458}
459
460impl ConnectionHandler {
461    pub(crate) fn new(
462        connection: Connection,
463        connector: Connector,
464        info_sender: tokio::sync::watch::Sender<ServerInfo>,
465        ping_period: Duration,
466    ) -> ConnectionHandler {
467        let mut ping_interval = interval(ping_period);
468        ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
469
470        ConnectionHandler {
471            connection,
472            connector,
473            subscriptions: HashMap::new(),
474            multiplexer: None,
475            pending_pings: 0,
476            info_sender,
477            ping_interval,
478            should_reconnect: false,
479            flush_observers: Vec::new(),
480            is_draining: false,
481            drain_pings: VecDeque::new(),
482        }
483    }
484
485    pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
486        struct ProcessFut<'a> {
487            handler: &'a mut ConnectionHandler,
488            receiver: &'a mut mpsc::Receiver<Command>,
489            recv_buf: &'a mut Vec<Command>,
490        }
491
492        enum ExitReason {
493            Disconnected(Option<io::Error>),
494            ReconnectRequested,
495            Closed,
496        }
497
498        impl ProcessFut<'_> {
499            const RECV_CHUNK_SIZE: usize = 16;
500
501            #[cold]
502            fn ping(&mut self) -> Poll<ExitReason> {
503                self.handler.pending_pings += 1;
504
505                if self.handler.pending_pings > MAX_PENDING_PINGS {
506                    debug!(
507                        pending_pings = self.handler.pending_pings,
508                        max_pings = MAX_PENDING_PINGS,
509                        "disconnecting due to too many pending pings"
510                    );
511
512                    Poll::Ready(ExitReason::Disconnected(None))
513                } else {
514                    self.handler.connection.enqueue_write_op(&ClientOp::Ping);
515
516                    Poll::Pending
517                }
518            }
519        }
520
521        impl Future for ProcessFut<'_> {
522            type Output = ExitReason;
523
524            /// Drives the connection forward.
525            ///
526            /// Returns one of the following:
527            ///
528            /// * `Poll::Pending` means that the connection
529            ///   is blocked on all fronts or there are
530            ///   no commands to send or receive
531            /// * `Poll::Ready(ExitReason::Disconnected(_))` means
532            ///   that an I/O operation failed and the connection
533            ///   is considered dead.
534            /// * `Poll::Ready(ExitReason::Closed)` means that
535            ///   [`Self::receiver`] was closed, so there's nothing
536            ///   more for us to do than to exit the client.
537            fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
538                // We need to be sure the waker is registered, therefore we need to poll until we
539                // get a `Poll::Pending`. With a sane interval delay, this means that the loop
540                // breaks at the second iteration.
541                while self.handler.ping_interval.poll_tick(cx).is_ready() {
542                    if let Poll::Ready(exit) = self.ping() {
543                        return Poll::Ready(exit);
544                    }
545                }
546
547                loop {
548                    match self.handler.connection.poll_read_op(cx) {
549                        Poll::Pending => break,
550                        Poll::Ready(Ok(Some(server_op))) => {
551                            self.handler.handle_server_op(server_op);
552                        }
553                        Poll::Ready(Ok(None)) => {
554                            return Poll::Ready(ExitReason::Disconnected(None))
555                        }
556                        Poll::Ready(Err(err)) => {
557                            return Poll::Ready(ExitReason::Disconnected(Some(err)))
558                        }
559                    }
560                }
561
562                // Before handling any commands, drop any subscriptions which are draining
563                // Note: safe to assume subscription drain has completed at this point, as we would have flushed
564                // all outgoing UNSUB messages in the previous call to this fn, and we would have processed and
565                // delivered any remaining messages to the subscription in the loop above.
566                while let Some(sid) = self.handler.drain_pings.pop_front() {
567                    self.handler.subscriptions.remove(&sid);
568                }
569
570                if self.handler.is_draining {
571                    // The entire connection is draining. This means we flushed outgoing messages in the previous
572                    // call to this fn, we handled any remaining messages from the server in the loop above, and
573                    // all subs were drained, so drain is complete and we should exit instead of processing any
574                    // further messages
575                    return Poll::Ready(ExitReason::Closed);
576                }
577
578                // WARNING: after the following loop `handle_command`,
579                // or other functions which call `enqueue_write_op`,
580                // cannot be called anymore. Runtime wakeups won't
581                // trigger a call to `poll_write`
582
583                let mut made_progress = true;
584                loop {
585                    while !self.handler.connection.is_write_buf_full() {
586                        debug_assert!(self.recv_buf.is_empty());
587
588                        let Self {
589                            recv_buf,
590                            handler,
591                            receiver,
592                        } = &mut *self;
593                        match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
594                            Poll::Pending => break,
595                            Poll::Ready(1..) => {
596                                made_progress = true;
597
598                                for cmd in recv_buf.drain(..) {
599                                    handler.handle_command(cmd);
600                                }
601                            }
602                            // TODO: replace `_` with `0` after bumping MSRV to 1.75
603                            Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
604                        }
605                    }
606
607                    // The first round will poll both from
608                    // the `receiver` and the writer, giving
609                    // them both a chance to make progress
610                    // and register `Waker`s.
611                    //
612                    // If writing is `Poll::Pending` we exit.
613                    //
614                    // If writing is completed we can repeat the entire
615                    // cycle as long as the `receiver` doesn't end-up
616                    // `Poll::Pending` immediately.
617                    if !mem::take(&mut made_progress) {
618                        break;
619                    }
620
621                    match self.handler.connection.poll_write(cx) {
622                        Poll::Pending => {
623                            // Write buffer couldn't be fully emptied
624                            break;
625                        }
626                        Poll::Ready(Ok(())) => {
627                            // Write buffer is empty
628                            continue;
629                        }
630                        Poll::Ready(Err(err)) => {
631                            return Poll::Ready(ExitReason::Disconnected(Some(err)))
632                        }
633                    }
634                }
635
636                if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
637                    self.handler.connection.should_flush(),
638                    self.handler.flush_observers.is_empty(),
639                ) {
640                    match self.handler.connection.poll_flush(cx) {
641                        Poll::Pending => {}
642                        Poll::Ready(Ok(())) => {
643                            for observer in self.handler.flush_observers.drain(..) {
644                                let _ = observer.send(());
645                            }
646                        }
647                        Poll::Ready(Err(err)) => {
648                            return Poll::Ready(ExitReason::Disconnected(Some(err)))
649                        }
650                    }
651                }
652
653                if mem::take(&mut self.handler.should_reconnect) {
654                    return Poll::Ready(ExitReason::ReconnectRequested);
655                }
656
657                Poll::Pending
658            }
659        }
660
661        let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
662        loop {
663            let process = ProcessFut {
664                handler: self,
665                receiver,
666                recv_buf: &mut recv_buf,
667            };
668            match process.await {
669                ExitReason::Disconnected(err) => {
670                    debug!(error = ?err, "disconnected");
671                    if self.handle_disconnect().await.is_err() {
672                        break;
673                    };
674                    debug!("reconnected");
675                }
676                ExitReason::Closed => {
677                    // Safe to ignore result as we're shutting down anyway
678                    self.connector.events_tx.try_send(Event::Closed).ok();
679                    break;
680                }
681                ExitReason::ReconnectRequested => {
682                    debug!("reconnect requested");
683                    // Should be ok to ingore error, as that means we are not in connected state.
684                    self.connection.stream.shutdown().await.ok();
685                    if self.handle_disconnect().await.is_err() {
686                        break;
687                    };
688                }
689            }
690        }
691    }
692
693    fn handle_server_op(&mut self, server_op: ServerOp) {
694        self.ping_interval.reset();
695
696        match server_op {
697            ServerOp::Ping => {
698                debug!("received PING");
699                self.connection.enqueue_write_op(&ClientOp::Pong);
700            }
701            ServerOp::Pong => {
702                debug!("received PONG");
703                self.pending_pings = self.pending_pings.saturating_sub(1);
704            }
705            ServerOp::Error(error) => {
706                debug!("received ERROR: {:?}", error);
707                self.connector
708                    .events_tx
709                    .try_send(Event::ServerError(error))
710                    .ok();
711            }
712            ServerOp::Message {
713                sid,
714                subject,
715                reply,
716                payload,
717                headers,
718                status,
719                description,
720                length,
721            } => {
722                debug!("received MESSAGE: sid={}, subject={}", sid, subject);
723                self.connector
724                    .connect_stats
725                    .in_messages
726                    .add(1, Ordering::Relaxed);
727
728                if let Some(subscription) = self.subscriptions.get_mut(&sid) {
729                    let message: Message = Message {
730                        subject,
731                        reply,
732                        payload,
733                        headers,
734                        status,
735                        description,
736                        length,
737                    };
738
739                    // if the channel for subscription was dropped, remove the
740                    // subscription from the map and unsubscribe.
741                    match subscription.sender.try_send(message) {
742                        Ok(_) => {
743                            subscription
744                                .statistics
745                                .pending_messages
746                                .add(1, Ordering::Relaxed);
747                            subscription
748                                .statistics
749                                .pending_bytes
750                                .add(length as u64, Ordering::Relaxed);
751                            self.connector
752                                .connect_stats
753                                .subscription_pending_messages
754                                .add(1, Ordering::Relaxed);
755                            self.connector
756                                .connect_stats
757                                .subscription_pending_bytes
758                                .add(length as u64, Ordering::Relaxed);
759                            subscription.delivered += 1;
760                            // if this `Subscription` has set `max` value, check if it
761                            // was reached. If yes, remove the `Subscription` and in
762                            // the result, `drop` the `sender` channel.
763                            if let Some(max) = subscription.max {
764                                if subscription.delivered.ge(&max) {
765                                    debug!("max messages reached for subscription {}", sid);
766                                    self.subscriptions.remove(&sid);
767                                }
768                            }
769                        }
770                        Err(mpsc::error::TrySendError::Full(returned_message)) => {
771                            let dropped_len = returned_message.length as u64;
772                            subscription
773                                .statistics
774                                .dropped_messages
775                                .add(1, Ordering::Relaxed);
776                            subscription
777                                .statistics
778                                .dropped_bytes
779                                .add(dropped_len, Ordering::Relaxed);
780                            self.connector
781                                .connect_stats
782                                .subscription_dropped_messages
783                                .add(1, Ordering::Relaxed);
784                            self.connector
785                                .connect_stats
786                                .subscription_dropped_bytes
787                                .add(dropped_len, Ordering::Relaxed);
788                            debug!("slow consumer detected for subscription {}", sid);
789                            self.connector
790                                .events_tx
791                                .try_send(Event::SlowConsumer(SlowConsumer {
792                                    sid,
793                                    subject: returned_message.subject,
794                                }))
795                                .ok();
796                        }
797                        Err(mpsc::error::TrySendError::Closed(_)) => {
798                            debug!("subscription {} channel closed", sid);
799                            self.subscriptions.remove(&sid);
800                            self.connection
801                                .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
802                        }
803                    }
804                } else if sid == MULTIPLEXER_SID {
805                    debug!("received message for multiplexer");
806                    if let Some(multiplexer) = self.multiplexer.as_mut() {
807                        let maybe_token =
808                            subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
809
810                        if let Some(token) = maybe_token {
811                            if let Some(sender) = multiplexer.senders.remove(token) {
812                                debug!("forwarding message to request with token {}", token);
813                                let message = Message {
814                                    subject,
815                                    reply,
816                                    payload,
817                                    headers,
818                                    status,
819                                    description,
820                                    length,
821                                };
822
823                                let _ = sender.send(message);
824                            }
825                        }
826                    }
827                }
828            }
829            // TODO: we should probably update advertised server list here too.
830            ServerOp::Info(info) => {
831                debug!("received INFO: server_id={}", info.server_id);
832                if info.lame_duck_mode {
833                    debug!("server in lame duck mode");
834                    self.connector.events_tx.try_send(Event::LameDuckMode).ok();
835                }
836            }
837
838            _ => {
839                // TODO: don't ignore.
840            }
841        }
842    }
843
844    fn handle_command(&mut self, command: Command) {
845        self.ping_interval.reset();
846
847        match command {
848            Command::Unsubscribe { sid, max } => {
849                if let Some(subscription) = self.subscriptions.get_mut(&sid) {
850                    subscription.max = max;
851                    match subscription.max {
852                        Some(n) => {
853                            if subscription.delivered >= n {
854                                self.subscriptions.remove(&sid);
855                            }
856                        }
857                        None => {
858                            self.subscriptions.remove(&sid);
859                        }
860                    }
861
862                    self.connection
863                        .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
864                }
865            }
866            Command::Flush { observer } => {
867                self.flush_observers.push(observer);
868            }
869            Command::Drain { sid } => {
870                let mut drain_sub = |sid: u64| {
871                    self.drain_pings.push_back(sid);
872                    self.connection
873                        .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
874                };
875
876                if let Some(sid) = sid {
877                    if self.subscriptions.get_mut(&sid).is_some() {
878                        drain_sub(sid);
879                    }
880                } else {
881                    // sid isn't set, so drain the whole client
882                    self.connector.events_tx.try_send(Event::Draining).ok();
883                    self.is_draining = true;
884                    for (&sid, _) in self.subscriptions.iter_mut() {
885                        drain_sub(sid);
886                    }
887                }
888                self.connection.enqueue_write_op(&ClientOp::Ping);
889            }
890            Command::Subscribe {
891                sid,
892                subject,
893                queue_group,
894                sender,
895                statistics,
896            } => {
897                let subscription = Subscription {
898                    sender,
899                    statistics,
900                    delivered: 0,
901                    max: None,
902                    subject: subject.to_owned(),
903                    queue_group: queue_group.to_owned(),
904                };
905
906                self.subscriptions.insert(sid, subscription);
907
908                self.connection.enqueue_write_op(&ClientOp::Subscribe {
909                    sid,
910                    subject,
911                    queue_group,
912                });
913            }
914            Command::Request {
915                subject,
916                payload,
917                respond,
918                headers,
919                sender,
920            } => {
921                let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
922
923                let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
924                    multiplexer
925                } else {
926                    let prefix = Subject::from(format!("{}.{}.", prefix, id_generator::next()));
927                    let subject = Subject::from(format!("{prefix}*"));
928
929                    self.connection.enqueue_write_op(&ClientOp::Subscribe {
930                        sid: MULTIPLEXER_SID,
931                        subject: subject.clone(),
932                        queue_group: None,
933                    });
934
935                    self.multiplexer.insert(Multiplexer {
936                        subject,
937                        prefix,
938                        senders: HashMap::new(),
939                    })
940                };
941                self.connector
942                    .connect_stats
943                    .out_messages
944                    .add(1, Ordering::Relaxed);
945
946                multiplexer.senders.insert(token.to_owned(), sender);
947
948                let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
949
950                let pub_op = ClientOp::Publish {
951                    subject,
952                    payload,
953                    respond: Some(respond),
954                    headers,
955                };
956
957                self.connection.enqueue_write_op(&pub_op);
958            }
959
960            Command::Publish(OutboundMessage {
961                subject,
962                payload,
963                reply: respond,
964                headers,
965            }) => {
966                self.connector
967                    .connect_stats
968                    .out_messages
969                    .add(1, Ordering::Relaxed);
970
971                let header_len = headers
972                    .as_ref()
973                    .map(|headers| headers.len())
974                    .unwrap_or_default();
975
976                self.connector.connect_stats.out_bytes.add(
977                    (payload.len()
978                        + respond.as_ref().map_or_else(|| 0, |r| r.len())
979                        + subject.len()
980                        + header_len) as u64,
981                    Ordering::Relaxed,
982                );
983
984                self.connection.enqueue_write_op(&ClientOp::Publish {
985                    subject,
986                    payload,
987                    respond,
988                    headers,
989                });
990            }
991
992            Command::Reconnect => {
993                self.should_reconnect = true;
994            }
995        }
996    }
997
998    async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
999        self.pending_pings = 0;
1000        self.connector.events_tx.try_send(Event::Disconnected).ok();
1001        self.connector.state_tx.send(State::Disconnected).ok();
1002
1003        self.handle_reconnect().await
1004    }
1005
1006    async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
1007        let (info, connection) = self.connector.connect().await?;
1008        self.connection = connection;
1009        let _ = self.info_sender.send(info);
1010
1011        self.subscriptions
1012            .retain(|_, subscription| !subscription.sender.is_closed());
1013
1014        for (sid, subscription) in &self.subscriptions {
1015            self.connection.enqueue_write_op(&ClientOp::Subscribe {
1016                sid: *sid,
1017                subject: subscription.subject.to_owned(),
1018                queue_group: subscription.queue_group.to_owned(),
1019            });
1020        }
1021
1022        if let Some(multiplexer) = &self.multiplexer {
1023            self.connection.enqueue_write_op(&ClientOp::Subscribe {
1024                sid: MULTIPLEXER_SID,
1025                subject: multiplexer.subject.to_owned(),
1026                queue_group: None,
1027            });
1028        }
1029        Ok(())
1030    }
1031}
1032
1033/// Connects to NATS with specified options.
1034///
1035/// It is generally advised to use [ConnectOptions] instead, as it provides a builder for whole
1036/// configuration.
1037///
1038/// # Examples
1039/// ```
1040/// # #[tokio::main]
1041/// # async fn main() ->  Result<(), async_nats::Error> {
1042/// let mut nc =
1043///     async_nats::connect_with_options("demo.nats.io", async_nats::ConnectOptions::new()).await?;
1044/// nc.publish("test", "data".into()).await?;
1045/// # Ok(())
1046/// # }
1047/// ```
1048pub async fn connect_with_options<A: ToServerAddrs>(
1049    addrs: A,
1050    options: ConnectOptions,
1051) -> Result<Client, ConnectError> {
1052    let ping_period = options.ping_interval;
1053
1054    let (events_tx, mut events_rx) = mpsc::channel(128);
1055    let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
1056    // We're setting it to the default server payload size.
1057    let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));
1058    let statistics = Arc::new(Statistics::default());
1059
1060    let mut connector = Connector::new(
1061        addrs,
1062        ConnectorOptions {
1063            tls_required: options.tls_required,
1064            certificates: options.certificates,
1065            client_key: options.client_key,
1066            client_cert: options.client_cert,
1067            tls_client_config: options.tls_client_config,
1068            tls_first: options.tls_first,
1069            auth: options.auth,
1070            no_echo: options.no_echo,
1071            connection_timeout: options.connection_timeout,
1072            name: options.name,
1073            ignore_discovered_servers: options.ignore_discovered_servers,
1074            retain_servers_order: options.retain_servers_order,
1075            read_buffer_capacity: options.read_buffer_capacity,
1076            reconnect_delay_callback: options.reconnect_delay_callback,
1077            auth_callback: options.auth_callback,
1078            max_reconnects: options.max_reconnects,
1079        },
1080        events_tx,
1081        state_tx,
1082        max_payload.clone(),
1083        statistics.clone(),
1084    )
1085    .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
1086
1087    let mut info: ServerInfo = Default::default();
1088    let mut connection = None;
1089    if !options.retry_on_initial_connect {
1090        debug!("retry on initial connect failure is disabled");
1091        let (info_ok, connection_ok) = connector.try_connect().await?;
1092        connection = Some(connection_ok);
1093        info = info_ok;
1094    }
1095
1096    let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
1097    let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
1098
1099    let client = Client::new(
1100        info_watcher,
1101        state_rx,
1102        sender,
1103        options.subscription_capacity,
1104        options.inbox_prefix,
1105        options.request_timeout,
1106        max_payload,
1107        statistics,
1108    );
1109
1110    task::spawn(async move {
1111        while let Some(event) = events_rx.recv().await {
1112            tracing::info!("event: {}", event);
1113            if let Some(event_callback) = &options.event_callback {
1114                event_callback.call(event).await;
1115            }
1116        }
1117    });
1118
1119    task::spawn(async move {
1120        if connection.is_none() && options.retry_on_initial_connect {
1121            let (info, connection_ok) = match connector.connect().await {
1122                Ok((info, connection)) => (info, connection),
1123                Err(err) => {
1124                    error!("connection closed: {}", err);
1125                    return;
1126                }
1127            };
1128            info_sender.send(info).ok();
1129            connection = Some(connection_ok);
1130        }
1131        let connection = connection.unwrap();
1132        let mut connection_handler =
1133            ConnectionHandler::new(connection, connector, info_sender, ping_period);
1134        connection_handler.process(&mut receiver).await
1135    });
1136
1137    Ok(client)
1138}
1139
1140#[derive(Debug, Clone, PartialEq, Eq)]
1141pub enum Event {
1142    Connected,
1143    Disconnected,
1144    LameDuckMode,
1145    Draining,
1146    Closed,
1147    SlowConsumer(SlowConsumer),
1148    ServerError(ServerError),
1149    ClientError(ClientError),
1150}
1151
1152impl fmt::Display for Event {
1153    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1154        match self {
1155            Event::Connected => write!(f, "connected"),
1156            Event::Disconnected => write!(f, "disconnected"),
1157            Event::LameDuckMode => write!(f, "lame duck mode detected"),
1158            Event::Draining => write!(f, "draining"),
1159            Event::Closed => write!(f, "closed"),
1160            Event::SlowConsumer(slow_consumer) => write!(
1161                f,
1162                "slow consumers for subscription {} on subject {}",
1163                slow_consumer.sid, slow_consumer.subject
1164            ),
1165            Event::ServerError(err) => write!(f, "server error: {err}"),
1166            Event::ClientError(err) => write!(f, "client error: {err}"),
1167        }
1168    }
1169}
1170
1171/// Connects to NATS with default config.
1172///
1173/// Returns cloneable [Client].
1174///
1175/// To have customized NATS connection, check [ConnectOptions].
1176///
1177/// # Examples
1178///
1179/// ## Single URL
1180/// ```
1181/// # #[tokio::main]
1182/// # async fn main() ->  Result<(), async_nats::Error> {
1183/// let mut nc = async_nats::connect("demo.nats.io").await?;
1184/// nc.publish("test", "data".into()).await?;
1185/// # Ok(())
1186/// # }
1187/// ```
1188///
1189/// ## Connect with [Vec] of [ServerAddr].
1190/// ```no_run
1191/// #[tokio::main]
1192/// # async fn main() -> Result<(), async_nats::Error> {
1193/// use async_nats::ServerAddr;
1194/// let client = async_nats::connect(vec![
1195///     "demo.nats.io".parse::<ServerAddr>()?,
1196///     "other.nats.io".parse::<ServerAddr>()?,
1197/// ])
1198/// .await
1199/// .unwrap();
1200/// # Ok(())
1201/// # }
1202/// ```
1203///
1204/// ## with [Vec], but parse URLs inside [crate::connect()]
1205/// ```no_run
1206/// #[tokio::main]
1207/// # async fn main() -> Result<(), async_nats::Error> {
1208/// use async_nats::ServerAddr;
1209/// let servers = vec!["demo.nats.io", "other.nats.io"];
1210/// let client = async_nats::connect(
1211///     servers
1212///         .iter()
1213///         .map(|url| url.parse())
1214///         .collect::<Result<Vec<ServerAddr>, _>>()?,
1215/// )
1216/// .await?;
1217/// # Ok(())
1218/// # }
1219/// ```
1220///
1221///
1222/// ## with slice.
1223/// ```no_run
1224/// #[tokio::main]
1225/// # async fn main() -> Result<(), async_nats::Error> {
1226/// use async_nats::ServerAddr;
1227/// let client = async_nats::connect(
1228///    [
1229///        "demo.nats.io".parse::<ServerAddr>()?,
1230///        "other.nats.io".parse::<ServerAddr>()?,
1231///    ]
1232///    .as_slice(),
1233/// )
1234/// .await?;
1235/// # Ok(())
1236/// # }
1237pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1238    connect_with_options(addrs, ConnectOptions::default()).await
1239}
1240
1241#[derive(Debug, Clone, Copy, PartialEq)]
1242pub enum ConnectErrorKind {
1243    /// Parsing the passed server address failed.
1244    ServerParse,
1245    /// DNS related issues.
1246    Dns,
1247    /// Failed authentication process, signing nonce, etc.
1248    Authentication,
1249    /// Server returned authorization violation error.
1250    AuthorizationViolation,
1251    /// Connect timed out.
1252    TimedOut,
1253    /// Erroneous TLS setup.
1254    Tls,
1255    /// Other IO error.
1256    Io,
1257    /// Reached the maximum number of reconnects.
1258    MaxReconnects,
1259}
1260
1261impl Display for ConnectErrorKind {
1262    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1263        match self {
1264            Self::ServerParse => write!(f, "failed to parse server or server list"),
1265            Self::Dns => write!(f, "DNS error"),
1266            Self::Authentication => write!(f, "failed signing nonce"),
1267            Self::AuthorizationViolation => write!(f, "authorization violation"),
1268            Self::TimedOut => write!(f, "timed out"),
1269            Self::Tls => write!(f, "TLS error"),
1270            Self::Io => write!(f, "IO error"),
1271            Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1272        }
1273    }
1274}
1275
1276/// Returned when initial connection fails.
1277/// To be enumerate over the variants, call [ConnectError::kind].
1278pub type ConnectError = error::Error<ConnectErrorKind>;
1279
1280impl From<io::Error> for ConnectError {
1281    fn from(err: io::Error) -> Self {
1282        ConnectError::with_source(ConnectErrorKind::Io, err)
1283    }
1284}
1285
1286/// Retrieves messages from given `subscription` created by [Client::subscribe].
1287///
1288/// Implements [futures_util::stream::Stream] for ergonomic async message processing.
1289///
1290/// # Examples
1291/// ```
1292/// # #[tokio::main]
1293/// # async fn main() ->  Result<(), async_nats::Error> {
1294/// let mut nc = async_nats::connect("demo.nats.io").await?;
1295/// # nc.publish("test", "data".into()).await?;
1296/// # Ok(())
1297/// # }
1298/// ```
1299#[derive(Debug)]
1300pub struct Subscriber {
1301    sid: u64,
1302    receiver: mpsc::Receiver<Message>,
1303    sender: mpsc::Sender<Command>,
1304    statistics: Arc<SubscriberStatistics>,
1305    connection_stats: Arc<client::Statistics>,
1306}
1307
1308impl Subscriber {
1309    pub(crate) fn new(
1310        sid: u64,
1311        sender: mpsc::Sender<Command>,
1312        receiver: mpsc::Receiver<Message>,
1313        statistics: Arc<SubscriberStatistics>,
1314        connection_stats: Arc<client::Statistics>,
1315    ) -> Subscriber {
1316        connection_stats
1317            .active_subscriptions
1318            .add(1, Ordering::Relaxed);
1319        connection_stats
1320            .active_subscription_capacity
1321            .add(receiver.max_capacity() as u64, Ordering::Relaxed);
1322
1323        Subscriber {
1324            sid,
1325            sender,
1326            receiver,
1327            statistics,
1328            connection_stats,
1329        }
1330    }
1331
1332    /// Returns statistics for this subscription handle.
1333    pub fn statistics(&self) -> Arc<SubscriberStatistics> {
1334        self.statistics.clone()
1335    }
1336
1337    /// Returns the number of messages currently buffered in this subscriber.
1338    pub fn pending_messages(&self) -> usize {
1339        self.receiver.len()
1340    }
1341
1342    /// Returns the number of bytes currently buffered in this subscriber.
1343    pub fn pending_bytes(&self) -> u64 {
1344        self.statistics.pending_bytes.load(Ordering::Relaxed)
1345    }
1346
1347    /// Returns the remaining message capacity in this subscriber.
1348    pub fn remaining_capacity(&self) -> usize {
1349        self.receiver.capacity()
1350    }
1351
1352    /// Returns the maximum message capacity in this subscriber.
1353    pub fn max_capacity(&self) -> usize {
1354        self.receiver.max_capacity()
1355    }
1356
1357    /// Unsubscribes from subscription, draining all remaining messages.
1358    ///
1359    /// # Examples
1360    /// ```
1361    /// # #[tokio::main]
1362    /// # async fn main() -> Result<(), async_nats::Error> {
1363    /// let client = async_nats::connect("demo.nats.io").await?;
1364    ///
1365    /// let mut subscriber = client.subscribe("foo").await?;
1366    ///
1367    /// subscriber.unsubscribe().await?;
1368    /// # Ok(())
1369    /// # }
1370    /// ```
1371    pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1372        self.sender
1373            .send(Command::Unsubscribe {
1374                sid: self.sid,
1375                max: None,
1376            })
1377            .await?;
1378        self.receiver.close();
1379        Ok(())
1380    }
1381
1382    /// Unsubscribes from subscription after reaching given number of messages.
1383    /// This is the total number of messages received by this subscription in it's whole
1384    /// lifespan. If it already reached or surpassed the passed value, it will immediately stop.
1385    ///
1386    /// # Examples
1387    /// ```
1388    /// # use futures_util::StreamExt;
1389    /// # #[tokio::main]
1390    /// # async fn main() -> Result<(), async_nats::Error> {
1391    /// let client = async_nats::connect("demo.nats.io").await?;
1392    ///
1393    /// let mut subscriber = client.subscribe("test").await?;
1394    /// subscriber.unsubscribe_after(3).await?;
1395    ///
1396    /// for _ in 0..3 {
1397    ///     client.publish("test", "data".into()).await?;
1398    /// }
1399    ///
1400    /// while let Some(message) = subscriber.next().await {
1401    ///     println!("message received: {:?}", message);
1402    /// }
1403    /// println!("no more messages, unsubscribed");
1404    /// # Ok(())
1405    /// # }
1406    /// ```
1407    pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1408        self.sender
1409            .send(Command::Unsubscribe {
1410                sid: self.sid,
1411                max: Some(unsub_after),
1412            })
1413            .await?;
1414        Ok(())
1415    }
1416
1417    /// Unsubscribes immediately but leaves the subscription open to allow any in-flight messages
1418    /// on the subscription to be delivered. The stream will be closed after any remaining messages
1419    /// are delivered
1420    ///
1421    /// # Examples
1422    /// ```no_run
1423    /// # use futures_util::StreamExt;
1424    /// # #[tokio::main]
1425    /// # async fn main() -> Result<(), async_nats::Error> {
1426    /// let client = async_nats::connect("demo.nats.io").await?;
1427    ///
1428    /// let mut subscriber = client.subscribe("test").await?;
1429    ///
1430    /// tokio::spawn({
1431    ///     let task_client = client.clone();
1432    ///     async move {
1433    ///         loop {
1434    ///             _ = task_client.publish("test", "data".into()).await;
1435    ///         }
1436    ///     }
1437    /// });
1438    ///
1439    /// client.flush().await?;
1440    /// subscriber.drain().await?;
1441    ///
1442    /// while let Some(message) = subscriber.next().await {
1443    ///     println!("message received: {:?}", message);
1444    /// }
1445    /// println!("no more messages, unsubscribed");
1446    /// # Ok(())
1447    /// # }
1448    /// ```
1449    pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
1450        self.sender
1451            .send(Command::Drain {
1452                sid: Some(self.sid),
1453            })
1454            .await?;
1455
1456        Ok(())
1457    }
1458}
1459
1460#[derive(Error, Debug, PartialEq)]
1461#[error("failed to send unsubscribe")]
1462pub struct UnsubscribeError(String);
1463
1464impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1465    fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1466        UnsubscribeError(err.to_string())
1467    }
1468}
1469
1470impl Drop for Subscriber {
1471    fn drop(&mut self) {
1472        self.receiver.close();
1473        let mut drained_messages = 0;
1474        let mut drained_bytes = 0;
1475
1476        while let Ok(message) = self.receiver.try_recv() {
1477            drained_messages += 1;
1478            drained_bytes += message.length as u64;
1479        }
1480
1481        if drained_messages > 0 {
1482            self.statistics
1483                .pending_messages
1484                .sub(drained_messages, Ordering::Relaxed);
1485            self.connection_stats
1486                .subscription_pending_messages
1487                .sub(drained_messages, Ordering::Relaxed);
1488        }
1489
1490        if drained_bytes > 0 {
1491            self.statistics
1492                .pending_bytes
1493                .sub(drained_bytes, Ordering::Relaxed);
1494            self.connection_stats
1495                .subscription_pending_bytes
1496                .sub(drained_bytes, Ordering::Relaxed);
1497        }
1498
1499        self.connection_stats
1500            .active_subscriptions
1501            .sub(1, Ordering::Relaxed);
1502        self.connection_stats
1503            .active_subscription_capacity
1504            .sub(self.receiver.max_capacity() as u64, Ordering::Relaxed);
1505
1506        tokio::spawn({
1507            let sender = self.sender.clone();
1508            let sid = self.sid;
1509            async move {
1510                sender
1511                    .send(Command::Unsubscribe { sid, max: None })
1512                    .await
1513                    .ok();
1514            }
1515        });
1516    }
1517}
1518
1519impl Stream for Subscriber {
1520    type Item = Message;
1521
1522    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1523        match self.receiver.poll_recv(cx) {
1524            Poll::Ready(Some(message)) => {
1525                self.statistics
1526                    .pending_messages
1527                    .sub(1, Ordering::Relaxed);
1528                self.statistics
1529                    .pending_bytes
1530                    .sub(message.length as u64, Ordering::Relaxed);
1531                self.connection_stats
1532                    .subscription_pending_messages
1533                    .sub(1, Ordering::Relaxed);
1534                self.connection_stats
1535                    .subscription_pending_bytes
1536                    .sub(message.length as u64, Ordering::Relaxed);
1537                Poll::Ready(Some(message))
1538            }
1539            other => other,
1540        }
1541    }
1542}
1543
1544/// Statistics for a single subscription handle.
1545#[derive(Default, Debug)]
1546pub struct SubscriberStatistics {
1547    /// Number of messages currently buffered in this subscription channel.
1548    pub pending_messages: AtomicU64,
1549    /// Number of bytes currently buffered in this subscription channel.
1550    pub pending_bytes: AtomicU64,
1551    /// Number of messages dropped because this subscription channel was full.
1552    pub dropped_messages: AtomicU64,
1553    /// Number of bytes dropped because this subscription channel was full.
1554    pub dropped_bytes: AtomicU64,
1555}
1556
1557#[derive(Clone, Debug, Eq, PartialEq)]
1558pub enum CallbackError {
1559    Client(ClientError),
1560    Server(ServerError),
1561}
1562impl std::fmt::Display for CallbackError {
1563    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1564        match self {
1565            Self::Client(error) => write!(f, "{error}"),
1566            Self::Server(error) => write!(f, "{error}"),
1567        }
1568    }
1569}
1570
1571impl From<ServerError> for CallbackError {
1572    fn from(server_error: ServerError) -> Self {
1573        CallbackError::Server(server_error)
1574    }
1575}
1576
1577impl From<ClientError> for CallbackError {
1578    fn from(client_error: ClientError) -> Self {
1579        CallbackError::Client(client_error)
1580    }
1581}
1582
1583#[derive(Clone, Debug, Eq, PartialEq, Error)]
1584pub enum ServerError {
1585    AuthorizationViolation,
1586    SlowConsumer(SlowConsumer),
1587    Other(String),
1588}
1589
1590#[derive(Clone, Debug, Eq, PartialEq)]
1591pub enum ClientError {
1592    Other(String),
1593    MaxReconnects,
1594}
1595impl std::fmt::Display for ClientError {
1596    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1597        match self {
1598            Self::Other(error) => write!(f, "nats: {error}"),
1599            Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1600        }
1601    }
1602}
1603
1604impl ServerError {
1605    fn new(error: String) -> ServerError {
1606        match error.to_lowercase().as_str() {
1607            "authorization violation" => ServerError::AuthorizationViolation,
1608            // error messages can contain case-sensitive values which should be preserved
1609            _ => ServerError::Other(error),
1610        }
1611    }
1612}
1613
1614impl std::fmt::Display for ServerError {
1615    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1616        match self {
1617            Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1618            Self::SlowConsumer(slow_consumer) => write!(
1619                f,
1620                "nats: subscription {} on subject {} is a slow consumer",
1621                slow_consumer.sid, slow_consumer.subject,
1622            ),
1623            Self::Other(error) => write!(f, "nats: {error}"),
1624        }
1625    }
1626}
1627
1628#[derive(Clone, Debug, Eq, PartialEq)]
1629pub struct SlowConsumer {
1630    pub sid: u64,
1631    pub subject: Subject,
1632}
1633
1634impl std::fmt::Display for SlowConsumer {
1635    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1636        write!(f, "slow consumer {} on subject {}", self.sid, self.subject)
1637    }
1638}
1639
1640/// Info to construct a CONNECT message.
1641#[derive(Clone, Debug, Serialize)]
1642pub struct ConnectInfo {
1643    /// Turns on +OK protocol acknowledgments.
1644    pub verbose: bool,
1645
1646    /// Turns on additional strict format checking, e.g. for properly formed
1647    /// subjects.
1648    pub pedantic: bool,
1649
1650    /// User's JWT.
1651    #[serde(rename = "jwt")]
1652    pub user_jwt: Option<String>,
1653
1654    /// Public nkey.
1655    pub nkey: Option<String>,
1656
1657    /// Signed nonce, encoded to Base64URL.
1658    #[serde(rename = "sig")]
1659    pub signature: Option<String>,
1660
1661    /// Optional client name.
1662    pub name: Option<String>,
1663
1664    /// If set to `true`, the server (version 1.2.0+) will not send originating
1665    /// messages from this connection to its own subscriptions. Clients should
1666    /// set this to `true` only for server supporting this feature, which is
1667    /// when proto in the INFO protocol is set to at least 1.
1668    pub echo: bool,
1669
1670    /// The implementation language of the client.
1671    pub lang: String,
1672
1673    /// The version of the client.
1674    pub version: String,
1675
1676    /// Sending 0 (or absent) indicates client supports original protocol.
1677    /// Sending 1 indicates that the client supports dynamic reconfiguration
1678    /// of cluster topology changes by asynchronously receiving INFO messages
1679    /// with known servers it can reconnect to.
1680    pub protocol: Protocol,
1681
1682    /// Indicates whether the client requires an SSL connection.
1683    pub tls_required: bool,
1684
1685    /// Connection username (if `auth_required` is set)
1686    pub user: Option<String>,
1687
1688    /// Connection password (if auth_required is set)
1689    pub pass: Option<String>,
1690
1691    /// Client authorization token (if auth_required is set)
1692    pub auth_token: Option<String>,
1693
1694    /// Whether the client supports the usage of headers.
1695    pub headers: bool,
1696
1697    /// Whether the client supports no_responders.
1698    pub no_responders: bool,
1699}
1700
1701/// Protocol version used by the client.
1702#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1703#[repr(u8)]
1704pub enum Protocol {
1705    /// Original protocol.
1706    Original = 0,
1707    /// Protocol with dynamic reconfiguration of cluster and lame duck mode functionality.
1708    Dynamic = 1,
1709}
1710
1711/// Address of a NATS server.
1712#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1713pub struct ServerAddr(Url);
1714
1715impl FromStr for ServerAddr {
1716    type Err = io::Error;
1717
1718    /// Parse an address of a NATS server.
1719    ///
1720    /// If not stated explicitly the `nats://` schema and port `4222` is assumed.
1721    fn from_str(input: &str) -> Result<Self, Self::Err> {
1722        let url: Url = if input.contains("://") {
1723            input.parse()
1724        } else {
1725            format!("nats://{input}").parse()
1726        }
1727        .map_err(|e| {
1728            io::Error::new(
1729                ErrorKind::InvalidInput,
1730                format!("NATS server URL is invalid: {e}"),
1731            )
1732        })?;
1733
1734        Self::from_url(url)
1735    }
1736}
1737
1738impl ServerAddr {
1739    /// Check if the URL is a valid NATS server address.
1740    pub fn from_url(url: Url) -> io::Result<Self> {
1741        if url.scheme() != "nats"
1742            && url.scheme() != "tls"
1743            && url.scheme() != "ws"
1744            && url.scheme() != "wss"
1745        {
1746            return Err(std::io::Error::new(
1747                ErrorKind::InvalidInput,
1748                format!("invalid scheme for NATS server URL: {}", url.scheme()),
1749            ));
1750        }
1751
1752        Ok(Self(url))
1753    }
1754
1755    /// Turn the server address into a standard URL.
1756    pub fn into_inner(self) -> Url {
1757        self.0
1758    }
1759
1760    /// Returns if tls is required by the client for this server.
1761    pub fn tls_required(&self) -> bool {
1762        self.0.scheme() == "tls"
1763    }
1764
1765    /// Returns if the server url had embedded username and password.
1766    pub fn has_user_pass(&self) -> bool {
1767        self.0.username() != ""
1768    }
1769
1770    pub fn scheme(&self) -> &str {
1771        self.0.scheme()
1772    }
1773
1774    /// Returns the host.
1775    pub fn host(&self) -> &str {
1776        match self.0.host() {
1777            Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1778            // `host_str()` for Ipv6 includes the []s
1779            Some(Host::Ipv6 { .. }) => {
1780                let host = self.0.host_str().unwrap();
1781                &host[1..host.len() - 1]
1782            }
1783            None => "",
1784        }
1785    }
1786
1787    pub fn is_websocket(&self) -> bool {
1788        self.0.scheme() == "ws" || self.0.scheme() == "wss"
1789    }
1790
1791    /// Returns the port.
1792    /// Delegates to [`Url::port_or_known_default`](https://docs.rs/url/latest/url/struct.Url.html#method.port_or_known_default) and defaults to 4222 if none was explicitly specified in creating this `ServerAddr`.
1793    pub fn port(&self) -> u16 {
1794        self.0.port_or_known_default().unwrap_or(4222)
1795    }
1796
1797    /// Returns the URL string.
1798    pub fn as_url_str(&self) -> &str {
1799        self.0.as_str()
1800    }
1801
1802    /// Returns the optional username in the url.
1803    pub fn username(&self) -> Option<&str> {
1804        let user = self.0.username();
1805        if user.is_empty() {
1806            None
1807        } else {
1808            Some(user)
1809        }
1810    }
1811
1812    /// Returns the optional password in the url.
1813    pub fn password(&self) -> Option<&str> {
1814        self.0.password()
1815    }
1816
1817    /// Return the sockets from resolving the server address.
1818    pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1819        tokio::net::lookup_host((self.host(), self.port())).await
1820    }
1821}
1822
1823/// Capability to convert into a list of NATS server addresses.
1824///
1825/// There are several implementations ensuring the easy passing of one or more server addresses to
1826/// functions like [`crate::connect()`].
1827pub trait ToServerAddrs {
1828    /// Returned iterator over socket addresses which this type may correspond
1829    /// to.
1830    type Iter: Iterator<Item = ServerAddr>;
1831
1832    fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1833}
1834
1835impl ToServerAddrs for ServerAddr {
1836    type Iter = option::IntoIter<ServerAddr>;
1837    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1838        Ok(Some(self.clone()).into_iter())
1839    }
1840}
1841
1842impl ToServerAddrs for str {
1843    type Iter = option::IntoIter<ServerAddr>;
1844    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1845        self.parse::<ServerAddr>()
1846            .map(|addr| Some(addr).into_iter())
1847    }
1848}
1849
1850impl ToServerAddrs for String {
1851    type Iter = option::IntoIter<ServerAddr>;
1852    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1853        (**self).to_server_addrs()
1854    }
1855}
1856
1857impl<T: AsRef<str>> ToServerAddrs for [T] {
1858    type Iter = std::vec::IntoIter<ServerAddr>;
1859    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1860        self.iter()
1861            .map(AsRef::as_ref)
1862            .map(str::parse)
1863            .collect::<io::Result<_>>()
1864            .map(Vec::into_iter)
1865    }
1866}
1867
1868impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1869    type Iter = std::vec::IntoIter<ServerAddr>;
1870    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1871        self.as_slice().to_server_addrs()
1872    }
1873}
1874
1875impl<'a> ToServerAddrs for &'a [ServerAddr] {
1876    type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1877
1878    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1879        Ok(self.iter().cloned())
1880    }
1881}
1882
1883impl ToServerAddrs for Vec<ServerAddr> {
1884    type Iter = std::vec::IntoIter<ServerAddr>;
1885
1886    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1887        Ok(self.clone().into_iter())
1888    }
1889}
1890
1891impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1892    type Iter = T::Iter;
1893    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1894        (**self).to_server_addrs()
1895    }
1896}
1897
1898#[allow(dead_code)]
1899pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1900    let subject_str = subject.as_ref();
1901    !subject_str.starts_with('.')
1902        && !subject_str.ends_with('.')
1903        && subject_str.bytes().all(|c| !c.is_ascii_whitespace())
1904}
1905#[allow(unused_macros)]
1906macro_rules! from_with_timeout {
1907    ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1908        impl From<$origin> for $t {
1909            fn from(err: $origin) -> Self {
1910                match err.kind() {
1911                    <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1912                    _ => Self::with_source(<$k>::Other, err),
1913                }
1914            }
1915        }
1916    };
1917}
1918#[allow(unused_imports)]
1919pub(crate) use from_with_timeout;
1920
1921use crate::connection::ShouldFlush;
1922use crate::message::OutboundMessage;
1923
1924#[cfg(test)]
1925mod tests {
1926    use super::*;
1927
1928    #[test]
1929    fn server_address_ipv6() {
1930        let address = ServerAddr::from_str("nats://[::]").unwrap();
1931        assert_eq!(address.host(), "::")
1932    }
1933
1934    #[test]
1935    fn server_address_ipv4() {
1936        let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1937        assert_eq!(address.host(), "127.0.0.1")
1938    }
1939
1940    #[test]
1941    fn server_address_domain() {
1942        let address = ServerAddr::from_str("nats://example.com").unwrap();
1943        assert_eq!(address.host(), "example.com")
1944    }
1945
1946    #[test]
1947    fn to_server_addrs_vec_str() {
1948        let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1949        let mut addrs_iter = vec.to_server_addrs().unwrap();
1950        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1951        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1952        assert_eq!(addrs_iter.next(), None);
1953    }
1954
1955    #[test]
1956    fn to_server_addrs_arr_str() {
1957        let arr = ["nats://127.0.0.1", "nats://[::]"];
1958        let mut addrs_iter = arr.to_server_addrs().unwrap();
1959        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1960        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1961        assert_eq!(addrs_iter.next(), None);
1962    }
1963
1964    #[test]
1965    fn to_server_addrs_vec_string() {
1966        let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1967        let mut addrs_iter = vec.to_server_addrs().unwrap();
1968        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1969        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1970        assert_eq!(addrs_iter.next(), None);
1971    }
1972
1973    #[test]
1974    fn to_server_addrs_arr_string() {
1975        let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1976        let mut addrs_iter = arr.to_server_addrs().unwrap();
1977        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1978        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1979        assert_eq!(addrs_iter.next(), None);
1980    }
1981
1982    #[test]
1983    fn to_server_ports_arr_string() {
1984        for (arr, expected_port) in [
1985            (
1986                [
1987                    "nats://127.0.0.1".to_string(),
1988                    "nats://[::]".to_string(),
1989                    "tls://127.0.0.1".to_string(),
1990                    "tls://[::]".to_string(),
1991                ],
1992                4222,
1993            ),
1994            (
1995                [
1996                    "ws://127.0.0.1:80".to_string(),
1997                    "ws://[::]:80".to_string(),
1998                    "ws://127.0.0.1".to_string(),
1999                    "ws://[::]".to_string(),
2000                ],
2001                80,
2002            ),
2003            (
2004                [
2005                    "wss://127.0.0.1".to_string(),
2006                    "wss://[::]".to_string(),
2007                    "wss://127.0.0.1:443".to_string(),
2008                    "wss://[::]:443".to_string(),
2009                ],
2010                443,
2011            ),
2012        ] {
2013            let mut addrs_iter = arr.to_server_addrs().unwrap();
2014            assert_eq!(addrs_iter.next().unwrap().port(), expected_port);
2015        }
2016    }
2017}