1#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_auto_cfg))]
196
197use thiserror::Error;
198
199use futures::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use std::collections::HashMap;
206use std::fmt::Display;
207use std::future::Future;
208use std::iter;
209use std::mem;
210use std::net::SocketAddr;
211use std::option;
212use std::pin::Pin;
213use std::slice;
214use std::str::{self, FromStr};
215use std::sync::atomic::AtomicUsize;
216use std::sync::atomic::Ordering;
217use std::sync::Arc;
218use std::task::{Context, Poll};
219use tokio::io::ErrorKind;
220use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
221use url::{Host, Url};
222
223use bytes::Bytes;
224use serde::{Deserialize, Serialize};
225use serde_repr::{Deserialize_repr, Serialize_repr};
226use tokio::io;
227use tokio::sync::mpsc;
228use tokio::task;
229
230pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
231
232const VERSION: &str = env!("CARGO_PKG_VERSION");
233const LANG: &str = "rust";
234const MAX_PENDING_PINGS: usize = 2;
235const MULTIPLEXER_SID: u64 = 0;
236
237pub use tokio_rustls::rustls;
241
242use connection::{Connection, State};
243use connector::{Connector, ConnectorOptions};
244pub use header::{HeaderMap, HeaderName, HeaderValue};
245pub use subject::Subject;
246
247mod auth;
248pub(crate) mod auth_utils;
249pub mod client;
250pub mod connection;
251mod connector;
252pub use connector::Dialer;
254mod options;
256
257pub use auth::Auth;
258pub use client::{
259 Client, PublishError, Request, RequestError, RequestErrorKind, Statistics, SubscribeError,
260};
261pub use options::{AuthError, ConnectOptions};
262
263mod crypto;
264pub mod error;
265pub mod header;
266pub mod jetstream;
267pub mod message;
268#[cfg(feature = "service")]
269pub mod service;
270pub mod status;
271pub mod subject;
272mod tls;
273
274pub use message::Message;
275pub use status::StatusCode;
276
277#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
280pub struct ServerInfo {
281 #[serde(default)]
283 pub server_id: String,
284 #[serde(default)]
286 pub server_name: String,
287 #[serde(default)]
289 pub host: String,
290 #[serde(default)]
292 pub port: u16,
293 #[serde(default)]
295 pub version: String,
296 #[serde(default)]
299 pub auth_required: bool,
300 #[serde(default)]
302 pub tls_required: bool,
303 #[serde(default)]
305 pub max_payload: usize,
306 #[serde(default)]
308 pub proto: i8,
309 #[serde(default)]
311 pub client_id: u64,
312 #[serde(default)]
314 pub go: String,
315 #[serde(default)]
317 pub nonce: String,
318 #[serde(default)]
320 pub connect_urls: Vec<String>,
321 #[serde(default)]
323 pub client_ip: String,
324 #[serde(default)]
326 pub headers: bool,
327 #[serde(default, rename = "ldm")]
329 pub lame_duck_mode: bool,
330}
331
332#[derive(Clone, Debug, Eq, PartialEq)]
333pub(crate) enum ServerOp {
334 Ok,
335 Info(Box<ServerInfo>),
336 Ping,
337 Pong,
338 Error(ServerError),
339 Message {
340 sid: u64,
341 subject: Subject,
342 reply: Option<Subject>,
343 payload: Bytes,
344 headers: Option<HeaderMap>,
345 status: Option<StatusCode>,
346 description: Option<String>,
347 length: usize,
348 },
349}
350
351#[derive(Debug)]
353pub struct PublishMessage {
354 pub subject: Subject,
355 pub payload: Bytes,
356 pub reply: Option<Subject>,
357 pub headers: Option<HeaderMap>,
358}
359
360#[derive(Debug)]
362pub(crate) enum Command {
363 Publish(PublishMessage),
364 Request {
365 subject: Subject,
366 payload: Bytes,
367 respond: Subject,
368 headers: Option<HeaderMap>,
369 sender: oneshot::Sender<Message>,
370 },
371 Subscribe {
372 sid: u64,
373 subject: Subject,
374 queue_group: Option<String>,
375 sender: mpsc::Sender<Message>,
376 },
377 Unsubscribe {
378 sid: u64,
379 max: Option<u64>,
380 },
381 Flush {
382 observer: oneshot::Sender<()>,
383 },
384 Drain {
385 sid: Option<u64>,
386 },
387 Reconnect,
388}
389
390#[derive(Debug)]
392pub(crate) enum ClientOp {
393 Publish {
394 subject: Subject,
395 payload: Bytes,
396 respond: Option<Subject>,
397 headers: Option<HeaderMap>,
398 },
399 Subscribe {
400 sid: u64,
401 subject: Subject,
402 queue_group: Option<String>,
403 },
404 Unsubscribe {
405 sid: u64,
406 max: Option<u64>,
407 },
408 Ping,
409 Pong,
410 Connect(ConnectInfo),
411}
412
413#[derive(Debug)]
414struct Subscription {
415 subject: Subject,
416 sender: mpsc::Sender<Message>,
417 queue_group: Option<String>,
418 delivered: u64,
419 max: Option<u64>,
420 is_draining: bool,
421}
422
423#[derive(Debug)]
424struct Multiplexer {
425 subject: Subject,
426 prefix: Subject,
427 senders: HashMap<String, oneshot::Sender<Message>>,
428}
429
430pub(crate) struct ConnectionHandler {
432 connection: Connection,
433 connector: Connector,
434 subscriptions: HashMap<u64, Subscription>,
435 multiplexer: Option<Multiplexer>,
436 pending_pings: usize,
437 info_sender: tokio::sync::watch::Sender<ServerInfo>,
438 ping_interval: Interval,
439 should_reconnect: bool,
440 flush_observers: Vec<oneshot::Sender<()>>,
441 is_draining: bool,
442}
443
444impl ConnectionHandler {
445 pub(crate) fn new(
446 connection: Connection,
447 connector: Connector,
448 info_sender: tokio::sync::watch::Sender<ServerInfo>,
449 ping_period: Duration,
450 ) -> ConnectionHandler {
451 let mut ping_interval = interval(ping_period);
452 ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
453
454 ConnectionHandler {
455 connection,
456 connector,
457 subscriptions: HashMap::new(),
458 multiplexer: None,
459 pending_pings: 0,
460 info_sender,
461 ping_interval,
462 should_reconnect: false,
463 flush_observers: Vec::new(),
464 is_draining: false,
465 }
466 }
467
468 pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
469 struct ProcessFut<'a> {
470 handler: &'a mut ConnectionHandler,
471 receiver: &'a mut mpsc::Receiver<Command>,
472 recv_buf: &'a mut Vec<Command>,
473 }
474
475 enum ExitReason {
476 Disconnected(Option<io::Error>),
477 ReconnectRequested,
478 Closed,
479 }
480
481 impl ProcessFut<'_> {
482 const RECV_CHUNK_SIZE: usize = 16;
483
484 #[cold]
485 fn ping(&mut self) -> Poll<ExitReason> {
486 self.handler.pending_pings += 1;
487
488 if self.handler.pending_pings > MAX_PENDING_PINGS {
489 debug!(
490 "pending pings {}, max pings {}. disconnecting",
491 self.handler.pending_pings, MAX_PENDING_PINGS
492 );
493
494 Poll::Ready(ExitReason::Disconnected(None))
495 } else {
496 self.handler.connection.enqueue_write_op(&ClientOp::Ping);
497
498 Poll::Pending
499 }
500 }
501 }
502
503 impl Future for ProcessFut<'_> {
504 type Output = ExitReason;
505
506 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
520 while self.handler.ping_interval.poll_tick(cx).is_ready() {
524 if let Poll::Ready(exit) = self.ping() {
525 return Poll::Ready(exit);
526 }
527 }
528
529 loop {
530 match self.handler.connection.poll_read_op(cx) {
531 Poll::Pending => break,
532 Poll::Ready(Ok(Some(server_op))) => {
533 self.handler.handle_server_op(server_op);
534 }
535 Poll::Ready(Ok(None)) => {
536 return Poll::Ready(ExitReason::Disconnected(None))
537 }
538 Poll::Ready(Err(err)) => {
539 return Poll::Ready(ExitReason::Disconnected(Some(err)))
540 }
541 }
542 }
543
544 self.handler.subscriptions.retain(|_, s| !s.is_draining);
549
550 if self.handler.is_draining {
551 return Poll::Ready(ExitReason::Closed);
556 }
557
558 let mut made_progress = true;
564 loop {
565 while !self.handler.connection.is_write_buf_full() {
566 debug_assert!(self.recv_buf.is_empty());
567
568 let Self {
569 recv_buf,
570 handler,
571 receiver,
572 } = &mut *self;
573 match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
574 Poll::Pending => break,
575 Poll::Ready(1..) => {
576 made_progress = true;
577
578 for cmd in recv_buf.drain(..) {
579 handler.handle_command(cmd);
580 }
581 }
582 Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
584 }
585 }
586
587 if !mem::take(&mut made_progress) {
598 break;
599 }
600
601 match self.handler.connection.poll_write(cx) {
602 Poll::Pending => {
603 break;
605 }
606 Poll::Ready(Ok(())) => {
607 continue;
609 }
610 Poll::Ready(Err(err)) => {
611 return Poll::Ready(ExitReason::Disconnected(Some(err)))
612 }
613 }
614 }
615
616 if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
617 self.handler.connection.should_flush(),
618 self.handler.flush_observers.is_empty(),
619 ) {
620 match self.handler.connection.poll_flush(cx) {
621 Poll::Pending => {}
622 Poll::Ready(Ok(())) => {
623 for observer in self.handler.flush_observers.drain(..) {
624 let _ = observer.send(());
625 }
626 }
627 Poll::Ready(Err(err)) => {
628 return Poll::Ready(ExitReason::Disconnected(Some(err)))
629 }
630 }
631 }
632
633 if mem::take(&mut self.handler.should_reconnect) {
634 return Poll::Ready(ExitReason::ReconnectRequested);
635 }
636
637 Poll::Pending
638 }
639 }
640
641 let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
642 loop {
643 let process = ProcessFut {
644 handler: self,
645 receiver,
646 recv_buf: &mut recv_buf,
647 };
648 match process.await {
649 ExitReason::Disconnected(err) => {
650 debug!(?err, "disconnected");
651 if self.handle_disconnect().await.is_err() {
652 break;
653 };
654 debug!("reconnected");
655 }
656 ExitReason::Closed => {
657 self.connector.events_tx.try_send(Event::Closed).ok();
659 break;
660 }
661 ExitReason::ReconnectRequested => {
662 debug!("reconnect requested");
663 self.connection.stream.shutdown().await.ok();
665 if self.handle_disconnect().await.is_err() {
666 break;
667 };
668 }
669 }
670 }
671 }
672
673 fn handle_server_op(&mut self, server_op: ServerOp) {
674 self.ping_interval.reset();
675
676 match server_op {
677 ServerOp::Ping => {
678 self.connection.enqueue_write_op(&ClientOp::Pong);
679 }
680 ServerOp::Pong => {
681 debug!("received PONG");
682 self.pending_pings = self.pending_pings.saturating_sub(1);
683 }
684 ServerOp::Error(error) => {
685 self.connector
686 .events_tx
687 .try_send(Event::ServerError(error))
688 .ok();
689 }
690 ServerOp::Message {
691 sid,
692 subject,
693 reply,
694 payload,
695 headers,
696 status,
697 description,
698 length,
699 } => {
700 self.connector
701 .connect_stats
702 .in_messages
703 .add(1, Ordering::Relaxed);
704
705 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
706 let message: Message = Message {
707 subject,
708 reply,
709 payload,
710 headers,
711 status,
712 description,
713 length,
714 };
715
716 match subscription.sender.try_send(message) {
719 Ok(_) => {
720 subscription.delivered += 1;
721 if let Some(max) = subscription.max {
725 if subscription.delivered.ge(&max) {
726 self.subscriptions.remove(&sid);
727 }
728 }
729 }
730 Err(mpsc::error::TrySendError::Full(_)) => {
731 self.connector
732 .events_tx
733 .try_send(Event::SlowConsumer(sid))
734 .ok();
735 }
736 Err(mpsc::error::TrySendError::Closed(_)) => {
737 self.subscriptions.remove(&sid);
738 self.connection
739 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
740 }
741 }
742 } else if sid == MULTIPLEXER_SID {
743 if let Some(multiplexer) = self.multiplexer.as_mut() {
744 let maybe_token =
745 subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
746
747 if let Some(token) = maybe_token {
748 if let Some(sender) = multiplexer.senders.remove(token) {
749 let message = Message {
750 subject,
751 reply,
752 payload,
753 headers,
754 status,
755 description,
756 length,
757 };
758
759 let _ = sender.send(message);
760 }
761 }
762 }
763 }
764 }
765 ServerOp::Info(info) => {
767 if info.lame_duck_mode {
768 self.connector.events_tx.try_send(Event::LameDuckMode).ok();
769 }
770 }
771
772 _ => {
773 }
775 }
776 }
777
778 fn handle_command(&mut self, command: Command) {
779 self.ping_interval.reset();
780
781 match command {
782 Command::Unsubscribe { sid, max } => {
783 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
784 subscription.max = max;
785 match subscription.max {
786 Some(n) => {
787 if subscription.delivered >= n {
788 self.subscriptions.remove(&sid);
789 }
790 }
791 None => {
792 self.subscriptions.remove(&sid);
793 }
794 }
795
796 self.connection
797 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
798 }
799 }
800 Command::Flush { observer } => {
801 self.flush_observers.push(observer);
802 }
803 Command::Drain { sid } => {
804 let mut drain_sub = |sid: u64, sub: &mut Subscription| {
805 sub.is_draining = true;
806 self.connection
807 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
808 };
809
810 if let Some(sid) = sid {
811 if let Some(sub) = self.subscriptions.get_mut(&sid) {
812 drain_sub(sid, sub);
813 }
814 } else {
815 self.connector.events_tx.try_send(Event::Draining).ok();
817 self.is_draining = true;
818 for (&sid, sub) in self.subscriptions.iter_mut() {
819 drain_sub(sid, sub);
820 }
821 }
822 }
823 Command::Subscribe {
824 sid,
825 subject,
826 queue_group,
827 sender,
828 } => {
829 let subscription = Subscription {
830 sender,
831 delivered: 0,
832 max: None,
833 subject: subject.to_owned(),
834 queue_group: queue_group.to_owned(),
835 is_draining: false,
836 };
837
838 self.subscriptions.insert(sid, subscription);
839
840 self.connection.enqueue_write_op(&ClientOp::Subscribe {
841 sid,
842 subject,
843 queue_group,
844 });
845 }
846 Command::Request {
847 subject,
848 payload,
849 respond,
850 headers,
851 sender,
852 } => {
853 let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
854
855 let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
856 multiplexer
857 } else {
858 let prefix = Subject::from(format!("{}.{}.", prefix, nuid::next()));
859 let subject = Subject::from(format!("{}*", prefix));
860
861 self.connection.enqueue_write_op(&ClientOp::Subscribe {
862 sid: MULTIPLEXER_SID,
863 subject: subject.clone(),
864 queue_group: None,
865 });
866
867 self.multiplexer.insert(Multiplexer {
868 subject,
869 prefix,
870 senders: HashMap::new(),
871 })
872 };
873 self.connector
874 .connect_stats
875 .out_messages
876 .add(1, Ordering::Relaxed);
877
878 multiplexer.senders.insert(token.to_owned(), sender);
879
880 let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
881
882 let pub_op = ClientOp::Publish {
883 subject,
884 payload,
885 respond: Some(respond),
886 headers,
887 };
888
889 self.connection.enqueue_write_op(&pub_op);
890 }
891
892 Command::Publish(PublishMessage {
893 subject,
894 payload,
895 reply: respond,
896 headers,
897 }) => {
898 self.connector
899 .connect_stats
900 .out_messages
901 .add(1, Ordering::Relaxed);
902
903 let header_len = headers
904 .as_ref()
905 .map(|headers| headers.len())
906 .unwrap_or_default();
907
908 self.connector.connect_stats.out_bytes.add(
909 (payload.len()
910 + respond.as_ref().map_or_else(|| 0, |r| r.len())
911 + subject.len()
912 + header_len) as u64,
913 Ordering::Relaxed,
914 );
915
916 self.connection.enqueue_write_op(&ClientOp::Publish {
917 subject,
918 payload,
919 respond,
920 headers,
921 });
922 }
923
924 Command::Reconnect => {
925 self.should_reconnect = true;
926 }
927 }
928 }
929
930 async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
931 self.pending_pings = 0;
932 self.connector.events_tx.try_send(Event::Disconnected).ok();
933 self.connector.state_tx.send(State::Disconnected).ok();
934
935 self.handle_reconnect().await
936 }
937
938 async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
939 let (info, connection) = self.connector.connect().await?;
940 self.connection = connection;
941 let _ = self.info_sender.send(info);
942
943 self.subscriptions
944 .retain(|_, subscription| !subscription.sender.is_closed());
945
946 for (sid, subscription) in &self.subscriptions {
947 self.connection.enqueue_write_op(&ClientOp::Subscribe {
948 sid: *sid,
949 subject: subscription.subject.to_owned(),
950 queue_group: subscription.queue_group.to_owned(),
951 });
952 }
953
954 if let Some(multiplexer) = &self.multiplexer {
955 self.connection.enqueue_write_op(&ClientOp::Subscribe {
956 sid: MULTIPLEXER_SID,
957 subject: multiplexer.subject.to_owned(),
958 queue_group: None,
959 });
960 }
961 Ok(())
962 }
963}
964
965pub async fn connect_with_options<A: ToServerAddrs>(
981 addrs: A,
982 options: ConnectOptions,
983) -> Result<Client, ConnectError> {
984 let ping_period = options.ping_interval;
985
986 let (events_tx, mut events_rx) = mpsc::channel(128);
987 let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
988 let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));
990 let statistics = Arc::new(Statistics::default());
991
992 let mut connector = Connector::new(
993 addrs,
994 ConnectorOptions {
995 tls_required: options.tls_required,
996 certificates: options.certificates,
997 client_key: options.client_key,
998 client_cert: options.client_cert,
999 tls_client_config: options.tls_client_config,
1000 tls_first: options.tls_first,
1001 auth: options.auth,
1002 no_echo: options.no_echo,
1003 connection_timeout: options.connection_timeout,
1004 name: options.name,
1005 ignore_discovered_servers: options.ignore_discovered_servers,
1006 retain_servers_order: options.retain_servers_order,
1007 read_buffer_capacity: options.read_buffer_capacity,
1008 reconnect_delay_callback: options.reconnect_delay_callback,
1009 auth_callback: options.auth_callback,
1010 max_reconnects: options.max_reconnects,
1011 dialer: options.dialer,
1013 },
1015 events_tx,
1016 state_tx,
1017 max_payload.clone(),
1018 statistics.clone(),
1019 )
1020 .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
1021
1022 let mut info: ServerInfo = Default::default();
1023 let mut connection = None;
1024 if !options.retry_on_initial_connect {
1025 debug!("retry on initial connect failure is disabled");
1026 let (info_ok, connection_ok) = connector.try_connect().await?;
1027 connection = Some(connection_ok);
1028 info = info_ok;
1029 }
1030
1031 let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
1032 let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
1033
1034 let client = Client::new(
1035 info_watcher,
1036 state_rx,
1037 sender,
1038 options.subscription_capacity,
1039 options.inbox_prefix,
1040 options.request_timeout,
1041 max_payload,
1042 statistics,
1043 );
1044
1045 task::spawn(async move {
1046 while let Some(event) = events_rx.recv().await {
1047 tracing::info!("event: {}", event);
1048 if let Some(event_callback) = &options.event_callback {
1049 event_callback.call(event).await;
1050 }
1051 }
1052 });
1053
1054 task::spawn(async move {
1055 if connection.is_none() && options.retry_on_initial_connect {
1056 let (info, connection_ok) = match connector.connect().await {
1057 Ok((info, connection)) => (info, connection),
1058 Err(err) => {
1059 error!("connection closed: {}", err);
1060 return;
1061 }
1062 };
1063 info_sender.send(info).ok();
1064 connection = Some(connection_ok);
1065 }
1066 let connection = connection.unwrap();
1067 let mut connection_handler =
1068 ConnectionHandler::new(connection, connector, info_sender, ping_period);
1069 connection_handler.process(&mut receiver).await
1070 });
1071
1072 Ok(client)
1073}
1074
1075#[derive(Debug, Clone, PartialEq, Eq)]
1076pub enum Event {
1077 Connected,
1078 Disconnected,
1079 LameDuckMode,
1080 Draining,
1081 Closed,
1082 SlowConsumer(u64),
1083 ServerError(ServerError),
1084 ClientError(ClientError),
1085}
1086
1087impl fmt::Display for Event {
1088 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1089 match self {
1090 Event::Connected => write!(f, "connected"),
1091 Event::Disconnected => write!(f, "disconnected"),
1092 Event::LameDuckMode => write!(f, "lame duck mode detected"),
1093 Event::Draining => write!(f, "draining"),
1094 Event::Closed => write!(f, "closed"),
1095 Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"),
1096 Event::ServerError(err) => write!(f, "server error: {err}"),
1097 Event::ClientError(err) => write!(f, "client error: {err}"),
1098 }
1099 }
1100}
1101
1102pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1169 connect_with_options(addrs, ConnectOptions::default()).await
1170}
1171
1172#[derive(Debug, Clone, Copy, PartialEq)]
1173pub enum ConnectErrorKind {
1174 ServerParse,
1176 Dns,
1178 Authentication,
1180 AuthorizationViolation,
1182 TimedOut,
1184 Tls,
1186 Io,
1188 MaxReconnects,
1190}
1191
1192impl Display for ConnectErrorKind {
1193 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1194 match self {
1195 Self::ServerParse => write!(f, "failed to parse server or server list"),
1196 Self::Dns => write!(f, "DNS error"),
1197 Self::Authentication => write!(f, "failed signing nonce"),
1198 Self::AuthorizationViolation => write!(f, "authorization violation"),
1199 Self::TimedOut => write!(f, "timed out"),
1200 Self::Tls => write!(f, "TLS error"),
1201 Self::Io => write!(f, "IO error"),
1202 Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1203 }
1204 }
1205}
1206
1207pub type ConnectError = error::Error<ConnectErrorKind>;
1210
1211impl From<io::Error> for ConnectError {
1212 fn from(err: io::Error) -> Self {
1213 ConnectError::with_source(ConnectErrorKind::Io, err)
1214 }
1215}
1216
1217#[derive(Debug)]
1231pub struct Subscriber {
1232 sid: u64,
1233 receiver: mpsc::Receiver<Message>,
1234 sender: mpsc::Sender<Command>,
1235}
1236
1237impl Subscriber {
1238 fn new(
1239 sid: u64,
1240 sender: mpsc::Sender<Command>,
1241 receiver: mpsc::Receiver<Message>,
1242 ) -> Subscriber {
1243 Subscriber {
1244 sid,
1245 sender,
1246 receiver,
1247 }
1248 }
1249
1250 pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1265 self.sender
1266 .send(Command::Unsubscribe {
1267 sid: self.sid,
1268 max: None,
1269 })
1270 .await?;
1271 self.receiver.close();
1272 Ok(())
1273 }
1274
1275 pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1301 self.sender
1302 .send(Command::Unsubscribe {
1303 sid: self.sid,
1304 max: Some(unsub_after),
1305 })
1306 .await?;
1307 Ok(())
1308 }
1309
1310 pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
1343 self.sender
1344 .send(Command::Drain {
1345 sid: Some(self.sid),
1346 })
1347 .await?;
1348
1349 Ok(())
1350 }
1351}
1352
1353#[derive(Error, Debug, PartialEq)]
1354#[error("failed to send unsubscribe")]
1355pub struct UnsubscribeError(String);
1356
1357impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1358 fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1359 UnsubscribeError(err.to_string())
1360 }
1361}
1362
1363impl Drop for Subscriber {
1364 fn drop(&mut self) {
1365 self.receiver.close();
1366 tokio::spawn({
1367 let sender = self.sender.clone();
1368 let sid = self.sid;
1369 async move {
1370 sender
1371 .send(Command::Unsubscribe { sid, max: None })
1372 .await
1373 .ok();
1374 }
1375 });
1376 }
1377}
1378
1379impl Stream for Subscriber {
1380 type Item = Message;
1381
1382 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1383 self.receiver.poll_recv(cx)
1384 }
1385}
1386
1387#[derive(Clone, Debug, Eq, PartialEq)]
1388pub enum CallbackError {
1389 Client(ClientError),
1390 Server(ServerError),
1391}
1392impl std::fmt::Display for CallbackError {
1393 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1394 match self {
1395 Self::Client(error) => write!(f, "{error}"),
1396 Self::Server(error) => write!(f, "{error}"),
1397 }
1398 }
1399}
1400
1401impl From<ServerError> for CallbackError {
1402 fn from(server_error: ServerError) -> Self {
1403 CallbackError::Server(server_error)
1404 }
1405}
1406
1407impl From<ClientError> for CallbackError {
1408 fn from(client_error: ClientError) -> Self {
1409 CallbackError::Client(client_error)
1410 }
1411}
1412
1413#[derive(Clone, Debug, Eq, PartialEq, Error)]
1414pub enum ServerError {
1415 AuthorizationViolation,
1416 SlowConsumer(u64),
1417 Other(String),
1418}
1419
1420#[derive(Clone, Debug, Eq, PartialEq)]
1421pub enum ClientError {
1422 Other(String),
1423 MaxReconnects,
1424}
1425impl std::fmt::Display for ClientError {
1426 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1427 match self {
1428 Self::Other(error) => write!(f, "nats: {error}"),
1429 Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1430 }
1431 }
1432}
1433
1434impl ServerError {
1435 fn new(error: String) -> ServerError {
1436 match error.to_lowercase().as_str() {
1437 "authorization violation" => ServerError::AuthorizationViolation,
1438 _ => ServerError::Other(error),
1440 }
1441 }
1442}
1443
1444impl std::fmt::Display for ServerError {
1445 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1446 match self {
1447 Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1448 Self::SlowConsumer(sid) => write!(f, "nats: subscription {sid} is a slow consumer"),
1449 Self::Other(error) => write!(f, "nats: {error}"),
1450 }
1451 }
1452}
1453
1454#[derive(Clone, Debug, Serialize)]
1456pub struct ConnectInfo {
1457 pub verbose: bool,
1459
1460 pub pedantic: bool,
1463
1464 #[serde(rename = "jwt")]
1466 pub user_jwt: Option<String>,
1467
1468 pub nkey: Option<String>,
1470
1471 #[serde(rename = "sig")]
1473 pub signature: Option<String>,
1474
1475 pub name: Option<String>,
1477
1478 pub echo: bool,
1483
1484 pub lang: String,
1486
1487 pub version: String,
1489
1490 pub protocol: Protocol,
1495
1496 pub tls_required: bool,
1498
1499 pub user: Option<String>,
1501
1502 pub pass: Option<String>,
1504
1505 pub auth_token: Option<String>,
1507
1508 pub headers: bool,
1510
1511 pub no_responders: bool,
1513}
1514
1515#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1517#[repr(u8)]
1518pub enum Protocol {
1519 Original = 0,
1521 Dynamic = 1,
1523}
1524
1525#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1527pub struct ServerAddr(Url);
1528
1529impl FromStr for ServerAddr {
1530 type Err = io::Error;
1531
1532 fn from_str(input: &str) -> Result<Self, Self::Err> {
1536 let url: Url = if input.contains("://") {
1537 input.parse()
1538 } else {
1539 format!("nats://{input}").parse()
1540 }
1541 .map_err(|e| {
1542 io::Error::new(
1543 ErrorKind::InvalidInput,
1544 format!("NATS server URL is invalid: {e}"),
1545 )
1546 })?;
1547
1548 Self::from_url(url)
1549 }
1550}
1551
1552impl ServerAddr {
1553 pub fn from_url(url: Url) -> io::Result<Self> {
1555 if url.scheme() != "nats"
1556 && url.scheme() != "tls"
1557 && url.scheme() != "ws"
1558 && url.scheme() != "wss"
1559 && url.scheme() != "ipc"
1561 {
1563 return Err(std::io::Error::new(
1564 ErrorKind::InvalidInput,
1565 format!("invalid scheme for NATS server URL: {}", url.scheme()),
1566 ));
1567 }
1568
1569 Ok(Self(url))
1570 }
1571
1572 pub fn into_inner(self) -> Url {
1574 self.0
1575 }
1576
1577 pub fn tls_required(&self) -> bool {
1579 self.0.scheme() == "tls"
1580 }
1581
1582 pub fn has_user_pass(&self) -> bool {
1584 self.0.username() != ""
1585 }
1586
1587 pub fn scheme(&self) -> &str {
1588 self.0.scheme()
1589 }
1590
1591 pub fn host(&self) -> &str {
1593 match self.0.host() {
1594 Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1595 Some(Host::Ipv6 { .. }) => {
1597 let host = self.0.host_str().unwrap();
1598 &host[1..host.len() - 1]
1599 }
1600 None => "",
1601 }
1602 }
1603
1604 pub fn is_websocket(&self) -> bool {
1605 self.0.scheme() == "ws" || self.0.scheme() == "wss"
1606 }
1607
1608 pub fn port(&self) -> u16 {
1610 self.0.port().unwrap_or(4222)
1611 }
1612
1613 pub fn username(&self) -> Option<&str> {
1615 let user = self.0.username();
1616 if user.is_empty() {
1617 None
1618 } else {
1619 Some(user)
1620 }
1621 }
1622
1623 pub fn password(&self) -> Option<&str> {
1625 self.0.password()
1626 }
1627
1628 pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1630 tokio::net::lookup_host((self.host(), self.port())).await
1631 }
1632}
1633
1634pub trait ToServerAddrs {
1639 type Iter: Iterator<Item = ServerAddr>;
1642
1643 fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1644}
1645
1646impl ToServerAddrs for ServerAddr {
1647 type Iter = option::IntoIter<ServerAddr>;
1648 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1649 Ok(Some(self.clone()).into_iter())
1650 }
1651}
1652
1653impl ToServerAddrs for str {
1654 type Iter = option::IntoIter<ServerAddr>;
1655 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1656 self.parse::<ServerAddr>()
1657 .map(|addr| Some(addr).into_iter())
1658 }
1659}
1660
1661impl ToServerAddrs for String {
1662 type Iter = option::IntoIter<ServerAddr>;
1663 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1664 (**self).to_server_addrs()
1665 }
1666}
1667
1668impl<T: AsRef<str>> ToServerAddrs for [T] {
1669 type Iter = std::vec::IntoIter<ServerAddr>;
1670 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1671 self.iter()
1672 .map(AsRef::as_ref)
1673 .map(str::parse)
1674 .collect::<io::Result<_>>()
1675 .map(Vec::into_iter)
1676 }
1677}
1678
1679impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1680 type Iter = std::vec::IntoIter<ServerAddr>;
1681 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1682 self.as_slice().to_server_addrs()
1683 }
1684}
1685
1686impl<'a> ToServerAddrs for &'a [ServerAddr] {
1687 type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1688
1689 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1690 Ok(self.iter().cloned())
1691 }
1692}
1693
1694impl ToServerAddrs for Vec<ServerAddr> {
1695 type Iter = std::vec::IntoIter<ServerAddr>;
1696
1697 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1698 Ok(self.clone().into_iter())
1699 }
1700}
1701
1702impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1703 type Iter = T::Iter;
1704 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1705 (**self).to_server_addrs()
1706 }
1707}
1708
1709pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1710 let subject_str = subject.as_ref();
1711 !subject_str.starts_with('.')
1712 && !subject_str.ends_with('.')
1713 && subject_str.bytes().all(|c| !c.is_ascii_whitespace())
1714}
1715macro_rules! from_with_timeout {
1716 ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1717 impl From<$origin> for $t {
1718 fn from(err: $origin) -> Self {
1719 match err.kind() {
1720 <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1721 _ => Self::with_source(<$k>::Other, err),
1722 }
1723 }
1724 }
1725 };
1726}
1727pub(crate) use from_with_timeout;
1728
1729use crate::connection::ShouldFlush;
1730
1731#[cfg(test)]
1732mod tests {
1733 use super::*;
1734
1735 #[test]
1736 fn server_address_ipv6() {
1737 let address = ServerAddr::from_str("nats://[::]").unwrap();
1738 assert_eq!(address.host(), "::")
1739 }
1740
1741 #[test]
1742 fn server_address_ipv4() {
1743 let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1744 assert_eq!(address.host(), "127.0.0.1")
1745 }
1746
1747 #[test]
1748 fn server_address_domain() {
1749 let address = ServerAddr::from_str("nats://example.com").unwrap();
1750 assert_eq!(address.host(), "example.com")
1751 }
1752
1753 #[test]
1754 fn to_server_addrs_vec_str() {
1755 let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1756 let mut addrs_iter = vec.to_server_addrs().unwrap();
1757 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1758 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1759 assert_eq!(addrs_iter.next(), None);
1760 }
1761
1762 #[test]
1763 fn to_server_addrs_arr_str() {
1764 let arr = ["nats://127.0.0.1", "nats://[::]"];
1765 let mut addrs_iter = arr.to_server_addrs().unwrap();
1766 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1767 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1768 assert_eq!(addrs_iter.next(), None);
1769 }
1770
1771 #[test]
1772 fn to_server_addrs_vec_string() {
1773 let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1774 let mut addrs_iter = vec.to_server_addrs().unwrap();
1775 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1776 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1777 assert_eq!(addrs_iter.next(), None);
1778 }
1779
1780 #[test]
1781 fn to_server_addrs_arr_string() {
1782 let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1783 let mut addrs_iter = arr.to_server_addrs().unwrap();
1784 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1785 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1786 assert_eq!(addrs_iter.next(), None);
1787 }
1788}