1#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_auto_cfg))]
196
197use thiserror::Error;
198
199use futures::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use std::collections::HashMap;
206use std::fmt::Display;
207use std::future::Future;
208use std::iter;
209use std::mem;
210use std::net::SocketAddr;
211use std::option;
212use std::pin::Pin;
213use std::slice;
214use std::str::{self, FromStr};
215use std::sync::atomic::AtomicUsize;
216use std::sync::atomic::Ordering;
217use std::sync::Arc;
218use std::task::{Context, Poll};
219use tokio::io::ErrorKind;
220use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
221use url::{Host, Url};
222
223use bytes::Bytes;
224use serde::{Deserialize, Serialize};
225use serde_repr::{Deserialize_repr, Serialize_repr};
226use tokio::io;
227use tokio::sync::mpsc;
228use tokio::task;
229
230pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
231
232const VERSION: &str = env!("CARGO_PKG_VERSION");
233const LANG: &str = "rust";
234const MAX_PENDING_PINGS: usize = 2;
235const MULTIPLEXER_SID: u64 = 0;
236
237pub use tokio_rustls::rustls;
241
242use connection::{Connection, State};
243use connector::{Connector, ConnectorOptions};
244pub use header::{HeaderMap, HeaderName, HeaderValue};
245pub use subject::Subject;
246
247mod auth;
248pub(crate) mod auth_utils;
249pub mod client;
250pub mod connection;
251mod connector;
252mod options;
253
254pub use auth::Auth;
255pub use client::{
256 Client, PublishError, Request, RequestError, RequestErrorKind, Statistics, SubscribeError,
257};
258pub use options::{AuthError, ConnectOptions};
259
260mod crypto;
261pub mod error;
262pub mod header;
263pub mod jetstream;
264pub mod message;
265#[cfg(feature = "service")]
266pub mod service;
267pub mod status;
268pub mod subject;
269mod tls;
270
271pub use message::Message;
272pub use status::StatusCode;
273
274#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
277pub struct ServerInfo {
278 #[serde(default)]
280 pub server_id: String,
281 #[serde(default)]
283 pub server_name: String,
284 #[serde(default)]
286 pub host: String,
287 #[serde(default)]
289 pub port: u16,
290 #[serde(default)]
292 pub version: String,
293 #[serde(default)]
296 pub auth_required: bool,
297 #[serde(default)]
299 pub tls_required: bool,
300 #[serde(default)]
302 pub max_payload: usize,
303 #[serde(default)]
305 pub proto: i8,
306 #[serde(default)]
308 pub client_id: u64,
309 #[serde(default)]
311 pub go: String,
312 #[serde(default)]
314 pub nonce: String,
315 #[serde(default)]
317 pub connect_urls: Vec<String>,
318 #[serde(default)]
320 pub client_ip: String,
321 #[serde(default)]
323 pub headers: bool,
324 #[serde(default, rename = "ldm")]
326 pub lame_duck_mode: bool,
327}
328
329#[derive(Clone, Debug, Eq, PartialEq)]
330pub(crate) enum ServerOp {
331 Ok,
332 Info(Box<ServerInfo>),
333 Ping,
334 Pong,
335 Error(ServerError),
336 Message {
337 sid: u64,
338 subject: Subject,
339 reply: Option<Subject>,
340 payload: Bytes,
341 headers: Option<HeaderMap>,
342 status: Option<StatusCode>,
343 description: Option<String>,
344 length: usize,
345 },
346}
347
348#[derive(Debug)]
350pub struct PublishMessage {
351 pub subject: Subject,
352 pub payload: Bytes,
353 pub reply: Option<Subject>,
354 pub headers: Option<HeaderMap>,
355}
356
357#[derive(Debug)]
359pub(crate) enum Command {
360 Publish(PublishMessage),
361 Request {
362 subject: Subject,
363 payload: Bytes,
364 respond: Subject,
365 headers: Option<HeaderMap>,
366 sender: oneshot::Sender<Message>,
367 },
368 Subscribe {
369 sid: u64,
370 subject: Subject,
371 queue_group: Option<String>,
372 sender: mpsc::Sender<Message>,
373 },
374 Unsubscribe {
375 sid: u64,
376 max: Option<u64>,
377 },
378 Flush {
379 observer: oneshot::Sender<()>,
380 },
381 Drain {
382 sid: Option<u64>,
383 },
384 Reconnect,
385}
386
387#[derive(Debug)]
389pub(crate) enum ClientOp {
390 Publish {
391 subject: Subject,
392 payload: Bytes,
393 respond: Option<Subject>,
394 headers: Option<HeaderMap>,
395 },
396 Subscribe {
397 sid: u64,
398 subject: Subject,
399 queue_group: Option<String>,
400 },
401 Unsubscribe {
402 sid: u64,
403 max: Option<u64>,
404 },
405 Ping,
406 Pong,
407 Connect(ConnectInfo),
408}
409
410#[derive(Debug)]
411struct Subscription {
412 subject: Subject,
413 sender: mpsc::Sender<Message>,
414 queue_group: Option<String>,
415 delivered: u64,
416 max: Option<u64>,
417 is_draining: bool,
418}
419
420#[derive(Debug)]
421struct Multiplexer {
422 subject: Subject,
423 prefix: Subject,
424 senders: HashMap<String, oneshot::Sender<Message>>,
425}
426
427pub(crate) struct ConnectionHandler {
429 connection: Connection,
430 connector: Connector,
431 subscriptions: HashMap<u64, Subscription>,
432 multiplexer: Option<Multiplexer>,
433 pending_pings: usize,
434 info_sender: tokio::sync::watch::Sender<ServerInfo>,
435 ping_interval: Interval,
436 should_reconnect: bool,
437 flush_observers: Vec<oneshot::Sender<()>>,
438 is_draining: bool,
439}
440
441impl ConnectionHandler {
442 pub(crate) fn new(
443 connection: Connection,
444 connector: Connector,
445 info_sender: tokio::sync::watch::Sender<ServerInfo>,
446 ping_period: Duration,
447 ) -> ConnectionHandler {
448 let mut ping_interval = interval(ping_period);
449 ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
450
451 ConnectionHandler {
452 connection,
453 connector,
454 subscriptions: HashMap::new(),
455 multiplexer: None,
456 pending_pings: 0,
457 info_sender,
458 ping_interval,
459 should_reconnect: false,
460 flush_observers: Vec::new(),
461 is_draining: false,
462 }
463 }
464
465 pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
466 struct ProcessFut<'a> {
467 handler: &'a mut ConnectionHandler,
468 receiver: &'a mut mpsc::Receiver<Command>,
469 recv_buf: &'a mut Vec<Command>,
470 }
471
472 enum ExitReason {
473 Disconnected(Option<io::Error>),
474 ReconnectRequested,
475 Closed,
476 }
477
478 impl ProcessFut<'_> {
479 const RECV_CHUNK_SIZE: usize = 16;
480
481 #[cold]
482 fn ping(&mut self) -> Poll<ExitReason> {
483 self.handler.pending_pings += 1;
484
485 if self.handler.pending_pings > MAX_PENDING_PINGS {
486 debug!(
487 pending_pings = self.handler.pending_pings,
488 max_pings = MAX_PENDING_PINGS,
489 "disconnecting due to too many pending pings"
490 );
491
492 Poll::Ready(ExitReason::Disconnected(None))
493 } else {
494 self.handler.connection.enqueue_write_op(&ClientOp::Ping);
495
496 Poll::Pending
497 }
498 }
499 }
500
501 impl Future for ProcessFut<'_> {
502 type Output = ExitReason;
503
504 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
518 while self.handler.ping_interval.poll_tick(cx).is_ready() {
522 if let Poll::Ready(exit) = self.ping() {
523 return Poll::Ready(exit);
524 }
525 }
526
527 loop {
528 match self.handler.connection.poll_read_op(cx) {
529 Poll::Pending => break,
530 Poll::Ready(Ok(Some(server_op))) => {
531 self.handler.handle_server_op(server_op);
532 }
533 Poll::Ready(Ok(None)) => {
534 return Poll::Ready(ExitReason::Disconnected(None))
535 }
536 Poll::Ready(Err(err)) => {
537 return Poll::Ready(ExitReason::Disconnected(Some(err)))
538 }
539 }
540 }
541
542 self.handler.subscriptions.retain(|_, s| !s.is_draining);
547
548 if self.handler.is_draining {
549 return Poll::Ready(ExitReason::Closed);
554 }
555
556 let mut made_progress = true;
562 loop {
563 while !self.handler.connection.is_write_buf_full() {
564 debug_assert!(self.recv_buf.is_empty());
565
566 let Self {
567 recv_buf,
568 handler,
569 receiver,
570 } = &mut *self;
571 match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
572 Poll::Pending => break,
573 Poll::Ready(1..) => {
574 made_progress = true;
575
576 for cmd in recv_buf.drain(..) {
577 handler.handle_command(cmd);
578 }
579 }
580 Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
582 }
583 }
584
585 if !mem::take(&mut made_progress) {
596 break;
597 }
598
599 match self.handler.connection.poll_write(cx) {
600 Poll::Pending => {
601 break;
603 }
604 Poll::Ready(Ok(())) => {
605 continue;
607 }
608 Poll::Ready(Err(err)) => {
609 return Poll::Ready(ExitReason::Disconnected(Some(err)))
610 }
611 }
612 }
613
614 if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
615 self.handler.connection.should_flush(),
616 self.handler.flush_observers.is_empty(),
617 ) {
618 match self.handler.connection.poll_flush(cx) {
619 Poll::Pending => {}
620 Poll::Ready(Ok(())) => {
621 for observer in self.handler.flush_observers.drain(..) {
622 let _ = observer.send(());
623 }
624 }
625 Poll::Ready(Err(err)) => {
626 return Poll::Ready(ExitReason::Disconnected(Some(err)))
627 }
628 }
629 }
630
631 if mem::take(&mut self.handler.should_reconnect) {
632 return Poll::Ready(ExitReason::ReconnectRequested);
633 }
634
635 Poll::Pending
636 }
637 }
638
639 let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
640 loop {
641 let process = ProcessFut {
642 handler: self,
643 receiver,
644 recv_buf: &mut recv_buf,
645 };
646 match process.await {
647 ExitReason::Disconnected(err) => {
648 debug!(error = ?err, "disconnected");
649 if self.handle_disconnect().await.is_err() {
650 break;
651 };
652 debug!("reconnected");
653 }
654 ExitReason::Closed => {
655 self.connector.events_tx.try_send(Event::Closed).ok();
657 break;
658 }
659 ExitReason::ReconnectRequested => {
660 debug!("reconnect requested");
661 self.connection.stream.shutdown().await.ok();
663 if self.handle_disconnect().await.is_err() {
664 break;
665 };
666 }
667 }
668 }
669 }
670
671 fn handle_server_op(&mut self, server_op: ServerOp) {
672 self.ping_interval.reset();
673
674 match server_op {
675 ServerOp::Ping => {
676 debug!("received PING");
677 self.connection.enqueue_write_op(&ClientOp::Pong);
678 }
679 ServerOp::Pong => {
680 debug!("received PONG");
681 self.pending_pings = self.pending_pings.saturating_sub(1);
682 }
683 ServerOp::Error(error) => {
684 debug!("received ERROR: {:?}", error);
685 self.connector
686 .events_tx
687 .try_send(Event::ServerError(error))
688 .ok();
689 }
690 ServerOp::Message {
691 sid,
692 subject,
693 reply,
694 payload,
695 headers,
696 status,
697 description,
698 length,
699 } => {
700 debug!("received MESSAGE: sid={}, subject={}", sid, subject);
701 self.connector
702 .connect_stats
703 .in_messages
704 .add(1, Ordering::Relaxed);
705
706 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
707 let message: Message = Message {
708 subject,
709 reply,
710 payload,
711 headers,
712 status,
713 description,
714 length,
715 };
716
717 match subscription.sender.try_send(message) {
720 Ok(_) => {
721 subscription.delivered += 1;
722 if let Some(max) = subscription.max {
726 if subscription.delivered.ge(&max) {
727 debug!("max messages reached for subscription {}", sid);
728 self.subscriptions.remove(&sid);
729 }
730 }
731 }
732 Err(mpsc::error::TrySendError::Full(_)) => {
733 debug!("slow consumer detected for subscription {}", sid);
734 self.connector
735 .events_tx
736 .try_send(Event::SlowConsumer(sid))
737 .ok();
738 }
739 Err(mpsc::error::TrySendError::Closed(_)) => {
740 debug!("subscription {} channel closed", sid);
741 self.subscriptions.remove(&sid);
742 self.connection
743 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
744 }
745 }
746 } else if sid == MULTIPLEXER_SID {
747 debug!("received message for multiplexer");
748 if let Some(multiplexer) = self.multiplexer.as_mut() {
749 let maybe_token =
750 subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
751
752 if let Some(token) = maybe_token {
753 if let Some(sender) = multiplexer.senders.remove(token) {
754 debug!("forwarding message to request with token {}", token);
755 let message = Message {
756 subject,
757 reply,
758 payload,
759 headers,
760 status,
761 description,
762 length,
763 };
764
765 let _ = sender.send(message);
766 }
767 }
768 }
769 }
770 }
771 ServerOp::Info(info) => {
773 debug!("received INFO: server_id={}", info.server_id);
774 if info.lame_duck_mode {
775 debug!("server in lame duck mode");
776 self.connector.events_tx.try_send(Event::LameDuckMode).ok();
777 }
778 }
779
780 _ => {
781 }
783 }
784 }
785
786 fn handle_command(&mut self, command: Command) {
787 self.ping_interval.reset();
788
789 match command {
790 Command::Unsubscribe { sid, max } => {
791 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
792 subscription.max = max;
793 match subscription.max {
794 Some(n) => {
795 if subscription.delivered >= n {
796 self.subscriptions.remove(&sid);
797 }
798 }
799 None => {
800 self.subscriptions.remove(&sid);
801 }
802 }
803
804 self.connection
805 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
806 }
807 }
808 Command::Flush { observer } => {
809 self.flush_observers.push(observer);
810 }
811 Command::Drain { sid } => {
812 let mut drain_sub = |sid: u64, sub: &mut Subscription| {
813 sub.is_draining = true;
814 self.connection
815 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
816 };
817
818 if let Some(sid) = sid {
819 if let Some(sub) = self.subscriptions.get_mut(&sid) {
820 drain_sub(sid, sub);
821 }
822 } else {
823 self.connector.events_tx.try_send(Event::Draining).ok();
825 self.is_draining = true;
826 for (&sid, sub) in self.subscriptions.iter_mut() {
827 drain_sub(sid, sub);
828 }
829 }
830 }
831 Command::Subscribe {
832 sid,
833 subject,
834 queue_group,
835 sender,
836 } => {
837 let subscription = Subscription {
838 sender,
839 delivered: 0,
840 max: None,
841 subject: subject.to_owned(),
842 queue_group: queue_group.to_owned(),
843 is_draining: false,
844 };
845
846 self.subscriptions.insert(sid, subscription);
847
848 self.connection.enqueue_write_op(&ClientOp::Subscribe {
849 sid,
850 subject,
851 queue_group,
852 });
853 }
854 Command::Request {
855 subject,
856 payload,
857 respond,
858 headers,
859 sender,
860 } => {
861 let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
862
863 let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
864 multiplexer
865 } else {
866 let prefix = Subject::from(format!("{}.{}.", prefix, nuid::next()));
867 let subject = Subject::from(format!("{}*", prefix));
868
869 self.connection.enqueue_write_op(&ClientOp::Subscribe {
870 sid: MULTIPLEXER_SID,
871 subject: subject.clone(),
872 queue_group: None,
873 });
874
875 self.multiplexer.insert(Multiplexer {
876 subject,
877 prefix,
878 senders: HashMap::new(),
879 })
880 };
881 self.connector
882 .connect_stats
883 .out_messages
884 .add(1, Ordering::Relaxed);
885
886 multiplexer.senders.insert(token.to_owned(), sender);
887
888 let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
889
890 let pub_op = ClientOp::Publish {
891 subject,
892 payload,
893 respond: Some(respond),
894 headers,
895 };
896
897 self.connection.enqueue_write_op(&pub_op);
898 }
899
900 Command::Publish(PublishMessage {
901 subject,
902 payload,
903 reply: respond,
904 headers,
905 }) => {
906 self.connector
907 .connect_stats
908 .out_messages
909 .add(1, Ordering::Relaxed);
910
911 let header_len = headers
912 .as_ref()
913 .map(|headers| headers.len())
914 .unwrap_or_default();
915
916 self.connector.connect_stats.out_bytes.add(
917 (payload.len()
918 + respond.as_ref().map_or_else(|| 0, |r| r.len())
919 + subject.len()
920 + header_len) as u64,
921 Ordering::Relaxed,
922 );
923
924 self.connection.enqueue_write_op(&ClientOp::Publish {
925 subject,
926 payload,
927 respond,
928 headers,
929 });
930 }
931
932 Command::Reconnect => {
933 self.should_reconnect = true;
934 }
935 }
936 }
937
938 async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
939 self.pending_pings = 0;
940 self.connector.events_tx.try_send(Event::Disconnected).ok();
941 self.connector.state_tx.send(State::Disconnected).ok();
942
943 self.handle_reconnect().await
944 }
945
946 async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
947 let (info, connection) = self.connector.connect().await?;
948 self.connection = connection;
949 let _ = self.info_sender.send(info);
950
951 self.subscriptions
952 .retain(|_, subscription| !subscription.sender.is_closed());
953
954 for (sid, subscription) in &self.subscriptions {
955 self.connection.enqueue_write_op(&ClientOp::Subscribe {
956 sid: *sid,
957 subject: subscription.subject.to_owned(),
958 queue_group: subscription.queue_group.to_owned(),
959 });
960 }
961
962 if let Some(multiplexer) = &self.multiplexer {
963 self.connection.enqueue_write_op(&ClientOp::Subscribe {
964 sid: MULTIPLEXER_SID,
965 subject: multiplexer.subject.to_owned(),
966 queue_group: None,
967 });
968 }
969 Ok(())
970 }
971}
972
973pub async fn connect_with_options<A: ToServerAddrs>(
989 addrs: A,
990 options: ConnectOptions,
991) -> Result<Client, ConnectError> {
992 let ping_period = options.ping_interval;
993
994 let (events_tx, mut events_rx) = mpsc::channel(128);
995 let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
996 let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));
998 let statistics = Arc::new(Statistics::default());
999
1000 let mut connector = Connector::new(
1001 addrs,
1002 ConnectorOptions {
1003 tls_required: options.tls_required,
1004 certificates: options.certificates,
1005 client_key: options.client_key,
1006 client_cert: options.client_cert,
1007 tls_client_config: options.tls_client_config,
1008 tls_first: options.tls_first,
1009 auth: options.auth,
1010 no_echo: options.no_echo,
1011 connection_timeout: options.connection_timeout,
1012 name: options.name,
1013 ignore_discovered_servers: options.ignore_discovered_servers,
1014 retain_servers_order: options.retain_servers_order,
1015 read_buffer_capacity: options.read_buffer_capacity,
1016 reconnect_delay_callback: options.reconnect_delay_callback,
1017 auth_callback: options.auth_callback,
1018 max_reconnects: options.max_reconnects,
1019 },
1020 events_tx,
1021 state_tx,
1022 max_payload.clone(),
1023 statistics.clone(),
1024 )
1025 .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
1026
1027 let mut info: ServerInfo = Default::default();
1028 let mut connection = None;
1029 if !options.retry_on_initial_connect {
1030 debug!("retry on initial connect failure is disabled");
1031 let (info_ok, connection_ok) = connector.try_connect().await?;
1032 connection = Some(connection_ok);
1033 info = info_ok;
1034 }
1035
1036 let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
1037 let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
1038
1039 let client = Client::new(
1040 info_watcher,
1041 state_rx,
1042 sender,
1043 options.subscription_capacity,
1044 options.inbox_prefix,
1045 options.request_timeout,
1046 max_payload,
1047 statistics,
1048 );
1049
1050 task::spawn(async move {
1051 while let Some(event) = events_rx.recv().await {
1052 tracing::info!("event: {}", event);
1053 if let Some(event_callback) = &options.event_callback {
1054 event_callback.call(event).await;
1055 }
1056 }
1057 });
1058
1059 task::spawn(async move {
1060 if connection.is_none() && options.retry_on_initial_connect {
1061 let (info, connection_ok) = match connector.connect().await {
1062 Ok((info, connection)) => (info, connection),
1063 Err(err) => {
1064 error!("connection closed: {}", err);
1065 return;
1066 }
1067 };
1068 info_sender.send(info).ok();
1069 connection = Some(connection_ok);
1070 }
1071 let connection = connection.unwrap();
1072 let mut connection_handler =
1073 ConnectionHandler::new(connection, connector, info_sender, ping_period);
1074 connection_handler.process(&mut receiver).await
1075 });
1076
1077 Ok(client)
1078}
1079
1080#[derive(Debug, Clone, PartialEq, Eq)]
1081pub enum Event {
1082 Connected,
1083 Disconnected,
1084 LameDuckMode,
1085 Draining,
1086 Closed,
1087 SlowConsumer(u64),
1088 ServerError(ServerError),
1089 ClientError(ClientError),
1090}
1091
1092impl fmt::Display for Event {
1093 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1094 match self {
1095 Event::Connected => write!(f, "connected"),
1096 Event::Disconnected => write!(f, "disconnected"),
1097 Event::LameDuckMode => write!(f, "lame duck mode detected"),
1098 Event::Draining => write!(f, "draining"),
1099 Event::Closed => write!(f, "closed"),
1100 Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"),
1101 Event::ServerError(err) => write!(f, "server error: {err}"),
1102 Event::ClientError(err) => write!(f, "client error: {err}"),
1103 }
1104 }
1105}
1106
1107pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1174 connect_with_options(addrs, ConnectOptions::default()).await
1175}
1176
1177#[derive(Debug, Clone, Copy, PartialEq)]
1178pub enum ConnectErrorKind {
1179 ServerParse,
1181 Dns,
1183 Authentication,
1185 AuthorizationViolation,
1187 TimedOut,
1189 Tls,
1191 Io,
1193 MaxReconnects,
1195}
1196
1197impl Display for ConnectErrorKind {
1198 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1199 match self {
1200 Self::ServerParse => write!(f, "failed to parse server or server list"),
1201 Self::Dns => write!(f, "DNS error"),
1202 Self::Authentication => write!(f, "failed signing nonce"),
1203 Self::AuthorizationViolation => write!(f, "authorization violation"),
1204 Self::TimedOut => write!(f, "timed out"),
1205 Self::Tls => write!(f, "TLS error"),
1206 Self::Io => write!(f, "IO error"),
1207 Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1208 }
1209 }
1210}
1211
1212pub type ConnectError = error::Error<ConnectErrorKind>;
1215
1216impl From<io::Error> for ConnectError {
1217 fn from(err: io::Error) -> Self {
1218 ConnectError::with_source(ConnectErrorKind::Io, err)
1219 }
1220}
1221
1222#[derive(Debug)]
1236pub struct Subscriber {
1237 sid: u64,
1238 receiver: mpsc::Receiver<Message>,
1239 sender: mpsc::Sender<Command>,
1240}
1241
1242impl Subscriber {
1243 fn new(
1244 sid: u64,
1245 sender: mpsc::Sender<Command>,
1246 receiver: mpsc::Receiver<Message>,
1247 ) -> Subscriber {
1248 Subscriber {
1249 sid,
1250 sender,
1251 receiver,
1252 }
1253 }
1254
1255 pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1270 self.sender
1271 .send(Command::Unsubscribe {
1272 sid: self.sid,
1273 max: None,
1274 })
1275 .await?;
1276 self.receiver.close();
1277 Ok(())
1278 }
1279
1280 pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1306 self.sender
1307 .send(Command::Unsubscribe {
1308 sid: self.sid,
1309 max: Some(unsub_after),
1310 })
1311 .await?;
1312 Ok(())
1313 }
1314
1315 pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
1348 self.sender
1349 .send(Command::Drain {
1350 sid: Some(self.sid),
1351 })
1352 .await?;
1353
1354 Ok(())
1355 }
1356}
1357
1358#[derive(Error, Debug, PartialEq)]
1359#[error("failed to send unsubscribe")]
1360pub struct UnsubscribeError(String);
1361
1362impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1363 fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1364 UnsubscribeError(err.to_string())
1365 }
1366}
1367
1368impl Drop for Subscriber {
1369 fn drop(&mut self) {
1370 self.receiver.close();
1371 tokio::spawn({
1372 let sender = self.sender.clone();
1373 let sid = self.sid;
1374 async move {
1375 sender
1376 .send(Command::Unsubscribe { sid, max: None })
1377 .await
1378 .ok();
1379 }
1380 });
1381 }
1382}
1383
1384impl Stream for Subscriber {
1385 type Item = Message;
1386
1387 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1388 self.receiver.poll_recv(cx)
1389 }
1390}
1391
1392#[derive(Clone, Debug, Eq, PartialEq)]
1393pub enum CallbackError {
1394 Client(ClientError),
1395 Server(ServerError),
1396}
1397impl std::fmt::Display for CallbackError {
1398 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1399 match self {
1400 Self::Client(error) => write!(f, "{error}"),
1401 Self::Server(error) => write!(f, "{error}"),
1402 }
1403 }
1404}
1405
1406impl From<ServerError> for CallbackError {
1407 fn from(server_error: ServerError) -> Self {
1408 CallbackError::Server(server_error)
1409 }
1410}
1411
1412impl From<ClientError> for CallbackError {
1413 fn from(client_error: ClientError) -> Self {
1414 CallbackError::Client(client_error)
1415 }
1416}
1417
1418#[derive(Clone, Debug, Eq, PartialEq, Error)]
1419pub enum ServerError {
1420 AuthorizationViolation,
1421 SlowConsumer(u64),
1422 Other(String),
1423}
1424
1425#[derive(Clone, Debug, Eq, PartialEq)]
1426pub enum ClientError {
1427 Other(String),
1428 MaxReconnects,
1429}
1430impl std::fmt::Display for ClientError {
1431 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1432 match self {
1433 Self::Other(error) => write!(f, "nats: {error}"),
1434 Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1435 }
1436 }
1437}
1438
1439impl ServerError {
1440 fn new(error: String) -> ServerError {
1441 match error.to_lowercase().as_str() {
1442 "authorization violation" => ServerError::AuthorizationViolation,
1443 _ => ServerError::Other(error),
1445 }
1446 }
1447}
1448
1449impl std::fmt::Display for ServerError {
1450 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1451 match self {
1452 Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1453 Self::SlowConsumer(sid) => write!(f, "nats: subscription {sid} is a slow consumer"),
1454 Self::Other(error) => write!(f, "nats: {error}"),
1455 }
1456 }
1457}
1458
1459#[derive(Clone, Debug, Serialize)]
1461pub struct ConnectInfo {
1462 pub verbose: bool,
1464
1465 pub pedantic: bool,
1468
1469 #[serde(rename = "jwt")]
1471 pub user_jwt: Option<String>,
1472
1473 pub nkey: Option<String>,
1475
1476 #[serde(rename = "sig")]
1478 pub signature: Option<String>,
1479
1480 pub name: Option<String>,
1482
1483 pub echo: bool,
1488
1489 pub lang: String,
1491
1492 pub version: String,
1494
1495 pub protocol: Protocol,
1500
1501 pub tls_required: bool,
1503
1504 pub user: Option<String>,
1506
1507 pub pass: Option<String>,
1509
1510 pub auth_token: Option<String>,
1512
1513 pub headers: bool,
1515
1516 pub no_responders: bool,
1518}
1519
1520#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1522#[repr(u8)]
1523pub enum Protocol {
1524 Original = 0,
1526 Dynamic = 1,
1528}
1529
1530#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1532pub struct ServerAddr(Url);
1533
1534impl FromStr for ServerAddr {
1535 type Err = io::Error;
1536
1537 fn from_str(input: &str) -> Result<Self, Self::Err> {
1541 let url: Url = if input.contains("://") {
1542 input.parse()
1543 } else {
1544 format!("nats://{input}").parse()
1545 }
1546 .map_err(|e| {
1547 io::Error::new(
1548 ErrorKind::InvalidInput,
1549 format!("NATS server URL is invalid: {e}"),
1550 )
1551 })?;
1552
1553 Self::from_url(url)
1554 }
1555}
1556
1557impl ServerAddr {
1558 pub fn from_url(url: Url) -> io::Result<Self> {
1560 if url.scheme() != "nats"
1561 && url.scheme() != "tls"
1562 && url.scheme() != "ws"
1563 && url.scheme() != "wss"
1564 {
1565 return Err(std::io::Error::new(
1566 ErrorKind::InvalidInput,
1567 format!("invalid scheme for NATS server URL: {}", url.scheme()),
1568 ));
1569 }
1570
1571 Ok(Self(url))
1572 }
1573
1574 pub fn into_inner(self) -> Url {
1576 self.0
1577 }
1578
1579 pub fn tls_required(&self) -> bool {
1581 self.0.scheme() == "tls"
1582 }
1583
1584 pub fn has_user_pass(&self) -> bool {
1586 self.0.username() != ""
1587 }
1588
1589 pub fn scheme(&self) -> &str {
1590 self.0.scheme()
1591 }
1592
1593 pub fn host(&self) -> &str {
1595 match self.0.host() {
1596 Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1597 Some(Host::Ipv6 { .. }) => {
1599 let host = self.0.host_str().unwrap();
1600 &host[1..host.len() - 1]
1601 }
1602 None => "",
1603 }
1604 }
1605
1606 pub fn is_websocket(&self) -> bool {
1607 self.0.scheme() == "ws" || self.0.scheme() == "wss"
1608 }
1609
1610 pub fn port(&self) -> u16 {
1613 self.0.port_or_known_default().unwrap_or(4222)
1614 }
1615
1616 pub fn username(&self) -> Option<&str> {
1618 let user = self.0.username();
1619 if user.is_empty() {
1620 None
1621 } else {
1622 Some(user)
1623 }
1624 }
1625
1626 pub fn password(&self) -> Option<&str> {
1628 self.0.password()
1629 }
1630
1631 pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1633 tokio::net::lookup_host((self.host(), self.port())).await
1634 }
1635}
1636
1637pub trait ToServerAddrs {
1642 type Iter: Iterator<Item = ServerAddr>;
1645
1646 fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1647}
1648
1649impl ToServerAddrs for ServerAddr {
1650 type Iter = option::IntoIter<ServerAddr>;
1651 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1652 Ok(Some(self.clone()).into_iter())
1653 }
1654}
1655
1656impl ToServerAddrs for str {
1657 type Iter = option::IntoIter<ServerAddr>;
1658 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1659 self.parse::<ServerAddr>()
1660 .map(|addr| Some(addr).into_iter())
1661 }
1662}
1663
1664impl ToServerAddrs for String {
1665 type Iter = option::IntoIter<ServerAddr>;
1666 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1667 (**self).to_server_addrs()
1668 }
1669}
1670
1671impl<T: AsRef<str>> ToServerAddrs for [T] {
1672 type Iter = std::vec::IntoIter<ServerAddr>;
1673 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1674 self.iter()
1675 .map(AsRef::as_ref)
1676 .map(str::parse)
1677 .collect::<io::Result<_>>()
1678 .map(Vec::into_iter)
1679 }
1680}
1681
1682impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1683 type Iter = std::vec::IntoIter<ServerAddr>;
1684 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1685 self.as_slice().to_server_addrs()
1686 }
1687}
1688
1689impl<'a> ToServerAddrs for &'a [ServerAddr] {
1690 type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1691
1692 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1693 Ok(self.iter().cloned())
1694 }
1695}
1696
1697impl ToServerAddrs for Vec<ServerAddr> {
1698 type Iter = std::vec::IntoIter<ServerAddr>;
1699
1700 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1701 Ok(self.clone().into_iter())
1702 }
1703}
1704
1705impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1706 type Iter = T::Iter;
1707 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1708 (**self).to_server_addrs()
1709 }
1710}
1711
1712pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1713 let subject_str = subject.as_ref();
1714 !subject_str.starts_with('.')
1715 && !subject_str.ends_with('.')
1716 && subject_str.bytes().all(|c| !c.is_ascii_whitespace())
1717}
1718macro_rules! from_with_timeout {
1719 ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1720 impl From<$origin> for $t {
1721 fn from(err: $origin) -> Self {
1722 match err.kind() {
1723 <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1724 _ => Self::with_source(<$k>::Other, err),
1725 }
1726 }
1727 }
1728 };
1729}
1730pub(crate) use from_with_timeout;
1731
1732use crate::connection::ShouldFlush;
1733
1734#[cfg(test)]
1735mod tests {
1736 use super::*;
1737
1738 #[test]
1739 fn server_address_ipv6() {
1740 let address = ServerAddr::from_str("nats://[::]").unwrap();
1741 assert_eq!(address.host(), "::")
1742 }
1743
1744 #[test]
1745 fn server_address_ipv4() {
1746 let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1747 assert_eq!(address.host(), "127.0.0.1")
1748 }
1749
1750 #[test]
1751 fn server_address_domain() {
1752 let address = ServerAddr::from_str("nats://example.com").unwrap();
1753 assert_eq!(address.host(), "example.com")
1754 }
1755
1756 #[test]
1757 fn to_server_addrs_vec_str() {
1758 let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1759 let mut addrs_iter = vec.to_server_addrs().unwrap();
1760 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1761 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1762 assert_eq!(addrs_iter.next(), None);
1763 }
1764
1765 #[test]
1766 fn to_server_addrs_arr_str() {
1767 let arr = ["nats://127.0.0.1", "nats://[::]"];
1768 let mut addrs_iter = arr.to_server_addrs().unwrap();
1769 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1770 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1771 assert_eq!(addrs_iter.next(), None);
1772 }
1773
1774 #[test]
1775 fn to_server_addrs_vec_string() {
1776 let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1777 let mut addrs_iter = vec.to_server_addrs().unwrap();
1778 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1779 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1780 assert_eq!(addrs_iter.next(), None);
1781 }
1782
1783 #[test]
1784 fn to_server_addrs_arr_string() {
1785 let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1786 let mut addrs_iter = arr.to_server_addrs().unwrap();
1787 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1788 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1789 assert_eq!(addrs_iter.next(), None);
1790 }
1791
1792 #[test]
1793 fn to_server_ports_arr_string() {
1794 for (arr, expected_port) in [
1795 (
1796 [
1797 "nats://127.0.0.1".to_string(),
1798 "nats://[::]".to_string(),
1799 "tls://127.0.0.1".to_string(),
1800 "tls://[::]".to_string(),
1801 ],
1802 4222,
1803 ),
1804 (
1805 [
1806 "ws://127.0.0.1:80".to_string(),
1807 "ws://[::]:80".to_string(),
1808 "ws://127.0.0.1".to_string(),
1809 "ws://[::]".to_string(),
1810 ],
1811 80,
1812 ),
1813 (
1814 [
1815 "wss://127.0.0.1".to_string(),
1816 "wss://[::]".to_string(),
1817 "wss://127.0.0.1:443".to_string(),
1818 "wss://[::]:443".to_string(),
1819 ],
1820 443,
1821 ),
1822 ] {
1823 let mut addrs_iter = arr.to_server_addrs().unwrap();
1824 assert_eq!(addrs_iter.next().unwrap().port(), expected_port);
1825 }
1826 }
1827}