use crate::{
clients::WithOptions,
commands,
error::Error,
interfaces::{default_send_command, FredResult},
modules::inner::ClientInner,
protocol::command::Command,
router::commands as router_commands,
types::{
config::{Config, ConnectionConfig, Options, PerformanceConfig, ReconnectPolicy, Server},
ClientState,
ConnectHandle,
CustomCommand,
FromValue,
InfoKind,
Resp3Frame,
RespVersion,
Value,
Version,
},
utils,
};
use arc_swap::ArcSwapAny;
use futures::{Stream, StreamExt};
use std::{future::Future, sync::Arc};
use tokio::sync::mpsc::{
channel as bounded_channel,
error::{TryRecvError, TrySendError},
unbounded_channel,
Receiver as BoundedReceiver,
Sender as BoundedSender,
UnboundedReceiver,
UnboundedSender,
};
pub use tokio::{
spawn,
sync::{
broadcast::{
self,
error::SendError as BroadcastSendError,
Receiver as BroadcastReceiver,
Sender as BroadcastSender,
},
oneshot::{channel as oneshot_channel, Receiver as OneshotReceiver, Sender as OneshotSender},
RwLock as AsyncRwLock,
},
task::JoinHandle,
time::sleep,
};
use tokio_stream::wrappers::{ReceiverStream, UnboundedReceiverStream};
#[cfg(feature = "dynamic-pool")]
use arc_swap::ArcSwapOption;
enum SenderKind<T: Send + 'static> {
Bounded(BoundedSender<T>),
Unbounded(UnboundedSender<T>),
}
impl<T: Send + 'static> Clone for SenderKind<T> {
fn clone(&self) -> Self {
match self {
SenderKind::Bounded(tx) => SenderKind::Bounded(tx.clone()),
SenderKind::Unbounded(tx) => SenderKind::Unbounded(tx.clone()),
}
}
}
pub struct Sender<T: Send + 'static> {
tx: SenderKind<T>,
}
impl<T: Send + 'static> Clone for Sender<T> {
fn clone(&self) -> Self {
Sender { tx: self.tx.clone() }
}
}
impl<T: Send + 'static> Sender<T> {
pub async fn send(&self, val: T) -> Result<(), T> {
match self.tx {
SenderKind::Bounded(ref tx) => tx.send(val).await.map_err(|e| e.0),
SenderKind::Unbounded(ref tx) => tx.send(val).map_err(|e| e.0),
}
}
pub fn try_send(&self, val: T) -> Result<(), TrySendError<T>> {
match self.tx {
SenderKind::Bounded(ref tx) => tx.try_send(val),
SenderKind::Unbounded(ref tx) => tx.send(val).map_err(|e| TrySendError::Closed(e.0)),
}
}
}
enum ReceiverKind<T: Send + 'static> {
Bounded(BoundedReceiver<T>),
Unbounded(UnboundedReceiver<T>),
}
pub struct Receiver<T: Send + 'static> {
rx: ReceiverKind<T>,
}
impl<T: Send + 'static> Receiver<T> {
pub async fn recv(&mut self) -> Option<T> {
match self.rx {
ReceiverKind::Bounded(ref mut tx) => tx.recv().await,
ReceiverKind::Unbounded(ref mut tx) => tx.recv().await,
}
}
pub fn try_recv(&mut self) -> Result<T, TryRecvError> {
match self.rx {
ReceiverKind::Bounded(ref mut tx) => tx.try_recv(),
ReceiverKind::Unbounded(ref mut tx) => tx.try_recv(),
}
}
pub fn into_stream(self) -> impl Stream<Item = T> + 'static {
match self.rx {
ReceiverKind::Bounded(tx) => ReceiverStream::new(tx).boxed(),
ReceiverKind::Unbounded(tx) => UnboundedReceiverStream::new(tx).boxed(),
}
}
}
pub fn channel<T: Send + 'static>(size: usize) -> (Sender<T>, Receiver<T>) {
if size == 0 {
let (tx, rx) = unbounded_channel();
(
Sender {
tx: SenderKind::Unbounded(tx),
},
Receiver {
rx: ReceiverKind::Unbounded(rx),
},
)
} else {
let (tx, rx) = bounded_channel(size);
(
Sender {
tx: SenderKind::Bounded(tx),
},
Receiver {
rx: ReceiverKind::Bounded(rx),
},
)
}
}
#[cfg(any(feature = "dns", feature = "trust-dns-resolver"))]
use crate::protocol::types::Resolve;
#[cfg(feature = "i-server")]
use crate::types::ShutdownFlags;
pub type RefCount<T> = Arc<T>;
pub type AtomicBool = std::sync::atomic::AtomicBool;
pub type AtomicUsize = std::sync::atomic::AtomicUsize;
pub type Mutex<T> = parking_lot::Mutex<T>;
pub type RwLock<T> = parking_lot::RwLock<T>;
pub type RefSwap<T> = ArcSwapAny<T>;
#[cfg(feature = "dynamic-pool")]
pub type RefSwapOption<T> = ArcSwapOption<T>;
pub fn broadcast_send<T: Clone, F: Fn(&T)>(tx: &BroadcastSender<T>, msg: &T, func: F) {
if let Err(BroadcastSendError(val)) = tx.send(msg.clone()) {
func(&val);
}
}
pub fn broadcast_channel<T: Clone>(capacity: usize) -> (BroadcastSender<T>, BroadcastReceiver<T>) {
broadcast::channel(capacity)
}
pub trait ClientLike: Clone + Send + Sync + Sized {
#[doc(hidden)]
fn inner(&self) -> &Arc<ClientInner>;
#[doc(hidden)]
fn change_command(&self, _: &mut Command) {}
#[doc(hidden)]
fn send_command<C>(&self, command: C) -> Result<(), Error>
where
C: Into<Command>,
{
let mut command: Command = command.into();
self.change_command(&mut command);
default_send_command(self.inner(), command)
}
fn id(&self) -> &str {
&self.inner().id
}
fn client_config(&self) -> Config {
self.inner().config.as_ref().clone()
}
fn client_reconnect_policy(&self) -> Option<ReconnectPolicy> {
self.inner().policy.read().clone()
}
fn connection_config(&self) -> &ConnectionConfig {
self.inner().connection.as_ref()
}
fn protocol_version(&self) -> RespVersion {
if self.inner().is_resp3() {
RespVersion::RESP3
} else {
RespVersion::RESP2
}
}
fn has_reconnect_policy(&self) -> bool {
self.inner().policy.read().is_some()
}
fn is_clustered(&self) -> bool {
self.inner().config.server.is_clustered()
}
fn uses_sentinels(&self) -> bool {
self.inner().config.server.is_sentinel()
}
fn update_perf_config(&self, config: PerformanceConfig) {
self.inner().update_performance_config(config);
}
fn perf_config(&self) -> PerformanceConfig {
self.inner().performance_config()
}
fn state(&self) -> ClientState {
self.inner().state.read().clone()
}
fn is_connected(&self) -> bool {
*self.inner().state.read() == ClientState::Connected
}
fn active_connections(&self) -> Vec<Server> {
self.inner().active_connections()
}
fn server_version(&self) -> Option<Version> {
self.inner().server_state.read().kind.server_version()
}
#[cfg(feature = "dns")]
#[cfg_attr(docsrs, doc(cfg(feature = "dns")))]
fn set_resolver(&self, resolver: Arc<dyn Resolve>) -> impl Future + Send {
async move { self.inner().set_resolver(resolver).await }
}
fn connect(&self) -> ConnectHandle {
let inner = self.inner().clone();
utils::reset_router_task(&inner);
tokio::spawn(async move {
inner.backchannel.clear_router_state(&inner).await;
let result = router_commands::start(&inner).await;
_trace!(inner, "Ending connection task with {:?}", result);
if let Err(ref error) = result {
if !error.is_canceled() {
inner.notifications.broadcast_connect(Err(error.clone()));
}
}
inner.cas_client_state(ClientState::Disconnecting, ClientState::Disconnected);
result
})
}
fn force_reconnection(&self) -> impl Future<Output = FredResult<()>> + Send {
async move { commands::server::force_reconnection(self.inner()).await }
}
fn wait_for_connect(&self) -> impl Future<Output = FredResult<()>> + Send {
async move {
if utils::read_locked(&self.inner().state) == ClientState::Connected {
debug!("{}: Client is already connected.", self.inner().id);
Ok(())
} else {
self.inner().notifications.connect.load().subscribe().recv().await?
}
}
}
fn init(&self) -> impl Future<Output = FredResult<ConnectHandle>> + Send {
async move {
let mut rx = { self.inner().notifications.connect.load().subscribe() };
let task = self.connect();
let error = rx.recv().await.map_err(Error::from).and_then(|r| r).err();
if let Some(error) = error {
utils::reset_router_task(self.inner());
Err(error)
} else {
Ok(task)
}
}
}
fn quit(&self) -> impl Future<Output = FredResult<()>> + Send {
async move { commands::server::quit(self).await }
}
#[cfg(feature = "i-server")]
#[cfg_attr(docsrs, doc(cfg(feature = "i-server")))]
fn shutdown(&self, flags: Option<ShutdownFlags>) -> impl Future<Output = FredResult<()>> + Send {
async move { commands::server::shutdown(self, flags).await }
}
fn flushall<R>(&self, r#async: bool) -> impl Future<Output = FredResult<R>> + Send
where
R: FromValue,
{
async move { commands::server::flushall(self, r#async).await?.convert() }
}
fn flushall_cluster(&self) -> impl Future<Output = FredResult<()>> + Send {
async move { commands::server::flushall_cluster(self).await }
}
fn ping<R>(&self, message: Option<String>) -> impl Future<Output = FredResult<R>> + Send
where
R: FromValue,
{
async move { commands::server::ping(self, message).await?.convert() }
}
fn info<R>(&self, section: Option<InfoKind>) -> impl Future<Output = FredResult<R>> + Send
where
R: FromValue,
{
async move { commands::server::info(self, section).await?.convert() }
}
fn custom<R, T>(&self, cmd: CustomCommand, args: Vec<T>) -> impl Future<Output = FredResult<R>> + Send
where
R: FromValue,
T: TryInto<Value> + Send,
T::Error: Into<Error> + Send,
{
async move {
let args = utils::try_into_vec(args)?;
commands::server::custom(self, cmd, args).await?.convert()
}
}
fn custom_raw<T>(&self, cmd: CustomCommand, args: Vec<T>) -> impl Future<Output = FredResult<Resp3Frame>> + Send
where
T: TryInto<Value> + Send,
T::Error: Into<Error> + Send,
{
async move {
let args = utils::try_into_vec(args)?;
commands::server::custom_raw(self, cmd, args).await
}
}
fn with_options(&self, options: &Options) -> WithOptions<Self> {
WithOptions {
client: self.clone(),
options: options.clone(),
}
}
}
pub fn spawn_event_listener<T, F, Fut>(mut rx: BroadcastReceiver<T>, func: F) -> JoinHandle<FredResult<()>>
where
T: Clone + Send + 'static,
Fut: Future<Output = FredResult<()>> + Send + 'static,
F: Fn(T) -> Fut + Send + 'static,
{
tokio::spawn(async move {
let mut result = Ok(());
while let Ok(val) = rx.recv().await {
if let Err(err) = func(val).await {
result = Err(err);
break;
}
}
result
})
}