#[cfg(test)]
use crate::commands::DebugCommands;
use crate::{
ClientError, Error, Future, Result,
client::{
ClientTrackingInvalidationStream, IntoConfig, Message, MonitorStream, Pipeline,
PreparedCommand, PubSubStream, Transaction,
},
commands::{
BitmapCommands, BlockingCommands, BloomCommands, ClusterCommands, ConnectionCommands,
CountMinSketchCommands, CuckooCommands, GenericCommands, GeoCommands, HashCommands,
HyperLogLogCommands, InternalPubSubCommands, JsonCommands, ListCommands, PubSubCommands,
ScriptingCommands, SearchCommands, SentinelCommands, ServerCommands, SetCommands,
SortedSetCommands, StreamCommands, StringCommands, TDigestCommands, TimeSeriesCommands,
TopKCommands, TransactionCommands, VectorSetCommands,
},
network::{
JoinHandle, MsgSender, NetworkHandler, PubSubReceiver, PubSubSender, PushReceiver,
PushSender, ReconnectReceiver, ReconnectSender, ResultReceiver, ResultSender,
ResultsReceiver, ResultsSender, timeout,
},
resp::{Command, CommandArgs, CommandArgsMut, RespResponse, Response, SubscriptionType, cmd},
};
use futures_channel::{mpsc, oneshot};
use log::{info, trace};
use serde::{Serialize, de::DeserializeOwned};
use smallvec::SmallVec;
use std::{future::IntoFuture, sync::Arc, time::Duration};
#[derive(Clone)]
pub struct Client {
msg_sender: Arc<Option<MsgSender>>,
network_task_join_handle: Arc<Option<JoinHandle<()>>>,
reconnect_sender: ReconnectSender,
command_timeout: Duration,
retry_on_error: bool,
connection_tag: Arc<str>,
}
impl Drop for Client {
fn drop(&mut self) {
let mut network_task_join_handle: Arc<Option<JoinHandle<()>>> = Arc::new(None);
std::mem::swap(
&mut network_task_join_handle,
&mut self.network_task_join_handle,
);
if Arc::try_unwrap(network_task_join_handle).is_ok() {
let mut msg_sender: Arc<Option<MsgSender>> = Arc::new(None);
std::mem::swap(&mut msg_sender, &mut self.msg_sender);
if let Ok(Some(msg_sender)) = Arc::try_unwrap(msg_sender) {
msg_sender.close_channel();
}
};
}
}
impl Client {
#[inline]
pub async fn connect(config: impl IntoConfig) -> Result<Self> {
let config = config.into_config()?;
let command_timeout = config.command_timeout;
let retry_on_error = config.retry_on_error;
let (msg_sender, network_task_join_handle, reconnect_sender, connection_tag) =
NetworkHandler::connect(config.into_config()?).await?;
Ok(Self {
msg_sender: Arc::new(Some(msg_sender)),
network_task_join_handle: Arc::new(Some(network_task_join_handle)),
reconnect_sender,
command_timeout,
retry_on_error,
connection_tag,
})
}
#[allow(dead_code)]
pub(crate) fn connection_tag(&self) -> &str {
&self.connection_tag
}
pub async fn close(mut self) -> Result<()> {
let mut network_task_join_handle: Arc<Option<JoinHandle<()>>> = Arc::new(None);
std::mem::swap(
&mut network_task_join_handle,
&mut self.network_task_join_handle,
);
if let Ok(Some(network_task_join_handle)) = Arc::try_unwrap(network_task_join_handle) {
let mut msg_sender: Arc<Option<MsgSender>> = Arc::new(None);
std::mem::swap(&mut msg_sender, &mut self.msg_sender);
if let Ok(Some(msg_sender)) = Arc::try_unwrap(msg_sender) {
msg_sender.close_channel();
network_task_join_handle.await?;
}
};
Ok(())
}
pub fn on_reconnect(&self) -> ReconnectReceiver {
self.reconnect_sender.subscribe()
}
#[inline]
pub async fn send<T: DeserializeOwned>(
&self,
command: impl Into<Command>,
retry_on_error: Option<bool>,
) -> Result<T> {
let response = self.internal_send(command, retry_on_error).await?;
response.to()
}
#[inline]
pub(crate) async fn internal_send(
&self,
command: impl Into<Command>,
retry_on_error: Option<bool>,
) -> Result<RespResponse> {
let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
let message = Message::single(
command.into(),
result_sender,
retry_on_error.unwrap_or(self.retry_on_error),
);
self.send_message(message)?;
if self.command_timeout != Duration::ZERO {
timeout(self.command_timeout, result_receiver).await??
} else {
result_receiver.await?
}
}
#[inline]
pub fn send_and_forget(
&self,
command: impl Into<Command>,
retry_on_error: Option<bool>,
) -> Result<()> {
let message = Message::single_forget(
command.into(),
retry_on_error.unwrap_or(self.retry_on_error),
);
self.send_message(message)?;
Ok(())
}
#[inline]
pub(crate) async fn internal_send_batch(
&self,
commands: SmallVec<[Command; 10]>,
retry_on_error: Option<bool>,
) -> Result<Vec<RespResponse>> {
let (results_sender, results_receiver): (ResultsSender, ResultsReceiver) =
oneshot::channel();
let message = Message::batch(
commands,
results_sender,
retry_on_error.unwrap_or(self.retry_on_error),
);
self.send_message(message)?;
if self.command_timeout != Duration::ZERO {
timeout(self.command_timeout, results_receiver).await??
} else {
results_receiver.await?
}
}
#[inline]
fn send_message(&self, message: Message) -> Result<()> {
if let Some(msg_sender) = &self.msg_sender as &Option<MsgSender> {
trace!(
"[{}], Will enqueue message: {message:?}",
self.connection_tag
);
Ok(msg_sender.unbounded_send(message).map_err(|e| {
info!("{e}");
Error::Client(ClientError::DisconnectedFromServer)
})?)
} else {
Err(Error::Client(ClientError::InvalidChannel))
}
}
#[inline]
pub fn create_transaction(&self) -> Transaction {
Transaction::new(self.clone())
}
#[inline]
pub fn create_pipeline<'a>(&'a self) -> Pipeline<'a> {
Pipeline::new(self)
}
#[inline]
pub fn create_pub_sub(&self) -> PubSubStream {
let (pub_sub_sender, pub_sub_receiver): (PubSubSender, PubSubReceiver) = mpsc::unbounded();
PubSubStream::new(pub_sub_sender, pub_sub_receiver, self.clone())
}
pub fn create_client_tracking_invalidation_stream(
&self,
) -> Result<ClientTrackingInvalidationStream> {
let (push_sender, push_receiver): (PushSender, PushReceiver) = mpsc::unbounded();
let message = Message::client_tracking_invalidation(push_sender);
self.send_message(message)?;
Ok(ClientTrackingInvalidationStream::new(push_receiver))
}
pub(crate) async fn subscribe_from_pub_sub_sender(
&self,
channels: &CommandArgs,
pub_sub_sender: &PubSubSender,
) -> Result<()> {
let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
let pub_sub_senders = channels
.into_iter()
.map(|c| (c, pub_sub_sender.clone()))
.collect();
let message = Message::pub_sub(
cmd("SUBSCRIBE").arg(channels).into(),
result_sender,
SubscriptionType::Channel,
pub_sub_senders,
);
self.send_message(message)?;
result_receiver.await??.to::<()>()
}
pub(crate) async fn psubscribe_from_pub_sub_sender(
&self,
patterns: &CommandArgs,
pub_sub_sender: &PubSubSender,
) -> Result<()> {
let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
let pub_sub_senders = patterns
.into_iter()
.map(|c| (c, pub_sub_sender.clone()))
.collect();
let message = Message::pub_sub(
cmd("PSUBSCRIBE").arg(patterns).into(),
result_sender,
SubscriptionType::Pattern,
pub_sub_senders,
);
self.send_message(message)?;
result_receiver.await??.to::<()>()
}
pub(crate) async fn ssubscribe_from_pub_sub_sender(
&self,
shardchannels: &CommandArgs,
pub_sub_sender: &PubSubSender,
) -> Result<()> {
let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
let pub_sub_senders = shardchannels
.into_iter()
.map(|c| (c, pub_sub_sender.clone()))
.collect();
let message = Message::pub_sub(
cmd("SSUBSCRIBE").key(shardchannels).into(),
result_sender,
SubscriptionType::ShardChannel,
pub_sub_senders,
);
self.send_message(message)?;
result_receiver.await??.to::<()>()
}
}
pub trait ClientPreparedCommand<'a, R> {
fn forget(self) -> Result<()>;
}
impl<'a, R: Response> ClientPreparedCommand<'a, R> for PreparedCommand<'a, &'a Client, R> {
fn forget(self) -> Result<()> {
self.executor
.send_and_forget(self.command, self.retry_on_error)
}
}
impl<'a, R: Response + DeserializeOwned + 'a> IntoFuture for PreparedCommand<'a, &'a Client, R> {
type Output = Result<R>;
type IntoFuture = Future<'a, R>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { self.executor.send(self.command, self.retry_on_error).await })
}
}
impl<'a> BitmapCommands<'a> for &'a Client {}
impl<'a> BloomCommands<'a> for &'a Client {}
impl<'a> ClusterCommands<'a> for &'a Client {}
impl<'a> CountMinSketchCommands<'a> for &'a Client {}
impl<'a> CuckooCommands<'a> for &'a Client {}
impl<'a> ConnectionCommands<'a> for &'a Client {}
#[cfg(test)]
impl<'a> DebugCommands<'a> for &'a Client {}
impl<'a> GenericCommands<'a> for &'a Client {}
impl<'a> GeoCommands<'a> for &'a Client {}
impl<'a> HashCommands<'a> for &'a Client {}
impl<'a> HyperLogLogCommands<'a> for &'a Client {}
impl<'a> InternalPubSubCommands<'a> for &'a Client {}
impl<'a> JsonCommands<'a> for &'a Client {}
impl<'a> ListCommands<'a> for &'a Client {}
impl<'a> ScriptingCommands<'a> for &'a Client {}
impl<'a> SearchCommands<'a> for &'a Client {}
impl<'a> SentinelCommands<'a> for &'a Client {}
impl<'a> ServerCommands<'a> for &'a Client {}
impl<'a> SetCommands<'a> for &'a Client {}
impl<'a> SortedSetCommands<'a> for &'a Client {}
impl<'a> StreamCommands<'a> for &'a Client {}
impl<'a> StringCommands<'a> for &'a Client {}
impl<'a> TDigestCommands<'a> for &'a Client {}
impl<'a> TimeSeriesCommands<'a> for &'a Client {}
impl<'a> TransactionCommands<'a> for &'a Client {}
impl<'a> TopKCommands<'a> for &'a Client {}
impl<'a> VectorSetCommands<'a> for &'a Client {}
impl<'a> PubSubCommands<'a> for &'a Client {
#[inline]
async fn subscribe(self, channels: impl Serialize) -> Result<PubSubStream> {
let channels = CommandArgsMut::default().arg(channels).freeze();
let (pub_sub_sender, pub_sub_receiver): (PubSubSender, PubSubReceiver) = mpsc::unbounded();
self.subscribe_from_pub_sub_sender(&channels, &pub_sub_sender)
.await?;
Ok(PubSubStream::from_channels(
channels,
pub_sub_sender,
pub_sub_receiver,
self.clone(),
))
}
#[inline]
async fn psubscribe(self, patterns: impl Serialize) -> Result<PubSubStream> {
let patterns = CommandArgsMut::default().arg(patterns).freeze();
let (pub_sub_sender, pub_sub_receiver): (PubSubSender, PubSubReceiver) = mpsc::unbounded();
self.psubscribe_from_pub_sub_sender(&patterns, &pub_sub_sender)
.await?;
Ok(PubSubStream::from_patterns(
patterns,
pub_sub_sender,
pub_sub_receiver,
self.clone(),
))
}
#[inline]
async fn ssubscribe(self, shardchannels: impl Serialize) -> Result<PubSubStream> {
let shardchannels = CommandArgsMut::default().arg(shardchannels).freeze();
let (pub_sub_sender, pub_sub_receiver): (PubSubSender, PubSubReceiver) = mpsc::unbounded();
self.ssubscribe_from_pub_sub_sender(&shardchannels, &pub_sub_sender)
.await?;
Ok(PubSubStream::from_shardchannels(
shardchannels,
pub_sub_sender,
pub_sub_receiver,
self.clone(),
))
}
}
impl<'a> BlockingCommands<'a> for &'a Client {
async fn monitor(self) -> Result<MonitorStream> {
let (result_sender, result_receiver): (ResultSender, ResultReceiver) = oneshot::channel();
let (push_sender, push_receiver): (PushSender, PushReceiver) = mpsc::unbounded();
let message = Message::monitor(cmd("MONITOR").into(), result_sender, push_sender);
self.send_message(message)?;
let _bytes = result_receiver.await??;
Ok(MonitorStream::new(push_receiver, self.clone()))
}
}