use std::{
collections::HashMap,
fmt, io,
iter::Iterator,
marker::Unpin,
mem,
pin::Pin,
sync::{Arc, Mutex},
task::{self, Poll},
};
use crate::{
aio::{ConnectionLike, MultiplexedConnection},
cluster::{get_connection_info, parse_slots, slot_cmd},
cluster_client::{ClusterParams, RetryParams},
cluster_routing::{
MultipleNodeRoutingInfo, Redirect, ResponsePolicy, Route, RoutingInfo,
SingleNodeRoutingInfo, Slot, SlotAddr, SlotMap,
},
Cmd, ConnectionInfo, ErrorKind, IntoConnectionInfo, RedisError, RedisFuture, RedisResult,
Value,
};
#[cfg(feature = "tls-rustls")]
use crate::tls::TlsConnParams;
#[cfg(not(feature = "tls-rustls"))]
use crate::connection::TlsConnParams;
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
use crate::aio::{async_std::AsyncStd, RedisRuntime};
use futures::{
future::{self, BoxFuture},
prelude::*,
ready, stream,
};
use log::trace;
use pin_project_lite::pin_project;
use rand::{seq::IteratorRandom, thread_rng};
use tokio::sync::{mpsc, oneshot, RwLock};
const SLOT_SIZE: usize = 16384;
#[derive(Clone)]
pub struct ClusterConnection<C = MultiplexedConnection>(mpsc::Sender<Message<C>>);
impl<C> ClusterConnection<C>
where
C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static,
{
pub(crate) async fn new(
initial_nodes: &[ConnectionInfo],
cluster_params: ClusterParams,
tls_params: Option<TlsConnParams>,
) -> RedisResult<ClusterConnection<C>> {
ClusterConnInner::new(initial_nodes, cluster_params, tls_params)
.await
.map(|inner| {
let (tx, mut rx) = mpsc::channel::<Message<_>>(100);
let stream = async move {
let _ = stream::poll_fn(move |cx| rx.poll_recv(cx))
.map(Ok)
.forward(inner)
.await;
};
#[cfg(feature = "tokio-comp")]
tokio::spawn(stream);
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
AsyncStd::spawn(stream);
ClusterConnection(tx)
})
}
pub async fn route_command(&mut self, cmd: &Cmd, routing: RoutingInfo) -> RedisResult<Value> {
trace!("send_packed_command");
let (sender, receiver) = oneshot::channel();
self.0
.send(Message {
cmd: CmdArg::Cmd {
cmd: Arc::new(cmd.clone()), routing: CommandRouting::Route(routing),
},
sender,
})
.await
.map_err(|_| {
RedisError::from(io::Error::new(
io::ErrorKind::BrokenPipe,
"redis_cluster: Unable to send command",
))
})?;
receiver
.await
.unwrap_or_else(|_| {
Err(RedisError::from(io::Error::new(
io::ErrorKind::BrokenPipe,
"redis_cluster: Unable to receive command",
)))
})
.map(|response| match response {
Response::Single(value) => value,
Response::Multiple(_) => unreachable!(),
})
}
pub async fn route_pipeline<'a>(
&'a mut self,
pipeline: &'a crate::Pipeline,
offset: usize,
count: usize,
route: SingleNodeRoutingInfo,
) -> RedisResult<Vec<Value>> {
let (sender, receiver) = oneshot::channel();
self.0
.send(Message {
cmd: CmdArg::Pipeline {
pipeline: Arc::new(pipeline.clone()), offset,
count,
route,
},
sender,
})
.await
.map_err(|_| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))?;
receiver
.await
.unwrap_or_else(|_| Err(RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))))
.map(|response| match response {
Response::Multiple(values) => values,
Response::Single(_) => unreachable!(),
})
}
}
type ConnectionFuture<C> = future::Shared<BoxFuture<'static, C>>;
type ConnectionMap<C> = HashMap<String, ConnectionFuture<C>>;
struct InnerCore<C> {
conn_lock: RwLock<(ConnectionMap<C>, SlotMap)>,
cluster_params: ClusterParams,
pending_requests: Mutex<Vec<PendingRequest<Response, C>>>,
tls_params: Option<TlsConnParams>,
}
type Core<C> = Arc<InnerCore<C>>;
struct ClusterConnInner<C> {
inner: Core<C>,
state: ConnectionState,
#[allow(clippy::complexity)]
in_flight_requests: stream::FuturesUnordered<
Pin<
Box<Request<BoxFuture<'static, (OperationTarget, RedisResult<Response>)>, Response, C>>,
>,
>,
refresh_error: Option<RedisError>,
}
#[derive(Clone)]
enum CommandRouting<C> {
Route(RoutingInfo),
Connection {
addr: String,
conn: ConnectionFuture<C>,
},
}
#[derive(Clone)]
enum CmdArg<C> {
Cmd {
cmd: Arc<Cmd>,
routing: CommandRouting<C>,
},
Pipeline {
pipeline: Arc<crate::Pipeline>,
offset: usize,
count: usize,
route: SingleNodeRoutingInfo,
},
}
fn route_for_pipeline(pipeline: &crate::Pipeline) -> RedisResult<Option<Route>> {
fn route_for_command(cmd: &Cmd) -> Option<Route> {
match RoutingInfo::for_routable(cmd) {
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) => None,
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(route))) => {
Some(route)
}
Some(RoutingInfo::MultiNode(_)) => None,
None => None,
}
}
pipeline.cmd_iter().map(route_for_command).try_fold(
None,
|chosen_route, next_cmd_route| match (chosen_route, next_cmd_route) {
(None, _) => Ok(next_cmd_route),
(_, None) => Ok(chosen_route),
(Some(chosen_route), Some(next_cmd_route)) => {
if chosen_route.slot() != next_cmd_route.slot() {
Err((ErrorKind::CrossSlot, "Received crossed slots in pipeline").into())
} else if chosen_route.slot_addr() == &SlotAddr::Replica {
Ok(Some(next_cmd_route))
} else {
Ok(Some(chosen_route))
}
}
},
)
}
enum Response {
Single(Value),
Multiple(Vec<Value>),
}
enum OperationTarget {
Node { address: String },
FanOut,
}
impl From<String> for OperationTarget {
fn from(address: String) -> Self {
OperationTarget::Node { address }
}
}
struct Message<C> {
cmd: CmdArg<C>,
sender: oneshot::Sender<RedisResult<Response>>,
}
enum RecoverFuture {
RecoverSlots(BoxFuture<'static, RedisResult<()>>),
Reconnect(BoxFuture<'static, ()>),
}
enum ConnectionState {
PollComplete,
Recover(RecoverFuture),
}
impl fmt::Debug for ConnectionState {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}",
match self {
ConnectionState::PollComplete => "PollComplete",
ConnectionState::Recover(_) => "Recover",
}
)
}
}
#[derive(Clone)]
struct RequestInfo<C> {
cmd: CmdArg<C>,
redirect: Option<Redirect>,
}
pin_project! {
#[project = RequestStateProj]
enum RequestState<F> {
None,
Future {
#[pin]
future: F,
},
Sleep {
#[pin]
sleep: BoxFuture<'static, ()>,
},
}
}
struct PendingRequest<I, C> {
retry: u32,
sender: oneshot::Sender<RedisResult<I>>,
info: RequestInfo<C>,
}
pin_project! {
struct Request<F, I, C> {
retry_params: RetryParams,
request: Option<PendingRequest<I, C>>,
#[pin]
future: RequestState<F>,
}
}
#[must_use]
enum Next<I, C> {
Retry {
request: PendingRequest<I, C>,
},
Reconnect {
request: PendingRequest<I, C>,
target: String,
},
RefreshSlots {
request: PendingRequest<I, C>,
},
Done,
}
impl<F, I, C> Future for Request<F, I, C>
where
F: Future<Output = (OperationTarget, RedisResult<I>)>,
{
type Output = Next<I, C>;
fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> {
let mut this = self.as_mut().project();
if this.request.is_none() {
return Poll::Ready(Next::Done);
}
let future = match this.future.as_mut().project() {
RequestStateProj::Future { future } => future,
RequestStateProj::Sleep { sleep } => {
ready!(sleep.poll(cx));
return Next::Retry {
request: self.project().request.take().unwrap(),
}
.into();
}
_ => panic!("Request future must be Some"),
};
match ready!(future.poll(cx)) {
(_, Ok(item)) => {
trace!("Ok");
self.respond(Ok(item));
Next::Done.into()
}
(target, Err(err)) => {
trace!("Request error {}", err);
let address = match target {
OperationTarget::Node { address } => address,
OperationTarget::FanOut => {
self.respond(Err(err));
return Next::Done.into();
}
};
let request = this.request.as_mut().unwrap();
if request.retry >= this.retry_params.number_of_retries {
self.respond(Err(err));
return Next::Done.into();
}
request.retry = request.retry.saturating_add(1);
match err.kind() {
ErrorKind::Ask => {
let mut request = this.request.take().unwrap();
request.info.redirect = err
.redirect_node()
.map(|(node, _slot)| Redirect::Ask(node.to_string()));
Next::Retry { request }.into()
}
ErrorKind::Moved => {
let mut request = this.request.take().unwrap();
request.info.redirect = err
.redirect_node()
.map(|(node, _slot)| Redirect::Moved(node.to_string()));
Next::RefreshSlots { request }.into()
}
ErrorKind::TryAgain | ErrorKind::ClusterDown => {
let sleep_duration = this.retry_params.wait_time_for_retry(request.retry);
this.future.set(RequestState::Sleep {
#[cfg(feature = "tokio-comp")]
sleep: Box::pin(tokio::time::sleep(sleep_duration)),
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
sleep: Box::pin(async_std::task::sleep(sleep_duration)),
});
self.poll(cx)
}
ErrorKind::IoError => Next::Reconnect {
request: this.request.take().unwrap(),
target: address,
}
.into(),
_ => {
if err.is_retryable() {
Next::Retry {
request: this.request.take().unwrap(),
}
.into()
} else {
self.respond(Err(err));
Next::Done.into()
}
}
}
}
}
}
}
impl<F, I, C> Request<F, I, C>
where
F: Future<Output = (OperationTarget, RedisResult<I>)>,
{
fn respond(self: Pin<&mut Self>, msg: RedisResult<I>) {
let _ = self
.project()
.request
.take()
.expect("Result should only be sent once")
.sender
.send(msg);
}
}
impl<C> ClusterConnInner<C>
where
C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
{
async fn new(
initial_nodes: &[ConnectionInfo],
cluster_params: ClusterParams,
tls_params: Option<TlsConnParams>,
) -> RedisResult<Self> {
let connections =
Self::create_initial_connections(initial_nodes, &cluster_params, &tls_params).await?;
let inner = Arc::new(InnerCore {
conn_lock: RwLock::new((connections, Default::default())),
cluster_params,
pending_requests: Mutex::new(Vec::new()),
tls_params,
});
let mut connection = ClusterConnInner {
inner,
in_flight_requests: Default::default(),
refresh_error: None,
state: ConnectionState::PollComplete,
};
connection.refresh_slots().await?;
Ok(connection)
}
async fn create_initial_connections(
initial_nodes: &[ConnectionInfo],
params: &ClusterParams,
tls_params: &Option<TlsConnParams>,
) -> RedisResult<ConnectionMap<C>> {
let connections = stream::iter(initial_nodes.iter().cloned())
.map(|info| {
let params = params.clone();
let tls_params = tls_params.clone();
async move {
let addr = info.addr.to_string();
let result = connect_and_check(&addr, params, tls_params).await;
match result {
Ok(conn) => Some((addr, async { conn }.boxed().shared())),
Err(e) => {
trace!("Failed to connect to initial node: {:?}", e);
None
}
}
}
})
.buffer_unordered(initial_nodes.len())
.fold(
HashMap::with_capacity(initial_nodes.len()),
|mut connections: ConnectionMap<C>, conn| async move {
connections.extend(conn);
connections
},
)
.await;
if connections.is_empty() {
return Err(RedisError::from((
ErrorKind::IoError,
"Failed to create initial connections",
)));
}
Ok(connections)
}
fn refresh_connections(&mut self, addrs: Vec<String>) -> impl Future<Output = ()> {
let inner = self.inner.clone();
async move {
let mut write_guard = inner.conn_lock.write().await;
let mut connections = stream::iter(addrs)
.fold(
mem::take(&mut write_guard.0),
|mut connections, addr| async {
let conn = Self::get_or_create_conn(
&addr,
connections.remove(&addr),
&inner.cluster_params,
inner.tls_params.clone(),
)
.await;
if let Ok(conn) = conn {
connections.insert(addr, async { conn }.boxed().shared());
}
connections
},
)
.await;
write_guard.0 = mem::take(&mut connections);
}
}
fn refresh_slots(&mut self) -> impl Future<Output = RedisResult<()>> {
let inner = self.inner.clone();
async move {
let mut write_guard = inner.conn_lock.write().await;
let mut connections = mem::take(&mut write_guard.0);
let slots = &mut write_guard.1;
let mut result = Ok(());
for (_, conn) in connections.iter_mut() {
let mut conn = conn.clone().await;
let value = match conn.req_packed_command(&slot_cmd()).await {
Ok(value) => value,
Err(err) => {
result = Err(err);
continue;
}
};
match parse_slots(value, inner.cluster_params.tls).and_then(|v| {
Self::build_slot_map(slots, v, inner.cluster_params.read_from_replicas)
}) {
Ok(_) => {
result = Ok(());
break;
}
Err(err) => result = Err(err),
}
}
result?;
let mut nodes = write_guard.1.values().flatten().collect::<Vec<_>>();
nodes.sort_unstable();
nodes.dedup();
let nodes_len = nodes.len();
let addresses_and_connections_iter = nodes
.into_iter()
.map(|addr| (addr, connections.remove(addr)));
write_guard.0 = stream::iter(addresses_and_connections_iter)
.fold(
HashMap::with_capacity(nodes_len),
|mut connections, (addr, connection)| async {
let conn = Self::get_or_create_conn(
addr,
connection,
&inner.cluster_params,
inner.tls_params.clone(),
)
.await;
if let Ok(conn) = conn {
connections.insert(addr.to_string(), async { conn }.boxed().shared());
}
connections
},
)
.await;
Ok(())
}
}
fn build_slot_map(
slot_map: &mut SlotMap,
mut slots_data: Vec<Slot>,
read_from_replicas: bool,
) -> RedisResult<()> {
slots_data.sort_by_key(|slot_data| slot_data.start());
let last_slot = slots_data.iter().try_fold(0, |prev_end, slot_data| {
if prev_end != slot_data.start() {
return Err(RedisError::from((
ErrorKind::ResponseError,
"Slot refresh error.",
format!(
"Received overlapping slots {} and {}..{}",
prev_end,
slot_data.start(),
slot_data.end()
),
)));
}
Ok(slot_data.end() + 1)
})?;
if usize::from(last_slot) != SLOT_SIZE {
return Err(RedisError::from((
ErrorKind::ResponseError,
"Slot refresh error.",
format!("Lacks the slots >= {last_slot}"),
)));
}
slot_map.clear();
slot_map.fill_slots(&slots_data, read_from_replicas);
trace!("{:?}", slot_map);
Ok(())
}
async fn aggregate_results(
receivers: Vec<(String, oneshot::Receiver<RedisResult<Response>>)>,
routing: &MultipleNodeRoutingInfo,
response_policy: Option<ResponsePolicy>,
) -> RedisResult<Value> {
let extract_result = |response| match response {
Response::Single(value) => value,
Response::Multiple(_) => unreachable!(),
};
let convert_result = |res: Result<RedisResult<Response>, _>| {
res.map_err(|_| RedisError::from((ErrorKind::ResponseError, "request wasn't handled due to internal failure"))) .and_then(|res| res.map(extract_result))
};
let get_receiver = |(_, receiver): (_, oneshot::Receiver<RedisResult<Response>>)| async {
convert_result(receiver.await)
};
match response_policy {
Some(ResponsePolicy::AllSucceeded) => {
future::try_join_all(receivers.into_iter().map(get_receiver))
.await
.map(|mut results| results.pop().unwrap()) }
Some(ResponsePolicy::OneSucceeded) => future::select_ok(
receivers
.into_iter()
.map(|tuple| Box::pin(get_receiver(tuple))),
)
.await
.map(|(result, _)| result),
Some(ResponsePolicy::OneSucceededNonEmpty) => {
future::select_ok(receivers.into_iter().map(|(_, receiver)| {
Box::pin(async move {
let result = convert_result(receiver.await)?;
match result {
Value::Nil => Err((ErrorKind::ResponseError, "no value found").into()),
_ => Ok(result),
}
})
}))
.await
.map(|(result, _)| result)
}
Some(ResponsePolicy::Aggregate(op)) => {
future::try_join_all(receivers.into_iter().map(get_receiver))
.await
.and_then(|results| crate::cluster_routing::aggregate(results, op))
}
Some(ResponsePolicy::AggregateLogical(op)) => {
future::try_join_all(receivers.into_iter().map(get_receiver))
.await
.and_then(|results| crate::cluster_routing::logical_aggregate(results, op))
}
Some(ResponsePolicy::CombineArrays) => {
future::try_join_all(receivers.into_iter().map(get_receiver))
.await
.and_then(|results| match routing {
MultipleNodeRoutingInfo::MultiSlot(vec) => {
crate::cluster_routing::combine_and_sort_array_results(
results,
vec.iter().map(|(_, indices)| indices),
)
}
_ => crate::cluster_routing::combine_array_results(results),
})
}
Some(ResponsePolicy::Special) | None => {
future::try_join_all(receivers.into_iter().map(|(addr, receiver)| async move {
let result = convert_result(receiver.await)?;
Ok(Value::Bulk(vec![Value::Data(addr.into_bytes()), result]))
}))
.await
.map(Value::Bulk)
}
}
}
async fn execute_on_multiple_nodes<'a>(
cmd: &'a Arc<Cmd>,
routing: &'a MultipleNodeRoutingInfo,
core: Core<C>,
response_policy: Option<ResponsePolicy>,
) -> (OperationTarget, RedisResult<Response>) {
let read_guard = core.conn_lock.read().await;
let (receivers, requests): (Vec<_>, Vec<_>) = read_guard
.1
.addresses_for_multi_routing(routing)
.into_iter()
.enumerate()
.filter_map(|(index, addr)| {
read_guard.0.get(addr).cloned().map(|conn| {
let cmd = match routing {
MultipleNodeRoutingInfo::MultiSlot(vec) => {
let (_, indices) = vec.get(index).unwrap();
Arc::new(crate::cluster_routing::command_for_multi_slot_indices(
cmd.as_ref(),
indices.iter(),
))
}
_ => cmd.clone(),
};
let (sender, receiver) = oneshot::channel();
let addr = addr.to_string();
(
(addr.clone(), receiver),
PendingRequest {
retry: 0,
sender,
info: RequestInfo {
cmd: CmdArg::Cmd {
cmd,
routing: CommandRouting::Connection { addr, conn },
},
redirect: None,
},
},
)
})
})
.unzip();
drop(read_guard);
core.pending_requests.lock().unwrap().extend(requests);
let result = Self::aggregate_results(receivers, routing, response_policy)
.await
.map(Response::Single);
(OperationTarget::FanOut, result)
}
async fn try_cmd_request(
cmd: Arc<Cmd>,
redirect: Option<Redirect>,
routing: CommandRouting<C>,
core: Core<C>,
asking: bool,
) -> (OperationTarget, RedisResult<Response>) {
let route = if redirect.is_some() {
SingleNodeRoutingInfo::Random
} else {
match routing {
CommandRouting::Route(RoutingInfo::MultiNode((
multi_node_routing,
response_policy,
))) => {
assert!(!asking);
assert!(redirect.is_none());
return Self::execute_on_multiple_nodes(
&cmd,
&multi_node_routing,
core,
response_policy,
)
.await;
}
CommandRouting::Route(RoutingInfo::SingleNode(single_node_routing)) => {
single_node_routing
}
CommandRouting::Connection { addr, conn } => {
let mut conn = conn.await;
let result = conn.req_packed_command(&cmd).await.map(Response::Single);
return (addr.into(), result);
}
}
};
let (addr, mut conn) = Self::get_connection(redirect, route, core, asking).await;
let result = conn.req_packed_command(&cmd).await.map(Response::Single);
(addr.into(), result)
}
async fn try_pipeline_request(
pipeline: Arc<crate::Pipeline>,
offset: usize,
count: usize,
conn: impl Future<Output = (String, C)>,
) -> (OperationTarget, RedisResult<Response>) {
let (addr, mut conn) = conn.await;
let result = conn
.req_packed_commands(&pipeline, offset, count)
.await
.map(Response::Multiple);
(OperationTarget::Node { address: addr }, result)
}
async fn try_request(
info: RequestInfo<C>,
core: Core<C>,
) -> (OperationTarget, RedisResult<Response>) {
let asking = matches!(&info.redirect, Some(Redirect::Ask(_)));
match info.cmd {
CmdArg::Cmd { cmd, routing } => {
Self::try_cmd_request(cmd, info.redirect, routing, core, asking).await
}
CmdArg::Pipeline {
pipeline,
offset,
count,
route,
} => {
Self::try_pipeline_request(
pipeline,
offset,
count,
Self::get_connection(info.redirect, route, core, asking),
)
.await
}
}
}
async fn get_connection(
mut redirect: Option<Redirect>,
route: SingleNodeRoutingInfo,
core: Core<C>,
asking: bool,
) -> (String, C) {
let read_guard = core.conn_lock.read().await;
let conn = match redirect.take() {
Some(Redirect::Moved(moved_addr)) => Some(moved_addr),
Some(Redirect::Ask(ask_addr)) => Some(ask_addr),
None => match route {
SingleNodeRoutingInfo::Random => None,
SingleNodeRoutingInfo::SpecificNode(route) => read_guard
.1
.slot_addr_for_route(&route)
.map(|addr| addr.to_string()),
},
}
.map(|addr| {
let conn = read_guard.0.get(&addr).cloned();
(addr, conn)
});
drop(read_guard);
let addr_conn_option = match conn {
Some((addr, Some(conn))) => Some((addr, conn.await)),
Some((addr, None)) => {
connect_and_check(&addr, core.cluster_params.clone(), core.tls_params.clone())
.await
.ok()
.map(|conn| (addr, conn))
}
None => None,
};
let (addr, mut conn) = match addr_conn_option {
Some(tuple) => tuple,
None => {
let read_guard = core.conn_lock.read().await;
let (random_addr, random_conn_future) = get_random_connection(&read_guard.0);
drop(read_guard);
(random_addr, random_conn_future.await)
}
};
if asking {
let _ = conn.req_packed_command(&crate::cmd::cmd("ASKING")).await;
}
(addr, conn)
}
fn poll_recover(
&mut self,
cx: &mut task::Context<'_>,
future: RecoverFuture,
) -> Poll<Result<(), RedisError>> {
match future {
RecoverFuture::RecoverSlots(mut future) => match future.as_mut().poll(cx) {
Poll::Ready(Ok(_)) => {
trace!("Recovered!");
self.state = ConnectionState::PollComplete;
Poll::Ready(Ok(()))
}
Poll::Pending => {
self.state = ConnectionState::Recover(RecoverFuture::RecoverSlots(future));
trace!("Recover not ready");
Poll::Pending
}
Poll::Ready(Err(err)) => {
self.state = ConnectionState::Recover(RecoverFuture::RecoverSlots(Box::pin(
self.refresh_slots(),
)));
Poll::Ready(Err(err))
}
},
RecoverFuture::Reconnect(mut future) => match future.as_mut().poll(cx) {
Poll::Ready(_) => {
trace!("Reconnected connections");
self.state = ConnectionState::PollComplete;
Poll::Ready(Ok(()))
}
Poll::Pending => {
self.state = ConnectionState::Recover(RecoverFuture::Reconnect(future));
trace!("Recover not ready");
Poll::Pending
}
},
}
}
fn poll_complete(&mut self, cx: &mut task::Context<'_>) -> Poll<PollFlushAction> {
let mut poll_flush_action = PollFlushAction::None;
let mut pending_requests_guard = self.inner.pending_requests.lock().unwrap();
if !pending_requests_guard.is_empty() {
let mut pending_requests = mem::take(&mut *pending_requests_guard);
for request in pending_requests.drain(..) {
if request.sender.is_closed() {
continue;
}
let future = Self::try_request(request.info.clone(), self.inner.clone()).boxed();
self.in_flight_requests.push(Box::pin(Request {
retry_params: self.inner.cluster_params.retry_params.clone(),
request: Some(request),
future: RequestState::Future { future },
}));
}
*pending_requests_guard = pending_requests;
}
drop(pending_requests_guard);
loop {
let result = match Pin::new(&mut self.in_flight_requests).poll_next(cx) {
Poll::Ready(Some(result)) => result,
Poll::Ready(None) | Poll::Pending => break,
};
match result {
Next::Done => {}
Next::Retry { request } => {
let future = Self::try_request(request.info.clone(), self.inner.clone());
self.in_flight_requests.push(Box::pin(Request {
retry_params: self.inner.cluster_params.retry_params.clone(),
request: Some(request),
future: RequestState::Future {
future: Box::pin(future),
},
}));
}
Next::RefreshSlots { request } => {
poll_flush_action =
poll_flush_action.change_state(PollFlushAction::RebuildSlots);
let future = Self::try_request(request.info.clone(), self.inner.clone());
self.in_flight_requests.push(Box::pin(Request {
retry_params: self.inner.cluster_params.retry_params.clone(),
request: Some(request),
future: RequestState::Future {
future: Box::pin(future),
},
}));
}
Next::Reconnect {
request, target, ..
} => {
poll_flush_action =
poll_flush_action.change_state(PollFlushAction::Reconnect(vec![target]));
self.inner.pending_requests.lock().unwrap().push(request);
}
}
}
match poll_flush_action {
PollFlushAction::None => {
if self.in_flight_requests.is_empty() {
Poll::Ready(poll_flush_action)
} else {
Poll::Pending
}
}
rebuild @ PollFlushAction::RebuildSlots => Poll::Ready(rebuild),
reestablish @ PollFlushAction::Reconnect(_) => Poll::Ready(reestablish),
}
}
fn send_refresh_error(&mut self) {
if self.refresh_error.is_some() {
if let Some(mut request) = Pin::new(&mut self.in_flight_requests)
.iter_pin_mut()
.find(|request| request.request.is_some())
{
(*request)
.as_mut()
.respond(Err(self.refresh_error.take().unwrap()));
} else if let Some(request) = self.inner.pending_requests.lock().unwrap().pop() {
let _ = request.sender.send(Err(self.refresh_error.take().unwrap()));
}
}
}
async fn get_or_create_conn(
addr: &str,
conn_option: Option<ConnectionFuture<C>>,
params: &ClusterParams,
tls_params: Option<TlsConnParams>,
) -> RedisResult<C> {
if let Some(conn) = conn_option {
let mut conn = conn.await;
match check_connection(&mut conn).await {
Ok(_) => Ok(conn),
Err(_) => connect_and_check(addr, params.clone(), tls_params).await,
}
} else {
connect_and_check(addr, params.clone(), tls_params).await
}
}
}
enum PollFlushAction {
None,
RebuildSlots,
Reconnect(Vec<String>),
}
impl PollFlushAction {
fn change_state(self, next_state: PollFlushAction) -> PollFlushAction {
match self {
Self::None => next_state,
rebuild @ Self::RebuildSlots => rebuild,
Self::Reconnect(mut addrs) => match next_state {
rebuild @ Self::RebuildSlots => rebuild,
Self::Reconnect(new_addrs) => {
addrs.extend(new_addrs);
Self::Reconnect(addrs)
}
Self::None => Self::Reconnect(addrs),
},
}
}
}
impl<C> Sink<Message<C>> for ClusterConnInner<C>
where
C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static,
{
type Error = ();
fn poll_ready(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
match mem::replace(&mut self.state, ConnectionState::PollComplete) {
ConnectionState::PollComplete => Poll::Ready(Ok(())),
ConnectionState::Recover(future) => {
match ready!(self.as_mut().poll_recover(cx, future)) {
Ok(()) => Poll::Ready(Ok(())),
Err(err) => {
if let Some(mut request) = Pin::new(&mut self.in_flight_requests)
.iter_pin_mut()
.find(|request| request.request.is_some())
{
(*request).as_mut().respond(Err(err));
} else {
self.refresh_error = Some(err);
}
Poll::Ready(Ok(()))
}
}
}
}
}
fn start_send(self: Pin<&mut Self>, msg: Message<C>) -> Result<(), Self::Error> {
trace!("start_send");
let Message { cmd, sender } = msg;
let redirect = None;
let info = RequestInfo { cmd, redirect };
self.inner
.pending_requests
.lock()
.unwrap()
.push(PendingRequest {
retry: 0,
sender,
info,
});
Ok(())
}
fn poll_flush(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
trace!("poll_complete: {:?}", self.state);
loop {
self.send_refresh_error();
match mem::replace(&mut self.state, ConnectionState::PollComplete) {
ConnectionState::Recover(future) => {
match ready!(self.as_mut().poll_recover(cx, future)) {
Ok(()) => (),
Err(err) => {
self.refresh_error = Some(err);
cx.waker().wake_by_ref();
return Poll::Pending;
}
}
}
ConnectionState::PollComplete => match ready!(self.poll_complete(cx)) {
PollFlushAction::None => return Poll::Ready(Ok(())),
PollFlushAction::RebuildSlots => {
self.state = ConnectionState::Recover(RecoverFuture::RecoverSlots(
Box::pin(self.refresh_slots()),
));
}
PollFlushAction::Reconnect(addrs) => {
self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin(
self.refresh_connections(addrs),
)));
}
},
}
}
}
fn poll_close(
mut self: Pin<&mut Self>,
cx: &mut task::Context,
) -> Poll<Result<(), Self::Error>> {
match self.poll_complete(cx) {
Poll::Ready(poll_flush_action) => match poll_flush_action {
PollFlushAction::None => (),
_ => Err(()).map_err(|_| ())?,
},
Poll::Pending => (),
};
if self.in_flight_requests.is_empty() {
return Poll::Ready(Ok(()));
}
self.poll_flush(cx)
}
}
impl<C> ConnectionLike for ClusterConnection<C>
where
C: ConnectionLike + Send + Clone + Unpin + Sync + Connect + 'static,
{
fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
let routing = RoutingInfo::for_routable(cmd)
.unwrap_or(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random));
self.route_command(cmd, routing).boxed()
}
fn req_packed_commands<'a>(
&'a mut self,
pipeline: &'a crate::Pipeline,
offset: usize,
count: usize,
) -> RedisFuture<'a, Vec<Value>> {
async move {
let route = route_for_pipeline(pipeline)?;
self.route_pipeline(pipeline, offset, count, route.into())
.await
}
.boxed()
}
fn get_db(&self) -> i64 {
0
}
}
pub trait Connect: Sized {
fn connect<'a, T>(info: T) -> RedisFuture<'a, Self>
where
T: IntoConnectionInfo + Send + 'a;
}
impl Connect for MultiplexedConnection {
fn connect<'a, T>(info: T) -> RedisFuture<'a, MultiplexedConnection>
where
T: IntoConnectionInfo + Send + 'a,
{
async move {
let connection_info = info.into_connection_info()?;
let client = crate::Client::open(connection_info)?;
#[cfg(feature = "tokio-comp")]
return client.get_multiplexed_tokio_connection().await;
#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
return client.get_multiplexed_async_std_connection().await;
}
.boxed()
}
}
async fn connect_and_check<C>(
node: &str,
params: ClusterParams,
tls_params: Option<TlsConnParams>,
) -> RedisResult<C>
where
C: ConnectionLike + Connect + Send + 'static,
{
let read_from_replicas = params.read_from_replicas;
let info = get_connection_info(node, params, tls_params)?;
let mut conn = C::connect(info).await?;
check_connection(&mut conn).await?;
if read_from_replicas {
() = crate::cmd("READONLY").query_async(&mut conn).await?;
}
Ok(conn)
}
async fn check_connection<C>(conn: &mut C) -> RedisResult<()>
where
C: ConnectionLike + Send + 'static,
{
let mut cmd = Cmd::new();
cmd.arg("PING");
cmd.query_async::<_, String>(conn).await?;
Ok(())
}
fn get_random_connection<C>(connections: &ConnectionMap<C>) -> (String, ConnectionFuture<C>)
where
C: Clone,
{
let addr = connections
.keys()
.choose(&mut thread_rng())
.expect("Connections is empty")
.to_string();
let conn = connections
.get(&addr)
.expect("Connections is empty")
.clone();
(addr, conn)
}
#[cfg(test)]
mod pipeline_routing_tests {
use super::route_for_pipeline;
use crate::{
cluster_routing::{Route, SlotAddr},
cmd,
};
#[test]
fn test_first_route_is_found() {
let mut pipeline = crate::Pipeline::new();
pipeline
.add_command(cmd("FLUSHALL")) .get("foo") .add_command(cmd("EVAL"));
assert_eq!(
route_for_pipeline(&pipeline),
Ok(Some(Route::new(12182, SlotAddr::Replica)))
);
}
#[test]
fn test_return_none_if_no_route_is_found() {
let mut pipeline = crate::Pipeline::new();
pipeline
.add_command(cmd("FLUSHALL")) .add_command(cmd("EVAL"));
assert_eq!(route_for_pipeline(&pipeline), Ok(None));
}
#[test]
fn test_prefer_primary_route_over_replica() {
let mut pipeline = crate::Pipeline::new();
pipeline
.get("foo") .add_command(cmd("FLUSHALL")) .add_command(cmd("EVAL")) .set("foo", "bar");
assert_eq!(
route_for_pipeline(&pipeline),
Ok(Some(Route::new(12182, SlotAddr::Master)))
);
}
#[test]
fn test_raise_cross_slot_error_on_conflicting_slots() {
let mut pipeline = crate::Pipeline::new();
pipeline
.add_command(cmd("FLUSHALL")) .set("baz", "bar") .get("foo");
assert_eq!(
route_for_pipeline(&pipeline).unwrap_err().kind(),
crate::ErrorKind::CrossSlot
);
}
}