1#![deny(unreachable_pub)]
191#![deny(rustdoc::broken_intra_doc_links)]
192#![deny(rustdoc::private_intra_doc_links)]
193#![deny(rustdoc::invalid_codeblock_attributes)]
194#![deny(rustdoc::invalid_rust_codeblocks)]
195#![cfg_attr(docsrs, feature(doc_auto_cfg))]
196
197use thiserror::Error;
198
199use futures::stream::Stream;
200use tokio::io::AsyncWriteExt;
201use tokio::sync::oneshot;
202use tracing::{debug, error};
203
204use core::fmt;
205use std::collections::HashMap;
206use std::fmt::Display;
207use std::future::Future;
208use std::iter;
209use std::mem;
210use std::net::SocketAddr;
211use std::option;
212use std::pin::Pin;
213use std::slice;
214use std::str::{self, FromStr};
215use std::sync::atomic::AtomicUsize;
216use std::sync::Arc;
217use std::task::{Context, Poll};
218use tokio::io::ErrorKind;
219use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
220use url::{Host, Url};
221
222use bytes::Bytes;
223use serde::{Deserialize, Serialize};
224use serde_repr::{Deserialize_repr, Serialize_repr};
225use tokio::io;
226use tokio::sync::mpsc;
227use tokio::task;
228
229pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
230
231const VERSION: &str = env!("CARGO_PKG_VERSION");
232const LANG: &str = "rust";
233const MAX_PENDING_PINGS: usize = 2;
234const MULTIPLEXER_SID: u64 = 0;
235
236pub use tokio_rustls::rustls;
240
241use connection::{Connection, State};
242use connector::{Connector, ConnectorOptions};
243pub use header::{HeaderMap, HeaderName, HeaderValue};
244pub use subject::Subject;
245
246mod auth;
247pub(crate) mod auth_utils;
248pub mod client;
249pub mod connection;
250mod connector;
251mod options;
252
253pub use auth::Auth;
254pub use client::{Client, PublishError, Request, RequestError, RequestErrorKind, SubscribeError};
255pub use options::{AuthError, ConnectOptions};
256
257mod crypto;
258pub mod error;
259pub mod header;
260pub mod jetstream;
261pub mod message;
262#[cfg(feature = "service")]
263pub mod service;
264pub mod status;
265pub mod subject;
266mod tls;
267
268pub use message::Message;
269pub use status::StatusCode;
270
271#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
274pub struct ServerInfo {
275 #[serde(default)]
277 pub server_id: String,
278 #[serde(default)]
280 pub server_name: String,
281 #[serde(default)]
283 pub host: String,
284 #[serde(default)]
286 pub port: u16,
287 #[serde(default)]
289 pub version: String,
290 #[serde(default)]
293 pub auth_required: bool,
294 #[serde(default)]
296 pub tls_required: bool,
297 #[serde(default)]
299 pub max_payload: usize,
300 #[serde(default)]
302 pub proto: i8,
303 #[serde(default)]
305 pub client_id: u64,
306 #[serde(default)]
308 pub go: String,
309 #[serde(default)]
311 pub nonce: String,
312 #[serde(default)]
314 pub connect_urls: Vec<String>,
315 #[serde(default)]
317 pub client_ip: String,
318 #[serde(default)]
320 pub headers: bool,
321 #[serde(default, rename = "ldm")]
323 pub lame_duck_mode: bool,
324}
325
326#[derive(Clone, Debug, Eq, PartialEq)]
327pub(crate) enum ServerOp {
328 Ok,
329 Info(Box<ServerInfo>),
330 Ping,
331 Pong,
332 Error(ServerError),
333 Message {
334 sid: u64,
335 subject: Subject,
336 reply: Option<Subject>,
337 payload: Bytes,
338 headers: Option<HeaderMap>,
339 status: Option<StatusCode>,
340 description: Option<String>,
341 length: usize,
342 },
343}
344
345#[derive(Debug)]
347pub(crate) enum Command {
348 Publish {
349 subject: Subject,
350 payload: Bytes,
351 respond: Option<Subject>,
352 headers: Option<HeaderMap>,
353 },
354 Request {
355 subject: Subject,
356 payload: Bytes,
357 respond: Subject,
358 headers: Option<HeaderMap>,
359 sender: oneshot::Sender<Message>,
360 },
361 Subscribe {
362 sid: u64,
363 subject: Subject,
364 queue_group: Option<String>,
365 sender: mpsc::Sender<Message>,
366 },
367 Unsubscribe {
368 sid: u64,
369 max: Option<u64>,
370 },
371 Flush {
372 observer: oneshot::Sender<()>,
373 },
374 Reconnect,
375}
376
377#[derive(Debug)]
379pub(crate) enum ClientOp {
380 Publish {
381 subject: Subject,
382 payload: Bytes,
383 respond: Option<Subject>,
384 headers: Option<HeaderMap>,
385 },
386 Subscribe {
387 sid: u64,
388 subject: Subject,
389 queue_group: Option<String>,
390 },
391 Unsubscribe {
392 sid: u64,
393 max: Option<u64>,
394 },
395 Ping,
396 Pong,
397 Connect(ConnectInfo),
398}
399
400#[derive(Debug)]
401struct Subscription {
402 subject: Subject,
403 sender: mpsc::Sender<Message>,
404 queue_group: Option<String>,
405 delivered: u64,
406 max: Option<u64>,
407}
408
409#[derive(Debug)]
410struct Multiplexer {
411 subject: Subject,
412 prefix: Subject,
413 senders: HashMap<String, oneshot::Sender<Message>>,
414}
415
416pub(crate) struct ConnectionHandler {
418 connection: Connection,
419 connector: Connector,
420 subscriptions: HashMap<u64, Subscription>,
421 multiplexer: Option<Multiplexer>,
422 pending_pings: usize,
423 info_sender: tokio::sync::watch::Sender<ServerInfo>,
424 ping_interval: Interval,
425 should_reconnect: bool,
426 flush_observers: Vec<oneshot::Sender<()>>,
427}
428
429impl ConnectionHandler {
430 pub(crate) fn new(
431 connection: Connection,
432 connector: Connector,
433 info_sender: tokio::sync::watch::Sender<ServerInfo>,
434 ping_period: Duration,
435 ) -> ConnectionHandler {
436 let mut ping_interval = interval(ping_period);
437 ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
438
439 ConnectionHandler {
440 connection,
441 connector,
442 subscriptions: HashMap::new(),
443 multiplexer: None,
444 pending_pings: 0,
445 info_sender,
446 ping_interval,
447 should_reconnect: false,
448 flush_observers: Vec::new(),
449 }
450 }
451
452 pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
453 struct ProcessFut<'a> {
454 handler: &'a mut ConnectionHandler,
455 receiver: &'a mut mpsc::Receiver<Command>,
456 recv_buf: &'a mut Vec<Command>,
457 }
458
459 enum ExitReason {
460 Disconnected(Option<io::Error>),
461 ReconnectRequested,
462 Closed,
463 }
464
465 impl<'a> ProcessFut<'a> {
466 const RECV_CHUNK_SIZE: usize = 16;
467
468 #[cold]
469 fn ping(&mut self) -> Poll<ExitReason> {
470 self.handler.pending_pings += 1;
471
472 if self.handler.pending_pings > MAX_PENDING_PINGS {
473 debug!(
474 "pending pings {}, max pings {}. disconnecting",
475 self.handler.pending_pings, MAX_PENDING_PINGS
476 );
477
478 Poll::Ready(ExitReason::Disconnected(None))
479 } else {
480 self.handler.connection.enqueue_write_op(&ClientOp::Ping);
481
482 Poll::Pending
483 }
484 }
485 }
486
487 impl<'a> Future for ProcessFut<'a> {
488 type Output = ExitReason;
489
490 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
504 while self.handler.ping_interval.poll_tick(cx).is_ready() {
508 if let Poll::Ready(exit) = self.ping() {
509 return Poll::Ready(exit);
510 }
511 }
512
513 loop {
514 match self.handler.connection.poll_read_op(cx) {
515 Poll::Pending => break,
516 Poll::Ready(Ok(Some(server_op))) => {
517 self.handler.handle_server_op(server_op);
518 }
519 Poll::Ready(Ok(None)) => {
520 return Poll::Ready(ExitReason::Disconnected(None))
521 }
522 Poll::Ready(Err(err)) => {
523 return Poll::Ready(ExitReason::Disconnected(Some(err)))
524 }
525 }
526 }
527
528 let mut made_progress = true;
534 loop {
535 while !self.handler.connection.is_write_buf_full() {
536 debug_assert!(self.recv_buf.is_empty());
537
538 let Self {
539 recv_buf,
540 handler,
541 receiver,
542 } = &mut *self;
543 match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
544 Poll::Pending => break,
545 Poll::Ready(1..) => {
546 made_progress = true;
547
548 for cmd in recv_buf.drain(..) {
549 handler.handle_command(cmd);
550 }
551 }
552 Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
554 }
555 }
556
557 if !mem::take(&mut made_progress) {
568 break;
569 }
570
571 match self.handler.connection.poll_write(cx) {
572 Poll::Pending => {
573 break;
575 }
576 Poll::Ready(Ok(())) => {
577 continue;
579 }
580 Poll::Ready(Err(err)) => {
581 return Poll::Ready(ExitReason::Disconnected(Some(err)))
582 }
583 }
584 }
585
586 if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
587 self.handler.connection.should_flush(),
588 self.handler.flush_observers.is_empty(),
589 ) {
590 match self.handler.connection.poll_flush(cx) {
591 Poll::Pending => {}
592 Poll::Ready(Ok(())) => {
593 for observer in self.handler.flush_observers.drain(..) {
594 let _ = observer.send(());
595 }
596 }
597 Poll::Ready(Err(err)) => {
598 return Poll::Ready(ExitReason::Disconnected(Some(err)))
599 }
600 }
601 }
602
603 if mem::take(&mut self.handler.should_reconnect) {
604 return Poll::Ready(ExitReason::ReconnectRequested);
605 }
606
607 Poll::Pending
608 }
609 }
610
611 let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
612 loop {
613 let process = ProcessFut {
614 handler: self,
615 receiver,
616 recv_buf: &mut recv_buf,
617 };
618 match process.await {
619 ExitReason::Disconnected(err) => {
620 debug!(?err, "disconnected");
621 if self.handle_disconnect().await.is_err() {
622 break;
623 };
624 debug!("reconnected");
625 }
626 ExitReason::Closed => break,
627 ExitReason::ReconnectRequested => {
628 debug!("reconnect requested");
629 self.connection.stream.shutdown().await.ok();
631 if self.handle_disconnect().await.is_err() {
632 break;
633 };
634 }
635 }
636 }
637 }
638
639 fn handle_server_op(&mut self, server_op: ServerOp) {
640 self.ping_interval.reset();
641
642 match server_op {
643 ServerOp::Ping => {
644 self.connection.enqueue_write_op(&ClientOp::Pong);
645 }
646 ServerOp::Pong => {
647 debug!("received PONG");
648 self.pending_pings = self.pending_pings.saturating_sub(1);
649 }
650 ServerOp::Error(error) => {
651 self.connector
652 .events_tx
653 .try_send(Event::ServerError(error))
654 .ok();
655 }
656 ServerOp::Message {
657 sid,
658 subject,
659 reply,
660 payload,
661 headers,
662 status,
663 description,
664 length,
665 } => {
666 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
667 let message: Message = Message {
668 subject,
669 reply,
670 payload,
671 headers,
672 status,
673 description,
674 length,
675 };
676
677 match subscription.sender.try_send(message) {
680 Ok(_) => {
681 subscription.delivered += 1;
682 if let Some(max) = subscription.max {
686 if subscription.delivered.ge(&max) {
687 self.subscriptions.remove(&sid);
688 }
689 }
690 }
691 Err(mpsc::error::TrySendError::Full(_)) => {
692 self.connector
693 .events_tx
694 .try_send(Event::SlowConsumer(sid))
695 .ok();
696 }
697 Err(mpsc::error::TrySendError::Closed(_)) => {
698 self.subscriptions.remove(&sid);
699 self.connection
700 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
701 }
702 }
703 } else if sid == MULTIPLEXER_SID {
704 if let Some(multiplexer) = self.multiplexer.as_mut() {
705 let maybe_token =
706 subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
707
708 if let Some(token) = maybe_token {
709 if let Some(sender) = multiplexer.senders.remove(token) {
710 let message = Message {
711 subject,
712 reply,
713 payload,
714 headers,
715 status,
716 description,
717 length,
718 };
719
720 let _ = sender.send(message);
721 }
722 }
723 }
724 }
725 }
726 ServerOp::Info(info) => {
728 if info.lame_duck_mode {
729 self.connector.events_tx.try_send(Event::LameDuckMode).ok();
730 }
731 }
732
733 _ => {
734 }
736 }
737 }
738
739 fn handle_command(&mut self, command: Command) {
740 self.ping_interval.reset();
741
742 match command {
743 Command::Unsubscribe { sid, max } => {
744 if let Some(subscription) = self.subscriptions.get_mut(&sid) {
745 subscription.max = max;
746 match subscription.max {
747 Some(n) => {
748 if subscription.delivered >= n {
749 self.subscriptions.remove(&sid);
750 }
751 }
752 None => {
753 self.subscriptions.remove(&sid);
754 }
755 }
756
757 self.connection
758 .enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
759 }
760 }
761 Command::Flush { observer } => {
762 self.flush_observers.push(observer);
763 }
764 Command::Subscribe {
765 sid,
766 subject,
767 queue_group,
768 sender,
769 } => {
770 let subscription = Subscription {
771 sender,
772 delivered: 0,
773 max: None,
774 subject: subject.to_owned(),
775 queue_group: queue_group.to_owned(),
776 };
777
778 self.subscriptions.insert(sid, subscription);
779
780 self.connection.enqueue_write_op(&ClientOp::Subscribe {
781 sid,
782 subject,
783 queue_group,
784 });
785 }
786 Command::Request {
787 subject,
788 payload,
789 respond,
790 headers,
791 sender,
792 } => {
793 let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
794
795 let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
796 multiplexer
797 } else {
798 let prefix = Subject::from(format!("{}.{}.", prefix, nuid::next()));
799 let subject = Subject::from(format!("{}*", prefix));
800
801 self.connection.enqueue_write_op(&ClientOp::Subscribe {
802 sid: MULTIPLEXER_SID,
803 subject: subject.clone(),
804 queue_group: None,
805 });
806
807 self.multiplexer.insert(Multiplexer {
808 subject,
809 prefix,
810 senders: HashMap::new(),
811 })
812 };
813
814 multiplexer.senders.insert(token.to_owned(), sender);
815
816 let pub_op = ClientOp::Publish {
817 subject,
818 payload,
819 respond: Some(format!("{}{}", multiplexer.prefix, token).into()),
820 headers,
821 };
822
823 self.connection.enqueue_write_op(&pub_op);
824 }
825
826 Command::Publish {
827 subject,
828 payload,
829 respond,
830 headers,
831 } => {
832 self.connection.enqueue_write_op(&ClientOp::Publish {
833 subject,
834 payload,
835 respond,
836 headers,
837 });
838 }
839
840 Command::Reconnect => {
841 self.should_reconnect = true;
842 }
843 }
844 }
845
846 async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
847 self.pending_pings = 0;
848 self.connector.events_tx.try_send(Event::Disconnected).ok();
849 self.connector.state_tx.send(State::Disconnected).ok();
850
851 self.handle_reconnect().await
852 }
853
854 async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
855 let (info, connection) = self.connector.connect().await?;
856 self.connection = connection;
857 let _ = self.info_sender.send(info);
858
859 self.subscriptions
860 .retain(|_, subscription| !subscription.sender.is_closed());
861
862 for (sid, subscription) in &self.subscriptions {
863 self.connection.enqueue_write_op(&ClientOp::Subscribe {
864 sid: *sid,
865 subject: subscription.subject.to_owned(),
866 queue_group: subscription.queue_group.to_owned(),
867 });
868 }
869
870 if let Some(multiplexer) = &self.multiplexer {
871 self.connection.enqueue_write_op(&ClientOp::Subscribe {
872 sid: MULTIPLEXER_SID,
873 subject: multiplexer.subject.to_owned(),
874 queue_group: None,
875 });
876 }
877 Ok(())
878 }
879}
880
881pub async fn connect_with_options<A: ToServerAddrs>(
897 addrs: A,
898 options: ConnectOptions,
899) -> Result<Client, ConnectError> {
900 let ping_period = options.ping_interval;
901
902 let (events_tx, mut events_rx) = mpsc::channel(128);
903 let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
904 let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));
906
907 let mut connector = Connector::new(
908 addrs,
909 ConnectorOptions {
910 tls_required: options.tls_required,
911 certificates: options.certificates,
912 client_key: options.client_key,
913 client_cert: options.client_cert,
914 tls_client_config: options.tls_client_config,
915 tls_first: options.tls_first,
916 auth: options.auth,
917 no_echo: options.no_echo,
918 connection_timeout: options.connection_timeout,
919 name: options.name,
920 ignore_discovered_servers: options.ignore_discovered_servers,
921 retain_servers_order: options.retain_servers_order,
922 read_buffer_capacity: options.read_buffer_capacity,
923 reconnect_delay_callback: options.reconnect_delay_callback,
924 auth_callback: options.auth_callback,
925 max_reconnects: options.max_reconnects,
926 },
927 events_tx,
928 state_tx,
929 max_payload.clone(),
930 )
931 .map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
932
933 let mut info: ServerInfo = Default::default();
934 let mut connection = None;
935 if !options.retry_on_initial_connect {
936 debug!("retry on initial connect failure is disabled");
937 let (info_ok, connection_ok) = connector.try_connect().await?;
938 connection = Some(connection_ok);
939 info = info_ok;
940 }
941
942 let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
943 let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
944
945 let client = Client::new(
946 info_watcher,
947 state_rx,
948 sender,
949 options.subscription_capacity,
950 options.inbox_prefix,
951 options.request_timeout,
952 max_payload,
953 );
954
955 task::spawn(async move {
956 while let Some(event) = events_rx.recv().await {
957 tracing::info!("event: {}", event);
958 if let Some(event_callback) = &options.event_callback {
959 event_callback.call(event).await;
960 }
961 }
962 });
963
964 task::spawn(async move {
965 if connection.is_none() && options.retry_on_initial_connect {
966 let (info, connection_ok) = match connector.connect().await {
967 Ok((info, connection)) => (info, connection),
968 Err(err) => {
969 error!("connection closed: {}", err);
970 return;
971 }
972 };
973 info_sender.send(info).ok();
974 connection = Some(connection_ok);
975 }
976 let connection = connection.unwrap();
977 let mut connection_handler =
978 ConnectionHandler::new(connection, connector, info_sender, ping_period);
979 connection_handler.process(&mut receiver).await
980 });
981
982 Ok(client)
983}
984
985#[derive(Debug, Clone, PartialEq, Eq)]
986pub enum Event {
987 Connected,
988 Disconnected,
989 LameDuckMode,
990 SlowConsumer(u64),
991 ServerError(ServerError),
992 ClientError(ClientError),
993}
994
995impl fmt::Display for Event {
996 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
997 match self {
998 Event::Connected => write!(f, "connected"),
999 Event::Disconnected => write!(f, "disconnected"),
1000 Event::LameDuckMode => write!(f, "lame duck mode detected"),
1001 Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"),
1002 Event::ServerError(err) => write!(f, "server error: {err}"),
1003 Event::ClientError(err) => write!(f, "client error: {err}"),
1004 }
1005 }
1006}
1007
1008pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
1075 connect_with_options(addrs, ConnectOptions::default()).await
1076}
1077
1078#[derive(Debug, Clone, Copy, PartialEq)]
1079pub enum ConnectErrorKind {
1080 ServerParse,
1082 Dns,
1084 Authentication,
1086 AuthorizationViolation,
1088 TimedOut,
1090 Tls,
1092 Io,
1094 MaxReconnects,
1096}
1097
1098impl Display for ConnectErrorKind {
1099 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1100 match self {
1101 Self::ServerParse => write!(f, "failed to parse server or server list"),
1102 Self::Dns => write!(f, "DNS error"),
1103 Self::Authentication => write!(f, "failed signing nonce"),
1104 Self::AuthorizationViolation => write!(f, "authorization violation"),
1105 Self::TimedOut => write!(f, "timed out"),
1106 Self::Tls => write!(f, "TLS error"),
1107 Self::Io => write!(f, "IO error"),
1108 Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
1109 }
1110 }
1111}
1112
1113pub type ConnectError = error::Error<ConnectErrorKind>;
1116
1117impl From<io::Error> for ConnectError {
1118 fn from(err: io::Error) -> Self {
1119 ConnectError::with_source(ConnectErrorKind::Io, err)
1120 }
1121}
1122
1123#[derive(Debug)]
1137pub struct Subscriber {
1138 sid: u64,
1139 receiver: mpsc::Receiver<Message>,
1140 sender: mpsc::Sender<Command>,
1141}
1142
1143impl Subscriber {
1144 fn new(
1145 sid: u64,
1146 sender: mpsc::Sender<Command>,
1147 receiver: mpsc::Receiver<Message>,
1148 ) -> Subscriber {
1149 Subscriber {
1150 sid,
1151 sender,
1152 receiver,
1153 }
1154 }
1155
1156 pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
1171 self.sender
1172 .send(Command::Unsubscribe {
1173 sid: self.sid,
1174 max: None,
1175 })
1176 .await?;
1177 self.receiver.close();
1178 Ok(())
1179 }
1180
1181 pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
1207 self.sender
1208 .send(Command::Unsubscribe {
1209 sid: self.sid,
1210 max: Some(unsub_after),
1211 })
1212 .await?;
1213 Ok(())
1214 }
1215}
1216
1217#[derive(Error, Debug, PartialEq)]
1218#[error("failed to send unsubscribe")]
1219pub struct UnsubscribeError(String);
1220
1221impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
1222 fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
1223 UnsubscribeError(err.to_string())
1224 }
1225}
1226
1227impl Drop for Subscriber {
1228 fn drop(&mut self) {
1229 self.receiver.close();
1230 tokio::spawn({
1231 let sender = self.sender.clone();
1232 let sid = self.sid;
1233 async move {
1234 sender
1235 .send(Command::Unsubscribe { sid, max: None })
1236 .await
1237 .ok();
1238 }
1239 });
1240 }
1241}
1242
1243impl Stream for Subscriber {
1244 type Item = Message;
1245
1246 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1247 self.receiver.poll_recv(cx)
1248 }
1249}
1250
1251#[derive(Clone, Debug, Eq, PartialEq)]
1252pub enum CallbackError {
1253 Client(ClientError),
1254 Server(ServerError),
1255}
1256impl std::fmt::Display for CallbackError {
1257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1258 match self {
1259 Self::Client(error) => write!(f, "{error}"),
1260 Self::Server(error) => write!(f, "{error}"),
1261 }
1262 }
1263}
1264
1265impl From<ServerError> for CallbackError {
1266 fn from(server_error: ServerError) -> Self {
1267 CallbackError::Server(server_error)
1268 }
1269}
1270
1271impl From<ClientError> for CallbackError {
1272 fn from(client_error: ClientError) -> Self {
1273 CallbackError::Client(client_error)
1274 }
1275}
1276
1277#[derive(Clone, Debug, Eq, PartialEq, Error)]
1278pub enum ServerError {
1279 AuthorizationViolation,
1280 SlowConsumer(u64),
1281 Other(String),
1282}
1283
1284#[derive(Clone, Debug, Eq, PartialEq)]
1285pub enum ClientError {
1286 Other(String),
1287 MaxReconnects,
1288}
1289impl std::fmt::Display for ClientError {
1290 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1291 match self {
1292 Self::Other(error) => write!(f, "nats: {error}"),
1293 Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
1294 }
1295 }
1296}
1297
1298impl ServerError {
1299 fn new(error: String) -> ServerError {
1300 match error.to_lowercase().as_str() {
1301 "authorization violation" => ServerError::AuthorizationViolation,
1302 _ => ServerError::Other(error),
1304 }
1305 }
1306}
1307
1308impl std::fmt::Display for ServerError {
1309 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1310 match self {
1311 Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
1312 Self::SlowConsumer(sid) => write!(f, "nats: subscription {sid} is a slow consumer"),
1313 Self::Other(error) => write!(f, "nats: {error}"),
1314 }
1315 }
1316}
1317
1318#[derive(Clone, Debug, Serialize)]
1320pub struct ConnectInfo {
1321 pub verbose: bool,
1323
1324 pub pedantic: bool,
1327
1328 #[serde(rename = "jwt")]
1330 pub user_jwt: Option<String>,
1331
1332 pub nkey: Option<String>,
1334
1335 #[serde(rename = "sig")]
1337 pub signature: Option<String>,
1338
1339 pub name: Option<String>,
1341
1342 pub echo: bool,
1347
1348 pub lang: String,
1350
1351 pub version: String,
1353
1354 pub protocol: Protocol,
1359
1360 pub tls_required: bool,
1362
1363 pub user: Option<String>,
1365
1366 pub pass: Option<String>,
1368
1369 pub auth_token: Option<String>,
1371
1372 pub headers: bool,
1374
1375 pub no_responders: bool,
1377}
1378
1379#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
1381#[repr(u8)]
1382pub enum Protocol {
1383 Original = 0,
1385 Dynamic = 1,
1387}
1388
1389#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1391pub struct ServerAddr(Url);
1392
1393impl FromStr for ServerAddr {
1394 type Err = io::Error;
1395
1396 fn from_str(input: &str) -> Result<Self, Self::Err> {
1400 let url: Url = if input.contains("://") {
1401 input.parse()
1402 } else {
1403 format!("nats://{input}").parse()
1404 }
1405 .map_err(|e| {
1406 io::Error::new(
1407 ErrorKind::InvalidInput,
1408 format!("NATS server URL is invalid: {e}"),
1409 )
1410 })?;
1411
1412 Self::from_url(url)
1413 }
1414}
1415
1416impl ServerAddr {
1417 pub fn from_url(url: Url) -> io::Result<Self> {
1419 if url.scheme() != "nats" && url.scheme() != "tls" {
1420 return Err(std::io::Error::new(
1421 ErrorKind::InvalidInput,
1422 format!("invalid scheme for NATS server URL: {}", url.scheme()),
1423 ));
1424 }
1425
1426 Ok(Self(url))
1427 }
1428
1429 pub fn into_inner(self) -> Url {
1431 self.0
1432 }
1433
1434 pub fn tls_required(&self) -> bool {
1436 self.0.scheme() == "tls"
1437 }
1438
1439 pub fn has_user_pass(&self) -> bool {
1441 self.0.username() != ""
1442 }
1443
1444 pub fn host(&self) -> &str {
1446 match self.0.host() {
1447 Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
1448 Some(Host::Ipv6 { .. }) => {
1450 let host = self.0.host_str().unwrap();
1451 &host[1..host.len() - 1]
1452 }
1453 None => "",
1454 }
1455 }
1456
1457 pub fn port(&self) -> u16 {
1459 self.0.port().unwrap_or(4222)
1460 }
1461
1462 pub fn username(&self) -> Option<&str> {
1464 let user = self.0.username();
1465 if user.is_empty() {
1466 None
1467 } else {
1468 Some(user)
1469 }
1470 }
1471
1472 pub fn password(&self) -> Option<&str> {
1474 self.0.password()
1475 }
1476
1477 pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
1479 tokio::net::lookup_host((self.host(), self.port())).await
1480 }
1481}
1482
1483pub trait ToServerAddrs {
1488 type Iter: Iterator<Item = ServerAddr>;
1491
1492 fn to_server_addrs(&self) -> io::Result<Self::Iter>;
1493}
1494
1495impl ToServerAddrs for ServerAddr {
1496 type Iter = option::IntoIter<ServerAddr>;
1497 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1498 Ok(Some(self.clone()).into_iter())
1499 }
1500}
1501
1502impl ToServerAddrs for str {
1503 type Iter = option::IntoIter<ServerAddr>;
1504 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1505 self.parse::<ServerAddr>()
1506 .map(|addr| Some(addr).into_iter())
1507 }
1508}
1509
1510impl ToServerAddrs for String {
1511 type Iter = option::IntoIter<ServerAddr>;
1512 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1513 (**self).to_server_addrs()
1514 }
1515}
1516
1517impl<T: AsRef<str>> ToServerAddrs for [T] {
1518 type Iter = std::vec::IntoIter<ServerAddr>;
1519 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1520 self.iter()
1521 .map(AsRef::as_ref)
1522 .map(str::parse)
1523 .collect::<io::Result<_>>()
1524 .map(Vec::into_iter)
1525 }
1526}
1527
1528impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
1529 type Iter = std::vec::IntoIter<ServerAddr>;
1530 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1531 self.as_slice().to_server_addrs()
1532 }
1533}
1534
1535impl<'a> ToServerAddrs for &'a [ServerAddr] {
1536 type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
1537
1538 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1539 Ok(self.iter().cloned())
1540 }
1541}
1542
1543impl ToServerAddrs for Vec<ServerAddr> {
1544 type Iter = std::vec::IntoIter<ServerAddr>;
1545
1546 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1547 Ok(self.clone().into_iter())
1548 }
1549}
1550
1551impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
1552 type Iter = T::Iter;
1553 fn to_server_addrs(&self) -> io::Result<Self::Iter> {
1554 (**self).to_server_addrs()
1555 }
1556}
1557
1558pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
1559 !subject.as_ref().contains([' ', '.', '\r', '\n'])
1560}
1561
1562macro_rules! from_with_timeout {
1563 ($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
1564 impl From<$origin> for $t {
1565 fn from(err: $origin) -> Self {
1566 match err.kind() {
1567 <$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
1568 _ => Self::with_source(<$k>::Other, err),
1569 }
1570 }
1571 }
1572 };
1573}
1574pub(crate) use from_with_timeout;
1575
1576use crate::connection::ShouldFlush;
1577
1578#[cfg(test)]
1579mod tests {
1580 use super::*;
1581
1582 #[test]
1583 fn server_address_ipv6() {
1584 let address = ServerAddr::from_str("nats://[::]").unwrap();
1585 assert_eq!(address.host(), "::")
1586 }
1587
1588 #[test]
1589 fn server_address_ipv4() {
1590 let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
1591 assert_eq!(address.host(), "127.0.0.1")
1592 }
1593
1594 #[test]
1595 fn server_address_domain() {
1596 let address = ServerAddr::from_str("nats://example.com").unwrap();
1597 assert_eq!(address.host(), "example.com")
1598 }
1599
1600 #[test]
1601 fn to_server_addrs_vec_str() {
1602 let vec = vec!["nats://127.0.0.1", "nats://[::]"];
1603 let mut addrs_iter = vec.to_server_addrs().unwrap();
1604 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1605 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1606 assert_eq!(addrs_iter.next(), None);
1607 }
1608
1609 #[test]
1610 fn to_server_addrs_arr_str() {
1611 let arr = ["nats://127.0.0.1", "nats://[::]"];
1612 let mut addrs_iter = arr.to_server_addrs().unwrap();
1613 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1614 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1615 assert_eq!(addrs_iter.next(), None);
1616 }
1617
1618 #[test]
1619 fn to_server_addrs_vec_string() {
1620 let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1621 let mut addrs_iter = vec.to_server_addrs().unwrap();
1622 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1623 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1624 assert_eq!(addrs_iter.next(), None);
1625 }
1626
1627 #[test]
1628 fn to_server_addrs_arr_string() {
1629 let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
1630 let mut addrs_iter = arr.to_server_addrs().unwrap();
1631 assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
1632 assert_eq!(addrs_iter.next().unwrap().host(), "::");
1633 assert_eq!(addrs_iter.next(), None);
1634 }
1635}