1#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_cfg))]
196
197use thiserror::Error;
198
199use futures_util::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use portable_atomic::AtomicU64;
206use std::collections::HashMap;
207use std::collections::VecDeque;
208use std::fmt::Display;
209use std::future::Future;
210use std::iter;
211use std::mem;
212use std::net::SocketAddr;
213use std::option;
214use std::pin::Pin;
215use std::slice;
216use std::str::{self, FromStr};
217use std::sync::atomic::AtomicUsize;
218use std::sync::atomic::Ordering;
219use std::sync::Arc;
220use std::task::{Context, Poll};
221use tokio::io::ErrorKind;
222use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
223use url::{Host, Url};
224
225use bytes::Bytes;
226use serde::{Deserialize, Serialize};
227use serde_repr::{Deserialize_repr, Serialize_repr};
228use tokio::io;
229use tokio::sync::mpsc;
230use tokio::task;
231
232pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
233
234const VERSION: &str = env!("CARGO_PKG_VERSION");
235const LANG: &str = "rust";
236const MAX_PENDING_PINGS: usize = 2;
237const MULTIPLEXER_SID: u64 = 0;
238
239pub use tokio_rustls::rustls;
243
244use connection::{Connection, State};
245use connector::{Connector, ConnectorOptions};
246pub use header::{HeaderMap, HeaderName, HeaderValue};
247pub use subject::Subject;
248
249mod auth;
250pub(crate) mod auth_utils;
251pub mod client;
252pub mod connection;
253mod connector;
254mod options;
255
256pub use auth::Auth;
257pub use client::{
258 Client, PublishError, Request, RequestError, RequestErrorKind, Statistics, SubscribeError,
259};
260pub use options::{AuthError, ConnectOptions};
261
262#[cfg(feature = "crypto")]
263#[cfg_attr(docsrs, doc(cfg(feature = "crypto")))]
264mod crypto;
265pub mod error;
266pub mod header;
267mod id_generator;
268#[cfg(feature = "jetstream")]
269#[cfg_attr(docsrs, doc(cfg(feature = "jetstream")))]
270pub mod jetstream;
271pub mod message;
272#[cfg(feature = "service")]
273#[cfg_attr(docsrs, doc(cfg(feature = "service")))]
274pub mod service;
275pub mod status;
276pub mod subject;
277mod tls;
278
279pub use message::Message;
280pub use status::StatusCode;
281
282#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
285pub struct ServerInfo {
286 #[serde(default)]
288 pub server_id: String,
289 #[serde(default)]
291 pub server_name: String,
292 #[serde(default)]
294 pub host: String,
295 #[serde(default)]
297 pub port: u16,
298 #[serde(default)]
300 pub version: String,
301 #[serde(default)]
304 pub auth_required: bool,
305 #[serde(default)]
307 pub tls_required: bool,
308 #[serde(default)]
310 pub max_payload: usize,
311 #[serde(default)]
313 pub proto: i8,
314 #[serde(default)]
316 pub client_id: u64,
317 #[serde(default)]
319 pub go: String,
320 #[serde(default)]
322 pub nonce: String,
323 #[serde(default)]
325 pub connect_urls: Vec<String>,
326 #[serde(default)]
328 pub client_ip: String,
329 #[serde(default)]
331 pub headers: bool,
332 #[serde(default, rename = "ldm")]
334 pub lame_duck_mode: bool,
335 #[serde(default)]
337 pub cluster: Option<String>,
338 #[serde(default)]
340 pub domain: Option<String>,
341 #[serde(default)]
343 pub jetstream: bool,
344}
345
346#[derive(Clone, Debug, Eq, PartialEq)]
347pub(crate) enum ServerOp {
348 Ok,
349 Info(Box<ServerInfo>),
350 Ping,
351 Pong,
352 Error(ServerError),
353 Message {
354 sid: u64,
355 subject: Subject,
356 reply: Option<Subject>,
357 payload: Bytes,
358 headers: Option<HeaderMap>,
359 status: Option<StatusCode>,
360 description: Option<String>,
361 length: usize,
362 },
363}
364
365#[deprecated(
369 since = "0.44.0",
370 note = "use `async_nats::message::OutboundMessage` instead"
371)]
372pub type PublishMessage = crate::message::OutboundMessage;
373
374#[derive(Debug)]
376pub(crate) enum Command {
377 Publish(OutboundMessage),
378 Request {
379 subject: Subject,
380 payload: Bytes,
381 respond: Subject,
382 headers: Option<HeaderMap>,
383 sender: oneshot::Sender<Message>,
384 },
385 Subscribe {
386 sid: u64,
387 subject: Subject,
388 queue_group: Option<String>,
389 sender: mpsc::Sender<Message>,
390 statistics: Arc<SubscriberStatistics>,
391 },
392 Unsubscribe {
393 sid: u64,
394 max: Option<u64>,
395 },
396 Flush {
397 observer: oneshot::Sender<()>,
398 },
399 Drain {
400 sid: Option<u64>,
401 },
402 Reconnect,
403}
404
405#[derive(Debug)]
407pub(crate) enum ClientOp {
408 Publish {
409 subject: Subject,
410 payload: Bytes,
411 respond: Option<Subject>,
412 headers: Option<HeaderMap>,
413 },
414 Subscribe {
415 sid: u64,
416 subject: Subject,
417 queue_group: Option<String>,
418 },
419 Unsubscribe {
420 sid: u64,
421 max: Option<u64>,
422 },
423 Ping,
424 Pong,
425 Connect(ConnectInfo),
426}
427
428#[derive(Debug)]
429struct Subscription {
430 subject: Subject,
431 sender: mpsc::Sender<Message>,
432 statistics: Arc<SubscriberStatistics>,
433 queue_group: Option<String>,
434 delivered: u64,
435 max: Option<u64>,
436}
437
438#[derive(Debug)]
439struct Multiplexer {
440 subject: Subject,
441 prefix: Subject,
442 senders: HashMap<String, oneshot::Sender<Message>>,
443}
444
445pub(crate) struct ConnectionHandler {
447 connection: Connection,
448 connector: Connector,
449 subscriptions: HashMap<u64, Subscription>,
450 multiplexer: Option<Multiplexer>,
451 pending_pings: usize,
452 info_sender: tokio::sync::watch::Sender<ServerInfo>,
453 ping_interval: Interval,
454 should_reconnect: bool,
455 flush_observers: Vec<oneshot::Sender<()>>,
456 is_draining: bool,
457 drain_pings: VecDeque<u64>,
458}
459
460impl ConnectionHandler {
461 pub(crate) fn new(
462 connection: Connection,
463 connector: Connector,
464 info_sender: tokio::sync::watch::Sender<ServerInfo>,
465 ping_period: Duration,
466 ) -> ConnectionHandler {
467 let mut ping_interval = interval(ping_period);
468 ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
469
470 ConnectionHandler {
471 connection,
472 connector,
473 subscriptions: HashMap::new(),
474 multiplexer: None,
475 pending_pings: 0,
476 info_sender,
477 ping_interval,
478 should_reconnect: false,
479 flush_observers: Vec::new(),
480 is_draining: false,
481 drain_pings: VecDeque::new(),
482 }
483 }
484
485 pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
486 struct ProcessFut<'a> {
487 handler: &'a mut ConnectionHandler,
488 receiver: &'a mut mpsc::Receiver<Command>,
489 recv_buf: &'a mut Vec<Command>,
490 }
491
492 enum ExitReason {
493 Disconnected(Option<io::Error>),
494 ReconnectRequested,
495 Closed,
496 }
497
498 impl ProcessFut<'_> {
499 const RECV_CHUNK_SIZE: usize = 16;
500
501 #[cold]
502 fn ping(&mut self) -> Poll<ExitReason> {
503 self.handler.pending_pings += 1;
504
505 if self.handler.pending_pings > MAX_PENDING_PINGS {
506 debug!(
507 pending_pings = self.handler.pending_pings,
508 max_pings = MAX_PENDING_PINGS,
509 "disconnecting due to too many pending pings"
510 );
511
512 Poll::Ready(ExitReason::Disconnected(None))
513 } else {
514 self.handler.connection.enqueue_write_op(&ClientOp::Ping);
515
516 Poll::Pending
517 }
518 }
519 }
520
521 impl Future for ProcessFut<'_> {
522 type Output = ExitReason;
523
524 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
538 while self.handler.ping_interval.poll_tick(cx).is_ready() {
542 if let Poll::Ready(exit) = self.ping() {
543 return Poll::Ready(exit);
544 }
545 }
546
547 loop {
548 match self.handler.connection.poll_read_op(cx) {
549 Poll::Pending => break,
550 Poll::Ready(Ok(Some(server_op))) => {
551 self.handler.handle_server_op(server_op);
552 }
553 Poll::Ready(Ok(None)) => {
554 return Poll::Ready(ExitReason::Disconnected(None))
555 }
556 Poll::Ready(Err(err)) => {
557 return Poll::Ready(ExitReason::Disconnected(Some(err)))
558 }
559 }
560 }
561
562 while let Some(sid) = self.handler.drain_pings.pop_front() {
567 self.handler.subscriptions.remove(&sid);
568 }
569
570 if self.handler.is_draining {
571 return Poll::Ready(ExitReason::Closed);
576 }
577
578 let mut made_progress = true;
584 loop {
585 while !self.handler.connection.is_write_buf_full() {
586 debug_assert!(self.recv_buf.is_empty());
587
588 let Self {
589 recv_buf,
590 handler,
591 receiver,
592 } = &mut *self;
593 match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
594 Poll::Pending => break,
595 Poll::Ready(1..) => {
596 made_progress = true;
597
598 for cmd in recv_buf.drain(..) {
599 handler.handle_command(cmd);
600 }
601 }
602 Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
604 }
605 }
606
607 if !mem::take(&mut made_progress) {
618 break;
619 }
620
621 match self.handler.connection.poll_write(cx) {
622 Poll::Pending => {
623 break;
625 }
626 Poll::Ready(Ok(())) => {
627 continue;
629 }
630 Poll::Ready(Err(err)) => {
631 return Poll::Ready(ExitReason::Disconnected(Some(err)))
632 }
633 }
634 }
635
636 if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
637 self.handler.connection.should_flush(),
638 self.handler.flush_observers.is_empty(),
639 ) {
640 match self.handler.connection.poll_flush(cx) {
641 Poll::Pending => {}
642 Poll::Ready(Ok(())) => {
643 for observer in self.handler.flush_observers.drain(..) {
644 let _ = observer.send(());
645 }
646 }
647 Poll::Ready(Err(err)) => {
648 return Poll::Ready(ExitReason::Disconnected(Some(err)))
649 }
650 }
651 }
652
653 if mem::take(&mut self.handler.should_reconnect) {
654 return Poll::Ready(ExitReason::ReconnectRequested);
655 }
656
657 Poll::Pending
658 }
659 }
660
661 let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
662 loop {
663 let process = ProcessFut {
664 handler: self,
665 receiver,
666 recv_buf: &mut recv_buf,
667 };
668 match process.await {
669 ExitReason::Disconnected(err) => {
670 debug!(error = ?err, "disconnected");
671 if self.handle_disconnect().await.is_err() {
672 break;
673 };
674 debug!("reconnected");
675 }
676 ExitReason::Closed => {
677 self.connector.events_tx.try_send(Event::Closed).ok();
679 break;
680 }
681 ExitReason::ReconnectRequested => {
682 debug!("reconnect requested");
683 self.connection.stream.shutdown().await.ok();
685 if self.handle_disconnect().await.is_err() {
686 break;
687 };
688 }
689 }
690 }
691 }
692
693 fn handle_server_op(&mut self, server_op: ServerOp) {
694 self.ping_interval.reset();
695
696 match server_op {
697 ServerOp::Ping => {
698 debug!("received PING");
699 self.connection.enqueue_write_op(&ClientOp::Pong);
700 }
701 ServerOp::Pong => {
702 debug!("received PONG");
703 self.pending_pings = self.pending_pings.saturating_sub(1);
704 }
705 ServerOp::Error(error) => {
706 debug!("received ERROR: {:?}", error);
707 self.connector
708 .events_tx
709 .try_send(Event::ServerError(error))
710 .ok();
711 }
712 ServerOp::Message {
713 sid,
714 subject,
715 reply,
716 payload,
717 headers,
718 status,
719 description,
720 length,
721 } => {
722 debug!("received MESSAGE: sid={}, subject={}", sid, subject);
723 self.connector
724 .connect_stats
725 .in_messages
726 .add(1, Ordering::Relaxed);
727
728 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
729 let message: Message = Message {
730 subject,
731 reply,
732 payload,
733 headers,
734 status,
735 description,
736 length,
737 };
738
739 match subscription.sender.try_send(message) {
742 Ok(_) => {
743 subscription
744 .statistics
745 .pending_messages
746 .add(1, Ordering::Relaxed);
747 subscription
748 .statistics
749 .pending_bytes
750 .add(length as u64, Ordering::Relaxed);
751 self.connector
752 .connect_stats
753 .subscription_pending_messages
754 .add(1, Ordering::Relaxed);
755 self.connector
756 .connect_stats
757 .subscription_pending_bytes
758 .add(length as u64, Ordering::Relaxed);
759 subscription.delivered += 1;
760 if let Some(max) = subscription.max {
764 if subscription.delivered.ge(&max) {
765 debug!("max messages reached for subscription {}", sid);
766 self.subscriptions.remove(&sid);
767 }
768 }
769 }
770 Err(mpsc::error::TrySendError::Full(returned_message)) => {
771 let dropped_len = returned_message.length as u64;
772 subscription
773 .statistics
774 .dropped_messages
775 .add(1, Ordering::Relaxed);
776 subscription
777 .statistics
778 .dropped_bytes
779 .add(dropped_len, Ordering::Relaxed);
780 self.connector
781 .connect_stats
782 .subscription_dropped_messages
783 .add(1, Ordering::Relaxed);
784 self.connector
785 .connect_stats
786 .subscription_dropped_bytes
787 .add(dropped_len, Ordering::Relaxed);
788 debug!("slow consumer detected for subscription {}", sid);
789 self.connector
790 .events_tx
791 .try_send(Event::SlowConsumer(SlowConsumer {
792 sid,
793 subject: returned_message.subject,
794 }))
795 .ok();
796 }
797 Err(mpsc::error::TrySendError::Closed(_)) => {
798 debug!("subscription {} channel closed", sid);
799 self.subscriptions.remove(&sid);
800 self.connection
801 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
802 }
803 }
804 } else if sid == MULTIPLEXER_SID {
805 debug!("received message for multiplexer");
806 if let Some(multiplexer) = self.multiplexer.as_mut() {
807 let maybe_token =
808 subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
809
810 if let Some(token) = maybe_token {
811 if let Some(sender) = multiplexer.senders.remove(token) {
812 debug!("forwarding message to request with token {}", token);
813 let message = Message {
814 subject,
815 reply,
816 payload,
817 headers,
818 status,
819 description,
820 length,
821 };
822
823 let _ = sender.send(message);
824 }
825 }
826 }
827 }
828 }
829 ServerOp::Info(info) => {
831 debug!("received INFO: server_id={}", info.server_id);
832 if info.lame_duck_mode {
833 debug!("server in lame duck mode");
834 self.connector.events_tx.try_send(Event::LameDuckMode).ok();
835 }
836 }
837
838 _ => {
839 }
841 }
842 }
843
844 fn handle_command(&mut self, command: Command) {
845 self.ping_interval.reset();
846
847 match command {
848 Command::Unsubscribe { sid, max } => {
849 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
850 subscription.max = max;
851 match subscription.max {
852 Some(n) => {
853 if subscription.delivered >= n {
854 self.subscriptions.remove(&sid);
855 }
856 }
857 None => {
858 self.subscriptions.remove(&sid);
859 }
860 }
861
862 self.connection
863 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
864 }
865 }
866 Command::Flush { observer } => {
867 self.flush_observers.push(observer);
868 }
869 Command::Drain { sid } => {
870 let mut drain_sub = |sid: u64| {
871 self.drain_pings.push_back(sid);
872 self.connection
873 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
874 };
875
876 if let Some(sid) = sid {
877 if self.subscriptions.get_mut(&sid).is_some() {
878 drain_sub(sid);
879 }
880 } else {
881 self.connector.events_tx.try_send(Event::Draining).ok();
883 self.is_draining = true;
884 for (&sid, _) in self.subscriptions.iter_mut() {
885 drain_sub(sid);
886 }
887 }
888 self.connection.enqueue_write_op(&ClientOp::Ping);
889 }
890 Command::Subscribe {
891 sid,
892 subject,
893 queue_group,
894 sender,
895 statistics,
896 } => {
897 let subscription = Subscription {
898 sender,
899 statistics,
900 delivered: 0,
901 max: None,
902 subject: subject.to_owned(),
903 queue_group: queue_group.to_owned(),
904 };
905
906 self.subscriptions.insert(sid, subscription);
907
908 self.connection.enqueue_write_op(&ClientOp::Subscribe {
909 sid,
910 subject,
911 queue_group,
912 });
913 }
914 Command::Request {
915 subject,
916 payload,
917 respond,
918 headers,
919 sender,
920 } => {
921 let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
922
923 let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
924 multiplexer
925 } else {
926 let prefix = Subject::from(format!("{}.{}.", prefix, id_generator::next()));
927 let subject = Subject::from(format!("{prefix}*"));
928
929 self.connection.enqueue_write_op(&ClientOp::Subscribe {
930 sid: MULTIPLEXER_SID,
931 subject: subject.clone(),
932 queue_group: None,
933 });
934
935 self.multiplexer.insert(Multiplexer {
936 subject,
937 prefix,
938 senders: HashMap::new(),
939 })
940 };
941 self.connector
942 .connect_stats
943 .out_messages
944 .add(1, Ordering::Relaxed);
945
946 multiplexer.senders.insert(token.to_owned(), sender);
947
948 let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
949
950 let pub_op = ClientOp::Publish {
951 subject,
952 payload,
953 respond: Some(respond),
954 headers,
955 };
956
957 self.connection.enqueue_write_op(&pub_op);
958 }
959
960 Command::Publish(OutboundMessage {
961 subject,
962 payload,
963 reply: respond,
964 headers,
965 }) => {
966 self.connector
967 .connect_stats
968 .out_messages
969 .add(1, Ordering::Relaxed);
970
971 let header_len = headers
972 .as_ref()
973 .map(|headers| headers.len())
974 .unwrap_or_default();
975
976 self.connector.connect_stats.out_bytes.add(
977 (payload.len()
978 + respond.as_ref().map_or_else(|| 0, |r| r.len())
979 + subject.len()
980 + header_len) as u64,
981 Ordering::Relaxed,
982 );
983
984 self.connection.enqueue_write_op(&ClientOp::Publish {
985 subject,
986 payload,
987 respond,
988 headers,
989 });
990 }
991
992 Command::Reconnect => {
993 self.should_reconnect = true;
994 }
995 }
996 }
997
998 async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
999 self.pending_pings = 0;
1000 self.connector.events_tx.try_send(Event::Disconnected).ok();
1001 self.connector.state_tx.send(State::Disconnected).ok();
1002
1003 self.handle_reconnect().await
1004 }
1005
1006 async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
1007 let (info, connection) = self.connector.connect().await?;
1008 self.connection = connection;
1009 let _ = self.info_sender.send(info);
1010
1011 self.subscriptions
1012 .retain(|_, subscription| !subscription.sender.is_closed());
1013
1014 for (sid, subscription) in &self.subscriptions {
1015 self.connection.enqueue_write_op(&ClientOp::Subscribe {
1016 sid: *sid,
1017 subject: subscription.subject.to_owned(),
1018 queue_group: subscription.queue_group.to_owned(),
1019 });
1020 }
1021
1022 if let Some(multiplexer) = &self.multiplexer {
1023 self.connection.enqueue_write_op(&ClientOp::Subscribe {
1024 sid: MULTIPLEXER_SID,
1025 subject: multiplexer.subject.to_owned(),
1026 queue_group: None,
1027 });
1028 }
1029 Ok(())
1030 }
1031}
1032
1033pub async fn connect_with_options<A: ToServerAddrs>(
1049 addrs: A,
1050 options: ConnectOptions,
1051) -> Result<Client, ConnectError> {
1052 let ping_period = options.ping_interval;
1053
1054 let (events_tx, mut events_rx) = mpsc::channel(128);
1055 let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
1056 let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));
1058 let statistics = Arc::new(Statistics::default());
1059
1060 let mut connector = Connector::new(
1061 addrs,
1062 ConnectorOptions {
1063 tls_required: options.tls_required,
1064 certificates: options.certificates,
1065 client_key: options.client_key,
1066 client_cert: options.client_cert,
1067 tls_client_config: options.tls_client_config,
1068 tls_first: options.tls_first,
1069 auth: options.auth,
1070 no_echo: options.no_echo,
1071 connection_timeout: options.connection_timeout,
1072 name: options.name,
1073 ignore_discovered_servers: options.ignore_discovered_servers,
1074 retain_servers_order: options.retain_servers_order,
1075 read_buffer_capacity: options.read_buffer_capacity,
1076 reconnect_delay_callback: options.reconnect_delay_callback,
1077 auth_callback: options.auth_callback,
1078 max_reconnects: options.max_reconnects,
1079 },
1080 events_tx,
1081 state_tx,
1082 max_payload.clone(),
1083 statistics.clone(),
1084 )
1085 .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
1086
1087 let mut info: ServerInfo = Default::default();
1088 let mut connection = None;
1089 if !options.retry_on_initial_connect {
1090 debug!("retry on initial connect failure is disabled");
1091 let (info_ok, connection_ok) = connector.try_connect().await?;
1092 connection = Some(connection_ok);
1093 info = info_ok;
1094 }
1095
1096 let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
1097 let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
1098
1099 let client = Client::new(
1100 info_watcher,
1101 state_rx,
1102 sender,
1103 options.subscription_capacity,
1104 options.inbox_prefix,
1105 options.request_timeout,
1106 max_payload,
1107 statistics,
1108 );
1109
1110 task::spawn(async move {
1111 while let Some(event) = events_rx.recv().await {
1112 tracing::info!("event: {}", event);
1113 if let Some(event_callback) = &options.event_callback {
1114 event_callback.call(event).await;
1115 }
1116 }
1117 });
1118
1119 task::spawn(async move {
1120 if connection.is_none() && options.retry_on_initial_connect {
1121 let (info, connection_ok) = match connector.connect().await {
1122 Ok((info, connection)) => (info, connection),
1123 Err(err) => {
1124 error!("connection closed: {}", err);
1125 return;
1126 }
1127 };
1128 info_sender.send(info).ok();
1129 connection = Some(connection_ok);
1130 }
1131 let connection = connection.unwrap();
1132 let mut connection_handler =
1133 ConnectionHandler::new(connection, connector, info_sender, ping_period);
1134 connection_handler.process(&mut receiver).await
1135 });
1136
1137 Ok(client)
1138}
1139
1140#[derive(Debug, Clone, PartialEq, Eq)]
1141pub enum Event {
1142 Connected,
1143 Disconnected,
1144 LameDuckMode,
1145 Draining,
1146 Closed,
1147 SlowConsumer(SlowConsumer),
1148 ServerError(ServerError),
1149 ClientError(ClientError),
1150}
1151
1152impl fmt::Display for Event {
1153 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1154 match self {
1155 Event::Connected => write!(f, "connected"),
1156 Event::Disconnected => write!(f, "disconnected"),
1157 Event::LameDuckMode => write!(f, "lame duck mode detected"),
1158 Event::Draining => write!(f, "draining"),
1159 Event::Closed => write!(f, "closed"),
1160 Event::SlowConsumer(slow_consumer) => write!(
1161 f,
1162 "slow consumers for subscription {} on subject {}",
1163 slow_consumer.sid, slow_consumer.subject
1164 ),
1165 Event::ServerError(err) => write!(f, "server error: {err}"),
1166 Event::ClientError(err) => write!(f, "client error: {err}"),
1167 }
1168 }
1169}
1170
1171pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1238 connect_with_options(addrs, ConnectOptions::default()).await
1239}
1240
1241#[derive(Debug, Clone, Copy, PartialEq)]
1242pub enum ConnectErrorKind {
1243 ServerParse,
1245 Dns,
1247 Authentication,
1249 AuthorizationViolation,
1251 TimedOut,
1253 Tls,
1255 Io,
1257 MaxReconnects,
1259}
1260
1261impl Display for ConnectErrorKind {
1262 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1263 match self {
1264 Self::ServerParse => write!(f, "failed to parse server or server list"),
1265 Self::Dns => write!(f, "DNS error"),
1266 Self::Authentication => write!(f, "failed signing nonce"),
1267 Self::AuthorizationViolation => write!(f, "authorization violation"),
1268 Self::TimedOut => write!(f, "timed out"),
1269 Self::Tls => write!(f, "TLS error"),
1270 Self::Io => write!(f, "IO error"),
1271 Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1272 }
1273 }
1274}
1275
1276pub type ConnectError = error::Error<ConnectErrorKind>;
1279
1280impl From<io::Error> for ConnectError {
1281 fn from(err: io::Error) -> Self {
1282 ConnectError::with_source(ConnectErrorKind::Io, err)
1283 }
1284}
1285
1286#[derive(Debug)]
1300pub struct Subscriber {
1301 sid: u64,
1302 receiver: mpsc::Receiver<Message>,
1303 sender: mpsc::Sender<Command>,
1304 statistics: Arc<SubscriberStatistics>,
1305 connection_stats: Arc<client::Statistics>,
1306}
1307
1308impl Subscriber {
1309 pub(crate) fn new(
1310 sid: u64,
1311 sender: mpsc::Sender<Command>,
1312 receiver: mpsc::Receiver<Message>,
1313 statistics: Arc<SubscriberStatistics>,
1314 connection_stats: Arc<client::Statistics>,
1315 ) -> Subscriber {
1316 connection_stats
1317 .active_subscriptions
1318 .add(1, Ordering::Relaxed);
1319 connection_stats
1320 .active_subscription_capacity
1321 .add(receiver.max_capacity() as u64, Ordering::Relaxed);
1322
1323 Subscriber {
1324 sid,
1325 sender,
1326 receiver,
1327 statistics,
1328 connection_stats,
1329 }
1330 }
1331
1332 pub fn statistics(&self) -> Arc<SubscriberStatistics> {
1334 self.statistics.clone()
1335 }
1336
1337 pub fn pending_messages(&self) -> usize {
1339 self.receiver.len()
1340 }
1341
1342 pub fn pending_bytes(&self) -> u64 {
1344 self.statistics.pending_bytes.load(Ordering::Relaxed)
1345 }
1346
1347 pub fn remaining_capacity(&self) -> usize {
1349 self.receiver.capacity()
1350 }
1351
1352 pub fn max_capacity(&self) -> usize {
1354 self.receiver.max_capacity()
1355 }
1356
1357 pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1372 self.sender
1373 .send(Command::Unsubscribe {
1374 sid: self.sid,
1375 max: None,
1376 })
1377 .await?;
1378 self.receiver.close();
1379 Ok(())
1380 }
1381
1382 pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1408 self.sender
1409 .send(Command::Unsubscribe {
1410 sid: self.sid,
1411 max: Some(unsub_after),
1412 })
1413 .await?;
1414 Ok(())
1415 }
1416
1417 pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
1450 self.sender
1451 .send(Command::Drain {
1452 sid: Some(self.sid),
1453 })
1454 .await?;
1455
1456 Ok(())
1457 }
1458}
1459
1460#[derive(Error, Debug, PartialEq)]
1461#[error("failed to send unsubscribe")]
1462pub struct UnsubscribeError(String);
1463
1464impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1465 fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1466 UnsubscribeError(err.to_string())
1467 }
1468}
1469
1470impl Drop for Subscriber {
1471 fn drop(&mut self) {
1472 self.receiver.close();
1473 let mut drained_messages = 0;
1474 let mut drained_bytes = 0;
1475
1476 while let Ok(message) = self.receiver.try_recv() {
1477 drained_messages += 1;
1478 drained_bytes += message.length as u64;
1479 }
1480
1481 if drained_messages > 0 {
1482 self.statistics
1483 .pending_messages
1484 .sub(drained_messages, Ordering::Relaxed);
1485 self.connection_stats
1486 .subscription_pending_messages
1487 .sub(drained_messages, Ordering::Relaxed);
1488 }
1489
1490 if drained_bytes > 0 {
1491 self.statistics
1492 .pending_bytes
1493 .sub(drained_bytes, Ordering::Relaxed);
1494 self.connection_stats
1495 .subscription_pending_bytes
1496 .sub(drained_bytes, Ordering::Relaxed);
1497 }
1498
1499 self.connection_stats
1500 .active_subscriptions
1501 .sub(1, Ordering::Relaxed);
1502 self.connection_stats
1503 .active_subscription_capacity
1504 .sub(self.receiver.max_capacity() as u64, Ordering::Relaxed);
1505
1506 tokio::spawn({
1507 let sender = self.sender.clone();
1508 let sid = self.sid;
1509 async move {
1510 sender
1511 .send(Command::Unsubscribe { sid, max: None })
1512 .await
1513 .ok();
1514 }
1515 });
1516 }
1517}
1518
1519impl Stream for Subscriber {
1520 type Item = Message;
1521
1522 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1523 match self.receiver.poll_recv(cx) {
1524 Poll::Ready(Some(message)) => {
1525 self.statistics
1526 .pending_messages
1527 .sub(1, Ordering::Relaxed);
1528 self.statistics
1529 .pending_bytes
1530 .sub(message.length as u64, Ordering::Relaxed);
1531 self.connection_stats
1532 .subscription_pending_messages
1533 .sub(1, Ordering::Relaxed);
1534 self.connection_stats
1535 .subscription_pending_bytes
1536 .sub(message.length as u64, Ordering::Relaxed);
1537 Poll::Ready(Some(message))
1538 }
1539 other => other,
1540 }
1541 }
1542}
1543
1544#[derive(Default, Debug)]
1546pub struct SubscriberStatistics {
1547 pub pending_messages: AtomicU64,
1549 pub pending_bytes: AtomicU64,
1551 pub dropped_messages: AtomicU64,
1553 pub dropped_bytes: AtomicU64,
1555}
1556
1557#[derive(Clone, Debug, Eq, PartialEq)]
1558pub enum CallbackError {
1559 Client(ClientError),
1560 Server(ServerError),
1561}
1562impl std::fmt::Display for CallbackError {
1563 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1564 match self {
1565 Self::Client(error) => write!(f, "{error}"),
1566 Self::Server(error) => write!(f, "{error}"),
1567 }
1568 }
1569}
1570
1571impl From<ServerError> for CallbackError {
1572 fn from(server_error: ServerError) -> Self {
1573 CallbackError::Server(server_error)
1574 }
1575}
1576
1577impl From<ClientError> for CallbackError {
1578 fn from(client_error: ClientError) -> Self {
1579 CallbackError::Client(client_error)
1580 }
1581}
1582
1583#[derive(Clone, Debug, Eq, PartialEq, Error)]
1584pub enum ServerError {
1585 AuthorizationViolation,
1586 SlowConsumer(SlowConsumer),
1587 Other(String),
1588}
1589
1590#[derive(Clone, Debug, Eq, PartialEq)]
1591pub enum ClientError {
1592 Other(String),
1593 MaxReconnects,
1594}
1595impl std::fmt::Display for ClientError {
1596 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1597 match self {
1598 Self::Other(error) => write!(f, "nats: {error}"),
1599 Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1600 }
1601 }
1602}
1603
1604impl ServerError {
1605 fn new(error: String) -> ServerError {
1606 match error.to_lowercase().as_str() {
1607 "authorization violation" => ServerError::AuthorizationViolation,
1608 _ => ServerError::Other(error),
1610 }
1611 }
1612}
1613
1614impl std::fmt::Display for ServerError {
1615 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1616 match self {
1617 Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1618 Self::SlowConsumer(slow_consumer) => write!(
1619 f,
1620 "nats: subscription {} on subject {} is a slow consumer",
1621 slow_consumer.sid, slow_consumer.subject,
1622 ),
1623 Self::Other(error) => write!(f, "nats: {error}"),
1624 }
1625 }
1626}
1627
1628#[derive(Clone, Debug, Eq, PartialEq)]
1629pub struct SlowConsumer {
1630 pub sid: u64,
1631 pub subject: Subject,
1632}
1633
1634impl std::fmt::Display for SlowConsumer {
1635 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1636 write!(f, "slow consumer {} on subject {}", self.sid, self.subject)
1637 }
1638}
1639
1640#[derive(Clone, Debug, Serialize)]
1642pub struct ConnectInfo {
1643 pub verbose: bool,
1645
1646 pub pedantic: bool,
1649
1650 #[serde(rename = "jwt")]
1652 pub user_jwt: Option<String>,
1653
1654 pub nkey: Option<String>,
1656
1657 #[serde(rename = "sig")]
1659 pub signature: Option<String>,
1660
1661 pub name: Option<String>,
1663
1664 pub echo: bool,
1669
1670 pub lang: String,
1672
1673 pub version: String,
1675
1676 pub protocol: Protocol,
1681
1682 pub tls_required: bool,
1684
1685 pub user: Option<String>,
1687
1688 pub pass: Option<String>,
1690
1691 pub auth_token: Option<String>,
1693
1694 pub headers: bool,
1696
1697 pub no_responders: bool,
1699}
1700
1701#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1703#[repr(u8)]
1704pub enum Protocol {
1705 Original = 0,
1707 Dynamic = 1,
1709}
1710
1711#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1713pub struct ServerAddr(Url);
1714
1715impl FromStr for ServerAddr {
1716 type Err = io::Error;
1717
1718 fn from_str(input: &str) -> Result<Self, Self::Err> {
1722 let url: Url = if input.contains("://") {
1723 input.parse()
1724 } else {
1725 format!("nats://{input}").parse()
1726 }
1727 .map_err(|e| {
1728 io::Error::new(
1729 ErrorKind::InvalidInput,
1730 format!("NATS server URL is invalid: {e}"),
1731 )
1732 })?;
1733
1734 Self::from_url(url)
1735 }
1736}
1737
1738impl ServerAddr {
1739 pub fn from_url(url: Url) -> io::Result<Self> {
1741 if url.scheme() != "nats"
1742 && url.scheme() != "tls"
1743 && url.scheme() != "ws"
1744 && url.scheme() != "wss"
1745 {
1746 return Err(std::io::Error::new(
1747 ErrorKind::InvalidInput,
1748 format!("invalid scheme for NATS server URL: {}", url.scheme()),
1749 ));
1750 }
1751
1752 Ok(Self(url))
1753 }
1754
1755 pub fn into_inner(self) -> Url {
1757 self.0
1758 }
1759
1760 pub fn tls_required(&self) -> bool {
1762 self.0.scheme() == "tls"
1763 }
1764
1765 pub fn has_user_pass(&self) -> bool {
1767 self.0.username() != ""
1768 }
1769
1770 pub fn scheme(&self) -> &str {
1771 self.0.scheme()
1772 }
1773
1774 pub fn host(&self) -> &str {
1776 match self.0.host() {
1777 Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1778 Some(Host::Ipv6 { .. }) => {
1780 let host = self.0.host_str().unwrap();
1781 &host[1..host.len() - 1]
1782 }
1783 None => "",
1784 }
1785 }
1786
1787 pub fn is_websocket(&self) -> bool {
1788 self.0.scheme() == "ws" || self.0.scheme() == "wss"
1789 }
1790
1791 pub fn port(&self) -> u16 {
1794 self.0.port_or_known_default().unwrap_or(4222)
1795 }
1796
1797 pub fn as_url_str(&self) -> &str {
1799 self.0.as_str()
1800 }
1801
1802 pub fn username(&self) -> Option<&str> {
1804 let user = self.0.username();
1805 if user.is_empty() {
1806 None
1807 } else {
1808 Some(user)
1809 }
1810 }
1811
1812 pub fn password(&self) -> Option<&str> {
1814 self.0.password()
1815 }
1816
1817 pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1819 tokio::net::lookup_host((self.host(), self.port())).await
1820 }
1821}
1822
1823pub trait ToServerAddrs {
1828 type Iter: Iterator<Item = ServerAddr>;
1831
1832 fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1833}
1834
1835impl ToServerAddrs for ServerAddr {
1836 type Iter = option::IntoIter<ServerAddr>;
1837 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1838 Ok(Some(self.clone()).into_iter())
1839 }
1840}
1841
1842impl ToServerAddrs for str {
1843 type Iter = option::IntoIter<ServerAddr>;
1844 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1845 self.parse::<ServerAddr>()
1846 .map(|addr| Some(addr).into_iter())
1847 }
1848}
1849
1850impl ToServerAddrs for String {
1851 type Iter = option::IntoIter<ServerAddr>;
1852 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1853 (**self).to_server_addrs()
1854 }
1855}
1856
1857impl<T: AsRef<str>> ToServerAddrs for [T] {
1858 type Iter = std::vec::IntoIter<ServerAddr>;
1859 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1860 self.iter()
1861 .map(AsRef::as_ref)
1862 .map(str::parse)
1863 .collect::<io::Result<_>>()
1864 .map(Vec::into_iter)
1865 }
1866}
1867
1868impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1869 type Iter = std::vec::IntoIter<ServerAddr>;
1870 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1871 self.as_slice().to_server_addrs()
1872 }
1873}
1874
1875impl<'a> ToServerAddrs for &'a [ServerAddr] {
1876 type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1877
1878 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1879 Ok(self.iter().cloned())
1880 }
1881}
1882
1883impl ToServerAddrs for Vec<ServerAddr> {
1884 type Iter = std::vec::IntoIter<ServerAddr>;
1885
1886 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1887 Ok(self.clone().into_iter())
1888 }
1889}
1890
1891impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1892 type Iter = T::Iter;
1893 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1894 (**self).to_server_addrs()
1895 }
1896}
1897
1898#[allow(dead_code)]
1899pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1900 let subject_str = subject.as_ref();
1901 !subject_str.starts_with('.')
1902 && !subject_str.ends_with('.')
1903 && subject_str.bytes().all(|c| !c.is_ascii_whitespace())
1904}
1905#[allow(unused_macros)]
1906macro_rules! from_with_timeout {
1907 ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1908 impl From<$origin> for $t {
1909 fn from(err: $origin) -> Self {
1910 match err.kind() {
1911 <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1912 _ => Self::with_source(<$k>::Other, err),
1913 }
1914 }
1915 }
1916 };
1917}
1918#[allow(unused_imports)]
1919pub(crate) use from_with_timeout;
1920
1921use crate::connection::ShouldFlush;
1922use crate::message::OutboundMessage;
1923
1924#[cfg(test)]
1925mod tests {
1926 use super::*;
1927
1928 #[test]
1929 fn server_address_ipv6() {
1930 let address = ServerAddr::from_str("nats://[::]").unwrap();
1931 assert_eq!(address.host(), "::")
1932 }
1933
1934 #[test]
1935 fn server_address_ipv4() {
1936 let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1937 assert_eq!(address.host(), "127.0.0.1")
1938 }
1939
1940 #[test]
1941 fn server_address_domain() {
1942 let address = ServerAddr::from_str("nats://example.com").unwrap();
1943 assert_eq!(address.host(), "example.com")
1944 }
1945
1946 #[test]
1947 fn to_server_addrs_vec_str() {
1948 let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1949 let mut addrs_iter = vec.to_server_addrs().unwrap();
1950 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1951 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1952 assert_eq!(addrs_iter.next(), None);
1953 }
1954
1955 #[test]
1956 fn to_server_addrs_arr_str() {
1957 let arr = ["nats://127.0.0.1", "nats://[::]"];
1958 let mut addrs_iter = arr.to_server_addrs().unwrap();
1959 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1960 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1961 assert_eq!(addrs_iter.next(), None);
1962 }
1963
1964 #[test]
1965 fn to_server_addrs_vec_string() {
1966 let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1967 let mut addrs_iter = vec.to_server_addrs().unwrap();
1968 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1969 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1970 assert_eq!(addrs_iter.next(), None);
1971 }
1972
1973 #[test]
1974 fn to_server_addrs_arr_string() {
1975 let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1976 let mut addrs_iter = arr.to_server_addrs().unwrap();
1977 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1978 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1979 assert_eq!(addrs_iter.next(), None);
1980 }
1981
1982 #[test]
1983 fn to_server_ports_arr_string() {
1984 for (arr, expected_port) in [
1985 (
1986 [
1987 "nats://127.0.0.1".to_string(),
1988 "nats://[::]".to_string(),
1989 "tls://127.0.0.1".to_string(),
1990 "tls://[::]".to_string(),
1991 ],
1992 4222,
1993 ),
1994 (
1995 [
1996 "ws://127.0.0.1:80".to_string(),
1997 "ws://[::]:80".to_string(),
1998 "ws://127.0.0.1".to_string(),
1999 "ws://[::]".to_string(),
2000 ],
2001 80,
2002 ),
2003 (
2004 [
2005 "wss://127.0.0.1".to_string(),
2006 "wss://[::]".to_string(),
2007 "wss://127.0.0.1:443".to_string(),
2008 "wss://[::]:443".to_string(),
2009 ],
2010 443,
2011 ),
2012 ] {
2013 let mut addrs_iter = arr.to_server_addrs().unwrap();
2014 assert_eq!(addrs_iter.next().unwrap().port(), expected_port);
2015 }
2016 }
2017}