use async_trait::async_trait;
use std::borrow::Borrow;
use std::collections::HashMap;
use std::convert::TryFrom;
use std::io::{self, Error, ErrorKind};
use std::net::{SocketAddr, ToSocketAddrs};
use std::ops::DerefMut;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, BufReader, ReadBuf};
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use url::Url;
use crate::auth_utils;
use crate::proto::{self, ClientOp, ServerOp};
use crate::rustls::{ClientConfig, ServerName};
use crate::secure_wipe::SecureString;
use crate::tokio_rustls::client::TlsStream;
use crate::{connect::ConnectInfo, inject_io_failure, AuthStyle, Options, ServerInfo};
pub(crate) struct Connector {
attempts: HashMap<ServerAddress, usize>,
options: Arc<Options>,
tls_config: Arc<ClientConfig>,
}
fn load_tls_certs(tls_options: &Arc<Options>) -> io::Result<ClientConfig> {
let roots = match rustls_native_certs::load_native_certs() {
Ok(store) => store.into_iter().map(|c| c.0).collect(),
Err(_) => Vec::new(),
};
let mut root_certs = crate::rustls::RootCertStore::empty();
let (_added, _ignored) = root_certs.add_parsable_certificates(&roots);
for path in &tls_options.certificates {
let f = std::fs::File::open(path)?;
let mut f = std::io::BufReader::new(f);
let certs = rustls_pemfile::certs(&mut f)?;
let (_added, _ignored) = root_certs.add_parsable_certificates(&certs);
}
let tls_config = tls_options
.tls_client_config
.clone()
.with_safe_defaults()
.with_root_certificates(root_certs);
let tls_config =
if let (Some(cert), Some(key)) = (&tls_options.client_cert, &tls_options.client_key) {
tls_config
.with_single_cert(auth_utils::load_certs(cert)?, auth_utils::load_key(key)?)
.map_err(|err| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid client certificate and key pair: {}", err),
)
})?
} else {
tls_config.with_no_client_auth()
};
Ok(tls_config)
}
impl Connector {
pub(crate) async fn new(
urls: Vec<ServerAddress>,
options: Arc<Options>,
) -> io::Result<Connector> {
let tls_options = options.clone();
let tls_config =
tokio::task::spawn_blocking(move || load_tls_certs(&tls_options)).await??;
let connector = Connector {
attempts: urls.into_iter().map(|url| (url, 0)).collect(),
options,
tls_config: Arc::new(tls_config),
};
Ok(connector)
}
pub(crate) fn add_server(&mut self, url: ServerAddress) {
self.attempts.insert(url, 0);
}
pub(crate) fn get_options(&self) -> Arc<Options> {
self.options.clone()
}
fn get_servers(&mut self) -> io::Result<Vec<ServerAddress>> {
let servers: Vec<_> = self
.attempts
.iter()
.filter_map(
|(server, reconnects)| match self.options.max_reconnects.as_ref() {
None => Some(server),
Some(max) if reconnects < max => Some(server),
Some(_) => None,
},
)
.cloned()
.collect();
if servers.is_empty() {
Err(Error::new(
ErrorKind::NotFound,
"no servers remaining to connect to",
))
} else {
Ok(servers)
}
}
pub(crate) async fn connect(
&mut self,
use_backoff: bool,
) -> io::Result<(ServerInfo, NatsStream)> {
let mut last_err = Error::new(ErrorKind::AddrNotAvailable, "no socket addresses");
loop {
let mut servers = self.get_servers()?;
fastrand::shuffle(&mut servers);
for server in &servers {
let reconnects = self.attempts.get_mut(server).unwrap();
let sleep_duration = self
.options
.reconnect_delay_callback
.call(*reconnects)
.await;
*reconnects += 1;
let lookup_res = server.socket_addrs();
let mut addrs = match lookup_res {
Ok(addrs) => addrs.collect::<Vec<_>>(),
Err(err) => {
last_err = err;
continue;
}
};
fastrand::shuffle(&mut addrs);
for addr in addrs {
if let Some(sleep_duration) = sleep_duration {
tokio::time::sleep(sleep_duration).await;
}
let res = self.connect_addr(addr, server).await;
let (server_info, stream) = match res {
Ok(val) => val,
Err(err) => {
last_err = err;
continue;
}
};
for url in &server_info.connect_urls {
self.add_server(url.parse()?);
}
*self.attempts.get_mut(server).unwrap() = 0;
return Ok((server_info, stream));
}
}
if !use_backoff {
return Err(last_err);
}
}
}
async fn connect_addr(
&self,
addr: SocketAddr,
server: &ServerAddress,
) -> io::Result<(ServerInfo, NatsStream)> {
inject_io_failure()?;
let mut stream = TcpStream::connect(addr).await?;
stream.set_nodelay(true)?;
let mut line = crate::SecureVec::with_capacity(1024);
while !line.ends_with(b"\r\n") {
let byte = &mut [0];
stream.read_exact(byte).await?;
line.push(byte[0]);
}
let server_info = match proto::decode(&line[..]).await? {
Some(ServerOp::Info(server_info)) => server_info,
Some(op) => {
return Err(Error::new(
ErrorKind::Other,
format!("expected INFO, received: {:?}", op),
));
}
None => {
return Err(Error::new(ErrorKind::UnexpectedEof, "connection closed"));
}
};
let tls_required =
self.options.tls_required || server.tls_required() || server_info.tls_required;
let mut stream = if tls_required {
inject_io_failure()?;
let dns_name = ServerName::try_from(server_info.host.as_str())
.or_else(|_| ServerName::try_from(server.host()))
.map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"cannot determine hostname for TLS connection",
)
})?;
NatsStream::new_tls(
tokio_rustls::TlsConnector::from(self.tls_config.clone())
.connect(dns_name, stream)
.await
.map_err(|e| io::Error::new(io::ErrorKind::NotConnected, e))?,
)
} else {
NatsStream::new_tcp(stream)
};
let mut connect_info = ConnectInfo {
tls_required,
name: self.options.name.clone().map(SecureString::from),
pedantic: false,
verbose: false,
lang: crate::LANG.to_string(),
version: crate::VERSION.to_string(),
protocol: crate::connect::Protocol::Dynamic,
user: None,
pass: None,
auth_token: None,
user_jwt: None,
nkey: None,
signature: None,
echo: !self.options.no_echo,
headers: true,
no_responders: true,
};
match &self.options.auth {
AuthStyle::NoAuth => {}
AuthStyle::UserPass(user, pass) => {
connect_info.user = Some(SecureString::from(user.to_string()));
connect_info.pass = Some(SecureString::from(pass.to_string()));
}
AuthStyle::Token(token) => {
connect_info.auth_token = Some(token.to_string().into());
}
AuthStyle::Credentials { jwt_cb, sig_cb } => {
let jwt = jwt_cb()?;
let sig = sig_cb(server_info.nonce.as_bytes())?;
connect_info.user_jwt = Some(jwt);
connect_info.signature = Some(sig);
}
AuthStyle::NKey { nkey_cb, sig_cb } => {
let nkey = nkey_cb()?;
let sig = sig_cb(server_info.nonce.as_bytes())?;
connect_info.nkey = Some(nkey);
connect_info.signature = Some(sig);
}
}
if server.has_user_pass() {
connect_info.user = server.username();
connect_info.pass = server.password();
}
proto::encode(&mut stream, ClientOp::Connect(&connect_info)).await?;
proto::encode(&mut stream, ClientOp::Ping).await?;
stream.flush().await?;
let mut reader = BufReader::new(stream.clone());
loop {
match proto::decode(&mut reader).await? {
Some(ServerOp::Pong) => break,
Some(ServerOp::Ping) => {
proto::encode(&mut stream, ClientOp::Pong).await?;
stream.flush().await?;
}
Some(op) => {
return Err(Error::new(
ErrorKind::InvalidData,
format!("unexpected line while connecting: {:?}", op),
));
}
None => {
return Err(Error::new(
ErrorKind::UnexpectedEof,
"connection closed while waiting for the first PONG",
));
}
}
}
Ok((server_info, stream))
}
}
#[derive(Debug, Clone)]
pub(crate) struct NatsStream {
flavor: Arc<Flavor>,
}
#[derive(Debug)]
#[allow(clippy::large_enum_variant)]
enum Flavor {
Tcp(Mutex<TcpStream>),
Tls(Mutex<TlsStream<TcpStream>>), }
impl NatsStream {
fn new_tcp(tcp: TcpStream) -> Self {
tcp.set_nodelay(true).ok(); Self {
flavor: Arc::new(Flavor::Tcp(Mutex::new(tcp))),
}
}
fn new_tls(tls: TlsStream<TcpStream>) -> Self {
Self {
flavor: Arc::new(Flavor::Tls(Mutex::new(tls))),
}
}
pub(crate) async fn shutdown(&mut self) {
match Arc::<Flavor>::get_mut(&mut self.flavor) {
Some(Flavor::Tcp(tcp)) => {
let tcp = tcp.get_mut();
let _ = tcp.shutdown().await;
}
Some(Flavor::Tls(tls)) => {
let tls = tls.get_mut();
let _ = tls.get_mut().0.shutdown().await;
}
None => {
log::warn!("connection shutdown deferred");
}
}
}
}
macro_rules! impl_poll_flavors {
( $fname:ident, $for:ident ) => {
fn $fname(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let flavor: &Flavor = self.flavor.borrow();
match flavor {
Flavor::Tcp(tcp) => {
if let Ok(mut guard) = tcp.try_lock() {
Pin::new(guard.deref_mut()).$fname(cx)
} else {
Poll::Pending
}
}
Flavor::Tls(tls) => {
if let Ok(mut guard) = tls.try_lock() {
Pin::new(guard.deref_mut()).$fname(cx)
} else {
Poll::Pending
}
}
}
}
};
}
impl AsyncRead for NatsStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let flavor: &Flavor = self.flavor.borrow();
match flavor {
Flavor::Tcp(tcp) => {
if let Ok(mut guard) = tcp.try_lock() {
Pin::new(guard.deref_mut()).poll_read(cx, buf)
} else {
Poll::Pending
}
}
Flavor::Tls(tls) => {
if let Ok(mut guard) = tls.try_lock() {
Pin::new(guard.deref_mut()).poll_read(cx, buf)
} else {
Poll::Pending
}
}
}
}
}
#[async_trait]
impl AsyncWrite for NatsStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let flavor: &Flavor = self.flavor.borrow();
match flavor {
Flavor::Tcp(tcp) => {
if let Ok(mut guard) = tcp.try_lock() {
Pin::new(guard.deref_mut()).poll_write(cx, buf)
} else {
Poll::Pending
}
}
Flavor::Tls(tls) => {
if let Ok(mut guard) = tls.try_lock() {
Pin::new(guard.deref_mut()).poll_write(cx, buf)
} else {
Poll::Pending
}
}
}
}
impl_poll_flavors!(poll_flush, NatsStream);
impl_poll_flavors!(poll_shutdown, NatsStream);
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ServerAddress(Url);
pub trait IntoServerList {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>>;
}
impl FromStr for ServerAddress {
type Err = Error;
fn from_str(input: &str) -> Result<Self, Self::Err> {
let url: Url = if input.contains("://") {
input.parse()
} else {
format!("nats://{}", input).parse()
}
.map_err(|e| {
Error::new(
ErrorKind::InvalidInput,
format!("NATS server URL is invalid: {}", e),
)
})?;
Self::from_url(url)
}
}
impl ServerAddress {
pub fn from_url(url: Url) -> io::Result<Self> {
if url.scheme() != "nats" && url.scheme() != "tls" {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("invalid scheme for NATS server URL: {}", url.scheme()),
));
}
Ok(Self(url))
}
pub fn into_inner(self) -> Url {
self.0
}
pub fn tls_required(&self) -> bool {
self.0.scheme() == "tls"
}
pub fn has_user_pass(&self) -> bool {
self.0.username() != ""
}
pub fn host(&self) -> &str {
self.0.host_str().unwrap()
}
pub fn port(&self) -> u16 {
self.0.port().unwrap_or(4222)
}
pub fn username(&self) -> Option<SecureString> {
let user = self.0.username();
if user.is_empty() {
None
} else {
Some(SecureString::from(user.to_string()))
}
}
pub fn password(&self) -> Option<SecureString> {
self.0
.password()
.map(|password| SecureString::from(password.to_string()))
}
pub fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr>> {
inject_io_failure().and_then(|_| (self.host(), self.port()).to_socket_addrs())
}
}
impl<'s> IntoServerList for &'s str {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
self.split(',').map(|url| url.parse()).collect()
}
}
impl<'s> IntoServerList for &'s [&'s str] {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
self.iter().map(|url| url.parse()).collect()
}
}
impl<'s, const N: usize> IntoServerList for &'s [&'s str; N] {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
self.as_ref().into_server_list()
}
}
impl IntoServerList for String {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
self.as_str().into_server_list()
}
}
impl<'s> IntoServerList for &'s String {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
self.as_str().into_server_list()
}
}
impl IntoServerList for ServerAddress {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
Ok(vec![self])
}
}
impl IntoServerList for Vec<ServerAddress> {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
Ok(self)
}
}
impl IntoServerList for io::Result<Vec<ServerAddress>> {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
self
}
}