1#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_cfg))]
196
197use thiserror::Error;
198
199use futures_util::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use std::collections::HashMap;
206use std::collections::VecDeque;
207use std::fmt::Display;
208use std::future::Future;
209use std::iter;
210use std::mem;
211use std::net::SocketAddr;
212use std::option;
213use std::pin::Pin;
214use std::slice;
215use std::str::{self, FromStr};
216use std::sync::atomic::AtomicUsize;
217use std::sync::atomic::Ordering;
218use std::sync::Arc;
219use std::task::{Context, Poll};
220use tokio::io::ErrorKind;
221use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
222use url::{Host, Url};
223
224use bytes::Bytes;
225use serde::{Deserialize, Serialize};
226use serde_repr::{Deserialize_repr, Serialize_repr};
227use tokio::io;
228use tokio::sync::mpsc;
229use tokio::task;
230
231pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
232
233const VERSION: &str = env!("CARGO_PKG_VERSION");
234const LANG: &str = "rust";
235const MAX_PENDING_PINGS: usize = 2;
236const MULTIPLEXER_SID: u64 = 0;
237pub(crate) const DEFAULT_SERVER_MAX_PAYLOAD: usize = 1024 * 1024;
238
239pub use tokio_rustls::rustls;
243
244use connection::{Connection, State};
245use connector::{Connector, ConnectorOptions};
246pub use connector::{ReconnectToServer, Server};
247pub use header::{HeaderMap, HeaderName, HeaderValue};
248pub use subject::{Subject, SubjectError, ToSubject};
249
250mod auth;
251pub(crate) mod auth_utils;
252pub mod client;
253pub mod connection;
254mod connector;
255mod options;
256
257pub use auth::Auth;
258pub use client::{
259 Client, PublishError, Request, RequestError, RequestErrorKind, ServerPoolError,
260 ServerPoolErrorKind, SetServerPoolError, SetServerPoolErrorKind, Statistics, SubscribeError,
261 SubscribeErrorKind,
262};
263pub use options::{AuthError, ConnectOptions};
264
265#[cfg(feature = "crypto")]
266#[cfg_attr(docsrs, doc(cfg(feature = "crypto")))]
267mod crypto;
268pub mod error;
269pub mod header;
270mod id_generator;
271#[cfg(feature = "jetstream")]
272#[cfg_attr(docsrs, doc(cfg(feature = "jetstream")))]
273pub mod jetstream;
274pub mod message;
275#[cfg(feature = "service")]
276#[cfg_attr(docsrs, doc(cfg(feature = "service")))]
277pub mod service;
278pub mod status;
279pub mod subject;
280mod tls;
281
282pub use message::Message;
283pub use status::StatusCode;
284
285#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
288pub struct ServerInfo {
289 #[serde(default)]
291 pub server_id: String,
292 #[serde(default)]
294 pub server_name: String,
295 #[serde(default)]
297 pub host: String,
298 #[serde(default)]
300 pub port: u16,
301 #[serde(default)]
303 pub version: String,
304 #[serde(default)]
307 pub auth_required: bool,
308 #[serde(default)]
310 pub tls_required: bool,
311 #[serde(default)]
313 pub max_payload: usize,
314 #[serde(default)]
316 pub proto: i8,
317 #[serde(default)]
319 pub client_id: u64,
320 #[serde(default)]
322 pub go: String,
323 #[serde(default)]
325 pub nonce: String,
326 #[serde(default)]
328 pub connect_urls: Vec<String>,
329 #[serde(default)]
331 pub client_ip: String,
332 #[serde(default)]
334 pub headers: bool,
335 #[serde(default, rename = "ldm")]
337 pub lame_duck_mode: bool,
338 #[serde(default)]
340 pub cluster: Option<String>,
341 #[serde(default)]
343 pub domain: Option<String>,
344 #[serde(default)]
346 pub jetstream: bool,
347}
348
349#[derive(Clone, Debug, Eq, PartialEq)]
350pub(crate) enum ServerOp {
351 Ok,
352 Info(Box<ServerInfo>),
353 Ping,
354 Pong,
355 Error(ServerError),
356 Message {
357 sid: u64,
358 subject: Subject,
359 reply: Option<Subject>,
360 payload: Bytes,
361 headers: Option<HeaderMap>,
362 status: Option<StatusCode>,
363 description: Option<String>,
364 length: usize,
365 },
366}
367
368#[deprecated(
372 since = "0.44.0",
373 note = "use `async_nats::message::OutboundMessage` instead"
374)]
375pub type PublishMessage = crate::message::OutboundMessage;
376
377#[derive(Debug)]
379pub(crate) enum Command {
380 Publish(OutboundMessage),
381 Request {
382 subject: Subject,
383 payload: Bytes,
384 respond: Subject,
385 headers: Option<HeaderMap>,
386 sender: oneshot::Sender<Message>,
387 },
388 Subscribe {
389 sid: u64,
390 subject: Subject,
391 queue_group: Option<String>,
392 sender: mpsc::Sender<Message>,
393 },
394 Unsubscribe {
395 sid: u64,
396 max: Option<u64>,
397 },
398 Flush {
399 observer: oneshot::Sender<()>,
400 },
401 Drain {
402 sid: Option<u64>,
403 },
404 Reconnect,
405 SetServerPool {
406 servers: Vec<ServerAddr>,
407 result: oneshot::Sender<Result<(), String>>,
408 },
409 ServerPool {
410 result: oneshot::Sender<Vec<connector::Server>>,
411 },
412}
413
414#[derive(Debug)]
416pub(crate) enum ClientOp {
417 Publish {
418 subject: Subject,
419 payload: Bytes,
420 respond: Option<Subject>,
421 headers: Option<HeaderMap>,
422 },
423 Subscribe {
424 sid: u64,
425 subject: Subject,
426 queue_group: Option<String>,
427 },
428 Unsubscribe {
429 sid: u64,
430 max: Option<u64>,
431 },
432 Ping,
433 Pong,
434 Connect(ConnectInfo),
435}
436
437#[derive(Debug)]
438struct Subscription {
439 subject: Subject,
440 sender: mpsc::Sender<Message>,
441 queue_group: Option<String>,
442 delivered: u64,
443 max: Option<u64>,
444}
445
446#[derive(Debug)]
447struct Multiplexer {
448 subject: Subject,
449 prefix: Subject,
450 senders: HashMap<String, oneshot::Sender<Message>>,
451}
452
453pub(crate) struct ConnectionHandler {
455 connection: Connection,
456 connector: Connector,
457 subscriptions: HashMap<u64, Subscription>,
458 multiplexer: Option<Multiplexer>,
459 pending_pings: usize,
460 info_sender: tokio::sync::watch::Sender<Option<ServerInfo>>,
461 ping_interval: Interval,
462 should_reconnect: bool,
463 flush_observers: Vec<oneshot::Sender<()>>,
464 is_draining: bool,
465 drain_pings: VecDeque<u64>,
466}
467
468impl ConnectionHandler {
469 pub(crate) fn new(
470 connection: Connection,
471 connector: Connector,
472 info_sender: tokio::sync::watch::Sender<Option<ServerInfo>>,
473 ping_period: Duration,
474 ) -> ConnectionHandler {
475 let mut ping_interval = interval(ping_period);
476 ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
477
478 ConnectionHandler {
479 connection,
480 connector,
481 subscriptions: HashMap::new(),
482 multiplexer: None,
483 pending_pings: 0,
484 info_sender,
485 ping_interval,
486 should_reconnect: false,
487 flush_observers: Vec::new(),
488 is_draining: false,
489 drain_pings: VecDeque::new(),
490 }
491 }
492
493 pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
494 struct ProcessFut<'a> {
495 handler: &'a mut ConnectionHandler,
496 receiver: &'a mut mpsc::Receiver<Command>,
497 recv_buf: &'a mut Vec<Command>,
498 }
499
500 enum ExitReason {
501 Disconnected(Option<io::Error>),
502 ReconnectRequested,
503 Closed,
504 }
505
506 impl ProcessFut<'_> {
507 const RECV_CHUNK_SIZE: usize = 16;
508
509 #[cold]
510 fn ping(&mut self) -> Poll<ExitReason> {
511 self.handler.pending_pings += 1;
512
513 if self.handler.pending_pings > MAX_PENDING_PINGS {
514 debug!(
515 pending_pings = self.handler.pending_pings,
516 max_pings = MAX_PENDING_PINGS,
517 "disconnecting due to too many pending pings"
518 );
519
520 Poll::Ready(ExitReason::Disconnected(None))
521 } else {
522 self.handler.connection.enqueue_write_op(&ClientOp::Ping);
523
524 Poll::Pending
525 }
526 }
527 }
528
529 impl Future for ProcessFut<'_> {
530 type Output = ExitReason;
531
532 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
546 while self.handler.ping_interval.poll_tick(cx).is_ready() {
550 if let Poll::Ready(exit) = self.ping() {
551 return Poll::Ready(exit);
552 }
553 }
554
555 loop {
556 match self.handler.connection.poll_read_op(cx) {
557 Poll::Pending => break,
558 Poll::Ready(Ok(Some(server_op))) => {
559 self.handler.handle_server_op(server_op);
560 }
561 Poll::Ready(Ok(None)) => {
562 return Poll::Ready(ExitReason::Disconnected(None))
563 }
564 Poll::Ready(Err(err)) => {
565 return Poll::Ready(ExitReason::Disconnected(Some(err)))
566 }
567 }
568 }
569
570 while let Some(sid) = self.handler.drain_pings.pop_front() {
575 self.handler.subscriptions.remove(&sid);
576 }
577
578 if self.handler.is_draining {
579 return Poll::Ready(ExitReason::Closed);
584 }
585
586 let mut made_progress = true;
592 loop {
593 while !self.handler.connection.is_write_buf_full() {
594 debug_assert!(self.recv_buf.is_empty());
595
596 let Self {
597 recv_buf,
598 handler,
599 receiver,
600 } = &mut *self;
601 match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
602 Poll::Pending => break,
603 Poll::Ready(1..) => {
604 made_progress = true;
605
606 for cmd in recv_buf.drain(..) {
607 handler.handle_command(cmd);
608 }
609 }
610 Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
612 }
613 }
614
615 if !mem::take(&mut made_progress) {
626 break;
627 }
628
629 match self.handler.connection.poll_write(cx) {
630 Poll::Pending => {
631 break;
633 }
634 Poll::Ready(Ok(())) => {
635 continue;
637 }
638 Poll::Ready(Err(err)) => {
639 return Poll::Ready(ExitReason::Disconnected(Some(err)))
640 }
641 }
642 }
643
644 if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
645 self.handler.connection.should_flush(),
646 self.handler.flush_observers.is_empty(),
647 ) {
648 match self.handler.connection.poll_flush(cx) {
649 Poll::Pending => {}
650 Poll::Ready(Ok(())) => {
651 for observer in self.handler.flush_observers.drain(..) {
652 let _ = observer.send(());
653 }
654 }
655 Poll::Ready(Err(err)) => {
656 return Poll::Ready(ExitReason::Disconnected(Some(err)))
657 }
658 }
659 }
660
661 if mem::take(&mut self.handler.should_reconnect) {
662 return Poll::Ready(ExitReason::ReconnectRequested);
663 }
664
665 Poll::Pending
666 }
667 }
668
669 let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
670 loop {
671 let process = ProcessFut {
672 handler: self,
673 receiver,
674 recv_buf: &mut recv_buf,
675 };
676 match process.await {
677 ExitReason::Disconnected(err) => {
678 debug!(error = ?err, "disconnected");
679 if self.handle_disconnect().await.is_err() {
680 break;
681 };
682 debug!("reconnected");
683 }
684 ExitReason::Closed => {
685 self.connector.events_tx.try_send(Event::Closed).ok();
687 break;
688 }
689 ExitReason::ReconnectRequested => {
690 debug!("reconnect requested");
691 self.connection.stream.shutdown().await.ok();
693 if self.handle_disconnect().await.is_err() {
694 break;
695 };
696 }
697 }
698 }
699 }
700
701 fn handle_server_op(&mut self, server_op: ServerOp) {
702 self.ping_interval.reset();
703
704 match server_op {
705 ServerOp::Ping => {
706 debug!("received PING");
707 self.connection.enqueue_write_op(&ClientOp::Pong);
708 }
709 ServerOp::Pong => {
710 debug!("received PONG");
711 self.pending_pings = self.pending_pings.saturating_sub(1);
712 }
713 ServerOp::Error(error) => {
714 debug!("received ERROR: {:?}", error);
715 self.connector
716 .events_tx
717 .try_send(Event::ServerError(error))
718 .ok();
719 }
720 ServerOp::Message {
721 sid,
722 subject,
723 reply,
724 payload,
725 headers,
726 status,
727 description,
728 length,
729 } => {
730 debug!("received MESSAGE: sid={}, subject={}", sid, subject);
731 self.connector
732 .connect_stats
733 .in_messages
734 .add(1, Ordering::Relaxed);
735
736 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
737 let message: Message = Message {
738 subject,
739 reply,
740 payload,
741 headers,
742 status,
743 description,
744 length,
745 };
746
747 match subscription.sender.try_send(message) {
750 Ok(_) => {
751 subscription.delivered += 1;
752 if let Some(max) = subscription.max {
756 if subscription.delivered.ge(&max) {
757 debug!("max messages reached for subscription {}", sid);
758 self.subscriptions.remove(&sid);
759 }
760 }
761 }
762 Err(mpsc::error::TrySendError::Full(_)) => {
763 debug!("slow consumer detected for subscription {}", sid);
764 self.connector
765 .events_tx
766 .try_send(Event::SlowConsumer(sid))
767 .ok();
768 }
769 Err(mpsc::error::TrySendError::Closed(_)) => {
770 debug!("subscription {} channel closed", sid);
771 self.subscriptions.remove(&sid);
772 self.connection
773 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
774 }
775 }
776 } else if sid == MULTIPLEXER_SID {
777 debug!("received message for multiplexer");
778 if let Some(multiplexer) = self.multiplexer.as_mut() {
779 let maybe_token =
780 subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
781
782 if let Some(token) = maybe_token {
783 if let Some(sender) = multiplexer.senders.remove(token) {
784 debug!("forwarding message to request with token {}", token);
785 let message = Message {
786 subject,
787 reply,
788 payload,
789 headers,
790 status,
791 description,
792 length,
793 };
794
795 let _ = sender.send(message);
796 }
797 }
798 }
799 }
800 }
801 ServerOp::Info(info) => {
803 debug!("received INFO: server_id={}", info.server_id);
804 if info.lame_duck_mode {
805 debug!("server in lame duck mode");
806 self.connector.events_tx.try_send(Event::LameDuckMode).ok();
807 }
808 }
809
810 _ => {
811 }
813 }
814 }
815
816 fn handle_command(&mut self, command: Command) {
817 self.ping_interval.reset();
818
819 match command {
820 Command::Unsubscribe { sid, max } => {
821 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
822 subscription.max = max;
823 match subscription.max {
824 Some(n) => {
825 if subscription.delivered >= n {
826 self.subscriptions.remove(&sid);
827 }
828 }
829 None => {
830 self.subscriptions.remove(&sid);
831 }
832 }
833
834 self.connection
835 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
836 }
837 }
838 Command::Flush { observer } => {
839 self.flush_observers.push(observer);
840 }
841 Command::Drain { sid } => {
842 let mut drain_sub = |sid: u64| {
843 self.drain_pings.push_back(sid);
844 self.connection
845 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
846 };
847
848 if let Some(sid) = sid {
849 if self.subscriptions.get_mut(&sid).is_some() {
850 drain_sub(sid);
851 }
852 } else {
853 self.connector.events_tx.try_send(Event::Draining).ok();
855 self.is_draining = true;
856 for (&sid, _) in self.subscriptions.iter_mut() {
857 drain_sub(sid);
858 }
859 }
860 self.connection.enqueue_write_op(&ClientOp::Ping);
861 }
862 Command::Subscribe {
863 sid,
864 subject,
865 queue_group,
866 sender,
867 } => {
868 let subscription = Subscription {
869 sender,
870 delivered: 0,
871 max: None,
872 subject: subject.to_owned(),
873 queue_group: queue_group.to_owned(),
874 };
875
876 self.subscriptions.insert(sid, subscription);
877
878 self.connection.enqueue_write_op(&ClientOp::Subscribe {
879 sid,
880 subject,
881 queue_group,
882 });
883 }
884 Command::Request {
885 subject,
886 payload,
887 respond,
888 headers,
889 sender,
890 } => {
891 let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
892
893 let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
894 multiplexer
895 } else {
896 let prefix = Subject::from(format!("{}.{}.", prefix, id_generator::next()));
897 let subject = Subject::from(format!("{prefix}*"));
898
899 self.connection.enqueue_write_op(&ClientOp::Subscribe {
900 sid: MULTIPLEXER_SID,
901 subject: subject.clone(),
902 queue_group: None,
903 });
904
905 self.multiplexer.insert(Multiplexer {
906 subject,
907 prefix,
908 senders: HashMap::new(),
909 })
910 };
911 self.connector
912 .connect_stats
913 .out_messages
914 .add(1, Ordering::Relaxed);
915
916 multiplexer.senders.insert(token.to_owned(), sender);
917
918 let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
919
920 let pub_op = ClientOp::Publish {
921 subject,
922 payload,
923 respond: Some(respond),
924 headers,
925 };
926
927 self.connection.enqueue_write_op(&pub_op);
928 }
929
930 Command::Publish(OutboundMessage {
931 subject,
932 payload,
933 reply: respond,
934 headers,
935 }) => {
936 self.connector
937 .connect_stats
938 .out_messages
939 .add(1, Ordering::Relaxed);
940
941 let header_len = headers
942 .as_ref()
943 .map(|headers| headers.len())
944 .unwrap_or_default();
945
946 self.connector.connect_stats.out_bytes.add(
947 (payload.len()
948 + respond.as_ref().map_or_else(|| 0, |r| r.len())
949 + subject.len()
950 + header_len) as u64,
951 Ordering::Relaxed,
952 );
953
954 self.connection.enqueue_write_op(&ClientOp::Publish {
955 subject,
956 payload,
957 respond,
958 headers,
959 });
960 }
961
962 Command::Reconnect => {
963 self.should_reconnect = true;
964 }
965
966 Command::SetServerPool { servers, result } => {
967 let _ = result.send(self.connector.set_server_pool(servers));
968 }
969
970 Command::ServerPool { result } => {
971 let _ = result.send(self.connector.server_pool());
972 }
973 }
974 }
975
976 async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
977 self.pending_pings = 0;
978 self.connector.events_tx.try_send(Event::Disconnected).ok();
979 self.connector.state_tx.send(State::Disconnected).ok();
980
981 self.handle_reconnect().await
982 }
983
984 async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
985 let (info, connection) = self.connector.connect().await?;
986 self.connection = connection;
987 let _ = self.info_sender.send(Some(info));
988
989 self.subscriptions
990 .retain(|_, subscription| !subscription.sender.is_closed());
991
992 for (sid, subscription) in &self.subscriptions {
993 self.connection.enqueue_write_op(&ClientOp::Subscribe {
994 sid: *sid,
995 subject: subscription.subject.to_owned(),
996 queue_group: subscription.queue_group.to_owned(),
997 });
998
999 if let Some(max) = subscription.max {
1000 self.connection.enqueue_write_op(&ClientOp::Unsubscribe {
1001 sid: *sid,
1002 max: Some(max.saturating_sub(subscription.delivered)),
1003 });
1004 }
1005 }
1006
1007 if let Some(multiplexer) = &self.multiplexer {
1008 self.connection.enqueue_write_op(&ClientOp::Subscribe {
1009 sid: MULTIPLEXER_SID,
1010 subject: multiplexer.subject.to_owned(),
1011 queue_group: None,
1012 });
1013 }
1014 Ok(())
1015 }
1016}
1017
1018pub async fn connect_with_options<A: ToServerAddrs>(
1034 addrs: A,
1035 options: ConnectOptions,
1036) -> Result<Client, ConnectError> {
1037 let ping_period = options.ping_interval;
1038
1039 let (events_tx, mut events_rx) = mpsc::channel(128);
1040 let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
1041 let max_payload = Arc::new(AtomicUsize::new(DEFAULT_SERVER_MAX_PAYLOAD));
1043 let statistics = Arc::new(Statistics::default());
1044
1045 let mut connector = Connector::new(
1046 addrs,
1047 ConnectorOptions {
1048 tls_required: options.tls_required,
1049 certificates: options.certificates,
1050 client_key: options.client_key,
1051 client_cert: options.client_cert,
1052 tls_client_config: options.tls_client_config,
1053 tls_first: options.tls_first,
1054 auth: options.auth,
1055 no_echo: options.no_echo,
1056 connection_timeout: options.connection_timeout,
1057 name: options.name,
1058 ignore_discovered_servers: options.ignore_discovered_servers,
1059 retain_servers_order: options.retain_servers_order,
1060 read_buffer_capacity: options.read_buffer_capacity,
1061 reconnect_delay_callback: options.reconnect_delay_callback,
1062 auth_callback: options.auth_callback,
1063 max_reconnects: options.max_reconnects,
1064 local_address: options.local_address,
1065 reconnect_to_server_callback: options.reconnect_to_server_callback,
1066 },
1067 events_tx,
1068 state_tx,
1069 max_payload.clone(),
1070 statistics.clone(),
1071 )
1072 .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
1073
1074 let mut info = None;
1075 let mut connection = None;
1076 if !options.retry_on_initial_connect {
1077 debug!("retry on initial connect failure is disabled");
1078 let (info_ok, connection_ok) = connector.try_connect().await?;
1079 connection = Some(connection_ok);
1080 info = Some(info_ok);
1081 }
1082
1083 let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
1084 let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
1085
1086 let client = Client::new(
1087 info_watcher,
1088 state_rx,
1089 sender,
1090 options.subscription_capacity,
1091 options.inbox_prefix,
1092 options.request_timeout,
1093 max_payload,
1094 statistics,
1095 options.skip_subject_validation,
1096 );
1097
1098 task::spawn(async move {
1099 while let Some(event) = events_rx.recv().await {
1100 tracing::info!("event: {}", event);
1101 if let Some(event_callback) = &options.event_callback {
1102 event_callback.call(event).await;
1103 }
1104 }
1105 });
1106
1107 task::spawn(async move {
1108 if connection.is_none() && options.retry_on_initial_connect {
1109 let (info, connection_ok) = match connector.connect().await {
1110 Ok((info, connection)) => (info, connection),
1111 Err(err) => {
1112 error!("connection closed: {}", err);
1113 return;
1114 }
1115 };
1116 info_sender.send(Some(info)).ok();
1117 connection = Some(connection_ok);
1118 }
1119 let connection = connection.unwrap();
1120 let mut connection_handler =
1121 ConnectionHandler::new(connection, connector, info_sender, ping_period);
1122 connection_handler.process(&mut receiver).await
1123 });
1124
1125 Ok(client)
1126}
1127
1128#[derive(Debug, Clone, PartialEq, Eq)]
1129pub enum Event {
1130 Connected,
1131 Disconnected,
1132 LameDuckMode,
1133 Draining,
1134 Closed,
1135 SlowConsumer(u64),
1136 ServerError(ServerError),
1137 ClientError(ClientError),
1138}
1139
1140impl fmt::Display for Event {
1141 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1142 match self {
1143 Event::Connected => write!(f, "connected"),
1144 Event::Disconnected => write!(f, "disconnected"),
1145 Event::LameDuckMode => write!(f, "lame duck mode detected"),
1146 Event::Draining => write!(f, "draining"),
1147 Event::Closed => write!(f, "closed"),
1148 Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"),
1149 Event::ServerError(err) => write!(f, "server error: {err}"),
1150 Event::ClientError(err) => write!(f, "client error: {err}"),
1151 }
1152 }
1153}
1154
1155pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1222 connect_with_options(addrs, ConnectOptions::default()).await
1223}
1224
1225#[derive(Debug, Clone, Copy, PartialEq)]
1226pub enum ConnectErrorKind {
1227 ServerParse,
1229 Dns,
1231 Authentication,
1233 AuthorizationViolation,
1235 TimedOut,
1237 Tls,
1239 Io,
1241 MaxReconnects,
1243}
1244
1245impl Display for ConnectErrorKind {
1246 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1247 match self {
1248 Self::ServerParse => write!(f, "failed to parse server or server list"),
1249 Self::Dns => write!(f, "DNS error"),
1250 Self::Authentication => write!(f, "failed signing nonce"),
1251 Self::AuthorizationViolation => write!(f, "authorization violation"),
1252 Self::TimedOut => write!(f, "timed out"),
1253 Self::Tls => write!(f, "TLS error"),
1254 Self::Io => write!(f, "IO error"),
1255 Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1256 }
1257 }
1258}
1259
1260pub type ConnectError = error::Error<ConnectErrorKind>;
1263
1264impl From<io::Error> for ConnectError {
1265 fn from(err: io::Error) -> Self {
1266 ConnectError::with_source(ConnectErrorKind::Io, err)
1267 }
1268}
1269
1270#[derive(Debug)]
1284pub struct Subscriber {
1285 sid: u64,
1286 receiver: mpsc::Receiver<Message>,
1287 sender: mpsc::Sender<Command>,
1288}
1289
1290impl Subscriber {
1291 fn new(
1292 sid: u64,
1293 sender: mpsc::Sender<Command>,
1294 receiver: mpsc::Receiver<Message>,
1295 ) -> Subscriber {
1296 Subscriber {
1297 sid,
1298 sender,
1299 receiver,
1300 }
1301 }
1302
1303 pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1318 self.sender
1319 .send(Command::Unsubscribe {
1320 sid: self.sid,
1321 max: None,
1322 })
1323 .await?;
1324 self.receiver.close();
1325 Ok(())
1326 }
1327
1328 pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1354 self.sender
1355 .send(Command::Unsubscribe {
1356 sid: self.sid,
1357 max: Some(unsub_after),
1358 })
1359 .await?;
1360 Ok(())
1361 }
1362
1363 pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
1396 self.sender
1397 .send(Command::Drain {
1398 sid: Some(self.sid),
1399 })
1400 .await?;
1401
1402 Ok(())
1403 }
1404}
1405
1406#[derive(Error, Debug, PartialEq)]
1407#[error("failed to send unsubscribe")]
1408pub struct UnsubscribeError(String);
1409
1410impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1411 fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1412 UnsubscribeError(err.to_string())
1413 }
1414}
1415
1416impl Drop for Subscriber {
1417 fn drop(&mut self) {
1418 self.receiver.close();
1419 tokio::spawn({
1420 let sender = self.sender.clone();
1421 let sid = self.sid;
1422 async move {
1423 sender
1424 .send(Command::Unsubscribe { sid, max: None })
1425 .await
1426 .ok();
1427 }
1428 });
1429 }
1430}
1431
1432impl Stream for Subscriber {
1433 type Item = Message;
1434
1435 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1436 self.receiver.poll_recv(cx)
1437 }
1438}
1439
1440#[derive(Clone, Debug, Eq, PartialEq)]
1441pub enum CallbackError {
1442 Client(ClientError),
1443 Server(ServerError),
1444}
1445impl std::fmt::Display for CallbackError {
1446 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1447 match self {
1448 Self::Client(error) => write!(f, "{error}"),
1449 Self::Server(error) => write!(f, "{error}"),
1450 }
1451 }
1452}
1453
1454impl From<ServerError> for CallbackError {
1455 fn from(server_error: ServerError) -> Self {
1456 CallbackError::Server(server_error)
1457 }
1458}
1459
1460impl From<ClientError> for CallbackError {
1461 fn from(client_error: ClientError) -> Self {
1462 CallbackError::Client(client_error)
1463 }
1464}
1465
1466#[derive(Clone, Debug, Eq, PartialEq, Error)]
1467pub enum ServerError {
1468 AuthorizationViolation,
1469 SlowConsumer(u64),
1470 Other(String),
1471}
1472
1473#[derive(Clone, Debug, Eq, PartialEq)]
1474pub enum ClientError {
1475 Other(String),
1476 MaxReconnects,
1477 ServerNotInPool,
1480}
1481impl std::fmt::Display for ClientError {
1482 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1483 match self {
1484 Self::Other(error) => write!(f, "nats: {error}"),
1485 Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1486 Self::ServerNotInPool => {
1487 write!(f, "nats: reconnect callback returned server not in pool")
1488 }
1489 }
1490 }
1491}
1492
1493impl ServerError {
1494 fn new(error: String) -> ServerError {
1495 match error.to_lowercase().as_str() {
1496 "authorization violation" => ServerError::AuthorizationViolation,
1497 _ => ServerError::Other(error),
1499 }
1500 }
1501}
1502
1503impl std::fmt::Display for ServerError {
1504 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1505 match self {
1506 Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1507 Self::SlowConsumer(sid) => write!(f, "nats: subscription {sid} is a slow consumer"),
1508 Self::Other(error) => write!(f, "nats: {error}"),
1509 }
1510 }
1511}
1512
1513#[derive(Clone, Debug, Serialize)]
1515pub struct ConnectInfo {
1516 pub verbose: bool,
1518
1519 pub pedantic: bool,
1522
1523 #[serde(rename = "jwt")]
1525 pub user_jwt: Option<String>,
1526
1527 pub nkey: Option<String>,
1529
1530 #[serde(rename = "sig")]
1532 pub signature: Option<String>,
1533
1534 pub name: Option<String>,
1536
1537 pub echo: bool,
1542
1543 pub lang: String,
1545
1546 pub version: String,
1548
1549 pub protocol: Protocol,
1554
1555 pub tls_required: bool,
1557
1558 pub user: Option<String>,
1560
1561 pub pass: Option<String>,
1563
1564 pub auth_token: Option<String>,
1566
1567 pub headers: bool,
1569
1570 pub no_responders: bool,
1572}
1573
1574#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1576#[repr(u8)]
1577pub enum Protocol {
1578 Original = 0,
1580 Dynamic = 1,
1582}
1583
1584#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1586pub struct ServerAddr(Url);
1587
1588impl FromStr for ServerAddr {
1589 type Err = io::Error;
1590
1591 fn from_str(input: &str) -> Result<Self, Self::Err> {
1595 let url: Url = if input.contains("://") {
1596 input.parse()
1597 } else {
1598 format!("nats://{input}").parse()
1599 }
1600 .map_err(|e| {
1601 io::Error::new(
1602 ErrorKind::InvalidInput,
1603 format!("NATS server URL is invalid: {e}"),
1604 )
1605 })?;
1606
1607 Self::from_url(url)
1608 }
1609}
1610
1611impl ServerAddr {
1612 pub fn from_url(url: Url) -> io::Result<Self> {
1614 if url.scheme() != "nats"
1615 && url.scheme() != "tls"
1616 && url.scheme() != "ws"
1617 && url.scheme() != "wss"
1618 {
1619 return Err(std::io::Error::new(
1620 ErrorKind::InvalidInput,
1621 format!("invalid scheme for NATS server URL: {}", url.scheme()),
1622 ));
1623 }
1624
1625 Ok(Self(url))
1626 }
1627
1628 pub fn into_inner(self) -> Url {
1630 self.0
1631 }
1632
1633 pub fn tls_required(&self) -> bool {
1635 self.0.scheme() == "tls"
1636 }
1637
1638 pub fn has_user_pass(&self) -> bool {
1640 self.0.username() != ""
1641 }
1642
1643 pub fn scheme(&self) -> &str {
1644 self.0.scheme()
1645 }
1646
1647 pub fn host(&self) -> &str {
1649 match self.0.host() {
1650 Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1651 Some(Host::Ipv6 { .. }) => {
1653 let host = self.0.host_str().unwrap();
1654 &host[1..host.len() - 1]
1655 }
1656 None => "",
1657 }
1658 }
1659
1660 pub fn is_websocket(&self) -> bool {
1661 self.0.scheme() == "ws" || self.0.scheme() == "wss"
1662 }
1663
1664 pub fn port(&self) -> u16 {
1667 self.0.port_or_known_default().unwrap_or(4222)
1668 }
1669
1670 pub fn as_url_str(&self) -> &str {
1672 self.0.as_str()
1673 }
1674
1675 pub fn username(&self) -> Option<&str> {
1677 let user = self.0.username();
1678 if user.is_empty() {
1679 None
1680 } else {
1681 Some(user)
1682 }
1683 }
1684
1685 pub fn password(&self) -> Option<&str> {
1687 self.0.password()
1688 }
1689
1690 pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1692 tokio::net::lookup_host((self.host(), self.port())).await
1693 }
1694}
1695
1696pub trait ToServerAddrs {
1701 type Iter: Iterator<Item = ServerAddr>;
1704
1705 fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1706}
1707
1708impl ToServerAddrs for ServerAddr {
1709 type Iter = option::IntoIter<ServerAddr>;
1710 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1711 Ok(Some(self.clone()).into_iter())
1712 }
1713}
1714
1715impl ToServerAddrs for str {
1716 type Iter = option::IntoIter<ServerAddr>;
1717 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1718 self.parse::<ServerAddr>()
1719 .map(|addr| Some(addr).into_iter())
1720 }
1721}
1722
1723impl ToServerAddrs for String {
1724 type Iter = option::IntoIter<ServerAddr>;
1725 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1726 (**self).to_server_addrs()
1727 }
1728}
1729
1730impl<T: AsRef<str>> ToServerAddrs for [T] {
1731 type Iter = std::vec::IntoIter<ServerAddr>;
1732 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1733 self.iter()
1734 .map(AsRef::as_ref)
1735 .map(str::parse)
1736 .collect::<io::Result<_>>()
1737 .map(Vec::into_iter)
1738 }
1739}
1740
1741impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1742 type Iter = std::vec::IntoIter<ServerAddr>;
1743 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1744 self.as_slice().to_server_addrs()
1745 }
1746}
1747
1748impl<'a> ToServerAddrs for &'a [ServerAddr] {
1749 type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1750
1751 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1752 Ok(self.iter().cloned())
1753 }
1754}
1755
1756impl ToServerAddrs for Vec<ServerAddr> {
1757 type Iter = std::vec::IntoIter<ServerAddr>;
1758
1759 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1760 Ok(self.clone().into_iter())
1761 }
1762}
1763
1764impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1765 type Iter = T::Iter;
1766 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1767 (**self).to_server_addrs()
1768 }
1769}
1770
1771pub(crate) fn is_valid_publish_subject<T: AsRef<str>>(subject: T) -> bool {
1776 let bytes = subject.as_ref().as_bytes();
1777
1778 if bytes.is_empty() {
1779 return false;
1780 }
1781
1782 memchr::memchr3(b' ', b'\r', b'\n', bytes).is_none() && memchr::memchr(b'\t', bytes).is_none()
1783}
1784
1785pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1789 let bytes = subject.as_ref().as_bytes();
1790
1791 if bytes.is_empty() {
1792 return false;
1793 }
1794
1795 bytes[0] != b'.'
1796 && bytes[bytes.len() - 1] != b'.'
1797 && memchr::memmem::find(bytes, b"..").is_none()
1798 && memchr::memchr3(b' ', b'\r', b'\n', bytes).is_none()
1799 && memchr::memchr(b'\t', bytes).is_none()
1800}
1801
1802pub(crate) fn is_valid_queue_group(queue_group: &str) -> bool {
1806 let bytes = queue_group.as_bytes();
1807
1808 if bytes.is_empty() {
1809 return false;
1810 }
1811
1812 memchr::memchr3(b' ', b'\r', b'\n', bytes).is_none() && memchr::memchr(b'\t', bytes).is_none()
1813}
1814
1815#[allow(unused_macros)]
1816macro_rules! from_with_timeout {
1817 ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1818 impl From<$origin> for $t {
1819 fn from(err: $origin) -> Self {
1820 match err.kind() {
1821 <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1822 _ => Self::with_source(<$k>::Other, err),
1823 }
1824 }
1825 }
1826 };
1827}
1828#[allow(unused_imports)]
1829pub(crate) use from_with_timeout;
1830
1831use crate::connection::ShouldFlush;
1832use crate::message::OutboundMessage;
1833
1834#[cfg(test)]
1835mod tests {
1836 use super::*;
1837
1838 #[test]
1839 fn server_address_ipv6() {
1840 let address = ServerAddr::from_str("nats://[::]").unwrap();
1841 assert_eq!(address.host(), "::")
1842 }
1843
1844 #[test]
1845 fn server_address_ipv4() {
1846 let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1847 assert_eq!(address.host(), "127.0.0.1")
1848 }
1849
1850 #[test]
1851 fn server_address_domain() {
1852 let address = ServerAddr::from_str("nats://example.com").unwrap();
1853 assert_eq!(address.host(), "example.com")
1854 }
1855
1856 #[test]
1857 fn to_server_addrs_vec_str() {
1858 let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1859 let mut addrs_iter = vec.to_server_addrs().unwrap();
1860 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1861 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1862 assert_eq!(addrs_iter.next(), None);
1863 }
1864
1865 #[test]
1866 fn to_server_addrs_arr_str() {
1867 let arr = ["nats://127.0.0.1", "nats://[::]"];
1868 let mut addrs_iter = arr.to_server_addrs().unwrap();
1869 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1870 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1871 assert_eq!(addrs_iter.next(), None);
1872 }
1873
1874 #[test]
1875 fn to_server_addrs_vec_string() {
1876 let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1877 let mut addrs_iter = vec.to_server_addrs().unwrap();
1878 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1879 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1880 assert_eq!(addrs_iter.next(), None);
1881 }
1882
1883 #[test]
1884 fn to_server_addrs_arr_string() {
1885 let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1886 let mut addrs_iter = arr.to_server_addrs().unwrap();
1887 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1888 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1889 assert_eq!(addrs_iter.next(), None);
1890 }
1891
1892 #[test]
1893 fn to_server_ports_arr_string() {
1894 for (arr, expected_port) in [
1895 (
1896 [
1897 "nats://127.0.0.1".to_string(),
1898 "nats://[::]".to_string(),
1899 "tls://127.0.0.1".to_string(),
1900 "tls://[::]".to_string(),
1901 ],
1902 4222,
1903 ),
1904 (
1905 [
1906 "ws://127.0.0.1:80".to_string(),
1907 "ws://[::]:80".to_string(),
1908 "ws://127.0.0.1".to_string(),
1909 "ws://[::]".to_string(),
1910 ],
1911 80,
1912 ),
1913 (
1914 [
1915 "wss://127.0.0.1".to_string(),
1916 "wss://[::]".to_string(),
1917 "wss://127.0.0.1:443".to_string(),
1918 "wss://[::]:443".to_string(),
1919 ],
1920 443,
1921 ),
1922 ] {
1923 let mut addrs_iter = arr.to_server_addrs().unwrap();
1924 assert_eq!(addrs_iter.next().unwrap().port(), expected_port);
1925 }
1926 }
1927}