async_nats/
lib.rs

1// Copyright 2020-2022 The NATS Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! A Rust asynchronous client for the NATS.io ecosystem.
15//!
16//! To access the repository, you can clone it by running:
17//!
18//! ```bash
19//! git clone https://github.com/nats-io/nats.rs
20//! ````
21//! NATS.io is a simple, secure, and high-performance open-source messaging
22//! system designed for cloud-native applications, IoT messaging, and microservices
23//! architectures.
24//!
25//! **Note**: The synchronous NATS API is deprecated and no longer actively maintained. If you need to use the deprecated synchronous API, you can refer to:
26//! <https://crates.io/crates/nats>
27//!
28//! For more information on NATS.io visit: <https://nats.io>
29//!
30//! ## Examples
31//!
32//! Below, you can find some basic examples on how to use this library.
33//!
34//! For more details, please refer to the specific methods and structures documentation.
35//!
36//! ### Complete example
37//!
38//! Connect to the NATS server, publish messages and subscribe to receive messages.
39//!
40//! ```no_run
41//! use bytes::Bytes;
42//! use futures_util::StreamExt;
43//!
44//! #[tokio::main]
45//! async fn main() -> Result<(), async_nats::Error> {
46//!     // Connect to the NATS server
47//!     let client = async_nats::connect("demo.nats.io").await?;
48//!
49//!     // Subscribe to the "messages" subject
50//!     let mut subscriber = client.subscribe("messages").await?;
51//!
52//!     // Publish messages to the "messages" subject
53//!     for _ in 0..10 {
54//!         client.publish("messages", "data".into()).await?;
55//!     }
56//!
57//!     // Receive and process messages
58//!     while let Some(message) = subscriber.next().await {
59//!         println!("Received message {:?}", message);
60//!     }
61//!
62//!     Ok(())
63//! }
64//! ```
65//!
66//! ### Publish
67//!
68//! Connect to the NATS server and publish messages to a subject.
69//!
70//! ```
71//! # use bytes::Bytes;
72//! # use std::error::Error;
73//! # use std::time::Instant;
74//! # #[tokio::main]
75//! # async fn main() -> Result<(), async_nats::Error> {
76//! // Connect to the NATS server
77//! let client = async_nats::connect("demo.nats.io").await?;
78//!
79//! // Prepare the subject and data
80//! let subject = "foo";
81//! let data = Bytes::from("bar");
82//!
83//! // Publish messages to the NATS server
84//! for _ in 0..10 {
85//!     client.publish(subject, data.clone()).await?;
86//! }
87//!
88//! // Flush internal buffer before exiting to make sure all messages are sent
89//! client.flush().await?;
90//!
91//! #    Ok(())
92//! # }
93//! ```
94//!
95//! ### Subscribe
96//!
97//! Connect to the NATS server, subscribe to a subject and receive messages.
98//!
99//! ```no_run
100//! # use bytes::Bytes;
101//! # use futures_util::StreamExt;
102//! # use std::error::Error;
103//! # use std::time::Instant;
104//! # #[tokio::main]
105//! # async fn main() -> Result<(), async_nats::Error> {
106//! // Connect to the NATS server
107//! let client = async_nats::connect("demo.nats.io").await?;
108//!
109//! // Subscribe to the "foo" subject
110//! let mut subscriber = client.subscribe("foo").await.unwrap();
111//!
112//! // Receive and process messages
113//! while let Some(message) = subscriber.next().await {
114//!     println!("Received message {:?}", message);
115//! }
116//! #     Ok(())
117//! # }
118//! ```
119//!
120//! ### JetStream
121//!
122//! To access JetStream API, create a JetStream [jetstream::Context].
123//!
124//! ```no_run
125//! # #[tokio::main]
126//! # async fn main() -> Result<(), async_nats::Error> {
127//! // Connect to the NATS server
128//! let client = async_nats::connect("demo.nats.io").await?;
129//! // Create a JetStream context.
130//! let jetstream = async_nats::jetstream::new(client);
131//!
132//! // Publish JetStream messages, manage streams, consumers, etc.
133//! jetstream.publish("foo", "bar".into()).await?;
134//! # Ok(())
135//! # }
136//! ```
137//!
138//! ### Key-value Store
139//!
140//! Key-value [Store][jetstream::kv::Store] is accessed through [jetstream::Context].
141//!
142//! ```no_run
143//! # #[tokio::main]
144//! # async fn main() -> Result<(), async_nats::Error> {
145//! // Connect to the NATS server
146//! let client = async_nats::connect("demo.nats.io").await?;
147//! // Create a JetStream context.
148//! let jetstream = async_nats::jetstream::new(client);
149//! // Access an existing key-value.
150//! let kv = jetstream.get_key_value("store").await?;
151//! # Ok(())
152//! # }
153//! ```
154//! ### Object Store store
155//!
156//! Object [Store][jetstream::object_store::ObjectStore] is accessed through [jetstream::Context].
157//!
158//! ```no_run
159//! # #[tokio::main]
160//! # async fn main() -> Result<(), async_nats::Error> {
161//! // Connect to the NATS server
162//! let client = async_nats::connect("demo.nats.io").await?;
163//! // Create a JetStream context.
164//! let jetstream = async_nats::jetstream::new(client);
165//! // Access an existing key-value.
166//! let kv = jetstream.get_object_store("store").await?;
167//! # Ok(())
168//! # }
169//! ```
170//! ### Service API
171//!
172//! [Service API][service::Service] is accessible through [Client] after importing its trait.
173//!
174//! ```no_run
175//! # #[tokio::main]
176//! # async fn main() -> Result<(), async_nats::Error> {
177//! use async_nats::service::ServiceExt;
178//! // Connect to the NATS server
179//! let client = async_nats::connect("demo.nats.io").await?;
180//! let mut service = client
181//!     .service_builder()
182//!     .description("some service")
183//!     .stats_handler(|endpoint, stats| serde_json::json!({ "endpoint": endpoint }))
184//!     .start("products", "1.0.0")
185//!     .await?;
186//! # Ok(())
187//! # }
188//! ```
189
190#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_auto_cfg))]
196
197use thiserror::Error;
198
199use futures_util::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use std::collections::HashMap;
206use std::fmt::Display;
207use std::future::Future;
208use std::iter;
209use std::mem;
210use std::net::SocketAddr;
211use std::option;
212use std::pin::Pin;
213use std::slice;
214use std::str::{self, FromStr};
215use std::sync::atomic::AtomicUsize;
216use std::sync::atomic::Ordering;
217use std::sync::Arc;
218use std::task::{Context, Poll};
219use tokio::io::ErrorKind;
220use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
221use url::{Host, Url};
222
223use bytes::Bytes;
224use serde::{Deserialize, Serialize};
225use serde_repr::{Deserialize_repr, Serialize_repr};
226use tokio::io;
227use tokio::sync::mpsc;
228use tokio::task;
229
230pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
231
232const VERSION: &str = env!("CARGO_PKG_VERSION");
233const LANG: &str = "rust";
234const MAX_PENDING_PINGS: usize = 2;
235const MULTIPLEXER_SID: u64 = 0;
236
237/// A re-export of the `rustls` crate used in this crate,
238/// for use in cases where manual client configurations
239/// must be provided using `Options::tls_client_config`.
240pub use tokio_rustls::rustls;
241
242use connection::{Connection, State};
243use connector::{Connector, ConnectorOptions};
244pub use header::{HeaderMap, HeaderName, HeaderValue};
245pub use subject::Subject;
246
247mod auth;
248pub(crate) mod auth_utils;
249pub mod client;
250pub mod connection;
251mod connector;
252mod options;
253
254pub use auth::Auth;
255pub use client::{
256    Client, PublishError, Request, RequestError, RequestErrorKind, Statistics, SubscribeError,
257};
258pub use options::{AuthError, ConnectOptions};
259
260mod crypto;
261pub mod error;
262pub mod header;
263pub mod jetstream;
264pub mod message;
265#[cfg(feature = "service")]
266pub mod service;
267pub mod status;
268pub mod subject;
269mod tls;
270
271pub use message::Message;
272pub use status::StatusCode;
273
274/// Information sent by the server back to this client
275/// during initial connection, and possibly again later.
276#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
277pub struct ServerInfo {
278    /// The unique identifier of the NATS server.
279    #[serde(default)]
280    pub server_id: String,
281    /// Generated Server Name.
282    #[serde(default)]
283    pub server_name: String,
284    /// The host specified in the cluster parameter/options.
285    #[serde(default)]
286    pub host: String,
287    /// The port number specified in the cluster parameter/options.
288    #[serde(default)]
289    pub port: u16,
290    /// The version of the NATS server.
291    #[serde(default)]
292    pub version: String,
293    /// If this is set, then the server should try to authenticate upon
294    /// connect.
295    #[serde(default)]
296    pub auth_required: bool,
297    /// If this is set, then the server must authenticate using TLS.
298    #[serde(default)]
299    pub tls_required: bool,
300    /// Maximum payload size that the server will accept.
301    #[serde(default)]
302    pub max_payload: usize,
303    /// The protocol version in use.
304    #[serde(default)]
305    pub proto: i8,
306    /// The server-assigned client ID. This may change during reconnection.
307    #[serde(default)]
308    pub client_id: u64,
309    /// The version of golang the NATS server was built with.
310    #[serde(default)]
311    pub go: String,
312    /// The nonce used for nkeys.
313    #[serde(default)]
314    pub nonce: String,
315    /// A list of server urls that a client can connect to.
316    #[serde(default)]
317    pub connect_urls: Vec<String>,
318    /// The client IP as known by the server.
319    #[serde(default)]
320    pub client_ip: String,
321    /// Whether the server supports headers.
322    #[serde(default)]
323    pub headers: bool,
324    /// Whether server goes into lame duck mode.
325    #[serde(default, rename = "ldm")]
326    pub lame_duck_mode: bool,
327    /// Name of the cluster if the server is in cluster-mode
328    #[serde(default)]
329    pub cluster: Option<String>,
330    /// The configured NATS domain of the server.
331    #[serde(default)]
332    pub domain: Option<String>,
333    /// Whether the server supports JetStream.
334    #[serde(default)]
335    pub jetstream: bool,
336}
337
338#[derive(Clone, Debug, Eq, PartialEq)]
339pub(crate) enum ServerOp {
340    Ok,
341    Info(Box<ServerInfo>),
342    Ping,
343    Pong,
344    Error(ServerError),
345    Message {
346        sid: u64,
347        subject: Subject,
348        reply: Option<Subject>,
349        payload: Bytes,
350        headers: Option<HeaderMap>,
351        status: Option<StatusCode>,
352        description: Option<String>,
353        length: usize,
354    },
355}
356
357/// `PublishMessage` represents a message being published
358#[derive(Debug)]
359pub struct PublishMessage {
360    pub subject: Subject,
361    pub payload: Bytes,
362    pub reply: Option<Subject>,
363    pub headers: Option<HeaderMap>,
364}
365
366/// `Command` represents all commands that a [`Client`] can handle
367#[derive(Debug)]
368pub(crate) enum Command {
369    Publish(PublishMessage),
370    Request {
371        subject: Subject,
372        payload: Bytes,
373        respond: Subject,
374        headers: Option<HeaderMap>,
375        sender: oneshot::Sender<Message>,
376    },
377    Subscribe {
378        sid: u64,
379        subject: Subject,
380        queue_group: Option<String>,
381        sender: mpsc::Sender<Message>,
382    },
383    Unsubscribe {
384        sid: u64,
385        max: Option<u64>,
386    },
387    Flush {
388        observer: oneshot::Sender<()>,
389    },
390    Drain {
391        sid: Option<u64>,
392    },
393    Reconnect,
394}
395
396/// `ClientOp` represents all actions of `Client`.
397#[derive(Debug)]
398pub(crate) enum ClientOp {
399    Publish {
400        subject: Subject,
401        payload: Bytes,
402        respond: Option<Subject>,
403        headers: Option<HeaderMap>,
404    },
405    Subscribe {
406        sid: u64,
407        subject: Subject,
408        queue_group: Option<String>,
409    },
410    Unsubscribe {
411        sid: u64,
412        max: Option<u64>,
413    },
414    Ping,
415    Pong,
416    Connect(ConnectInfo),
417}
418
419#[derive(Debug)]
420struct Subscription {
421    subject: Subject,
422    sender: mpsc::Sender<Message>,
423    queue_group: Option<String>,
424    delivered: u64,
425    max: Option<u64>,
426    is_draining: bool,
427}
428
429#[derive(Debug)]
430struct Multiplexer {
431    subject: Subject,
432    prefix: Subject,
433    senders: HashMap<String, oneshot::Sender<Message>>,
434}
435
436/// A connection handler which facilitates communication from channels to a single shared connection.
437pub(crate) struct ConnectionHandler {
438    connection: Connection,
439    connector: Connector,
440    subscriptions: HashMap<u64, Subscription>,
441    multiplexer: Option<Multiplexer>,
442    pending_pings: usize,
443    info_sender: tokio::sync::watch::Sender<ServerInfo>,
444    ping_interval: Interval,
445    should_reconnect: bool,
446    flush_observers: Vec<oneshot::Sender<()>>,
447    is_draining: bool,
448}
449
450impl ConnectionHandler {
451    pub(crate) fn new(
452        connection: Connection,
453        connector: Connector,
454        info_sender: tokio::sync::watch::Sender<ServerInfo>,
455        ping_period: Duration,
456    ) -> ConnectionHandler {
457        let mut ping_interval = interval(ping_period);
458        ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
459
460        ConnectionHandler {
461            connection,
462            connector,
463            subscriptions: HashMap::new(),
464            multiplexer: None,
465            pending_pings: 0,
466            info_sender,
467            ping_interval,
468            should_reconnect: false,
469            flush_observers: Vec::new(),
470            is_draining: false,
471        }
472    }
473
474    pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
475        struct ProcessFut<'a> {
476            handler: &'a mut ConnectionHandler,
477            receiver: &'a mut mpsc::Receiver<Command>,
478            recv_buf: &'a mut Vec<Command>,
479        }
480
481        enum ExitReason {
482            Disconnected(Option<io::Error>),
483            ReconnectRequested,
484            Closed,
485        }
486
487        impl ProcessFut<'_> {
488            const RECV_CHUNK_SIZE: usize = 16;
489
490            #[cold]
491            fn ping(&mut self) -> Poll<ExitReason> {
492                self.handler.pending_pings += 1;
493
494                if self.handler.pending_pings > MAX_PENDING_PINGS {
495                    debug!(
496                        pending_pings = self.handler.pending_pings,
497                        max_pings = MAX_PENDING_PINGS,
498                        "disconnecting due to too many pending pings"
499                    );
500
501                    Poll::Ready(ExitReason::Disconnected(None))
502                } else {
503                    self.handler.connection.enqueue_write_op(&ClientOp::Ping);
504
505                    Poll::Pending
506                }
507            }
508        }
509
510        impl Future for ProcessFut<'_> {
511            type Output = ExitReason;
512
513            /// Drives the connection forward.
514            ///
515            /// Returns one of the following:
516            ///
517            /// * `Poll::Pending` means that the connection
518            ///   is blocked on all fronts or there are
519            ///   no commands to send or receive
520            /// * `Poll::Ready(ExitReason::Disconnected(_))` means
521            ///   that an I/O operation failed and the connection
522            ///   is considered dead.
523            /// * `Poll::Ready(ExitReason::Closed)` means that
524            ///   [`Self::receiver`] was closed, so there's nothing
525            ///   more for us to do than to exit the client.
526            fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
527                // We need to be sure the waker is registered, therefore we need to poll until we
528                // get a `Poll::Pending`. With a sane interval delay, this means that the loop
529                // breaks at the second iteration.
530                while self.handler.ping_interval.poll_tick(cx).is_ready() {
531                    if let Poll::Ready(exit) = self.ping() {
532                        return Poll::Ready(exit);
533                    }
534                }
535
536                loop {
537                    match self.handler.connection.poll_read_op(cx) {
538                        Poll::Pending => break,
539                        Poll::Ready(Ok(Some(server_op))) => {
540                            self.handler.handle_server_op(server_op);
541                        }
542                        Poll::Ready(Ok(None)) => {
543                            return Poll::Ready(ExitReason::Disconnected(None))
544                        }
545                        Poll::Ready(Err(err)) => {
546                            return Poll::Ready(ExitReason::Disconnected(Some(err)))
547                        }
548                    }
549                }
550
551                // Before handling any commands, drop any subscriptions which are draining
552                // Note: safe to assume subscription drain has completed at this point, as we would have flushed
553                // all outgoing UNSUB messages in the previous call to this fn, and we would have processed and
554                // delivered any remaining messages to the subscription in the loop above.
555                self.handler.subscriptions.retain(|_, s| !s.is_draining);
556
557                if self.handler.is_draining {
558                    // The entire connection is draining. This means we flushed outgoing messages in the previous
559                    // call to this fn, we handled any remaining messages from the server in the loop above, and
560                    // all subs were drained, so drain is complete and we should exit instead of processing any
561                    // further messages
562                    return Poll::Ready(ExitReason::Closed);
563                }
564
565                // WARNING: after the following loop `handle_command`,
566                // or other functions which call `enqueue_write_op`,
567                // cannot be called anymore. Runtime wakeups won't
568                // trigger a call to `poll_write`
569
570                let mut made_progress = true;
571                loop {
572                    while !self.handler.connection.is_write_buf_full() {
573                        debug_assert!(self.recv_buf.is_empty());
574
575                        let Self {
576                            recv_buf,
577                            handler,
578                            receiver,
579                        } = &mut *self;
580                        match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
581                            Poll::Pending => break,
582                            Poll::Ready(1..) => {
583                                made_progress = true;
584
585                                for cmd in recv_buf.drain(..) {
586                                    handler.handle_command(cmd);
587                                }
588                            }
589                            // TODO: replace `_` with `0` after bumping MSRV to 1.75
590                            Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
591                        }
592                    }
593
594                    // The first round will poll both from
595                    // the `receiver` and the writer, giving
596                    // them both a chance to make progress
597                    // and register `Waker`s.
598                    //
599                    // If writing is `Poll::Pending` we exit.
600                    //
601                    // If writing is completed we can repeat the entire
602                    // cycle as long as the `receiver` doesn't end-up
603                    // `Poll::Pending` immediately.
604                    if !mem::take(&mut made_progress) {
605                        break;
606                    }
607
608                    match self.handler.connection.poll_write(cx) {
609                        Poll::Pending => {
610                            // Write buffer couldn't be fully emptied
611                            break;
612                        }
613                        Poll::Ready(Ok(())) => {
614                            // Write buffer is empty
615                            continue;
616                        }
617                        Poll::Ready(Err(err)) => {
618                            return Poll::Ready(ExitReason::Disconnected(Some(err)))
619                        }
620                    }
621                }
622
623                if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
624                    self.handler.connection.should_flush(),
625                    self.handler.flush_observers.is_empty(),
626                ) {
627                    match self.handler.connection.poll_flush(cx) {
628                        Poll::Pending => {}
629                        Poll::Ready(Ok(())) => {
630                            for observer in self.handler.flush_observers.drain(..) {
631                                let _ = observer.send(());
632                            }
633                        }
634                        Poll::Ready(Err(err)) => {
635                            return Poll::Ready(ExitReason::Disconnected(Some(err)))
636                        }
637                    }
638                }
639
640                if mem::take(&mut self.handler.should_reconnect) {
641                    return Poll::Ready(ExitReason::ReconnectRequested);
642                }
643
644                Poll::Pending
645            }
646        }
647
648        let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
649        loop {
650            let process = ProcessFut {
651                handler: self,
652                receiver,
653                recv_buf: &mut recv_buf,
654            };
655            match process.await {
656                ExitReason::Disconnected(err) => {
657                    debug!(error = ?err, "disconnected");
658                    if self.handle_disconnect().await.is_err() {
659                        break;
660                    };
661                    debug!("reconnected");
662                }
663                ExitReason::Closed => {
664                    // Safe to ignore result as we're shutting down anyway
665                    self.connector.events_tx.try_send(Event::Closed).ok();
666                    break;
667                }
668                ExitReason::ReconnectRequested => {
669                    debug!("reconnect requested");
670                    // Should be ok to ingore error, as that means we are not in connected state.
671                    self.connection.stream.shutdown().await.ok();
672                    if self.handle_disconnect().await.is_err() {
673                        break;
674                    };
675                }
676            }
677        }
678    }
679
680    fn handle_server_op(&mut self, server_op: ServerOp) {
681        self.ping_interval.reset();
682
683        match server_op {
684            ServerOp::Ping => {
685                debug!("received PING");
686                self.connection.enqueue_write_op(&ClientOp::Pong);
687            }
688            ServerOp::Pong => {
689                debug!("received PONG");
690                self.pending_pings = self.pending_pings.saturating_sub(1);
691            }
692            ServerOp::Error(error) => {
693                debug!("received ERROR: {:?}", error);
694                self.connector
695                    .events_tx
696                    .try_send(Event::ServerError(error))
697                    .ok();
698            }
699            ServerOp::Message {
700                sid,
701                subject,
702                reply,
703                payload,
704                headers,
705                status,
706                description,
707                length,
708            } => {
709                debug!("received MESSAGE: sid={}, subject={}", sid, subject);
710                self.connector
711                    .connect_stats
712                    .in_messages
713                    .add(1, Ordering::Relaxed);
714
715                if let Some(subscription) = self.subscriptions.get_mut(&sid) {
716                    let message: Message = Message {
717                        subject,
718                        reply,
719                        payload,
720                        headers,
721                        status,
722                        description,
723                        length,
724                    };
725
726                    // if the channel for subscription was dropped, remove the
727                    // subscription from the map and unsubscribe.
728                    match subscription.sender.try_send(message) {
729                        Ok(_) => {
730                            subscription.delivered += 1;
731                            // if this `Subscription` has set `max` value, check if it
732                            // was reached. If yes, remove the `Subscription` and in
733                            // the result, `drop` the `sender` channel.
734                            if let Some(max) = subscription.max {
735                                if subscription.delivered.ge(&max) {
736                                    debug!("max messages reached for subscription {}", sid);
737                                    self.subscriptions.remove(&sid);
738                                }
739                            }
740                        }
741                        Err(mpsc::error::TrySendError::Full(_)) => {
742                            debug!("slow consumer detected for subscription {}", sid);
743                            self.connector
744                                .events_tx
745                                .try_send(Event::SlowConsumer(sid))
746                                .ok();
747                        }
748                        Err(mpsc::error::TrySendError::Closed(_)) => {
749                            debug!("subscription {} channel closed", sid);
750                            self.subscriptions.remove(&sid);
751                            self.connection
752                                .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
753                        }
754                    }
755                } else if sid == MULTIPLEXER_SID {
756                    debug!("received message for multiplexer");
757                    if let Some(multiplexer) = self.multiplexer.as_mut() {
758                        let maybe_token =
759                            subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
760
761                        if let Some(token) = maybe_token {
762                            if let Some(sender) = multiplexer.senders.remove(token) {
763                                debug!("forwarding message to request with token {}", token);
764                                let message = Message {
765                                    subject,
766                                    reply,
767                                    payload,
768                                    headers,
769                                    status,
770                                    description,
771                                    length,
772                                };
773
774                                let _ = sender.send(message);
775                            }
776                        }
777                    }
778                }
779            }
780            // TODO: we should probably update advertised server list here too.
781            ServerOp::Info(info) => {
782                debug!("received INFO: server_id={}", info.server_id);
783                if info.lame_duck_mode {
784                    debug!("server in lame duck mode");
785                    self.connector.events_tx.try_send(Event::LameDuckMode).ok();
786                }
787            }
788
789            _ => {
790                // TODO: don't ignore.
791            }
792        }
793    }
794
795    fn handle_command(&mut self, command: Command) {
796        self.ping_interval.reset();
797
798        match command {
799            Command::Unsubscribe { sid, max } => {
800                if let Some(subscription) = self.subscriptions.get_mut(&sid) {
801                    subscription.max = max;
802                    match subscription.max {
803                        Some(n) => {
804                            if subscription.delivered >= n {
805                                self.subscriptions.remove(&sid);
806                            }
807                        }
808                        None => {
809                            self.subscriptions.remove(&sid);
810                        }
811                    }
812
813                    self.connection
814                        .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
815                }
816            }
817            Command::Flush { observer } => {
818                self.flush_observers.push(observer);
819            }
820            Command::Drain { sid } => {
821                let mut drain_sub = |sid: u64, sub: &mut Subscription| {
822                    sub.is_draining = true;
823                    self.connection
824                        .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
825                };
826
827                if let Some(sid) = sid {
828                    if let Some(sub) = self.subscriptions.get_mut(&sid) {
829                        drain_sub(sid, sub);
830                    }
831                } else {
832                    // sid isn't set, so drain the whole client
833                    self.connector.events_tx.try_send(Event::Draining).ok();
834                    self.is_draining = true;
835                    for (&sid, sub) in self.subscriptions.iter_mut() {
836                        drain_sub(sid, sub);
837                    }
838                }
839            }
840            Command::Subscribe {
841                sid,
842                subject,
843                queue_group,
844                sender,
845            } => {
846                let subscription = Subscription {
847                    sender,
848                    delivered: 0,
849                    max: None,
850                    subject: subject.to_owned(),
851                    queue_group: queue_group.to_owned(),
852                    is_draining: false,
853                };
854
855                self.subscriptions.insert(sid, subscription);
856
857                self.connection.enqueue_write_op(&ClientOp::Subscribe {
858                    sid,
859                    subject,
860                    queue_group,
861                });
862            }
863            Command::Request {
864                subject,
865                payload,
866                respond,
867                headers,
868                sender,
869            } => {
870                let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
871
872                let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
873                    multiplexer
874                } else {
875                    let prefix = Subject::from(format!("{}.{}.", prefix, nuid::next()));
876                    let subject = Subject::from(format!("{prefix}*"));
877
878                    self.connection.enqueue_write_op(&ClientOp::Subscribe {
879                        sid: MULTIPLEXER_SID,
880                        subject: subject.clone(),
881                        queue_group: None,
882                    });
883
884                    self.multiplexer.insert(Multiplexer {
885                        subject,
886                        prefix,
887                        senders: HashMap::new(),
888                    })
889                };
890                self.connector
891                    .connect_stats
892                    .out_messages
893                    .add(1, Ordering::Relaxed);
894
895                multiplexer.senders.insert(token.to_owned(), sender);
896
897                let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
898
899                let pub_op = ClientOp::Publish {
900                    subject,
901                    payload,
902                    respond: Some(respond),
903                    headers,
904                };
905
906                self.connection.enqueue_write_op(&pub_op);
907            }
908
909            Command::Publish(PublishMessage {
910                subject,
911                payload,
912                reply: respond,
913                headers,
914            }) => {
915                self.connector
916                    .connect_stats
917                    .out_messages
918                    .add(1, Ordering::Relaxed);
919
920                let header_len = headers
921                    .as_ref()
922                    .map(|headers| headers.len())
923                    .unwrap_or_default();
924
925                self.connector.connect_stats.out_bytes.add(
926                    (payload.len()
927                        + respond.as_ref().map_or_else(|| 0, |r| r.len())
928                        + subject.len()
929                        + header_len) as u64,
930                    Ordering::Relaxed,
931                );
932
933                self.connection.enqueue_write_op(&ClientOp::Publish {
934                    subject,
935                    payload,
936                    respond,
937                    headers,
938                });
939            }
940
941            Command::Reconnect => {
942                self.should_reconnect = true;
943            }
944        }
945    }
946
947    async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
948        self.pending_pings = 0;
949        self.connector.events_tx.try_send(Event::Disconnected).ok();
950        self.connector.state_tx.send(State::Disconnected).ok();
951
952        self.handle_reconnect().await
953    }
954
955    async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
956        let (info, connection) = self.connector.connect().await?;
957        self.connection = connection;
958        let _ = self.info_sender.send(info);
959
960        self.subscriptions
961            .retain(|_, subscription| !subscription.sender.is_closed());
962
963        for (sid, subscription) in &self.subscriptions {
964            self.connection.enqueue_write_op(&ClientOp::Subscribe {
965                sid: *sid,
966                subject: subscription.subject.to_owned(),
967                queue_group: subscription.queue_group.to_owned(),
968            });
969        }
970
971        if let Some(multiplexer) = &self.multiplexer {
972            self.connection.enqueue_write_op(&ClientOp::Subscribe {
973                sid: MULTIPLEXER_SID,
974                subject: multiplexer.subject.to_owned(),
975                queue_group: None,
976            });
977        }
978        Ok(())
979    }
980}
981
982/// Connects to NATS with specified options.
983///
984/// It is generally advised to use [ConnectOptions] instead, as it provides a builder for whole
985/// configuration.
986///
987/// # Examples
988/// ```
989/// # #[tokio::main]
990/// # async fn main() ->  Result<(), async_nats::Error> {
991/// let mut nc =
992///     async_nats::connect_with_options("demo.nats.io", async_nats::ConnectOptions::new()).await?;
993/// nc.publish("test", "data".into()).await?;
994/// # Ok(())
995/// # }
996/// ```
997pub async fn connect_with_options<A: ToServerAddrs>(
998    addrs: A,
999    options: ConnectOptions,
1000) -> Result<Client, ConnectError> {
1001    let ping_period = options.ping_interval;
1002
1003    let (events_tx, mut events_rx) = mpsc::channel(128);
1004    let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
1005    // We're setting it to the default server payload size.
1006    let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));
1007    let statistics = Arc::new(Statistics::default());
1008
1009    let mut connector = Connector::new(
1010        addrs,
1011        ConnectorOptions {
1012            tls_required: options.tls_required,
1013            certificates: options.certificates,
1014            client_key: options.client_key,
1015            client_cert: options.client_cert,
1016            tls_client_config: options.tls_client_config,
1017            tls_first: options.tls_first,
1018            auth: options.auth,
1019            no_echo: options.no_echo,
1020            connection_timeout: options.connection_timeout,
1021            name: options.name,
1022            ignore_discovered_servers: options.ignore_discovered_servers,
1023            retain_servers_order: options.retain_servers_order,
1024            read_buffer_capacity: options.read_buffer_capacity,
1025            reconnect_delay_callback: options.reconnect_delay_callback,
1026            auth_callback: options.auth_callback,
1027            max_reconnects: options.max_reconnects,
1028        },
1029        events_tx,
1030        state_tx,
1031        max_payload.clone(),
1032        statistics.clone(),
1033    )
1034    .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
1035
1036    let mut info: ServerInfo = Default::default();
1037    let mut connection = None;
1038    if !options.retry_on_initial_connect {
1039        debug!("retry on initial connect failure is disabled");
1040        let (info_ok, connection_ok) = connector.try_connect().await?;
1041        connection = Some(connection_ok);
1042        info = info_ok;
1043    }
1044
1045    let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
1046    let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
1047
1048    let client = Client::new(
1049        info_watcher,
1050        state_rx,
1051        sender,
1052        options.subscription_capacity,
1053        options.inbox_prefix,
1054        options.request_timeout,
1055        max_payload,
1056        statistics,
1057    );
1058
1059    task::spawn(async move {
1060        while let Some(event) = events_rx.recv().await {
1061            tracing::info!("event: {}", event);
1062            if let Some(event_callback) = &options.event_callback {
1063                event_callback.call(event).await;
1064            }
1065        }
1066    });
1067
1068    task::spawn(async move {
1069        if connection.is_none() && options.retry_on_initial_connect {
1070            let (info, connection_ok) = match connector.connect().await {
1071                Ok((info, connection)) => (info, connection),
1072                Err(err) => {
1073                    error!("connection closed: {}", err);
1074                    return;
1075                }
1076            };
1077            info_sender.send(info).ok();
1078            connection = Some(connection_ok);
1079        }
1080        let connection = connection.unwrap();
1081        let mut connection_handler =
1082            ConnectionHandler::new(connection, connector, info_sender, ping_period);
1083        connection_handler.process(&mut receiver).await
1084    });
1085
1086    Ok(client)
1087}
1088
1089#[derive(Debug, Clone, PartialEq, Eq)]
1090pub enum Event {
1091    Connected,
1092    Disconnected,
1093    LameDuckMode,
1094    Draining,
1095    Closed,
1096    SlowConsumer(u64),
1097    ServerError(ServerError),
1098    ClientError(ClientError),
1099}
1100
1101impl fmt::Display for Event {
1102    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1103        match self {
1104            Event::Connected => write!(f, "connected"),
1105            Event::Disconnected => write!(f, "disconnected"),
1106            Event::LameDuckMode => write!(f, "lame duck mode detected"),
1107            Event::Draining => write!(f, "draining"),
1108            Event::Closed => write!(f, "closed"),
1109            Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"),
1110            Event::ServerError(err) => write!(f, "server error: {err}"),
1111            Event::ClientError(err) => write!(f, "client error: {err}"),
1112        }
1113    }
1114}
1115
1116/// Connects to NATS with default config.
1117///
1118/// Returns cloneable [Client].
1119///
1120/// To have customized NATS connection, check [ConnectOptions].
1121///
1122/// # Examples
1123///
1124/// ## Single URL
1125/// ```
1126/// # #[tokio::main]
1127/// # async fn main() ->  Result<(), async_nats::Error> {
1128/// let mut nc = async_nats::connect("demo.nats.io").await?;
1129/// nc.publish("test", "data".into()).await?;
1130/// # Ok(())
1131/// # }
1132/// ```
1133///
1134/// ## Connect with [Vec] of [ServerAddr].
1135/// ```no_run
1136/// #[tokio::main]
1137/// # async fn main() -> Result<(), async_nats::Error> {
1138/// use async_nats::ServerAddr;
1139/// let client = async_nats::connect(vec![
1140///     "demo.nats.io".parse::<ServerAddr>()?,
1141///     "other.nats.io".parse::<ServerAddr>()?,
1142/// ])
1143/// .await
1144/// .unwrap();
1145/// # Ok(())
1146/// # }
1147/// ```
1148///
1149/// ## with [Vec], but parse URLs inside [crate::connect()]
1150/// ```no_run
1151/// #[tokio::main]
1152/// # async fn main() -> Result<(), async_nats::Error> {
1153/// use async_nats::ServerAddr;
1154/// let servers = vec!["demo.nats.io", "other.nats.io"];
1155/// let client = async_nats::connect(
1156///     servers
1157///         .iter()
1158///         .map(|url| url.parse())
1159///         .collect::<Result<Vec<ServerAddr>, _>>()?,
1160/// )
1161/// .await?;
1162/// # Ok(())
1163/// # }
1164/// ```
1165///
1166///
1167/// ## with slice.
1168/// ```no_run
1169/// #[tokio::main]
1170/// # async fn main() -> Result<(), async_nats::Error> {
1171/// use async_nats::ServerAddr;
1172/// let client = async_nats::connect(
1173///    [
1174///        "demo.nats.io".parse::<ServerAddr>()?,
1175///        "other.nats.io".parse::<ServerAddr>()?,
1176///    ]
1177///    .as_slice(),
1178/// )
1179/// .await?;
1180/// # Ok(())
1181/// # }
1182pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1183    connect_with_options(addrs, ConnectOptions::default()).await
1184}
1185
1186#[derive(Debug, Clone, Copy, PartialEq)]
1187pub enum ConnectErrorKind {
1188    /// Parsing the passed server address failed.
1189    ServerParse,
1190    /// DNS related issues.
1191    Dns,
1192    /// Failed authentication process, signing nonce, etc.
1193    Authentication,
1194    /// Server returned authorization violation error.
1195    AuthorizationViolation,
1196    /// Connect timed out.
1197    TimedOut,
1198    /// Erroneous TLS setup.
1199    Tls,
1200    /// Other IO error.
1201    Io,
1202    /// Reached the maximum number of reconnects.
1203    MaxReconnects,
1204}
1205
1206impl Display for ConnectErrorKind {
1207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1208        match self {
1209            Self::ServerParse => write!(f, "failed to parse server or server list"),
1210            Self::Dns => write!(f, "DNS error"),
1211            Self::Authentication => write!(f, "failed signing nonce"),
1212            Self::AuthorizationViolation => write!(f, "authorization violation"),
1213            Self::TimedOut => write!(f, "timed out"),
1214            Self::Tls => write!(f, "TLS error"),
1215            Self::Io => write!(f, "IO error"),
1216            Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1217        }
1218    }
1219}
1220
1221/// Returned when initial connection fails.
1222/// To be enumerate over the variants, call [ConnectError::kind].
1223pub type ConnectError = error::Error<ConnectErrorKind>;
1224
1225impl From<io::Error> for ConnectError {
1226    fn from(err: io::Error) -> Self {
1227        ConnectError::with_source(ConnectErrorKind::Io, err)
1228    }
1229}
1230
1231/// Retrieves messages from given `subscription` created by [Client::subscribe].
1232///
1233/// Implements [futures_util::stream::Stream] for ergonomic async message processing.
1234///
1235/// # Examples
1236/// ```
1237/// # #[tokio::main]
1238/// # async fn main() ->  Result<(), async_nats::Error> {
1239/// let mut nc = async_nats::connect("demo.nats.io").await?;
1240/// # nc.publish("test", "data".into()).await?;
1241/// # Ok(())
1242/// # }
1243/// ```
1244#[derive(Debug)]
1245pub struct Subscriber {
1246    sid: u64,
1247    receiver: mpsc::Receiver<Message>,
1248    sender: mpsc::Sender<Command>,
1249}
1250
1251impl Subscriber {
1252    fn new(
1253        sid: u64,
1254        sender: mpsc::Sender<Command>,
1255        receiver: mpsc::Receiver<Message>,
1256    ) -> Subscriber {
1257        Subscriber {
1258            sid,
1259            sender,
1260            receiver,
1261        }
1262    }
1263
1264    /// Unsubscribes from subscription, draining all remaining messages.
1265    ///
1266    /// # Examples
1267    /// ```
1268    /// # #[tokio::main]
1269    /// # async fn main() -> Result<(), async_nats::Error> {
1270    /// let client = async_nats::connect("demo.nats.io").await?;
1271    ///
1272    /// let mut subscriber = client.subscribe("foo").await?;
1273    ///
1274    /// subscriber.unsubscribe().await?;
1275    /// # Ok(())
1276    /// # }
1277    /// ```
1278    pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1279        self.sender
1280            .send(Command::Unsubscribe {
1281                sid: self.sid,
1282                max: None,
1283            })
1284            .await?;
1285        self.receiver.close();
1286        Ok(())
1287    }
1288
1289    /// Unsubscribes from subscription after reaching given number of messages.
1290    /// This is the total number of messages received by this subscription in it's whole
1291    /// lifespan. If it already reached or surpassed the passed value, it will immediately stop.
1292    ///
1293    /// # Examples
1294    /// ```
1295    /// # use futures_util::StreamExt;
1296    /// # #[tokio::main]
1297    /// # async fn main() -> Result<(), async_nats::Error> {
1298    /// let client = async_nats::connect("demo.nats.io").await?;
1299    ///
1300    /// let mut subscriber = client.subscribe("test").await?;
1301    /// subscriber.unsubscribe_after(3).await?;
1302    ///
1303    /// for _ in 0..3 {
1304    ///     client.publish("test", "data".into()).await?;
1305    /// }
1306    ///
1307    /// while let Some(message) = subscriber.next().await {
1308    ///     println!("message received: {:?}", message);
1309    /// }
1310    /// println!("no more messages, unsubscribed");
1311    /// # Ok(())
1312    /// # }
1313    /// ```
1314    pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1315        self.sender
1316            .send(Command::Unsubscribe {
1317                sid: self.sid,
1318                max: Some(unsub_after),
1319            })
1320            .await?;
1321        Ok(())
1322    }
1323
1324    /// Unsubscribes immediately but leaves the subscription open to allow any in-flight messages
1325    /// on the subscription to be delivered. The stream will be closed after any remaining messages
1326    /// are delivered
1327    ///
1328    /// # Examples
1329    /// ```no_run
1330    /// # use futures_util::StreamExt;
1331    /// # #[tokio::main]
1332    /// # async fn main() -> Result<(), async_nats::Error> {
1333    /// let client = async_nats::connect("demo.nats.io").await?;
1334    ///
1335    /// let mut subscriber = client.subscribe("test").await?;
1336    ///
1337    /// tokio::spawn({
1338    ///     let task_client = client.clone();
1339    ///     async move {
1340    ///         loop {
1341    ///             _ = task_client.publish("test", "data".into()).await;
1342    ///         }
1343    ///     }
1344    /// });
1345    ///
1346    /// client.flush().await?;
1347    /// subscriber.drain().await?;
1348    ///
1349    /// while let Some(message) = subscriber.next().await {
1350    ///     println!("message received: {:?}", message);
1351    /// }
1352    /// println!("no more messages, unsubscribed");
1353    /// # Ok(())
1354    /// # }
1355    /// ```
1356    pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
1357        self.sender
1358            .send(Command::Drain {
1359                sid: Some(self.sid),
1360            })
1361            .await?;
1362
1363        Ok(())
1364    }
1365}
1366
1367#[derive(Error, Debug, PartialEq)]
1368#[error("failed to send unsubscribe")]
1369pub struct UnsubscribeError(String);
1370
1371impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1372    fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1373        UnsubscribeError(err.to_string())
1374    }
1375}
1376
1377impl Drop for Subscriber {
1378    fn drop(&mut self) {
1379        self.receiver.close();
1380        tokio::spawn({
1381            let sender = self.sender.clone();
1382            let sid = self.sid;
1383            async move {
1384                sender
1385                    .send(Command::Unsubscribe { sid, max: None })
1386                    .await
1387                    .ok();
1388            }
1389        });
1390    }
1391}
1392
1393impl Stream for Subscriber {
1394    type Item = Message;
1395
1396    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1397        self.receiver.poll_recv(cx)
1398    }
1399}
1400
1401#[derive(Clone, Debug, Eq, PartialEq)]
1402pub enum CallbackError {
1403    Client(ClientError),
1404    Server(ServerError),
1405}
1406impl std::fmt::Display for CallbackError {
1407    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1408        match self {
1409            Self::Client(error) => write!(f, "{error}"),
1410            Self::Server(error) => write!(f, "{error}"),
1411        }
1412    }
1413}
1414
1415impl From<ServerError> for CallbackError {
1416    fn from(server_error: ServerError) -> Self {
1417        CallbackError::Server(server_error)
1418    }
1419}
1420
1421impl From<ClientError> for CallbackError {
1422    fn from(client_error: ClientError) -> Self {
1423        CallbackError::Client(client_error)
1424    }
1425}
1426
1427#[derive(Clone, Debug, Eq, PartialEq, Error)]
1428pub enum ServerError {
1429    AuthorizationViolation,
1430    SlowConsumer(u64),
1431    Other(String),
1432}
1433
1434#[derive(Clone, Debug, Eq, PartialEq)]
1435pub enum ClientError {
1436    Other(String),
1437    MaxReconnects,
1438}
1439impl std::fmt::Display for ClientError {
1440    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1441        match self {
1442            Self::Other(error) => write!(f, "nats: {error}"),
1443            Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1444        }
1445    }
1446}
1447
1448impl ServerError {
1449    fn new(error: String) -> ServerError {
1450        match error.to_lowercase().as_str() {
1451            "authorization violation" => ServerError::AuthorizationViolation,
1452            // error messages can contain case-sensitive values which should be preserved
1453            _ => ServerError::Other(error),
1454        }
1455    }
1456}
1457
1458impl std::fmt::Display for ServerError {
1459    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1460        match self {
1461            Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1462            Self::SlowConsumer(sid) => write!(f, "nats: subscription {sid} is a slow consumer"),
1463            Self::Other(error) => write!(f, "nats: {error}"),
1464        }
1465    }
1466}
1467
1468/// Info to construct a CONNECT message.
1469#[derive(Clone, Debug, Serialize)]
1470pub struct ConnectInfo {
1471    /// Turns on +OK protocol acknowledgments.
1472    pub verbose: bool,
1473
1474    /// Turns on additional strict format checking, e.g. for properly formed
1475    /// subjects.
1476    pub pedantic: bool,
1477
1478    /// User's JWT.
1479    #[serde(rename = "jwt")]
1480    pub user_jwt: Option<String>,
1481
1482    /// Public nkey.
1483    pub nkey: Option<String>,
1484
1485    /// Signed nonce, encoded to Base64URL.
1486    #[serde(rename = "sig")]
1487    pub signature: Option<String>,
1488
1489    /// Optional client name.
1490    pub name: Option<String>,
1491
1492    /// If set to `true`, the server (version 1.2.0+) will not send originating
1493    /// messages from this connection to its own subscriptions. Clients should
1494    /// set this to `true` only for server supporting this feature, which is
1495    /// when proto in the INFO protocol is set to at least 1.
1496    pub echo: bool,
1497
1498    /// The implementation language of the client.
1499    pub lang: String,
1500
1501    /// The version of the client.
1502    pub version: String,
1503
1504    /// Sending 0 (or absent) indicates client supports original protocol.
1505    /// Sending 1 indicates that the client supports dynamic reconfiguration
1506    /// of cluster topology changes by asynchronously receiving INFO messages
1507    /// with known servers it can reconnect to.
1508    pub protocol: Protocol,
1509
1510    /// Indicates whether the client requires an SSL connection.
1511    pub tls_required: bool,
1512
1513    /// Connection username (if `auth_required` is set)
1514    pub user: Option<String>,
1515
1516    /// Connection password (if auth_required is set)
1517    pub pass: Option<String>,
1518
1519    /// Client authorization token (if auth_required is set)
1520    pub auth_token: Option<String>,
1521
1522    /// Whether the client supports the usage of headers.
1523    pub headers: bool,
1524
1525    /// Whether the client supports no_responders.
1526    pub no_responders: bool,
1527}
1528
1529/// Protocol version used by the client.
1530#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1531#[repr(u8)]
1532pub enum Protocol {
1533    /// Original protocol.
1534    Original = 0,
1535    /// Protocol with dynamic reconfiguration of cluster and lame duck mode functionality.
1536    Dynamic = 1,
1537}
1538
1539/// Address of a NATS server.
1540#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1541pub struct ServerAddr(Url);
1542
1543impl FromStr for ServerAddr {
1544    type Err = io::Error;
1545
1546    /// Parse an address of a NATS server.
1547    ///
1548    /// If not stated explicitly the `nats://` schema and port `4222` is assumed.
1549    fn from_str(input: &str) -> Result<Self, Self::Err> {
1550        let url: Url = if input.contains("://") {
1551            input.parse()
1552        } else {
1553            format!("nats://{input}").parse()
1554        }
1555        .map_err(|e| {
1556            io::Error::new(
1557                ErrorKind::InvalidInput,
1558                format!("NATS server URL is invalid: {e}"),
1559            )
1560        })?;
1561
1562        Self::from_url(url)
1563    }
1564}
1565
1566impl ServerAddr {
1567    /// Check if the URL is a valid NATS server address.
1568    pub fn from_url(url: Url) -> io::Result<Self> {
1569        if url.scheme() != "nats"
1570            && url.scheme() != "tls"
1571            && url.scheme() != "ws"
1572            && url.scheme() != "wss"
1573        {
1574            return Err(std::io::Error::new(
1575                ErrorKind::InvalidInput,
1576                format!("invalid scheme for NATS server URL: {}", url.scheme()),
1577            ));
1578        }
1579
1580        Ok(Self(url))
1581    }
1582
1583    /// Turn the server address into a standard URL.
1584    pub fn into_inner(self) -> Url {
1585        self.0
1586    }
1587
1588    /// Returns if tls is required by the client for this server.
1589    pub fn tls_required(&self) -> bool {
1590        self.0.scheme() == "tls"
1591    }
1592
1593    /// Returns if the server url had embedded username and password.
1594    pub fn has_user_pass(&self) -> bool {
1595        self.0.username() != ""
1596    }
1597
1598    pub fn scheme(&self) -> &str {
1599        self.0.scheme()
1600    }
1601
1602    /// Returns the host.
1603    pub fn host(&self) -> &str {
1604        match self.0.host() {
1605            Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1606            // `host_str()` for Ipv6 includes the []s
1607            Some(Host::Ipv6 { .. }) => {
1608                let host = self.0.host_str().unwrap();
1609                &host[1..host.len() - 1]
1610            }
1611            None => "",
1612        }
1613    }
1614
1615    pub fn is_websocket(&self) -> bool {
1616        self.0.scheme() == "ws" || self.0.scheme() == "wss"
1617    }
1618
1619    /// Returns the port.
1620    /// Delegates to [`Url::port_or_known_default`](https://docs.rs/url/latest/url/struct.Url.html#method.port_or_known_default) and defaults to 4222 if none was explicitly specified in creating this `ServerAddr`.
1621    pub fn port(&self) -> u16 {
1622        self.0.port_or_known_default().unwrap_or(4222)
1623    }
1624
1625    /// Returns the URL string.
1626    pub fn as_url_str(&self) -> &str {
1627        self.0.as_str()
1628    }
1629
1630    /// Returns the optional username in the url.
1631    pub fn username(&self) -> Option<&str> {
1632        let user = self.0.username();
1633        if user.is_empty() {
1634            None
1635        } else {
1636            Some(user)
1637        }
1638    }
1639
1640    /// Returns the optional password in the url.
1641    pub fn password(&self) -> Option<&str> {
1642        self.0.password()
1643    }
1644
1645    /// Return the sockets from resolving the server address.
1646    pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1647        tokio::net::lookup_host((self.host(), self.port())).await
1648    }
1649}
1650
1651/// Capability to convert into a list of NATS server addresses.
1652///
1653/// There are several implementations ensuring the easy passing of one or more server addresses to
1654/// functions like [`crate::connect()`].
1655pub trait ToServerAddrs {
1656    /// Returned iterator over socket addresses which this type may correspond
1657    /// to.
1658    type Iter: Iterator<Item = ServerAddr>;
1659
1660    fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1661}
1662
1663impl ToServerAddrs for ServerAddr {
1664    type Iter = option::IntoIter<ServerAddr>;
1665    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1666        Ok(Some(self.clone()).into_iter())
1667    }
1668}
1669
1670impl ToServerAddrs for str {
1671    type Iter = option::IntoIter<ServerAddr>;
1672    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1673        self.parse::<ServerAddr>()
1674            .map(|addr| Some(addr).into_iter())
1675    }
1676}
1677
1678impl ToServerAddrs for String {
1679    type Iter = option::IntoIter<ServerAddr>;
1680    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1681        (**self).to_server_addrs()
1682    }
1683}
1684
1685impl<T: AsRef<str>> ToServerAddrs for [T] {
1686    type Iter = std::vec::IntoIter<ServerAddr>;
1687    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1688        self.iter()
1689            .map(AsRef::as_ref)
1690            .map(str::parse)
1691            .collect::<io::Result<_>>()
1692            .map(Vec::into_iter)
1693    }
1694}
1695
1696impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1697    type Iter = std::vec::IntoIter<ServerAddr>;
1698    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1699        self.as_slice().to_server_addrs()
1700    }
1701}
1702
1703impl<'a> ToServerAddrs for &'a [ServerAddr] {
1704    type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1705
1706    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1707        Ok(self.iter().cloned())
1708    }
1709}
1710
1711impl ToServerAddrs for Vec<ServerAddr> {
1712    type Iter = std::vec::IntoIter<ServerAddr>;
1713
1714    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1715        Ok(self.clone().into_iter())
1716    }
1717}
1718
1719impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1720    type Iter = T::Iter;
1721    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1722        (**self).to_server_addrs()
1723    }
1724}
1725
1726pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1727    let subject_str = subject.as_ref();
1728    !subject_str.starts_with('.')
1729        && !subject_str.ends_with('.')
1730        && subject_str.bytes().all(|c| !c.is_ascii_whitespace())
1731}
1732macro_rules! from_with_timeout {
1733    ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1734        impl From<$origin> for $t {
1735            fn from(err: $origin) -> Self {
1736                match err.kind() {
1737                    <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1738                    _ => Self::with_source(<$k>::Other, err),
1739                }
1740            }
1741        }
1742    };
1743}
1744pub(crate) use from_with_timeout;
1745
1746use crate::connection::ShouldFlush;
1747
1748#[cfg(test)]
1749mod tests {
1750    use super::*;
1751
1752    #[test]
1753    fn server_address_ipv6() {
1754        let address = ServerAddr::from_str("nats://[::]").unwrap();
1755        assert_eq!(address.host(), "::")
1756    }
1757
1758    #[test]
1759    fn server_address_ipv4() {
1760        let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1761        assert_eq!(address.host(), "127.0.0.1")
1762    }
1763
1764    #[test]
1765    fn server_address_domain() {
1766        let address = ServerAddr::from_str("nats://example.com").unwrap();
1767        assert_eq!(address.host(), "example.com")
1768    }
1769
1770    #[test]
1771    fn to_server_addrs_vec_str() {
1772        let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1773        let mut addrs_iter = vec.to_server_addrs().unwrap();
1774        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1775        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1776        assert_eq!(addrs_iter.next(), None);
1777    }
1778
1779    #[test]
1780    fn to_server_addrs_arr_str() {
1781        let arr = ["nats://127.0.0.1", "nats://[::]"];
1782        let mut addrs_iter = arr.to_server_addrs().unwrap();
1783        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1784        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1785        assert_eq!(addrs_iter.next(), None);
1786    }
1787
1788    #[test]
1789    fn to_server_addrs_vec_string() {
1790        let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1791        let mut addrs_iter = vec.to_server_addrs().unwrap();
1792        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1793        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1794        assert_eq!(addrs_iter.next(), None);
1795    }
1796
1797    #[test]
1798    fn to_server_addrs_arr_string() {
1799        let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1800        let mut addrs_iter = arr.to_server_addrs().unwrap();
1801        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1802        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1803        assert_eq!(addrs_iter.next(), None);
1804    }
1805
1806    #[test]
1807    fn to_server_ports_arr_string() {
1808        for (arr, expected_port) in [
1809            (
1810                [
1811                    "nats://127.0.0.1".to_string(),
1812                    "nats://[::]".to_string(),
1813                    "tls://127.0.0.1".to_string(),
1814                    "tls://[::]".to_string(),
1815                ],
1816                4222,
1817            ),
1818            (
1819                [
1820                    "ws://127.0.0.1:80".to_string(),
1821                    "ws://[::]:80".to_string(),
1822                    "ws://127.0.0.1".to_string(),
1823                    "ws://[::]".to_string(),
1824                ],
1825                80,
1826            ),
1827            (
1828                [
1829                    "wss://127.0.0.1".to_string(),
1830                    "wss://[::]".to_string(),
1831                    "wss://127.0.0.1:443".to_string(),
1832                    "wss://[::]:443".to_string(),
1833                ],
1834                443,
1835            ),
1836        ] {
1837            let mut addrs_iter = arr.to_server_addrs().unwrap();
1838            assert_eq!(addrs_iter.next().unwrap().port(), expected_port);
1839        }
1840    }
1841}