async_nats_flyradar/
lib.rs

1// Copyright 2020-2022 The NATS Authors
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//
6// http://www.apache.org/licenses/LICENSE-2.0
7//
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14//! A Rust asynchronous client for the NATS.io ecosystem.
15//!
16//! To access the repository, you can clone it by running:
17//!
18//! ```bash
19//! git clone https://github.com/nats-io/nats.rs
20//! ````
21//! NATS.io is a simple, secure, and high-performance open-source messaging
22//! system designed for cloud-native applications, IoT messaging, and microservices
23//! architectures.
24//!
25//! **Note**: The synchronous NATS API is deprecated and no longer actively maintained. If you need to use the deprecated synchronous API, you can refer to:
26//! <https://crates.io/crates/nats>
27//!
28//! For more information on NATS.io visit: <https://nats.io>
29//!
30//! ## Examples
31//!
32//! Below, you can find some basic examples on how to use this library.
33//!
34//! For more details, please refer to the specific methods and structures documentation.
35//!
36//! ### Complete example
37//!
38//! Connect to the NATS server, publish messages and subscribe to receive messages.
39//!
40//! ```no_run
41//! use bytes::Bytes;
42//! use futures::StreamExt;
43//!
44//! #[tokio::main]
45//! async fn main() -> Result<(), async_nats::Error> {
46//!     // Connect to the NATS server
47//!     let client = async_nats::connect("demo.nats.io").await?;
48//!
49//!     // Subscribe to the "messages" subject
50//!     let mut subscriber = client.subscribe("messages").await?;
51//!
52//!     // Publish messages to the "messages" subject
53//!     for _ in 0..10 {
54//!         client.publish("messages", "data".into()).await?;
55//!     }
56//!
57//!     // Receive and process messages
58//!     while let Some(message) = subscriber.next().await {
59//!         println!("Received message {:?}", message);
60//!     }
61//!
62//!     Ok(())
63//! }
64//! ```
65//!
66//! ### Publish
67//!
68//! Connect to the NATS server and publish messages to a subject.
69//!
70//! ```
71//! # use bytes::Bytes;
72//! # use std::error::Error;
73//! # use std::time::Instant;
74//! # #[tokio::main]
75//! # async fn main() -> Result<(), async_nats::Error> {
76//! // Connect to the NATS server
77//! let client = async_nats::connect("demo.nats.io").await?;
78//!
79//! // Prepare the subject and data
80//! let subject = "foo";
81//! let data = Bytes::from("bar");
82//!
83//! // Publish messages to the NATS server
84//! for _ in 0..10 {
85//!     client.publish(subject, data.clone()).await?;
86//! }
87//!
88//! // Flush internal buffer before exiting to make sure all messages are sent
89//! client.flush().await?;
90//!
91//! #    Ok(())
92//! # }
93//! ```
94//!
95//! ### Subscribe
96//!
97//! Connect to the NATS server, subscribe to a subject and receive messages.
98//!
99//! ```no_run
100//! # use bytes::Bytes;
101//! # use futures::StreamExt;
102//! # use std::error::Error;
103//! # use std::time::Instant;
104//! # #[tokio::main]
105//! # async fn main() -> Result<(), async_nats::Error> {
106//! // Connect to the NATS server
107//! let client = async_nats::connect("demo.nats.io").await?;
108//!
109//! // Subscribe to the "foo" subject
110//! let mut subscriber = client.subscribe("foo").await.unwrap();
111//!
112//! // Receive and process messages
113//! while let Some(message) = subscriber.next().await {
114//!     println!("Received message {:?}", message);
115//! }
116//! #     Ok(())
117//! # }
118//! ```
119//!
120//! ### JetStream
121//!
122//! To access JetStream API, create a JetStream [jetstream::Context].
123//!
124//! ```no_run
125//! # #[tokio::main]
126//! # async fn main() -> Result<(), async_nats::Error> {
127//! // Connect to the NATS server
128//! let client = async_nats::connect("demo.nats.io").await?;
129//! // Create a JetStream context.
130//! let jetstream = async_nats::jetstream::new(client);
131//!
132//! // Publish JetStream messages, manage streams, consumers, etc.
133//! jetstream.publish("foo", "bar".into()).await?;
134//! # Ok(())
135//! # }
136//! ```
137//!
138//! ### Key-value Store
139//!
140//! Key-value [Store][jetstream::kv::Store] is accessed through [jetstream::Context].
141//!
142//! ```no_run
143//! # #[tokio::main]
144//! # async fn main() -> Result<(), async_nats::Error> {
145//! // Connect to the NATS server
146//! let client = async_nats::connect("demo.nats.io").await?;
147//! // Create a JetStream context.
148//! let jetstream = async_nats::jetstream::new(client);
149//! // Access an existing key-value.
150//! let kv = jetstream.get_key_value("store").await?;
151//! # Ok(())
152//! # }
153//! ```
154//! ### Object Store store
155//!
156//! Object [Store][jetstream::object_store::ObjectStore] is accessed through [jetstream::Context].
157//!
158//! ```no_run
159//! # #[tokio::main]
160//! # async fn main() -> Result<(), async_nats::Error> {
161//! // Connect to the NATS server
162//! let client = async_nats::connect("demo.nats.io").await?;
163//! // Create a JetStream context.
164//! let jetstream = async_nats::jetstream::new(client);
165//! // Access an existing key-value.
166//! let kv = jetstream.get_object_store("store").await?;
167//! # Ok(())
168//! # }
169//! ```
170//! ### Service API
171//!
172//! [Service API][service::Service] is accessible through [Client] after importing its trait.
173//!
174//! ```no_run
175//! # #[tokio::main]
176//! # async fn main() -> Result<(), async_nats::Error> {
177//! use async_nats::service::ServiceExt;
178//! // Connect to the NATS server
179//! let client = async_nats::connect("demo.nats.io").await?;
180//! let mut service = client
181//!     .service_builder()
182//!     .description("some service")
183//!     .stats_handler(|endpoint, stats| serde_json::json!({ "endpoint": endpoint }))
184//!     .start("products", "1.0.0")
185//!     .await?;
186//! # Ok(())
187//! # }
188//! ```
189
190#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_auto_cfg))]
196
197use thiserror::Error;
198
199use futures::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use std::collections::HashMap;
206use std::fmt::Display;
207use std::future::Future;
208use std::iter;
209use std::mem;
210use std::net::SocketAddr;
211use std::option;
212use std::pin::Pin;
213use std::slice;
214use std::str::{self, FromStr};
215use std::sync::atomic::AtomicUsize;
216use std::sync::atomic::Ordering;
217use std::sync::Arc;
218use std::task::{Context, Poll};
219use tokio::io::ErrorKind;
220use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
221use url::{Host, Url};
222
223use bytes::Bytes;
224use serde::{Deserialize, Serialize};
225use serde_repr::{Deserialize_repr, Serialize_repr};
226use tokio::io;
227use tokio::sync::mpsc;
228use tokio::task;
229
230pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
231
232const VERSION: &str = env!("CARGO_PKG_VERSION");
233const LANG: &str = "rust";
234const MAX_PENDING_PINGS: usize = 2;
235const MULTIPLEXER_SID: u64 = 0;
236
237/// A re-export of the `rustls` crate used in this crate,
238/// for use in cases where manual client configurations
239/// must be provided using `Options::tls_client_config`.
240pub use tokio_rustls::rustls;
241
242use connection::{Connection, State};
243use connector::{Connector, ConnectorOptions};
244pub use header::{HeaderMap, HeaderName, HeaderValue};
245pub use subject::Subject;
246
247mod auth;
248pub(crate) mod auth_utils;
249pub mod client;
250pub mod connection;
251mod connector;
252//INFO: Diff starts here
253pub use connector::Dialer;
254//INFO: Diff ends here
255mod options;
256
257pub use auth::Auth;
258pub use client::{
259    Client, PublishError, Request, RequestError, RequestErrorKind, Statistics, SubscribeError,
260};
261pub use options::{AuthError, ConnectOptions};
262
263mod crypto;
264pub mod error;
265pub mod header;
266pub mod jetstream;
267pub mod message;
268#[cfg(feature = "service")]
269pub mod service;
270pub mod status;
271pub mod subject;
272mod tls;
273
274pub use message::Message;
275pub use status::StatusCode;
276
277/// Information sent by the server back to this client
278/// during initial connection, and possibly again later.
279#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
280pub struct ServerInfo {
281    /// The unique identifier of the NATS server.
282    #[serde(default)]
283    pub server_id: String,
284    /// Generated Server Name.
285    #[serde(default)]
286    pub server_name: String,
287    /// The host specified in the cluster parameter/options.
288    #[serde(default)]
289    pub host: String,
290    /// The port number specified in the cluster parameter/options.
291    #[serde(default)]
292    pub port: u16,
293    /// The version of the NATS server.
294    #[serde(default)]
295    pub version: String,
296    /// If this is set, then the server should try to authenticate upon
297    /// connect.
298    #[serde(default)]
299    pub auth_required: bool,
300    /// If this is set, then the server must authenticate using TLS.
301    #[serde(default)]
302    pub tls_required: bool,
303    /// Maximum payload size that the server will accept.
304    #[serde(default)]
305    pub max_payload: usize,
306    /// The protocol version in use.
307    #[serde(default)]
308    pub proto: i8,
309    /// The server-assigned client ID. This may change during reconnection.
310    #[serde(default)]
311    pub client_id: u64,
312    /// The version of golang the NATS server was built with.
313    #[serde(default)]
314    pub go: String,
315    /// The nonce used for nkeys.
316    #[serde(default)]
317    pub nonce: String,
318    /// A list of server urls that a client can connect to.
319    #[serde(default)]
320    pub connect_urls: Vec<String>,
321    /// The client IP as known by the server.
322    #[serde(default)]
323    pub client_ip: String,
324    /// Whether the server supports headers.
325    #[serde(default)]
326    pub headers: bool,
327    /// Whether server goes into lame duck mode.
328    #[serde(default, rename = "ldm")]
329    pub lame_duck_mode: bool,
330}
331
332#[derive(Clone, Debug, Eq, PartialEq)]
333pub(crate) enum ServerOp {
334    Ok,
335    Info(Box<ServerInfo>),
336    Ping,
337    Pong,
338    Error(ServerError),
339    Message {
340        sid: u64,
341        subject: Subject,
342        reply: Option<Subject>,
343        payload: Bytes,
344        headers: Option<HeaderMap>,
345        status: Option<StatusCode>,
346        description: Option<String>,
347        length: usize,
348    },
349}
350
351/// `PublishMessage` represents a message being published
352#[derive(Debug)]
353pub struct PublishMessage {
354    pub subject: Subject,
355    pub payload: Bytes,
356    pub reply: Option<Subject>,
357    pub headers: Option<HeaderMap>,
358}
359
360/// `Command` represents all commands that a [`Client`] can handle
361#[derive(Debug)]
362pub(crate) enum Command {
363    Publish(PublishMessage),
364    Request {
365        subject: Subject,
366        payload: Bytes,
367        respond: Subject,
368        headers: Option<HeaderMap>,
369        sender: oneshot::Sender<Message>,
370    },
371    Subscribe {
372        sid: u64,
373        subject: Subject,
374        queue_group: Option<String>,
375        sender: mpsc::Sender<Message>,
376    },
377    Unsubscribe {
378        sid: u64,
379        max: Option<u64>,
380    },
381    Flush {
382        observer: oneshot::Sender<()>,
383    },
384    Drain {
385        sid: Option<u64>,
386    },
387    Reconnect,
388}
389
390/// `ClientOp` represents all actions of `Client`.
391#[derive(Debug)]
392pub(crate) enum ClientOp {
393    Publish {
394        subject: Subject,
395        payload: Bytes,
396        respond: Option<Subject>,
397        headers: Option<HeaderMap>,
398    },
399    Subscribe {
400        sid: u64,
401        subject: Subject,
402        queue_group: Option<String>,
403    },
404    Unsubscribe {
405        sid: u64,
406        max: Option<u64>,
407    },
408    Ping,
409    Pong,
410    Connect(ConnectInfo),
411}
412
413#[derive(Debug)]
414struct Subscription {
415    subject: Subject,
416    sender: mpsc::Sender<Message>,
417    queue_group: Option<String>,
418    delivered: u64,
419    max: Option<u64>,
420    is_draining: bool,
421}
422
423#[derive(Debug)]
424struct Multiplexer {
425    subject: Subject,
426    prefix: Subject,
427    senders: HashMap<String, oneshot::Sender<Message>>,
428}
429
430/// A connection handler which facilitates communication from channels to a single shared connection.
431pub(crate) struct ConnectionHandler {
432    connection: Connection,
433    connector: Connector,
434    subscriptions: HashMap<u64, Subscription>,
435    multiplexer: Option<Multiplexer>,
436    pending_pings: usize,
437    info_sender: tokio::sync::watch::Sender<ServerInfo>,
438    ping_interval: Interval,
439    should_reconnect: bool,
440    flush_observers: Vec<oneshot::Sender<()>>,
441    is_draining: bool,
442}
443
444impl ConnectionHandler {
445    pub(crate) fn new(
446        connection: Connection,
447        connector: Connector,
448        info_sender: tokio::sync::watch::Sender<ServerInfo>,
449        ping_period: Duration,
450    ) -> ConnectionHandler {
451        let mut ping_interval = interval(ping_period);
452        ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
453
454        ConnectionHandler {
455            connection,
456            connector,
457            subscriptions: HashMap::new(),
458            multiplexer: None,
459            pending_pings: 0,
460            info_sender,
461            ping_interval,
462            should_reconnect: false,
463            flush_observers: Vec::new(),
464            is_draining: false,
465        }
466    }
467
468    pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
469        struct ProcessFut<'a> {
470            handler: &'a mut ConnectionHandler,
471            receiver: &'a mut mpsc::Receiver<Command>,
472            recv_buf: &'a mut Vec<Command>,
473        }
474
475        enum ExitReason {
476            Disconnected(Option<io::Error>),
477            ReconnectRequested,
478            Closed,
479        }
480
481        impl ProcessFut<'_> {
482            const RECV_CHUNK_SIZE: usize = 16;
483
484            #[cold]
485            fn ping(&mut self) -> Poll<ExitReason> {
486                self.handler.pending_pings += 1;
487
488                if self.handler.pending_pings > MAX_PENDING_PINGS {
489                    debug!(
490                        "pending pings {}, max pings {}. disconnecting",
491                        self.handler.pending_pings, MAX_PENDING_PINGS
492                    );
493
494                    Poll::Ready(ExitReason::Disconnected(None))
495                } else {
496                    self.handler.connection.enqueue_write_op(&ClientOp::Ping);
497
498                    Poll::Pending
499                }
500            }
501        }
502
503        impl Future for ProcessFut<'_> {
504            type Output = ExitReason;
505
506            /// Drives the connection forward.
507            ///
508            /// Returns one of the following:
509            ///
510            /// * `Poll::Pending` means that the connection
511            ///   is blocked on all fronts or there are
512            ///   no commands to send or receive
513            /// * `Poll::Ready(ExitReason::Disconnected(_))` means
514            ///   that an I/O operation failed and the connection
515            ///   is considered dead.
516            /// * `Poll::Ready(ExitReason::Closed)` means that
517            ///   [`Self::receiver`] was closed, so there's nothing
518            ///   more for us to do than to exit the client.
519            fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
520                // We need to be sure the waker is registered, therefore we need to poll until we
521                // get a `Poll::Pending`. With a sane interval delay, this means that the loop
522                // breaks at the second iteration.
523                while self.handler.ping_interval.poll_tick(cx).is_ready() {
524                    if let Poll::Ready(exit) = self.ping() {
525                        return Poll::Ready(exit);
526                    }
527                }
528
529                loop {
530                    match self.handler.connection.poll_read_op(cx) {
531                        Poll::Pending => break,
532                        Poll::Ready(Ok(Some(server_op))) => {
533                            self.handler.handle_server_op(server_op);
534                        }
535                        Poll::Ready(Ok(None)) => {
536                            return Poll::Ready(ExitReason::Disconnected(None))
537                        }
538                        Poll::Ready(Err(err)) => {
539                            return Poll::Ready(ExitReason::Disconnected(Some(err)))
540                        }
541                    }
542                }
543
544                // Before handling any commands, drop any subscriptions which are draining
545                // Note: safe to assume subscription drain has completed at this point, as we would have flushed
546                // all outgoing UNSUB messages in the previous call to this fn, and we would have processed and
547                // delivered any remaining messages to the subscription in the loop above.
548                self.handler.subscriptions.retain(|_, s| !s.is_draining);
549
550                if self.handler.is_draining {
551                    // The entire connection is draining. This means we flushed outgoing messages in the previous
552                    // call to this fn, we handled any remaining messages from the server in the loop above, and
553                    // all subs were drained, so drain is complete and we should exit instead of processing any
554                    // further messages
555                    return Poll::Ready(ExitReason::Closed);
556                }
557
558                // WARNING: after the following loop `handle_command`,
559                // or other functions which call `enqueue_write_op`,
560                // cannot be called anymore. Runtime wakeups won't
561                // trigger a call to `poll_write`
562
563                let mut made_progress = true;
564                loop {
565                    while !self.handler.connection.is_write_buf_full() {
566                        debug_assert!(self.recv_buf.is_empty());
567
568                        let Self {
569                            recv_buf,
570                            handler,
571                            receiver,
572                        } = &mut *self;
573                        match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
574                            Poll::Pending => break,
575                            Poll::Ready(1..) => {
576                                made_progress = true;
577
578                                for cmd in recv_buf.drain(..) {
579                                    handler.handle_command(cmd);
580                                }
581                            }
582                            // TODO: replace `_` with `0` after bumping MSRV to 1.75
583                            Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
584                        }
585                    }
586
587                    // The first round will poll both from
588                    // the `receiver` and the writer, giving
589                    // them both a chance to make progress
590                    // and register `Waker`s.
591                    //
592                    // If writing is `Poll::Pending` we exit.
593                    //
594                    // If writing is completed we can repeat the entire
595                    // cycle as long as the `receiver` doesn't end-up
596                    // `Poll::Pending` immediately.
597                    if !mem::take(&mut made_progress) {
598                        break;
599                    }
600
601                    match self.handler.connection.poll_write(cx) {
602                        Poll::Pending => {
603                            // Write buffer couldn't be fully emptied
604                            break;
605                        }
606                        Poll::Ready(Ok(())) => {
607                            // Write buffer is empty
608                            continue;
609                        }
610                        Poll::Ready(Err(err)) => {
611                            return Poll::Ready(ExitReason::Disconnected(Some(err)))
612                        }
613                    }
614                }
615
616                if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
617                    self.handler.connection.should_flush(),
618                    self.handler.flush_observers.is_empty(),
619                ) {
620                    match self.handler.connection.poll_flush(cx) {
621                        Poll::Pending => {}
622                        Poll::Ready(Ok(())) => {
623                            for observer in self.handler.flush_observers.drain(..) {
624                                let _ = observer.send(());
625                            }
626                        }
627                        Poll::Ready(Err(err)) => {
628                            return Poll::Ready(ExitReason::Disconnected(Some(err)))
629                        }
630                    }
631                }
632
633                if mem::take(&mut self.handler.should_reconnect) {
634                    return Poll::Ready(ExitReason::ReconnectRequested);
635                }
636
637                Poll::Pending
638            }
639        }
640
641        let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
642        loop {
643            let process = ProcessFut {
644                handler: self,
645                receiver,
646                recv_buf: &mut recv_buf,
647            };
648            match process.await {
649                ExitReason::Disconnected(err) => {
650                    debug!(?err, "disconnected");
651                    if self.handle_disconnect().await.is_err() {
652                        break;
653                    };
654                    debug!("reconnected");
655                }
656                ExitReason::Closed => {
657                    // Safe to ignore result as we're shutting down anyway
658                    self.connector.events_tx.try_send(Event::Closed).ok();
659                    break;
660                }
661                ExitReason::ReconnectRequested => {
662                    debug!("reconnect requested");
663                    // Should be ok to ingore error, as that means we are not in connected state.
664                    self.connection.stream.shutdown().await.ok();
665                    if self.handle_disconnect().await.is_err() {
666                        break;
667                    };
668                }
669            }
670        }
671    }
672
673    fn handle_server_op(&mut self, server_op: ServerOp) {
674        self.ping_interval.reset();
675
676        match server_op {
677            ServerOp::Ping => {
678                self.connection.enqueue_write_op(&ClientOp::Pong);
679            }
680            ServerOp::Pong => {
681                debug!("received PONG");
682                self.pending_pings = self.pending_pings.saturating_sub(1);
683            }
684            ServerOp::Error(error) => {
685                self.connector
686                    .events_tx
687                    .try_send(Event::ServerError(error))
688                    .ok();
689            }
690            ServerOp::Message {
691                sid,
692                subject,
693                reply,
694                payload,
695                headers,
696                status,
697                description,
698                length,
699            } => {
700                self.connector
701                    .connect_stats
702                    .in_messages
703                    .add(1, Ordering::Relaxed);
704
705                if let Some(subscription) = self.subscriptions.get_mut(&sid) {
706                    let message: Message = Message {
707                        subject,
708                        reply,
709                        payload,
710                        headers,
711                        status,
712                        description,
713                        length,
714                    };
715
716                    // if the channel for subscription was dropped, remove the
717                    // subscription from the map and unsubscribe.
718                    match subscription.sender.try_send(message) {
719                        Ok(_) => {
720                            subscription.delivered += 1;
721                            // if this `Subscription` has set `max` value, check if it
722                            // was reached. If yes, remove the `Subscription` and in
723                            // the result, `drop` the `sender` channel.
724                            if let Some(max) = subscription.max {
725                                if subscription.delivered.ge(&max) {
726                                    self.subscriptions.remove(&sid);
727                                }
728                            }
729                        }
730                        Err(mpsc::error::TrySendError::Full(_)) => {
731                            self.connector
732                                .events_tx
733                                .try_send(Event::SlowConsumer(sid))
734                                .ok();
735                        }
736                        Err(mpsc::error::TrySendError::Closed(_)) => {
737                            self.subscriptions.remove(&sid);
738                            self.connection
739                                .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
740                        }
741                    }
742                } else if sid == MULTIPLEXER_SID {
743                    if let Some(multiplexer) = self.multiplexer.as_mut() {
744                        let maybe_token =
745                            subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
746
747                        if let Some(token) = maybe_token {
748                            if let Some(sender) = multiplexer.senders.remove(token) {
749                                let message = Message {
750                                    subject,
751                                    reply,
752                                    payload,
753                                    headers,
754                                    status,
755                                    description,
756                                    length,
757                                };
758
759                                let _ = sender.send(message);
760                            }
761                        }
762                    }
763                }
764            }
765            // TODO: we should probably update advertised server list here too.
766            ServerOp::Info(info) => {
767                if info.lame_duck_mode {
768                    self.connector.events_tx.try_send(Event::LameDuckMode).ok();
769                }
770            }
771
772            _ => {
773                // TODO: don't ignore.
774            }
775        }
776    }
777
778    fn handle_command(&mut self, command: Command) {
779        self.ping_interval.reset();
780
781        match command {
782            Command::Unsubscribe { sid, max } => {
783                if let Some(subscription) = self.subscriptions.get_mut(&sid) {
784                    subscription.max = max;
785                    match subscription.max {
786                        Some(n) => {
787                            if subscription.delivered >= n {
788                                self.subscriptions.remove(&sid);
789                            }
790                        }
791                        None => {
792                            self.subscriptions.remove(&sid);
793                        }
794                    }
795
796                    self.connection
797                        .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
798                }
799            }
800            Command::Flush { observer } => {
801                self.flush_observers.push(observer);
802            }
803            Command::Drain { sid } => {
804                let mut drain_sub = |sid: u64, sub: &mut Subscription| {
805                    sub.is_draining = true;
806                    self.connection
807                        .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
808                };
809
810                if let Some(sid) = sid {
811                    if let Some(sub) = self.subscriptions.get_mut(&sid) {
812                        drain_sub(sid, sub);
813                    }
814                } else {
815                    // sid isn't set, so drain the whole client
816                    self.connector.events_tx.try_send(Event::Draining).ok();
817                    self.is_draining = true;
818                    for (&sid, sub) in self.subscriptions.iter_mut() {
819                        drain_sub(sid, sub);
820                    }
821                }
822            }
823            Command::Subscribe {
824                sid,
825                subject,
826                queue_group,
827                sender,
828            } => {
829                let subscription = Subscription {
830                    sender,
831                    delivered: 0,
832                    max: None,
833                    subject: subject.to_owned(),
834                    queue_group: queue_group.to_owned(),
835                    is_draining: false,
836                };
837
838                self.subscriptions.insert(sid, subscription);
839
840                self.connection.enqueue_write_op(&ClientOp::Subscribe {
841                    sid,
842                    subject,
843                    queue_group,
844                });
845            }
846            Command::Request {
847                subject,
848                payload,
849                respond,
850                headers,
851                sender,
852            } => {
853                let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
854
855                let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
856                    multiplexer
857                } else {
858                    let prefix = Subject::from(format!("{}.{}.", prefix, nuid::next()));
859                    let subject = Subject::from(format!("{}*", prefix));
860
861                    self.connection.enqueue_write_op(&ClientOp::Subscribe {
862                        sid: MULTIPLEXER_SID,
863                        subject: subject.clone(),
864                        queue_group: None,
865                    });
866
867                    self.multiplexer.insert(Multiplexer {
868                        subject,
869                        prefix,
870                        senders: HashMap::new(),
871                    })
872                };
873                self.connector
874                    .connect_stats
875                    .out_messages
876                    .add(1, Ordering::Relaxed);
877
878                multiplexer.senders.insert(token.to_owned(), sender);
879
880                let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
881
882                let pub_op = ClientOp::Publish {
883                    subject,
884                    payload,
885                    respond: Some(respond),
886                    headers,
887                };
888
889                self.connection.enqueue_write_op(&pub_op);
890            }
891
892            Command::Publish(PublishMessage {
893                subject,
894                payload,
895                reply: respond,
896                headers,
897            }) => {
898                self.connector
899                    .connect_stats
900                    .out_messages
901                    .add(1, Ordering::Relaxed);
902
903                let header_len = headers
904                    .as_ref()
905                    .map(|headers| headers.len())
906                    .unwrap_or_default();
907
908                self.connector.connect_stats.out_bytes.add(
909                    (payload.len()
910                        + respond.as_ref().map_or_else(|| 0, |r| r.len())
911                        + subject.len()
912                        + header_len) as u64,
913                    Ordering::Relaxed,
914                );
915
916                self.connection.enqueue_write_op(&ClientOp::Publish {
917                    subject,
918                    payload,
919                    respond,
920                    headers,
921                });
922            }
923
924            Command::Reconnect => {
925                self.should_reconnect = true;
926            }
927        }
928    }
929
930    async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
931        self.pending_pings = 0;
932        self.connector.events_tx.try_send(Event::Disconnected).ok();
933        self.connector.state_tx.send(State::Disconnected).ok();
934
935        self.handle_reconnect().await
936    }
937
938    async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
939        let (info, connection) = self.connector.connect().await?;
940        self.connection = connection;
941        let _ = self.info_sender.send(info);
942
943        self.subscriptions
944            .retain(|_, subscription| !subscription.sender.is_closed());
945
946        for (sid, subscription) in &self.subscriptions {
947            self.connection.enqueue_write_op(&ClientOp::Subscribe {
948                sid: *sid,
949                subject: subscription.subject.to_owned(),
950                queue_group: subscription.queue_group.to_owned(),
951            });
952        }
953
954        if let Some(multiplexer) = &self.multiplexer {
955            self.connection.enqueue_write_op(&ClientOp::Subscribe {
956                sid: MULTIPLEXER_SID,
957                subject: multiplexer.subject.to_owned(),
958                queue_group: None,
959            });
960        }
961        Ok(())
962    }
963}
964
965/// Connects to NATS with specified options.
966///
967/// It is generally advised to use [ConnectOptions] instead, as it provides a builder for whole
968/// configuration.
969///
970/// # Examples
971/// ```
972/// # #[tokio::main]
973/// # async fn main() ->  Result<(), async_nats::Error> {
974/// let mut nc =
975///     async_nats::connect_with_options("demo.nats.io", async_nats::ConnectOptions::new()).await?;
976/// nc.publish("test", "data".into()).await?;
977/// # Ok(())
978/// # }
979/// ```
980pub async fn connect_with_options<A: ToServerAddrs>(
981    addrs: A,
982    options: ConnectOptions,
983) -> Result<Client, ConnectError> {
984    let ping_period = options.ping_interval;
985
986    let (events_tx, mut events_rx) = mpsc::channel(128);
987    let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
988    // We're setting it to the default server payload size.
989    let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));
990    let statistics = Arc::new(Statistics::default());
991
992    let mut connector = Connector::new(
993        addrs,
994        ConnectorOptions {
995            tls_required: options.tls_required,
996            certificates: options.certificates,
997            client_key: options.client_key,
998            client_cert: options.client_cert,
999            tls_client_config: options.tls_client_config,
1000            tls_first: options.tls_first,
1001            auth: options.auth,
1002            no_echo: options.no_echo,
1003            connection_timeout: options.connection_timeout,
1004            name: options.name,
1005            ignore_discovered_servers: options.ignore_discovered_servers,
1006            retain_servers_order: options.retain_servers_order,
1007            read_buffer_capacity: options.read_buffer_capacity,
1008            reconnect_delay_callback: options.reconnect_delay_callback,
1009            auth_callback: options.auth_callback,
1010            max_reconnects: options.max_reconnects,
1011            //INFO: Diff starts
1012            dialer: options.dialer,
1013            //INFO: Diff ends
1014        },
1015        events_tx,
1016        state_tx,
1017        max_payload.clone(),
1018        statistics.clone(),
1019    )
1020    .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
1021
1022    let mut info: ServerInfo = Default::default();
1023    let mut connection = None;
1024    if !options.retry_on_initial_connect {
1025        debug!("retry on initial connect failure is disabled");
1026        let (info_ok, connection_ok) = connector.try_connect().await?;
1027        connection = Some(connection_ok);
1028        info = info_ok;
1029    }
1030
1031    let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
1032    let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
1033
1034    let client = Client::new(
1035        info_watcher,
1036        state_rx,
1037        sender,
1038        options.subscription_capacity,
1039        options.inbox_prefix,
1040        options.request_timeout,
1041        max_payload,
1042        statistics,
1043    );
1044
1045    task::spawn(async move {
1046        while let Some(event) = events_rx.recv().await {
1047            tracing::info!("event: {}", event);
1048            if let Some(event_callback) = &options.event_callback {
1049                event_callback.call(event).await;
1050            }
1051        }
1052    });
1053
1054    task::spawn(async move {
1055        if connection.is_none() && options.retry_on_initial_connect {
1056            let (info, connection_ok) = match connector.connect().await {
1057                Ok((info, connection)) => (info, connection),
1058                Err(err) => {
1059                    error!("connection closed: {}", err);
1060                    return;
1061                }
1062            };
1063            info_sender.send(info).ok();
1064            connection = Some(connection_ok);
1065        }
1066        let connection = connection.unwrap();
1067        let mut connection_handler =
1068            ConnectionHandler::new(connection, connector, info_sender, ping_period);
1069        connection_handler.process(&mut receiver).await
1070    });
1071
1072    Ok(client)
1073}
1074
1075#[derive(Debug, Clone, PartialEq, Eq)]
1076pub enum Event {
1077    Connected,
1078    Disconnected,
1079    LameDuckMode,
1080    Draining,
1081    Closed,
1082    SlowConsumer(u64),
1083    ServerError(ServerError),
1084    ClientError(ClientError),
1085}
1086
1087impl fmt::Display for Event {
1088    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1089        match self {
1090            Event::Connected => write!(f, "connected"),
1091            Event::Disconnected => write!(f, "disconnected"),
1092            Event::LameDuckMode => write!(f, "lame duck mode detected"),
1093            Event::Draining => write!(f, "draining"),
1094            Event::Closed => write!(f, "closed"),
1095            Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"),
1096            Event::ServerError(err) => write!(f, "server error: {err}"),
1097            Event::ClientError(err) => write!(f, "client error: {err}"),
1098        }
1099    }
1100}
1101
1102/// Connects to NATS with default config.
1103///
1104/// Returns cloneable [Client].
1105///
1106/// To have customized NATS connection, check [ConnectOptions].
1107///
1108/// # Examples
1109///
1110/// ## Single URL
1111/// ```
1112/// # #[tokio::main]
1113/// # async fn main() ->  Result<(), async_nats::Error> {
1114/// let mut nc = async_nats::connect("demo.nats.io").await?;
1115/// nc.publish("test", "data".into()).await?;
1116/// # Ok(())
1117/// # }
1118/// ```
1119///
1120/// ## Connect with [Vec] of [ServerAddr].
1121/// ```no_run
1122/// #[tokio::main]
1123/// # async fn main() -> Result<(), async_nats::Error> {
1124/// use async_nats::ServerAddr;
1125/// let client = async_nats::connect(vec![
1126///     "demo.nats.io".parse::<ServerAddr>()?,
1127///     "other.nats.io".parse::<ServerAddr>()?,
1128/// ])
1129/// .await
1130/// .unwrap();
1131/// # Ok(())
1132/// # }
1133/// ```
1134///
1135/// ## with [Vec], but parse URLs inside [crate::connect()]
1136/// ```no_run
1137/// #[tokio::main]
1138/// # async fn main() -> Result<(), async_nats::Error> {
1139/// use async_nats::ServerAddr;
1140/// let servers = vec!["demo.nats.io", "other.nats.io"];
1141/// let client = async_nats::connect(
1142///     servers
1143///         .iter()
1144///         .map(|url| url.parse())
1145///         .collect::<Result<Vec<ServerAddr>, _>>()?,
1146/// )
1147/// .await?;
1148/// # Ok(())
1149/// # }
1150/// ```
1151///
1152///
1153/// ## with slice.
1154/// ```no_run
1155/// #[tokio::main]
1156/// # async fn main() -> Result<(), async_nats::Error> {
1157/// use async_nats::ServerAddr;
1158/// let client = async_nats::connect(
1159///    [
1160///        "demo.nats.io".parse::<ServerAddr>()?,
1161///        "other.nats.io".parse::<ServerAddr>()?,
1162///    ]
1163///    .as_slice(),
1164/// )
1165/// .await?;
1166/// # Ok(())
1167/// # }
1168pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1169    connect_with_options(addrs, ConnectOptions::default()).await
1170}
1171
1172#[derive(Debug, Clone, Copy, PartialEq)]
1173pub enum ConnectErrorKind {
1174    /// Parsing the passed server address failed.
1175    ServerParse,
1176    /// DNS related issues.
1177    Dns,
1178    /// Failed authentication process, signing nonce, etc.
1179    Authentication,
1180    /// Server returned authorization violation error.
1181    AuthorizationViolation,
1182    /// Connect timed out.
1183    TimedOut,
1184    /// Erroneous TLS setup.
1185    Tls,
1186    /// Other IO error.
1187    Io,
1188    /// Reached the maximum number of reconnects.
1189    MaxReconnects,
1190}
1191
1192impl Display for ConnectErrorKind {
1193    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1194        match self {
1195            Self::ServerParse => write!(f, "failed to parse server or server list"),
1196            Self::Dns => write!(f, "DNS error"),
1197            Self::Authentication => write!(f, "failed signing nonce"),
1198            Self::AuthorizationViolation => write!(f, "authorization violation"),
1199            Self::TimedOut => write!(f, "timed out"),
1200            Self::Tls => write!(f, "TLS error"),
1201            Self::Io => write!(f, "IO error"),
1202            Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1203        }
1204    }
1205}
1206
1207/// Returned when initial connection fails.
1208/// To be enumerate over the variants, call [ConnectError::kind].
1209pub type ConnectError = error::Error<ConnectErrorKind>;
1210
1211impl From<io::Error> for ConnectError {
1212    fn from(err: io::Error) -> Self {
1213        ConnectError::with_source(ConnectErrorKind::Io, err)
1214    }
1215}
1216
1217/// Retrieves messages from given `subscription` created by [Client::subscribe].
1218///
1219/// Implements [futures::stream::Stream] for ergonomic async message processing.
1220///
1221/// # Examples
1222/// ```
1223/// # #[tokio::main]
1224/// # async fn main() ->  Result<(), async_nats::Error> {
1225/// let mut nc = async_nats::connect("demo.nats.io").await?;
1226/// # nc.publish("test", "data".into()).await?;
1227/// # Ok(())
1228/// # }
1229/// ```
1230#[derive(Debug)]
1231pub struct Subscriber {
1232    sid: u64,
1233    receiver: mpsc::Receiver<Message>,
1234    sender: mpsc::Sender<Command>,
1235}
1236
1237impl Subscriber {
1238    fn new(
1239        sid: u64,
1240        sender: mpsc::Sender<Command>,
1241        receiver: mpsc::Receiver<Message>,
1242    ) -> Subscriber {
1243        Subscriber {
1244            sid,
1245            sender,
1246            receiver,
1247        }
1248    }
1249
1250    /// Unsubscribes from subscription, draining all remaining messages.
1251    ///
1252    /// # Examples
1253    /// ```
1254    /// # #[tokio::main]
1255    /// # async fn main() -> Result<(), async_nats::Error> {
1256    /// let client = async_nats::connect("demo.nats.io").await?;
1257    ///
1258    /// let mut subscriber = client.subscribe("foo").await?;
1259    ///
1260    /// subscriber.unsubscribe().await?;
1261    /// # Ok(())
1262    /// # }
1263    /// ```
1264    pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1265        self.sender
1266            .send(Command::Unsubscribe {
1267                sid: self.sid,
1268                max: None,
1269            })
1270            .await?;
1271        self.receiver.close();
1272        Ok(())
1273    }
1274
1275    /// Unsubscribes from subscription after reaching given number of messages.
1276    /// This is the total number of messages received by this subscription in it's whole
1277    /// lifespan. If it already reached or surpassed the passed value, it will immediately stop.
1278    ///
1279    /// # Examples
1280    /// ```
1281    /// # use futures::StreamExt;
1282    /// # #[tokio::main]
1283    /// # async fn main() -> Result<(), async_nats::Error> {
1284    /// let client = async_nats::connect("demo.nats.io").await?;
1285    ///
1286    /// let mut subscriber = client.subscribe("test").await?;
1287    /// subscriber.unsubscribe_after(3).await?;
1288    ///
1289    /// for _ in 0..3 {
1290    ///     client.publish("test", "data".into()).await?;
1291    /// }
1292    ///
1293    /// while let Some(message) = subscriber.next().await {
1294    ///     println!("message received: {:?}", message);
1295    /// }
1296    /// println!("no more messages, unsubscribed");
1297    /// # Ok(())
1298    /// # }
1299    /// ```
1300    pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1301        self.sender
1302            .send(Command::Unsubscribe {
1303                sid: self.sid,
1304                max: Some(unsub_after),
1305            })
1306            .await?;
1307        Ok(())
1308    }
1309
1310    /// Unsubscribes immediately but leaves the subscription open to allow any in-flight messages
1311    /// on the subscription to be delivered. The stream will be closed after any remaining messages
1312    /// are delivered
1313    ///
1314    /// # Examples
1315    /// ```no_run
1316    /// # use futures::StreamExt;
1317    /// # #[tokio::main]
1318    /// # async fn main() -> Result<(), async_nats::Error> {
1319    /// let client = async_nats::connect("demo.nats.io").await?;
1320    ///
1321    /// let mut subscriber = client.subscribe("test").await?;
1322    ///
1323    /// tokio::spawn({
1324    ///     let task_client = client.clone();
1325    ///     async move {
1326    ///         loop {
1327    ///             _ = task_client.publish("test", "data".into()).await;
1328    ///         }
1329    ///     }
1330    /// });
1331    ///
1332    /// client.flush().await?;
1333    /// subscriber.drain().await?;
1334    ///
1335    /// while let Some(message) = subscriber.next().await {
1336    ///     println!("message received: {:?}", message);
1337    /// }
1338    /// println!("no more messages, unsubscribed");
1339    /// # Ok(())
1340    /// # }
1341    /// ```
1342    pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
1343        self.sender
1344            .send(Command::Drain {
1345                sid: Some(self.sid),
1346            })
1347            .await?;
1348
1349        Ok(())
1350    }
1351}
1352
1353#[derive(Error, Debug, PartialEq)]
1354#[error("failed to send unsubscribe")]
1355pub struct UnsubscribeError(String);
1356
1357impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1358    fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1359        UnsubscribeError(err.to_string())
1360    }
1361}
1362
1363impl Drop for Subscriber {
1364    fn drop(&mut self) {
1365        self.receiver.close();
1366        tokio::spawn({
1367            let sender = self.sender.clone();
1368            let sid = self.sid;
1369            async move {
1370                sender
1371                    .send(Command::Unsubscribe { sid, max: None })
1372                    .await
1373                    .ok();
1374            }
1375        });
1376    }
1377}
1378
1379impl Stream for Subscriber {
1380    type Item = Message;
1381
1382    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1383        self.receiver.poll_recv(cx)
1384    }
1385}
1386
1387#[derive(Clone, Debug, Eq, PartialEq)]
1388pub enum CallbackError {
1389    Client(ClientError),
1390    Server(ServerError),
1391}
1392impl std::fmt::Display for CallbackError {
1393    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1394        match self {
1395            Self::Client(error) => write!(f, "{error}"),
1396            Self::Server(error) => write!(f, "{error}"),
1397        }
1398    }
1399}
1400
1401impl From<ServerError> for CallbackError {
1402    fn from(server_error: ServerError) -> Self {
1403        CallbackError::Server(server_error)
1404    }
1405}
1406
1407impl From<ClientError> for CallbackError {
1408    fn from(client_error: ClientError) -> Self {
1409        CallbackError::Client(client_error)
1410    }
1411}
1412
1413#[derive(Clone, Debug, Eq, PartialEq, Error)]
1414pub enum ServerError {
1415    AuthorizationViolation,
1416    SlowConsumer(u64),
1417    Other(String),
1418}
1419
1420#[derive(Clone, Debug, Eq, PartialEq)]
1421pub enum ClientError {
1422    Other(String),
1423    MaxReconnects,
1424}
1425impl std::fmt::Display for ClientError {
1426    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1427        match self {
1428            Self::Other(error) => write!(f, "nats: {error}"),
1429            Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1430        }
1431    }
1432}
1433
1434impl ServerError {
1435    fn new(error: String) -> ServerError {
1436        match error.to_lowercase().as_str() {
1437            "authorization violation" => ServerError::AuthorizationViolation,
1438            // error messages can contain case-sensitive values which should be preserved
1439            _ => ServerError::Other(error),
1440        }
1441    }
1442}
1443
1444impl std::fmt::Display for ServerError {
1445    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1446        match self {
1447            Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1448            Self::SlowConsumer(sid) => write!(f, "nats: subscription {sid} is a slow consumer"),
1449            Self::Other(error) => write!(f, "nats: {error}"),
1450        }
1451    }
1452}
1453
1454/// Info to construct a CONNECT message.
1455#[derive(Clone, Debug, Serialize)]
1456pub struct ConnectInfo {
1457    /// Turns on +OK protocol acknowledgments.
1458    pub verbose: bool,
1459
1460    /// Turns on additional strict format checking, e.g. for properly formed
1461    /// subjects.
1462    pub pedantic: bool,
1463
1464    /// User's JWT.
1465    #[serde(rename = "jwt")]
1466    pub user_jwt: Option<String>,
1467
1468    /// Public nkey.
1469    pub nkey: Option<String>,
1470
1471    /// Signed nonce, encoded to Base64URL.
1472    #[serde(rename = "sig")]
1473    pub signature: Option<String>,
1474
1475    /// Optional client name.
1476    pub name: Option<String>,
1477
1478    /// If set to `true`, the server (version 1.2.0+) will not send originating
1479    /// messages from this connection to its own subscriptions. Clients should
1480    /// set this to `true` only for server supporting this feature, which is
1481    /// when proto in the INFO protocol is set to at least 1.
1482    pub echo: bool,
1483
1484    /// The implementation language of the client.
1485    pub lang: String,
1486
1487    /// The version of the client.
1488    pub version: String,
1489
1490    /// Sending 0 (or absent) indicates client supports original protocol.
1491    /// Sending 1 indicates that the client supports dynamic reconfiguration
1492    /// of cluster topology changes by asynchronously receiving INFO messages
1493    /// with known servers it can reconnect to.
1494    pub protocol: Protocol,
1495
1496    /// Indicates whether the client requires an SSL connection.
1497    pub tls_required: bool,
1498
1499    /// Connection username (if `auth_required` is set)
1500    pub user: Option<String>,
1501
1502    /// Connection password (if auth_required is set)
1503    pub pass: Option<String>,
1504
1505    /// Client authorization token (if auth_required is set)
1506    pub auth_token: Option<String>,
1507
1508    /// Whether the client supports the usage of headers.
1509    pub headers: bool,
1510
1511    /// Whether the client supports no_responders.
1512    pub no_responders: bool,
1513}
1514
1515/// Protocol version used by the client.
1516#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1517#[repr(u8)]
1518pub enum Protocol {
1519    /// Original protocol.
1520    Original = 0,
1521    /// Protocol with dynamic reconfiguration of cluster and lame duck mode functionality.
1522    Dynamic = 1,
1523}
1524
1525/// Address of a NATS server.
1526#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1527pub struct ServerAddr(Url);
1528
1529impl FromStr for ServerAddr {
1530    type Err = io::Error;
1531
1532    /// Parse an address of a NATS server.
1533    ///
1534    /// If not stated explicitly the `nats://` schema and port `4222` is assumed.
1535    fn from_str(input: &str) -> Result<Self, Self::Err> {
1536        let url: Url = if input.contains("://") {
1537            input.parse()
1538        } else {
1539            format!("nats://{input}").parse()
1540        }
1541        .map_err(|e| {
1542            io::Error::new(
1543                ErrorKind::InvalidInput,
1544                format!("NATS server URL is invalid: {e}"),
1545            )
1546        })?;
1547
1548        Self::from_url(url)
1549    }
1550}
1551
1552impl ServerAddr {
1553    /// Check if the URL is a valid NATS server address.
1554    pub fn from_url(url: Url) -> io::Result<Self> {
1555        if url.scheme() != "nats"
1556            && url.scheme() != "tls"
1557            && url.scheme() != "ws"
1558            && url.scheme() != "wss"
1559        //INFO: Diff starts here
1560            && url.scheme() != "ipc"
1561        //INFO: Diff ends here
1562        {
1563            return Err(std::io::Error::new(
1564                ErrorKind::InvalidInput,
1565                format!("invalid scheme for NATS server URL: {}", url.scheme()),
1566            ));
1567        }
1568
1569        Ok(Self(url))
1570    }
1571
1572    /// Turn the server address into a standard URL.
1573    pub fn into_inner(self) -> Url {
1574        self.0
1575    }
1576
1577    /// Returns if tls is required by the client for this server.
1578    pub fn tls_required(&self) -> bool {
1579        self.0.scheme() == "tls"
1580    }
1581
1582    /// Returns if the server url had embedded username and password.
1583    pub fn has_user_pass(&self) -> bool {
1584        self.0.username() != ""
1585    }
1586
1587    pub fn scheme(&self) -> &str {
1588        self.0.scheme()
1589    }
1590
1591    /// Returns the host.
1592    pub fn host(&self) -> &str {
1593        match self.0.host() {
1594            Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1595            // `host_str()` for Ipv6 includes the []s
1596            Some(Host::Ipv6 { .. }) => {
1597                let host = self.0.host_str().unwrap();
1598                &host[1..host.len() - 1]
1599            }
1600            None => "",
1601        }
1602    }
1603
1604    pub fn is_websocket(&self) -> bool {
1605        self.0.scheme() == "ws" || self.0.scheme() == "wss"
1606    }
1607
1608    /// Returns the port.
1609    pub fn port(&self) -> u16 {
1610        self.0.port().unwrap_or(4222)
1611    }
1612
1613    /// Returns the optional username in the url.
1614    pub fn username(&self) -> Option<&str> {
1615        let user = self.0.username();
1616        if user.is_empty() {
1617            None
1618        } else {
1619            Some(user)
1620        }
1621    }
1622
1623    /// Returns the optional password in the url.
1624    pub fn password(&self) -> Option<&str> {
1625        self.0.password()
1626    }
1627
1628    /// Return the sockets from resolving the server address.
1629    pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1630        tokio::net::lookup_host((self.host(), self.port())).await
1631    }
1632}
1633
1634/// Capability to convert into a list of NATS server addresses.
1635///
1636/// There are several implementations ensuring the easy passing of one or more server addresses to
1637/// functions like [`crate::connect()`].
1638pub trait ToServerAddrs {
1639    /// Returned iterator over socket addresses which this type may correspond
1640    /// to.
1641    type Iter: Iterator<Item = ServerAddr>;
1642
1643    fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1644}
1645
1646impl ToServerAddrs for ServerAddr {
1647    type Iter = option::IntoIter<ServerAddr>;
1648    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1649        Ok(Some(self.clone()).into_iter())
1650    }
1651}
1652
1653impl ToServerAddrs for str {
1654    type Iter = option::IntoIter<ServerAddr>;
1655    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1656        self.parse::<ServerAddr>()
1657            .map(|addr| Some(addr).into_iter())
1658    }
1659}
1660
1661impl ToServerAddrs for String {
1662    type Iter = option::IntoIter<ServerAddr>;
1663    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1664        (**self).to_server_addrs()
1665    }
1666}
1667
1668impl<T: AsRef<str>> ToServerAddrs for [T] {
1669    type Iter = std::vec::IntoIter<ServerAddr>;
1670    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1671        self.iter()
1672            .map(AsRef::as_ref)
1673            .map(str::parse)
1674            .collect::<io::Result<_>>()
1675            .map(Vec::into_iter)
1676    }
1677}
1678
1679impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1680    type Iter = std::vec::IntoIter<ServerAddr>;
1681    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1682        self.as_slice().to_server_addrs()
1683    }
1684}
1685
1686impl<'a> ToServerAddrs for &'a [ServerAddr] {
1687    type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1688
1689    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1690        Ok(self.iter().cloned())
1691    }
1692}
1693
1694impl ToServerAddrs for Vec<ServerAddr> {
1695    type Iter = std::vec::IntoIter<ServerAddr>;
1696
1697    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1698        Ok(self.clone().into_iter())
1699    }
1700}
1701
1702impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1703    type Iter = T::Iter;
1704    fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1705        (**self).to_server_addrs()
1706    }
1707}
1708
1709pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1710    let subject_str = subject.as_ref();
1711    !subject_str.starts_with('.')
1712        && !subject_str.ends_with('.')
1713        && subject_str.bytes().all(|c| !c.is_ascii_whitespace())
1714}
1715macro_rules! from_with_timeout {
1716    ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1717        impl From<$origin> for $t {
1718            fn from(err: $origin) -> Self {
1719                match err.kind() {
1720                    <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1721                    _ => Self::with_source(<$k>::Other, err),
1722                }
1723            }
1724        }
1725    };
1726}
1727pub(crate) use from_with_timeout;
1728
1729use crate::connection::ShouldFlush;
1730
1731#[cfg(test)]
1732mod tests {
1733    use super::*;
1734
1735    #[test]
1736    fn server_address_ipv6() {
1737        let address = ServerAddr::from_str("nats://[::]").unwrap();
1738        assert_eq!(address.host(), "::")
1739    }
1740
1741    #[test]
1742    fn server_address_ipv4() {
1743        let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1744        assert_eq!(address.host(), "127.0.0.1")
1745    }
1746
1747    #[test]
1748    fn server_address_domain() {
1749        let address = ServerAddr::from_str("nats://example.com").unwrap();
1750        assert_eq!(address.host(), "example.com")
1751    }
1752
1753    #[test]
1754    fn to_server_addrs_vec_str() {
1755        let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1756        let mut addrs_iter = vec.to_server_addrs().unwrap();
1757        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1758        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1759        assert_eq!(addrs_iter.next(), None);
1760    }
1761
1762    #[test]
1763    fn to_server_addrs_arr_str() {
1764        let arr = ["nats://127.0.0.1", "nats://[::]"];
1765        let mut addrs_iter = arr.to_server_addrs().unwrap();
1766        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1767        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1768        assert_eq!(addrs_iter.next(), None);
1769    }
1770
1771    #[test]
1772    fn to_server_addrs_vec_string() {
1773        let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1774        let mut addrs_iter = vec.to_server_addrs().unwrap();
1775        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1776        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1777        assert_eq!(addrs_iter.next(), None);
1778    }
1779
1780    #[test]
1781    fn to_server_addrs_arr_string() {
1782        let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1783        let mut addrs_iter = arr.to_server_addrs().unwrap();
1784        assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1785        assert_eq!(addrs_iter.next().unwrap().host(), "::");
1786        assert_eq!(addrs_iter.next(), None);
1787    }
1788}