use std::{
collections::{HashMap, HashSet},
fmt,
future::Future,
io,
ops::Deref,
pin::Pin,
sync::{Arc, Mutex},
task::{self, Poll},
time::Duration,
};
mod request;
mod routing;
use crate::{
AsyncConnectionConfig, Cmd, ConnectionInfo, ErrorKind, IntoConnectionInfo, RedisError,
RedisFuture, RedisResult, ToRedisArgs, Value,
aio::{ConnectionLike, HandleContainer, MultiplexedConnection, Runtime},
check_resp3,
cluster_handling::{
NodeAddress,
client::ClusterParams,
get_connection_info,
read_routing::ReadRoutingStrategy,
routing::{
MultipleNodeRoutingInfo, Redirect, ResponsePolicy, RoutingInfo, SingleNodeRoutingInfo,
},
slot_cmd,
slot_map::{Slot, SlotMap},
topology::parse_slots,
},
cmd,
errors::closed_connection_error,
subscription_tracker::SubscriptionTracker,
};
use crate::ProtocolVersion;
#[cfg(feature = "cache-aio")]
use crate::caching::{CacheManager, CacheStatistics};
use futures_util::{
future::{self, BoxFuture, FutureExt},
ready,
sink::Sink,
stream::{self, Stream, StreamExt},
};
use log::{debug, trace, warn};
use rand::{rng, seq::IteratorRandom};
use request::{CmdArg, PendingRequest, Request, RequestState, Retry};
use routing::{InternalRoutingInfo, InternalSingleNodeRouting, route_for_pipeline};
use tokio::sync::{RwLock, mpsc, oneshot};
struct ClientSideState {
protocol: ProtocolVersion,
_task_handle: HandleContainer,
overall_response_timeout: Option<Duration>,
runtime: Runtime,
#[cfg(feature = "cache-aio")]
cache_manager: Option<CacheManager>,
}
#[derive(Clone)]
pub struct ClusterConnection<C = MultiplexedConnection> {
state: Arc<ClientSideState>,
sender: mpsc::Sender<Message<C>>,
}
impl<C> ClusterConnection<C>
where
C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static,
{
pub(crate) async fn new(
initial_nodes: &[ConnectionInfo],
cluster_params: ClusterParams,
) -> RedisResult<ClusterConnection<C>> {
let (connection, connect_receiver) = Self::new_inner(initial_nodes, cluster_params);
connect_receiver.await.map_err(|_| {
RedisError::from((ErrorKind::Io, "Cluster connection task were dropped"))
})??;
Ok(connection)
}
pub(crate) fn new_pending(
initial_nodes: &[ConnectionInfo],
cluster_params: ClusterParams,
) -> ClusterConnection<C> {
let (connection, _connect_receiver) = Self::new_inner(initial_nodes, cluster_params);
connection
}
pub(crate) fn new_inner(
initial_nodes: &[ConnectionInfo],
cluster_params: ClusterParams,
) -> (ClusterConnection<C>, oneshot::Receiver<RedisResult<()>>) {
let protocol = cluster_params.protocol.unwrap_or_default();
let overall_response_timeout = cluster_params.overall_response_timeout;
#[cfg(feature = "cache-aio")]
let cache_manager = cluster_params.cache_manager.clone();
let runtime = Runtime::locate();
let mut inner = ClusterConnInner::new(initial_nodes, cluster_params);
let (connect_sender, connect_receiver) = oneshot::channel::<RedisResult<()>>();
let (sender, mut receiver) = mpsc::channel::<Message<_>>(100);
let stream = async move {
let connect_result = inner.wait_for_initial_connection().await;
let _ = connect_sender.send(connect_result);
let _ = stream::poll_fn(move |cx| receiver.poll_recv(cx))
.map(Ok)
.forward(inner)
.await;
};
let _task_handle = HandleContainer::new(runtime.spawn(stream));
(
ClusterConnection {
sender,
state: Arc::new(ClientSideState {
protocol,
_task_handle,
overall_response_timeout,
runtime,
#[cfg(feature = "cache-aio")]
cache_manager,
}),
},
connect_receiver,
)
}
pub async fn route_command(&mut self, cmd: Cmd, routing: RoutingInfo) -> RedisResult<Value> {
trace!("send_packed_command");
let (sender, receiver) = oneshot::channel();
let request = async {
self.sender
.send(Message {
cmd: CmdArg::Cmd {
cmd: Arc::new(cmd),
routing: routing.into(),
},
sender,
})
.await
.map_err(|_| {
RedisError::from(io::Error::new(
io::ErrorKind::BrokenPipe,
"redis_cluster: Unable to send command",
))
})?;
receiver
.await
.unwrap_or_else(|_| {
Err(RedisError::from(io::Error::new(
io::ErrorKind::BrokenPipe,
"redis_cluster: Unable to receive command",
)))
})
.map(|response| match response {
Response::Single(value) => value,
Response::Multiple(_) => unreachable!(),
})
};
match self.state.overall_response_timeout {
Some(duration) => self.state.runtime.timeout(duration, request).await?,
None => request.await,
}
}
pub async fn route_pipeline(
&mut self,
pipeline: crate::Pipeline,
offset: usize,
count: usize,
route: SingleNodeRoutingInfo,
) -> RedisResult<Vec<Value>> {
let (sender, receiver) = oneshot::channel();
let request = async {
self.sender
.send(Message {
cmd: CmdArg::Pipeline {
pipeline: Arc::new(pipeline),
offset,
count,
route: route.into(),
},
sender,
})
.await
.map_err(|_| closed_connection_error())?;
receiver
.await
.unwrap_or_else(|_| Err(closed_connection_error()))
.map(|response| match response {
Response::Multiple(values) => values,
Response::Single(_) => unreachable!(),
})
};
match self.state.overall_response_timeout {
Some(duration) => self.state.runtime.timeout(duration, request).await?,
None => request.await,
}
}
pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.state.protocol);
let mut cmd = cmd("SUBSCRIBE");
cmd.arg(channel_name);
cmd.exec_async(self).await?;
Ok(())
}
pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.state.protocol);
let mut cmd = cmd("UNSUBSCRIBE");
cmd.arg(channel_name);
cmd.exec_async(self).await?;
Ok(())
}
pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.state.protocol);
let mut cmd = cmd("PSUBSCRIBE");
cmd.arg(channel_pattern);
cmd.exec_async(self).await?;
Ok(())
}
pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.state.protocol);
let mut cmd = cmd("PUNSUBSCRIBE");
cmd.arg(channel_pattern);
cmd.exec_async(self).await?;
Ok(())
}
pub async fn ssubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.state.protocol);
let mut cmd = cmd("SSUBSCRIBE");
cmd.arg(channel_name);
cmd.exec_async(self).await?;
Ok(())
}
pub async fn sunsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
check_resp3!(self.state.protocol);
let mut cmd = cmd("SUNSUBSCRIBE");
cmd.arg(channel_name);
cmd.exec_async(self).await?;
Ok(())
}
#[cfg(feature = "cache-aio")]
#[cfg_attr(docsrs, doc(cfg(feature = "cache-aio")))]
pub fn get_cache_statistics(&self) -> Option<CacheStatistics> {
self.state.cache_manager.as_ref().map(|cm| cm.statistics())
}
}
type ConnectionMap<C> = HashMap<NodeAddress, C>;
struct InnerCore<C> {
conn_lock: RwLock<(ConnectionMap<C>, SlotMap)>,
cluster_params: ClusterParams,
pending_requests_tx: mpsc::UnboundedSender<PendingRequest<C>>,
initial_nodes: Vec<ConnectionInfo>,
subscription_tracker: Option<Mutex<SubscriptionTracker>>,
routing_strategy: Option<Box<dyn ReadRoutingStrategy>>,
}
#[derive(Clone)]
struct Core<C>(Arc<InnerCore<C>>);
impl<C> Deref for Core<C> {
type Target = InnerCore<C>;
fn deref(&self) -> &InnerCore<C> {
&self.0
}
}
impl<C> Core<C>
where
C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
{
async fn execute_on_multiple_nodes<'a>(
&self,
cmd: &'a Arc<Cmd>,
routing: &'a MultipleNodeRoutingInfo,
response_policy: Option<ResponsePolicy>,
) -> OperationResult {
let read_guard = self.conn_lock.read().await;
if read_guard.0.is_empty() {
return (
OperationTarget::FanOut,
Result::Err(
(
ErrorKind::ClusterConnectionNotFound,
"No connections found for multi-node operation",
)
.into(),
),
);
}
let (receivers, requests): (Vec<_>, Vec<_>) = {
let to_request = |(addr, cmd): (&NodeAddress, Arc<Cmd>)| {
read_guard.0.get(addr).cloned().map(|conn| {
let (sender, receiver) = oneshot::channel();
let addr = addr.clone();
(
(addr.clone(), receiver),
PendingRequest {
retry: 0,
sender: request::ResultExpectation::External(sender),
cmd: CmdArg::Cmd {
cmd,
routing: InternalSingleNodeRouting::Connection {
identifier: addr,
conn,
}
.into(),
},
},
)
})
};
let slot_map = &read_guard.1;
match routing {
MultipleNodeRoutingInfo::AllNodes => slot_map
.addresses_for_all_nodes()
.into_iter()
.filter_map(|addr| to_request((addr, cmd.clone())))
.unzip(),
MultipleNodeRoutingInfo::AllMasters => slot_map
.addresses_for_all_primaries()
.into_iter()
.filter_map(|addr| to_request((addr, cmd.clone())))
.unzip(),
MultipleNodeRoutingInfo::MultiSlot((routes, _)) => slot_map
.addresses_for_multi_slot(routes, self.routing_strategy.as_deref())
.enumerate()
.filter_map(|(index, addr_opt)| {
addr_opt.and_then(|addr| {
let (_, indices) = routes.get(index).unwrap();
let cmd =
Arc::new(crate::cluster_routing::command_for_multi_slot_indices(
cmd.as_ref(),
indices.iter(),
));
to_request((addr, cmd))
})
})
.unzip(),
}
};
drop(read_guard);
for request in requests {
let _ = self.pending_requests_tx.send(request);
}
(
OperationTarget::FanOut,
Self::aggregate_results(receivers, routing, response_policy)
.await
.map(Response::Single),
)
}
async fn aggregate_results(
receivers: Vec<(NodeAddress, oneshot::Receiver<RedisResult<Response>>)>,
routing: &MultipleNodeRoutingInfo,
response_policy: Option<ResponsePolicy>,
) -> RedisResult<Value> {
if receivers.is_empty() {
return Err((
ErrorKind::ClusterConnectionNotFound,
"No nodes found for multi-node operation",
)
.into());
}
let extract_result = |response| match response {
Response::Single(value) => value,
Response::Multiple(_) => unreachable!(),
};
let convert_result = |res: Result<RedisResult<Response>, _>| {
res.map_err(|_| RedisError::from((ErrorKind::Client, "request wasn't handled due to internal failure"))) .and_then(|res| res.map(extract_result))
};
let get_receiver = |(_, receiver): (_, oneshot::Receiver<RedisResult<Response>>)| async {
convert_result(receiver.await)
};
match response_policy {
Some(ResponsePolicy::AllSucceeded) => {
future::try_join_all(receivers.into_iter().map(get_receiver))
.await
.and_then(|mut results| {
results.pop().ok_or(
(
ErrorKind::ClusterConnectionNotFound,
"No results received for multi-node operation",
)
.into(),
)
})
}
Some(ResponsePolicy::OneSucceeded) => future::select_ok(
receivers
.into_iter()
.map(|tuple| Box::pin(get_receiver(tuple))),
)
.await
.map(|(result, _)| result),
Some(ResponsePolicy::FirstSucceededNonEmptyOrAllEmpty) => {
let mut nil_counter = 0;
let mut last_err = None;
let resolved = future::join_all(receivers.into_iter().map(get_receiver)).await;
let num_results = resolved.len();
for val in resolved {
match val {
Ok(Value::Nil) => nil_counter += 1,
Ok(Value::ServerError(err)) => {
last_err = Some(err.into());
}
Ok(val) => return Ok(val),
Err(err) => {
last_err = Some(err);
}
}
}
if nil_counter == num_results {
Ok(Value::Nil)
} else {
Err(last_err.unwrap_or_else(|| {
(
ErrorKind::ClusterConnectionNotFound,
"Couldn't find any connection",
)
.into()
}))
}
}
Some(ResponsePolicy::Aggregate(op)) => {
future::try_join_all(receivers.into_iter().map(get_receiver))
.await
.and_then(|results| crate::cluster_routing::aggregate(results, op))
}
Some(ResponsePolicy::AggregateLogical(op)) => {
future::try_join_all(receivers.into_iter().map(get_receiver))
.await
.and_then(|results| crate::cluster_routing::logical_aggregate(results, op))
}
Some(ResponsePolicy::CombineArrays) => {
future::try_join_all(receivers.into_iter().map(get_receiver))
.await
.and_then(|results| match routing {
MultipleNodeRoutingInfo::MultiSlot((vec, pattern)) => {
crate::cluster_routing::combine_and_sort_array_results(
results, vec, pattern,
)
}
_ => crate::cluster_routing::combine_array_results(results),
})
}
Some(ResponsePolicy::CombineMaps) => {
let resolved =
future::try_join_all(receivers.into_iter().map(get_receiver)).await?;
crate::cluster_routing::combine_map_results(resolved)
}
Some(ResponsePolicy::Special) | None => {
let results =
future::join_all(receivers.into_iter().map(|(addr, receiver)| async move {
let result =
convert_result(receiver.await).or_else(|err| match err.try_into() {
Ok(server_error) => Ok(Value::ServerError(server_error)),
Err(err) => Err(err),
})?;
Ok::<_, RedisError>((
Value::BulkString(addr.to_string().into_bytes()),
result,
))
}))
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()?;
Ok(Value::Map(results))
}
}
}
async fn try_cmd_request(
&self,
cmd: Arc<Cmd>,
routing: InternalRoutingInfo<C>,
) -> OperationResult {
let route = match routing {
InternalRoutingInfo::SingleNode(single_node_routing) => single_node_routing,
InternalRoutingInfo::MultiNode((multi_node_routing, response_policy)) => {
return self
.execute_on_multiple_nodes(&cmd, &multi_node_routing, response_policy)
.await;
}
};
match self.get_connection(route).await {
Ok((addr, mut conn)) => (
addr.into(),
conn.req_packed_command(&cmd)
.await
.inspect(|res| {
if !matches!(res, Value::ServerError(_)) {
if let Some(tracker) = &self.subscription_tracker {
if let Some((action, args)) =
SubscriptionTracker::to_request(cmd.as_ref())
{
let mut tracker = tracker.lock().unwrap();
tracker.update_with_request(action, args);
}
}
}
})
.map(Response::Single),
),
Err(err) => (OperationTarget::NotFound, Err(err)),
}
}
async fn try_pipeline_request(
&self,
pipeline: Arc<crate::Pipeline>,
offset: usize,
count: usize,
route: InternalSingleNodeRouting<C>,
) -> OperationResult {
match self.get_connection(route).await {
Ok((addr, mut conn)) => (
OperationTarget::Node { address: addr },
conn.req_packed_commands(&pipeline, offset, count)
.await
.inspect(|res| {
let Some(tracker) = &self.subscription_tracker else {
return;
};
let res = if pipeline.is_transaction() {
debug_assert_eq!(res.len(), 1);
match res[0] {
Value::Array(ref arr) => arr,
_ => res,
}
} else {
res
};
let mut iterator = pipeline
.cmd_iter()
.enumerate()
.flat_map(|(index, cmd)| {
if matches!(res[index], Value::ServerError(_)) {
None
} else {
SubscriptionTracker::to_request(cmd)
}
})
.peekable();
if iterator.peek().is_some() {
let mut tracker = tracker.lock().unwrap();
for (action, args) in iterator {
tracker.update_with_request(action, args);
}
}
})
.map(Response::Multiple),
),
Err(err) => (OperationTarget::NotFound, Err(err)),
}
}
async fn try_request(self, cmd: CmdArg<C>) -> OperationResult {
match cmd {
CmdArg::Cmd { cmd, routing } => self.try_cmd_request(cmd, routing).await,
CmdArg::Pipeline {
pipeline,
offset,
count,
route,
} => {
self.try_pipeline_request(pipeline, offset, count, route)
.await
}
}
}
async fn get_connection(
&self,
route: InternalSingleNodeRouting<C>,
) -> RedisResult<(NodeAddress, C)> {
let read_guard = self.conn_lock.read().await;
let conn = match route {
InternalSingleNodeRouting::Random => None,
InternalSingleNodeRouting::SpecificNode(route) => read_guard
.1
.slot_addr_for_route(&route, self.routing_strategy.as_deref())
.cloned(),
InternalSingleNodeRouting::Connection { identifier, conn } => {
return Ok((identifier, conn));
}
InternalSingleNodeRouting::Redirect { redirect, .. } => {
drop(read_guard);
return self.get_redirected_connection(redirect).await;
}
InternalSingleNodeRouting::ByAddress(address) => {
if let Some(conn) = read_guard.0.get(&address).cloned() {
return Ok((address, conn));
} else {
return Err((
ErrorKind::Client,
"Requested connection not found",
address.to_string(),
)
.into());
}
}
}
.map(|addr| {
let conn = read_guard.0.get(&addr).cloned();
(addr, conn)
});
drop(read_guard);
let addr_conn_option = match conn {
Some((addr, Some(conn))) => Some((addr, conn)),
Some((addr, None)) => self
.connect_check_and_add(&addr)
.await
.ok()
.map(|conn| (addr, conn)),
None => None,
};
let (addr, conn) = match addr_conn_option {
Some(tuple) => tuple,
None => {
let read_guard = self.conn_lock.read().await;
if let Some((random_addr, random_conn)) = get_random_connection(&read_guard.0) {
drop(read_guard);
(random_addr, random_conn)
} else {
return Err(
(ErrorKind::ClusterConnectionNotFound, "No connections found").into(),
);
}
}
};
Ok((addr, conn))
}
async fn get_redirected_connection(&self, redirect: Redirect) -> RedisResult<(NodeAddress, C)> {
let asking = matches!(redirect, Redirect::Ask(_));
let addr = match redirect {
Redirect::Moved(addr) => addr,
Redirect::Ask(addr) => addr,
};
let read_guard = self.conn_lock.read().await;
let conn = read_guard.0.get(&addr).cloned();
drop(read_guard);
let mut conn = match conn {
Some(conn) => conn,
None => self.connect_check_and_add(&addr).await?,
};
if asking {
let mut asking_cmd = crate::cmd::cmd("ASKING");
asking_cmd.skip_concurrency_limit = true;
let _ = conn
.req_packed_command(&asking_cmd)
.await
.and_then(|value| value.extract_error());
}
Ok((addr, conn))
}
async fn connect_check_and_add(&self, addr: &NodeAddress) -> RedisResult<C> {
match connect_and_check::<C>(addr, &self.cluster_params).await {
Ok(conn) => {
self.conn_lock
.write()
.await
.0
.insert(addr.clone(), conn.clone());
Ok(conn)
}
Err(err) => Err(err),
}
}
async fn refresh_slots(self) -> RedisResult<()> {
let mut write_guard = self.conn_lock.write().await;
let (connections, slots) = &mut *write_guard;
let mut result = Ok(());
for (addr, conn) in &mut *connections {
result = async {
let mut slot_refresh_cmd = slot_cmd();
slot_refresh_cmd.skip_concurrency_limit = true;
let value = conn
.req_packed_command(&slot_refresh_cmd)
.await
.and_then(|value| value.extract_error())?;
let v: Vec<Slot> = parse_slots(value, addr.host())?;
build_slot_map(slots, v)
}
.await;
if result.is_ok() {
break;
}
}
result?;
if let Some(ref strategy) = self.routing_strategy {
strategy.on_topology_changed(slots.topology());
}
let nodes = slots.values().flatten().cloned().collect::<HashSet<_>>();
self.refresh_connections_locked(connections, nodes).await;
Ok(())
}
async fn refresh_connections_locked(
&self,
connections: &mut ConnectionMap<C>,
nodes: HashSet<NodeAddress>,
) {
let nodes_len = nodes.len();
let addresses_and_connections_iter = nodes
.into_iter()
.map(|addr| {
let connection = connections.remove(&addr);
async move {
let res = get_or_create_conn(&addr, connection, &self.cluster_params).await;
(addr, res)
}
})
.collect::<Vec<_>>();
stream::iter(addresses_and_connections_iter)
.buffer_unordered(nodes_len.max(8))
.fold(connections, |connections, (addr, result)| async move {
if let Ok(conn) = result {
connections.insert(addr, conn);
}
connections
})
.await;
}
fn resubscribe(&self) {
let Some(subscription_tracker) = self.subscription_tracker.as_ref() else {
return;
};
let subscription_pipe = subscription_tracker
.lock()
.unwrap()
.get_subscription_pipeline();
let requests = subscription_pipe.into_cmd_iter().map(|cmd| {
let routing = RoutingInfo::for_routable(&cmd)
.unwrap_or(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random))
.into();
PendingRequest {
retry: 0,
sender: request::ResultExpectation::Internal,
cmd: CmdArg::Cmd {
cmd: Arc::new(cmd),
routing,
},
}
});
for request in requests {
let _ = self.pending_requests_tx.send(request);
}
}
}
struct ClusterConnInner<C> {
inner: Core<C>,
state: ConnectionState,
#[allow(clippy::complexity)]
in_flight_requests: stream::FuturesUnordered<Pin<Box<Request<C>>>>,
pending_requests_rx: mpsc::UnboundedReceiver<PendingRequest<C>>,
refresh_error: Option<RedisError>,
}
fn boxed_sleep(duration: Duration) -> BoxFuture<'static, ()> {
Box::pin(Runtime::locate_and_sleep(duration))
}
#[derive(Debug, PartialEq)]
pub(crate) enum Response {
Single(Value),
Multiple(Vec<Value>),
}
enum OperationTarget {
Node { address: NodeAddress },
NotFound,
FanOut,
}
type OperationResult = (OperationTarget, Result<Response, RedisError>);
impl From<NodeAddress> for OperationTarget {
fn from(address: NodeAddress) -> Self {
OperationTarget::Node { address }
}
}
struct Message<C> {
cmd: CmdArg<C>,
sender: oneshot::Sender<RedisResult<Response>>,
}
enum RecoverFuture {
RecoverSlots(BoxFuture<'static, RedisResult<()>>),
Reconnect(BoxFuture<'static, RedisResult<()>>),
}
enum ConnectionState {
PollComplete,
Recover(RecoverFuture),
}
impl fmt::Debug for ConnectionState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match self {
ConnectionState::PollComplete => "PollComplete",
ConnectionState::Recover(_) => "Recover",
}
)
}
}
fn build_slot_map(slot_map: &mut SlotMap, slots_data: Vec<Slot>) -> RedisResult<()> {
slot_map.clear();
slot_map.fill_slots(slots_data);
trace!("{slot_map:?}");
Ok(())
}
impl<C> ClusterConnInner<C>
where
C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
{
fn new(initial_nodes: &[ConnectionInfo], cluster_params: ClusterParams) -> Self {
let subscription_tracker = if cluster_params.async_push_sender.is_some() {
Some(Mutex::new(SubscriptionTracker::default()))
} else {
None
};
let routing_strategy = cluster_params
.read_routing_factory
.as_ref()
.map(|f| f.create_strategy());
let (pending_requests_tx, pending_requests_rx) = mpsc::unbounded_channel();
let inner = Arc::new(InnerCore {
conn_lock: RwLock::new((Default::default(), SlotMap::new())),
cluster_params,
pending_requests_tx,
initial_nodes: initial_nodes.to_vec(),
subscription_tracker,
routing_strategy,
});
let core = Core(inner);
let mut inner = ClusterConnInner {
inner: core.clone(),
in_flight_requests: Default::default(),
pending_requests_rx,
refresh_error: None,
state: ConnectionState::PollComplete,
};
inner.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin(
inner.reconnect_to_initial_nodes(),
)));
inner
}
async fn create_initial_connections(
initial_nodes: &[ConnectionInfo],
params: &ClusterParams,
) -> RedisResult<ConnectionMap<C>> {
let (connections, error) = stream::iter(initial_nodes.iter().cloned())
.map(async move |info| {
let addr = NodeAddress::try_from(&info.addr)?;
let result = connect_and_check(&addr, params).await;
match result {
Ok(conn) => Ok((addr, conn)),
Err(e) => {
debug!("Failed to connect to initial node: {e:?}");
Err(e)
}
}
})
.buffer_unordered(initial_nodes.len())
.fold(
(ConnectionMap::<C>::with_capacity(initial_nodes.len()), None),
|(mut connections, mut error), result| async move {
match result {
Ok((addr, conn)) => {
connections.insert(addr, conn);
}
Err(err) => {
error = Some(err);
}
}
(connections, error)
},
)
.await;
if connections.is_empty() {
if let Some(err) = error {
return Err(RedisError::from((
ErrorKind::Io,
"Failed to create initial connections",
err.to_string(),
)));
} else {
return Err(RedisError::from((
ErrorKind::Io,
"Failed to create initial connections",
)));
}
}
Ok(connections)
}
fn reconnect_to_initial_nodes(&mut self) -> impl Future<Output = RedisResult<()>> + use<C> {
debug!("Received request to reconnect to initial nodes");
let inner = self.inner.clone();
async move {
let connection_map =
Self::create_initial_connections(&inner.initial_nodes, &inner.cluster_params)
.await?;
*inner.conn_lock.write().await = (connection_map, SlotMap::new());
inner.refresh_slots().await?;
Ok(())
}
}
fn refresh_connections(
&mut self,
addrs: HashSet<NodeAddress>,
) -> impl Future<Output = ()> + use<C> {
let inner = self.inner.clone();
async move {
let mut write_guard = inner.conn_lock.write().await;
inner
.refresh_connections_locked(&mut write_guard.0, addrs)
.await;
}
}
async fn wait_for_initial_connection(&mut self) -> RedisResult<()> {
if let ConnectionState::Recover(fut) =
std::mem::replace(&mut self.state, ConnectionState::PollComplete)
{
match fut {
RecoverFuture::RecoverSlots(fut) => fut.await?,
RecoverFuture::Reconnect(fut) => fut.await?,
}
}
Ok(())
}
fn poll_recover(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), RedisError>> {
let recover_future = match &mut self.state {
ConnectionState::PollComplete => return Poll::Ready(Ok(())),
ConnectionState::Recover(future) => future,
};
let res = match recover_future {
RecoverFuture::RecoverSlots(future) => match ready!(future.as_mut().poll(cx)) {
Ok(_) => {
trace!("Recovered!");
self.state = ConnectionState::PollComplete;
Ok(())
}
Err(err) => {
trace!("Recover slots failed!");
*future = Box::pin(self.inner.clone().refresh_slots());
Err(err)
}
},
RecoverFuture::Reconnect(future) => {
match ready!(future.as_mut().poll(cx)) {
Err(err) => warn!("Can't reconnect to initial nodes: `{err}`"),
Ok(()) => trace!("Reconnected connections"),
}
self.state = ConnectionState::PollComplete;
Ok(())
}
};
if res.is_ok() {
self.inner.resubscribe();
}
Poll::Ready(res)
}
fn handle_retries(&mut self, request_handling: Option<Retry<C>>) {
match request_handling {
Some(Retry::MoveToPending { request }) => {
let _ = self.inner.pending_requests_tx.send(request);
}
Some(Retry::Immediately { request }) => {
let future = self.inner.clone().try_request(request.cmd.clone());
self.in_flight_requests.push(Box::pin(Request {
retry_params: self.inner.cluster_params.retry_params.clone(),
request: Some(request),
future: RequestState::Future {
future: Box::pin(future),
},
}));
}
Some(Retry::AfterSleep {
request,
sleep_duration,
}) => {
let future = RequestState::Sleep {
sleep: boxed_sleep(sleep_duration),
};
self.in_flight_requests.push(Box::pin(Request {
retry_params: self.inner.cluster_params.retry_params.clone(),
request: Some(request),
future,
}));
}
None => {}
}
}
fn poll_complete(&mut self, cx: &mut task::Context<'_>) -> Poll<PollFlushAction> {
let mut poll_flush_action = PollFlushAction::None;
loop {
let request = match self.pending_requests_rx.try_recv() {
Ok(request) => request,
Err(tokio::sync::mpsc::error::TryRecvError::Empty)
| Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
};
if request.sender.is_closed() {
continue;
}
let future = self.inner.clone().try_request(request.cmd.clone()).boxed();
self.in_flight_requests.push(Box::pin(Request {
retry_params: self.inner.cluster_params.retry_params.clone(),
request: Some(request),
future: RequestState::Future { future },
}));
}
loop {
let (request_handling, next) =
match Pin::new(&mut self.in_flight_requests).poll_next(cx) {
Poll::Ready(Some(result)) => result,
Poll::Ready(None) | Poll::Pending => break,
};
self.handle_retries(request_handling);
poll_flush_action = poll_flush_action.change_state(next);
}
if !matches!(poll_flush_action, PollFlushAction::None) || self.in_flight_requests.is_empty()
{
Poll::Ready(poll_flush_action)
} else {
Poll::Pending
}
}
fn send_refresh_error(&mut self) {
let Some(refresh_error) = self.refresh_error.take() else {
return;
};
let mut inflight_requests = Vec::new();
for mut request in Pin::new(&mut self.in_flight_requests).iter_pin_mut() {
let mut request = request.as_mut();
let Some(pending_request) = request.request.take() else {
continue;
};
inflight_requests.push((pending_request, std::mem::take(&mut request.retry_params)));
}
if inflight_requests.is_empty() {
let maybe_request = self.pending_requests_rx.try_recv().ok();
if let Some(request) = maybe_request {
inflight_requests.push((request, self.inner.cluster_params.retry_params.clone()));
}
}
for (pending_request, retry_params) in inflight_requests {
self.handle_retries(
request::choose_response(
(OperationTarget::NotFound, Err(refresh_error.clone())),
pending_request,
&retry_params,
)
.0,
);
}
}
}
async fn get_or_create_conn<C>(
addr: &NodeAddress,
conn_option: Option<C>,
params: &ClusterParams,
) -> RedisResult<C>
where
C: Connect + ConnectionLike + Clone + Send + Sync + 'static,
{
if let Some(mut conn) = conn_option {
match check_connection(&mut conn).await {
Ok(_) => Ok(conn),
Err(_) => connect_and_check(addr, params).await,
}
} else {
connect_and_check(addr, params).await
}
}
#[derive(Debug, PartialEq)]
enum PollFlushAction {
None,
RebuildSlots,
Reconnect(HashSet<NodeAddress>),
ReconnectFromInitialConnections,
}
impl PollFlushAction {
fn change_state(self, next_state: PollFlushAction) -> PollFlushAction {
match (self, next_state) {
(PollFlushAction::None, next_state) => next_state,
(next_state, PollFlushAction::None) => next_state,
(PollFlushAction::ReconnectFromInitialConnections, _)
| (_, PollFlushAction::ReconnectFromInitialConnections) => {
PollFlushAction::ReconnectFromInitialConnections
}
(PollFlushAction::RebuildSlots, _) | (_, PollFlushAction::RebuildSlots) => {
PollFlushAction::RebuildSlots
}
(PollFlushAction::Reconnect(mut addrs), PollFlushAction::Reconnect(new_addrs)) => {
addrs.extend(new_addrs);
Self::Reconnect(addrs)
}
}
}
}
impl<C> Sink<Message<C>> for ClusterConnInner<C>
where
C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static,
{
type Error = ();
fn poll_ready(self: Pin<&mut Self>, _cx: &mut task::Context) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, msg: Message<C>) -> Result<(), Self::Error> {
trace!("start_send");
let Message { cmd, sender } = msg;
let _ = self.inner.pending_requests_tx.send(PendingRequest {
retry: 0,
sender: request::ResultExpectation::External(sender),
cmd,
});
Ok(())
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
trace!("poll_flush: {:?}", self.state);
loop {
self.send_refresh_error();
if let Err(err) = ready!(self.as_mut().poll_recover(cx)) {
self.refresh_error = Some(err);
cx.waker().wake_by_ref();
return Poll::Pending;
}
match ready!(self.poll_complete(cx)) {
PollFlushAction::None => return Poll::Ready(Ok(())),
PollFlushAction::RebuildSlots => {
self.state = ConnectionState::Recover(RecoverFuture::RecoverSlots(Box::pin(
self.inner.clone().refresh_slots(),
)));
}
PollFlushAction::Reconnect(addrs) => {
self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin(
self.refresh_connections(addrs).map(Ok),
)));
}
PollFlushAction::ReconnectFromInitialConnections => {
self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin(
self.reconnect_to_initial_nodes(),
)));
}
}
}
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
match self.poll_complete(cx) {
Poll::Ready(PollFlushAction::None) => (),
Poll::Ready(_) => Err(())?,
Poll::Pending => (),
};
if self.in_flight_requests.is_empty() {
return Poll::Ready(Ok(()));
}
self.poll_flush(cx)
}
}
impl<C> ConnectionLike for ClusterConnection<C>
where
C: ConnectionLike + Send + Clone + Unpin + Sync + Connect + 'static,
{
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
let routing = RoutingInfo::for_routable(cmd)
.unwrap_or(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random));
self.route_command(cmd.clone(), routing).boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
pipeline: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
async move {
let route = route_for_pipeline(pipeline)?;
self.route_pipeline(pipeline.clone(), offset, count, route.into())
.await
}
.boxed()
}
fn get_db(&self) -> i64 {
0
}
}
pub trait Connect: Sized {
fn connect_with_config<'a, T>(info: T, config: AsyncConnectionConfig) -> RedisFuture<'a, Self>
where
T: IntoConnectionInfo + Send + 'a;
}
impl Connect for MultiplexedConnection {
fn connect_with_config<'a, T>(info: T, config: AsyncConnectionConfig) -> RedisFuture<'a, Self>
where
T: IntoConnectionInfo + Send + 'a,
{
async move {
let connection_info = info.into_connection_info()?;
let client = crate::Client::open(connection_info)?;
client
.get_multiplexed_async_connection_with_config(&config)
.await
}
.boxed()
}
}
async fn connect_and_check<C>(node: &NodeAddress, params: &ClusterParams) -> RedisResult<C>
where
C: ConnectionLike + Connect + Send + 'static,
{
let info = get_connection_info(node, params);
let mut config =
AsyncConnectionConfig::default().set_connection_timeout(Some(params.connection_timeout));
config = config.set_response_timeout(params.response_timeout);
if let Some(push_sender) = ¶ms.async_push_sender {
config = config.set_push_sender_internal(push_sender.clone());
}
if let Some(resolver) = ¶ms.async_dns_resolver {
config = config.set_dns_resolver_internal(resolver.clone());
}
#[cfg(feature = "cache-aio")]
if let Some(cache_manager) = ¶ms.cache_manager {
config = config.set_cache_manager(cache_manager.clone_and_increase_epoch());
}
#[cfg(feature = "token-based-authentication")]
if let Some(credentials_provider) = ¶ms.credentials_provider {
config = config.set_credentials_provider_internal(credentials_provider.clone());
}
if let Some(limit) = params.connection_concurrency_limit {
config = config.set_concurrency_limit(limit);
}
let mut conn = match C::connect_with_config(info, config).await {
Ok(conn) => conn,
Err(err) => {
warn!("Failed to connect to node: {node:?}, due to: {err:?}");
return Err(err);
}
};
let mut readonly_cmd = cmd("READONLY");
readonly_cmd.skip_concurrency_limit = true;
conn.req_packed_command(&readonly_cmd).await?;
Ok(conn)
}
async fn check_connection<C>(conn: &mut C) -> RedisResult<()>
where
C: ConnectionLike + Send + 'static,
{
let mut ping_cmd = cmd("PING");
ping_cmd.skip_concurrency_limit = true;
conn.req_packed_command(&ping_cmd)
.await
.and_then(|v| v.extract_error())?;
Ok(())
}
fn get_random_connection<C>(connections: &ConnectionMap<C>) -> Option<(NodeAddress, C)>
where
C: Clone,
{
connections
.iter()
.choose(&mut rng())
.map(|(addr, conn)| (addr.clone(), conn.clone()))
}