use lazy_static::__Deref;
use parking_lot::{Mutex, MutexGuard};
use std::collections::HashMap;
use std::io::prelude::*;
use std::io::{self, BufReader, Error, ErrorKind};
use std::net::{Shutdown, SocketAddr, TcpStream, ToSocketAddrs};
use std::str::FromStr;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use url::{Host, Url};
use webpki::DNSNameRef;
use crate::auth_utils;
use crate::proto::{self, ClientOp, ServerOp};
use crate::rustls::{ClientConfig, ClientSession, Session};
use crate::secure_wipe::SecureString;
use crate::{connect::ConnectInfo, inject_io_failure, AuthStyle, Options, ServerInfo};
pub(crate) struct Connector {
attempts: HashMap<ServerAddress, usize>,
options: Arc<Options>,
tls_config: Arc<ClientConfig>,
}
impl Connector {
pub(crate) fn new(urls: Vec<ServerAddress>, options: Arc<Options>) -> io::Result<Connector> {
let mut tls_config = options.tls_client_config.clone();
let roots = match rustls_native_certs::load_native_certs() {
Ok(store) | Err((Some(store), _)) => store.roots,
Err((None, _)) => Vec::new(),
};
for root in roots {
tls_config.root_store.roots.push(root);
}
for path in &options.certificates {
let contents = std::fs::read(path)?;
let mut cursor = std::io::Cursor::new(contents);
tls_config
.root_store
.add_pem_file(&mut cursor)
.map_err(|_| {
io::Error::new(io::ErrorKind::InvalidInput, "invalid certificate file")
})?;
}
if let Some(cert) = &options.client_cert {
if let Some(key) = &options.client_key {
tls_config
.set_single_client_cert(
auth_utils::load_certs(cert)?,
auth_utils::load_key(key)?,
)
.map_err(|err| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("invalid client certificate and key pair: {}", err),
)
})?;
}
}
let connector = Connector {
attempts: urls.into_iter().map(|url| (url, 0)).collect(),
options,
tls_config: Arc::new(tls_config),
};
Ok(connector)
}
pub(crate) fn add_server(&mut self, url: ServerAddress) {
self.attempts.insert(url, 0);
}
pub(crate) fn get_options(&self) -> Arc<Options> {
self.options.clone()
}
fn get_servers(&mut self) -> io::Result<Vec<ServerAddress>> {
let servers: Vec<_> = self
.attempts
.iter()
.filter_map(
|(server, reconnects)| match self.options.max_reconnects.as_ref() {
None => Some(server),
Some(max) if reconnects < max => Some(server),
Some(_) => None,
},
)
.cloned()
.collect();
if servers.is_empty() {
Err(Error::new(
ErrorKind::NotFound,
"no servers remaining to connect to",
))
} else {
Ok(servers)
}
}
pub(crate) fn connect(&mut self, use_backoff: bool) -> io::Result<(ServerInfo, NatsStream)> {
let mut last_err = Error::new(ErrorKind::AddrNotAvailable, "no socket addresses");
loop {
let mut servers = self.get_servers()?;
fastrand::shuffle(&mut servers);
for server in &servers {
let reconnects = self.attempts.get_mut(server).unwrap();
let sleep_duration = self.options.reconnect_delay_callback.call(*reconnects);
*reconnects += 1;
let lookup_res = server.socket_addrs();
let mut addrs = match lookup_res {
Ok(addrs) => addrs.collect::<Vec<_>>(),
Err(err) => {
last_err = err;
continue;
}
};
fastrand::shuffle(&mut addrs);
for addr in addrs {
thread::sleep(sleep_duration);
let res = self.connect_addr(addr, server);
let (server_info, stream) = match res {
Ok(val) => val,
Err(err) => {
last_err = err;
continue;
}
};
for url in &server_info.connect_urls {
self.add_server(url.parse()?);
}
*self.attempts.get_mut(server).unwrap() = 0;
return Ok((server_info, stream));
}
}
if !use_backoff {
return Err(last_err);
}
}
}
fn connect_addr(
&self,
addr: SocketAddr,
server: &ServerAddress,
) -> io::Result<(ServerInfo, NatsStream)> {
inject_io_failure()?;
let mut stream = TcpStream::connect(addr)?;
stream.set_nodelay(true)?;
let mut line = crate::SecureVec::with_capacity(1024);
while !line.ends_with(b"\r\n") {
let byte = &mut [0];
stream.read_exact(byte)?;
line.push(byte[0]);
}
let server_info = match proto::decode(&line[..])? {
Some(ServerOp::Info(server_info)) => server_info,
Some(op) => {
return Err(Error::new(
ErrorKind::Other,
format!("expected INFO, received: {:?}", op),
));
}
None => {
return Err(Error::new(ErrorKind::UnexpectedEof, "connection closed"));
}
};
let tls_required =
self.options.tls_required || server.tls_required() || server_info.tls_required;
let session = if tls_required {
inject_io_failure()?;
let dns_name = DNSNameRef::try_from_ascii_str(&server_info.host)
.or_else(|_| DNSNameRef::try_from_ascii_str(server.host()))
.map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
"cannot determine hostname for TLS connection",
)
})?;
Some(ClientSession::new(&self.tls_config, dns_name))
} else {
None
};
let mut stream = NatsStream::new(stream, session)?;
let mut connect_info = ConnectInfo {
tls_required,
name: self.options.name.clone().map(SecureString::from),
pedantic: false,
verbose: false,
lang: crate::LANG.to_string(),
version: crate::VERSION.to_string(),
protocol: crate::connect::Protocol::Dynamic,
user: None,
pass: None,
auth_token: None,
user_jwt: None,
nkey: None,
signature: None,
echo: !self.options.no_echo,
headers: true,
no_responders: true,
};
let server_auth = server.auth();
let auth = if let AuthStyle::NoAuth = server_auth {
&self.options.auth
} else {
&server_auth
};
match auth {
AuthStyle::NoAuth => {}
AuthStyle::UserPass(user, pass) => {
connect_info.user = Some(SecureString::from(user.to_string()));
connect_info.pass = Some(SecureString::from(pass.to_string()));
}
AuthStyle::Token(token) => {
connect_info.auth_token = Some(token.to_string().into());
}
AuthStyle::Credentials { jwt_cb, sig_cb } => {
let jwt = jwt_cb()?;
let sig = sig_cb(server_info.nonce.as_bytes())?;
connect_info.user_jwt = Some(jwt);
connect_info.signature = Some(sig);
}
AuthStyle::NKey { nkey_cb, sig_cb } => {
let nkey = nkey_cb()?;
let sig = sig_cb(server_info.nonce.as_bytes())?;
connect_info.nkey = Some(nkey);
connect_info.signature = Some(sig);
}
}
proto::encode(&mut stream, ClientOp::Connect(&connect_info))?;
proto::encode(&mut stream, ClientOp::Ping)?;
stream.flush()?;
let mut reader = BufReader::new(stream.clone());
loop {
match proto::decode(&mut reader)? {
Some(ServerOp::Pong) => break,
Some(ServerOp::Ping) => {
proto::encode(&mut stream, ClientOp::Pong)?;
stream.flush()?;
}
Some(op) => {
return Err(Error::new(
ErrorKind::InvalidData,
format!("unexpected line while connecting: {:?}", op),
));
}
None => {
return Err(Error::new(
ErrorKind::UnexpectedEof,
"connection closed while waiting for the first PONG",
));
}
}
}
Ok((server_info, stream))
}
}
#[derive(Clone)]
pub(crate) struct NatsStream {
flavor: Arc<Flavor>,
}
enum Flavor {
Tcp(TcpStream),
Tls(Box<Mutex<TlsStream>>),
}
struct TlsStream {
tcp: TcpStream,
session: ClientSession,
}
impl NatsStream {
fn new(tcp: TcpStream, session: Option<ClientSession>) -> io::Result<NatsStream> {
let flavor = match session {
None => Flavor::Tcp(tcp),
Some(session) => {
tcp.set_nonblocking(true)?;
Flavor::Tls(Box::new(Mutex::new(TlsStream { tcp, session })))
}
};
let flavor = Arc::new(flavor);
Ok(NatsStream { flavor })
}
pub(crate) fn set_write_timeout(&self, timeout: Option<Duration>) -> io::Result<()> {
match &*self.flavor {
Flavor::Tcp(tcp) => tcp.set_write_timeout(timeout),
Flavor::Tls(tls) => tls.lock().tcp.set_write_timeout(timeout),
}
}
pub(crate) fn shutdown(&self) {
match &*self.flavor {
Flavor::Tcp(tcp) => tcp.shutdown(Shutdown::Both),
Flavor::Tls(tls) => tls.lock().tcp.shutdown(Shutdown::Both),
}
.ok();
}
}
impl Read for NatsStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
<&NatsStream as Read>::read(&mut &*self, buf)
}
}
impl Read for &NatsStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match &*self.flavor {
Flavor::Tcp(tcp) => (tcp.deref()).read(buf),
Flavor::Tls(tls) => tls_op(tls, |session, eof| match session.read(buf) {
Ok(0) if !eof => Err(io::ErrorKind::WouldBlock.into()),
res => res,
}),
}
}
}
impl Write for NatsStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
<&NatsStream as Write>::write(&mut &*self, buf)
}
fn flush(&mut self) -> io::Result<()> {
<&NatsStream as Write>::flush(&mut &*self)
}
}
impl Write for &NatsStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match &*self.flavor {
Flavor::Tcp(tcp) => (tcp.deref()).write(buf),
Flavor::Tls(tls) => tls_op(tls, |session, _| session.write(buf)),
}
}
fn flush(&mut self) -> io::Result<()> {
match &*self.flavor {
Flavor::Tcp(tcp) => (tcp.deref()).flush(),
Flavor::Tls(tls) => tls_op(tls, |session, _| session.flush()),
}
}
}
fn tls_op<T: std::fmt::Debug>(
tls: &Mutex<TlsStream>,
mut op: impl FnMut(&mut ClientSession, bool) -> io::Result<T>,
) -> io::Result<T> {
loop {
let mut tls = tls.lock();
let TlsStream { tcp, session } = &mut *tls;
let mut eof = false;
if session.wants_read() {
match session.read_tls(tcp) {
Ok(0) => eof = true,
Ok(_) => session
.process_new_packets()
.map_err(|err| Error::new(ErrorKind::Other, format!("TLS error: {}", err)))?,
Err(err) if err.kind() == ErrorKind::WouldBlock => {}
Err(err) => return Err(err),
}
}
if session.wants_write() {
match session.write_tls(tcp) {
Ok(_) => {}
Err(err) if err.kind() == ErrorKind::WouldBlock => {}
Err(err) => return Err(err),
}
}
match op(session, eof) {
Err(err) if err.kind() == ErrorKind::WouldBlock => {}
res => return res,
}
tls_wait(tls)?;
}
}
fn tls_wait(mut tls: MutexGuard<'_, TlsStream>) -> io::Result<()> {
#[cfg(unix)]
use {
libc::{self as sys, poll, pollfd},
std::os::unix::io::AsRawFd,
};
#[cfg(windows)]
use {
std::os::windows::io::AsRawSocket,
winapi::um::winsock2::{self as sys, WSAPoll as poll, WSAPOLLFD as pollfd},
};
let TlsStream { tcp, session } = &mut *tls;
#[allow(trivial_numeric_casts)]
let mut pollfd = pollfd {
#[cfg(unix)]
fd: tcp.as_raw_fd() as _,
#[cfg(windows)]
fd: tcp.as_raw_socket() as _,
#[cfg(unix)]
events: sys::POLLERR,
#[cfg(windows)]
events: 0,
revents: 0,
};
if session.wants_read() {
pollfd.events |= sys::POLLIN;
}
if session.wants_write() {
pollfd.events |= sys::POLLOUT;
}
drop(tls);
#[allow(unsafe_code)]
while unsafe { poll(&mut pollfd, 1, -1) } == -1 {
let err = Error::last_os_error();
if err.kind() != io::ErrorKind::Interrupted {
return Err(err);
}
}
Ok(())
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ServerAddress(Url);
pub trait IntoServerList {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>>;
}
impl FromStr for ServerAddress {
type Err = Error;
fn from_str(input: &str) -> Result<Self, Self::Err> {
let url: Url = if input.contains("://") {
input.parse()
} else {
format!("nats://{}", input).parse()
}
.map_err(|e| {
Error::new(
ErrorKind::InvalidInput,
format!("NATS server URL is invalid: {}", e),
)
})?;
Self::from_url(url)
}
}
impl ServerAddress {
pub fn from_url(url: Url) -> io::Result<Self> {
if url.scheme() != "nats" && url.scheme() != "tls" {
return Err(Error::new(
ErrorKind::InvalidInput,
format!("invalid scheme for NATS server URL: {}", url.scheme()),
));
}
Ok(Self(url))
}
pub fn into_inner(self) -> Url {
self.0
}
pub fn tls_required(&self) -> bool {
self.0.scheme() == "tls"
}
pub fn has_user_pass(&self) -> bool {
self.0.username() != ""
}
pub(crate) fn auth(&self) -> AuthStyle {
if let Some(password) = self.0.password() {
if self.0.username() == "" {
AuthStyle::NoAuth
} else {
AuthStyle::UserPass(self.0.username().to_string(), password.to_string())
}
} else if "" != self.0.username() {
AuthStyle::Token(self.0.username().to_string())
} else {
AuthStyle::NoAuth
}
}
pub fn host(&self) -> &str {
match self.0.host() {
Some(Host::Domain(_)) | Some(Host::Ipv4 { .. }) => self.0.host_str().unwrap(),
Some(Host::Ipv6 { .. }) => {
let host = self.0.host_str().unwrap();
&host[1..host.len() - 1]
}
None => "",
}
}
pub fn port(&self) -> u16 {
self.0.port().unwrap_or(4222)
}
pub fn username(&self) -> Option<SecureString> {
let user = self.0.username();
if user.is_empty() {
None
} else {
Some(SecureString::from(user.to_string()))
}
}
pub fn password(&self) -> Option<SecureString> {
self.0
.password()
.map(|password| SecureString::from(password.to_string()))
}
pub fn socket_addrs(&self) -> io::Result<impl Iterator<Item = SocketAddr>> {
inject_io_failure().and_then(|_| (self.host(), self.port()).to_socket_addrs())
}
}
impl<'s> IntoServerList for &'s str {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
self.split(',').map(|url| url.parse()).collect()
}
}
impl<'s> IntoServerList for &'s [&'s str] {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
self.iter().map(|url| url.parse()).collect()
}
}
impl<'s, const N: usize> IntoServerList for &'s [&'s str; N] {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
self.as_ref().into_server_list()
}
}
impl IntoServerList for String {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
self.as_str().into_server_list()
}
}
impl<'s> IntoServerList for &'s String {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
self.as_str().into_server_list()
}
}
impl IntoServerList for ServerAddress {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
Ok(vec![self])
}
}
impl IntoServerList for Vec<ServerAddress> {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
Ok(self)
}
}
impl IntoServerList for io::Result<Vec<ServerAddress>> {
fn into_server_list(self) -> io::Result<Vec<ServerAddress>> {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn server_address_ipv6() {
let address = ServerAddress::from_str("nats://[::]").unwrap();
assert_eq!(address.host(), "::")
}
#[test]
fn server_address_ipv4() {
let address = ServerAddress::from_str("nats://127.0.0.1").unwrap();
assert_eq!(address.host(), "127.0.0.1")
}
#[test]
fn server_address_domain() {
let address = ServerAddress::from_str("nats://example.com").unwrap();
assert_eq!(address.host(), "example.com")
}
#[test]
fn server_address_no_auth() {
let address = ServerAddress::from_str("nats://localhost").unwrap();
assert!(matches!(address.auth(), AuthStyle::NoAuth));
}
#[test]
fn server_address_token_auth() {
let address = ServerAddress::from_str("nats://mytoken@localhost").unwrap();
assert!(matches!(address.auth(), AuthStyle::Token(token) if &token == "mytoken"));
}
#[test]
fn server_address_user_auth() {
let address = ServerAddress::from_str("nats://myuser:mypass@localhost").unwrap();
assert!(
matches!(address.auth(), AuthStyle::UserPass(username, password) if &username == "myuser" && &password == "mypass")
);
}
}