1#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_auto_cfg))]
196
197use thiserror::Error;
198
199use futures_util::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use std::collections::HashMap;
206use std::fmt::Display;
207use std::future::Future;
208use std::iter;
209use std::mem;
210use std::net::SocketAddr;
211use std::option;
212use std::pin::Pin;
213use std::slice;
214use std::str::{self, FromStr};
215use std::sync::atomic::AtomicUsize;
216use std::sync::atomic::Ordering;
217use std::sync::Arc;
218use std::task::{Context, Poll};
219use tokio::io::ErrorKind;
220use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
221use url::{Host, Url};
222
223use bytes::Bytes;
224use serde::{Deserialize, Serialize};
225use serde_repr::{Deserialize_repr, Serialize_repr};
226use tokio::io;
227use tokio::sync::mpsc;
228use tokio::task;
229
230pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
231
232const VERSION: &str = env!("CARGO_PKG_VERSION");
233const LANG: &str = "rust";
234const MAX_PENDING_PINGS: usize = 2;
235const MULTIPLEXER_SID: u64 = 0;
236
237pub use tokio_rustls::rustls;
241
242use connection::{Connection, State};
243use connector::{Connector, ConnectorOptions};
244pub use header::{HeaderMap, HeaderName, HeaderValue};
245pub use subject::Subject;
246
247mod auth;
248pub(crate) mod auth_utils;
249pub mod client;
250pub mod connection;
251mod connector;
252mod options;
253
254pub use auth::Auth;
255pub use client::{
256 Client, PublishError, Request, RequestError, RequestErrorKind, Statistics, SubscribeError,
257};
258pub use options::{AuthError, ConnectOptions};
259
260mod crypto;
261pub mod error;
262pub mod header;
263pub mod jetstream;
264pub mod message;
265#[cfg(feature = "service")]
266pub mod service;
267pub mod status;
268pub mod subject;
269mod tls;
270
271pub use message::Message;
272pub use status::StatusCode;
273
274#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
277pub struct ServerInfo {
278 #[serde(default)]
280 pub server_id: String,
281 #[serde(default)]
283 pub server_name: String,
284 #[serde(default)]
286 pub host: String,
287 #[serde(default)]
289 pub port: u16,
290 #[serde(default)]
292 pub version: String,
293 #[serde(default)]
296 pub auth_required: bool,
297 #[serde(default)]
299 pub tls_required: bool,
300 #[serde(default)]
302 pub max_payload: usize,
303 #[serde(default)]
305 pub proto: i8,
306 #[serde(default)]
308 pub client_id: u64,
309 #[serde(default)]
311 pub go: String,
312 #[serde(default)]
314 pub nonce: String,
315 #[serde(default)]
317 pub connect_urls: Vec<String>,
318 #[serde(default)]
320 pub client_ip: String,
321 #[serde(default)]
323 pub headers: bool,
324 #[serde(default, rename = "ldm")]
326 pub lame_duck_mode: bool,
327 #[serde(default)]
329 pub cluster: Option<String>,
330 #[serde(default)]
332 pub domain: Option<String>,
333 #[serde(default)]
335 pub jetstream: bool,
336}
337
338#[derive(Clone, Debug, Eq, PartialEq)]
339pub(crate) enum ServerOp {
340 Ok,
341 Info(Box<ServerInfo>),
342 Ping,
343 Pong,
344 Error(ServerError),
345 Message {
346 sid: u64,
347 subject: Subject,
348 reply: Option<Subject>,
349 payload: Bytes,
350 headers: Option<HeaderMap>,
351 status: Option<StatusCode>,
352 description: Option<String>,
353 length: usize,
354 },
355}
356
357#[derive(Debug)]
359pub struct PublishMessage {
360 pub subject: Subject,
361 pub payload: Bytes,
362 pub reply: Option<Subject>,
363 pub headers: Option<HeaderMap>,
364}
365
366#[derive(Debug)]
368pub(crate) enum Command {
369 Publish(PublishMessage),
370 Request {
371 subject: Subject,
372 payload: Bytes,
373 respond: Subject,
374 headers: Option<HeaderMap>,
375 sender: oneshot::Sender<Message>,
376 },
377 Subscribe {
378 sid: u64,
379 subject: Subject,
380 queue_group: Option<String>,
381 sender: mpsc::Sender<Message>,
382 },
383 Unsubscribe {
384 sid: u64,
385 max: Option<u64>,
386 },
387 Flush {
388 observer: oneshot::Sender<()>,
389 },
390 Drain {
391 sid: Option<u64>,
392 },
393 Reconnect,
394}
395
396#[derive(Debug)]
398pub(crate) enum ClientOp {
399 Publish {
400 subject: Subject,
401 payload: Bytes,
402 respond: Option<Subject>,
403 headers: Option<HeaderMap>,
404 },
405 Subscribe {
406 sid: u64,
407 subject: Subject,
408 queue_group: Option<String>,
409 },
410 Unsubscribe {
411 sid: u64,
412 max: Option<u64>,
413 },
414 Ping,
415 Pong,
416 Connect(ConnectInfo),
417}
418
419#[derive(Debug)]
420struct Subscription {
421 subject: Subject,
422 sender: mpsc::Sender<Message>,
423 queue_group: Option<String>,
424 delivered: u64,
425 max: Option<u64>,
426 is_draining: bool,
427}
428
429#[derive(Debug)]
430struct Multiplexer {
431 subject: Subject,
432 prefix: Subject,
433 senders: HashMap<String, oneshot::Sender<Message>>,
434}
435
436pub(crate) struct ConnectionHandler {
438 connection: Connection,
439 connector: Connector,
440 subscriptions: HashMap<u64, Subscription>,
441 multiplexer: Option<Multiplexer>,
442 pending_pings: usize,
443 info_sender: tokio::sync::watch::Sender<ServerInfo>,
444 ping_interval: Interval,
445 should_reconnect: bool,
446 flush_observers: Vec<oneshot::Sender<()>>,
447 is_draining: bool,
448}
449
450impl ConnectionHandler {
451 pub(crate) fn new(
452 connection: Connection,
453 connector: Connector,
454 info_sender: tokio::sync::watch::Sender<ServerInfo>,
455 ping_period: Duration,
456 ) -> ConnectionHandler {
457 let mut ping_interval = interval(ping_period);
458 ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
459
460 ConnectionHandler {
461 connection,
462 connector,
463 subscriptions: HashMap::new(),
464 multiplexer: None,
465 pending_pings: 0,
466 info_sender,
467 ping_interval,
468 should_reconnect: false,
469 flush_observers: Vec::new(),
470 is_draining: false,
471 }
472 }
473
474 pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
475 struct ProcessFut<'a> {
476 handler: &'a mut ConnectionHandler,
477 receiver: &'a mut mpsc::Receiver<Command>,
478 recv_buf: &'a mut Vec<Command>,
479 }
480
481 enum ExitReason {
482 Disconnected(Option<io::Error>),
483 ReconnectRequested,
484 Closed,
485 }
486
487 impl ProcessFut<'_> {
488 const RECV_CHUNK_SIZE: usize = 16;
489
490 #[cold]
491 fn ping(&mut self) -> Poll<ExitReason> {
492 self.handler.pending_pings += 1;
493
494 if self.handler.pending_pings > MAX_PENDING_PINGS {
495 debug!(
496 pending_pings = self.handler.pending_pings,
497 max_pings = MAX_PENDING_PINGS,
498 "disconnecting due to too many pending pings"
499 );
500
501 Poll::Ready(ExitReason::Disconnected(None))
502 } else {
503 self.handler.connection.enqueue_write_op(&ClientOp::Ping);
504
505 Poll::Pending
506 }
507 }
508 }
509
510 impl Future for ProcessFut<'_> {
511 type Output = ExitReason;
512
513 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
527 while self.handler.ping_interval.poll_tick(cx).is_ready() {
531 if let Poll::Ready(exit) = self.ping() {
532 return Poll::Ready(exit);
533 }
534 }
535
536 loop {
537 match self.handler.connection.poll_read_op(cx) {
538 Poll::Pending => break,
539 Poll::Ready(Ok(Some(server_op))) => {
540 self.handler.handle_server_op(server_op);
541 }
542 Poll::Ready(Ok(None)) => {
543 return Poll::Ready(ExitReason::Disconnected(None))
544 }
545 Poll::Ready(Err(err)) => {
546 return Poll::Ready(ExitReason::Disconnected(Some(err)))
547 }
548 }
549 }
550
551 self.handler.subscriptions.retain(|_, s| !s.is_draining);
556
557 if self.handler.is_draining {
558 return Poll::Ready(ExitReason::Closed);
563 }
564
565 let mut made_progress = true;
571 loop {
572 while !self.handler.connection.is_write_buf_full() {
573 debug_assert!(self.recv_buf.is_empty());
574
575 let Self {
576 recv_buf,
577 handler,
578 receiver,
579 } = &mut *self;
580 match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
581 Poll::Pending => break,
582 Poll::Ready(1..) => {
583 made_progress = true;
584
585 for cmd in recv_buf.drain(..) {
586 handler.handle_command(cmd);
587 }
588 }
589 Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
591 }
592 }
593
594 if !mem::take(&mut made_progress) {
605 break;
606 }
607
608 match self.handler.connection.poll_write(cx) {
609 Poll::Pending => {
610 break;
612 }
613 Poll::Ready(Ok(())) => {
614 continue;
616 }
617 Poll::Ready(Err(err)) => {
618 return Poll::Ready(ExitReason::Disconnected(Some(err)))
619 }
620 }
621 }
622
623 if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
624 self.handler.connection.should_flush(),
625 self.handler.flush_observers.is_empty(),
626 ) {
627 match self.handler.connection.poll_flush(cx) {
628 Poll::Pending => {}
629 Poll::Ready(Ok(())) => {
630 for observer in self.handler.flush_observers.drain(..) {
631 let _ = observer.send(());
632 }
633 }
634 Poll::Ready(Err(err)) => {
635 return Poll::Ready(ExitReason::Disconnected(Some(err)))
636 }
637 }
638 }
639
640 if mem::take(&mut self.handler.should_reconnect) {
641 return Poll::Ready(ExitReason::ReconnectRequested);
642 }
643
644 Poll::Pending
645 }
646 }
647
648 let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
649 loop {
650 let process = ProcessFut {
651 handler: self,
652 receiver,
653 recv_buf: &mut recv_buf,
654 };
655 match process.await {
656 ExitReason::Disconnected(err) => {
657 debug!(error = ?err, "disconnected");
658 if self.handle_disconnect().await.is_err() {
659 break;
660 };
661 debug!("reconnected");
662 }
663 ExitReason::Closed => {
664 self.connector.events_tx.try_send(Event::Closed).ok();
666 break;
667 }
668 ExitReason::ReconnectRequested => {
669 debug!("reconnect requested");
670 self.connection.stream.shutdown().await.ok();
672 if self.handle_disconnect().await.is_err() {
673 break;
674 };
675 }
676 }
677 }
678 }
679
680 fn handle_server_op(&mut self, server_op: ServerOp) {
681 self.ping_interval.reset();
682
683 match server_op {
684 ServerOp::Ping => {
685 debug!("received PING");
686 self.connection.enqueue_write_op(&ClientOp::Pong);
687 }
688 ServerOp::Pong => {
689 debug!("received PONG");
690 self.pending_pings = self.pending_pings.saturating_sub(1);
691 }
692 ServerOp::Error(error) => {
693 debug!("received ERROR: {:?}", error);
694 self.connector
695 .events_tx
696 .try_send(Event::ServerError(error))
697 .ok();
698 }
699 ServerOp::Message {
700 sid,
701 subject,
702 reply,
703 payload,
704 headers,
705 status,
706 description,
707 length,
708 } => {
709 debug!("received MESSAGE: sid={}, subject={}", sid, subject);
710 self.connector
711 .connect_stats
712 .in_messages
713 .add(1, Ordering::Relaxed);
714
715 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
716 let message: Message = Message {
717 subject,
718 reply,
719 payload,
720 headers,
721 status,
722 description,
723 length,
724 };
725
726 match subscription.sender.try_send(message) {
729 Ok(_) => {
730 subscription.delivered += 1;
731 if let Some(max) = subscription.max {
735 if subscription.delivered.ge(&max) {
736 debug!("max messages reached for subscription {}", sid);
737 self.subscriptions.remove(&sid);
738 }
739 }
740 }
741 Err(mpsc::error::TrySendError::Full(_)) => {
742 debug!("slow consumer detected for subscription {}", sid);
743 self.connector
744 .events_tx
745 .try_send(Event::SlowConsumer(sid))
746 .ok();
747 }
748 Err(mpsc::error::TrySendError::Closed(_)) => {
749 debug!("subscription {} channel closed", sid);
750 self.subscriptions.remove(&sid);
751 self.connection
752 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
753 }
754 }
755 } else if sid == MULTIPLEXER_SID {
756 debug!("received message for multiplexer");
757 if let Some(multiplexer) = self.multiplexer.as_mut() {
758 let maybe_token =
759 subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
760
761 if let Some(token) = maybe_token {
762 if let Some(sender) = multiplexer.senders.remove(token) {
763 debug!("forwarding message to request with token {}", token);
764 let message = Message {
765 subject,
766 reply,
767 payload,
768 headers,
769 status,
770 description,
771 length,
772 };
773
774 let _ = sender.send(message);
775 }
776 }
777 }
778 }
779 }
780 ServerOp::Info(info) => {
782 debug!("received INFO: server_id={}", info.server_id);
783 if info.lame_duck_mode {
784 debug!("server in lame duck mode");
785 self.connector.events_tx.try_send(Event::LameDuckMode).ok();
786 }
787 }
788
789 _ => {
790 }
792 }
793 }
794
795 fn handle_command(&mut self, command: Command) {
796 self.ping_interval.reset();
797
798 match command {
799 Command::Unsubscribe { sid, max } => {
800 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
801 subscription.max = max;
802 match subscription.max {
803 Some(n) => {
804 if subscription.delivered >= n {
805 self.subscriptions.remove(&sid);
806 }
807 }
808 None => {
809 self.subscriptions.remove(&sid);
810 }
811 }
812
813 self.connection
814 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
815 }
816 }
817 Command::Flush { observer } => {
818 self.flush_observers.push(observer);
819 }
820 Command::Drain { sid } => {
821 let mut drain_sub = |sid: u64, sub: &mut Subscription| {
822 sub.is_draining = true;
823 self.connection
824 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
825 };
826
827 if let Some(sid) = sid {
828 if let Some(sub) = self.subscriptions.get_mut(&sid) {
829 drain_sub(sid, sub);
830 }
831 } else {
832 self.connector.events_tx.try_send(Event::Draining).ok();
834 self.is_draining = true;
835 for (&sid, sub) in self.subscriptions.iter_mut() {
836 drain_sub(sid, sub);
837 }
838 }
839 }
840 Command::Subscribe {
841 sid,
842 subject,
843 queue_group,
844 sender,
845 } => {
846 let subscription = Subscription {
847 sender,
848 delivered: 0,
849 max: None,
850 subject: subject.to_owned(),
851 queue_group: queue_group.to_owned(),
852 is_draining: false,
853 };
854
855 self.subscriptions.insert(sid, subscription);
856
857 self.connection.enqueue_write_op(&ClientOp::Subscribe {
858 sid,
859 subject,
860 queue_group,
861 });
862 }
863 Command::Request {
864 subject,
865 payload,
866 respond,
867 headers,
868 sender,
869 } => {
870 let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
871
872 let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
873 multiplexer
874 } else {
875 let prefix = Subject::from(format!("{}.{}.", prefix, nuid::next()));
876 let subject = Subject::from(format!("{prefix}*"));
877
878 self.connection.enqueue_write_op(&ClientOp::Subscribe {
879 sid: MULTIPLEXER_SID,
880 subject: subject.clone(),
881 queue_group: None,
882 });
883
884 self.multiplexer.insert(Multiplexer {
885 subject,
886 prefix,
887 senders: HashMap::new(),
888 })
889 };
890 self.connector
891 .connect_stats
892 .out_messages
893 .add(1, Ordering::Relaxed);
894
895 multiplexer.senders.insert(token.to_owned(), sender);
896
897 let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
898
899 let pub_op = ClientOp::Publish {
900 subject,
901 payload,
902 respond: Some(respond),
903 headers,
904 };
905
906 self.connection.enqueue_write_op(&pub_op);
907 }
908
909 Command::Publish(PublishMessage {
910 subject,
911 payload,
912 reply: respond,
913 headers,
914 }) => {
915 self.connector
916 .connect_stats
917 .out_messages
918 .add(1, Ordering::Relaxed);
919
920 let header_len = headers
921 .as_ref()
922 .map(|headers| headers.len())
923 .unwrap_or_default();
924
925 self.connector.connect_stats.out_bytes.add(
926 (payload.len()
927 + respond.as_ref().map_or_else(|| 0, |r| r.len())
928 + subject.len()
929 + header_len) as u64,
930 Ordering::Relaxed,
931 );
932
933 self.connection.enqueue_write_op(&ClientOp::Publish {
934 subject,
935 payload,
936 respond,
937 headers,
938 });
939 }
940
941 Command::Reconnect => {
942 self.should_reconnect = true;
943 }
944 }
945 }
946
947 async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
948 self.pending_pings = 0;
949 self.connector.events_tx.try_send(Event::Disconnected).ok();
950 self.connector.state_tx.send(State::Disconnected).ok();
951
952 self.handle_reconnect().await
953 }
954
955 async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
956 let (info, connection) = self.connector.connect().await?;
957 self.connection = connection;
958 let _ = self.info_sender.send(info);
959
960 self.subscriptions
961 .retain(|_, subscription| !subscription.sender.is_closed());
962
963 for (sid, subscription) in &self.subscriptions {
964 self.connection.enqueue_write_op(&ClientOp::Subscribe {
965 sid: *sid,
966 subject: subscription.subject.to_owned(),
967 queue_group: subscription.queue_group.to_owned(),
968 });
969 }
970
971 if let Some(multiplexer) = &self.multiplexer {
972 self.connection.enqueue_write_op(&ClientOp::Subscribe {
973 sid: MULTIPLEXER_SID,
974 subject: multiplexer.subject.to_owned(),
975 queue_group: None,
976 });
977 }
978 Ok(())
979 }
980}
981
982pub async fn connect_with_options<A: ToServerAddrs>(
998 addrs: A,
999 options: ConnectOptions,
1000) -> Result<Client, ConnectError> {
1001 let ping_period = options.ping_interval;
1002
1003 let (events_tx, mut events_rx) = mpsc::channel(128);
1004 let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
1005 let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));
1007 let statistics = Arc::new(Statistics::default());
1008
1009 let mut connector = Connector::new(
1010 addrs,
1011 ConnectorOptions {
1012 tls_required: options.tls_required,
1013 certificates: options.certificates,
1014 client_key: options.client_key,
1015 client_cert: options.client_cert,
1016 tls_client_config: options.tls_client_config,
1017 tls_first: options.tls_first,
1018 auth: options.auth,
1019 no_echo: options.no_echo,
1020 connection_timeout: options.connection_timeout,
1021 name: options.name,
1022 ignore_discovered_servers: options.ignore_discovered_servers,
1023 retain_servers_order: options.retain_servers_order,
1024 read_buffer_capacity: options.read_buffer_capacity,
1025 reconnect_delay_callback: options.reconnect_delay_callback,
1026 auth_callback: options.auth_callback,
1027 max_reconnects: options.max_reconnects,
1028 },
1029 events_tx,
1030 state_tx,
1031 max_payload.clone(),
1032 statistics.clone(),
1033 )
1034 .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
1035
1036 let mut info: ServerInfo = Default::default();
1037 let mut connection = None;
1038 if !options.retry_on_initial_connect {
1039 debug!("retry on initial connect failure is disabled");
1040 let (info_ok, connection_ok) = connector.try_connect().await?;
1041 connection = Some(connection_ok);
1042 info = info_ok;
1043 }
1044
1045 let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
1046 let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
1047
1048 let client = Client::new(
1049 info_watcher,
1050 state_rx,
1051 sender,
1052 options.subscription_capacity,
1053 options.inbox_prefix,
1054 options.request_timeout,
1055 max_payload,
1056 statistics,
1057 );
1058
1059 task::spawn(async move {
1060 while let Some(event) = events_rx.recv().await {
1061 tracing::info!("event: {}", event);
1062 if let Some(event_callback) = &options.event_callback {
1063 event_callback.call(event).await;
1064 }
1065 }
1066 });
1067
1068 task::spawn(async move {
1069 if connection.is_none() && options.retry_on_initial_connect {
1070 let (info, connection_ok) = match connector.connect().await {
1071 Ok((info, connection)) => (info, connection),
1072 Err(err) => {
1073 error!("connection closed: {}", err);
1074 return;
1075 }
1076 };
1077 info_sender.send(info).ok();
1078 connection = Some(connection_ok);
1079 }
1080 let connection = connection.unwrap();
1081 let mut connection_handler =
1082 ConnectionHandler::new(connection, connector, info_sender, ping_period);
1083 connection_handler.process(&mut receiver).await
1084 });
1085
1086 Ok(client)
1087}
1088
1089#[derive(Debug, Clone, PartialEq, Eq)]
1090pub enum Event {
1091 Connected,
1092 Disconnected,
1093 LameDuckMode,
1094 Draining,
1095 Closed,
1096 SlowConsumer(u64),
1097 ServerError(ServerError),
1098 ClientError(ClientError),
1099}
1100
1101impl fmt::Display for Event {
1102 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1103 match self {
1104 Event::Connected => write!(f, "connected"),
1105 Event::Disconnected => write!(f, "disconnected"),
1106 Event::LameDuckMode => write!(f, "lame duck mode detected"),
1107 Event::Draining => write!(f, "draining"),
1108 Event::Closed => write!(f, "closed"),
1109 Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"),
1110 Event::ServerError(err) => write!(f, "server error: {err}"),
1111 Event::ClientError(err) => write!(f, "client error: {err}"),
1112 }
1113 }
1114}
1115
1116pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1183 connect_with_options(addrs, ConnectOptions::default()).await
1184}
1185
1186#[derive(Debug, Clone, Copy, PartialEq)]
1187pub enum ConnectErrorKind {
1188 ServerParse,
1190 Dns,
1192 Authentication,
1194 AuthorizationViolation,
1196 TimedOut,
1198 Tls,
1200 Io,
1202 MaxReconnects,
1204}
1205
1206impl Display for ConnectErrorKind {
1207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1208 match self {
1209 Self::ServerParse => write!(f, "failed to parse server or server list"),
1210 Self::Dns => write!(f, "DNS error"),
1211 Self::Authentication => write!(f, "failed signing nonce"),
1212 Self::AuthorizationViolation => write!(f, "authorization violation"),
1213 Self::TimedOut => write!(f, "timed out"),
1214 Self::Tls => write!(f, "TLS error"),
1215 Self::Io => write!(f, "IO error"),
1216 Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1217 }
1218 }
1219}
1220
1221pub type ConnectError = error::Error<ConnectErrorKind>;
1224
1225impl From<io::Error> for ConnectError {
1226 fn from(err: io::Error) -> Self {
1227 ConnectError::with_source(ConnectErrorKind::Io, err)
1228 }
1229}
1230
1231#[derive(Debug)]
1245pub struct Subscriber {
1246 sid: u64,
1247 receiver: mpsc::Receiver<Message>,
1248 sender: mpsc::Sender<Command>,
1249}
1250
1251impl Subscriber {
1252 fn new(
1253 sid: u64,
1254 sender: mpsc::Sender<Command>,
1255 receiver: mpsc::Receiver<Message>,
1256 ) -> Subscriber {
1257 Subscriber {
1258 sid,
1259 sender,
1260 receiver,
1261 }
1262 }
1263
1264 pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1279 self.sender
1280 .send(Command::Unsubscribe {
1281 sid: self.sid,
1282 max: None,
1283 })
1284 .await?;
1285 self.receiver.close();
1286 Ok(())
1287 }
1288
1289 pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1315 self.sender
1316 .send(Command::Unsubscribe {
1317 sid: self.sid,
1318 max: Some(unsub_after),
1319 })
1320 .await?;
1321 Ok(())
1322 }
1323
1324 pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
1357 self.sender
1358 .send(Command::Drain {
1359 sid: Some(self.sid),
1360 })
1361 .await?;
1362
1363 Ok(())
1364 }
1365}
1366
1367#[derive(Error, Debug, PartialEq)]
1368#[error("failed to send unsubscribe")]
1369pub struct UnsubscribeError(String);
1370
1371impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1372 fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1373 UnsubscribeError(err.to_string())
1374 }
1375}
1376
1377impl Drop for Subscriber {
1378 fn drop(&mut self) {
1379 self.receiver.close();
1380 tokio::spawn({
1381 let sender = self.sender.clone();
1382 let sid = self.sid;
1383 async move {
1384 sender
1385 .send(Command::Unsubscribe { sid, max: None })
1386 .await
1387 .ok();
1388 }
1389 });
1390 }
1391}
1392
1393impl Stream for Subscriber {
1394 type Item = Message;
1395
1396 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1397 self.receiver.poll_recv(cx)
1398 }
1399}
1400
1401#[derive(Clone, Debug, Eq, PartialEq)]
1402pub enum CallbackError {
1403 Client(ClientError),
1404 Server(ServerError),
1405}
1406impl std::fmt::Display for CallbackError {
1407 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1408 match self {
1409 Self::Client(error) => write!(f, "{error}"),
1410 Self::Server(error) => write!(f, "{error}"),
1411 }
1412 }
1413}
1414
1415impl From<ServerError> for CallbackError {
1416 fn from(server_error: ServerError) -> Self {
1417 CallbackError::Server(server_error)
1418 }
1419}
1420
1421impl From<ClientError> for CallbackError {
1422 fn from(client_error: ClientError) -> Self {
1423 CallbackError::Client(client_error)
1424 }
1425}
1426
1427#[derive(Clone, Debug, Eq, PartialEq, Error)]
1428pub enum ServerError {
1429 AuthorizationViolation,
1430 SlowConsumer(u64),
1431 Other(String),
1432}
1433
1434#[derive(Clone, Debug, Eq, PartialEq)]
1435pub enum ClientError {
1436 Other(String),
1437 MaxReconnects,
1438}
1439impl std::fmt::Display for ClientError {
1440 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1441 match self {
1442 Self::Other(error) => write!(f, "nats: {error}"),
1443 Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1444 }
1445 }
1446}
1447
1448impl ServerError {
1449 fn new(error: String) -> ServerError {
1450 match error.to_lowercase().as_str() {
1451 "authorization violation" => ServerError::AuthorizationViolation,
1452 _ => ServerError::Other(error),
1454 }
1455 }
1456}
1457
1458impl std::fmt::Display for ServerError {
1459 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1460 match self {
1461 Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1462 Self::SlowConsumer(sid) => write!(f, "nats: subscription {sid} is a slow consumer"),
1463 Self::Other(error) => write!(f, "nats: {error}"),
1464 }
1465 }
1466}
1467
1468#[derive(Clone, Debug, Serialize)]
1470pub struct ConnectInfo {
1471 pub verbose: bool,
1473
1474 pub pedantic: bool,
1477
1478 #[serde(rename = "jwt")]
1480 pub user_jwt: Option<String>,
1481
1482 pub nkey: Option<String>,
1484
1485 #[serde(rename = "sig")]
1487 pub signature: Option<String>,
1488
1489 pub name: Option<String>,
1491
1492 pub echo: bool,
1497
1498 pub lang: String,
1500
1501 pub version: String,
1503
1504 pub protocol: Protocol,
1509
1510 pub tls_required: bool,
1512
1513 pub user: Option<String>,
1515
1516 pub pass: Option<String>,
1518
1519 pub auth_token: Option<String>,
1521
1522 pub headers: bool,
1524
1525 pub no_responders: bool,
1527}
1528
1529#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1531#[repr(u8)]
1532pub enum Protocol {
1533 Original = 0,
1535 Dynamic = 1,
1537}
1538
1539#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1541pub struct ServerAddr(Url);
1542
1543impl FromStr for ServerAddr {
1544 type Err = io::Error;
1545
1546 fn from_str(input: &str) -> Result<Self, Self::Err> {
1550 let url: Url = if input.contains("://") {
1551 input.parse()
1552 } else {
1553 format!("nats://{input}").parse()
1554 }
1555 .map_err(|e| {
1556 io::Error::new(
1557 ErrorKind::InvalidInput,
1558 format!("NATS server URL is invalid: {e}"),
1559 )
1560 })?;
1561
1562 Self::from_url(url)
1563 }
1564}
1565
1566impl ServerAddr {
1567 pub fn from_url(url: Url) -> io::Result<Self> {
1569 if url.scheme() != "nats"
1570 && url.scheme() != "tls"
1571 && url.scheme() != "ws"
1572 && url.scheme() != "wss"
1573 {
1574 return Err(std::io::Error::new(
1575 ErrorKind::InvalidInput,
1576 format!("invalid scheme for NATS server URL: {}", url.scheme()),
1577 ));
1578 }
1579
1580 Ok(Self(url))
1581 }
1582
1583 pub fn into_inner(self) -> Url {
1585 self.0
1586 }
1587
1588 pub fn tls_required(&self) -> bool {
1590 self.0.scheme() == "tls"
1591 }
1592
1593 pub fn has_user_pass(&self) -> bool {
1595 self.0.username() != ""
1596 }
1597
1598 pub fn scheme(&self) -> &str {
1599 self.0.scheme()
1600 }
1601
1602 pub fn host(&self) -> &str {
1604 match self.0.host() {
1605 Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1606 Some(Host::Ipv6 { .. }) => {
1608 let host = self.0.host_str().unwrap();
1609 &host[1..host.len() - 1]
1610 }
1611 None => "",
1612 }
1613 }
1614
1615 pub fn is_websocket(&self) -> bool {
1616 self.0.scheme() == "ws" || self.0.scheme() == "wss"
1617 }
1618
1619 pub fn port(&self) -> u16 {
1622 self.0.port_or_known_default().unwrap_or(4222)
1623 }
1624
1625 pub fn as_url_str(&self) -> &str {
1627 self.0.as_str()
1628 }
1629
1630 pub fn username(&self) -> Option<&str> {
1632 let user = self.0.username();
1633 if user.is_empty() {
1634 None
1635 } else {
1636 Some(user)
1637 }
1638 }
1639
1640 pub fn password(&self) -> Option<&str> {
1642 self.0.password()
1643 }
1644
1645 pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1647 tokio::net::lookup_host((self.host(), self.port())).await
1648 }
1649}
1650
1651pub trait ToServerAddrs {
1656 type Iter: Iterator<Item = ServerAddr>;
1659
1660 fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1661}
1662
1663impl ToServerAddrs for ServerAddr {
1664 type Iter = option::IntoIter<ServerAddr>;
1665 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1666 Ok(Some(self.clone()).into_iter())
1667 }
1668}
1669
1670impl ToServerAddrs for str {
1671 type Iter = option::IntoIter<ServerAddr>;
1672 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1673 self.parse::<ServerAddr>()
1674 .map(|addr| Some(addr).into_iter())
1675 }
1676}
1677
1678impl ToServerAddrs for String {
1679 type Iter = option::IntoIter<ServerAddr>;
1680 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1681 (**self).to_server_addrs()
1682 }
1683}
1684
1685impl<T: AsRef<str>> ToServerAddrs for [T] {
1686 type Iter = std::vec::IntoIter<ServerAddr>;
1687 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1688 self.iter()
1689 .map(AsRef::as_ref)
1690 .map(str::parse)
1691 .collect::<io::Result<_>>()
1692 .map(Vec::into_iter)
1693 }
1694}
1695
1696impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1697 type Iter = std::vec::IntoIter<ServerAddr>;
1698 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1699 self.as_slice().to_server_addrs()
1700 }
1701}
1702
1703impl<'a> ToServerAddrs for &'a [ServerAddr] {
1704 type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1705
1706 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1707 Ok(self.iter().cloned())
1708 }
1709}
1710
1711impl ToServerAddrs for Vec<ServerAddr> {
1712 type Iter = std::vec::IntoIter<ServerAddr>;
1713
1714 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1715 Ok(self.clone().into_iter())
1716 }
1717}
1718
1719impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1720 type Iter = T::Iter;
1721 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1722 (**self).to_server_addrs()
1723 }
1724}
1725
1726pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1727 let subject_str = subject.as_ref();
1728 !subject_str.starts_with('.')
1729 && !subject_str.ends_with('.')
1730 && subject_str.bytes().all(|c| !c.is_ascii_whitespace())
1731}
1732macro_rules! from_with_timeout {
1733 ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1734 impl From<$origin> for $t {
1735 fn from(err: $origin) -> Self {
1736 match err.kind() {
1737 <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1738 _ => Self::with_source(<$k>::Other, err),
1739 }
1740 }
1741 }
1742 };
1743}
1744pub(crate) use from_with_timeout;
1745
1746use crate::connection::ShouldFlush;
1747
1748#[cfg(test)]
1749mod tests {
1750 use super::*;
1751
1752 #[test]
1753 fn server_address_ipv6() {
1754 let address = ServerAddr::from_str("nats://[::]").unwrap();
1755 assert_eq!(address.host(), "::")
1756 }
1757
1758 #[test]
1759 fn server_address_ipv4() {
1760 let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1761 assert_eq!(address.host(), "127.0.0.1")
1762 }
1763
1764 #[test]
1765 fn server_address_domain() {
1766 let address = ServerAddr::from_str("nats://example.com").unwrap();
1767 assert_eq!(address.host(), "example.com")
1768 }
1769
1770 #[test]
1771 fn to_server_addrs_vec_str() {
1772 let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1773 let mut addrs_iter = vec.to_server_addrs().unwrap();
1774 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1775 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1776 assert_eq!(addrs_iter.next(), None);
1777 }
1778
1779 #[test]
1780 fn to_server_addrs_arr_str() {
1781 let arr = ["nats://127.0.0.1", "nats://[::]"];
1782 let mut addrs_iter = arr.to_server_addrs().unwrap();
1783 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1784 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1785 assert_eq!(addrs_iter.next(), None);
1786 }
1787
1788 #[test]
1789 fn to_server_addrs_vec_string() {
1790 let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1791 let mut addrs_iter = vec.to_server_addrs().unwrap();
1792 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1793 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1794 assert_eq!(addrs_iter.next(), None);
1795 }
1796
1797 #[test]
1798 fn to_server_addrs_arr_string() {
1799 let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1800 let mut addrs_iter = arr.to_server_addrs().unwrap();
1801 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1802 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1803 assert_eq!(addrs_iter.next(), None);
1804 }
1805
1806 #[test]
1807 fn to_server_ports_arr_string() {
1808 for (arr, expected_port) in [
1809 (
1810 [
1811 "nats://127.0.0.1".to_string(),
1812 "nats://[::]".to_string(),
1813 "tls://127.0.0.1".to_string(),
1814 "tls://[::]".to_string(),
1815 ],
1816 4222,
1817 ),
1818 (
1819 [
1820 "ws://127.0.0.1:80".to_string(),
1821 "ws://[::]:80".to_string(),
1822 "ws://127.0.0.1".to_string(),
1823 "ws://[::]".to_string(),
1824 ],
1825 80,
1826 ),
1827 (
1828 [
1829 "wss://127.0.0.1".to_string(),
1830 "wss://[::]".to_string(),
1831 "wss://127.0.0.1:443".to_string(),
1832 "wss://[::]:443".to_string(),
1833 ],
1834 443,
1835 ),
1836 ] {
1837 let mut addrs_iter = arr.to_server_addrs().unwrap();
1838 assert_eq!(addrs_iter.next().unwrap().port(), expected_port);
1839 }
1840 }
1841}