1#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_cfg))]
196
197use thiserror::Error;
198
199use futures_util::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use std::collections::HashMap;
206use std::collections::VecDeque;
207use std::fmt::Display;
208use std::future::Future;
209use std::iter;
210use std::mem;
211use std::net::SocketAddr;
212use std::option;
213use std::pin::Pin;
214use std::slice;
215use std::str::{self, FromStr};
216use std::sync::atomic::AtomicUsize;
217use std::sync::atomic::Ordering;
218use std::sync::Arc;
219use std::task::{Context, Poll};
220use tokio::io::ErrorKind;
221use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
222use url::{Host, Url};
223
224use bytes::Bytes;
225use serde::{Deserialize, Serialize};
226use serde_repr::{Deserialize_repr, Serialize_repr};
227use tokio::io;
228use tokio::sync::mpsc;
229use tokio::task;
230
231pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
232
233const VERSION: &str = env!("CARGO_PKG_VERSION");
234const LANG: &str = "rust";
235const MAX_PENDING_PINGS: usize = 2;
236const MULTIPLEXER_SID: u64 = 0;
237
238pub use tokio_rustls::rustls;
242
243use connection::{Connection, State};
244use connector::{Connector, ConnectorOptions};
245pub use header::{HeaderMap, HeaderName, HeaderValue};
246pub use subject::{Subject, SubjectError, ToSubject};
247
248mod auth;
249pub(crate) mod auth_utils;
250pub mod client;
251pub mod connection;
252mod connector;
253mod options;
254
255pub use auth::Auth;
256pub use client::{
257 Client, PublishError, Request, RequestError, RequestErrorKind, Statistics, SubscribeError,
258 SubscribeErrorKind,
259};
260pub use options::{AuthError, ConnectOptions};
261
262#[cfg(feature = "crypto")]
263#[cfg_attr(docsrs, doc(cfg(feature = "crypto")))]
264mod crypto;
265pub mod error;
266pub mod header;
267mod id_generator;
268#[cfg(feature = "jetstream")]
269#[cfg_attr(docsrs, doc(cfg(feature = "jetstream")))]
270pub mod jetstream;
271pub mod message;
272#[cfg(feature = "service")]
273#[cfg_attr(docsrs, doc(cfg(feature = "service")))]
274pub mod service;
275pub mod status;
276pub mod subject;
277mod tls;
278
279pub use message::Message;
280pub use status::StatusCode;
281
282#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
285pub struct ServerInfo {
286 #[serde(default)]
288 pub server_id: String,
289 #[serde(default)]
291 pub server_name: String,
292 #[serde(default)]
294 pub host: String,
295 #[serde(default)]
297 pub port: u16,
298 #[serde(default)]
300 pub version: String,
301 #[serde(default)]
304 pub auth_required: bool,
305 #[serde(default)]
307 pub tls_required: bool,
308 #[serde(default)]
310 pub max_payload: usize,
311 #[serde(default)]
313 pub proto: i8,
314 #[serde(default)]
316 pub client_id: u64,
317 #[serde(default)]
319 pub go: String,
320 #[serde(default)]
322 pub nonce: String,
323 #[serde(default)]
325 pub connect_urls: Vec<String>,
326 #[serde(default)]
328 pub client_ip: String,
329 #[serde(default)]
331 pub headers: bool,
332 #[serde(default, rename = "ldm")]
334 pub lame_duck_mode: bool,
335 #[serde(default)]
337 pub cluster: Option<String>,
338 #[serde(default)]
340 pub domain: Option<String>,
341 #[serde(default)]
343 pub jetstream: bool,
344}
345
346#[derive(Clone, Debug, Eq, PartialEq)]
347pub(crate) enum ServerOp {
348 Ok,
349 Info(Box<ServerInfo>),
350 Ping,
351 Pong,
352 Error(ServerError),
353 Message {
354 sid: u64,
355 subject: Subject,
356 reply: Option<Subject>,
357 payload: Bytes,
358 headers: Option<HeaderMap>,
359 status: Option<StatusCode>,
360 description: Option<String>,
361 length: usize,
362 },
363}
364
365#[deprecated(
369 since = "0.44.0",
370 note = "use `async_nats::message::OutboundMessage` instead"
371)]
372pub type PublishMessage = crate::message::OutboundMessage;
373
374#[derive(Debug)]
376pub(crate) enum Command {
377 Publish(OutboundMessage),
378 Request {
379 subject: Subject,
380 payload: Bytes,
381 respond: Subject,
382 headers: Option<HeaderMap>,
383 sender: oneshot::Sender<Message>,
384 },
385 Subscribe {
386 sid: u64,
387 subject: Subject,
388 queue_group: Option<String>,
389 sender: mpsc::Sender<Message>,
390 },
391 Unsubscribe {
392 sid: u64,
393 max: Option<u64>,
394 },
395 Flush {
396 observer: oneshot::Sender<()>,
397 },
398 Drain {
399 sid: Option<u64>,
400 },
401 Reconnect,
402}
403
404#[derive(Debug)]
406pub(crate) enum ClientOp {
407 Publish {
408 subject: Subject,
409 payload: Bytes,
410 respond: Option<Subject>,
411 headers: Option<HeaderMap>,
412 },
413 Subscribe {
414 sid: u64,
415 subject: Subject,
416 queue_group: Option<String>,
417 },
418 Unsubscribe {
419 sid: u64,
420 max: Option<u64>,
421 },
422 Ping,
423 Pong,
424 Connect(ConnectInfo),
425}
426
427#[derive(Debug)]
428struct Subscription {
429 subject: Subject,
430 sender: mpsc::Sender<Message>,
431 queue_group: Option<String>,
432 delivered: u64,
433 max: Option<u64>,
434}
435
436#[derive(Debug)]
437struct Multiplexer {
438 subject: Subject,
439 prefix: Subject,
440 senders: HashMap<String, oneshot::Sender<Message>>,
441}
442
443pub(crate) struct ConnectionHandler {
445 connection: Connection,
446 connector: Connector,
447 subscriptions: HashMap<u64, Subscription>,
448 multiplexer: Option<Multiplexer>,
449 pending_pings: usize,
450 info_sender: tokio::sync::watch::Sender<ServerInfo>,
451 ping_interval: Interval,
452 should_reconnect: bool,
453 flush_observers: Vec<oneshot::Sender<()>>,
454 is_draining: bool,
455 drain_pings: VecDeque<u64>,
456}
457
458impl ConnectionHandler {
459 pub(crate) fn new(
460 connection: Connection,
461 connector: Connector,
462 info_sender: tokio::sync::watch::Sender<ServerInfo>,
463 ping_period: Duration,
464 ) -> ConnectionHandler {
465 let mut ping_interval = interval(ping_period);
466 ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
467
468 ConnectionHandler {
469 connection,
470 connector,
471 subscriptions: HashMap::new(),
472 multiplexer: None,
473 pending_pings: 0,
474 info_sender,
475 ping_interval,
476 should_reconnect: false,
477 flush_observers: Vec::new(),
478 is_draining: false,
479 drain_pings: VecDeque::new(),
480 }
481 }
482
483 pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
484 struct ProcessFut<'a> {
485 handler: &'a mut ConnectionHandler,
486 receiver: &'a mut mpsc::Receiver<Command>,
487 recv_buf: &'a mut Vec<Command>,
488 }
489
490 enum ExitReason {
491 Disconnected(Option<io::Error>),
492 ReconnectRequested,
493 Closed,
494 }
495
496 impl ProcessFut<'_> {
497 const RECV_CHUNK_SIZE: usize = 16;
498
499 #[cold]
500 fn ping(&mut self) -> Poll<ExitReason> {
501 self.handler.pending_pings += 1;
502
503 if self.handler.pending_pings > MAX_PENDING_PINGS {
504 debug!(
505 pending_pings = self.handler.pending_pings,
506 max_pings = MAX_PENDING_PINGS,
507 "disconnecting due to too many pending pings"
508 );
509
510 Poll::Ready(ExitReason::Disconnected(None))
511 } else {
512 self.handler.connection.enqueue_write_op(&ClientOp::Ping);
513
514 Poll::Pending
515 }
516 }
517 }
518
519 impl Future for ProcessFut<'_> {
520 type Output = ExitReason;
521
522 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
536 while self.handler.ping_interval.poll_tick(cx).is_ready() {
540 if let Poll::Ready(exit) = self.ping() {
541 return Poll::Ready(exit);
542 }
543 }
544
545 loop {
546 match self.handler.connection.poll_read_op(cx) {
547 Poll::Pending => break,
548 Poll::Ready(Ok(Some(server_op))) => {
549 self.handler.handle_server_op(server_op);
550 }
551 Poll::Ready(Ok(None)) => {
552 return Poll::Ready(ExitReason::Disconnected(None))
553 }
554 Poll::Ready(Err(err)) => {
555 return Poll::Ready(ExitReason::Disconnected(Some(err)))
556 }
557 }
558 }
559
560 while let Some(sid) = self.handler.drain_pings.pop_front() {
565 self.handler.subscriptions.remove(&sid);
566 }
567
568 if self.handler.is_draining {
569 return Poll::Ready(ExitReason::Closed);
574 }
575
576 let mut made_progress = true;
582 loop {
583 while !self.handler.connection.is_write_buf_full() {
584 debug_assert!(self.recv_buf.is_empty());
585
586 let Self {
587 recv_buf,
588 handler,
589 receiver,
590 } = &mut *self;
591 match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
592 Poll::Pending => break,
593 Poll::Ready(1..) => {
594 made_progress = true;
595
596 for cmd in recv_buf.drain(..) {
597 handler.handle_command(cmd);
598 }
599 }
600 Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
602 }
603 }
604
605 if !mem::take(&mut made_progress) {
616 break;
617 }
618
619 match self.handler.connection.poll_write(cx) {
620 Poll::Pending => {
621 break;
623 }
624 Poll::Ready(Ok(())) => {
625 continue;
627 }
628 Poll::Ready(Err(err)) => {
629 return Poll::Ready(ExitReason::Disconnected(Some(err)))
630 }
631 }
632 }
633
634 if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
635 self.handler.connection.should_flush(),
636 self.handler.flush_observers.is_empty(),
637 ) {
638 match self.handler.connection.poll_flush(cx) {
639 Poll::Pending => {}
640 Poll::Ready(Ok(())) => {
641 for observer in self.handler.flush_observers.drain(..) {
642 let _ = observer.send(());
643 }
644 }
645 Poll::Ready(Err(err)) => {
646 return Poll::Ready(ExitReason::Disconnected(Some(err)))
647 }
648 }
649 }
650
651 if mem::take(&mut self.handler.should_reconnect) {
652 return Poll::Ready(ExitReason::ReconnectRequested);
653 }
654
655 Poll::Pending
656 }
657 }
658
659 let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
660 loop {
661 let process = ProcessFut {
662 handler: self,
663 receiver,
664 recv_buf: &mut recv_buf,
665 };
666 match process.await {
667 ExitReason::Disconnected(err) => {
668 debug!(error = ?err, "disconnected");
669 if self.handle_disconnect().await.is_err() {
670 break;
671 };
672 debug!("reconnected");
673 }
674 ExitReason::Closed => {
675 self.connector.events_tx.try_send(Event::Closed).ok();
677 break;
678 }
679 ExitReason::ReconnectRequested => {
680 debug!("reconnect requested");
681 self.connection.stream.shutdown().await.ok();
683 if self.handle_disconnect().await.is_err() {
684 break;
685 };
686 }
687 }
688 }
689 }
690
691 fn handle_server_op(&mut self, server_op: ServerOp) {
692 self.ping_interval.reset();
693
694 match server_op {
695 ServerOp::Ping => {
696 debug!("received PING");
697 self.connection.enqueue_write_op(&ClientOp::Pong);
698 }
699 ServerOp::Pong => {
700 debug!("received PONG");
701 self.pending_pings = self.pending_pings.saturating_sub(1);
702 }
703 ServerOp::Error(error) => {
704 debug!("received ERROR: {:?}", error);
705 self.connector
706 .events_tx
707 .try_send(Event::ServerError(error))
708 .ok();
709 }
710 ServerOp::Message {
711 sid,
712 subject,
713 reply,
714 payload,
715 headers,
716 status,
717 description,
718 length,
719 } => {
720 debug!("received MESSAGE: sid={}, subject={}", sid, subject);
721 self.connector
722 .connect_stats
723 .in_messages
724 .add(1, Ordering::Relaxed);
725
726 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
727 let message: Message = Message {
728 subject,
729 reply,
730 payload,
731 headers,
732 status,
733 description,
734 length,
735 };
736
737 match subscription.sender.try_send(message) {
740 Ok(_) => {
741 subscription.delivered += 1;
742 if let Some(max) = subscription.max {
746 if subscription.delivered.ge(&max) {
747 debug!("max messages reached for subscription {}", sid);
748 self.subscriptions.remove(&sid);
749 }
750 }
751 }
752 Err(mpsc::error::TrySendError::Full(_)) => {
753 debug!("slow consumer detected for subscription {}", sid);
754 self.connector
755 .events_tx
756 .try_send(Event::SlowConsumer(sid))
757 .ok();
758 }
759 Err(mpsc::error::TrySendError::Closed(_)) => {
760 debug!("subscription {} channel closed", sid);
761 self.subscriptions.remove(&sid);
762 self.connection
763 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
764 }
765 }
766 } else if sid == MULTIPLEXER_SID {
767 debug!("received message for multiplexer");
768 if let Some(multiplexer) = self.multiplexer.as_mut() {
769 let maybe_token =
770 subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
771
772 if let Some(token) = maybe_token {
773 if let Some(sender) = multiplexer.senders.remove(token) {
774 debug!("forwarding message to request with token {}", token);
775 let message = Message {
776 subject,
777 reply,
778 payload,
779 headers,
780 status,
781 description,
782 length,
783 };
784
785 let _ = sender.send(message);
786 }
787 }
788 }
789 }
790 }
791 ServerOp::Info(info) => {
793 debug!("received INFO: server_id={}", info.server_id);
794 if info.lame_duck_mode {
795 debug!("server in lame duck mode");
796 self.connector.events_tx.try_send(Event::LameDuckMode).ok();
797 }
798 }
799
800 _ => {
801 }
803 }
804 }
805
806 fn handle_command(&mut self, command: Command) {
807 self.ping_interval.reset();
808
809 match command {
810 Command::Unsubscribe { sid, max } => {
811 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
812 subscription.max = max;
813 match subscription.max {
814 Some(n) => {
815 if subscription.delivered >= n {
816 self.subscriptions.remove(&sid);
817 }
818 }
819 None => {
820 self.subscriptions.remove(&sid);
821 }
822 }
823
824 self.connection
825 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
826 }
827 }
828 Command::Flush { observer } => {
829 self.flush_observers.push(observer);
830 }
831 Command::Drain { sid } => {
832 let mut drain_sub = |sid: u64| {
833 self.drain_pings.push_back(sid);
834 self.connection
835 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
836 };
837
838 if let Some(sid) = sid {
839 if self.subscriptions.get_mut(&sid).is_some() {
840 drain_sub(sid);
841 }
842 } else {
843 self.connector.events_tx.try_send(Event::Draining).ok();
845 self.is_draining = true;
846 for (&sid, _) in self.subscriptions.iter_mut() {
847 drain_sub(sid);
848 }
849 }
850 self.connection.enqueue_write_op(&ClientOp::Ping);
851 }
852 Command::Subscribe {
853 sid,
854 subject,
855 queue_group,
856 sender,
857 } => {
858 let subscription = Subscription {
859 sender,
860 delivered: 0,
861 max: None,
862 subject: subject.to_owned(),
863 queue_group: queue_group.to_owned(),
864 };
865
866 self.subscriptions.insert(sid, subscription);
867
868 self.connection.enqueue_write_op(&ClientOp::Subscribe {
869 sid,
870 subject,
871 queue_group,
872 });
873 }
874 Command::Request {
875 subject,
876 payload,
877 respond,
878 headers,
879 sender,
880 } => {
881 let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
882
883 let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
884 multiplexer
885 } else {
886 let prefix = Subject::from(format!("{}.{}.", prefix, id_generator::next()));
887 let subject = Subject::from(format!("{prefix}*"));
888
889 self.connection.enqueue_write_op(&ClientOp::Subscribe {
890 sid: MULTIPLEXER_SID,
891 subject: subject.clone(),
892 queue_group: None,
893 });
894
895 self.multiplexer.insert(Multiplexer {
896 subject,
897 prefix,
898 senders: HashMap::new(),
899 })
900 };
901 self.connector
902 .connect_stats
903 .out_messages
904 .add(1, Ordering::Relaxed);
905
906 multiplexer.senders.insert(token.to_owned(), sender);
907
908 let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
909
910 let pub_op = ClientOp::Publish {
911 subject,
912 payload,
913 respond: Some(respond),
914 headers,
915 };
916
917 self.connection.enqueue_write_op(&pub_op);
918 }
919
920 Command::Publish(OutboundMessage {
921 subject,
922 payload,
923 reply: respond,
924 headers,
925 }) => {
926 self.connector
927 .connect_stats
928 .out_messages
929 .add(1, Ordering::Relaxed);
930
931 let header_len = headers
932 .as_ref()
933 .map(|headers| headers.len())
934 .unwrap_or_default();
935
936 self.connector.connect_stats.out_bytes.add(
937 (payload.len()
938 + respond.as_ref().map_or_else(|| 0, |r| r.len())
939 + subject.len()
940 + header_len) as u64,
941 Ordering::Relaxed,
942 );
943
944 self.connection.enqueue_write_op(&ClientOp::Publish {
945 subject,
946 payload,
947 respond,
948 headers,
949 });
950 }
951
952 Command::Reconnect => {
953 self.should_reconnect = true;
954 }
955 }
956 }
957
958 async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
959 self.pending_pings = 0;
960 self.connector.events_tx.try_send(Event::Disconnected).ok();
961 self.connector.state_tx.send(State::Disconnected).ok();
962
963 self.handle_reconnect().await
964 }
965
966 async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
967 let (info, connection) = self.connector.connect().await?;
968 self.connection = connection;
969 let _ = self.info_sender.send(info);
970
971 self.subscriptions
972 .retain(|_, subscription| !subscription.sender.is_closed());
973
974 for (sid, subscription) in &self.subscriptions {
975 self.connection.enqueue_write_op(&ClientOp::Subscribe {
976 sid: *sid,
977 subject: subscription.subject.to_owned(),
978 queue_group: subscription.queue_group.to_owned(),
979 });
980 }
981
982 if let Some(multiplexer) = &self.multiplexer {
983 self.connection.enqueue_write_op(&ClientOp::Subscribe {
984 sid: MULTIPLEXER_SID,
985 subject: multiplexer.subject.to_owned(),
986 queue_group: None,
987 });
988 }
989 Ok(())
990 }
991}
992
993pub async fn connect_with_options<A: ToServerAddrs>(
1009 addrs: A,
1010 options: ConnectOptions,
1011) -> Result<Client, ConnectError> {
1012 let ping_period = options.ping_interval;
1013
1014 let (events_tx, mut events_rx) = mpsc::channel(128);
1015 let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
1016 let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));
1018 let statistics = Arc::new(Statistics::default());
1019
1020 let mut connector = Connector::new(
1021 addrs,
1022 ConnectorOptions {
1023 tls_required: options.tls_required,
1024 certificates: options.certificates,
1025 client_key: options.client_key,
1026 client_cert: options.client_cert,
1027 tls_client_config: options.tls_client_config,
1028 tls_first: options.tls_first,
1029 auth: options.auth,
1030 no_echo: options.no_echo,
1031 connection_timeout: options.connection_timeout,
1032 name: options.name,
1033 ignore_discovered_servers: options.ignore_discovered_servers,
1034 retain_servers_order: options.retain_servers_order,
1035 read_buffer_capacity: options.read_buffer_capacity,
1036 reconnect_delay_callback: options.reconnect_delay_callback,
1037 auth_callback: options.auth_callback,
1038 max_reconnects: options.max_reconnects,
1039 local_address: options.local_address,
1040 },
1041 events_tx,
1042 state_tx,
1043 max_payload.clone(),
1044 statistics.clone(),
1045 )
1046 .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
1047
1048 let mut info: ServerInfo = Default::default();
1049 let mut connection = None;
1050 if !options.retry_on_initial_connect {
1051 debug!("retry on initial connect failure is disabled");
1052 let (info_ok, connection_ok) = connector.try_connect().await?;
1053 connection = Some(connection_ok);
1054 info = info_ok;
1055 }
1056
1057 let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
1058 let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
1059
1060 let client = Client::new(
1061 info_watcher,
1062 state_rx,
1063 sender,
1064 options.subscription_capacity,
1065 options.inbox_prefix,
1066 options.request_timeout,
1067 max_payload,
1068 statistics,
1069 options.skip_subject_validation,
1070 );
1071
1072 task::spawn(async move {
1073 while let Some(event) = events_rx.recv().await {
1074 tracing::info!("event: {}", event);
1075 if let Some(event_callback) = &options.event_callback {
1076 event_callback.call(event).await;
1077 }
1078 }
1079 });
1080
1081 task::spawn(async move {
1082 if connection.is_none() && options.retry_on_initial_connect {
1083 let (info, connection_ok) = match connector.connect().await {
1084 Ok((info, connection)) => (info, connection),
1085 Err(err) => {
1086 error!("connection closed: {}", err);
1087 return;
1088 }
1089 };
1090 info_sender.send(info).ok();
1091 connection = Some(connection_ok);
1092 }
1093 let connection = connection.unwrap();
1094 let mut connection_handler =
1095 ConnectionHandler::new(connection, connector, info_sender, ping_period);
1096 connection_handler.process(&mut receiver).await
1097 });
1098
1099 Ok(client)
1100}
1101
1102#[derive(Debug, Clone, PartialEq, Eq)]
1103pub enum Event {
1104 Connected,
1105 Disconnected,
1106 LameDuckMode,
1107 Draining,
1108 Closed,
1109 SlowConsumer(u64),
1110 ServerError(ServerError),
1111 ClientError(ClientError),
1112}
1113
1114impl fmt::Display for Event {
1115 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1116 match self {
1117 Event::Connected => write!(f, "connected"),
1118 Event::Disconnected => write!(f, "disconnected"),
1119 Event::LameDuckMode => write!(f, "lame duck mode detected"),
1120 Event::Draining => write!(f, "draining"),
1121 Event::Closed => write!(f, "closed"),
1122 Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"),
1123 Event::ServerError(err) => write!(f, "server error: {err}"),
1124 Event::ClientError(err) => write!(f, "client error: {err}"),
1125 }
1126 }
1127}
1128
1129pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1196 connect_with_options(addrs, ConnectOptions::default()).await
1197}
1198
1199#[derive(Debug, Clone, Copy, PartialEq)]
1200pub enum ConnectErrorKind {
1201 ServerParse,
1203 Dns,
1205 Authentication,
1207 AuthorizationViolation,
1209 TimedOut,
1211 Tls,
1213 Io,
1215 MaxReconnects,
1217}
1218
1219impl Display for ConnectErrorKind {
1220 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1221 match self {
1222 Self::ServerParse => write!(f, "failed to parse server or server list"),
1223 Self::Dns => write!(f, "DNS error"),
1224 Self::Authentication => write!(f, "failed signing nonce"),
1225 Self::AuthorizationViolation => write!(f, "authorization violation"),
1226 Self::TimedOut => write!(f, "timed out"),
1227 Self::Tls => write!(f, "TLS error"),
1228 Self::Io => write!(f, "IO error"),
1229 Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1230 }
1231 }
1232}
1233
1234pub type ConnectError = error::Error<ConnectErrorKind>;
1237
1238impl From<io::Error> for ConnectError {
1239 fn from(err: io::Error) -> Self {
1240 ConnectError::with_source(ConnectErrorKind::Io, err)
1241 }
1242}
1243
1244#[derive(Debug)]
1258pub struct Subscriber {
1259 sid: u64,
1260 receiver: mpsc::Receiver<Message>,
1261 sender: mpsc::Sender<Command>,
1262}
1263
1264impl Subscriber {
1265 fn new(
1266 sid: u64,
1267 sender: mpsc::Sender<Command>,
1268 receiver: mpsc::Receiver<Message>,
1269 ) -> Subscriber {
1270 Subscriber {
1271 sid,
1272 sender,
1273 receiver,
1274 }
1275 }
1276
1277 pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1292 self.sender
1293 .send(Command::Unsubscribe {
1294 sid: self.sid,
1295 max: None,
1296 })
1297 .await?;
1298 self.receiver.close();
1299 Ok(())
1300 }
1301
1302 pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1328 self.sender
1329 .send(Command::Unsubscribe {
1330 sid: self.sid,
1331 max: Some(unsub_after),
1332 })
1333 .await?;
1334 Ok(())
1335 }
1336
1337 pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
1370 self.sender
1371 .send(Command::Drain {
1372 sid: Some(self.sid),
1373 })
1374 .await?;
1375
1376 Ok(())
1377 }
1378}
1379
1380#[derive(Error, Debug, PartialEq)]
1381#[error("failed to send unsubscribe")]
1382pub struct UnsubscribeError(String);
1383
1384impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1385 fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1386 UnsubscribeError(err.to_string())
1387 }
1388}
1389
1390impl Drop for Subscriber {
1391 fn drop(&mut self) {
1392 self.receiver.close();
1393 tokio::spawn({
1394 let sender = self.sender.clone();
1395 let sid = self.sid;
1396 async move {
1397 sender
1398 .send(Command::Unsubscribe { sid, max: None })
1399 .await
1400 .ok();
1401 }
1402 });
1403 }
1404}
1405
1406impl Stream for Subscriber {
1407 type Item = Message;
1408
1409 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1410 self.receiver.poll_recv(cx)
1411 }
1412}
1413
1414#[derive(Clone, Debug, Eq, PartialEq)]
1415pub enum CallbackError {
1416 Client(ClientError),
1417 Server(ServerError),
1418}
1419impl std::fmt::Display for CallbackError {
1420 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1421 match self {
1422 Self::Client(error) => write!(f, "{error}"),
1423 Self::Server(error) => write!(f, "{error}"),
1424 }
1425 }
1426}
1427
1428impl From<ServerError> for CallbackError {
1429 fn from(server_error: ServerError) -> Self {
1430 CallbackError::Server(server_error)
1431 }
1432}
1433
1434impl From<ClientError> for CallbackError {
1435 fn from(client_error: ClientError) -> Self {
1436 CallbackError::Client(client_error)
1437 }
1438}
1439
1440#[derive(Clone, Debug, Eq, PartialEq, Error)]
1441pub enum ServerError {
1442 AuthorizationViolation,
1443 SlowConsumer(u64),
1444 Other(String),
1445}
1446
1447#[derive(Clone, Debug, Eq, PartialEq)]
1448pub enum ClientError {
1449 Other(String),
1450 MaxReconnects,
1451}
1452impl std::fmt::Display for ClientError {
1453 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1454 match self {
1455 Self::Other(error) => write!(f, "nats: {error}"),
1456 Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1457 }
1458 }
1459}
1460
1461impl ServerError {
1462 fn new(error: String) -> ServerError {
1463 match error.to_lowercase().as_str() {
1464 "authorization violation" => ServerError::AuthorizationViolation,
1465 _ => ServerError::Other(error),
1467 }
1468 }
1469}
1470
1471impl std::fmt::Display for ServerError {
1472 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1473 match self {
1474 Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1475 Self::SlowConsumer(sid) => write!(f, "nats: subscription {sid} is a slow consumer"),
1476 Self::Other(error) => write!(f, "nats: {error}"),
1477 }
1478 }
1479}
1480
1481#[derive(Clone, Debug, Serialize)]
1483pub struct ConnectInfo {
1484 pub verbose: bool,
1486
1487 pub pedantic: bool,
1490
1491 #[serde(rename = "jwt")]
1493 pub user_jwt: Option<String>,
1494
1495 pub nkey: Option<String>,
1497
1498 #[serde(rename = "sig")]
1500 pub signature: Option<String>,
1501
1502 pub name: Option<String>,
1504
1505 pub echo: bool,
1510
1511 pub lang: String,
1513
1514 pub version: String,
1516
1517 pub protocol: Protocol,
1522
1523 pub tls_required: bool,
1525
1526 pub user: Option<String>,
1528
1529 pub pass: Option<String>,
1531
1532 pub auth_token: Option<String>,
1534
1535 pub headers: bool,
1537
1538 pub no_responders: bool,
1540}
1541
1542#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1544#[repr(u8)]
1545pub enum Protocol {
1546 Original = 0,
1548 Dynamic = 1,
1550}
1551
1552#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1554pub struct ServerAddr(Url);
1555
1556impl FromStr for ServerAddr {
1557 type Err = io::Error;
1558
1559 fn from_str(input: &str) -> Result<Self, Self::Err> {
1563 let url: Url = if input.contains("://") {
1564 input.parse()
1565 } else {
1566 format!("nats://{input}").parse()
1567 }
1568 .map_err(|e| {
1569 io::Error::new(
1570 ErrorKind::InvalidInput,
1571 format!("NATS server URL is invalid: {e}"),
1572 )
1573 })?;
1574
1575 Self::from_url(url)
1576 }
1577}
1578
1579impl ServerAddr {
1580 pub fn from_url(url: Url) -> io::Result<Self> {
1582 if url.scheme() != "nats"
1583 && url.scheme() != "tls"
1584 && url.scheme() != "ws"
1585 && url.scheme() != "wss"
1586 {
1587 return Err(std::io::Error::new(
1588 ErrorKind::InvalidInput,
1589 format!("invalid scheme for NATS server URL: {}", url.scheme()),
1590 ));
1591 }
1592
1593 Ok(Self(url))
1594 }
1595
1596 pub fn into_inner(self) -> Url {
1598 self.0
1599 }
1600
1601 pub fn tls_required(&self) -> bool {
1603 self.0.scheme() == "tls"
1604 }
1605
1606 pub fn has_user_pass(&self) -> bool {
1608 self.0.username() != ""
1609 }
1610
1611 pub fn scheme(&self) -> &str {
1612 self.0.scheme()
1613 }
1614
1615 pub fn host(&self) -> &str {
1617 match self.0.host() {
1618 Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1619 Some(Host::Ipv6 { .. }) => {
1621 let host = self.0.host_str().unwrap();
1622 &host[1..host.len() - 1]
1623 }
1624 None => "",
1625 }
1626 }
1627
1628 pub fn is_websocket(&self) -> bool {
1629 self.0.scheme() == "ws" || self.0.scheme() == "wss"
1630 }
1631
1632 pub fn port(&self) -> u16 {
1635 self.0.port_or_known_default().unwrap_or(4222)
1636 }
1637
1638 pub fn as_url_str(&self) -> &str {
1640 self.0.as_str()
1641 }
1642
1643 pub fn username(&self) -> Option<&str> {
1645 let user = self.0.username();
1646 if user.is_empty() {
1647 None
1648 } else {
1649 Some(user)
1650 }
1651 }
1652
1653 pub fn password(&self) -> Option<&str> {
1655 self.0.password()
1656 }
1657
1658 pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1660 tokio::net::lookup_host((self.host(), self.port())).await
1661 }
1662}
1663
1664pub trait ToServerAddrs {
1669 type Iter: Iterator<Item = ServerAddr>;
1672
1673 fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1674}
1675
1676impl ToServerAddrs for ServerAddr {
1677 type Iter = option::IntoIter<ServerAddr>;
1678 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1679 Ok(Some(self.clone()).into_iter())
1680 }
1681}
1682
1683impl ToServerAddrs for str {
1684 type Iter = option::IntoIter<ServerAddr>;
1685 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1686 self.parse::<ServerAddr>()
1687 .map(|addr| Some(addr).into_iter())
1688 }
1689}
1690
1691impl ToServerAddrs for String {
1692 type Iter = option::IntoIter<ServerAddr>;
1693 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1694 (**self).to_server_addrs()
1695 }
1696}
1697
1698impl<T: AsRef<str>> ToServerAddrs for [T] {
1699 type Iter = std::vec::IntoIter<ServerAddr>;
1700 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1701 self.iter()
1702 .map(AsRef::as_ref)
1703 .map(str::parse)
1704 .collect::<io::Result<_>>()
1705 .map(Vec::into_iter)
1706 }
1707}
1708
1709impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1710 type Iter = std::vec::IntoIter<ServerAddr>;
1711 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1712 self.as_slice().to_server_addrs()
1713 }
1714}
1715
1716impl<'a> ToServerAddrs for &'a [ServerAddr] {
1717 type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1718
1719 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1720 Ok(self.iter().cloned())
1721 }
1722}
1723
1724impl ToServerAddrs for Vec<ServerAddr> {
1725 type Iter = std::vec::IntoIter<ServerAddr>;
1726
1727 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1728 Ok(self.clone().into_iter())
1729 }
1730}
1731
1732impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1733 type Iter = T::Iter;
1734 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1735 (**self).to_server_addrs()
1736 }
1737}
1738
1739pub(crate) fn is_valid_publish_subject<T: AsRef<str>>(subject: T) -> bool {
1744 let bytes = subject.as_ref().as_bytes();
1745
1746 if bytes.is_empty() {
1747 return false;
1748 }
1749
1750 memchr::memchr3(b' ', b'\r', b'\n', bytes).is_none() && memchr::memchr(b'\t', bytes).is_none()
1751}
1752
1753pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1757 let bytes = subject.as_ref().as_bytes();
1758
1759 if bytes.is_empty() {
1760 return false;
1761 }
1762
1763 bytes[0] != b'.'
1764 && bytes[bytes.len() - 1] != b'.'
1765 && memchr::memmem::find(bytes, b"..").is_none()
1766 && memchr::memchr3(b' ', b'\r', b'\n', bytes).is_none()
1767 && memchr::memchr(b'\t', bytes).is_none()
1768}
1769
1770pub(crate) fn is_valid_queue_group(queue_group: &str) -> bool {
1774 let bytes = queue_group.as_bytes();
1775
1776 if bytes.is_empty() {
1777 return false;
1778 }
1779
1780 memchr::memchr3(b' ', b'\r', b'\n', bytes).is_none() && memchr::memchr(b'\t', bytes).is_none()
1781}
1782
1783#[allow(unused_macros)]
1784macro_rules! from_with_timeout {
1785 ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1786 impl From<$origin> for $t {
1787 fn from(err: $origin) -> Self {
1788 match err.kind() {
1789 <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1790 _ => Self::with_source(<$k>::Other, err),
1791 }
1792 }
1793 }
1794 };
1795}
1796#[allow(unused_imports)]
1797pub(crate) use from_with_timeout;
1798
1799use crate::connection::ShouldFlush;
1800use crate::message::OutboundMessage;
1801
1802#[cfg(test)]
1803mod tests {
1804 use super::*;
1805
1806 #[test]
1807 fn server_address_ipv6() {
1808 let address = ServerAddr::from_str("nats://[::]").unwrap();
1809 assert_eq!(address.host(), "::")
1810 }
1811
1812 #[test]
1813 fn server_address_ipv4() {
1814 let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1815 assert_eq!(address.host(), "127.0.0.1")
1816 }
1817
1818 #[test]
1819 fn server_address_domain() {
1820 let address = ServerAddr::from_str("nats://example.com").unwrap();
1821 assert_eq!(address.host(), "example.com")
1822 }
1823
1824 #[test]
1825 fn to_server_addrs_vec_str() {
1826 let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1827 let mut addrs_iter = vec.to_server_addrs().unwrap();
1828 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1829 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1830 assert_eq!(addrs_iter.next(), None);
1831 }
1832
1833 #[test]
1834 fn to_server_addrs_arr_str() {
1835 let arr = ["nats://127.0.0.1", "nats://[::]"];
1836 let mut addrs_iter = arr.to_server_addrs().unwrap();
1837 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1838 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1839 assert_eq!(addrs_iter.next(), None);
1840 }
1841
1842 #[test]
1843 fn to_server_addrs_vec_string() {
1844 let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1845 let mut addrs_iter = vec.to_server_addrs().unwrap();
1846 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1847 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1848 assert_eq!(addrs_iter.next(), None);
1849 }
1850
1851 #[test]
1852 fn to_server_addrs_arr_string() {
1853 let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1854 let mut addrs_iter = arr.to_server_addrs().unwrap();
1855 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1856 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1857 assert_eq!(addrs_iter.next(), None);
1858 }
1859
1860 #[test]
1861 fn to_server_ports_arr_string() {
1862 for (arr, expected_port) in [
1863 (
1864 [
1865 "nats://127.0.0.1".to_string(),
1866 "nats://[::]".to_string(),
1867 "tls://127.0.0.1".to_string(),
1868 "tls://[::]".to_string(),
1869 ],
1870 4222,
1871 ),
1872 (
1873 [
1874 "ws://127.0.0.1:80".to_string(),
1875 "ws://[::]:80".to_string(),
1876 "ws://127.0.0.1".to_string(),
1877 "ws://[::]".to_string(),
1878 ],
1879 80,
1880 ),
1881 (
1882 [
1883 "wss://127.0.0.1".to_string(),
1884 "wss://[::]".to_string(),
1885 "wss://127.0.0.1:443".to_string(),
1886 "wss://[::]:443".to_string(),
1887 ],
1888 443,
1889 ),
1890 ] {
1891 let mut addrs_iter = arr.to_server_addrs().unwrap();
1892 assert_eq!(addrs_iter.next().unwrap().port(), expected_port);
1893 }
1894 }
1895}