#![deny(unreachable_pub)]
#![deny(rustdoc::broken_intra_doc_links)]
#![deny(rustdoc::private_intra_doc_links)]
#![deny(rustdoc::invalid_codeblock_attributes)]
#![deny(rustdoc::invalid_rust_codeblocks)]
#![cfg_attr(docsrs, feature(doc_cfg))]
use thiserror::Error;
use futures_util::stream::Stream;
use tokio::io::AsyncWriteExt;
use tokio::sync::oneshot;
use tracing::{debug, error};
use core::fmt;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::fmt::Display;
use std::future::Future;
use std::iter;
use std::mem;
use std::net::SocketAddr;
use std::option;
use std::pin::Pin;
use std::slice;
use std::str::{self, FromStr};
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::ErrorKind;
use tokio::time::{interval, Duration, Interval, MissedTickBehavior};
use url::{Host, Url};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use tokio::io;
use tokio::sync::mpsc;
use tokio::task;
pub type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
const VERSION: &str = env!("CARGO_PKG_VERSION");
const LANG: &str = "rust";
const MAX_PENDING_PINGS: usize = 2;
const MULTIPLEXER_SID: u64 = 0;
pub use tokio_rustls::rustls;
use connection::{Connection, State};
use connector::{Connector, ConnectorOptions};
pub use header::{HeaderMap, HeaderName, HeaderValue};
pub use subject::{Subject, SubjectError, ToSubject};
mod auth;
pub(crate) mod auth_utils;
pub mod client;
pub mod connection;
mod connector;
mod options;
pub use auth::Auth;
pub use client::{
Client, PublishError, Request, RequestError, RequestErrorKind, Statistics, SubscribeError,
SubscribeErrorKind,
};
pub use options::{AuthError, ConnectOptions};
#[cfg(feature = "crypto")]
#[cfg_attr(docsrs, doc(cfg(feature = "crypto")))]
mod crypto;
pub mod error;
pub mod header;
mod id_generator;
#[cfg(feature = "jetstream")]
#[cfg_attr(docsrs, doc(cfg(feature = "jetstream")))]
pub mod jetstream;
pub mod message;
#[cfg(feature = "service")]
#[cfg_attr(docsrs, doc(cfg(feature = "service")))]
pub mod service;
pub mod status;
pub mod subject;
mod tls;
pub use message::Message;
pub use status::StatusCode;
#[derive(Debug, Deserialize, Default, Clone, Eq, PartialEq)]
pub struct ServerInfo {
#[serde(default)]
pub server_id: String,
#[serde(default)]
pub server_name: String,
#[serde(default)]
pub host: String,
#[serde(default)]
pub port: u16,
#[serde(default)]
pub version: String,
#[serde(default)]
pub auth_required: bool,
#[serde(default)]
pub tls_required: bool,
#[serde(default)]
pub max_payload: usize,
#[serde(default)]
pub proto: i8,
#[serde(default)]
pub client_id: u64,
#[serde(default)]
pub go: String,
#[serde(default)]
pub nonce: String,
#[serde(default)]
pub connect_urls: Vec<String>,
#[serde(default)]
pub client_ip: String,
#[serde(default)]
pub headers: bool,
#[serde(default, rename = "ldm")]
pub lame_duck_mode: bool,
#[serde(default)]
pub cluster: Option<String>,
#[serde(default)]
pub domain: Option<String>,
#[serde(default)]
pub jetstream: bool,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(crate) enum ServerOp {
Ok,
Info(Box<ServerInfo>),
Ping,
Pong,
Error(ServerError),
Message {
sid: u64,
subject: Subject,
reply: Option<Subject>,
payload: Bytes,
headers: Option<HeaderMap>,
status: Option<StatusCode>,
description: Option<String>,
length: usize,
},
}
#[deprecated(
since = "0.44.0",
note = "use `async_nats::message::OutboundMessage` instead"
)]
pub type PublishMessage = crate::message::OutboundMessage;
#[derive(Debug)]
pub(crate) enum Command {
Publish(OutboundMessage),
Request {
subject: Subject,
payload: Bytes,
respond: Subject,
headers: Option<HeaderMap>,
sender: oneshot::Sender<Message>,
},
Subscribe {
sid: u64,
subject: Subject,
queue_group: Option<String>,
sender: mpsc::Sender<Message>,
},
Unsubscribe {
sid: u64,
max: Option<u64>,
},
Flush {
observer: oneshot::Sender<()>,
},
Drain {
sid: Option<u64>,
},
Reconnect,
}
#[derive(Debug)]
pub(crate) enum ClientOp {
Publish {
subject: Subject,
payload: Bytes,
respond: Option<Subject>,
headers: Option<HeaderMap>,
},
Subscribe {
sid: u64,
subject: Subject,
queue_group: Option<String>,
},
Unsubscribe {
sid: u64,
max: Option<u64>,
},
Ping,
Pong,
Connect(ConnectInfo),
}
#[derive(Debug)]
struct Subscription {
subject: Subject,
sender: mpsc::Sender<Message>,
queue_group: Option<String>,
delivered: u64,
max: Option<u64>,
}
#[derive(Debug)]
struct Multiplexer {
subject: Subject,
prefix: Subject,
senders: HashMap<String, oneshot::Sender<Message>>,
}
pub(crate) struct ConnectionHandler {
connection: Connection,
connector: Connector,
subscriptions: HashMap<u64, Subscription>,
multiplexer: Option<Multiplexer>,
pending_pings: usize,
info_sender: tokio::sync::watch::Sender<ServerInfo>,
ping_interval: Interval,
should_reconnect: bool,
flush_observers: Vec<oneshot::Sender<()>>,
is_draining: bool,
drain_pings: VecDeque<u64>,
}
impl ConnectionHandler {
pub(crate) fn new(
connection: Connection,
connector: Connector,
info_sender: tokio::sync::watch::Sender<ServerInfo>,
ping_period: Duration,
) -> ConnectionHandler {
let mut ping_interval = interval(ping_period);
ping_interval.set_missed_tick_behavior(MissedTickBehavior::Delay);
ConnectionHandler {
connection,
connector,
subscriptions: HashMap::new(),
multiplexer: None,
pending_pings: 0,
info_sender,
ping_interval,
should_reconnect: false,
flush_observers: Vec::new(),
is_draining: false,
drain_pings: VecDeque::new(),
}
}
pub(crate) async fn process<'a>(&'a mut self, receiver: &'a mut mpsc::Receiver<Command>) {
struct ProcessFut<'a> {
handler: &'a mut ConnectionHandler,
receiver: &'a mut mpsc::Receiver<Command>,
recv_buf: &'a mut Vec<Command>,
}
enum ExitReason {
Disconnected(Option<io::Error>),
ReconnectRequested,
Closed,
}
impl ProcessFut<'_> {
const RECV_CHUNK_SIZE: usize = 16;
#[cold]
fn ping(&mut self) -> Poll<ExitReason> {
self.handler.pending_pings += 1;
if self.handler.pending_pings > MAX_PENDING_PINGS {
debug!(
pending_pings = self.handler.pending_pings,
max_pings = MAX_PENDING_PINGS,
"disconnecting due to too many pending pings"
);
Poll::Ready(ExitReason::Disconnected(None))
} else {
self.handler.connection.enqueue_write_op(&ClientOp::Ping);
Poll::Pending
}
}
}
impl Future for ProcessFut<'_> {
type Output = ExitReason;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
while self.handler.ping_interval.poll_tick(cx).is_ready() {
if let Poll::Ready(exit) = self.ping() {
return Poll::Ready(exit);
}
}
loop {
match self.handler.connection.poll_read_op(cx) {
Poll::Pending => break,
Poll::Ready(Ok(Some(server_op))) => {
self.handler.handle_server_op(server_op);
}
Poll::Ready(Ok(None)) => {
return Poll::Ready(ExitReason::Disconnected(None))
}
Poll::Ready(Err(err)) => {
return Poll::Ready(ExitReason::Disconnected(Some(err)))
}
}
}
while let Some(sid) = self.handler.drain_pings.pop_front() {
self.handler.subscriptions.remove(&sid);
}
if self.handler.is_draining {
return Poll::Ready(ExitReason::Closed);
}
let mut made_progress = true;
loop {
while !self.handler.connection.is_write_buf_full() {
debug_assert!(self.recv_buf.is_empty());
let Self {
recv_buf,
handler,
receiver,
} = &mut *self;
match receiver.poll_recv_many(cx, recv_buf, Self::RECV_CHUNK_SIZE) {
Poll::Pending => break,
Poll::Ready(1..) => {
made_progress = true;
for cmd in recv_buf.drain(..) {
handler.handle_command(cmd);
}
}
Poll::Ready(_) => return Poll::Ready(ExitReason::Closed),
}
}
if !mem::take(&mut made_progress) {
break;
}
match self.handler.connection.poll_write(cx) {
Poll::Pending => {
break;
}
Poll::Ready(Ok(())) => {
continue;
}
Poll::Ready(Err(err)) => {
return Poll::Ready(ExitReason::Disconnected(Some(err)))
}
}
}
if let (ShouldFlush::Yes, _) | (ShouldFlush::No, false) = (
self.handler.connection.should_flush(),
self.handler.flush_observers.is_empty(),
) {
match self.handler.connection.poll_flush(cx) {
Poll::Pending => {}
Poll::Ready(Ok(())) => {
for observer in self.handler.flush_observers.drain(..) {
let _ = observer.send(());
}
}
Poll::Ready(Err(err)) => {
return Poll::Ready(ExitReason::Disconnected(Some(err)))
}
}
}
if mem::take(&mut self.handler.should_reconnect) {
return Poll::Ready(ExitReason::ReconnectRequested);
}
Poll::Pending
}
}
let mut recv_buf = Vec::with_capacity(ProcessFut::RECV_CHUNK_SIZE);
loop {
let process = ProcessFut {
handler: self,
receiver,
recv_buf: &mut recv_buf,
};
match process.await {
ExitReason::Disconnected(err) => {
debug!(error = ?err, "disconnected");
if self.handle_disconnect().await.is_err() {
break;
};
debug!("reconnected");
}
ExitReason::Closed => {
self.connector.events_tx.try_send(Event::Closed).ok();
break;
}
ExitReason::ReconnectRequested => {
debug!("reconnect requested");
self.connection.stream.shutdown().await.ok();
if self.handle_disconnect().await.is_err() {
break;
};
}
}
}
}
fn handle_server_op(&mut self, server_op: ServerOp) {
self.ping_interval.reset();
match server_op {
ServerOp::Ping => {
debug!("received PING");
self.connection.enqueue_write_op(&ClientOp::Pong);
}
ServerOp::Pong => {
debug!("received PONG");
self.pending_pings = self.pending_pings.saturating_sub(1);
}
ServerOp::Error(error) => {
debug!("received ERROR: {:?}", error);
self.connector
.events_tx
.try_send(Event::ServerError(error))
.ok();
}
ServerOp::Message {
sid,
subject,
reply,
payload,
headers,
status,
description,
length,
} => {
debug!("received MESSAGE: sid={}, subject={}", sid, subject);
self.connector
.connect_stats
.in_messages
.add(1, Ordering::Relaxed);
if let Some(subscription) = self.subscriptions.get_mut(&sid) {
let message: Message = Message {
subject,
reply,
payload,
headers,
status,
description,
length,
};
match subscription.sender.try_send(message) {
Ok(_) => {
subscription.delivered += 1;
if let Some(max) = subscription.max {
if subscription.delivered.ge(&max) {
debug!("max messages reached for subscription {}", sid);
self.subscriptions.remove(&sid);
}
}
}
Err(mpsc::error::TrySendError::Full(_)) => {
debug!("slow consumer detected for subscription {}", sid);
self.connector
.events_tx
.try_send(Event::SlowConsumer(sid))
.ok();
}
Err(mpsc::error::TrySendError::Closed(_)) => {
debug!("subscription {} channel closed", sid);
self.subscriptions.remove(&sid);
self.connection
.enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
}
}
} else if sid == MULTIPLEXER_SID {
debug!("received message for multiplexer");
if let Some(multiplexer) = self.multiplexer.as_mut() {
let maybe_token =
subject.strip_prefix(multiplexer.prefix.as_ref()).to_owned();
if let Some(token) = maybe_token {
if let Some(sender) = multiplexer.senders.remove(token) {
debug!("forwarding message to request with token {}", token);
let message = Message {
subject,
reply,
payload,
headers,
status,
description,
length,
};
let _ = sender.send(message);
}
}
}
}
}
ServerOp::Info(info) => {
debug!("received INFO: server_id={}", info.server_id);
if info.lame_duck_mode {
debug!("server in lame duck mode");
self.connector.events_tx.try_send(Event::LameDuckMode).ok();
}
}
_ => {
}
}
}
fn handle_command(&mut self, command: Command) {
self.ping_interval.reset();
match command {
Command::Unsubscribe { sid, max } => {
if let Some(subscription) = self.subscriptions.get_mut(&sid) {
subscription.max = max;
match subscription.max {
Some(n) => {
if subscription.delivered >= n {
self.subscriptions.remove(&sid);
}
}
None => {
self.subscriptions.remove(&sid);
}
}
self.connection
.enqueue_write_op(&ClientOp::Unsubscribe { sid, max });
}
}
Command::Flush { observer } => {
self.flush_observers.push(observer);
}
Command::Drain { sid } => {
let mut drain_sub = |sid: u64| {
self.drain_pings.push_back(sid);
self.connection
.enqueue_write_op(&ClientOp::Unsubscribe { sid, max: None });
};
if let Some(sid) = sid {
if self.subscriptions.get_mut(&sid).is_some() {
drain_sub(sid);
}
} else {
self.connector.events_tx.try_send(Event::Draining).ok();
self.is_draining = true;
for (&sid, _) in self.subscriptions.iter_mut() {
drain_sub(sid);
}
}
self.connection.enqueue_write_op(&ClientOp::Ping);
}
Command::Subscribe {
sid,
subject,
queue_group,
sender,
} => {
let subscription = Subscription {
sender,
delivered: 0,
max: None,
subject: subject.to_owned(),
queue_group: queue_group.to_owned(),
};
self.subscriptions.insert(sid, subscription);
self.connection.enqueue_write_op(&ClientOp::Subscribe {
sid,
subject,
queue_group,
});
}
Command::Request {
subject,
payload,
respond,
headers,
sender,
} => {
let (prefix, token) = respond.rsplit_once('.').expect("malformed request subject");
let multiplexer = if let Some(multiplexer) = self.multiplexer.as_mut() {
multiplexer
} else {
let prefix = Subject::from(format!("{}.{}.", prefix, id_generator::next()));
let subject = Subject::from(format!("{prefix}*"));
self.connection.enqueue_write_op(&ClientOp::Subscribe {
sid: MULTIPLEXER_SID,
subject: subject.clone(),
queue_group: None,
});
self.multiplexer.insert(Multiplexer {
subject,
prefix,
senders: HashMap::new(),
})
};
self.connector
.connect_stats
.out_messages
.add(1, Ordering::Relaxed);
multiplexer.senders.insert(token.to_owned(), sender);
let respond: Subject = format!("{}{}", multiplexer.prefix, token).into();
let pub_op = ClientOp::Publish {
subject,
payload,
respond: Some(respond),
headers,
};
self.connection.enqueue_write_op(&pub_op);
}
Command::Publish(OutboundMessage {
subject,
payload,
reply: respond,
headers,
}) => {
self.connector
.connect_stats
.out_messages
.add(1, Ordering::Relaxed);
let header_len = headers
.as_ref()
.map(|headers| headers.len())
.unwrap_or_default();
self.connector.connect_stats.out_bytes.add(
(payload.len()
+ respond.as_ref().map_or_else(|| 0, |r| r.len())
+ subject.len()
+ header_len) as u64,
Ordering::Relaxed,
);
self.connection.enqueue_write_op(&ClientOp::Publish {
subject,
payload,
respond,
headers,
});
}
Command::Reconnect => {
self.should_reconnect = true;
}
}
}
async fn handle_disconnect(&mut self) -> Result<(), ConnectError> {
self.pending_pings = 0;
self.connector.events_tx.try_send(Event::Disconnected).ok();
self.connector.state_tx.send(State::Disconnected).ok();
self.handle_reconnect().await
}
async fn handle_reconnect(&mut self) -> Result<(), ConnectError> {
let (info, connection) = self.connector.connect().await?;
self.connection = connection;
let _ = self.info_sender.send(info);
self.subscriptions
.retain(|_, subscription| !subscription.sender.is_closed());
for (sid, subscription) in &self.subscriptions {
self.connection.enqueue_write_op(&ClientOp::Subscribe {
sid: *sid,
subject: subscription.subject.to_owned(),
queue_group: subscription.queue_group.to_owned(),
});
}
if let Some(multiplexer) = &self.multiplexer {
self.connection.enqueue_write_op(&ClientOp::Subscribe {
sid: MULTIPLEXER_SID,
subject: multiplexer.subject.to_owned(),
queue_group: None,
});
}
Ok(())
}
}
pub async fn connect_with_options<A: ToServerAddrs>(
addrs: A,
options: ConnectOptions,
) -> Result<Client, ConnectError> {
let ping_period = options.ping_interval;
let (events_tx, mut events_rx) = mpsc::channel(128);
let (state_tx, state_rx) = tokio::sync::watch::channel(State::Pending);
let max_payload = Arc::new(AtomicUsize::new(1024 * 1024));
let statistics = Arc::new(Statistics::default());
let mut connector = Connector::new(
addrs,
ConnectorOptions {
tls_required: options.tls_required,
certificates: options.certificates,
client_key: options.client_key,
client_cert: options.client_cert,
tls_client_config: options.tls_client_config,
tls_first: options.tls_first,
auth: options.auth,
no_echo: options.no_echo,
connection_timeout: options.connection_timeout,
name: options.name,
ignore_discovered_servers: options.ignore_discovered_servers,
retain_servers_order: options.retain_servers_order,
read_buffer_capacity: options.read_buffer_capacity,
reconnect_delay_callback: options.reconnect_delay_callback,
auth_callback: options.auth_callback,
max_reconnects: options.max_reconnects,
local_address: options.local_address,
},
events_tx,
state_tx,
max_payload.clone(),
statistics.clone(),
)
.map_err(|err| ConnectError::with_source(ConnectErrorKind::ServerParse, err))?;
let mut info: ServerInfo = Default::default();
let mut connection = None;
if !options.retry_on_initial_connect {
debug!("retry on initial connect failure is disabled");
let (info_ok, connection_ok) = connector.try_connect().await?;
connection = Some(connection_ok);
info = info_ok;
}
let (info_sender, info_watcher) = tokio::sync::watch::channel(info.clone());
let (sender, mut receiver) = mpsc::channel(options.sender_capacity);
let client = Client::new(
info_watcher,
state_rx,
sender,
options.subscription_capacity,
options.inbox_prefix,
options.request_timeout,
max_payload,
statistics,
options.skip_subject_validation,
);
task::spawn(async move {
while let Some(event) = events_rx.recv().await {
tracing::info!("event: {}", event);
if let Some(event_callback) = &options.event_callback {
event_callback.call(event).await;
}
}
});
task::spawn(async move {
if connection.is_none() && options.retry_on_initial_connect {
let (info, connection_ok) = match connector.connect().await {
Ok((info, connection)) => (info, connection),
Err(err) => {
error!("connection closed: {}", err);
return;
}
};
info_sender.send(info).ok();
connection = Some(connection_ok);
}
let connection = connection.unwrap();
let mut connection_handler =
ConnectionHandler::new(connection, connector, info_sender, ping_period);
connection_handler.process(&mut receiver).await
});
Ok(client)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Event {
Connected,
Disconnected,
LameDuckMode,
Draining,
Closed,
SlowConsumer(u64),
ServerError(ServerError),
ClientError(ClientError),
}
impl fmt::Display for Event {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Event::Connected => write!(f, "connected"),
Event::Disconnected => write!(f, "disconnected"),
Event::LameDuckMode => write!(f, "lame duck mode detected"),
Event::Draining => write!(f, "draining"),
Event::Closed => write!(f, "closed"),
Event::SlowConsumer(sid) => write!(f, "slow consumers for subscription {sid}"),
Event::ServerError(err) => write!(f, "server error: {err}"),
Event::ClientError(err) => write!(f, "client error: {err}"),
}
}
}
pub async fn connect<A: ToServerAddrs>(addrs: A) -> Result<Client, ConnectError> {
connect_with_options(addrs, ConnectOptions::default()).await
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ConnectErrorKind {
ServerParse,
Dns,
Authentication,
AuthorizationViolation,
TimedOut,
Tls,
Io,
MaxReconnects,
}
impl Display for ConnectErrorKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ServerParse => write!(f, "failed to parse server or server list"),
Self::Dns => write!(f, "DNS error"),
Self::Authentication => write!(f, "failed signing nonce"),
Self::AuthorizationViolation => write!(f, "authorization violation"),
Self::TimedOut => write!(f, "timed out"),
Self::Tls => write!(f, "TLS error"),
Self::Io => write!(f, "IO error"),
Self::MaxReconnects => write!(f, "reached maximum number of reconnects"),
}
}
}
pub type ConnectError = error::Error<ConnectErrorKind>;
impl From<io::Error> for ConnectError {
fn from(err: io::Error) -> Self {
ConnectError::with_source(ConnectErrorKind::Io, err)
}
}
#[derive(Debug)]
pub struct Subscriber {
sid: u64,
receiver: mpsc::Receiver<Message>,
sender: mpsc::Sender<Command>,
}
impl Subscriber {
fn new(
sid: u64,
sender: mpsc::Sender<Command>,
receiver: mpsc::Receiver<Message>,
) -> Subscriber {
Subscriber {
sid,
sender,
receiver,
}
}
pub async fn unsubscribe(&mut self) -> Result<(), UnsubscribeError> {
self.sender
.send(Command::Unsubscribe {
sid: self.sid,
max: None,
})
.await?;
self.receiver.close();
Ok(())
}
pub async fn unsubscribe_after(&mut self, unsub_after: u64) -> Result<(), UnsubscribeError> {
self.sender
.send(Command::Unsubscribe {
sid: self.sid,
max: Some(unsub_after),
})
.await?;
Ok(())
}
pub async fn drain(&mut self) -> Result<(), UnsubscribeError> {
self.sender
.send(Command::Drain {
sid: Some(self.sid),
})
.await?;
Ok(())
}
}
#[derive(Error, Debug, PartialEq)]
#[error("failed to send unsubscribe")]
pub struct UnsubscribeError(String);
impl From<tokio::sync::mpsc::error::SendError<Command>> for UnsubscribeError {
fn from(err: tokio::sync::mpsc::error::SendError<Command>) -> Self {
UnsubscribeError(err.to_string())
}
}
impl Drop for Subscriber {
fn drop(&mut self) {
self.receiver.close();
tokio::spawn({
let sender = self.sender.clone();
let sid = self.sid;
async move {
sender
.send(Command::Unsubscribe { sid, max: None })
.await
.ok();
}
});
}
}
impl Stream for Subscriber {
type Item = Message;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.receiver.poll_recv(cx)
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum CallbackError {
Client(ClientError),
Server(ServerError),
}
impl std::fmt::Display for CallbackError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Client(error) => write!(f, "{error}"),
Self::Server(error) => write!(f, "{error}"),
}
}
}
impl From<ServerError> for CallbackError {
fn from(server_error: ServerError) -> Self {
CallbackError::Server(server_error)
}
}
impl From<ClientError> for CallbackError {
fn from(client_error: ClientError) -> Self {
CallbackError::Client(client_error)
}
}
#[derive(Clone, Debug, Eq, PartialEq, Error)]
pub enum ServerError {
AuthorizationViolation,
SlowConsumer(u64),
Other(String),
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum ClientError {
Other(String),
MaxReconnects,
}
impl std::fmt::Display for ClientError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Other(error) => write!(f, "nats: {error}"),
Self::MaxReconnects => write!(f, "nats: max reconnects reached"),
}
}
}
impl ServerError {
fn new(error: String) -> ServerError {
match error.to_lowercase().as_str() {
"authorization violation" => ServerError::AuthorizationViolation,
_ => ServerError::Other(error),
}
}
}
impl std::fmt::Display for ServerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::AuthorizationViolation => write!(f, "nats: authorization violation"),
Self::SlowConsumer(sid) => write!(f, "nats: subscription {sid} is a slow consumer"),
Self::Other(error) => write!(f, "nats: {error}"),
}
}
}
#[derive(Clone, Debug, Serialize)]
pub struct ConnectInfo {
pub verbose: bool,
pub pedantic: bool,
#[serde(rename = "jwt")]
pub user_jwt: Option<String>,
pub nkey: Option<String>,
#[serde(rename = "sig")]
pub signature: Option<String>,
pub name: Option<String>,
pub echo: bool,
pub lang: String,
pub version: String,
pub protocol: Protocol,
pub tls_required: bool,
pub user: Option<String>,
pub pass: Option<String>,
pub auth_token: Option<String>,
pub headers: bool,
pub no_responders: bool,
}
#[derive(Serialize_repr, Deserialize_repr, PartialEq, Eq, Debug, Clone, Copy)]
#[repr(u8)]
pub enum Protocol {
Original = 0,
Dynamic = 1,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ServerAddr(Url);
impl FromStr for ServerAddr {
type Err = io::Error;
fn from_str(input: &str) -> Result<Self, Self::Err> {
let url: Url = if input.contains("://") {
input.parse()
} else {
format!("nats://{input}").parse()
}
.map_err(|e| {
io::Error::new(
ErrorKind::InvalidInput,
format!("NATS server URL is invalid: {e}"),
)
})?;
Self::from_url(url)
}
}
impl ServerAddr {
pub fn from_url(url: Url) -> io::Result<Self> {
if url.scheme() != "nats"
&& url.scheme() != "tls"
&& url.scheme() != "ws"
&& url.scheme() != "wss"
{
return Err(std::io::Error::new(
ErrorKind::InvalidInput,
format!("invalid scheme for NATS server URL: {}", url.scheme()),
));
}
Ok(Self(url))
}
pub fn into_inner(self) -> Url {
self.0
}
pub fn tls_required(&self) -> bool {
self.0.scheme() == "tls"
}
pub fn has_user_pass(&self) -> bool {
self.0.username() != ""
}
pub fn scheme(&self) -> &str {
self.0.scheme()
}
pub fn host(&self) -> &str {
match self.0.host() {
Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
Some(Host::Ipv6 { .. }) => {
let host = self.0.host_str().unwrap();
&host[1..host.len() - 1]
}
None => "",
}
}
pub fn is_websocket(&self) -> bool {
self.0.scheme() == "ws" || self.0.scheme() == "wss"
}
pub fn port(&self) -> u16 {
self.0.port_or_known_default().unwrap_or(4222)
}
pub fn as_url_str(&self) -> &str {
self.0.as_str()
}
pub fn username(&self) -> Option<&str> {
let user = self.0.username();
if user.is_empty() {
None
} else {
Some(user)
}
}
pub fn password(&self) -> Option<&str> {
self.0.password()
}
pub async fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr> + '_> {
tokio::net::lookup_host((self.host(), self.port())).await
}
}
pub trait ToServerAddrs {
type Iter: Iterator<Item = ServerAddr>;
fn to_server_addrs(&self) -> io::Result<Self::Iter>;
}
impl ToServerAddrs for ServerAddr {
type Iter = option::IntoIter<ServerAddr>;
fn to_server_addrs(&self) -> io::Result<Self::Iter> {
Ok(Some(self.clone()).into_iter())
}
}
impl ToServerAddrs for str {
type Iter = option::IntoIter<ServerAddr>;
fn to_server_addrs(&self) -> io::Result<Self::Iter> {
self.parse::<ServerAddr>()
.map(|addr| Some(addr).into_iter())
}
}
impl ToServerAddrs for String {
type Iter = option::IntoIter<ServerAddr>;
fn to_server_addrs(&self) -> io::Result<Self::Iter> {
(**self).to_server_addrs()
}
}
impl<T: AsRef<str>> ToServerAddrs for [T] {
type Iter = std::vec::IntoIter<ServerAddr>;
fn to_server_addrs(&self) -> io::Result<Self::Iter> {
self.iter()
.map(AsRef::as_ref)
.map(str::parse)
.collect::<io::Result<_>>()
.map(Vec::into_iter)
}
}
impl<T: AsRef<str>> ToServerAddrs for Vec<T> {
type Iter = std::vec::IntoIter<ServerAddr>;
fn to_server_addrs(&self) -> io::Result<Self::Iter> {
self.as_slice().to_server_addrs()
}
}
impl<'a> ToServerAddrs for &'a [ServerAddr] {
type Iter = iter::Cloned<slice::Iter<'a, ServerAddr>>;
fn to_server_addrs(&self) -> io::Result<Self::Iter> {
Ok(self.iter().cloned())
}
}
impl ToServerAddrs for Vec<ServerAddr> {
type Iter = std::vec::IntoIter<ServerAddr>;
fn to_server_addrs(&self) -> io::Result<Self::Iter> {
Ok(self.clone().into_iter())
}
}
impl<T: ToServerAddrs + ?Sized> ToServerAddrs for &T {
type Iter = T::Iter;
fn to_server_addrs(&self) -> io::Result<Self::Iter> {
(**self).to_server_addrs()
}
}
pub(crate) fn is_valid_publish_subject<T: AsRef<str>>(subject: T) -> bool {
let bytes = subject.as_ref().as_bytes();
if bytes.is_empty() {
return false;
}
memchr::memchr3(b' ', b'\r', b'\n', bytes).is_none() && memchr::memchr(b'\t', bytes).is_none()
}
pub(crate) fn is_valid_subject<T: AsRef<str>>(subject: T) -> bool {
let bytes = subject.as_ref().as_bytes();
if bytes.is_empty() {
return false;
}
bytes[0] != b'.'
&& bytes[bytes.len() - 1] != b'.'
&& memchr::memmem::find(bytes, b"..").is_none()
&& memchr::memchr3(b' ', b'\r', b'\n', bytes).is_none()
&& memchr::memchr(b'\t', bytes).is_none()
}
pub(crate) fn is_valid_queue_group(queue_group: &str) -> bool {
let bytes = queue_group.as_bytes();
if bytes.is_empty() {
return false;
}
memchr::memchr3(b' ', b'\r', b'\n', bytes).is_none() && memchr::memchr(b'\t', bytes).is_none()
}
#[allow(unused_macros)]
macro_rules! from_with_timeout {
($t:ty, $k:ty, $origin: ty, $origin_kind: ty) => {
impl From<$origin> for $t {
fn from(err: $origin) -> Self {
match err.kind() {
<$origin_kind>::TimedOut => Self::new(<$k>::TimedOut),
_ => Self::with_source(<$k>::Other, err),
}
}
}
};
}
#[allow(unused_imports)]
pub(crate) use from_with_timeout;
use crate::connection::ShouldFlush;
use crate::message::OutboundMessage;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn server_address_ipv6() {
let address = ServerAddr::from_str("nats://[::]").unwrap();
assert_eq!(address.host(), "::")
}
#[test]
fn server_address_ipv4() {
let address = ServerAddr::from_str("nats://127.0.0.1").unwrap();
assert_eq!(address.host(), "127.0.0.1")
}
#[test]
fn server_address_domain() {
let address = ServerAddr::from_str("nats://example.com").unwrap();
assert_eq!(address.host(), "example.com")
}
#[test]
fn to_server_addrs_vec_str() {
let vec = vec!["nats://127.0.0.1", "nats://[::]"];
let mut addrs_iter = vec.to_server_addrs().unwrap();
assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
assert_eq!(addrs_iter.next().unwrap().host(), "::");
assert_eq!(addrs_iter.next(), None);
}
#[test]
fn to_server_addrs_arr_str() {
let arr = ["nats://127.0.0.1", "nats://[::]"];
let mut addrs_iter = arr.to_server_addrs().unwrap();
assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
assert_eq!(addrs_iter.next().unwrap().host(), "::");
assert_eq!(addrs_iter.next(), None);
}
#[test]
fn to_server_addrs_vec_string() {
let vec = vec!["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
let mut addrs_iter = vec.to_server_addrs().unwrap();
assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
assert_eq!(addrs_iter.next().unwrap().host(), "::");
assert_eq!(addrs_iter.next(), None);
}
#[test]
fn to_server_addrs_arr_string() {
let arr = ["nats://127.0.0.1".to_string(), "nats://[::]".to_string()];
let mut addrs_iter = arr.to_server_addrs().unwrap();
assert_eq!(addrs_iter.next().unwrap().host(), "127.0.0.1");
assert_eq!(addrs_iter.next().unwrap().host(), "::");
assert_eq!(addrs_iter.next(), None);
}
#[test]
fn to_server_ports_arr_string() {
for (arr, expected_port) in [
(
[
"nats://127.0.0.1".to_string(),
"nats://[::]".to_string(),
"tls://127.0.0.1".to_string(),
"tls://[::]".to_string(),
],
4222,
),
(
[
"ws://127.0.0.1:80".to_string(),
"ws://[::]:80".to_string(),
"ws://127.0.0.1".to_string(),
"ws://[::]".to_string(),
],
80,
),
(
[
"wss://127.0.0.1".to_string(),
"wss://[::]".to_string(),
"wss://127.0.0.1:443".to_string(),
"wss://[::]:443".to_string(),
],
443,
),
] {
let mut addrs_iter = arr.to_server_addrs().unwrap();
assert_eq!(addrs_iter.next().unwrap().port(), expected_port);
}
}
}