mod codec;
#[cfg(test)]
mod tests;
mod tls_or_tcp_stream;
mod types;
mod util;
#[cfg(feature = "tls")]
use crate::types::tls::TlsConfig;
use futures::{
future::{self, Either, FutureExt},
lock::{Mutex, MutexGuard},
pin_mut,
stream::StreamExt,
};
use log::{debug, error, info, trace, warn};
use owning_ref::{OwningRef, OwningRefMut};
use rand::seq::SliceRandom;
use std::{
collections::HashMap, io::ErrorKind, mem, result::Result as StdResult, sync::Arc,
time::Duration,
};
use tokio::{
io::{self, AsyncWriteExt, ReadHalf, WriteHalf},
net::TcpStream,
sync::{
mpsc, oneshot,
watch::{self, Sender as WatchSender},
},
time::{self, Elapsed},
};
use tokio_util::codec::FramedRead;
use uuid::Uuid;
use crate::{
codec::Codec,
error::{Error, Result},
tls_or_tcp_stream::TlsOrTcpStream,
types::{
ClientControl, ConnectionState, ServerMessage, StableMutexGuard, StateTransition,
StateTransitionResult,
},
};
#[cfg(feature = "native-tls")]
pub use native_tls_crate as native_tls;
#[cfg(feature = "rustls-tls")]
pub use rustls;
pub use tokio::sync::{
mpsc::Receiver as MpscReceiver, mpsc::Sender as MpscSender, watch::Receiver as WatchReceiver,
};
pub use crate::types::{
error, Address, Authorization, ClientRef, ClientRefMut, ClientState, Connect, Info, Msg,
ProtocolError, Sid, Subject, SubjectBuilder, Subscription,
};
const TCP_SOCKET_DISCONNECTED_MESSAGE: &str = "TCP socket was disconnected";
pub type DelayGenerator = Box<dyn Fn(&Client, u64, u64) -> Duration + Send>;
pub fn generate_delay_generator(
connect_series_attempts_before_cool_down: u64,
connect_delay: Duration,
connect_series_delay: Duration,
cool_down: Duration,
) -> DelayGenerator {
Box::new(move |_: &Client, connect_attempts: u64, addresses: u64| {
if connect_attempts % (addresses * connect_series_attempts_before_cool_down) == 0 {
trace!("Using cool down delay {}s", cool_down.as_secs_f32());
cool_down
} else if connect_attempts % addresses == 0 {
trace!(
"Using connect series delay {}s",
connect_series_delay.as_secs_f32()
);
connect_series_delay
} else {
trace!("Using connect delay {}s", connect_delay.as_secs_f32());
connect_delay
}
})
}
#[derive(Clone)]
pub struct Client {
sync: Arc<Mutex<SyncClient>>,
}
impl Client {
pub fn new(addresses: Vec<Address>) -> Self {
Self::with_connect(addresses, Connect::new())
}
pub fn with_connect(addresses: Vec<Address>, connect: Connect) -> Self {
Self {
sync: Arc::new(Mutex::new(SyncClient::with_connect(addresses, connect))),
}
}
pub async fn state(&self) -> ClientState {
self.lock().await.state()
}
pub async fn state_stream(&self) -> WatchReceiver<ClientState> {
self.lock().await.state_stream()
}
pub async fn info(&self) -> Info {
self.lock().await.info()
}
pub async fn connect_mut(&self) -> ClientRefMut<'_, Connect> {
ClientRefMut(
OwningRefMut::new(StableMutexGuard(self.lock().await)).map_mut(|c| c.connect_mut()),
)
}
pub async fn addresses_mut(&self) -> ClientRefMut<'_, [Address]> {
ClientRefMut(
OwningRefMut::new(StableMutexGuard(self.lock().await)).map_mut(|c| c.addresses_mut()),
)
}
pub async fn tcp_connect_timeout(&self) -> Duration {
self.lock().await.tcp_connect_timeout()
}
pub async fn set_tcp_connect_timeout(&self, tcp_connect_timeout: Duration) -> &Self {
self.lock()
.await
.set_tcp_connect_timeout(tcp_connect_timeout);
self
}
pub async fn delay_generator_mut(&self) -> ClientRefMut<'_, DelayGenerator> {
ClientRefMut(
OwningRefMut::new(StableMutexGuard(self.lock().await))
.map_mut(|c| c.delay_generator_mut()),
)
}
#[cfg(feature = "tls")]
pub async fn set_tls_config(&mut self, tls_config: TlsConfig) -> &mut Self {
self.lock().await.set_tls_config(tls_config);
self
}
#[cfg(feature = "tls")]
pub async fn set_tls_domain(&mut self, tls_domain: String) -> &mut Self {
self.lock().await.set_tls_domain(tls_domain);
self
}
pub async fn sids(&self) -> Vec<Sid> {
self.lock()
.await
.subscriptions()
.map(|(sid, _)| *sid)
.collect()
}
pub async fn subscription(&self, sid: Sid) -> Option<ClientRef<'_, Subscription>> {
let client = self.lock().await;
if client.subscriptions.contains_key(&sid) {
Some(ClientRef(
OwningRef::new(StableMutexGuard(client))
.map(|c| c.subscriptions.get(&sid).unwrap()),
))
} else {
None
}
}
pub async fn send_connect(&self) -> Result<()> {
let mut client = self.lock().await;
client.send_connect().await
}
pub async fn connect(&self) {
SyncClient::connect(Self::clone(self)).await
}
pub async fn disconnect(&self) {
SyncClient::disconnect(Arc::clone(&self.sync)).await
}
pub async fn publish(&self, subject: &Subject, payload: &[u8]) -> Result<()> {
self.publish_with_optional_reply(subject, None, payload)
.await
}
pub async fn publish_with_reply(
&self,
subject: &Subject,
reply_to: &Subject,
payload: &[u8],
) -> Result<()> {
self.publish_with_optional_reply(subject, Some(reply_to), payload)
.await
}
pub async fn publish_with_optional_reply(
&self,
subject: &Subject,
reply_to: Option<&Subject>,
payload: &[u8],
) -> Result<()> {
let mut client = self.lock().await;
client
.publish_with_optional_reply(subject, reply_to, payload)
.await
}
pub async fn request(&self, subject: &Subject, payload: &[u8]) -> Result<Msg> {
SyncClient::request_with_timeout(Arc::clone(&self.sync), subject, payload, None).await
}
pub async fn request_with_timeout(
&self,
subject: &Subject,
payload: &[u8],
duration: Duration,
) -> Result<Msg> {
SyncClient::request_with_timeout(Arc::clone(&self.sync), subject, payload, Some(duration))
.await
}
pub async fn subscribe(
&self,
subject: &Subject,
buffer: usize,
) -> Result<(Sid, MpscReceiver<Msg>)> {
self.subscribe_with_optional_queue_group(subject, None, buffer)
.await
}
pub async fn subscribe_with_queue_group(
&self,
subject: &Subject,
queue_group: &str,
buffer: usize,
) -> Result<(Sid, MpscReceiver<Msg>)> {
self.subscribe_with_optional_queue_group(subject, Some(queue_group), buffer)
.await
}
pub async fn subscribe_with_optional_queue_group(
&self,
subject: &Subject,
queue_group: Option<&str>,
buffer: usize,
) -> Result<(Sid, MpscReceiver<Msg>)> {
let mut client = self.lock().await;
client
.subscribe_with_optional_queue_group(subject, queue_group, buffer)
.await
}
pub async fn unsubscribe(&self, sid: Sid) -> Result<()> {
self.unsubscribe_optional_max_msgs(sid, None).await
}
pub async fn unsubscribe_with_max_msgs(&self, sid: Sid, max_msgs: u64) -> Result<()> {
self.unsubscribe_optional_max_msgs(sid, Some(max_msgs))
.await
}
pub async fn unsubscribe_optional_max_msgs(
&self,
sid: Sid,
max_msgs: Option<u64>,
) -> Result<()> {
let mut client = self.lock().await;
client.unsubscribe_optional_max_msgs(sid, max_msgs).await
}
pub async fn unsubscribe_all(&self) -> Result<()> {
let unsubscribes = self
.sids()
.await
.into_iter()
.map(|sid| self.unsubscribe(sid));
future::try_join_all(unsubscribes).await?;
Ok(())
}
pub async fn info_stream(&self) -> WatchReceiver<Info> {
self.lock().await.info_stream()
}
pub async fn ping_stream(&self) -> WatchReceiver<()> {
self.lock().await.ping_stream()
}
pub async fn pong_stream(&self) -> WatchReceiver<()> {
self.lock().await.pong_stream()
}
pub async fn ok_stream(&self) -> WatchReceiver<()> {
self.lock().await.ok_stream()
}
pub async fn err_stream(&self) -> WatchReceiver<ProtocolError> {
self.lock().await.err_stream()
}
pub async fn ping(&self) -> Result<()> {
let mut client = self.lock().await;
client.ping().await
}
pub async fn pong(&self) -> Result<()> {
let mut client = self.lock().await;
client.pong().await
}
pub async fn ping_pong(&self) -> Result<()> {
SyncClient::ping_pong(Arc::clone(&self.sync)).await
}
async fn lock(&self) -> MutexGuard<'_, SyncClient> {
self.sync.lock().await
}
}
impl Drop for Client {
fn drop(&mut self) {
trace!("Client was dropped");
}
}
struct SyncClient {
addresses: Vec<Address>,
connect: Connect,
state: ConnectionState,
state_tx: WatchSender<ClientState>,
state_rx: WatchReceiver<ClientState>,
info_tx: WatchSender<Info>,
info_rx: WatchReceiver<Info>,
ping_tx: WatchSender<()>,
ping_rx: WatchReceiver<()>,
pong_tx: WatchSender<()>,
pong_rx: WatchReceiver<()>,
ok_tx: WatchSender<()>,
ok_rx: WatchReceiver<()>,
err_tx: WatchSender<ProtocolError>,
err_rx: WatchReceiver<ProtocolError>,
tcp_connect_timeout: Duration,
delay_generator: DelayGenerator,
#[cfg(feature = "tls")]
tls_config: Option<TlsConfig>,
#[cfg(feature = "tls")]
tls_domain: Option<String>,
subscriptions: HashMap<Sid, Subscription>,
request_inbox_mapping: HashMap<Subject, MpscSender<Msg>>,
request_wildcard_subscription: Option<Sid>,
request_base_inbox: String,
}
impl SyncClient {
fn with_connect(addresses: Vec<Address>, connect: Connect) -> Self {
let state = ConnectionState::Disconnected;
let (state_tx, state_rx) = watch::channel((&state).into());
let (info_tx, info_rx) = watch::channel(Info::new());
let (ping_tx, ping_rx) = watch::channel(());
let (pong_tx, pong_rx) = watch::channel(());
let (ok_tx, ok_rx) = watch::channel(());
let (err_tx, err_rx) = watch::channel(ProtocolError::UnknownProtocolOperation);
Self {
addresses,
connect,
state,
state_tx,
state_rx,
info_tx,
info_rx,
ping_tx,
ping_rx,
pong_tx,
pong_rx,
ok_tx,
ok_rx,
err_tx,
err_rx,
tcp_connect_timeout: util::DEFAULT_TCP_CONNECT_TIMEOUT,
delay_generator: generate_delay_generator(
util::DEFAULT_CONNECT_SERIES_ATTEMPTS_BEFORE_COOL_DOWN,
util::DEFAULT_CONNECT_DELAY,
util::DEFAULT_CONNECT_SERIES_DELAY,
util::DEFAULT_COOL_DOWN,
),
#[cfg(feature = "tls")]
tls_config: None,
#[cfg(feature = "tls")]
tls_domain: None,
subscriptions: HashMap::new(),
request_inbox_mapping: HashMap::new(),
request_wildcard_subscription: None,
request_base_inbox: Uuid::new_v4().to_simple().to_string(),
}
}
fn state(&self) -> ClientState {
self.state_rx.borrow().clone()
}
fn state_stream(&self) -> WatchReceiver<ClientState> {
self.state_rx.clone()
}
pub fn info(&self) -> Info {
self.info_rx.borrow().clone()
}
fn connect_mut(&mut self) -> &mut Connect {
&mut self.connect
}
fn addresses_mut(&mut self) -> &mut [Address] {
&mut self.addresses
}
fn tcp_connect_timeout(&self) -> Duration {
self.tcp_connect_timeout
}
fn set_tcp_connect_timeout(&mut self, tcp_connect_timeout: Duration) -> &mut Self {
self.tcp_connect_timeout = tcp_connect_timeout;
self
}
fn delay_generator_mut(&mut self) -> &mut DelayGenerator {
&mut self.delay_generator
}
#[cfg(feature = "tls")]
fn set_tls_config(&mut self, tls_config: TlsConfig) -> &mut Self {
self.tls_config = Some(tls_config);
self
}
#[cfg(feature = "tls")]
fn set_tls_domain(&mut self, domain: String) -> &mut Self {
self.tls_domain = Some(domain);
self
}
fn subscriptions(&self) -> impl Iterator<Item = (&Sid, &Subscription)> {
self.subscriptions.iter()
}
async fn send_connect(&mut self) -> Result<()> {
if let ConnectionState::Connected(address, writer) = &mut self.state {
Self::send_connect_with_writer(writer, &self.connect, address).await
} else {
Err(Error::NotConnected)
}
}
#[cfg(feature = "tls")]
async fn upgrade_to_tls(
&mut self,
stream: TlsOrTcpStream,
domain: &str,
) -> Result<TlsOrTcpStream> {
let domain = self.tls_domain.as_deref().unwrap_or(domain).to_string();
info!(
"Using '{}' as the domain to upgrade to a TLS connection",
domain
);
let tls_config = self.tls_config.clone().ok_or(Error::TlsDisabled)?;
Ok(stream.upgrade(tls_config, &domain).await?)
}
#[cfg(not(feature = "tls"))]
async fn upgrade_to_tls(
&mut self,
_stream: TlsOrTcpStream,
_domain: &str,
) -> Result<TlsOrTcpStream> {
Err(Error::TlsDisabled)
}
#[allow(clippy::cognitive_complexity)]
async fn connect(wrapped_client: Client) {
if let ConnectionState::Connected(_, _) = wrapped_client.lock().await.state {
return;
}
let (addresses_len, mut addresses_iter) = {
let client = wrapped_client.lock().await;
let mut addresses = client
.addresses
.iter()
.chain(client.info_rx.borrow().connect_urls().iter())
.cloned()
.collect::<Vec<_>>();
let addresses_len = addresses.len() as u64;
addresses.shuffle(&mut rand::thread_rng());
(addresses_len, addresses.into_iter().cycle())
};
let mut connect_attempts = 0;
loop {
if connect_attempts != 0 {
let delay = (wrapped_client.lock().await.delay_generator)(
&wrapped_client,
connect_attempts,
addresses_len,
);
debug!(
"Delaying for {}s after {} connect attempts with {} addresses",
delay.as_secs_f32(),
connect_attempts,
addresses_len
);
time::delay_for(delay).await;
}
connect_attempts += 1;
let mut client = wrapped_client.lock().await;
match client.state {
ConnectionState::Connected(_, _) => {
return;
}
ConnectionState::Disconnecting(_) => {
client.state_transition(StateTransition::ToDisconnected);
return;
}
_ => (),
}
let address = if let Some(address) = addresses_iter.next() {
address
} else {
error!("No addresses to connect to");
continue;
};
client.state_transition(StateTransition::ToConnecting(address.clone()));
let connect = time::timeout(
client.tcp_connect_timeout,
TcpStream::connect(address.address()),
);
let (reader, writer) = match connect.await {
Ok(Ok(stream)) => {
io::split(TlsOrTcpStream::new(stream))
}
Ok(Err(e)) => {
error!("Failed to connect to '{}', err: {}", address, e);
continue;
}
Err(_) => {
error!("Timed out while connecting to '{}'", address);
continue;
}
};
let mut reader = FramedRead::new(reader, Codec::new());
let wait_for_info = time::timeout(client.tcp_connect_timeout, reader.next());
let tls_required = if let Some(message) =
Self::unwrap_server_message_with_timeout(wait_for_info.await, util::INFO_OP_NAME)
{
if let ServerMessage::Info(info) = message {
let tls_required = info.tls_required();
client.handle_info_message(info);
tls_required
} else {
error!(
"First message should be {} instead received '{:?}'",
util::INFO_OP_NAME,
message
);
debug_assert!(false);
continue;
}
} else {
continue;
};
let (mut reader, mut writer) = if tls_required {
let stream = reader.into_inner().unsplit(writer);
let upgraded_stream = match client.upgrade_to_tls(stream, address.domain()).await {
Ok(stream) => stream,
Err(e) => {
error!("Failed to upgrade to TLS connection, err: {}", e);
continue;
}
};
let (reader, writer) = io::split(upgraded_stream);
(FramedRead::new(reader, Codec::new()), writer)
} else {
(reader, writer)
};
if let Err(e) =
Self::send_connect_with_writer(&mut writer, &client.connect, &address).await
{
error!("Failed to send connect message, err: {}", e);
continue;
}
if client.connect_mut().is_verbose() {
let wait_for_ok = time::timeout(client.tcp_connect_timeout, reader.next());
if let Some(message) =
Self::unwrap_server_message_with_timeout(wait_for_ok.await, util::OK_OP_NAME)
{
match message {
ServerMessage::Ok => (),
ServerMessage::Err(e) => {
error!(
"Protocol error waiting for {} message, err: {}",
util::OK_OP_NAME,
e
);
continue;
}
message => {
error!(
"Next message should be {} instead received '{:?}'",
util::OK_OP_NAME,
message
);
debug_assert!(false);
continue;
}
}
} else {
continue;
};
}
if let Err(e) = Self::ping_with_writer(&mut writer).await {
error!("Failed to send ping when verifying connection, err: {}", e);
continue;
}
let wait_for_pong = time::timeout(client.tcp_connect_timeout, reader.next());
if let Some(message) =
Self::unwrap_server_message_with_timeout(wait_for_pong.await, util::PONG_OP_NAME)
{
match message {
ServerMessage::Pong => (),
ServerMessage::Err(e) => {
error!(
"Protocol error waiting for {} message, err: {}",
util::PONG_OP_NAME,
e
);
continue;
}
message => {
error!(
"Next message should be {} instead received '{:?}'",
util::PONG_OP_NAME,
message
);
debug_assert!(false);
continue;
}
}
} else {
continue;
};
let mut failed_to_resubscribe = Vec::new();
for (sid, subscription) in &client.subscriptions {
if let Err(e) =
Self::write_line(&mut writer, ClientControl::Sub(subscription)).await
{
error!(
"Failed to resubscribe to sid '{}' with subject '{}', err: {}",
sid,
subscription.subject(),
e
);
failed_to_resubscribe.push(*sid);
}
}
client
.subscriptions
.retain(|sid, _| !failed_to_resubscribe.contains(&sid));
tokio::spawn(Self::type_erased_server_messages_handler(
Client::clone(&wrapped_client),
reader,
));
client.state_transition(StateTransition::ToConnected(writer));
return;
}
}
async fn disconnect(wrapped_client: Arc<Mutex<Self>>) {
let (tx, rx) = oneshot::channel();
{
let mut client = wrapped_client.lock().await;
if let ConnectionState::Disconnected = client.state {
return;
}
let mut state_stream = client.state_stream();
tokio::spawn(async move {
while let Some(state) = state_stream.next().await {
if state.is_disconnected() {
tx.send(()).expect("to send disconnect signal");
break;
}
}
});
client.state_transition(StateTransition::ToDisconnecting);
}
rx.await.expect("to receive disconnect signal");
}
async fn publish_with_reply(
&mut self,
subject: &Subject,
reply_to: &Subject,
payload: &[u8],
) -> Result<()> {
self.publish_with_optional_reply(subject, Some(reply_to), payload)
.await
}
async fn publish_with_optional_reply(
&mut self,
subject: &Subject,
reply_to: Option<&Subject>,
payload: &[u8],
) -> Result<()> {
let max_payload = self.info().max_payload;
if let ConnectionState::Connected(_, writer) = &mut self.state {
if payload.len() > max_payload {
return Err(Error::ExceedsMaxPayload {
tried: payload.len(),
limit: max_payload,
});
}
Self::write_line(writer, ClientControl::Pub(subject, reply_to, payload.len())).await?;
writer.write_all(payload).await?;
writer
.write_all(util::MESSAGE_TERMINATOR.as_bytes())
.await?;
Ok(())
} else {
Err(Error::NotConnected)
}
}
async fn request_wildcard_handler(
wrapped_client: Arc<Mutex<Self>>,
mut subscription_rx: MpscReceiver<Msg>,
) {
while let Some(msg) = subscription_rx.next().await {
let mut client = wrapped_client.lock().await;
if let Some(mut requester_tx) = client.request_inbox_mapping.remove(&msg.subject()) {
requester_tx.send(msg).await.unwrap_or_else(|err| {
warn!("Could not write response to pending request via mapping channel. Skipping! Err: {}", err);
debug_assert!(false);
});
} else {
warn!(
"Could not find response channel for request with subject: {}",
&msg.subject()
);
}
}
let mut client = wrapped_client.lock().await;
client.request_inbox_mapping.clear();
client.request_wildcard_subscription = None;
}
async fn request_with_timeout(
wrapped_client: Arc<Mutex<Self>>,
subject: &Subject,
payload: &[u8],
duration: Option<Duration>,
) -> Result<Msg> {
let request = Request::new(wrapped_client).await?;
request.call(subject, payload, duration).await
}
async fn subscribe(
&mut self,
subject: &Subject,
buffer: usize,
) -> Result<(Sid, MpscReceiver<Msg>)> {
self.subscribe_with_optional_queue_group(subject, None, buffer)
.await
}
async fn subscribe_with_optional_queue_group(
&mut self,
subject: &Subject,
queue_group: Option<&str>,
buffer: usize,
) -> Result<(Sid, MpscReceiver<Msg>)> {
if let ConnectionState::Connected(_, writer) = &mut self.state {
let (tx, rx) = mpsc::channel(buffer);
let subscription =
Subscription::new(subject.clone(), queue_group.map(String::from), tx);
Self::write_line(writer, ClientControl::Sub(&subscription)).await?;
let sid = subscription.sid();
self.subscriptions.insert(sid, subscription);
Ok((sid, rx))
} else {
Err(Error::NotConnected)
}
}
async fn unsubscribe(&mut self, sid: Sid) -> Result<()> {
self.unsubscribe_optional_max_msgs(sid, None).await
}
async fn unsubscribe_optional_max_msgs(
&mut self,
sid: Sid,
max_msgs: Option<u64>,
) -> Result<()> {
if let ConnectionState::Connected(_, writer) = &mut self.state {
let subscription = match self.subscriptions.get_mut(&sid) {
Some(subscription) => subscription,
None => return Err(Error::UnknownSid(sid)),
};
subscription.unsubscribe_after = max_msgs;
Self::write_line(writer, ClientControl::Unsub(sid, max_msgs)).await?;
if subscription.unsubscribe_after.is_none() {
self.subscriptions.remove(&sid);
}
Ok(())
} else {
Err(Error::NotConnected)
}
}
pub fn info_stream(&mut self) -> WatchReceiver<Info> {
self.info_rx.clone()
}
pub fn ping_stream(&mut self) -> WatchReceiver<()> {
self.ping_rx.clone()
}
pub fn pong_stream(&mut self) -> WatchReceiver<()> {
self.pong_rx.clone()
}
pub fn ok_stream(&mut self) -> WatchReceiver<()> {
self.ok_rx.clone()
}
pub fn err_stream(&mut self) -> WatchReceiver<ProtocolError> {
self.err_rx.clone()
}
async fn ping(&mut self) -> Result<()> {
if let ConnectionState::Connected(_, writer) = &mut self.state {
Self::ping_with_writer(writer).await?;
Ok(())
} else {
Err(Error::NotConnected)
}
}
async fn pong(&mut self) -> Result<()> {
if let ConnectionState::Connected(_, writer) = &mut self.state {
Self::write_line(writer, ClientControl::Pong).await?;
Ok(())
} else {
Err(Error::NotConnected)
}
}
async fn ping_pong(wrapped_client: Arc<Mutex<Self>>) -> Result<()> {
let mut pong_stream = {
let mut client = wrapped_client.lock().await;
let mut pong_stream = client.pong_stream();
pong_stream.next().now_or_never();
client.ping().await?;
pong_stream
};
pong_stream.next().await;
Ok(())
}
async fn server_messages_handler(
wrapped_client: Client,
mut reader: FramedRead<ReadHalf<TlsOrTcpStream>, Codec>,
) {
let disconnecting = Self::disconnecting(Arc::clone(&wrapped_client.sync));
pin_mut!(disconnecting);
loop {
let wrapped_message = match future::select(reader.next(), disconnecting).await {
Either::Left((Some(message), unresolved_disconnecting)) => {
disconnecting = unresolved_disconnecting;
message
}
Either::Left((None, _)) => {
error!("{}", TCP_SOCKET_DISCONNECTED_MESSAGE);
break;
}
Either::Right(((), _)) => break,
};
match Disposition::from_output(wrapped_message) {
Disposition::Message(m) => {
Self::handle_server_message(Arc::clone(&wrapped_client.sync), m).await;
continue;
}
Disposition::DecodingError(e) => {
error!("Received invalid server message, err: {}", e);
continue;
}
Disposition::UnrecoverableError(e) => {
error!("TCP socket error, err: {}", e);
break;
}
}
}
let mut client = wrapped_client.lock().await;
if let Some(request_wildcard_sid) = client.request_wildcard_subscription {
client.subscriptions.remove(&request_wildcard_sid);
}
let should_reconnect = !client.state().is_disconnecting();
if let StateTransitionResult::Writer(writer) =
client.state_transition(StateTransition::ToDisconnected)
{
let mut stream = reader.into_inner().unsplit(writer);
if let Err(e) = stream.shutdown().await {
if e.kind() != ErrorKind::NotConnected {
error!("Failed to shutdown TCP stream, err: {}", e);
}
}
} else {
error!("Disconnected with no TCP writer. Unable to shutdown TCP stream.");
debug_assert!(false);
}
if should_reconnect {
tokio::spawn(Self::connect(Client::clone(&wrapped_client)));
}
}
#[allow(clippy::cognitive_complexity)]
async fn handle_server_message(wrapped_client: Arc<Mutex<Self>>, message: ServerMessage) {
match message {
ServerMessage::Info(info) => {
wrapped_client.lock().await.handle_info_message(info);
}
ServerMessage::Msg(msg) => {
let sid = msg.sid();
let mut client = wrapped_client.lock().await;
let subscription = match client.subscriptions.get_mut(&sid) {
Some(subscription) => subscription,
None => {
error!("Received unknown sid '{}'", sid);
debug_assert!(false);
let wrapped_client = Arc::clone(&wrapped_client);
tokio::spawn(async move {
info!("Unsubscribing from unknown sid '{}'", sid);
let mut client = wrapped_client.lock().await;
if let Err(e) = client.unsubscribe(sid).await {
error!("Failed to unsubscribe from '{}', err: {}", sid, e);
}
});
return;
}
};
if subscription.tx.send(msg).await.is_err() {
let wrapped_client = Arc::clone(&wrapped_client);
tokio::spawn(async move {
info!("Unsubscribing from closed sid '{}'", sid);
let mut client = wrapped_client.lock().await;
if let Err(e) = client.unsubscribe(sid).await {
error!("Failed to unsubscribe from sid '{}', err: {}", sid, e);
}
});
}
if let Some(unsubscribe_after) = &mut subscription.unsubscribe_after() {
*unsubscribe_after -= 1;
if *unsubscribe_after == 0 {
client.subscriptions.remove(&sid);
}
}
}
ServerMessage::Ping => {
if let Err(e) = wrapped_client.lock().await.ping_tx.broadcast(()) {
error!("Failed to broadcast {}, err: {}", util::PING_OP_NAME, e);
}
let wrapped_client = Arc::clone(&wrapped_client);
tokio::spawn(async move {
let mut client = wrapped_client.lock().await;
if let Err(e) = client.pong().await {
error!("Failed to send {}, err: {}", util::PONG_OP_NAME, e);
}
});
}
ServerMessage::Pong => {
if let Err(e) = wrapped_client.lock().await.pong_tx.broadcast(()) {
error!("Failed to broadcast {}, err: {}", util::PONG_OP_NAME, e);
}
}
ServerMessage::Ok => {
if let Err(e) = wrapped_client.lock().await.ok_tx.broadcast(()) {
error!("Failed to broadcast {}, err: {}", util::OK_OP_NAME, e);
}
}
ServerMessage::Err(e) => {
error!("Protocol error, err: '{}'", e);
if let Err(e) = wrapped_client.lock().await.err_tx.broadcast(e) {
error!("Failed to broadcast {}, err: {}", util::ERR_OP_NAME, e);
}
}
}
}
fn handle_info_message(&mut self, info: Info) {
if let Err(e) = self.info_tx.broadcast(info) {
error!("Failed to broadcast {}, err: {}", util::INFO_OP_NAME, e);
}
}
fn type_erased_server_messages_handler(
wrapped_client: Client,
reader: FramedRead<ReadHalf<TlsOrTcpStream>, Codec>,
) -> impl std::future::Future<Output = ()> + Send {
Self::server_messages_handler(wrapped_client, reader)
}
async fn disconnecting(wrapped_client: Arc<Mutex<Self>>) {
let mut state_stream = wrapped_client.lock().await.state_stream();
while let Some(state) = state_stream.next().await {
if state.is_disconnecting() {
break;
}
}
}
async fn write_line(
writer: &mut WriteHalf<TlsOrTcpStream>,
control_line: ClientControl<'_>,
) -> Result<()> {
let line = control_line.to_line();
Ok(writer.write_all(line.as_bytes()).await?)
}
async fn send_connect_with_writer(
writer: &mut WriteHalf<TlsOrTcpStream>,
connect: &Connect,
address: &Address,
) -> Result<()> {
let mut connect = connect.clone();
if let Some(authorization) = address.authorization() {
connect.set_authorization(Some(authorization.clone()));
}
Self::write_line(writer, ClientControl::Connect(&connect)).await
}
async fn ping_with_writer(writer: &mut WriteHalf<TlsOrTcpStream>) -> Result<()> {
Self::write_line(writer, ClientControl::Ping).await
}
fn unwrap_server_message_with_timeout(
wrapped_server_message: StdResult<
Option<StdResult<Result<ServerMessage>, io::Error>>,
Elapsed,
>,
waiting_for: &str,
) -> Option<ServerMessage> {
match wrapped_server_message {
Ok(Some(wrapped_message)) => match Disposition::from_output(wrapped_message) {
Disposition::Message(m) => Some(m),
Disposition::DecodingError(e) => {
error!("Received invalid server message, err: {}", e);
None
}
Disposition::UnrecoverableError(e) => {
error!("TCP socket error, err: {}", e);
None
}
},
Ok(None) => {
error!("{}", TCP_SOCKET_DISCONNECTED_MESSAGE);
None
}
Err(_) => {
error!("Timed out waiting for {} message", waiting_for);
None
}
}
}
fn state_transition(&mut self, transition: StateTransition) -> StateTransitionResult {
let previous_client_state = ClientState::from(&self.state);
let previous_state = mem::replace(&mut self.state, ConnectionState::Disconnected);
let (next_state, result) = match (previous_state, transition) {
(ConnectionState::Disconnected, StateTransition::ToConnecting(address)) => (
ConnectionState::Connecting(address),
StateTransitionResult::None,
),
(ConnectionState::Connecting(_), StateTransition::ToConnecting(address)) => (
ConnectionState::Connecting(address),
StateTransitionResult::None,
),
(ConnectionState::Connecting(address), StateTransition::ToConnected(writer)) => (
ConnectionState::Connected(address, writer),
StateTransitionResult::None,
),
(ConnectionState::Connecting(_), StateTransition::ToDisconnecting) => (
ConnectionState::Disconnecting(None),
StateTransitionResult::None,
),
(ConnectionState::Connected(_, writer), StateTransition::ToDisconnecting) => (
ConnectionState::Disconnecting(Some(writer)),
StateTransitionResult::None,
),
(ConnectionState::Connected(_, writer), StateTransition::ToDisconnected) => (
ConnectionState::Disconnected,
StateTransitionResult::Writer(writer),
),
(ConnectionState::Disconnecting(Some(writer)), StateTransition::ToDisconnected) => (
ConnectionState::Disconnected,
StateTransitionResult::Writer(writer),
),
(ConnectionState::Disconnecting(None), StateTransition::ToDisconnected) => {
(ConnectionState::Disconnected, StateTransitionResult::None)
}
(_, transition) => {
error!(
"Invalid transition '{:?}' from '{}'",
transition, previous_client_state,
);
unreachable!();
}
};
self.state = next_state;
let next_client_state = ClientState::from(&self.state);
info!(
"Transitioned to state '{}' from '{}'",
next_client_state, previous_client_state
);
self.state_tx
.broadcast(next_client_state)
.expect("to broadcast state transition");
result
}
}
type DecoderStreamOutput = StdResult<Result<ServerMessage>, io::Error>;
enum Disposition {
Message(ServerMessage),
DecodingError(Error),
UnrecoverableError(io::Error),
}
impl Disposition {
fn from_output(o: DecoderStreamOutput) -> Self {
match o {
Ok(Ok(m)) => Self::Message(m),
Ok(Err(e)) => Self::DecodingError(e),
Err(e) => Self::UnrecoverableError(e),
}
}
}
struct Request {
reply_to: Subject,
wrapped_client: Arc<Mutex<SyncClient>>,
request_inbox_mapping_was_removed: bool,
}
impl Request {
async fn new(wrapped_client: Arc<Mutex<SyncClient>>) -> Result<Self> {
let inbox_uuid = Uuid::new_v4();
let client = wrapped_client.lock().await;
let request_inbox = inbox_uuid.to_simple();
let reply_to: Subject = format!(
"{}.{}.{}",
util::INBOX_PREFIX,
client.request_base_inbox,
request_inbox
)
.parse()?;
Ok(Self {
reply_to,
wrapped_client: wrapped_client.clone(),
request_inbox_mapping_was_removed: false,
})
}
async fn call(
mut self,
subject: &Subject,
payload: &[u8],
duration: Option<Duration>,
) -> Result<Msg> {
let mut rx = {
let mut client = self.wrapped_client.lock().await;
if client.request_wildcard_subscription.is_none() {
let global_reply_to =
format!("{}.{}.*", util::INBOX_PREFIX, client.request_base_inbox).parse()?;
let (sid, rx) = client.subscribe(&global_reply_to, 1024).await?;
client.request_wildcard_subscription = Some(sid);
tokio::spawn(SyncClient::request_wildcard_handler(
self.wrapped_client.clone(),
rx,
));
}
let (tx, rx) = mpsc::channel(1);
client
.request_inbox_mapping
.insert(self.reply_to.clone(), tx);
client
.publish_with_reply(subject, &self.reply_to, payload)
.await?;
rx
};
let next_message = match duration {
Some(duration) => tokio::time::timeout(duration, rx.next()).await?,
None => rx.next().await,
};
match next_message {
Some(response) => {
self.request_inbox_mapping_was_removed = true;
Ok(response)
}
None => Err(Error::NoResponse),
}
}
}
impl Drop for Request {
fn drop(&mut self) {
if self.request_inbox_mapping_was_removed {
return;
}
futures::executor::block_on(async {
let mut client = self.wrapped_client.lock().await;
client.request_inbox_mapping.remove(&self.reply_to);
});
}
}