async_nats/
lib.rs

1// Copyright 2020-2022 The NATS Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! A Rust asynchronous client for the NATS.io ecosystem.
15//!
16//! To access the repository, you can clone it by running:
17//!
18//! ```bash
19//! git clone https://github.com/nats-io/nats.rs
20//! ````
21//! NATS.io is a simple, secure, and high-performance open-source messaging
22//! system designed for cloud-native applications, IoT messaging, and microservices
23//! architectures.
24//!
25//! **Note**: The synchronous NATS API is deprecated and no longer actively maintained. If you need to use the deprecated synchronous API, you can refer to:
26//! <https://crates.io/crates/nats>
27//!
28//! For more information on NATS.io visit: <https://nats.io>
29//!
30//! ## Examples
31//!
32//! Below, you can find some basic examples on how to use this library.
33//!
34//! For more details, please refer to the specific methods and structures documentation.
35//!
36//! ### Complete example
37//!
38//! Connect to the NATS server, publish messages and subscribe to receive messages.
39//!
40//! ```no_run
41//! use bytes::Bytes;
42//! use futures::StreamExt;
43//!
44//! #[tokio::main]
45//! async fn main() -> Result<(), async_nats::Error> {
46//!     // Connect to the NATS server
47//!     let client = async_nats::connect("demo.nats.io").await?;
48//!
49//!     // Subscribe to the "messages" subject
50//!     let mut subscriber = client.subscribe("messages").await?;
51//!
52//!     // Publish messages to the "messages" subject
53//!     for _ in 0..10 {
54//!         client.publish("messages", "data".into()).await?;
55//!     }
56//!
57//!     // Receive and process messages
58//!     while let Some(message) = subscriber.next().await {
59//!         println!("Received message {:?}", message);
60//!     }
61//!
62//!     Ok(())
63//! }
64//! ```
65//!
66//! ### Publish
67//!
68//! Connect to the NATS server and publish messages to a subject.
69//!
70//! ```
71//! # use bytes::Bytes;
72//! # use std::error::Error;
73//! # use std::time::Instant;
74//! # #[tokio::main]
75//! # async fn main() -> Result<(), async_nats::Error> {
76//! // Connect to the NATS server
77//! let client = async_nats::connect("demo.nats.io").await?;
78//!
79//! // Prepare the subject and data
80//! let subject = "foo";
81//! let data = Bytes::from("bar");
82//!
83//! // Publish messages to the NATS server
84//! for _ in 0..10 {
85//!     client.publish(subject, data.clone()).await?;
86//! }
87//!
88//! // Flush internal buffer before exiting to make sure all messages are sent
89//! client.flush().await?;
90//!
91//! #    Ok(())
92//! # }
93//! ```
94//!
95//! ### Subscribe
96//!
97//! Connect to the NATS server, subscribe to a subject and receive messages.
98//!
99//! ```no_run
100//! # use bytes::Bytes;
101//! # use futures::StreamExt;
102//! # use std::error::Error;
103//! # use std::time::Instant;
104//! # #[tokio::main]
105//! # async fn main() -> Result<(), async_nats::Error> {
106//! // Connect to the NATS server
107//! let client = async_nats::connect("demo.nats.io").await?;
108//!
109//! // Subscribe to the "foo" subject
110//! let mut subscriber = client.subscribe("foo").await.unwrap();
111//!
112//! // Receive and process messages
113//! while let Some(message) = subscriber.next().await {
114//!     println!("Received message {:?}", message);
115//! }
116//! #     Ok(())
117//! # }
118//! ```
119//!
120//! ### JetStream
121//!
122//! To access JetStream API, create a JetStream [jetstream::Context].
123//!
124//! ```no_run
125//! # #[tokio::main]
126//! # async fn main() -> Result<(), async_nats::Error> {
127//! // Connect to the NATS server
128//! let client = async_nats::connect("demo.nats.io").await?;
129//! // Create a JetStream context.
130//! let jetstream = async_nats::jetstream::new(client);
131//!
132//! // Publish JetStream messages, manage streams, consumers, etc.
133//! jetstream.publish("foo", "bar".into()).await?;
134//! # Ok(())
135//! # }
136//! ```
137//!
138//! ### Key-value Store
139//!
140//! Key-value [Store][jetstream::kv::Store] is accessed through [jetstream::Context].
141//!
142//! ```no_run
143//! # #[tokio::main]
144//! # async fn main() -> Result<(), async_nats::Error> {
145//! // Connect to the NATS server
146//! let client = async_nats::connect("demo.nats.io").await?;
147//! // Create a JetStream context.
148//! let jetstream = async_nats::jetstream::new(client);
149//! // Access an existing key-value.
150//! let kv = jetstream.get_key_value("store").await?;
151//! # Ok(())
152//! # }
153//! ```
154//! ### Object Store store
155//!
156//! Object [Store][jetstream::object_store::ObjectStore] is accessed through [jetstream::Context].
157//!
158//! ```no_run
159//! # #[tokio::main]
160//! # async fn main() -> Result<(), async_nats::Error> {
161//! // Connect to the NATS server
162//! let client = async_nats::connect("demo.nats.io").await?;
163//! // Create a JetStream context.
164//! let jetstream = async_nats::jetstream::new(client);
165//! // Access an existing key-value.
166//! let kv = jetstream.get_object_store("store").await?;
167//! # Ok(())
168//! # }
169//! ```
170//! ### Service API
171//!
172//! [Service API][service::Service] is accessible through [Client] after importing its trait.
173//!
174//! ```no_run
175//! # #[tokio::main]
176//! # async fn main() -> Result<(), async_nats::Error> {
177//! use async_nats::service::ServiceExt;
178//! // Connect to the NATS server
179//! let client = async_nats::connect("demo.nats.io").await?;
180//! let mut service = client
181//!     .service_builder()
182//!     .description("some service")
183//!     .stats_handler(|endpoint, stats| serde_json::json!({ "endpoint": endpoint }))
184//!     .start("products", "1.0.0")
185//!     .await?;
186//! # Ok(())
187//! # }
188//! ```
189
190#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_auto_cfg))]
196
197use thiserror::Error;
198
199use futures::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use std::collections::HashMap;
206use std::fmt::Display;
207use std::future::Future;
208use std::iter;
209use std::mem;
210use std::net::SocketAddr;
211use std::option;
212use std::pin::Pin;
213use std::slice;
214use std::str::{self, FromStr};
215use std::sync::atomic::AtomicUsize;
216use std::sync::atomic::Ordering;
217use std::sync::Arc;
218use std::task::{Context, Poll};
219use tokio::io::ErrorKind;
220use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
221use url::{Host, Url};
222
223use bytes::Bytes;
224use serde::{Deserialize, Serialize};
225use serde_repr::{Deserialize_repr, Serialize_repr};
226use tokio::io;
227use tokio::sync::mpsc;
228use tokio::task;
229
230pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
231
232const VERSION: &str = env!("CARGO_PKG_VERSION");
233const LANG: &str = "rust";
234const MAX_PENDING_PINGS: usize = 2;
235const MULTIPLEXER_SID: u64 = 0;
236
237/// A re-export of the `rustls` crate used in this crate,
238/// for use in cases where manual client configurations
239/// must be provided using `Options::tls_client_config`.
240pub use tokio_rustls::rustls;
241
242use connection::{Connection, State};
243use connector::{Connector, ConnectorOptions};
244pub use header::{HeaderMap, HeaderName, HeaderValue};
245pub use subject::Subject;
246
247mod auth;
248pub(crate) mod auth_utils;
249pub mod client;
250pub mod connection;
251mod connector;
252mod options;
253
254pub use auth::Auth;
255pub use client::{
256    Client, PublishError, Request, RequestError, RequestErrorKind, Statistics, SubscribeError,
257};
258pub use options::{AuthError, ConnectOptions};
259
260mod crypto;
261pub mod error;
262pub mod header;
263pub mod jetstream;
264pub mod message;
265#[cfg(feature = "service")]
266pub mod service;
267pub mod status;
268pub mod subject;
269mod tls;
270
271pub use message::Message;
272pub use status::StatusCode;
273
274/// Information sent by the server back to this client
275/// during initial connection, and possibly again later.
276#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
277pub struct ServerInfo {
278    /// The unique identifier of the NATS server.
279    #[serde(default)]
280    pub server_id: String,
281    /// Generated Server Name.
282    #[serde(default)]
283    pub server_name: String,
284    /// The host specified in the cluster parameter/options.
285    #[serde(default)]
286    pub host: String,
287    /// The port number specified in the cluster parameter/options.
288    #[serde(default)]
289    pub port: u16,
290    /// The version of the NATS server.
291    #[serde(default)]
292    pub version: String,
293    /// If this is set, then the server should try to authenticate upon
294    /// connect.
295    #[serde(default)]
296    pub auth_required: bool,
297    /// If this is set, then the server must authenticate using TLS.
298    #[serde(default)]
299    pub tls_required: bool,
300    /// Maximum payload size that the server will accept.
301    #[serde(default)]
302    pub max_payload: usize,
303    /// The protocol version in use.
304    #[serde(default)]
305    pub proto: i8,
306    /// The server-assigned client ID. This may change during reconnection.
307    #[serde(default)]
308    pub client_id: u64,
309    /// The version of golang the NATS server was built with.
310    #[serde(default)]
311    pub go: String,
312    /// The nonce used for nkeys.
313    #[serde(default)]
314    pub nonce: String,
315    /// A list of server urls that a client can connect to.
316    #[serde(default)]
317    pub connect_urls: Vec<String>,
318    /// The client IP as known by the server.
319    #[serde(default)]
320    pub client_ip: String,
321    /// Whether the server supports headers.
322    #[serde(default)]
323    pub headers: bool,
324    /// Whether server goes into lame duck mode.
325    #[serde(default, rename = "ldm")]
326    pub lame_duck_mode: bool,
327}
328
329#[derive(Clone, Debug, Eq, PartialEq)]
330pub(crate) enum ServerOp {
331    Ok,
332    Info(Box<ServerInfo>),
333    Ping,
334    Pong,
335    Error(ServerError),
336    Message {
337        sid: u64,
338        subject: Subject,
339        reply: Option<Subject>,
340        payload: Bytes,
341        headers: Option<HeaderMap>,
342        status: Option<StatusCode>,
343        description: Option<String>,
344        length: usize,
345    },
346}
347
348/// `PublishMessage` represents a message being published
349#[derive(Debug)]
350pub struct PublishMessage {
351    pub subject: Subject,
352    pub payload: Bytes,
353    pub reply: Option<Subject>,
354    pub headers: Option<HeaderMap>,
355}
356
357/// `Command` represents all commands that a [`Client`] can handle
358#[derive(Debug)]
359pub(crate) enum Command {
360    Publish(PublishMessage),
361    Request {
362        subject: Subject,
363        payload: Bytes,
364        respond: Subject,
365        headers: Option<HeaderMap>,
366        sender: oneshot::Sender<Message>,
367    },
368    Subscribe {
369        sid: u64,
370        subject: Subject,
371        queue_group: Option<String>,
372        sender: mpsc::Sender<Message>,
373    },
374    Unsubscribe {
375        sid: u64,
376        max: Option<u64>,
377    },
378    Flush {
379        observer: oneshot::Sender<()>,
380    },
381    Drain {
382        sid: Option<u64>,
383    },
384    Reconnect,
385}
386
387/// `ClientOp` represents all actions of `Client`.
388#[derive(Debug)]
389pub(crate) enum ClientOp {
390    Publish {
391        subject: Subject,
392        payload: Bytes,
393        respond: Option<Subject>,
394        headers: Option<HeaderMap>,
395    },
396    Subscribe {
397        sid: u64,
398        subject: Subject,
399        queue_group: Option<String>,
400    },
401    Unsubscribe {
402        sid: u64,
403        max: Option<u64>,
404    },
405    Ping,
406    Pong,
407    Connect(ConnectInfo),
408}
409
410#[derive(Debug)]
411struct Subscription {
412    subject: Subject,
413    sender: mpsc::Sender<Message>,
414    queue_group: Option<String>,
415    delivered: u64,
416    max: Option<u64>,
417    is_draining: bool,
418}
419
420#[derive(Debug)]
421struct Multiplexer {
422    subject: Subject,
423    prefix: Subject,
424    senders: HashMap<String, oneshot::Sender<Message>>,
425}
426
427/// A connection handler which facilitates communication from channels to a single shared connection.
428pub(crate) struct ConnectionHandler {
429    connection: Connection,
430    connector: Connector,
431    subscriptions: HashMap<u64, Subscription>,
432    multiplexer: Option<Multiplexer>,
433    pending_pings: usize,
434    info_sender: tokio::sync::watch::Sender<ServerInfo>,
435    ping_interval: Interval,
436    should_reconnect: bool,
437    flush_observers: Vec<oneshot::Sender<()>>,
438    is_draining: bool,
439}
440
441impl ConnectionHandler {
442    pub(crate) fn new(
443        connection: Connection,
444        connector: Connector,
445        info_sender: tokio::sync::watch::Sender<ServerInfo>,
446        ping_period: Duration,
447    ) -> ConnectionHandler {
448        let mut ping_interval = interval(ping_period);
449        ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
450
451        ConnectionHandler {
452            connection,
453            connector,
454            subscriptions: HashMap::new(),
455            multiplexer: None,
456            pending_pings: 0,
457            info_sender,
458            ping_interval,
459            should_reconnect: false,
460            flush_observers: Vec::new(),
461            is_draining: false,
462        }
463    }
464
465    pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
466        struct ProcessFut<'a> {
467            handler: &'a mut ConnectionHandler,
468            receiver: &'a mut mpsc::Receiver<Command>,
469            recv_buf: &'a mut Vec<Command>,
470        }
471
472        enum ExitReason {
473            Disconnected(Option<io::Error>),
474            ReconnectRequested,
475            Closed,
476        }
477
478        impl ProcessFut<'_> {
479            const RECV_CHUNK_SIZE: usize = 16;
480
481            #[cold]
482            fn ping(&mut self) -> Poll<ExitReason> {
483                self.handler.pending_pings += 1;
484
485                if self.handler.pending_pings > MAX_PENDING_PINGS {
486                    debug!(
487                        pending_pings = self.handler.pending_pings,
488                        max_pings = MAX_PENDING_PINGS,
489                        "disconnecting due to too many pending pings"
490                    );
491
492                    Poll::Ready(ExitReason::Disconnected(None))
493                } else {
494                    self.handler.connection.enqueue_write_op(&ClientOp::Ping);
495
496                    Poll::Pending
497                }
498            }
499        }
500
501        impl Future for ProcessFut<'_> {
502            type Output = ExitReason;
503
504            /// Drives the connection forward.
505            ///
506            /// Returns one of the following:
507            ///
508            /// * `Poll::Pending` means that the connection
509            ///   is blocked on all fronts or there are
510            ///   no commands to send or receive
511            /// * `Poll::Ready(ExitReason::Disconnected(_))` means
512            ///   that an I/O operation failed and the connection
513            ///   is considered dead.
514            /// * `Poll::Ready(ExitReason::Closed)` means that
515            ///   [`Self::receiver`] was closed, so there's nothing
516            ///   more for us to do than to exit the client.
517            fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
518                // We need to be sure the waker is registered, therefore we need to poll until we
519                // get a `Poll::Pending`. With a sane interval delay, this means that the loop
520                // breaks at the second iteration.
521                while self.handler.ping_interval.poll_tick(cx).is_ready() {
522                    if let Poll::Ready(exit) = self.ping() {
523                        return Poll::Ready(exit);
524                    }
525                }
526
527                loop {
528                    match self.handler.connection.poll_read_op(cx) {
529                        Poll::Pending => break,
530                        Poll::Ready(Ok(Some(server_op))) => {
531                            self.handler.handle_server_op(server_op);
532                        }
533                        Poll::Ready(Ok(None)) => {
534                            return Poll::Ready(ExitReason::Disconnected(None))
535                        }
536                        Poll::Ready(Err(err)) => {
537                            return Poll::Ready(ExitReason::Disconnected(Some(err)))
538                        }
539                    }
540                }
541
542                // Before handling any commands, drop any subscriptions which are draining
543                // Note: safe to assume subscription drain has completed at this point, as we would have flushed
544                // all outgoing UNSUB messages in the previous call to this fn, and we would have processed and
545                // delivered any remaining messages to the subscription in the loop above.
546                self.handler.subscriptions.retain(|_, s| !s.is_draining);
547
548                if self.handler.is_draining {
549                    // The entire connection is draining. This means we flushed outgoing messages in the previous
550                    // call to this fn, we handled any remaining messages from the server in the loop above, and
551                    // all subs were drained, so drain is complete and we should exit instead of processing any
552                    // further messages
553                    return Poll::Ready(ExitReason::Closed);
554                }
555
556                // WARNING: after the following loop `handle_command`,
557                // or other functions which call `enqueue_write_op`,
558                // cannot be called anymore. Runtime wakeups won't
559                // trigger a call to `poll_write`
560
561                let mut made_progress = true;
562                loop {
563                    while !self.handler.connection.is_write_buf_full() {
564                        debug_assert!(self.recv_buf.is_empty());
565
566                        let Self {
567                            recv_buf,
568                            handler,
569                            receiver,
570                        } = &mut *self;
571                        match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
572                            Poll::Pending => break,
573                            Poll::Ready(1..) => {
574                                made_progress = true;
575
576                                for cmd in recv_buf.drain(..) {
577                                    handler.handle_command(cmd);
578                                }
579                            }
580                            // TODO: replace `_` with `0` after bumping MSRV to 1.75
581                            Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
582                        }
583                    }
584
585                    // The first round will poll both from
586                    // the `receiver` and the writer, giving
587                    // them both a chance to make progress
588                    // and register `Waker`s.
589                    //
590                    // If writing is `Poll::Pending` we exit.
591                    //
592                    // If writing is completed we can repeat the entire
593                    // cycle as long as the `receiver` doesn't end-up
594                    // `Poll::Pending` immediately.
595                    if !mem::take(&mut made_progress) {
596                        break;
597                    }
598
599                    match self.handler.connection.poll_write(cx) {
600                        Poll::Pending => {
601                            // Write buffer couldn't be fully emptied
602                            break;
603                        }
604                        Poll::Ready(Ok(())) => {
605                            // Write buffer is empty
606                            continue;
607                        }
608                        Poll::Ready(Err(err)) => {
609                            return Poll::Ready(ExitReason::Disconnected(Some(err)))
610                        }
611                    }
612                }
613
614                if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
615                    self.handler.connection.should_flush(),
616                    self.handler.flush_observers.is_empty(),
617                ) {
618                    match self.handler.connection.poll_flush(cx) {
619                        Poll::Pending => {}
620                        Poll::Ready(Ok(())) => {
621                            for observer in self.handler.flush_observers.drain(..) {
622                                let _ = observer.send(());
623                            }
624                        }
625                        Poll::Ready(Err(err)) => {
626                            return Poll::Ready(ExitReason::Disconnected(Some(err)))
627                        }
628                    }
629                }
630
631                if mem::take(&mut self.handler.should_reconnect) {
632                    return Poll::Ready(ExitReason::ReconnectRequested);
633                }
634
635                Poll::Pending
636            }
637        }
638
639        let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
640        loop {
641            let process = ProcessFut {
642                handler: self,
643                receiver,
644                recv_buf: &mut recv_buf,
645            };
646            match process.await {
647                ExitReason::Disconnected(err) => {
648                    debug!(error = ?err, "disconnected");
649                    if self.handle_disconnect().await.is_err() {
650                        break;
651                    };
652                    debug!("reconnected");
653                }
654                ExitReason::Closed => {
655                    // Safe to ignore result as we're shutting down anyway
656                    self.connector.events_tx.try_send(Event::Closed).ok();
657                    break;
658                }
659                ExitReason::ReconnectRequested => {
660                    debug!("reconnect requested");
661                    // Should be ok to ingore error, as that means we are not in connected state.
662                    self.connection.stream.shutdown().await.ok();
663                    if self.handle_disconnect().await.is_err() {
664                        break;
665                    };
666                }
667            }
668        }
669    }
670
671    fn handle_server_op(&mut self, server_op: ServerOp) {
672        self.ping_interval.reset();
673
674        match server_op {
675            ServerOp::Ping => {
676                debug!("received PING");
677                self.connection.enqueue_write_op(&ClientOp::Pong);
678            }
679            ServerOp::Pong => {
680                debug!("received PONG");
681                self.pending_pings = self.pending_pings.saturating_sub(1);
682            }
683            ServerOp::Error(error) => {
684                debug!("received ERROR: {:?}", error);
685                self.connector
686                    .events_tx
687                    .try_send(Event::ServerError(error))
688                    .ok();
689            }
690            ServerOp::Message {
691                sid,
692                subject,
693                reply,
694                payload,
695                headers,
696                status,
697                description,
698                length,
699            } => {
700                debug!("received MESSAGE: sid={}, subject={}", sid, subject);
701                self.connector
702                    .connect_stats
703                    .in_messages
704                    .add(1, Ordering::Relaxed);
705
706                if let Some(subscription) = self.subscriptions.get_mut(&sid) {
707                    let message: Message = Message {
708                        subject,
709                        reply,
710                        payload,
711                        headers,
712                        status,
713                        description,
714                        length,
715                    };
716
717                    // if the channel for subscription was dropped, remove the
718                    // subscription from the map and unsubscribe.
719                    match subscription.sender.try_send(message) {
720                        Ok(_) => {
721                            subscription.delivered += 1;
722                            // if this `Subscription` has set `max` value, check if it
723                            // was reached. If yes, remove the `Subscription` and in
724                            // the result, `drop` the `sender` channel.
725                            if let Some(max) = subscription.max {
726                                if subscription.delivered.ge(&max) {
727                                    debug!("max messages reached for subscription {}", sid);
728                                    self.subscriptions.remove(&sid);
729                                }
730                            }
731                        }
732                        Err(mpsc::error::TrySendError::Full(_)) => {
733                            debug!("slow consumer detected for subscription {}", sid);
734                            self.connector
735                                .events_tx
736                                .try_send(Event::SlowConsumer(sid))
737                                .ok();
738                        }
739                        Err(mpsc::error::TrySendError::Closed(_)) => {
740                            debug!("subscription {} channel closed", sid);
741                            self.subscriptions.remove(&sid);
742                            self.connection
743                                .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
744                        }
745                    }
746                } else if sid == MULTIPLEXER_SID {
747                    debug!("received message for multiplexer");
748                    if let Some(multiplexer) = self.multiplexer.as_mut() {
749                        let maybe_token =
750                            subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
751
752                        if let Some(token) = maybe_token {
753                            if let Some(sender) = multiplexer.senders.remove(token) {
754                                debug!("forwarding message to request with token {}", token);
755                                let message = Message {
756                                    subject,
757                                    reply,
758                                    payload,
759                                    headers,
760                                    status,
761                                    description,
762                                    length,
763                                };
764
765                                let _ = sender.send(message);
766                            }
767                        }
768                    }
769                }
770            }
771            // TODO: we should probably update advertised server list here too.
772            ServerOp::Info(info) => {
773                debug!("received INFO: server_id={}", info.server_id);
774                if info.lame_duck_mode {
775                    debug!("server in lame duck mode");
776                    self.connector.events_tx.try_send(Event::LameDuckMode).ok();
777                }
778            }
779
780            _ => {
781                // TODO: don't ignore.
782            }
783        }
784    }
785
786    fn handle_command(&mut self, command: Command) {
787        self.ping_interval.reset();
788
789        match command {
790            Command::Unsubscribe { sid, max } => {
791                if let Some(subscription) = self.subscriptions.get_mut(&sid) {
792                    subscription.max = max;
793                    match subscription.max {
794                        Some(n) => {
795                            if subscription.delivered >= n {
796                                self.subscriptions.remove(&sid);
797                            }
798                        }
799                        None => {
800                            self.subscriptions.remove(&sid);
801                        }
802                    }
803
804                    self.connection
805                        .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
806                }
807            }
808            Command::Flush { observer } => {
809                self.flush_observers.push(observer);
810            }
811            Command::Drain { sid } => {
812                let mut drain_sub = |sid: u64, sub: &mut Subscription| {
813                    sub.is_draining = true;
814                    self.connection
815                        .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
816                };
817
818                if let Some(sid) = sid {
819                    if let Some(sub) = self.subscriptions.get_mut(&sid) {
820                        drain_sub(sid, sub);
821                    }
822                } else {
823                    // sid isn't set, so drain the whole client
824                    self.connector.events_tx.try_send(Event::Draining).ok();
825                    self.is_draining = true;
826                    for (&sid, sub) in self.subscriptions.iter_mut() {
827                        drain_sub(sid, sub);
828                    }
829                }
830            }
831            Command::Subscribe {
832                sid,
833                subject,
834                queue_group,
835                sender,
836            } => {
837                let subscription = Subscription {
838                    sender,
839                    delivered: 0,
840                    max: None,
841                    subject: subject.to_owned(),
842                    queue_group: queue_group.to_owned(),
843                    is_draining: false,
844                };
845
846                self.subscriptions.insert(sid, subscription);
847
848                self.connection.enqueue_write_op(&ClientOp::Subscribe {
849                    sid,
850                    subject,
851                    queue_group,
852                });
853            }
854            Command::Request {
855                subject,
856                payload,
857                respond,
858                headers,
859                sender,
860            } => {
861                let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
862
863                let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
864                    multiplexer
865                } else {
866                    let prefix = Subject::from(format!("{}.{}.", prefix, nuid::next()));
867                    let subject = Subject::from(format!("{}*", prefix));
868
869                    self.connection.enqueue_write_op(&ClientOp::Subscribe {
870                        sid: MULTIPLEXER_SID,
871                        subject: subject.clone(),
872                        queue_group: None,
873                    });
874
875                    self.multiplexer.insert(Multiplexer {
876                        subject,
877                        prefix,
878                        senders: HashMap::new(),
879                    })
880                };
881                self.connector
882                    .connect_stats
883                    .out_messages
884                    .add(1, Ordering::Relaxed);
885
886                multiplexer.senders.insert(token.to_owned(), sender);
887
888                let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
889
890                let pub_op = ClientOp::Publish {
891                    subject,
892                    payload,
893                    respond: Some(respond),
894                    headers,
895                };
896
897                self.connection.enqueue_write_op(&pub_op);
898            }
899
900            Command::Publish(PublishMessage {
901                subject,
902                payload,
903                reply: respond,
904                headers,
905            }) => {
906                self.connector
907                    .connect_stats
908                    .out_messages
909                    .add(1, Ordering::Relaxed);
910
911                let header_len = headers
912                    .as_ref()
913                    .map(|headers| headers.len())
914                    .unwrap_or_default();
915
916                self.connector.connect_stats.out_bytes.add(
917                    (payload.len()
918                        + respond.as_ref().map_or_else(|| 0, |r| r.len())
919                        + subject.len()
920                        + header_len) as u64,
921                    Ordering::Relaxed,
922                );
923
924                self.connection.enqueue_write_op(&ClientOp::Publish {
925                    subject,
926                    payload,
927                    respond,
928                    headers,
929                });
930            }
931
932            Command::Reconnect => {
933                self.should_reconnect = true;
934            }
935        }
936    }
937
938    async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
939        self.pending_pings = 0;
940        self.connector.events_tx.try_send(Event::Disconnected).ok();
941        self.connector.state_tx.send(State::Disconnected).ok();
942
943        self.handle_reconnect().await
944    }
945
946    async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
947        let (info, connection) = self.connector.connect().await?;
948        self.connection = connection;
949        let _ = self.info_sender.send(info);
950
951        self.subscriptions
952            .retain(|_, subscription| !subscription.sender.is_closed());
953
954        for (sid, subscription) in &self.subscriptions {
955            self.connection.enqueue_write_op(&ClientOp::Subscribe {
956                sid: *sid,
957                subject: subscription.subject.to_owned(),
958                queue_group: subscription.queue_group.to_owned(),
959            });
960        }
961
962        if let Some(multiplexer) = &self.multiplexer {
963            self.connection.enqueue_write_op(&ClientOp::Subscribe {
964                sid: MULTIPLEXER_SID,
965                subject: multiplexer.subject.to_owned(),
966                queue_group: None,
967            });
968        }
969        Ok(())
970    }
971}
972
973/// Connects to NATS with specified options.
974///
975/// It is generally advised to use [ConnectOptions] instead, as it provides a builder for whole
976/// configuration.
977///
978/// # Examples
979/// ```
980/// # #[tokio::main]
981/// # async fn main() ->  Result<(), async_nats::Error> {
982/// let mut nc =
983///     async_nats::connect_with_options("demo.nats.io", async_nats::ConnectOptions::new()).await?;
984/// nc.publish("test", "data".into()).await?;
985/// # Ok(())
986/// # }
987/// ```
988pub async fn connect_with_options<A: ToServerAddrs>(
989    addrs: A,
990    options: ConnectOptions,
991) -> Result<Client, ConnectError> {
992    let ping_period = options.ping_interval;
993
994    let (events_tx, mut events_rx) = mpsc::channel(128);
995    let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
996    // We're setting it to the default server payload size.
997    let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));
998    let statistics = Arc::new(Statistics::default());
999
1000    let mut connector = Connector::new(
1001        addrs,
1002        ConnectorOptions {
1003            tls_required: options.tls_required,
1004            certificates: options.certificates,
1005            client_key: options.client_key,
1006            client_cert: options.client_cert,
1007            tls_client_config: options.tls_client_config,
1008            tls_first: options.tls_first,
1009            auth: options.auth,
1010            no_echo: options.no_echo,
1011            connection_timeout: options.connection_timeout,
1012            name: options.name,
1013            ignore_discovered_servers: options.ignore_discovered_servers,
1014            retain_servers_order: options.retain_servers_order,
1015            read_buffer_capacity: options.read_buffer_capacity,
1016            reconnect_delay_callback: options.reconnect_delay_callback,
1017            auth_callback: options.auth_callback,
1018            max_reconnects: options.max_reconnects,
1019        },
1020        events_tx,
1021        state_tx,
1022        max_payload.clone(),
1023        statistics.clone(),
1024    )
1025    .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
1026
1027    let mut info: ServerInfo = Default::default();
1028    let mut connection = None;
1029    if !options.retry_on_initial_connect {
1030        debug!("retry on initial connect failure is disabled");
1031        let (info_ok, connection_ok) = connector.try_connect().await?;
1032        connection = Some(connection_ok);
1033        info = info_ok;
1034    }
1035
1036    let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
1037    let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
1038
1039    let client = Client::new(
1040        info_watcher,
1041        state_rx,
1042        sender,
1043        options.subscription_capacity,
1044        options.inbox_prefix,
1045        options.request_timeout,
1046        max_payload,
1047        statistics,
1048    );
1049
1050    task::spawn(async move {
1051        while let Some(event) = events_rx.recv().await {
1052            tracing::info!("event: {}", event);
1053            if let Some(event_callback) = &options.event_callback {
1054                event_callback.call(event).await;
1055            }
1056        }
1057    });
1058
1059    task::spawn(async move {
1060        if connection.is_none() && options.retry_on_initial_connect {
1061            let (info, connection_ok) = match connector.connect().await {
1062                Ok((info, connection)) => (info, connection),
1063                Err(err) => {
1064                    error!("connection closed: {}", err);
1065                    return;
1066                }
1067            };
1068            info_sender.send(info).ok();
1069            connection = Some(connection_ok);
1070        }
1071        let connection = connection.unwrap();
1072        let mut connection_handler =
1073            ConnectionHandler::new(connection, connector, info_sender, ping_period);
1074        connection_handler.process(&mut receiver).await
1075    });
1076
1077    Ok(client)
1078}
1079
1080#[derive(Debug, Clone, PartialEq, Eq)]
1081pub enum Event {
1082    Connected,
1083    Disconnected,
1084    LameDuckMode,
1085    Draining,
1086    Closed,
1087    SlowConsumer(u64),
1088    ServerError(ServerError),
1089    ClientError(ClientError),
1090}
1091
1092impl fmt::Display for Event {
1093    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1094        match self {
1095            Event::Connected => write!(f, "connected"),
1096            Event::Disconnected => write!(f, "disconnected"),
1097            Event::LameDuckMode => write!(f, "lame duck mode detected"),
1098            Event::Draining => write!(f, "draining"),
1099            Event::Closed => write!(f, "closed"),
1100            Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"),
1101            Event::ServerError(err) => write!(f, "server error: {err}"),
1102            Event::ClientError(err) => write!(f, "client error: {err}"),
1103        }
1104    }
1105}
1106
1107/// Connects to NATS with default config.
1108///
1109/// Returns cloneable [Client].
1110///
1111/// To have customized NATS connection, check [ConnectOptions].
1112///
1113/// # Examples
1114///
1115/// ## Single URL
1116/// ```
1117/// # #[tokio::main]
1118/// # async fn main() ->  Result<(), async_nats::Error> {
1119/// let mut nc = async_nats::connect("demo.nats.io").await?;
1120/// nc.publish("test", "data".into()).await?;
1121/// # Ok(())
1122/// # }
1123/// ```
1124///
1125/// ## Connect with [Vec] of [ServerAddr].
1126/// ```no_run
1127/// #[tokio::main]
1128/// # async fn main() -> Result<(), async_nats::Error> {
1129/// use async_nats::ServerAddr;
1130/// let client = async_nats::connect(vec![
1131///     "demo.nats.io".parse::<ServerAddr>()?,
1132///     "other.nats.io".parse::<ServerAddr>()?,
1133/// ])
1134/// .await
1135/// .unwrap();
1136/// # Ok(())
1137/// # }
1138/// ```
1139///
1140/// ## with [Vec], but parse URLs inside [crate::connect()]
1141/// ```no_run
1142/// #[tokio::main]
1143/// # async fn main() -> Result<(), async_nats::Error> {
1144/// use async_nats::ServerAddr;
1145/// let servers = vec!["demo.nats.io", "other.nats.io"];
1146/// let client = async_nats::connect(
1147///     servers
1148///         .iter()
1149///         .map(|url| url.parse())
1150///         .collect::<Result<Vec<ServerAddr>, _>>()?,
1151/// )
1152/// .await?;
1153/// # Ok(())
1154/// # }
1155/// ```
1156///
1157///
1158/// ## with slice.
1159/// ```no_run
1160/// #[tokio::main]
1161/// # async fn main() -> Result<(), async_nats::Error> {
1162/// use async_nats::ServerAddr;
1163/// let client = async_nats::connect(
1164///    [
1165///        "demo.nats.io".parse::<ServerAddr>()?,
1166///        "other.nats.io".parse::<ServerAddr>()?,
1167///    ]
1168///    .as_slice(),
1169/// )
1170/// .await?;
1171/// # Ok(())
1172/// # }
1173pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1174    connect_with_options(addrs, ConnectOptions::default()).await
1175}
1176
1177#[derive(Debug, Clone, Copy, PartialEq)]
1178pub enum ConnectErrorKind {
1179    /// Parsing the passed server address failed.
1180    ServerParse,
1181    /// DNS related issues.
1182    Dns,
1183    /// Failed authentication process, signing nonce, etc.
1184    Authentication,
1185    /// Server returned authorization violation error.
1186    AuthorizationViolation,
1187    /// Connect timed out.
1188    TimedOut,
1189    /// Erroneous TLS setup.
1190    Tls,
1191    /// Other IO error.
1192    Io,
1193    /// Reached the maximum number of reconnects.
1194    MaxReconnects,
1195}
1196
1197impl Display for ConnectErrorKind {
1198    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1199        match self {
1200            Self::ServerParse => write!(f, "failed to parse server or server list"),
1201            Self::Dns => write!(f, "DNS error"),
1202            Self::Authentication => write!(f, "failed signing nonce"),
1203            Self::AuthorizationViolation => write!(f, "authorization violation"),
1204            Self::TimedOut => write!(f, "timed out"),
1205            Self::Tls => write!(f, "TLS error"),
1206            Self::Io => write!(f, "IO error"),
1207            Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1208        }
1209    }
1210}
1211
1212/// Returned when initial connection fails.
1213/// To be enumerate over the variants, call [ConnectError::kind].
1214pub type ConnectError = error::Error<ConnectErrorKind>;
1215
1216impl From<io::Error> for ConnectError {
1217    fn from(err: io::Error) -> Self {
1218        ConnectError::with_source(ConnectErrorKind::Io, err)
1219    }
1220}
1221
1222/// Retrieves messages from given `subscription` created by [Client::subscribe].
1223///
1224/// Implements [futures::stream::Stream] for ergonomic async message processing.
1225///
1226/// # Examples
1227/// ```
1228/// # #[tokio::main]
1229/// # async fn main() ->  Result<(), async_nats::Error> {
1230/// let mut nc = async_nats::connect("demo.nats.io").await?;
1231/// # nc.publish("test", "data".into()).await?;
1232/// # Ok(())
1233/// # }
1234/// ```
1235#[derive(Debug)]
1236pub struct Subscriber {
1237    sid: u64,
1238    receiver: mpsc::Receiver<Message>,
1239    sender: mpsc::Sender<Command>,
1240}
1241
1242impl Subscriber {
1243    fn new(
1244        sid: u64,
1245        sender: mpsc::Sender<Command>,
1246        receiver: mpsc::Receiver<Message>,
1247    ) -> Subscriber {
1248        Subscriber {
1249            sid,
1250            sender,
1251            receiver,
1252        }
1253    }
1254
1255    /// Unsubscribes from subscription, draining all remaining messages.
1256    ///
1257    /// # Examples
1258    /// ```
1259    /// # #[tokio::main]
1260    /// # async fn main() -> Result<(), async_nats::Error> {
1261    /// let client = async_nats::connect("demo.nats.io").await?;
1262    ///
1263    /// let mut subscriber = client.subscribe("foo").await?;
1264    ///
1265    /// subscriber.unsubscribe().await?;
1266    /// # Ok(())
1267    /// # }
1268    /// ```
1269    pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1270        self.sender
1271            .send(Command::Unsubscribe {
1272                sid: self.sid,
1273                max: None,
1274            })
1275            .await?;
1276        self.receiver.close();
1277        Ok(())
1278    }
1279
1280    /// Unsubscribes from subscription after reaching given number of messages.
1281    /// This is the total number of messages received by this subscription in it's whole
1282    /// lifespan. If it already reached or surpassed the passed value, it will immediately stop.
1283    ///
1284    /// # Examples
1285    /// ```
1286    /// # use futures::StreamExt;
1287    /// # #[tokio::main]
1288    /// # async fn main() -> Result<(), async_nats::Error> {
1289    /// let client = async_nats::connect("demo.nats.io").await?;
1290    ///
1291    /// let mut subscriber = client.subscribe("test").await?;
1292    /// subscriber.unsubscribe_after(3).await?;
1293    ///
1294    /// for _ in 0..3 {
1295    ///     client.publish("test", "data".into()).await?;
1296    /// }
1297    ///
1298    /// while let Some(message) = subscriber.next().await {
1299    ///     println!("message received: {:?}", message);
1300    /// }
1301    /// println!("no more messages, unsubscribed");
1302    /// # Ok(())
1303    /// # }
1304    /// ```
1305    pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1306        self.sender
1307            .send(Command::Unsubscribe {
1308                sid: self.sid,
1309                max: Some(unsub_after),
1310            })
1311            .await?;
1312        Ok(())
1313    }
1314
1315    /// Unsubscribes immediately but leaves the subscription open to allow any in-flight messages
1316    /// on the subscription to be delivered. The stream will be closed after any remaining messages
1317    /// are delivered
1318    ///
1319    /// # Examples
1320    /// ```no_run
1321    /// # use futures::StreamExt;
1322    /// # #[tokio::main]
1323    /// # async fn main() -> Result<(), async_nats::Error> {
1324    /// let client = async_nats::connect("demo.nats.io").await?;
1325    ///
1326    /// let mut subscriber = client.subscribe("test").await?;
1327    ///
1328    /// tokio::spawn({
1329    ///     let task_client = client.clone();
1330    ///     async move {
1331    ///         loop {
1332    ///             _ = task_client.publish("test", "data".into()).await;
1333    ///         }
1334    ///     }
1335    /// });
1336    ///
1337    /// client.flush().await?;
1338    /// subscriber.drain().await?;
1339    ///
1340    /// while let Some(message) = subscriber.next().await {
1341    ///     println!("message received: {:?}", message);
1342    /// }
1343    /// println!("no more messages, unsubscribed");
1344    /// # Ok(())
1345    /// # }
1346    /// ```
1347    pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
1348        self.sender
1349            .send(Command::Drain {
1350                sid: Some(self.sid),
1351            })
1352            .await?;
1353
1354        Ok(())
1355    }
1356}
1357
1358#[derive(Error, Debug, PartialEq)]
1359#[error("failed to send unsubscribe")]
1360pub struct UnsubscribeError(String);
1361
1362impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1363    fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1364        UnsubscribeError(err.to_string())
1365    }
1366}
1367
1368impl Drop for Subscriber {
1369    fn drop(&mut self) {
1370        self.receiver.close();
1371        tokio::spawn({
1372            let sender = self.sender.clone();
1373            let sid = self.sid;
1374            async move {
1375                sender
1376                    .send(Command::Unsubscribe { sid, max: None })
1377                    .await
1378                    .ok();
1379            }
1380        });
1381    }
1382}
1383
1384impl Stream for Subscriber {
1385    type Item = Message;
1386
1387    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1388        self.receiver.poll_recv(cx)
1389    }
1390}
1391
1392#[derive(Clone, Debug, Eq, PartialEq)]
1393pub enum CallbackError {
1394    Client(ClientError),
1395    Server(ServerError),
1396}
1397impl std::fmt::Display for CallbackError {
1398    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1399        match self {
1400            Self::Client(error) => write!(f, "{error}"),
1401            Self::Server(error) => write!(f, "{error}"),
1402        }
1403    }
1404}
1405
1406impl From<ServerError> for CallbackError {
1407    fn from(server_error: ServerError) -> Self {
1408        CallbackError::Server(server_error)
1409    }
1410}
1411
1412impl From<ClientError> for CallbackError {
1413    fn from(client_error: ClientError) -> Self {
1414        CallbackError::Client(client_error)
1415    }
1416}
1417
1418#[derive(Clone, Debug, Eq, PartialEq, Error)]
1419pub enum ServerError {
1420    AuthorizationViolation,
1421    SlowConsumer(u64),
1422    Other(String),
1423}
1424
1425#[derive(Clone, Debug, Eq, PartialEq)]
1426pub enum ClientError {
1427    Other(String),
1428    MaxReconnects,
1429}
1430impl std::fmt::Display for ClientError {
1431    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1432        match self {
1433            Self::Other(error) => write!(f, "nats: {error}"),
1434            Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1435        }
1436    }
1437}
1438
1439impl ServerError {
1440    fn new(error: String) -> ServerError {
1441        match error.to_lowercase().as_str() {
1442            "authorization violation" => ServerError::AuthorizationViolation,
1443            // error messages can contain case-sensitive values which should be preserved
1444            _ => ServerError::Other(error),
1445        }
1446    }
1447}
1448
1449impl std::fmt::Display for ServerError {
1450    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1451        match self {
1452            Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1453            Self::SlowConsumer(sid) => write!(f, "nats: subscription {sid} is a slow consumer"),
1454            Self::Other(error) => write!(f, "nats: {error}"),
1455        }
1456    }
1457}
1458
1459/// Info to construct a CONNECT message.
1460#[derive(Clone, Debug, Serialize)]
1461pub struct ConnectInfo {
1462    /// Turns on +OK protocol acknowledgments.
1463    pub verbose: bool,
1464
1465    /// Turns on additional strict format checking, e.g. for properly formed
1466    /// subjects.
1467    pub pedantic: bool,
1468
1469    /// User's JWT.
1470    #[serde(rename = "jwt")]
1471    pub user_jwt: Option<String>,
1472
1473    /// Public nkey.
1474    pub nkey: Option<String>,
1475
1476    /// Signed nonce, encoded to Base64URL.
1477    #[serde(rename = "sig")]
1478    pub signature: Option<String>,
1479
1480    /// Optional client name.
1481    pub name: Option<String>,
1482
1483    /// If set to `true`, the server (version 1.2.0+) will not send originating
1484    /// messages from this connection to its own subscriptions. Clients should
1485    /// set this to `true` only for server supporting this feature, which is
1486    /// when proto in the INFO protocol is set to at least 1.
1487    pub echo: bool,
1488
1489    /// The implementation language of the client.
1490    pub lang: String,
1491
1492    /// The version of the client.
1493    pub version: String,
1494
1495    /// Sending 0 (or absent) indicates client supports original protocol.
1496    /// Sending 1 indicates that the client supports dynamic reconfiguration
1497    /// of cluster topology changes by asynchronously receiving INFO messages
1498    /// with known servers it can reconnect to.
1499    pub protocol: Protocol,
1500
1501    /// Indicates whether the client requires an SSL connection.
1502    pub tls_required: bool,
1503
1504    /// Connection username (if `auth_required` is set)
1505    pub user: Option<String>,
1506
1507    /// Connection password (if auth_required is set)
1508    pub pass: Option<String>,
1509
1510    /// Client authorization token (if auth_required is set)
1511    pub auth_token: Option<String>,
1512
1513    /// Whether the client supports the usage of headers.
1514    pub headers: bool,
1515
1516    /// Whether the client supports no_responders.
1517    pub no_responders: bool,
1518}
1519
1520/// Protocol version used by the client.
1521#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1522#[repr(u8)]
1523pub enum Protocol {
1524    /// Original protocol.
1525    Original = 0,
1526    /// Protocol with dynamic reconfiguration of cluster and lame duck mode functionality.
1527    Dynamic = 1,
1528}
1529
1530/// Address of a NATS server.
1531#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1532pub struct ServerAddr(Url);
1533
1534impl FromStr for ServerAddr {
1535    type Err = io::Error;
1536
1537    /// Parse an address of a NATS server.
1538    ///
1539    /// If not stated explicitly the `nats://` schema and port `4222` is assumed.
1540    fn from_str(input: &str) -> Result<Self, Self::Err> {
1541        let url: Url = if input.contains("://") {
1542            input.parse()
1543        } else {
1544            format!("nats://{input}").parse()
1545        }
1546        .map_err(|e| {
1547            io::Error::new(
1548                ErrorKind::InvalidInput,
1549                format!("NATS server URL is invalid: {e}"),
1550            )
1551        })?;
1552
1553        Self::from_url(url)
1554    }
1555}
1556
1557impl ServerAddr {
1558    /// Check if the URL is a valid NATS server address.
1559    pub fn from_url(url: Url) -> io::Result<Self> {
1560        if url.scheme() != "nats"
1561            && url.scheme() != "tls"
1562            && url.scheme() != "ws"
1563            && url.scheme() != "wss"
1564        {
1565            return Err(std::io::Error::new(
1566                ErrorKind::InvalidInput,
1567                format!("invalid scheme for NATS server URL: {}", url.scheme()),
1568            ));
1569        }
1570
1571        Ok(Self(url))
1572    }
1573
1574    /// Turn the server address into a standard URL.
1575    pub fn into_inner(self) -> Url {
1576        self.0
1577    }
1578
1579    /// Returns if tls is required by the client for this server.
1580    pub fn tls_required(&self) -> bool {
1581        self.0.scheme() == "tls"
1582    }
1583
1584    /// Returns if the server url had embedded username and password.
1585    pub fn has_user_pass(&self) -> bool {
1586        self.0.username() != ""
1587    }
1588
1589    pub fn scheme(&self) -> &str {
1590        self.0.scheme()
1591    }
1592
1593    /// Returns the host.
1594    pub fn host(&self) -> &str {
1595        match self.0.host() {
1596            Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1597            // `host_str()` for Ipv6 includes the []s
1598            Some(Host::Ipv6 { .. }) => {
1599                let host = self.0.host_str().unwrap();
1600                &host[1..host.len() - 1]
1601            }
1602            None => "",
1603        }
1604    }
1605
1606    pub fn is_websocket(&self) -> bool {
1607        self.0.scheme() == "ws" || self.0.scheme() == "wss"
1608    }
1609
1610    /// Returns the port.
1611    /// Delegates to [`Url::port_or_known_default`](https://docs.rs/url/latest/url/struct.Url.html#method.port_or_known_default) and defaults to 4222 if none was explicitly specified in creating this `ServerAddr`.
1612    pub fn port(&self) -> u16 {
1613        self.0.port_or_known_default().unwrap_or(4222)
1614    }
1615
1616    /// Returns the optional username in the url.
1617    pub fn username(&self) -> Option<&str> {
1618        let user = self.0.username();
1619        if user.is_empty() {
1620            None
1621        } else {
1622            Some(user)
1623        }
1624    }
1625
1626    /// Returns the optional password in the url.
1627    pub fn password(&self) -> Option<&str> {
1628        self.0.password()
1629    }
1630
1631    /// Return the sockets from resolving the server address.
1632    pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1633        tokio::net::lookup_host((self.host(), self.port())).await
1634    }
1635}
1636
1637/// Capability to convert into a list of NATS server addresses.
1638///
1639/// There are several implementations ensuring the easy passing of one or more server addresses to
1640/// functions like [`crate::connect()`].
1641pub trait ToServerAddrs {
1642    /// Returned iterator over socket addresses which this type may correspond
1643    /// to.
1644    type Iter: Iterator<Item = ServerAddr>;
1645
1646    fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1647}
1648
1649impl ToServerAddrs for ServerAddr {
1650    type Iter = option::IntoIter<ServerAddr>;
1651    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1652        Ok(Some(self.clone()).into_iter())
1653    }
1654}
1655
1656impl ToServerAddrs for str {
1657    type Iter = option::IntoIter<ServerAddr>;
1658    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1659        self.parse::<ServerAddr>()
1660            .map(|addr| Some(addr).into_iter())
1661    }
1662}
1663
1664impl ToServerAddrs for String {
1665    type Iter = option::IntoIter<ServerAddr>;
1666    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1667        (**self).to_server_addrs()
1668    }
1669}
1670
1671impl<T: AsRef<str>> ToServerAddrs for [T] {
1672    type Iter = std::vec::IntoIter<ServerAddr>;
1673    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1674        self.iter()
1675            .map(AsRef::as_ref)
1676            .map(str::parse)
1677            .collect::<io::Result<_>>()
1678            .map(Vec::into_iter)
1679    }
1680}
1681
1682impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1683    type Iter = std::vec::IntoIter<ServerAddr>;
1684    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1685        self.as_slice().to_server_addrs()
1686    }
1687}
1688
1689impl<'a> ToServerAddrs for &'a [ServerAddr] {
1690    type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1691
1692    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1693        Ok(self.iter().cloned())
1694    }
1695}
1696
1697impl ToServerAddrs for Vec<ServerAddr> {
1698    type Iter = std::vec::IntoIter<ServerAddr>;
1699
1700    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1701        Ok(self.clone().into_iter())
1702    }
1703}
1704
1705impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1706    type Iter = T::Iter;
1707    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1708        (**self).to_server_addrs()
1709    }
1710}
1711
1712pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1713    let subject_str = subject.as_ref();
1714    !subject_str.starts_with('.')
1715        && !subject_str.ends_with('.')
1716        && subject_str.bytes().all(|c| !c.is_ascii_whitespace())
1717}
1718macro_rules! from_with_timeout {
1719    ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1720        impl From<$origin> for $t {
1721            fn from(err: $origin) -> Self {
1722                match err.kind() {
1723                    <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1724                    _ => Self::with_source(<$k>::Other, err),
1725                }
1726            }
1727        }
1728    };
1729}
1730pub(crate) use from_with_timeout;
1731
1732use crate::connection::ShouldFlush;
1733
1734#[cfg(test)]
1735mod tests {
1736    use super::*;
1737
1738    #[test]
1739    fn server_address_ipv6() {
1740        let address = ServerAddr::from_str("nats://[::]").unwrap();
1741        assert_eq!(address.host(), "::")
1742    }
1743
1744    #[test]
1745    fn server_address_ipv4() {
1746        let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1747        assert_eq!(address.host(), "127.0.0.1")
1748    }
1749
1750    #[test]
1751    fn server_address_domain() {
1752        let address = ServerAddr::from_str("nats://example.com").unwrap();
1753        assert_eq!(address.host(), "example.com")
1754    }
1755
1756    #[test]
1757    fn to_server_addrs_vec_str() {
1758        let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1759        let mut addrs_iter = vec.to_server_addrs().unwrap();
1760        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1761        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1762        assert_eq!(addrs_iter.next(), None);
1763    }
1764
1765    #[test]
1766    fn to_server_addrs_arr_str() {
1767        let arr = ["nats://127.0.0.1", "nats://[::]"];
1768        let mut addrs_iter = arr.to_server_addrs().unwrap();
1769        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1770        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1771        assert_eq!(addrs_iter.next(), None);
1772    }
1773
1774    #[test]
1775    fn to_server_addrs_vec_string() {
1776        let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1777        let mut addrs_iter = vec.to_server_addrs().unwrap();
1778        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1779        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1780        assert_eq!(addrs_iter.next(), None);
1781    }
1782
1783    #[test]
1784    fn to_server_addrs_arr_string() {
1785        let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1786        let mut addrs_iter = arr.to_server_addrs().unwrap();
1787        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1788        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1789        assert_eq!(addrs_iter.next(), None);
1790    }
1791
1792    #[test]
1793    fn to_server_ports_arr_string() {
1794        for (arr, expected_port) in [
1795            (
1796                [
1797                    "nats://127.0.0.1".to_string(),
1798                    "nats://[::]".to_string(),
1799                    "tls://127.0.0.1".to_string(),
1800                    "tls://[::]".to_string(),
1801                ],
1802                4222,
1803            ),
1804            (
1805                [
1806                    "ws://127.0.0.1:80".to_string(),
1807                    "ws://[::]:80".to_string(),
1808                    "ws://127.0.0.1".to_string(),
1809                    "ws://[::]".to_string(),
1810                ],
1811                80,
1812            ),
1813            (
1814                [
1815                    "wss://127.0.0.1".to_string(),
1816                    "wss://[::]".to_string(),
1817                    "wss://127.0.0.1:443".to_string(),
1818                    "wss://[::]:443".to_string(),
1819                ],
1820                443,
1821            ),
1822        ] {
1823            let mut addrs_iter = arr.to_server_addrs().unwrap();
1824            assert_eq!(addrs_iter.next().unwrap().port(), expected_port);
1825        }
1826    }
1827}