use super::pub_sub_message::PubSubMessage;
use crate::{
ClientError, Connection, Error, JoinHandle, ReconnectionState, Result, RetryReason,
client::{Config, Message, MessageKind},
commands::InternalPubSubCommands,
resp::{ClientReplyMode, CommandKind, RespResponse, SubscriptionType, cmd},
spawn, timeout,
};
use bytes::Bytes;
use futures_channel::{mpsc, oneshot};
use futures_util::{FutureExt, StreamExt, select};
use log::{Level, debug, error, info, log_enabled, trace, warn};
use smallvec::SmallVec;
use std::{
collections::{HashMap, VecDeque},
sync::Arc,
task::Poll,
time::Duration,
};
use tokio::{sync::broadcast, time::Instant};
pub(crate) type MsgSender = mpsc::UnboundedSender<Message>;
pub(crate) type MsgReceiver = mpsc::UnboundedReceiver<Message>;
pub(crate) type ResultSender = oneshot::Sender<Result<RespResponse>>;
pub(crate) type ResultReceiver = oneshot::Receiver<Result<RespResponse>>;
pub(crate) type ResultsSender = oneshot::Sender<Result<Vec<RespResponse>>>;
pub(crate) type ResultsReceiver = oneshot::Receiver<Result<Vec<RespResponse>>>;
pub(crate) type PubSubSender = mpsc::UnboundedSender<Result<RespResponse>>;
pub(crate) type PubSubReceiver = mpsc::UnboundedReceiver<Result<RespResponse>>;
pub(crate) type PushSender = mpsc::UnboundedSender<Result<RespResponse>>;
pub(crate) type PushReceiver = mpsc::UnboundedReceiver<Result<RespResponse>>;
pub(crate) type ReconnectSender = broadcast::Sender<()>;
pub(crate) type ReconnectReceiver = broadcast::Receiver<()>;
type PendingResult = (ResultSender, Result<RespResponse>);
type PendingResultBatch = (ResultsSender, Result<Vec<RespResponse>>);
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Status {
Disconnected,
Connected,
EnteringMonitor,
Monitor,
LeavingMonitor,
}
struct MessageToSend {
pub message: Message,
pub attempts: usize,
}
impl MessageToSend {
pub fn new(message: Message) -> Self {
Self {
message,
attempts: 0,
}
}
}
#[derive(Debug)]
struct MessageToReceive {
pub message: Message,
pub num_commands: usize,
pub attempts: usize,
}
impl MessageToReceive {
pub fn new(message: Message, num_commands: usize, attempts: usize) -> Self {
Self {
message,
num_commands,
attempts,
}
}
}
struct PendingSubscription {
pub channel_or_pattern: Bytes,
pub subscription_type: SubscriptionType,
pub sender: PubSubSender,
pub more_to_come: bool,
}
pub(crate) struct NetworkHandler {
status: Status,
connection: Connection,
msg_sender: MsgSender,
msg_receiver: MsgReceiver,
messages_to_send: VecDeque<MessageToSend>,
messages_to_receive: VecDeque<MessageToReceive>,
pending_subscriptions: VecDeque<PendingSubscription>,
pending_unsubscriptions: VecDeque<HashMap<Bytes, SubscriptionType>>,
subscriptions: HashMap<Bytes, (SubscriptionType, PubSubSender)>,
is_reply_on: bool,
push_sender: Option<PushSender>,
pending_responses: Option<Vec<RespResponse>>,
reconnect_sender: ReconnectSender,
auto_resubscribe: bool,
auto_remonitor: bool,
tag: Arc<str>,
reconnection_state: ReconnectionState,
pending_results: SmallVec<[PendingResult; 64]>,
pending_result_batches: SmallVec<[PendingResultBatch; 64]>,
}
impl NetworkHandler {
pub async fn connect(
config: Config,
) -> Result<(MsgSender, JoinHandle<()>, ReconnectSender, Arc<str>)> {
let auto_resubscribe = config.auto_resubscribe;
let auto_remonitor = config.auto_remonitor;
let reconnection_config = config.reconnection.clone();
let connection = Connection::connect(config).await?;
let (msg_sender, msg_receiver): (MsgSender, MsgReceiver) = mpsc::unbounded();
let (reconnect_sender, _): (ReconnectSender, ReconnectReceiver) = broadcast::channel(32);
let tag = connection.tag().to_owned();
let mut network_handler = NetworkHandler {
status: Status::Connected,
connection,
msg_sender: msg_sender.clone(),
msg_receiver,
messages_to_send: VecDeque::new(),
messages_to_receive: VecDeque::new(),
pending_subscriptions: VecDeque::new(),
pending_unsubscriptions: VecDeque::new(),
subscriptions: HashMap::new(),
is_reply_on: true,
push_sender: None,
pending_responses: None,
reconnect_sender: reconnect_sender.clone(),
auto_resubscribe,
auto_remonitor,
tag: tag.clone(),
reconnection_state: ReconnectionState::new(reconnection_config),
pending_results: SmallVec::new(),
pending_result_batches: SmallVec::new(),
};
let join_handle = spawn(async move {
if let Err(e) = network_handler.network_loop().await {
error!("[{}] network loop ended in error: {e}", network_handler.tag);
}
});
Ok((msg_sender, join_handle, reconnect_sender, tag))
}
async fn network_loop(&mut self) -> Result<()> {
loop {
select! {
msg = self.msg_receiver.next().fuse() => {
if !self.try_handle_message(msg).await { break; }
},
result = self.connection.read().fuse() => {
if !self.try_handle_result(result).await { break; }
}
}
}
debug!("[{}] end of network loop", self.tag);
Ok(())
}
async fn try_handle_message(&mut self, mut msg: Option<Message>) -> bool {
let is_channel_closed: bool;
loop {
if let Some(msg) = msg {
self.handle_message(msg);
} else {
is_channel_closed = true;
break;
}
match self.msg_receiver.try_recv() {
Ok(m) => msg = Some(m),
Err(_) => {
is_channel_closed = false;
break;
}
}
}
if self.status != Status::Disconnected {
self.send_messages().await
}
!is_channel_closed
}
fn handle_message(&mut self, mut msg: Message) {
trace!(
"[{}][{:?}] Will handle message: {msg:?}",
self.tag, self.status
);
let mut collision_error = None;
match &self.status {
Status::Connected => {
match &mut msg.kind {
MessageKind::PubSub {
subscription_type,
subscriptions,
..
} => {
for (channel_or_pattern, _sender) in subscriptions.iter() {
if self.subscriptions.contains_key(channel_or_pattern) {
debug!(
"[{}][{:?}] There is already a subscription on channel `{}`",
self.tag,
self.status,
String::from_utf8_lossy(channel_or_pattern)
);
collision_error =
Some(Error::Client(ClientError::AlreadySubscribed));
break;
}
}
if collision_error.is_none() {
let subscriptions = std::mem::take(subscriptions);
let num_pending_subscriptions = subscriptions.len();
let pending_subscriptions = subscriptions.into_iter().enumerate().map(
|(index, (channel_or_pattern, sender))| PendingSubscription {
channel_or_pattern,
subscription_type: *subscription_type,
sender,
more_to_come: index < num_pending_subscriptions - 1,
},
);
self.pending_subscriptions.extend(pending_subscriptions);
}
}
MessageKind::Monitor { push_sender, .. } => {
self.status = Status::EnteringMonitor;
let push_sender = push_sender.take();
if let Some(push_sender) = push_sender {
debug!("[{}] Registering MONITOR push_sender", self.tag);
self.push_sender = Some(push_sender);
}
}
MessageKind::Invalidation { push_sender } => {
let push_sender = push_sender.take();
if let Some(push_sender) = push_sender {
debug!("[{}] Registering Invalidation push_sender", self.tag);
self.push_sender = Some(push_sender);
}
return; }
MessageKind::Single { command, .. } => {
if let CommandKind::Unsbuscribe(subscription_type) = command.kind() {
self.pending_unsubscriptions.push_back(
command.args().map(|a| (a, *subscription_type)).collect(),
);
}
}
_ => (),
}
if let Some(err) = collision_error {
msg.send_error(&self.tag, err);
} else {
self.messages_to_send.push_back(MessageToSend::new(msg));
}
}
Status::Disconnected => {
if msg.retry_on_error {
debug!(
"[{}] network disconnected, queuing command: {:?}",
self.tag,
msg.commands()
);
self.messages_to_send.push_back(MessageToSend::new(msg));
} else {
debug!(
"[{}] network disconnected, sending command in error: {:?}",
self.tag,
msg.commands()
);
msg.send_error(&self.tag, Error::DisconnectedByPeer);
}
}
Status::EnteringMonitor => self.messages_to_send.push_back(MessageToSend::new(msg)),
Status::Monitor => {
for command in msg.commands() {
if matches!(command.kind(), CommandKind::Reset) {
self.status = Status::LeavingMonitor;
}
}
self.messages_to_send.push_back(MessageToSend::new(msg));
}
Status::LeavingMonitor => {
self.messages_to_send.push_back(MessageToSend::new(msg));
}
}
}
async fn send_messages(&mut self) {
if log_enabled!(Level::Debug) {
let num_commands = self
.messages_to_send
.iter()
.fold(0, |sum, msg| sum + msg.message.num_commands());
if num_commands > 1 {
debug!("[{}] sending batch of {} commands", self.tag, num_commands);
}
}
let mut retry_reasons = SmallVec::<[RetryReason; 10]>::new();
let start_idx = self.messages_to_receive.len();
while let Some(message_to_send) = self.messages_to_send.pop_front() {
let mut msg = message_to_send.message;
let reasons = msg.retry_reasons.take();
if let Some(reasons) = reasons {
retry_reasons.extend(reasons);
}
let mut num_commands_to_receive: usize = 0;
for command in msg.commands_mut() {
match command.kind() {
CommandKind::ClientReply(ClientReplyMode::On) => self.is_reply_on = true,
CommandKind::ClientReply(ClientReplyMode::Off | ClientReplyMode::Skip) => {
self.is_reply_on = false
}
_ => (),
}
if self.is_reply_on {
num_commands_to_receive += 1;
}
if let Err(e) = self.connection.feed(command, &retry_reasons).await {
error!("[{}] Feed error: {e}", self.tag);
msg.send_error(&self.tag, e);
return;
}
}
if num_commands_to_receive > 0 {
self.messages_to_receive.push_back(MessageToReceive::new(
msg,
num_commands_to_receive,
message_to_send.attempts,
));
}
}
if let Err(e) = self.connection.flush().await {
error!("[{}] Flush error: {e}", self.tag);
while self.messages_to_receive.len() > start_idx {
if let Some(msg_to_receive) = self.messages_to_receive.pop_back() {
msg_to_receive.message.send_error(&self.tag, e.clone());
}
}
}
}
async fn try_handle_result(&mut self, result: Option<Result<RespResponse>>) -> bool {
let Some(result) = result else {
return self.reconnect().await;
};
self.handle_result(result);
while let Poll::Ready(result) = self.connection.try_read() {
let Some(result) = result else {
return self.reconnect().await;
};
self.handle_result(result);
}
for (sender, response) in self.pending_results.drain(..) {
if let Err(e) = sender.send(response) {
warn!(
"[{}] Cannot send value to caller because receiver is not there anymore: {e:?}",
self.tag
);
}
}
for (sender, results) in self.pending_result_batches.drain(..) {
if let Err(e) = sender.send(results) {
warn!(
"[{}] Cannot send value to caller because receiver is not there anymore: {e:?}",
self.tag
);
}
}
true
}
fn handle_result(&mut self, result: Result<RespResponse>) {
match self.status {
Status::Disconnected => (),
Status::Connected => match &result {
Ok(response) if response.is_push() => {
if let Some(response) = self.try_match_pubsub_message(result) {
if response.is_err() {
self.receive_result(response);
} else {
match &mut self.push_sender {
Some(push_sender) => {
if let Err(e) = push_sender.unbounded_send(response) {
warn!(
"[{}] Cannot send push message result to caller: {e}",
self.tag
);
}
}
None => {
warn!(
"[{}] Received a push message with no sender configured: {response:?}",
self.tag
)
}
}
}
}
}
_ => {
self.receive_result(result);
}
},
Status::EnteringMonitor => {
self.receive_result(result);
self.status = Status::Monitor;
}
Status::Monitor => match &result {
Ok(response) if response.is_monitor() => {
if let Some(push_sender) = &mut self.push_sender
&& let Err(e) = push_sender.unbounded_send(result)
{
warn!("[{}] Cannot send monitor result to caller: {e}", self.tag);
}
}
_ => self.receive_result(result),
},
Status::LeavingMonitor => match &result {
Ok(response) if response.is_monitor() => {
if let Some(push_sender) = &mut self.push_sender
&& let Err(e) = push_sender.unbounded_send(result)
{
warn!("[{}] Cannot send monitor result to caller: {e}", self.tag);
}
}
_ => {
self.receive_result(result);
self.status = Status::Connected;
}
},
}
}
fn receive_result(&mut self, result: Result<RespResponse>) {
match self.messages_to_receive.front_mut() {
Some(message_to_receive) => {
log::trace!("message_to_receive: {:?}", message_to_receive);
if message_to_receive.num_commands == 1 || result.is_err() {
if let Some(mut message_to_receive) = self.messages_to_receive.pop_front() {
let mut should_retry = false;
if let Err(Error::Retry(_)) = &result {
should_retry = true;
} else if message_to_receive.message.retry_reasons.is_some() {
should_retry = true;
}
if should_retry {
if let Err(Error::Retry(reasons)) = result {
if let Some(retry_reasons) =
&mut message_to_receive.message.retry_reasons
{
retry_reasons.extend(reasons);
} else {
message_to_receive.message.retry_reasons =
Some(SmallVec::<[RetryReason; 10]>::from_iter(reasons));
}
}
let result = self.msg_sender.unbounded_send(message_to_receive.message);
if let Err(e) = result {
error!("[{}] Cannot retry message: {e}", self.tag);
}
} else {
trace!(
"[{}] Will respond to: {:?}",
self.tag, message_to_receive.message
);
match message_to_receive.message.kind {
MessageKind::Single {
result_sender: Some(result_sender),
..
}
| MessageKind::PubSub { result_sender, .. }
| MessageKind::Monitor { result_sender, .. } => {
self.pending_results.push((result_sender, result));
}
MessageKind::Batch { results_sender, .. } => match result {
Ok(resp_buf) => {
let pending_replies = self.pending_responses.take();
if let Some(mut pending_replies) = pending_replies {
pending_replies.push(resp_buf);
self.pending_result_batches
.push((results_sender, Ok(pending_replies)));
} else {
self.pending_result_batches
.push((results_sender, Ok(vec![resp_buf])));
}
}
Err(e) => {
self.pending_result_batches.push((results_sender, Err(e)));
}
},
MessageKind::Invalidation { .. }
| MessageKind::Single {
result_sender: None,
..
} => {
debug!("[{}] forget value {result:?}", self.tag)
}
}
}
}
} else {
if self.pending_responses.is_none() {
self.pending_responses = Some(Vec::new());
}
if let Some(pending_replies) = &mut self.pending_responses {
match result {
Ok(value) => {
pending_replies.push(value);
message_to_receive.num_commands -= 1;
}
Err(Error::Retry(reasons)) => {
if let Some(retry_reasons) =
&mut message_to_receive.message.retry_reasons
{
retry_reasons.extend(reasons);
} else {
message_to_receive.message.retry_reasons =
Some(SmallVec::<[RetryReason; 10]>::from_iter(reasons));
}
}
_ => (),
}
}
}
}
None => {
assert!(
result.is_err(),
"[{}] Received unexpected message: {result:?}",
self.tag
);
}
}
}
fn try_match_pubsub_message(
&mut self,
value: Result<RespResponse>,
) -> Option<Result<RespResponse>> {
if let Ok(ref_value) = &value {
if let Ok(pub_sub_message) = PubSubMessage::try_from(ref_value) {
match pub_sub_message {
PubSubMessage::Message(channel_or_pattern, _)
| PubSubMessage::SMessage(channel_or_pattern, _) => {
match self.subscriptions.get_mut(channel_or_pattern) {
Some((_subscription_type, pub_sub_sender)) => {
if let Err(e) = pub_sub_sender.unbounded_send(value) {
let error_desc = e.to_string();
if let Ok(ref_value) = &e.into_inner()
&& let Some(
PubSubMessage::Message(channel_or_pattern, _)
| PubSubMessage::SMessage(channel_or_pattern, _),
) = PubSubMessage::try_from(ref_value).ok()
{
warn!(
"[{}] Cannot send pub/sub message to caller from channel `{}`: {error_desc}",
self.tag,
String::from_utf8_lossy(channel_or_pattern)
);
}
}
}
None => {
error!(
"[{}] Unexpected message on channel `{}`",
self.tag,
String::from_utf8_lossy(channel_or_pattern)
);
}
}
None
}
PubSubMessage::Subscribe(channel_or_pattern)
| PubSubMessage::PSubscribe(channel_or_pattern)
| PubSubMessage::SSubscribe(channel_or_pattern) => {
if let Some(pending_sub) = self.pending_subscriptions.pop_front() {
if pending_sub.channel_or_pattern == channel_or_pattern {
if self
.subscriptions
.insert(
pending_sub.channel_or_pattern,
(pending_sub.subscription_type, pending_sub.sender),
)
.is_some()
{
return Some(Err(Error::Client(
ClientError::AlreadySubscribed,
)));
}
if pending_sub.more_to_come {
return None;
}
} else {
error!(
"[{}] Unexpected subscription confirmation on channel `{}`",
self.tag,
String::from_utf8_lossy(channel_or_pattern)
);
}
} else {
error!(
"[{}] Cannot find pending subscription for channel `{}`",
self.tag,
String::from_utf8_lossy(channel_or_pattern)
);
}
self.receive_result(Ok(RespResponse::ok()));
None
}
PubSubMessage::Unsubscribe(channel_or_pattern)
| PubSubMessage::PUnsubscribe(channel_or_pattern)
| PubSubMessage::SUnsubscribe(channel_or_pattern) => {
self.subscriptions.remove(channel_or_pattern);
if let Some(remaining) = self.pending_unsubscriptions.front_mut() {
if remaining.len() > 1 {
if remaining.remove(channel_or_pattern).is_none() {
error!(
"[{}] Cannot find channel or pattern to remove: `{}`",
self.tag,
String::from_utf8_lossy(channel_or_pattern)
);
}
None
} else {
let Some(mut remaining) = self.pending_unsubscriptions.pop_front()
else {
error!(
"[{}] Cannot find channel or pattern to remove: `{}`",
self.tag,
String::from_utf8_lossy(channel_or_pattern)
);
return None;
};
if remaining.remove(channel_or_pattern).is_none() {
error!(
"[{}] Cannot find channel or pattern to remove: `{}`",
self.tag,
String::from_utf8_lossy(channel_or_pattern)
);
return None;
}
self.receive_result(Ok(RespResponse::ok()));
None
}
} else {
Some(value)
}
}
PubSubMessage::PMessage(pattern, channel, _) => {
match self.subscriptions.get_mut(pattern) {
Some((_subscription_type, pub_sub_sender)) => {
if let Err(e) = pub_sub_sender.unbounded_send(value) {
warn!(
"[{}] Cannot send pub/sub message to caller: {e}",
self.tag
);
}
}
None => {
error!(
"[{}] Unexpected message on channel `{}` for pattern `{}`",
self.tag,
String::from_utf8_lossy(channel),
String::from_utf8_lossy(pattern)
);
}
}
None
}
}
} else {
Some(value)
}
} else {
Some(value)
}
}
async fn reconnect(&mut self) -> bool {
debug!("[{}] reconnecting...", self.tag);
let old_status = self.status;
self.status = Status::Disconnected;
while let Some(message_to_receive) = self.messages_to_receive.front() {
if !message_to_receive.message.retry_on_error {
if let Some(message_to_receive) = self.messages_to_receive.pop_front() {
message_to_receive
.message
.send_error(&self.tag, Error::DisconnectedByPeer);
}
} else {
break;
}
}
while let Some(message_to_send) = self.messages_to_send.front() {
if !message_to_send.message.retry_on_error {
if let Some(message_to_send) = self.messages_to_send.pop_front() {
message_to_send
.message
.send_error(&self.tag, Error::DisconnectedByPeer);
}
} else {
break;
}
}
loop {
if let Some(delay) = self.reconnection_state.next_delay() {
debug!("[{}] Waiting {delay} ms before reconnection", self.tag);
let start = Instant::now();
let end = start.checked_add(Duration::from_millis(delay)).unwrap();
loop {
let delay = end.duration_since(Instant::now());
let result = timeout(delay, self.msg_receiver.next().fuse()).await;
if let Ok(msg) = result {
if !self.try_handle_message(msg).await {
return false;
}
} else {
break;
}
}
} else {
warn!("[{}] Max reconnection attempts reached", self.tag);
while let Some(message_to_receive) = self.messages_to_receive.pop_front() {
message_to_receive
.message
.send_error(&self.tag, Error::DisconnectedByPeer);
}
while let Some(message_to_send) = self.messages_to_send.pop_front() {
message_to_send
.message
.send_error(&self.tag, Error::DisconnectedByPeer);
}
return false;
}
if let Err(e) = self.connection.reconnect().await {
error!("[{}] Failed to reconnect: {e:?}", self.tag);
continue;
}
if self.auto_resubscribe
&& let Err(e) = self.auto_resubscribe().await
{
error!("[{}] Failed to reconnect: {e:?}", self.tag);
continue;
}
if self.auto_remonitor
&& let Err(e) = self.auto_remonitor(old_status).await
{
error!("[{}] Failed to reconnect: {e:?}", self.tag);
continue;
}
if let Err(e) = self.reconnect_sender.send(()) {
debug!(
"[{}] Cannot send reconnect notification to clients: {e}",
self.tag
);
}
while let Some(message_to_receive) = self.messages_to_receive.pop_back() {
self.messages_to_send.push_front(MessageToSend {
message: message_to_receive.message,
attempts: message_to_receive.attempts,
});
}
self.send_messages().await;
if let Status::Monitor | Status::EnteringMonitor = old_status {
if self.push_sender.is_some() {
self.status = Status::Monitor;
}
} else {
self.status = Status::Connected;
}
info!("[{}] reconnected!", self.tag);
self.reconnection_state.reset_attempts();
return true;
}
}
async fn auto_resubscribe(&mut self) -> Result<()> {
if !self.subscriptions.is_empty() {
for (channel_or_pattern, (subscription_type, _)) in &self.subscriptions {
match subscription_type {
SubscriptionType::Channel => {
self.connection.subscribe(channel_or_pattern).await?;
}
SubscriptionType::Pattern => {
self.connection.psubscribe(channel_or_pattern).await?;
}
SubscriptionType::ShardChannel => {
self.connection.ssubscribe(channel_or_pattern).await?;
}
}
}
}
if !self.pending_subscriptions.is_empty() {
for pending_sub in self.pending_subscriptions.drain(..) {
match pending_sub.subscription_type {
SubscriptionType::Channel => {
self.connection
.subscribe(pending_sub.channel_or_pattern.clone())
.await?;
}
SubscriptionType::Pattern => {
self.connection
.psubscribe(pending_sub.channel_or_pattern.clone())
.await?;
}
SubscriptionType::ShardChannel => {
self.connection
.ssubscribe(pending_sub.channel_or_pattern.clone())
.await?;
}
}
self.subscriptions.insert(
pending_sub.channel_or_pattern,
(pending_sub.subscription_type, pending_sub.sender),
);
}
}
if !self.pending_unsubscriptions.is_empty() {
for mut map in self.pending_unsubscriptions.drain(..) {
for (channel_or_pattern, subscription_type) in map.drain() {
match subscription_type {
SubscriptionType::Channel => {
self.connection
.subscribe(channel_or_pattern.clone())
.await?;
}
SubscriptionType::Pattern => {
self.connection
.psubscribe(channel_or_pattern.clone())
.await?;
}
SubscriptionType::ShardChannel => {
self.connection
.ssubscribe(channel_or_pattern.clone())
.await?;
}
}
self.subscriptions.remove(&channel_or_pattern);
}
}
}
Ok(())
}
async fn auto_remonitor(&mut self, old_status: Status) -> Result<()> {
if let Status::Monitor | Status::EnteringMonitor = old_status {
self.connection.send(&cmd("MONITOR").into()).await?;
}
Ok(())
}
}