1#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_cfg))]
196
197use thiserror::Error;
198
199use futures_util::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use std::collections::HashMap;
206use std::collections::VecDeque;
207use std::fmt::Display;
208use std::future::Future;
209use std::iter;
210use std::mem;
211use std::net::SocketAddr;
212use std::option;
213use std::pin::Pin;
214use std::slice;
215use std::str::{self, FromStr};
216use std::sync::atomic::AtomicUsize;
217use std::sync::atomic::Ordering;
218use std::sync::Arc;
219use std::task::{Context, Poll};
220use tokio::io::ErrorKind;
221use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
222use url::{Host, Url};
223
224use bytes::Bytes;
225use serde::{Deserialize, Serialize};
226use serde_repr::{Deserialize_repr, Serialize_repr};
227use tokio::io;
228use tokio::sync::mpsc;
229use tokio::task;
230
231pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
232
233const VERSION: &str = env!("CARGO_PKG_VERSION");
234const LANG: &str = "rust";
235const MAX_PENDING_PINGS: usize = 2;
236const MULTIPLEXER_SID: u64 = 0;
237
238pub use tokio_rustls::rustls;
242
243use connection::{Connection, State};
244use connector::{Connector, ConnectorOptions};
245pub use header::{HeaderMap, HeaderName, HeaderValue};
246pub use subject::Subject;
247
248mod auth;
249pub(crate) mod auth_utils;
250pub mod client;
251pub mod connection;
252mod connector;
253mod options;
254
255pub use auth::Auth;
256pub use client::{
257 Client, PublishError, Request, RequestError, RequestErrorKind, Statistics, SubscribeError,
258};
259pub use options::{AuthError, ConnectOptions};
260
261#[cfg(feature = "crypto")]
262#[cfg_attr(docsrs, doc(cfg(feature = "crypto")))]
263mod crypto;
264pub mod error;
265pub mod header;
266mod id_generator;
267#[cfg(feature = "jetstream")]
268#[cfg_attr(docsrs, doc(cfg(feature = "jetstream")))]
269pub mod jetstream;
270pub mod message;
271#[cfg(feature = "service")]
272#[cfg_attr(docsrs, doc(cfg(feature = "service")))]
273pub mod service;
274pub mod status;
275pub mod subject;
276mod tls;
277
278pub use message::Message;
279pub use status::StatusCode;
280
281#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
284pub struct ServerInfo {
285 #[serde(default)]
287 pub server_id: String,
288 #[serde(default)]
290 pub server_name: String,
291 #[serde(default)]
293 pub host: String,
294 #[serde(default)]
296 pub port: u16,
297 #[serde(default)]
299 pub version: String,
300 #[serde(default)]
303 pub auth_required: bool,
304 #[serde(default)]
306 pub tls_required: bool,
307 #[serde(default)]
309 pub max_payload: usize,
310 #[serde(default)]
312 pub proto: i8,
313 #[serde(default)]
315 pub client_id: u64,
316 #[serde(default)]
318 pub go: String,
319 #[serde(default)]
321 pub nonce: String,
322 #[serde(default)]
324 pub connect_urls: Vec<String>,
325 #[serde(default)]
327 pub client_ip: String,
328 #[serde(default)]
330 pub headers: bool,
331 #[serde(default, rename = "ldm")]
333 pub lame_duck_mode: bool,
334 #[serde(default)]
336 pub cluster: Option<String>,
337 #[serde(default)]
339 pub domain: Option<String>,
340 #[serde(default)]
342 pub jetstream: bool,
343}
344
345#[derive(Clone, Debug, Eq, PartialEq)]
346pub(crate) enum ServerOp {
347 Ok,
348 Info(Box<ServerInfo>),
349 Ping,
350 Pong,
351 Error(ServerError),
352 Message {
353 sid: u64,
354 subject: Subject,
355 reply: Option<Subject>,
356 payload: Bytes,
357 headers: Option<HeaderMap>,
358 status: Option<StatusCode>,
359 description: Option<String>,
360 length: usize,
361 },
362}
363
364#[deprecated(
368 since = "0.44.0",
369 note = "use `async_nats::message::OutboundMessage` instead"
370)]
371pub type PublishMessage = crate::message::OutboundMessage;
372
373#[derive(Debug)]
375pub(crate) enum Command {
376 Publish(OutboundMessage),
377 Request {
378 subject: Subject,
379 payload: Bytes,
380 respond: Subject,
381 headers: Option<HeaderMap>,
382 sender: oneshot::Sender<Message>,
383 },
384 Subscribe {
385 sid: u64,
386 subject: Subject,
387 queue_group: Option<String>,
388 sender: mpsc::Sender<Message>,
389 },
390 Unsubscribe {
391 sid: u64,
392 max: Option<u64>,
393 },
394 Flush {
395 observer: oneshot::Sender<()>,
396 },
397 Drain {
398 sid: Option<u64>,
399 },
400 Reconnect,
401}
402
403#[derive(Debug)]
405pub(crate) enum ClientOp {
406 Publish {
407 subject: Subject,
408 payload: Bytes,
409 respond: Option<Subject>,
410 headers: Option<HeaderMap>,
411 },
412 Subscribe {
413 sid: u64,
414 subject: Subject,
415 queue_group: Option<String>,
416 },
417 Unsubscribe {
418 sid: u64,
419 max: Option<u64>,
420 },
421 Ping,
422 Pong,
423 Connect(ConnectInfo),
424}
425
426#[derive(Debug)]
427struct Subscription {
428 subject: Subject,
429 sender: mpsc::Sender<Message>,
430 queue_group: Option<String>,
431 delivered: u64,
432 max: Option<u64>,
433}
434
435#[derive(Debug)]
436struct Multiplexer {
437 subject: Subject,
438 prefix: Subject,
439 senders: HashMap<String, oneshot::Sender<Message>>,
440}
441
442pub(crate) struct ConnectionHandler {
444 connection: Connection,
445 connector: Connector,
446 subscriptions: HashMap<u64, Subscription>,
447 multiplexer: Option<Multiplexer>,
448 pending_pings: usize,
449 info_sender: tokio::sync::watch::Sender<ServerInfo>,
450 ping_interval: Interval,
451 should_reconnect: bool,
452 flush_observers: Vec<oneshot::Sender<()>>,
453 is_draining: bool,
454 drain_pings: VecDeque<u64>,
455}
456
457impl ConnectionHandler {
458 pub(crate) fn new(
459 connection: Connection,
460 connector: Connector,
461 info_sender: tokio::sync::watch::Sender<ServerInfo>,
462 ping_period: Duration,
463 ) -> ConnectionHandler {
464 let mut ping_interval = interval(ping_period);
465 ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
466
467 ConnectionHandler {
468 connection,
469 connector,
470 subscriptions: HashMap::new(),
471 multiplexer: None,
472 pending_pings: 0,
473 info_sender,
474 ping_interval,
475 should_reconnect: false,
476 flush_observers: Vec::new(),
477 is_draining: false,
478 drain_pings: VecDeque::new(),
479 }
480 }
481
482 pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
483 struct ProcessFut<'a> {
484 handler: &'a mut ConnectionHandler,
485 receiver: &'a mut mpsc::Receiver<Command>,
486 recv_buf: &'a mut Vec<Command>,
487 }
488
489 enum ExitReason {
490 Disconnected(Option<io::Error>),
491 ReconnectRequested,
492 Closed,
493 }
494
495 impl ProcessFut<'_> {
496 const RECV_CHUNK_SIZE: usize = 16;
497
498 #[cold]
499 fn ping(&mut self) -> Poll<ExitReason> {
500 self.handler.pending_pings += 1;
501
502 if self.handler.pending_pings > MAX_PENDING_PINGS {
503 debug!(
504 pending_pings = self.handler.pending_pings,
505 max_pings = MAX_PENDING_PINGS,
506 "disconnecting due to too many pending pings"
507 );
508
509 Poll::Ready(ExitReason::Disconnected(None))
510 } else {
511 self.handler.connection.enqueue_write_op(&ClientOp::Ping);
512
513 Poll::Pending
514 }
515 }
516 }
517
518 impl Future for ProcessFut<'_> {
519 type Output = ExitReason;
520
521 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
535 while self.handler.ping_interval.poll_tick(cx).is_ready() {
539 if let Poll::Ready(exit) = self.ping() {
540 return Poll::Ready(exit);
541 }
542 }
543
544 loop {
545 match self.handler.connection.poll_read_op(cx) {
546 Poll::Pending => break,
547 Poll::Ready(Ok(Some(server_op))) => {
548 self.handler.handle_server_op(server_op);
549 }
550 Poll::Ready(Ok(None)) => {
551 return Poll::Ready(ExitReason::Disconnected(None))
552 }
553 Poll::Ready(Err(err)) => {
554 return Poll::Ready(ExitReason::Disconnected(Some(err)))
555 }
556 }
557 }
558
559 while let Some(sid) = self.handler.drain_pings.pop_front() {
564 self.handler.subscriptions.remove(&sid);
565 }
566
567 if self.handler.is_draining {
568 return Poll::Ready(ExitReason::Closed);
573 }
574
575 let mut made_progress = true;
581 loop {
582 while !self.handler.connection.is_write_buf_full() {
583 debug_assert!(self.recv_buf.is_empty());
584
585 let Self {
586 recv_buf,
587 handler,
588 receiver,
589 } = &mut *self;
590 match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
591 Poll::Pending => break,
592 Poll::Ready(1..) => {
593 made_progress = true;
594
595 for cmd in recv_buf.drain(..) {
596 handler.handle_command(cmd);
597 }
598 }
599 Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
601 }
602 }
603
604 if !mem::take(&mut made_progress) {
615 break;
616 }
617
618 match self.handler.connection.poll_write(cx) {
619 Poll::Pending => {
620 break;
622 }
623 Poll::Ready(Ok(())) => {
624 continue;
626 }
627 Poll::Ready(Err(err)) => {
628 return Poll::Ready(ExitReason::Disconnected(Some(err)))
629 }
630 }
631 }
632
633 if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
634 self.handler.connection.should_flush(),
635 self.handler.flush_observers.is_empty(),
636 ) {
637 match self.handler.connection.poll_flush(cx) {
638 Poll::Pending => {}
639 Poll::Ready(Ok(())) => {
640 for observer in self.handler.flush_observers.drain(..) {
641 let _ = observer.send(());
642 }
643 }
644 Poll::Ready(Err(err)) => {
645 return Poll::Ready(ExitReason::Disconnected(Some(err)))
646 }
647 }
648 }
649
650 if mem::take(&mut self.handler.should_reconnect) {
651 return Poll::Ready(ExitReason::ReconnectRequested);
652 }
653
654 Poll::Pending
655 }
656 }
657
658 let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
659 loop {
660 let process = ProcessFut {
661 handler: self,
662 receiver,
663 recv_buf: &mut recv_buf,
664 };
665 match process.await {
666 ExitReason::Disconnected(err) => {
667 debug!(error = ?err, "disconnected");
668 if self.handle_disconnect().await.is_err() {
669 break;
670 };
671 debug!("reconnected");
672 }
673 ExitReason::Closed => {
674 self.connector.events_tx.try_send(Event::Closed).ok();
676 break;
677 }
678 ExitReason::ReconnectRequested => {
679 debug!("reconnect requested");
680 self.connection.stream.shutdown().await.ok();
682 if self.handle_disconnect().await.is_err() {
683 break;
684 };
685 }
686 }
687 }
688 }
689
690 fn handle_server_op(&mut self, server_op: ServerOp) {
691 self.ping_interval.reset();
692
693 match server_op {
694 ServerOp::Ping => {
695 debug!("received PING");
696 self.connection.enqueue_write_op(&ClientOp::Pong);
697 }
698 ServerOp::Pong => {
699 debug!("received PONG");
700 self.pending_pings = self.pending_pings.saturating_sub(1);
701 }
702 ServerOp::Error(error) => {
703 debug!("received ERROR: {:?}", error);
704 self.connector
705 .events_tx
706 .try_send(Event::ServerError(error))
707 .ok();
708 }
709 ServerOp::Message {
710 sid,
711 subject,
712 reply,
713 payload,
714 headers,
715 status,
716 description,
717 length,
718 } => {
719 debug!("received MESSAGE: sid={}, subject={}", sid, subject);
720 self.connector
721 .connect_stats
722 .in_messages
723 .add(1, Ordering::Relaxed);
724
725 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
726 let message: Message = Message {
727 subject,
728 reply,
729 payload,
730 headers,
731 status,
732 description,
733 length,
734 };
735
736 match subscription.sender.try_send(message) {
739 Ok(_) => {
740 subscription.delivered += 1;
741 if let Some(max) = subscription.max {
745 if subscription.delivered.ge(&max) {
746 debug!("max messages reached for subscription {}", sid);
747 self.subscriptions.remove(&sid);
748 }
749 }
750 }
751 Err(mpsc::error::TrySendError::Full(_)) => {
752 debug!("slow consumer detected for subscription {}", sid);
753 self.connector
754 .events_tx
755 .try_send(Event::SlowConsumer(sid))
756 .ok();
757 }
758 Err(mpsc::error::TrySendError::Closed(_)) => {
759 debug!("subscription {} channel closed", sid);
760 self.subscriptions.remove(&sid);
761 self.connection
762 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
763 }
764 }
765 } else if sid == MULTIPLEXER_SID {
766 debug!("received message for multiplexer");
767 if let Some(multiplexer) = self.multiplexer.as_mut() {
768 let maybe_token =
769 subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
770
771 if let Some(token) = maybe_token {
772 if let Some(sender) = multiplexer.senders.remove(token) {
773 debug!("forwarding message to request with token {}", token);
774 let message = Message {
775 subject,
776 reply,
777 payload,
778 headers,
779 status,
780 description,
781 length,
782 };
783
784 let _ = sender.send(message);
785 }
786 }
787 }
788 }
789 }
790 ServerOp::Info(info) => {
792 debug!("received INFO: server_id={}", info.server_id);
793 if info.lame_duck_mode {
794 debug!("server in lame duck mode");
795 self.connector.events_tx.try_send(Event::LameDuckMode).ok();
796 }
797 }
798
799 _ => {
800 }
802 }
803 }
804
805 fn handle_command(&mut self, command: Command) {
806 self.ping_interval.reset();
807
808 match command {
809 Command::Unsubscribe { sid, max } => {
810 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
811 subscription.max = max;
812 match subscription.max {
813 Some(n) => {
814 if subscription.delivered >= n {
815 self.subscriptions.remove(&sid);
816 }
817 }
818 None => {
819 self.subscriptions.remove(&sid);
820 }
821 }
822
823 self.connection
824 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
825 }
826 }
827 Command::Flush { observer } => {
828 self.flush_observers.push(observer);
829 }
830 Command::Drain { sid } => {
831 let mut drain_sub = |sid: u64| {
832 self.drain_pings.push_back(sid);
833 self.connection
834 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
835 };
836
837 if let Some(sid) = sid {
838 if self.subscriptions.get_mut(&sid).is_some() {
839 drain_sub(sid);
840 }
841 } else {
842 self.connector.events_tx.try_send(Event::Draining).ok();
844 self.is_draining = true;
845 for (&sid, _) in self.subscriptions.iter_mut() {
846 drain_sub(sid);
847 }
848 }
849 self.connection.enqueue_write_op(&ClientOp::Ping);
850 }
851 Command::Subscribe {
852 sid,
853 subject,
854 queue_group,
855 sender,
856 } => {
857 let subscription = Subscription {
858 sender,
859 delivered: 0,
860 max: None,
861 subject: subject.to_owned(),
862 queue_group: queue_group.to_owned(),
863 };
864
865 self.subscriptions.insert(sid, subscription);
866
867 self.connection.enqueue_write_op(&ClientOp::Subscribe {
868 sid,
869 subject,
870 queue_group,
871 });
872 }
873 Command::Request {
874 subject,
875 payload,
876 respond,
877 headers,
878 sender,
879 } => {
880 let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
881
882 let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
883 multiplexer
884 } else {
885 let prefix = Subject::from(format!("{}.{}.", prefix, id_generator::next()));
886 let subject = Subject::from(format!("{prefix}*"));
887
888 self.connection.enqueue_write_op(&ClientOp::Subscribe {
889 sid: MULTIPLEXER_SID,
890 subject: subject.clone(),
891 queue_group: None,
892 });
893
894 self.multiplexer.insert(Multiplexer {
895 subject,
896 prefix,
897 senders: HashMap::new(),
898 })
899 };
900 self.connector
901 .connect_stats
902 .out_messages
903 .add(1, Ordering::Relaxed);
904
905 multiplexer.senders.insert(token.to_owned(), sender);
906
907 let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
908
909 let pub_op = ClientOp::Publish {
910 subject,
911 payload,
912 respond: Some(respond),
913 headers,
914 };
915
916 self.connection.enqueue_write_op(&pub_op);
917 }
918
919 Command::Publish(OutboundMessage {
920 subject,
921 payload,
922 reply: respond,
923 headers,
924 }) => {
925 self.connector
926 .connect_stats
927 .out_messages
928 .add(1, Ordering::Relaxed);
929
930 let header_len = headers
931 .as_ref()
932 .map(|headers| headers.len())
933 .unwrap_or_default();
934
935 self.connector.connect_stats.out_bytes.add(
936 (payload.len()
937 + respond.as_ref().map_or_else(|| 0, |r| r.len())
938 + subject.len()
939 + header_len) as u64,
940 Ordering::Relaxed,
941 );
942
943 self.connection.enqueue_write_op(&ClientOp::Publish {
944 subject,
945 payload,
946 respond,
947 headers,
948 });
949 }
950
951 Command::Reconnect => {
952 self.should_reconnect = true;
953 }
954 }
955 }
956
957 async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
958 self.pending_pings = 0;
959 self.connector.events_tx.try_send(Event::Disconnected).ok();
960 self.connector.state_tx.send(State::Disconnected).ok();
961
962 self.handle_reconnect().await
963 }
964
965 async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
966 let (info, connection) = self.connector.connect().await?;
967 self.connection = connection;
968 let _ = self.info_sender.send(info);
969
970 self.subscriptions
971 .retain(|_, subscription| !subscription.sender.is_closed());
972
973 for (sid, subscription) in &self.subscriptions {
974 self.connection.enqueue_write_op(&ClientOp::Subscribe {
975 sid: *sid,
976 subject: subscription.subject.to_owned(),
977 queue_group: subscription.queue_group.to_owned(),
978 });
979 }
980
981 if let Some(multiplexer) = &self.multiplexer {
982 self.connection.enqueue_write_op(&ClientOp::Subscribe {
983 sid: MULTIPLEXER_SID,
984 subject: multiplexer.subject.to_owned(),
985 queue_group: None,
986 });
987 }
988 Ok(())
989 }
990}
991
992pub async fn connect_with_options<A: ToServerAddrs>(
1008 addrs: A,
1009 options: ConnectOptions,
1010) -> Result<Client, ConnectError> {
1011 let ping_period = options.ping_interval;
1012
1013 let (events_tx, mut events_rx) = mpsc::channel(128);
1014 let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
1015 let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));
1017 let statistics = Arc::new(Statistics::default());
1018
1019 let mut connector = Connector::new(
1020 addrs,
1021 ConnectorOptions {
1022 tls_required: options.tls_required,
1023 certificates: options.certificates,
1024 client_key: options.client_key,
1025 client_cert: options.client_cert,
1026 tls_client_config: options.tls_client_config,
1027 tls_first: options.tls_first,
1028 auth: options.auth,
1029 no_echo: options.no_echo,
1030 connection_timeout: options.connection_timeout,
1031 name: options.name,
1032 ignore_discovered_servers: options.ignore_discovered_servers,
1033 retain_servers_order: options.retain_servers_order,
1034 read_buffer_capacity: options.read_buffer_capacity,
1035 reconnect_delay_callback: options.reconnect_delay_callback,
1036 auth_callback: options.auth_callback,
1037 max_reconnects: options.max_reconnects,
1038 },
1039 events_tx,
1040 state_tx,
1041 max_payload.clone(),
1042 statistics.clone(),
1043 )
1044 .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
1045
1046 let mut info: ServerInfo = Default::default();
1047 let mut connection = None;
1048 if !options.retry_on_initial_connect {
1049 debug!("retry on initial connect failure is disabled");
1050 let (info_ok, connection_ok) = connector.try_connect().await?;
1051 connection = Some(connection_ok);
1052 info = info_ok;
1053 }
1054
1055 let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
1056 let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
1057
1058 let client = Client::new(
1059 info_watcher,
1060 state_rx,
1061 sender,
1062 options.subscription_capacity,
1063 options.inbox_prefix,
1064 options.request_timeout,
1065 max_payload,
1066 statistics,
1067 );
1068
1069 task::spawn(async move {
1070 while let Some(event) = events_rx.recv().await {
1071 tracing::info!("event: {}", event);
1072 if let Some(event_callback) = &options.event_callback {
1073 event_callback.call(event).await;
1074 }
1075 }
1076 });
1077
1078 task::spawn(async move {
1079 if connection.is_none() && options.retry_on_initial_connect {
1080 let (info, connection_ok) = match connector.connect().await {
1081 Ok((info, connection)) => (info, connection),
1082 Err(err) => {
1083 error!("connection closed: {}", err);
1084 return;
1085 }
1086 };
1087 info_sender.send(info).ok();
1088 connection = Some(connection_ok);
1089 }
1090 let connection = connection.unwrap();
1091 let mut connection_handler =
1092 ConnectionHandler::new(connection, connector, info_sender, ping_period);
1093 connection_handler.process(&mut receiver).await
1094 });
1095
1096 Ok(client)
1097}
1098
1099#[derive(Debug, Clone, PartialEq, Eq)]
1100pub enum Event {
1101 Connected,
1102 Disconnected,
1103 LameDuckMode,
1104 Draining,
1105 Closed,
1106 SlowConsumer(u64),
1107 ServerError(ServerError),
1108 ClientError(ClientError),
1109}
1110
1111impl fmt::Display for Event {
1112 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1113 match self {
1114 Event::Connected => write!(f, "connected"),
1115 Event::Disconnected => write!(f, "disconnected"),
1116 Event::LameDuckMode => write!(f, "lame duck mode detected"),
1117 Event::Draining => write!(f, "draining"),
1118 Event::Closed => write!(f, "closed"),
1119 Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"),
1120 Event::ServerError(err) => write!(f, "server error: {err}"),
1121 Event::ClientError(err) => write!(f, "client error: {err}"),
1122 }
1123 }
1124}
1125
1126pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1193 connect_with_options(addrs, ConnectOptions::default()).await
1194}
1195
1196#[derive(Debug, Clone, Copy, PartialEq)]
1197pub enum ConnectErrorKind {
1198 ServerParse,
1200 Dns,
1202 Authentication,
1204 AuthorizationViolation,
1206 TimedOut,
1208 Tls,
1210 Io,
1212 MaxReconnects,
1214}
1215
1216impl Display for ConnectErrorKind {
1217 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1218 match self {
1219 Self::ServerParse => write!(f, "failed to parse server or server list"),
1220 Self::Dns => write!(f, "DNS error"),
1221 Self::Authentication => write!(f, "failed signing nonce"),
1222 Self::AuthorizationViolation => write!(f, "authorization violation"),
1223 Self::TimedOut => write!(f, "timed out"),
1224 Self::Tls => write!(f, "TLS error"),
1225 Self::Io => write!(f, "IO error"),
1226 Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1227 }
1228 }
1229}
1230
1231pub type ConnectError = error::Error<ConnectErrorKind>;
1234
1235impl From<io::Error> for ConnectError {
1236 fn from(err: io::Error) -> Self {
1237 ConnectError::with_source(ConnectErrorKind::Io, err)
1238 }
1239}
1240
1241#[derive(Debug)]
1255pub struct Subscriber {
1256 sid: u64,
1257 receiver: mpsc::Receiver<Message>,
1258 sender: mpsc::Sender<Command>,
1259}
1260
1261impl Subscriber {
1262 fn new(
1263 sid: u64,
1264 sender: mpsc::Sender<Command>,
1265 receiver: mpsc::Receiver<Message>,
1266 ) -> Subscriber {
1267 Subscriber {
1268 sid,
1269 sender,
1270 receiver,
1271 }
1272 }
1273
1274 pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1289 self.sender
1290 .send(Command::Unsubscribe {
1291 sid: self.sid,
1292 max: None,
1293 })
1294 .await?;
1295 self.receiver.close();
1296 Ok(())
1297 }
1298
1299 pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1325 self.sender
1326 .send(Command::Unsubscribe {
1327 sid: self.sid,
1328 max: Some(unsub_after),
1329 })
1330 .await?;
1331 Ok(())
1332 }
1333
1334 pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
1367 self.sender
1368 .send(Command::Drain {
1369 sid: Some(self.sid),
1370 })
1371 .await?;
1372
1373 Ok(())
1374 }
1375}
1376
1377#[derive(Error, Debug, PartialEq)]
1378#[error("failed to send unsubscribe")]
1379pub struct UnsubscribeError(String);
1380
1381impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1382 fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1383 UnsubscribeError(err.to_string())
1384 }
1385}
1386
1387impl Drop for Subscriber {
1388 fn drop(&mut self) {
1389 self.receiver.close();
1390 tokio::spawn({
1391 let sender = self.sender.clone();
1392 let sid = self.sid;
1393 async move {
1394 sender
1395 .send(Command::Unsubscribe { sid, max: None })
1396 .await
1397 .ok();
1398 }
1399 });
1400 }
1401}
1402
1403impl Stream for Subscriber {
1404 type Item = Message;
1405
1406 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1407 self.receiver.poll_recv(cx)
1408 }
1409}
1410
1411#[derive(Clone, Debug, Eq, PartialEq)]
1412pub enum CallbackError {
1413 Client(ClientError),
1414 Server(ServerError),
1415}
1416impl std::fmt::Display for CallbackError {
1417 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1418 match self {
1419 Self::Client(error) => write!(f, "{error}"),
1420 Self::Server(error) => write!(f, "{error}"),
1421 }
1422 }
1423}
1424
1425impl From<ServerError> for CallbackError {
1426 fn from(server_error: ServerError) -> Self {
1427 CallbackError::Server(server_error)
1428 }
1429}
1430
1431impl From<ClientError> for CallbackError {
1432 fn from(client_error: ClientError) -> Self {
1433 CallbackError::Client(client_error)
1434 }
1435}
1436
1437#[derive(Clone, Debug, Eq, PartialEq, Error)]
1438pub enum ServerError {
1439 AuthorizationViolation,
1440 SlowConsumer(u64),
1441 Other(String),
1442}
1443
1444#[derive(Clone, Debug, Eq, PartialEq)]
1445pub enum ClientError {
1446 Other(String),
1447 MaxReconnects,
1448}
1449impl std::fmt::Display for ClientError {
1450 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1451 match self {
1452 Self::Other(error) => write!(f, "nats: {error}"),
1453 Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1454 }
1455 }
1456}
1457
1458impl ServerError {
1459 fn new(error: String) -> ServerError {
1460 match error.to_lowercase().as_str() {
1461 "authorization violation" => ServerError::AuthorizationViolation,
1462 _ => ServerError::Other(error),
1464 }
1465 }
1466}
1467
1468impl std::fmt::Display for ServerError {
1469 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1470 match self {
1471 Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1472 Self::SlowConsumer(sid) => write!(f, "nats: subscription {sid} is a slow consumer"),
1473 Self::Other(error) => write!(f, "nats: {error}"),
1474 }
1475 }
1476}
1477
1478#[derive(Clone, Debug, Serialize)]
1480pub struct ConnectInfo {
1481 pub verbose: bool,
1483
1484 pub pedantic: bool,
1487
1488 #[serde(rename = "jwt")]
1490 pub user_jwt: Option<String>,
1491
1492 pub nkey: Option<String>,
1494
1495 #[serde(rename = "sig")]
1497 pub signature: Option<String>,
1498
1499 pub name: Option<String>,
1501
1502 pub echo: bool,
1507
1508 pub lang: String,
1510
1511 pub version: String,
1513
1514 pub protocol: Protocol,
1519
1520 pub tls_required: bool,
1522
1523 pub user: Option<String>,
1525
1526 pub pass: Option<String>,
1528
1529 pub auth_token: Option<String>,
1531
1532 pub headers: bool,
1534
1535 pub no_responders: bool,
1537}
1538
1539#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1541#[repr(u8)]
1542pub enum Protocol {
1543 Original = 0,
1545 Dynamic = 1,
1547}
1548
1549#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1551pub struct ServerAddr(Url);
1552
1553impl FromStr for ServerAddr {
1554 type Err = io::Error;
1555
1556 fn from_str(input: &str) -> Result<Self, Self::Err> {
1560 let url: Url = if input.contains("://") {
1561 input.parse()
1562 } else {
1563 format!("nats://{input}").parse()
1564 }
1565 .map_err(|e| {
1566 io::Error::new(
1567 ErrorKind::InvalidInput,
1568 format!("NATS server URL is invalid: {e}"),
1569 )
1570 })?;
1571
1572 Self::from_url(url)
1573 }
1574}
1575
1576impl ServerAddr {
1577 pub fn from_url(url: Url) -> io::Result<Self> {
1579 if url.scheme() != "nats"
1580 && url.scheme() != "tls"
1581 && url.scheme() != "ws"
1582 && url.scheme() != "wss"
1583 {
1584 return Err(std::io::Error::new(
1585 ErrorKind::InvalidInput,
1586 format!("invalid scheme for NATS server URL: {}", url.scheme()),
1587 ));
1588 }
1589
1590 Ok(Self(url))
1591 }
1592
1593 pub fn into_inner(self) -> Url {
1595 self.0
1596 }
1597
1598 pub fn tls_required(&self) -> bool {
1600 self.0.scheme() == "tls"
1601 }
1602
1603 pub fn has_user_pass(&self) -> bool {
1605 self.0.username() != ""
1606 }
1607
1608 pub fn scheme(&self) -> &str {
1609 self.0.scheme()
1610 }
1611
1612 pub fn host(&self) -> &str {
1614 match self.0.host() {
1615 Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1616 Some(Host::Ipv6 { .. }) => {
1618 let host = self.0.host_str().unwrap();
1619 &host[1..host.len() - 1]
1620 }
1621 None => "",
1622 }
1623 }
1624
1625 pub fn is_websocket(&self) -> bool {
1626 self.0.scheme() == "ws" || self.0.scheme() == "wss"
1627 }
1628
1629 pub fn port(&self) -> u16 {
1632 self.0.port_or_known_default().unwrap_or(4222)
1633 }
1634
1635 pub fn as_url_str(&self) -> &str {
1637 self.0.as_str()
1638 }
1639
1640 pub fn username(&self) -> Option<&str> {
1642 let user = self.0.username();
1643 if user.is_empty() {
1644 None
1645 } else {
1646 Some(user)
1647 }
1648 }
1649
1650 pub fn password(&self) -> Option<&str> {
1652 self.0.password()
1653 }
1654
1655 pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1657 tokio::net::lookup_host((self.host(), self.port())).await
1658 }
1659}
1660
1661pub trait ToServerAddrs {
1666 type Iter: Iterator<Item = ServerAddr>;
1669
1670 fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1671}
1672
1673impl ToServerAddrs for ServerAddr {
1674 type Iter = option::IntoIter<ServerAddr>;
1675 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1676 Ok(Some(self.clone()).into_iter())
1677 }
1678}
1679
1680impl ToServerAddrs for str {
1681 type Iter = option::IntoIter<ServerAddr>;
1682 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1683 self.parse::<ServerAddr>()
1684 .map(|addr| Some(addr).into_iter())
1685 }
1686}
1687
1688impl ToServerAddrs for String {
1689 type Iter = option::IntoIter<ServerAddr>;
1690 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1691 (**self).to_server_addrs()
1692 }
1693}
1694
1695impl<T: AsRef<str>> ToServerAddrs for [T] {
1696 type Iter = std::vec::IntoIter<ServerAddr>;
1697 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1698 self.iter()
1699 .map(AsRef::as_ref)
1700 .map(str::parse)
1701 .collect::<io::Result<_>>()
1702 .map(Vec::into_iter)
1703 }
1704}
1705
1706impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1707 type Iter = std::vec::IntoIter<ServerAddr>;
1708 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1709 self.as_slice().to_server_addrs()
1710 }
1711}
1712
1713impl<'a> ToServerAddrs for &'a [ServerAddr] {
1714 type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1715
1716 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1717 Ok(self.iter().cloned())
1718 }
1719}
1720
1721impl ToServerAddrs for Vec<ServerAddr> {
1722 type Iter = std::vec::IntoIter<ServerAddr>;
1723
1724 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1725 Ok(self.clone().into_iter())
1726 }
1727}
1728
1729impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1730 type Iter = T::Iter;
1731 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1732 (**self).to_server_addrs()
1733 }
1734}
1735
1736#[allow(dead_code)]
1737pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1738 let subject_str = subject.as_ref();
1739 !subject_str.starts_with('.')
1740 && !subject_str.ends_with('.')
1741 && subject_str.bytes().all(|c| !c.is_ascii_whitespace())
1742}
1743#[allow(unused_macros)]
1744macro_rules! from_with_timeout {
1745 ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1746 impl From<$origin> for $t {
1747 fn from(err: $origin) -> Self {
1748 match err.kind() {
1749 <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1750 _ => Self::with_source(<$k>::Other, err),
1751 }
1752 }
1753 }
1754 };
1755}
1756#[allow(unused_imports)]
1757pub(crate) use from_with_timeout;
1758
1759use crate::connection::ShouldFlush;
1760use crate::message::OutboundMessage;
1761
1762#[cfg(test)]
1763mod tests {
1764 use super::*;
1765
1766 #[test]
1767 fn server_address_ipv6() {
1768 let address = ServerAddr::from_str("nats://[::]").unwrap();
1769 assert_eq!(address.host(), "::")
1770 }
1771
1772 #[test]
1773 fn server_address_ipv4() {
1774 let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1775 assert_eq!(address.host(), "127.0.0.1")
1776 }
1777
1778 #[test]
1779 fn server_address_domain() {
1780 let address = ServerAddr::from_str("nats://example.com").unwrap();
1781 assert_eq!(address.host(), "example.com")
1782 }
1783
1784 #[test]
1785 fn to_server_addrs_vec_str() {
1786 let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1787 let mut addrs_iter = vec.to_server_addrs().unwrap();
1788 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1789 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1790 assert_eq!(addrs_iter.next(), None);
1791 }
1792
1793 #[test]
1794 fn to_server_addrs_arr_str() {
1795 let arr = ["nats://127.0.0.1", "nats://[::]"];
1796 let mut addrs_iter = arr.to_server_addrs().unwrap();
1797 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1798 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1799 assert_eq!(addrs_iter.next(), None);
1800 }
1801
1802 #[test]
1803 fn to_server_addrs_vec_string() {
1804 let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1805 let mut addrs_iter = vec.to_server_addrs().unwrap();
1806 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1807 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1808 assert_eq!(addrs_iter.next(), None);
1809 }
1810
1811 #[test]
1812 fn to_server_addrs_arr_string() {
1813 let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1814 let mut addrs_iter = arr.to_server_addrs().unwrap();
1815 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1816 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1817 assert_eq!(addrs_iter.next(), None);
1818 }
1819
1820 #[test]
1821 fn to_server_ports_arr_string() {
1822 for (arr, expected_port) in [
1823 (
1824 [
1825 "nats://127.0.0.1".to_string(),
1826 "nats://[::]".to_string(),
1827 "tls://127.0.0.1".to_string(),
1828 "tls://[::]".to_string(),
1829 ],
1830 4222,
1831 ),
1832 (
1833 [
1834 "ws://127.0.0.1:80".to_string(),
1835 "ws://[::]:80".to_string(),
1836 "ws://127.0.0.1".to_string(),
1837 "ws://[::]".to_string(),
1838 ],
1839 80,
1840 ),
1841 (
1842 [
1843 "wss://127.0.0.1".to_string(),
1844 "wss://[::]".to_string(),
1845 "wss://127.0.0.1:443".to_string(),
1846 "wss://[::]:443".to_string(),
1847 ],
1848 443,
1849 ),
1850 ] {
1851 let mut addrs_iter = arr.to_server_addrs().unwrap();
1852 assert_eq!(addrs_iter.next().unwrap().port(), expected_port);
1853 }
1854 }
1855}