1#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_cfg))]
196
197use thiserror::Error;
198
199use futures_util::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use std::collections::HashMap;
206use std::collections::VecDeque;
207use std::fmt::Display;
208use std::future::Future;
209use std::iter;
210use std::mem;
211use std::net::SocketAddr;
212use std::option;
213use std::pin::Pin;
214use std::slice;
215use std::str::{self, FromStr};
216use std::sync::atomic::AtomicUsize;
217use std::sync::atomic::Ordering;
218use std::sync::Arc;
219use std::task::{Context, Poll};
220use tokio::io::ErrorKind;
221use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
222use url::{Host, Url};
223
224use bytes::Bytes;
225use serde::{Deserialize, Serialize};
226use serde_repr::{Deserialize_repr, Serialize_repr};
227use tokio::io;
228use tokio::sync::mpsc;
229use tokio::task;
230
231pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
232
233const VERSION: &str = env!("CARGO_PKG_VERSION");
234const LANG: &str = "rust";
235const MAX_PENDING_PINGS: usize = 2;
236const MULTIPLEXER_SID: u64 = 0;
237pub(crate) const DEFAULT_SERVER_MAX_PAYLOAD: usize = 1024 * 1024;
238
239pub use tokio_rustls::rustls;
243
244use connection::{Connection, State};
245use connector::{Connector, ConnectorOptions};
246pub use connector::{ReconnectToServer, Server};
247pub use header::{HeaderMap, HeaderName, HeaderValue};
248pub use subject::{Subject, SubjectError, ToSubject};
249
250mod auth;
251pub(crate) mod auth_utils;
252pub mod client;
253pub mod connection;
254mod connector;
255mod options;
256
257pub use auth::Auth;
258pub use client::{
259 Client, PublishError, PublishErrorKind, Request, RequestError, RequestErrorKind,
260 ServerPoolError, ServerPoolErrorKind, SetServerPoolError, SetServerPoolErrorKind, Statistics,
261 SubscribeError, SubscribeErrorKind,
262};
263pub use options::{AuthError, ConnectOptions};
264
265#[cfg(feature = "crypto")]
266#[cfg_attr(docsrs, doc(cfg(feature = "crypto")))]
267mod crypto;
268pub mod error;
269pub mod header;
270mod id_generator;
271#[cfg(feature = "jetstream")]
272#[cfg_attr(docsrs, doc(cfg(feature = "jetstream")))]
273pub mod jetstream;
274pub mod message;
275#[cfg(feature = "service")]
276#[cfg_attr(docsrs, doc(cfg(feature = "service")))]
277pub mod service;
278pub mod status;
279pub mod subject;
280mod tls;
281
282pub use message::Message;
283pub use status::StatusCode;
284
285#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
288pub struct ServerInfo {
289 #[serde(default)]
291 pub server_id: String,
292 #[serde(default)]
294 pub server_name: String,
295 #[serde(default)]
297 pub host: String,
298 #[serde(default)]
300 pub port: u16,
301 #[serde(default)]
303 pub version: String,
304 #[serde(default)]
307 pub auth_required: bool,
308 #[serde(default)]
310 pub tls_required: bool,
311 #[serde(default)]
313 pub max_payload: usize,
314 #[serde(default)]
316 pub proto: i8,
317 #[serde(default)]
319 pub client_id: u64,
320 #[serde(default)]
322 pub go: String,
323 #[serde(default)]
325 pub nonce: String,
326 #[serde(default)]
328 pub connect_urls: Vec<String>,
329 #[serde(default)]
331 pub client_ip: String,
332 #[serde(default)]
334 pub headers: bool,
335 #[serde(default, rename = "ldm")]
337 pub lame_duck_mode: bool,
338 #[serde(default)]
340 pub cluster: Option<String>,
341 #[serde(default)]
343 pub domain: Option<String>,
344 #[serde(default)]
346 pub jetstream: bool,
347}
348
349#[derive(Clone, Debug, Eq, PartialEq)]
350pub(crate) enum ServerOp {
351 Ok,
352 Info(Box<ServerInfo>),
353 Ping,
354 Pong,
355 Error(ServerError),
356 Message {
357 sid: u64,
358 subject: Subject,
359 reply: Option<Subject>,
360 payload: Bytes,
361 headers: Option<HeaderMap>,
362 status: Option<StatusCode>,
363 description: Option<String>,
364 length: usize,
365 },
366}
367
368#[deprecated(
372 since = "0.44.0",
373 note = "use `async_nats::message::OutboundMessage` instead"
374)]
375pub type PublishMessage = crate::message::OutboundMessage;
376
377#[derive(Debug)]
379pub(crate) enum Command {
380 Publish(OutboundMessage),
381 Request {
382 subject: Subject,
383 payload: Bytes,
384 respond: Subject,
385 headers: Option<HeaderMap>,
386 sender: oneshot::Sender<Message>,
387 },
388 Subscribe {
389 sid: u64,
390 subject: Subject,
391 queue_group: Option<String>,
392 sender: mpsc::Sender<Message>,
393 },
394 Unsubscribe {
395 sid: u64,
396 max: Option<u64>,
397 },
398 Flush {
399 observer: oneshot::Sender<()>,
400 },
401 Drain {
402 sid: Option<u64>,
403 },
404 Reconnect,
405 SetServerPool {
406 servers: Vec<ServerAddr>,
407 result: oneshot::Sender<Result<(), String>>,
408 },
409 ServerPool {
410 result: oneshot::Sender<Vec<connector::Server>>,
411 },
412}
413
414#[derive(Debug)]
416pub(crate) enum ClientOp {
417 Publish {
418 subject: Subject,
419 payload: Bytes,
420 respond: Option<Subject>,
421 headers: Option<HeaderMap>,
422 },
423 Subscribe {
424 sid: u64,
425 subject: Subject,
426 queue_group: Option<String>,
427 },
428 Unsubscribe {
429 sid: u64,
430 max: Option<u64>,
431 },
432 Ping,
433 Pong,
434 Connect(ConnectInfo),
435}
436
437#[derive(Debug)]
438struct Subscription {
439 subject: Subject,
440 sender: mpsc::Sender<Message>,
441 queue_group: Option<String>,
442 delivered: u64,
443 max: Option<u64>,
444}
445
446#[derive(Debug)]
447struct Multiplexer {
448 subject: Subject,
449 prefix: Subject,
450 senders: HashMap<String, oneshot::Sender<Message>>,
451}
452
453pub(crate) struct ConnectionHandler {
455 connection: Connection,
456 connector: Connector,
457 subscriptions: HashMap<u64, Subscription>,
458 multiplexer: Option<Multiplexer>,
459 pending_pings: usize,
460 info_sender: tokio::sync::watch::Sender<Option<ServerInfo>>,
461 ping_interval: Interval,
462 should_reconnect: bool,
463 flush_observers: Vec<oneshot::Sender<()>>,
464 is_draining: bool,
465 drain_pings: VecDeque<u64>,
466}
467
468impl ConnectionHandler {
469 pub(crate) fn new(
470 connection: Connection,
471 connector: Connector,
472 info_sender: tokio::sync::watch::Sender<Option<ServerInfo>>,
473 ping_period: Duration,
474 ) -> ConnectionHandler {
475 let mut ping_interval = interval(ping_period);
476 ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
477
478 ConnectionHandler {
479 connection,
480 connector,
481 subscriptions: HashMap::new(),
482 multiplexer: None,
483 pending_pings: 0,
484 info_sender,
485 ping_interval,
486 should_reconnect: false,
487 flush_observers: Vec::new(),
488 is_draining: false,
489 drain_pings: VecDeque::new(),
490 }
491 }
492
493 pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
494 struct ProcessFut<'a> {
495 handler: &'a mut ConnectionHandler,
496 receiver: &'a mut mpsc::Receiver<Command>,
497 recv_buf: &'a mut Vec<Command>,
498 }
499
500 enum ExitReason {
501 Disconnected(Option<io::Error>),
502 ReconnectRequested,
503 Closed,
504 }
505
506 impl ProcessFut<'_> {
507 const RECV_CHUNK_SIZE: usize = 16;
508
509 #[cold]
510 fn ping(&mut self) -> Poll<ExitReason> {
511 self.handler.pending_pings += 1;
512
513 if self.handler.pending_pings > MAX_PENDING_PINGS {
514 debug!(
515 pending_pings = self.handler.pending_pings,
516 max_pings = MAX_PENDING_PINGS,
517 "disconnecting due to too many pending pings"
518 );
519
520 Poll::Ready(ExitReason::Disconnected(None))
521 } else {
522 self.handler.connection.enqueue_write_op(&ClientOp::Ping);
523
524 Poll::Pending
525 }
526 }
527 }
528
529 impl Future for ProcessFut<'_> {
530 type Output = ExitReason;
531
532 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
546 while self.handler.ping_interval.poll_tick(cx).is_ready() {
550 if let Poll::Ready(exit) = self.ping() {
551 return Poll::Ready(exit);
552 }
553 }
554
555 loop {
556 match self.handler.connection.poll_read_op(cx) {
557 Poll::Pending => break,
558 Poll::Ready(Ok(Some(server_op))) => {
559 self.handler.handle_server_op(server_op);
560 }
561 Poll::Ready(Ok(None)) => {
562 return Poll::Ready(ExitReason::Disconnected(None))
563 }
564 Poll::Ready(Err(err)) => {
565 return Poll::Ready(ExitReason::Disconnected(Some(err)))
566 }
567 }
568 }
569
570 while let Some(sid) = self.handler.drain_pings.pop_front() {
575 self.handler.subscriptions.remove(&sid);
576 }
577
578 if self.handler.is_draining {
579 return Poll::Ready(ExitReason::Closed);
584 }
585
586 let mut made_progress = true;
592 loop {
593 while !self.handler.connection.is_write_buf_full() {
594 debug_assert!(self.recv_buf.is_empty());
595
596 let Self {
597 recv_buf,
598 handler,
599 receiver,
600 } = &mut *self;
601 match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
602 Poll::Pending => break,
603 Poll::Ready(1..) => {
604 made_progress = true;
605
606 for cmd in recv_buf.drain(..) {
607 handler.handle_command(cmd);
608 }
609 }
610 Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
612 }
613 }
614
615 if !mem::take(&mut made_progress) {
626 break;
627 }
628
629 match self.handler.connection.poll_write(cx) {
630 Poll::Pending => {
631 break;
633 }
634 Poll::Ready(Ok(())) => {
635 continue;
637 }
638 Poll::Ready(Err(err)) => {
639 return Poll::Ready(ExitReason::Disconnected(Some(err)))
640 }
641 }
642 }
643
644 if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
645 self.handler.connection.should_flush(),
646 self.handler.flush_observers.is_empty(),
647 ) {
648 match self.handler.connection.poll_flush(cx) {
649 Poll::Pending => {}
650 Poll::Ready(Ok(())) => {
651 for observer in self.handler.flush_observers.drain(..) {
652 let _ = observer.send(());
653 }
654 }
655 Poll::Ready(Err(err)) => {
656 return Poll::Ready(ExitReason::Disconnected(Some(err)))
657 }
658 }
659 }
660
661 if mem::take(&mut self.handler.should_reconnect) {
662 return Poll::Ready(ExitReason::ReconnectRequested);
663 }
664
665 Poll::Pending
666 }
667 }
668
669 let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
670 loop {
671 let process = ProcessFut {
672 handler: self,
673 receiver,
674 recv_buf: &mut recv_buf,
675 };
676 match process.await {
677 ExitReason::Disconnected(err) => {
678 debug!(error = ?err, "disconnected");
679 if self.handle_disconnect().await.is_err() {
680 break;
681 };
682 debug!("reconnected");
683 }
684 ExitReason::Closed => {
685 self.connector.events_tx.try_send(Event::Closed).ok();
687 break;
688 }
689 ExitReason::ReconnectRequested => {
690 debug!("reconnect requested");
691 self.connection.stream.shutdown().await.ok();
693 if self.handle_disconnect().await.is_err() {
694 break;
695 };
696 }
697 }
698 }
699 }
700
701 fn handle_server_op(&mut self, server_op: ServerOp) {
702 self.ping_interval.reset();
703
704 match server_op {
705 ServerOp::Ping => {
706 debug!("received PING");
707 self.connection.enqueue_write_op(&ClientOp::Pong);
708 }
709 ServerOp::Pong => {
710 debug!("received PONG");
711 self.pending_pings = self.pending_pings.saturating_sub(1);
712 }
713 ServerOp::Error(error) => {
714 debug!("received ERROR: {:?}", error);
715 self.connector
716 .events_tx
717 .try_send(Event::ServerError(error))
718 .ok();
719 }
720 ServerOp::Message {
721 sid,
722 subject,
723 reply,
724 payload,
725 headers,
726 status,
727 description,
728 length,
729 } => {
730 debug!("received MESSAGE: sid={}, subject={}", sid, subject);
731 self.connector
732 .connect_stats
733 .in_messages
734 .add(1, Ordering::Relaxed);
735
736 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
737 let message: Message = Message {
738 subject,
739 reply,
740 payload,
741 headers,
742 status,
743 description,
744 length,
745 };
746
747 match subscription.sender.try_send(message) {
750 Ok(_) => {
751 subscription.delivered += 1;
752 if let Some(max) = subscription.max {
756 if subscription.delivered.ge(&max) {
757 debug!("max messages reached for subscription {}", sid);
758 self.subscriptions.remove(&sid);
759 }
760 }
761 }
762 Err(mpsc::error::TrySendError::Full(_)) => {
763 debug!("slow consumer detected for subscription {}", sid);
764 self.connector
765 .events_tx
766 .try_send(Event::SlowConsumer(sid))
767 .ok();
768 }
769 Err(mpsc::error::TrySendError::Closed(_)) => {
770 debug!("subscription {} channel closed", sid);
771 self.subscriptions.remove(&sid);
772 self.connection
773 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
774 }
775 }
776 } else if sid == MULTIPLEXER_SID {
777 debug!("received message for multiplexer");
778 if let Some(multiplexer) = self.multiplexer.as_mut() {
779 let maybe_token =
780 subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
781
782 if let Some(token) = maybe_token {
783 if let Some(sender) = multiplexer.senders.remove(token) {
784 debug!("forwarding message to request with token {}", token);
785 let message = Message {
786 subject,
787 reply,
788 payload,
789 headers,
790 status,
791 description,
792 length,
793 };
794
795 let _ = sender.send(message);
796 }
797 }
798 }
799 }
800 }
801 ServerOp::Info(info) => {
803 debug!("received INFO: server_id={}", info.server_id);
804 if info.lame_duck_mode {
805 debug!("server in lame duck mode");
806 self.connector.events_tx.try_send(Event::LameDuckMode).ok();
807 }
808 }
809
810 _ => {
811 }
813 }
814 }
815
816 fn handle_command(&mut self, command: Command) {
817 match command {
818 Command::Unsubscribe { sid, max } => {
819 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
820 subscription.max = max;
821 match subscription.max {
822 Some(n) => {
823 if subscription.delivered >= n {
824 self.subscriptions.remove(&sid);
825 }
826 }
827 None => {
828 self.subscriptions.remove(&sid);
829 }
830 }
831
832 self.connection
833 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
834 }
835 }
836 Command::Flush { observer } => {
837 self.flush_observers.push(observer);
838 }
839 Command::Drain { sid } => {
840 let mut drain_sub = |sid: u64| {
841 self.drain_pings.push_back(sid);
842 self.connection
843 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
844 };
845
846 if let Some(sid) = sid {
847 if self.subscriptions.get_mut(&sid).is_some() {
848 drain_sub(sid);
849 }
850 } else {
851 self.connector.events_tx.try_send(Event::Draining).ok();
853 self.is_draining = true;
854 for (&sid, _) in self.subscriptions.iter_mut() {
855 drain_sub(sid);
856 }
857 }
858 self.connection.enqueue_write_op(&ClientOp::Ping);
859 }
860 Command::Subscribe {
861 sid,
862 subject,
863 queue_group,
864 sender,
865 } => {
866 let subscription = Subscription {
867 sender,
868 delivered: 0,
869 max: None,
870 subject: subject.to_owned(),
871 queue_group: queue_group.to_owned(),
872 };
873
874 self.subscriptions.insert(sid, subscription);
875
876 self.connection.enqueue_write_op(&ClientOp::Subscribe {
877 sid,
878 subject,
879 queue_group,
880 });
881 }
882 Command::Request {
883 subject,
884 payload,
885 respond,
886 headers,
887 sender,
888 } => {
889 let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
890
891 let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
892 multiplexer
893 } else {
894 let prefix = Subject::from(format!("{}.{}.", prefix, id_generator::next()));
895 let subject = Subject::from(format!("{prefix}*"));
896
897 self.connection.enqueue_write_op(&ClientOp::Subscribe {
898 sid: MULTIPLEXER_SID,
899 subject: subject.clone(),
900 queue_group: None,
901 });
902
903 self.multiplexer.insert(Multiplexer {
904 subject,
905 prefix,
906 senders: HashMap::new(),
907 })
908 };
909 self.connector
910 .connect_stats
911 .out_messages
912 .add(1, Ordering::Relaxed);
913
914 multiplexer.senders.insert(token.to_owned(), sender);
915
916 let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
917
918 let pub_op = ClientOp::Publish {
919 subject,
920 payload,
921 respond: Some(respond),
922 headers,
923 };
924
925 self.connection.enqueue_write_op(&pub_op);
926 }
927
928 Command::Publish(OutboundMessage {
929 subject,
930 payload,
931 reply: respond,
932 headers,
933 }) => {
934 self.connector
935 .connect_stats
936 .out_messages
937 .add(1, Ordering::Relaxed);
938
939 let header_len = headers
940 .as_ref()
941 .map(|headers| headers.len())
942 .unwrap_or_default();
943
944 self.connector.connect_stats.out_bytes.add(
945 (payload.len()
946 + respond.as_ref().map_or_else(|| 0, |r| r.len())
947 + subject.len()
948 + header_len) as u64,
949 Ordering::Relaxed,
950 );
951
952 self.connection.enqueue_write_op(&ClientOp::Publish {
953 subject,
954 payload,
955 respond,
956 headers,
957 });
958 }
959
960 Command::Reconnect => {
961 self.should_reconnect = true;
962 }
963
964 Command::SetServerPool { servers, result } => {
965 let _ = result.send(self.connector.set_server_pool(servers));
966 }
967
968 Command::ServerPool { result } => {
969 let _ = result.send(self.connector.server_pool());
970 }
971 }
972 }
973
974 async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
975 self.pending_pings = 0;
976 self.connector.events_tx.try_send(Event::Disconnected).ok();
977 self.connector.state_tx.send(State::Disconnected).ok();
978
979 self.handle_reconnect().await
980 }
981
982 async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
983 let (info, connection) = self.connector.connect().await?;
984 self.connection = connection;
985 let _ = self.info_sender.send(Some(info));
986
987 self.subscriptions
988 .retain(|_, subscription| !subscription.sender.is_closed());
989
990 for (sid, subscription) in &self.subscriptions {
991 self.connection.enqueue_write_op(&ClientOp::Subscribe {
992 sid: *sid,
993 subject: subscription.subject.to_owned(),
994 queue_group: subscription.queue_group.to_owned(),
995 });
996
997 if let Some(max) = subscription.max {
998 self.connection.enqueue_write_op(&ClientOp::Unsubscribe {
999 sid: *sid,
1000 max: Some(max.saturating_sub(subscription.delivered)),
1001 });
1002 }
1003 }
1004
1005 if let Some(multiplexer) = &self.multiplexer {
1006 self.connection.enqueue_write_op(&ClientOp::Subscribe {
1007 sid: MULTIPLEXER_SID,
1008 subject: multiplexer.subject.to_owned(),
1009 queue_group: None,
1010 });
1011 }
1012 Ok(())
1013 }
1014}
1015
1016pub async fn connect_with_options<A: ToServerAddrs>(
1032 addrs: A,
1033 options: ConnectOptions,
1034) -> Result<Client, ConnectError> {
1035 let ping_period = options.ping_interval;
1036
1037 let (events_tx, mut events_rx) = mpsc::channel(128);
1038 let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
1039 let max_payload = Arc::new(AtomicUsize::new(DEFAULT_SERVER_MAX_PAYLOAD));
1041 let statistics = Arc::new(Statistics::default());
1042
1043 let mut connector = Connector::new(
1044 addrs,
1045 ConnectorOptions {
1046 tls_required: options.tls_required,
1047 certificates: options.certificates,
1048 client_key: options.client_key,
1049 client_cert: options.client_cert,
1050 tls_client_config: options.tls_client_config,
1051 tls_first: options.tls_first,
1052 auth: options.auth,
1053 no_echo: options.no_echo,
1054 connection_timeout: options.connection_timeout,
1055 name: options.name,
1056 ignore_discovered_servers: options.ignore_discovered_servers,
1057 retain_servers_order: options.retain_servers_order,
1058 read_buffer_capacity: options.read_buffer_capacity,
1059 reconnect_delay_callback: options.reconnect_delay_callback,
1060 auth_callback: options.auth_callback,
1061 max_reconnects: options.max_reconnects,
1062 local_address: options.local_address,
1063 reconnect_to_server_callback: options.reconnect_to_server_callback,
1064 },
1065 events_tx,
1066 state_tx,
1067 max_payload.clone(),
1068 statistics.clone(),
1069 )
1070 .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
1071
1072 let mut info = None;
1073 let mut connection = None;
1074 if !options.retry_on_initial_connect {
1075 debug!("retry on initial connect failure is disabled");
1076 let (info_ok, connection_ok) = connector.try_connect().await?;
1077 connection = Some(connection_ok);
1078 info = Some(info_ok);
1079 }
1080
1081 let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
1082 let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
1083
1084 let client = Client::new(
1085 info_watcher,
1086 state_rx,
1087 sender,
1088 options.subscription_capacity,
1089 options.inbox_prefix,
1090 options.request_timeout,
1091 max_payload,
1092 statistics,
1093 options.skip_subject_validation,
1094 );
1095
1096 task::spawn(async move {
1097 while let Some(event) = events_rx.recv().await {
1098 tracing::info!("event: {}", event);
1099 if let Some(event_callback) = &options.event_callback {
1100 event_callback.call(event).await;
1101 }
1102 }
1103 });
1104
1105 task::spawn(async move {
1106 if connection.is_none() && options.retry_on_initial_connect {
1107 let (info, connection_ok) = match connector.connect().await {
1108 Ok((info, connection)) => (info, connection),
1109 Err(err) => {
1110 error!("connection closed: {}", err);
1111 return;
1112 }
1113 };
1114 info_sender.send(Some(info)).ok();
1115 connection = Some(connection_ok);
1116 }
1117 let connection = connection.unwrap();
1118 let mut connection_handler =
1119 ConnectionHandler::new(connection, connector, info_sender, ping_period);
1120 connection_handler.process(&mut receiver).await
1121 });
1122
1123 Ok(client)
1124}
1125
1126#[derive(Debug, Clone, PartialEq, Eq)]
1127pub enum Event {
1128 Connected,
1129 Disconnected,
1130 LameDuckMode,
1131 Draining,
1132 Closed,
1133 SlowConsumer(u64),
1134 ServerError(ServerError),
1135 ClientError(ClientError),
1136}
1137
1138impl fmt::Display for Event {
1139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1140 match self {
1141 Event::Connected => write!(f, "connected"),
1142 Event::Disconnected => write!(f, "disconnected"),
1143 Event::LameDuckMode => write!(f, "lame duck mode detected"),
1144 Event::Draining => write!(f, "draining"),
1145 Event::Closed => write!(f, "closed"),
1146 Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"),
1147 Event::ServerError(err) => write!(f, "server error: {err}"),
1148 Event::ClientError(err) => write!(f, "client error: {err}"),
1149 }
1150 }
1151}
1152
1153pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1220 connect_with_options(addrs, ConnectOptions::default()).await
1221}
1222
1223#[derive(Debug, Clone, Copy, PartialEq)]
1224pub enum ConnectErrorKind {
1225 ServerParse,
1227 Dns,
1229 Authentication,
1231 AuthorizationViolation,
1233 TimedOut,
1235 Tls,
1237 Io,
1239 MaxReconnects,
1241}
1242
1243impl Display for ConnectErrorKind {
1244 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1245 match self {
1246 Self::ServerParse => write!(f, "failed to parse server or server list"),
1247 Self::Dns => write!(f, "DNS error"),
1248 Self::Authentication => write!(f, "failed signing nonce"),
1249 Self::AuthorizationViolation => write!(f, "authorization violation"),
1250 Self::TimedOut => write!(f, "timed out"),
1251 Self::Tls => write!(f, "TLS error"),
1252 Self::Io => write!(f, "IO error"),
1253 Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1254 }
1255 }
1256}
1257
1258pub type ConnectError = error::Error<ConnectErrorKind>;
1261
1262impl From<io::Error> for ConnectError {
1263 fn from(err: io::Error) -> Self {
1264 ConnectError::with_source(ConnectErrorKind::Io, err)
1265 }
1266}
1267
1268#[derive(Debug)]
1282pub struct Subscriber {
1283 sid: u64,
1284 receiver: mpsc::Receiver<Message>,
1285 sender: mpsc::Sender<Command>,
1286}
1287
1288impl Subscriber {
1289 fn new(
1290 sid: u64,
1291 sender: mpsc::Sender<Command>,
1292 receiver: mpsc::Receiver<Message>,
1293 ) -> Subscriber {
1294 Subscriber {
1295 sid,
1296 sender,
1297 receiver,
1298 }
1299 }
1300
1301 pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1316 self.sender
1317 .send(Command::Unsubscribe {
1318 sid: self.sid,
1319 max: None,
1320 })
1321 .await?;
1322 self.receiver.close();
1323 Ok(())
1324 }
1325
1326 pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1352 self.sender
1353 .send(Command::Unsubscribe {
1354 sid: self.sid,
1355 max: Some(unsub_after),
1356 })
1357 .await?;
1358 Ok(())
1359 }
1360
1361 pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
1394 self.sender
1395 .send(Command::Drain {
1396 sid: Some(self.sid),
1397 })
1398 .await?;
1399
1400 Ok(())
1401 }
1402}
1403
1404#[derive(Error, Debug, PartialEq)]
1405#[error("failed to send unsubscribe")]
1406pub struct UnsubscribeError(String);
1407
1408impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1409 fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1410 UnsubscribeError(err.to_string())
1411 }
1412}
1413
1414impl Drop for Subscriber {
1415 fn drop(&mut self) {
1416 self.receiver.close();
1417 tokio::spawn({
1418 let sender = self.sender.clone();
1419 let sid = self.sid;
1420 async move {
1421 sender
1422 .send(Command::Unsubscribe { sid, max: None })
1423 .await
1424 .ok();
1425 }
1426 });
1427 }
1428}
1429
1430impl Stream for Subscriber {
1431 type Item = Message;
1432
1433 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1434 self.receiver.poll_recv(cx)
1435 }
1436}
1437
1438#[derive(Clone, Debug, Eq, PartialEq)]
1439pub enum CallbackError {
1440 Client(ClientError),
1441 Server(ServerError),
1442}
1443impl std::fmt::Display for CallbackError {
1444 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1445 match self {
1446 Self::Client(error) => write!(f, "{error}"),
1447 Self::Server(error) => write!(f, "{error}"),
1448 }
1449 }
1450}
1451
1452impl From<ServerError> for CallbackError {
1453 fn from(server_error: ServerError) -> Self {
1454 CallbackError::Server(server_error)
1455 }
1456}
1457
1458impl From<ClientError> for CallbackError {
1459 fn from(client_error: ClientError) -> Self {
1460 CallbackError::Client(client_error)
1461 }
1462}
1463
1464#[derive(Clone, Debug, Eq, PartialEq, Error)]
1465pub enum ServerError {
1466 AuthorizationViolation,
1467 SlowConsumer(u64),
1468 Other(String),
1469}
1470
1471#[derive(Clone, Debug, Eq, PartialEq)]
1472pub enum ClientError {
1473 Other(String),
1474 MaxReconnects,
1475 ServerNotInPool,
1478}
1479impl std::fmt::Display for ClientError {
1480 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1481 match self {
1482 Self::Other(error) => write!(f, "nats: {error}"),
1483 Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1484 Self::ServerNotInPool => {
1485 write!(f, "nats: reconnect callback returned server not in pool")
1486 }
1487 }
1488 }
1489}
1490
1491impl ServerError {
1492 fn new(error: String) -> ServerError {
1493 match error.to_lowercase().as_str() {
1494 "authorization violation" => ServerError::AuthorizationViolation,
1495 _ => ServerError::Other(error),
1497 }
1498 }
1499}
1500
1501impl std::fmt::Display for ServerError {
1502 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1503 match self {
1504 Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1505 Self::SlowConsumer(sid) => write!(f, "nats: subscription {sid} is a slow consumer"),
1506 Self::Other(error) => write!(f, "nats: {error}"),
1507 }
1508 }
1509}
1510
1511#[derive(Clone, Debug, Serialize)]
1513pub struct ConnectInfo {
1514 pub verbose: bool,
1516
1517 pub pedantic: bool,
1520
1521 #[serde(rename = "jwt")]
1523 pub user_jwt: Option<String>,
1524
1525 pub nkey: Option<String>,
1527
1528 #[serde(rename = "sig")]
1530 pub signature: Option<String>,
1531
1532 pub name: Option<String>,
1534
1535 pub echo: bool,
1540
1541 pub lang: String,
1543
1544 pub version: String,
1546
1547 pub protocol: Protocol,
1552
1553 pub tls_required: bool,
1555
1556 pub user: Option<String>,
1558
1559 pub pass: Option<String>,
1561
1562 pub auth_token: Option<String>,
1564
1565 pub headers: bool,
1567
1568 pub no_responders: bool,
1570}
1571
1572#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1574#[repr(u8)]
1575pub enum Protocol {
1576 Original = 0,
1578 Dynamic = 1,
1580}
1581
1582#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1584pub struct ServerAddr(Url);
1585
1586impl FromStr for ServerAddr {
1587 type Err = io::Error;
1588
1589 fn from_str(input: &str) -> Result<Self, Self::Err> {
1593 let url: Url = if input.contains("://") {
1594 input.parse()
1595 } else {
1596 format!("nats://{input}").parse()
1597 }
1598 .map_err(|e| {
1599 io::Error::new(
1600 ErrorKind::InvalidInput,
1601 format!("NATS server URL is invalid: {e}"),
1602 )
1603 })?;
1604
1605 Self::from_url(url)
1606 }
1607}
1608
1609impl ServerAddr {
1610 pub fn from_url(url: Url) -> io::Result<Self> {
1612 if url.scheme() != "nats"
1613 && url.scheme() != "tls"
1614 && url.scheme() != "ws"
1615 && url.scheme() != "wss"
1616 {
1617 return Err(std::io::Error::new(
1618 ErrorKind::InvalidInput,
1619 format!("invalid scheme for NATS server URL: {}", url.scheme()),
1620 ));
1621 }
1622
1623 Ok(Self(url))
1624 }
1625
1626 pub fn into_inner(self) -> Url {
1628 self.0
1629 }
1630
1631 pub fn tls_required(&self) -> bool {
1633 self.0.scheme() == "tls"
1634 }
1635
1636 pub fn has_user_pass(&self) -> bool {
1638 self.0.username() != ""
1639 }
1640
1641 pub fn scheme(&self) -> &str {
1642 self.0.scheme()
1643 }
1644
1645 pub fn host(&self) -> &str {
1647 match self.0.host() {
1648 Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1649 Some(Host::Ipv6 { .. }) => {
1651 let host = self.0.host_str().unwrap();
1652 &host[1..host.len() - 1]
1653 }
1654 None => "",
1655 }
1656 }
1657
1658 pub fn is_websocket(&self) -> bool {
1659 self.0.scheme() == "ws" || self.0.scheme() == "wss"
1660 }
1661
1662 pub fn port(&self) -> u16 {
1665 self.0.port_or_known_default().unwrap_or(4222)
1666 }
1667
1668 pub fn as_url_str(&self) -> &str {
1670 self.0.as_str()
1671 }
1672
1673 pub fn username(&self) -> Option<&str> {
1675 let user = self.0.username();
1676 if user.is_empty() {
1677 None
1678 } else {
1679 Some(user)
1680 }
1681 }
1682
1683 pub fn password(&self) -> Option<&str> {
1685 self.0.password()
1686 }
1687
1688 pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1690 tokio::net::lookup_host((self.host(), self.port())).await
1691 }
1692}
1693
1694pub trait ToServerAddrs {
1699 type Iter: Iterator<Item = ServerAddr>;
1702
1703 fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1704}
1705
1706impl ToServerAddrs for ServerAddr {
1707 type Iter = option::IntoIter<ServerAddr>;
1708 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1709 Ok(Some(self.clone()).into_iter())
1710 }
1711}
1712
1713impl ToServerAddrs for str {
1714 type Iter = option::IntoIter<ServerAddr>;
1715 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1716 self.parse::<ServerAddr>()
1717 .map(|addr| Some(addr).into_iter())
1718 }
1719}
1720
1721impl ToServerAddrs for String {
1722 type Iter = option::IntoIter<ServerAddr>;
1723 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1724 (**self).to_server_addrs()
1725 }
1726}
1727
1728impl<T: AsRef<str>> ToServerAddrs for [T] {
1729 type Iter = std::vec::IntoIter<ServerAddr>;
1730 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1731 self.iter()
1732 .map(AsRef::as_ref)
1733 .map(str::parse)
1734 .collect::<io::Result<_>>()
1735 .map(Vec::into_iter)
1736 }
1737}
1738
1739impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1740 type Iter = std::vec::IntoIter<ServerAddr>;
1741 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1742 self.as_slice().to_server_addrs()
1743 }
1744}
1745
1746impl<'a> ToServerAddrs for &'a [ServerAddr] {
1747 type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1748
1749 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1750 Ok(self.iter().cloned())
1751 }
1752}
1753
1754impl ToServerAddrs for Vec<ServerAddr> {
1755 type Iter = std::vec::IntoIter<ServerAddr>;
1756
1757 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1758 Ok(self.clone().into_iter())
1759 }
1760}
1761
1762impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1763 type Iter = T::Iter;
1764 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1765 (**self).to_server_addrs()
1766 }
1767}
1768
1769pub(crate) fn is_valid_publish_subject<T: AsRef<str>>(subject: T) -> bool {
1774 let bytes = subject.as_ref().as_bytes();
1775
1776 if bytes.is_empty() {
1777 return false;
1778 }
1779
1780 memchr::memchr3(b' ', b'\r', b'\n', bytes).is_none() && memchr::memchr(b'\t', bytes).is_none()
1781}
1782
1783pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1787 let bytes = subject.as_ref().as_bytes();
1788
1789 if bytes.is_empty() {
1790 return false;
1791 }
1792
1793 bytes[0] != b'.'
1794 && bytes[bytes.len() - 1] != b'.'
1795 && memchr::memmem::find(bytes, b"..").is_none()
1796 && memchr::memchr3(b' ', b'\r', b'\n', bytes).is_none()
1797 && memchr::memchr(b'\t', bytes).is_none()
1798}
1799
1800pub(crate) fn is_valid_queue_group(queue_group: &str) -> bool {
1804 let bytes = queue_group.as_bytes();
1805
1806 if bytes.is_empty() {
1807 return false;
1808 }
1809
1810 memchr::memchr3(b' ', b'\r', b'\n', bytes).is_none() && memchr::memchr(b'\t', bytes).is_none()
1811}
1812
1813#[allow(unused_macros)]
1814macro_rules! from_with_timeout {
1815 ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1816 impl From<$origin> for $t {
1817 fn from(err: $origin) -> Self {
1818 match err.kind() {
1819 <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1820 _ => Self::with_source(<$k>::Other, err),
1821 }
1822 }
1823 }
1824 };
1825}
1826#[allow(unused_imports)]
1827pub(crate) use from_with_timeout;
1828
1829use crate::connection::ShouldFlush;
1830use crate::message::OutboundMessage;
1831
1832#[cfg(test)]
1833mod tests {
1834 use super::*;
1835
1836 #[test]
1837 fn server_address_ipv6() {
1838 let address = ServerAddr::from_str("nats://[::]").unwrap();
1839 assert_eq!(address.host(), "::")
1840 }
1841
1842 #[test]
1843 fn server_address_ipv4() {
1844 let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1845 assert_eq!(address.host(), "127.0.0.1")
1846 }
1847
1848 #[test]
1849 fn server_address_domain() {
1850 let address = ServerAddr::from_str("nats://example.com").unwrap();
1851 assert_eq!(address.host(), "example.com")
1852 }
1853
1854 #[test]
1855 fn to_server_addrs_vec_str() {
1856 let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1857 let mut addrs_iter = vec.to_server_addrs().unwrap();
1858 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1859 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1860 assert_eq!(addrs_iter.next(), None);
1861 }
1862
1863 #[test]
1864 fn to_server_addrs_arr_str() {
1865 let arr = ["nats://127.0.0.1", "nats://[::]"];
1866 let mut addrs_iter = arr.to_server_addrs().unwrap();
1867 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1868 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1869 assert_eq!(addrs_iter.next(), None);
1870 }
1871
1872 #[test]
1873 fn to_server_addrs_vec_string() {
1874 let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1875 let mut addrs_iter = vec.to_server_addrs().unwrap();
1876 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1877 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1878 assert_eq!(addrs_iter.next(), None);
1879 }
1880
1881 #[test]
1882 fn to_server_addrs_arr_string() {
1883 let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1884 let mut addrs_iter = arr.to_server_addrs().unwrap();
1885 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1886 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1887 assert_eq!(addrs_iter.next(), None);
1888 }
1889
1890 #[test]
1891 fn to_server_ports_arr_string() {
1892 for (arr, expected_port) in [
1893 (
1894 [
1895 "nats://127.0.0.1".to_string(),
1896 "nats://[::]".to_string(),
1897 "tls://127.0.0.1".to_string(),
1898 "tls://[::]".to_string(),
1899 ],
1900 4222,
1901 ),
1902 (
1903 [
1904 "ws://127.0.0.1:80".to_string(),
1905 "ws://[::]:80".to_string(),
1906 "ws://127.0.0.1".to_string(),
1907 "ws://[::]".to_string(),
1908 ],
1909 80,
1910 ),
1911 (
1912 [
1913 "wss://127.0.0.1".to_string(),
1914 "wss://[::]".to_string(),
1915 "wss://127.0.0.1:443".to_string(),
1916 "wss://[::]:443".to_string(),
1917 ],
1918 443,
1919 ),
1920 ] {
1921 let mut addrs_iter = arr.to_server_addrs().unwrap();
1922 assert_eq!(addrs_iter.next().unwrap().port(), expected_port);
1923 }
1924 }
1925}