1use std::{
94 collections::HashMap,
95 fmt,
96 future::Future,
97 io, mem,
98 pin::Pin,
99 sync::{Arc, Mutex},
100 task::{self, Poll},
101 time::Duration,
102};
103
104mod request;
105mod routing;
106use crate::{
107 aio::{check_resp3, ConnectionLike, HandleContainer, MultiplexedConnection, Runtime},
108 cluster::{get_connection_info, slot_cmd},
109 cluster_client::ClusterParams,
110 cluster_routing::{
111 MultipleNodeRoutingInfo, Redirect, ResponsePolicy, RoutingInfo, SingleNodeRoutingInfo,
112 Slot, SlotMap,
113 },
114 cluster_topology::parse_slots,
115 cmd,
116 subscription_tracker::SubscriptionTracker,
117 types::closed_connection_error,
118 AsyncConnectionConfig, Cmd, ConnectionInfo, ErrorKind, IntoConnectionInfo, RedisError,
119 RedisFuture, RedisResult, ToRedisArgs, Value,
120};
121
122use crate::ProtocolVersion;
123use futures_sink::Sink;
124use futures_util::{
125 future::{self, BoxFuture, FutureExt},
126 ready,
127 stream::{self, Stream, StreamExt},
128};
129use log::{debug, trace, warn};
130use rand::{rng, seq::IteratorRandom};
131use request::{CmdArg, PendingRequest, Request, RequestState, Retry};
132use routing::{route_for_pipeline, InternalRoutingInfo, InternalSingleNodeRouting};
133use tokio::sync::{mpsc, oneshot, RwLock};
134
135struct ClientSideState {
136 protocol: ProtocolVersion,
137 _task_handle: HandleContainer,
138 response_timeout: Option<Duration>,
139 runtime: Runtime,
140}
141
142#[derive(Clone)]
147pub struct ClusterConnection<C = MultiplexedConnection> {
148 state: Arc<ClientSideState>,
149 sender: mpsc::Sender<Message<C>>,
150}
151
152impl<C> ClusterConnection<C>
153where
154 C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static,
155{
156 pub(crate) async fn new(
157 initial_nodes: &[ConnectionInfo],
158 cluster_params: ClusterParams,
159 ) -> RedisResult<ClusterConnection<C>> {
160 let protocol = cluster_params.protocol.unwrap_or_default();
161 let response_timeout = cluster_params.response_timeout;
162 let runtime = Runtime::locate();
163 ClusterConnInner::new(initial_nodes, cluster_params)
164 .await
165 .map(|inner| {
166 let (sender, mut receiver) = mpsc::channel::<Message<_>>(100);
167 let stream = async move {
168 let _ = stream::poll_fn(move |cx| receiver.poll_recv(cx))
169 .map(Ok)
170 .forward(inner)
171 .await;
172 };
173 let _task_handle = HandleContainer::new(runtime.spawn(stream));
174
175 ClusterConnection {
176 sender,
177 state: Arc::new(ClientSideState {
178 protocol,
179 _task_handle,
180 response_timeout,
181 runtime,
182 }),
183 }
184 })
185 }
186
187 pub async fn route_command(&mut self, cmd: &Cmd, routing: RoutingInfo) -> RedisResult<Value> {
189 trace!("send_packed_command");
190 let (sender, receiver) = oneshot::channel();
191 let request = async {
192 self.sender
193 .send(Message {
194 cmd: CmdArg::Cmd {
195 cmd: Arc::new(cmd.clone()), routing: routing.into(),
197 },
198 sender,
199 })
200 .await
201 .map_err(|_| {
202 RedisError::from(io::Error::new(
203 io::ErrorKind::BrokenPipe,
204 "redis_cluster: Unable to send command",
205 ))
206 })?;
207
208 receiver
209 .await
210 .unwrap_or_else(|_| {
211 Err(RedisError::from(io::Error::new(
212 io::ErrorKind::BrokenPipe,
213 "redis_cluster: Unable to receive command",
214 )))
215 })
216 .map(|response| match response {
217 Response::Single(value) => value,
218 Response::Multiple(_) => unreachable!(),
219 })
220 };
221
222 match self.state.response_timeout {
223 Some(duration) => self.state.runtime.timeout(duration, request).await?,
224 None => request.await,
225 }
226 }
227
228 pub async fn route_pipeline<'a>(
230 &'a mut self,
231 pipeline: &'a crate::Pipeline,
232 offset: usize,
233 count: usize,
234 route: SingleNodeRoutingInfo,
235 ) -> RedisResult<Vec<Value>> {
236 let (sender, receiver) = oneshot::channel();
237
238 let request = async {
239 self.sender
240 .send(Message {
241 cmd: CmdArg::Pipeline {
242 pipeline: Arc::new(pipeline.clone()), offset,
244 count,
245 route: route.into(),
246 },
247 sender,
248 })
249 .await
250 .map_err(|_| closed_connection_error())?;
251 receiver
252 .await
253 .unwrap_or_else(|_| Err(closed_connection_error()))
254 .map(|response| match response {
255 Response::Multiple(values) => values,
256 Response::Single(_) => unreachable!(),
257 })
258 };
259
260 match self.state.response_timeout {
261 Some(duration) => self.state.runtime.timeout(duration, request).await?,
262 None => request.await,
263 }
264 }
265
266 pub async fn subscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
275 check_resp3!(self.state.protocol);
276 let mut cmd = cmd("SUBSCRIBE");
277 cmd.arg(channel_name);
278 cmd.exec_async(self).await?;
279 Ok(())
280 }
281
282 pub async fn unsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
286 check_resp3!(self.state.protocol);
287 let mut cmd = cmd("UNSUBSCRIBE");
288 cmd.arg(channel_name);
289 cmd.exec_async(self).await?;
290 Ok(())
291 }
292
293 pub async fn psubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
302 check_resp3!(self.state.protocol);
303 let mut cmd = cmd("PSUBSCRIBE");
304 cmd.arg(channel_pattern);
305 cmd.exec_async(self).await?;
306 Ok(())
307 }
308
309 pub async fn punsubscribe(&mut self, channel_pattern: impl ToRedisArgs) -> RedisResult<()> {
313 check_resp3!(self.state.protocol);
314 let mut cmd = cmd("PUNSUBSCRIBE");
315 cmd.arg(channel_pattern);
316 cmd.exec_async(self).await?;
317 Ok(())
318 }
319
320 pub async fn ssubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
329 check_resp3!(self.state.protocol);
330 let mut cmd = cmd("SSUBSCRIBE");
331 cmd.arg(channel_name);
332 cmd.exec_async(self).await?;
333 Ok(())
334 }
335
336 pub async fn sunsubscribe(&mut self, channel_name: impl ToRedisArgs) -> RedisResult<()> {
340 check_resp3!(self.state.protocol);
341 let mut cmd = cmd("SUNSUBSCRIBE");
342 cmd.arg(channel_name);
343 cmd.exec_async(self).await?;
344 Ok(())
345 }
346}
347
348type ConnectionMap<C> = HashMap<String, C>;
349
350struct InnerCore<C> {
354 conn_lock: RwLock<(ConnectionMap<C>, SlotMap)>,
355 cluster_params: ClusterParams,
356 pending_requests: Mutex<Vec<PendingRequest<C>>>,
357 initial_nodes: Vec<ConnectionInfo>,
358 subscription_tracker: Option<Mutex<SubscriptionTracker>>,
359}
360
361type Core<C> = Arc<InnerCore<C>>;
362
363struct ClusterConnInner<C> {
367 inner: Core<C>,
368 state: ConnectionState,
369 #[allow(clippy::complexity)]
370 in_flight_requests: stream::FuturesUnordered<Pin<Box<Request<C>>>>,
371 refresh_error: Option<RedisError>,
372}
373
374fn boxed_sleep(duration: Duration) -> BoxFuture<'static, ()> {
375 Box::pin(Runtime::locate_and_sleep(duration))
376}
377
378#[derive(Debug, PartialEq)]
379pub(crate) enum Response {
380 Single(Value),
381 Multiple(Vec<Value>),
382}
383
384enum OperationTarget {
385 Node { address: String },
386 NotFound,
387 FanOut,
388}
389type OperationResult = Result<Response, (OperationTarget, RedisError)>;
390
391impl From<String> for OperationTarget {
392 fn from(address: String) -> Self {
393 OperationTarget::Node { address }
394 }
395}
396
397struct Message<C> {
398 cmd: CmdArg<C>,
399 sender: oneshot::Sender<RedisResult<Response>>,
400}
401
402enum RecoverFuture {
403 RecoverSlots(BoxFuture<'static, RedisResult<()>>),
404 Reconnect(BoxFuture<'static, ()>),
405}
406
407enum ConnectionState {
408 PollComplete,
409 Recover(RecoverFuture),
410}
411
412impl fmt::Debug for ConnectionState {
413 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
414 write!(
415 f,
416 "{}",
417 match self {
418 ConnectionState::PollComplete => "PollComplete",
419 ConnectionState::Recover(_) => "Recover",
420 }
421 )
422 }
423}
424
425impl<C> ClusterConnInner<C>
426where
427 C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
428{
429 async fn new(
430 initial_nodes: &[ConnectionInfo],
431 cluster_params: ClusterParams,
432 ) -> RedisResult<Self> {
433 let connections = Self::create_initial_connections(initial_nodes, &cluster_params).await?;
434 let subscription_tracker = if cluster_params.async_push_sender.is_some() {
435 Some(Mutex::new(SubscriptionTracker::default()))
436 } else {
437 None
438 };
439 let inner = Arc::new(InnerCore {
440 conn_lock: RwLock::new((connections, SlotMap::new(cluster_params.read_from_replicas))),
441 cluster_params,
442 pending_requests: Mutex::new(Vec::new()),
443 initial_nodes: initial_nodes.to_vec(),
444 subscription_tracker,
445 });
446 let connection = ClusterConnInner {
447 inner,
448 in_flight_requests: Default::default(),
449 refresh_error: None,
450 state: ConnectionState::PollComplete,
451 };
452 Self::refresh_slots(connection.inner.clone()).await?;
453 Ok(connection)
454 }
455
456 async fn create_initial_connections(
457 initial_nodes: &[ConnectionInfo],
458 params: &ClusterParams,
459 ) -> RedisResult<ConnectionMap<C>> {
460 let (connections, error) = stream::iter(initial_nodes.iter().cloned())
461 .map(|info| {
462 let params = params.clone();
463 async move {
464 let addr = info.addr.to_string();
465 let result = connect_and_check(&addr, params).await;
466 match result {
467 Ok(conn) => Ok((addr, conn)),
468 Err(e) => {
469 debug!("Failed to connect to initial node: {:?}", e);
470 Err(e)
471 }
472 }
473 }
474 })
475 .buffer_unordered(initial_nodes.len())
476 .fold(
477 (ConnectionMap::<C>::with_capacity(initial_nodes.len()), None),
478 |(mut connections, mut error), result| async move {
479 match result {
480 Ok((addr, conn)) => {
481 connections.insert(addr, conn);
482 }
483 Err(err) => {
484 error = Some(err);
487 }
488 }
489 (connections, error)
490 },
491 )
492 .await;
493 if connections.is_empty() {
494 if let Some(err) = error {
495 return Err(RedisError::from((
496 ErrorKind::IoError,
497 "Failed to create initial connections",
498 err.to_string(),
499 )));
500 } else {
501 return Err(RedisError::from((
502 ErrorKind::IoError,
503 "Failed to create initial connections",
504 )));
505 }
506 }
507 Ok(connections)
508 }
509
510 fn resubscribe(&self) {
511 let Some(subscription_tracker) = self.inner.subscription_tracker.as_ref() else {
512 return;
513 };
514
515 let subscription_pipe = subscription_tracker
516 .lock()
517 .unwrap()
518 .get_subscription_pipeline();
519
520 let requests = subscription_pipe.cmd_iter().map(|cmd| {
522 let routing = RoutingInfo::for_routable(cmd)
523 .unwrap_or(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random))
524 .into();
525 PendingRequest {
526 retry: 0,
527 sender: request::ResultExpectation::Internal,
528 cmd: CmdArg::Cmd {
529 cmd: Arc::new(cmd.clone()),
530 routing,
531 },
532 }
533 });
534 self.inner.pending_requests.lock().unwrap().extend(requests);
535 }
536
537 fn reconnect_to_initial_nodes(&mut self) -> impl Future<Output = ()> {
538 debug!("Received request to reconnect to initial nodes");
539 let inner = self.inner.clone();
540 async move {
541 let connection_map =
542 match Self::create_initial_connections(&inner.initial_nodes, &inner.cluster_params)
543 .await
544 {
545 Ok(map) => map,
546 Err(err) => {
547 warn!("Can't reconnect to initial nodes: `{err}`");
548 return;
549 }
550 };
551 let mut write_lock = inner.conn_lock.write().await;
552 *write_lock = (
553 connection_map,
554 SlotMap::new(inner.cluster_params.read_from_replicas),
555 );
556 drop(write_lock);
557 if let Err(err) = Self::refresh_slots(inner.clone()).await {
558 warn!("Can't refresh slots with initial nodes: `{err}`");
559 };
560 }
561 }
562
563 fn refresh_connections(&mut self, addrs: Vec<String>) -> impl Future<Output = ()> {
564 let inner = self.inner.clone();
565 async move {
566 let mut write_guard = inner.conn_lock.write().await;
567
568 Self::refresh_connections_locked(&inner, &mut write_guard.0, addrs).await;
569 }
570 }
571
572 async fn refresh_slots(inner: Core<C>) -> RedisResult<()> {
574 let mut write_guard = inner.conn_lock.write().await;
575 let (connections, slots) = &mut *write_guard;
576
577 let mut result = Ok(());
578 for (addr, conn) in &mut *connections {
579 result = async {
580 let value = conn
581 .req_packed_command(&slot_cmd())
582 .await
583 .and_then(|value| value.extract_error())?;
584 let v: Vec<Slot> = parse_slots(
585 value,
586 inner.cluster_params.tls,
587 addr.rsplit_once(':').unwrap().0,
588 )?;
589 Self::build_slot_map(slots, v)
590 }
591 .await;
592 if result.is_ok() {
593 break;
594 }
595 }
596 result?;
597
598 let mut nodes = slots.values().flatten().cloned().collect::<Vec<_>>();
599 nodes.sort_unstable();
600 nodes.dedup();
601 Self::refresh_connections_locked(&inner, connections, nodes).await;
602
603 Ok(())
604 }
605
606 async fn refresh_connections_locked(
607 inner: &Core<C>,
608 connections: &mut ConnectionMap<C>,
609 nodes: Vec<String>,
610 ) {
611 let nodes_len = nodes.len();
612
613 let addresses_and_connections_iter = nodes.into_iter().map(|addr| {
614 let value = connections.remove(&addr);
615 (addr, value)
616 });
617
618 let inner = &inner;
619 *connections = stream::iter(addresses_and_connections_iter)
620 .map(|(addr, connection)| async move {
621 (
622 addr.clone(),
623 Self::get_or_create_conn(&addr, connection, &inner.cluster_params).await,
624 )
625 })
626 .buffer_unordered(nodes_len.max(8))
627 .fold(
628 HashMap::with_capacity(nodes_len),
629 |mut connections, (addr, result)| async move {
630 if let Ok(conn) = result {
631 connections.insert(addr, conn);
632 }
633 connections
634 },
635 )
636 .await;
637 }
638
639 fn build_slot_map(slot_map: &mut SlotMap, slots_data: Vec<Slot>) -> RedisResult<()> {
640 slot_map.clear();
641 slot_map.fill_slots(slots_data);
642 trace!("{:?}", slot_map);
643 Ok(())
644 }
645
646 async fn aggregate_results(
647 receivers: Vec<(String, oneshot::Receiver<RedisResult<Response>>)>,
648 routing: &MultipleNodeRoutingInfo,
649 response_policy: Option<ResponsePolicy>,
650 ) -> RedisResult<Value> {
651 if receivers.is_empty() {
652 return Err((
653 ErrorKind::ClusterConnectionNotFound,
654 "No nodes found for multi-node operation",
655 )
656 .into());
657 }
658
659 let extract_result = |response| match response {
660 Response::Single(value) => value,
661 Response::Multiple(_) => unreachable!(),
662 };
663
664 let convert_result = |res: Result<RedisResult<Response>, _>| {
665 res.map_err(|_| RedisError::from((ErrorKind::ResponseError, "request wasn't handled due to internal failure"))) .and_then(|res| res.map(extract_result))
667 };
668
669 let get_receiver = |(_, receiver): (_, oneshot::Receiver<RedisResult<Response>>)| async {
670 convert_result(receiver.await)
671 };
672
673 match response_policy {
675 Some(ResponsePolicy::AllSucceeded) => {
676 future::try_join_all(receivers.into_iter().map(get_receiver))
677 .await
678 .and_then(|mut results| {
679 results.pop().ok_or(
680 (
681 ErrorKind::ClusterConnectionNotFound,
682 "No results received for multi-node operation",
683 )
684 .into(),
685 )
686 })
687 }
688 Some(ResponsePolicy::OneSucceeded) => future::select_ok(
689 receivers
690 .into_iter()
691 .map(|tuple| Box::pin(get_receiver(tuple))),
692 )
693 .await
694 .map(|(result, _)| result),
695 Some(ResponsePolicy::OneSucceededNonEmpty) => {
696 future::select_ok(receivers.into_iter().map(|(_, receiver)| {
697 Box::pin(async move {
698 let result = convert_result(receiver.await)?;
699 match result {
700 Value::Nil => Err((ErrorKind::ResponseError, "no value found").into()),
701 _ => Ok(result),
702 }
703 })
704 }))
705 .await
706 .map(|(result, _)| result)
707 }
708 Some(ResponsePolicy::Aggregate(op)) => {
709 future::try_join_all(receivers.into_iter().map(get_receiver))
710 .await
711 .and_then(|results| crate::cluster_routing::aggregate(results, op))
712 }
713 Some(ResponsePolicy::AggregateLogical(op)) => {
714 future::try_join_all(receivers.into_iter().map(get_receiver))
715 .await
716 .and_then(|results| crate::cluster_routing::logical_aggregate(results, op))
717 }
718 Some(ResponsePolicy::CombineArrays) => {
719 future::try_join_all(receivers.into_iter().map(get_receiver))
720 .await
721 .and_then(|results| match routing {
722 MultipleNodeRoutingInfo::MultiSlot(vec) => {
723 crate::cluster_routing::combine_and_sort_array_results(
724 results,
725 vec.iter().map(|(_, indices)| indices),
726 )
727 }
728 _ => crate::cluster_routing::combine_array_results(results),
729 })
730 }
731 Some(ResponsePolicy::Special) | None => {
732 future::try_join_all(receivers.into_iter().map(|(addr, receiver)| async move {
736 let result = convert_result(receiver.await)?;
737 Ok((Value::BulkString(addr.into_bytes()), result))
738 }))
739 .await
740 .map(Value::Map)
741 }
742 }
743 }
744
745 async fn execute_on_multiple_nodes<'a>(
746 cmd: &'a Arc<Cmd>,
747 routing: &'a MultipleNodeRoutingInfo,
748 core: Core<C>,
749 response_policy: Option<ResponsePolicy>,
750 ) -> OperationResult {
751 let read_guard = core.conn_lock.read().await;
752 if read_guard.0.is_empty() {
753 return OperationResult::Err((
754 OperationTarget::FanOut,
755 (
756 ErrorKind::ClusterConnectionNotFound,
757 "No connections found for multi-node operation",
758 )
759 .into(),
760 ));
761 }
762 let (receivers, requests): (Vec<_>, Vec<_>) = {
763 let to_request = |(addr, cmd): (&str, Arc<Cmd>)| {
764 read_guard.0.get(addr).cloned().map(|conn| {
765 let (sender, receiver) = oneshot::channel();
766 let addr = addr.to_string();
767 (
768 (addr.clone(), receiver),
769 PendingRequest {
770 retry: 0,
771 sender: request::ResultExpectation::External(sender),
772 cmd: CmdArg::Cmd {
773 cmd,
774 routing: InternalSingleNodeRouting::Connection {
775 identifier: addr,
776 conn,
777 }
778 .into(),
779 },
780 },
781 )
782 })
783 };
784 let slot_map = &read_guard.1;
785
786 match routing {
789 MultipleNodeRoutingInfo::AllNodes => slot_map
790 .addresses_for_all_nodes()
791 .into_iter()
792 .filter_map(|addr| to_request((addr, cmd.clone())))
793 .unzip(),
794 MultipleNodeRoutingInfo::AllMasters => slot_map
795 .addresses_for_all_primaries()
796 .into_iter()
797 .filter_map(|addr| to_request((addr, cmd.clone())))
798 .unzip(),
799 MultipleNodeRoutingInfo::MultiSlot(routes) => slot_map
800 .addresses_for_multi_slot(routes)
801 .enumerate()
802 .filter_map(|(index, addr_opt)| {
803 addr_opt.and_then(|addr| {
804 let (_, indices) = routes.get(index).unwrap();
805 let cmd =
806 Arc::new(crate::cluster_routing::command_for_multi_slot_indices(
807 cmd.as_ref(),
808 indices.iter(),
809 ));
810 to_request((addr, cmd))
811 })
812 })
813 .unzip(),
814 }
815 };
816 drop(read_guard);
817 core.pending_requests.lock().unwrap().extend(requests);
818
819 Self::aggregate_results(receivers, routing, response_policy)
820 .await
821 .map(Response::Single)
822 .map_err(|err| (OperationTarget::FanOut, err))
823 }
824
825 async fn try_cmd_request(
826 cmd: Arc<Cmd>,
827 routing: InternalRoutingInfo<C>,
828 core: Core<C>,
829 ) -> OperationResult {
830 let route = match routing {
831 InternalRoutingInfo::SingleNode(single_node_routing) => single_node_routing,
832 InternalRoutingInfo::MultiNode((multi_node_routing, response_policy)) => {
833 return Self::execute_on_multiple_nodes(
834 &cmd,
835 &multi_node_routing,
836 core,
837 response_policy,
838 )
839 .await;
840 }
841 };
842
843 match Self::get_connection(route, core).await {
844 Ok((addr, mut conn)) => conn
845 .req_packed_command(&cmd)
846 .await
847 .and_then(|value| value.extract_error())
848 .map(Response::Single)
849 .map_err(|err| (addr.into(), err)),
850 Err(err) => Err((OperationTarget::NotFound, err)),
851 }
852 }
853
854 async fn try_pipeline_request(
855 pipeline: Arc<crate::Pipeline>,
856 offset: usize,
857 count: usize,
858 conn: impl Future<Output = RedisResult<(String, C)>>,
859 ) -> OperationResult {
860 match conn.await {
861 Ok((addr, mut conn)) => conn
862 .req_packed_commands(&pipeline, offset, count)
863 .await
864 .and_then(Value::extract_error_vec)
865 .map(Response::Multiple)
866 .map_err(|err| (OperationTarget::Node { address: addr }, err)),
867 Err(err) => Err((OperationTarget::NotFound, err)),
868 }
869 }
870
871 async fn try_request(cmd: CmdArg<C>, core: Core<C>) -> OperationResult {
872 match cmd {
873 CmdArg::Cmd { cmd, routing } => Self::try_cmd_request(cmd, routing, core).await,
874 CmdArg::Pipeline {
875 pipeline,
876 offset,
877 count,
878 route,
879 } => {
880 Self::try_pipeline_request(
881 pipeline,
882 offset,
883 count,
884 Self::get_connection(route, core),
885 )
886 .await
887 }
888 }
889 }
890
891 async fn get_connection(
892 route: InternalSingleNodeRouting<C>,
893 core: Core<C>,
894 ) -> RedisResult<(String, C)> {
895 let read_guard = core.conn_lock.read().await;
896
897 let conn = match route {
898 InternalSingleNodeRouting::Random => None,
899 InternalSingleNodeRouting::SpecificNode(route) => read_guard
900 .1
901 .slot_addr_for_route(&route)
902 .map(|addr| addr.to_string()),
903 InternalSingleNodeRouting::Connection { identifier, conn } => {
904 return Ok((identifier, conn));
905 }
906 InternalSingleNodeRouting::Redirect { redirect, .. } => {
907 drop(read_guard);
908 return Self::get_redirected_connection(redirect, core).await;
910 }
911 InternalSingleNodeRouting::ByAddress(address) => {
912 if let Some(conn) = read_guard.0.get(&address).cloned() {
913 return Ok((address, conn));
914 } else {
915 return Err((
916 ErrorKind::ClientError,
917 "Requested connection not found",
918 address,
919 )
920 .into());
921 }
922 }
923 }
924 .map(|addr| {
925 let conn = read_guard.0.get(&addr).cloned();
926 (addr, conn)
927 });
928 drop(read_guard);
929
930 let addr_conn_option = match conn {
931 Some((addr, Some(conn))) => Some((addr, conn)),
932 Some((addr, None)) => connect_check_and_add(core.clone(), addr.clone())
933 .await
934 .ok()
935 .map(|conn| (addr, conn)),
936 None => None,
937 };
938
939 let (addr, conn) = match addr_conn_option {
940 Some(tuple) => tuple,
941 None => {
942 let read_guard = core.conn_lock.read().await;
943 if let Some((random_addr, random_conn)) = get_random_connection(&read_guard.0) {
944 drop(read_guard);
945 (random_addr, random_conn)
946 } else {
947 return Err(
948 (ErrorKind::ClusterConnectionNotFound, "No connections found").into(),
949 );
950 }
951 }
952 };
953
954 Ok((addr, conn))
955 }
956
957 async fn get_redirected_connection(
958 redirect: Redirect,
959 core: Core<C>,
960 ) -> RedisResult<(String, C)> {
961 let asking = matches!(redirect, Redirect::Ask(_));
962 let addr = match redirect {
963 Redirect::Moved(addr) => addr,
964 Redirect::Ask(addr) => addr,
965 };
966 let read_guard = core.conn_lock.read().await;
967 let conn = read_guard.0.get(&addr).cloned();
968 drop(read_guard);
969 let mut conn = match conn {
970 Some(conn) => conn,
971 None => connect_check_and_add(core.clone(), addr.clone()).await?,
972 };
973 if asking {
974 let _ = conn
975 .req_packed_command(&crate::cmd::cmd("ASKING"))
976 .await
977 .and_then(|value| value.extract_error());
978 }
979
980 Ok((addr, conn))
981 }
982
983 fn poll_recover(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), RedisError>> {
984 let recover_future = match &mut self.state {
985 ConnectionState::PollComplete => return Poll::Ready(Ok(())),
986 ConnectionState::Recover(future) => future,
987 };
988 let res = match recover_future {
989 RecoverFuture::RecoverSlots(ref mut future) => match ready!(future.as_mut().poll(cx)) {
990 Ok(_) => {
991 trace!("Recovered!");
992 self.state = ConnectionState::PollComplete;
993 Ok(())
994 }
995 Err(err) => {
996 trace!("Recover slots failed!");
997 *future = Box::pin(Self::refresh_slots(self.inner.clone()));
998 Err(err)
999 }
1000 },
1001 RecoverFuture::Reconnect(ref mut future) => {
1002 ready!(future.as_mut().poll(cx));
1003 trace!("Reconnected connections");
1004 self.state = ConnectionState::PollComplete;
1005 Ok(())
1006 }
1007 };
1008 if res.is_ok() {
1009 self.resubscribe();
1010 }
1011 Poll::Ready(res)
1012 }
1013
1014 fn poll_complete(&mut self, cx: &mut task::Context<'_>) -> Poll<PollFlushAction> {
1015 let mut poll_flush_action = PollFlushAction::None;
1016
1017 let mut pending_requests_guard = self.inner.pending_requests.lock().unwrap();
1018 if !pending_requests_guard.is_empty() {
1019 let mut pending_requests = mem::take(&mut *pending_requests_guard);
1020 for request in pending_requests.drain(..) {
1021 if request.sender.is_closed() {
1025 continue;
1026 }
1027
1028 let future = Self::try_request(request.cmd.clone(), self.inner.clone()).boxed();
1029 self.in_flight_requests.push(Box::pin(Request {
1030 retry_params: self.inner.cluster_params.retry_params.clone(),
1031 request: Some(request),
1032 future: RequestState::Future { future },
1033 }));
1034 }
1035 *pending_requests_guard = pending_requests;
1036 }
1037 drop(pending_requests_guard);
1038
1039 loop {
1040 let (request_handling, next) =
1041 match Pin::new(&mut self.in_flight_requests).poll_next(cx) {
1042 Poll::Ready(Some(result)) => result,
1043 Poll::Ready(None) | Poll::Pending => break,
1044 };
1045 match request_handling {
1046 Some(Retry::MoveToPending { request }) => {
1047 self.inner.pending_requests.lock().unwrap().push(request)
1048 }
1049 Some(Retry::Immediately { request }) => {
1050 let future = Self::try_request(request.cmd.clone(), self.inner.clone());
1051 self.in_flight_requests.push(Box::pin(Request {
1052 retry_params: self.inner.cluster_params.retry_params.clone(),
1053 request: Some(request),
1054 future: RequestState::Future {
1055 future: Box::pin(future),
1056 },
1057 }));
1058 }
1059 Some(Retry::AfterSleep {
1060 request,
1061 sleep_duration,
1062 }) => {
1063 let future = RequestState::Sleep {
1064 sleep: boxed_sleep(sleep_duration),
1065 };
1066 self.in_flight_requests.push(Box::pin(Request {
1067 retry_params: self.inner.cluster_params.retry_params.clone(),
1068 request: Some(request),
1069 future,
1070 }));
1071 }
1072 None => {}
1073 };
1074 poll_flush_action = poll_flush_action.change_state(next);
1075 }
1076
1077 if !matches!(poll_flush_action, PollFlushAction::None) || self.in_flight_requests.is_empty()
1078 {
1079 Poll::Ready(poll_flush_action)
1080 } else {
1081 Poll::Pending
1082 }
1083 }
1084
1085 fn send_refresh_error(&mut self) {
1086 if self.refresh_error.is_some() {
1087 if let Some(mut request) = Pin::new(&mut self.in_flight_requests)
1088 .iter_pin_mut()
1089 .find(|request| request.request.is_some())
1090 {
1091 (*request)
1092 .as_mut()
1093 .respond(Err(self.refresh_error.take().unwrap()));
1094 } else if let Some(request) = self.inner.pending_requests.lock().unwrap().pop() {
1095 request.sender.send(Err(self.refresh_error.take().unwrap()));
1096 }
1097 }
1098 }
1099
1100 async fn get_or_create_conn(
1101 addr: &str,
1102 conn_option: Option<C>,
1103 params: &ClusterParams,
1104 ) -> RedisResult<C> {
1105 if let Some(mut conn) = conn_option {
1106 match check_connection(&mut conn).await {
1107 Ok(_) => Ok(conn),
1108 Err(_) => connect_and_check(addr, params.clone()).await,
1109 }
1110 } else {
1111 connect_and_check(addr, params.clone()).await
1112 }
1113 }
1114}
1115
1116#[derive(Debug, PartialEq)]
1117enum PollFlushAction {
1118 None,
1119 RebuildSlots,
1120 Reconnect(Vec<String>),
1121 ReconnectFromInitialConnections,
1122}
1123
1124impl PollFlushAction {
1125 fn change_state(self, next_state: PollFlushAction) -> PollFlushAction {
1126 match (self, next_state) {
1127 (PollFlushAction::None, next_state) => next_state,
1128 (next_state, PollFlushAction::None) => next_state,
1129 (PollFlushAction::ReconnectFromInitialConnections, _)
1130 | (_, PollFlushAction::ReconnectFromInitialConnections) => {
1131 PollFlushAction::ReconnectFromInitialConnections
1132 }
1133
1134 (PollFlushAction::RebuildSlots, _) | (_, PollFlushAction::RebuildSlots) => {
1135 PollFlushAction::RebuildSlots
1136 }
1137
1138 (PollFlushAction::Reconnect(mut addrs), PollFlushAction::Reconnect(new_addrs)) => {
1139 addrs.extend(new_addrs);
1140 Self::Reconnect(addrs)
1141 }
1142 }
1143 }
1144}
1145
1146impl<C> Sink<Message<C>> for ClusterConnInner<C>
1147where
1148 C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static,
1149{
1150 type Error = ();
1151
1152 fn poll_ready(self: Pin<&mut Self>, _cx: &mut task::Context) -> Poll<Result<(), Self::Error>> {
1153 Poll::Ready(Ok(()))
1154 }
1155
1156 fn start_send(self: Pin<&mut Self>, msg: Message<C>) -> Result<(), Self::Error> {
1157 trace!("start_send");
1158 let Message { cmd, sender } = msg;
1159
1160 if let Some(tracker) = &self.inner.subscription_tracker {
1161 let mut tracker = tracker.lock().unwrap();
1163 match &cmd {
1164 CmdArg::Cmd { cmd, .. } => tracker.update_with_cmd(cmd.as_ref()),
1165 CmdArg::Pipeline { pipeline, .. } => {
1166 tracker.update_with_pipeline(pipeline.as_ref())
1167 }
1168 }
1169 };
1170
1171 self.inner
1172 .pending_requests
1173 .lock()
1174 .unwrap()
1175 .push(PendingRequest {
1176 retry: 0,
1177 sender: request::ResultExpectation::External(sender),
1178 cmd,
1179 });
1180 Ok(())
1181 }
1182
1183 fn poll_flush(
1184 mut self: Pin<&mut Self>,
1185 cx: &mut task::Context,
1186 ) -> Poll<Result<(), Self::Error>> {
1187 trace!("poll_flush: {:?}", self.state);
1188 loop {
1189 self.send_refresh_error();
1190
1191 if let Err(err) = ready!(self.as_mut().poll_recover(cx)) {
1192 self.refresh_error = Some(err);
1196
1197 cx.waker().wake_by_ref();
1201 return Poll::Pending;
1202 }
1203
1204 match ready!(self.poll_complete(cx)) {
1205 PollFlushAction::None => return Poll::Ready(Ok(())),
1206 PollFlushAction::RebuildSlots => {
1207 self.state = ConnectionState::Recover(RecoverFuture::RecoverSlots(Box::pin(
1208 Self::refresh_slots(self.inner.clone()),
1209 )));
1210 }
1211 PollFlushAction::Reconnect(addrs) => {
1212 self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin(
1213 self.refresh_connections(addrs),
1214 )));
1215 }
1216 PollFlushAction::ReconnectFromInitialConnections => {
1217 self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin(
1218 self.reconnect_to_initial_nodes(),
1219 )));
1220 }
1221 }
1222 }
1223 }
1224
1225 fn poll_close(
1226 mut self: Pin<&mut Self>,
1227 cx: &mut task::Context,
1228 ) -> Poll<Result<(), Self::Error>> {
1229 match self.poll_complete(cx) {
1231 Poll::Ready(PollFlushAction::None) => (),
1232 Poll::Ready(_) => Err(())?,
1233 Poll::Pending => (),
1234 };
1235 if self.in_flight_requests.is_empty() {
1238 return Poll::Ready(Ok(()));
1239 }
1240
1241 self.poll_flush(cx)
1242 }
1243}
1244
1245impl<C> ConnectionLike for ClusterConnection<C>
1246where
1247 C: ConnectionLike + Send + Clone + Unpin + Sync + Connect + 'static,
1248{
1249 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
1250 let routing = RoutingInfo::for_routable(cmd)
1251 .unwrap_or(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random));
1252 self.route_command(cmd, routing).boxed()
1253 }
1254
1255 fn req_packed_commands<'a>(
1256 &'a mut self,
1257 pipeline: &'a crate::Pipeline,
1258 offset: usize,
1259 count: usize,
1260 ) -> RedisFuture<'a, Vec<Value>> {
1261 async move {
1262 let route = route_for_pipeline(pipeline)?;
1263 self.route_pipeline(pipeline, offset, count, route.into())
1264 .await
1265 }
1266 .boxed()
1267 }
1268
1269 fn get_db(&self) -> i64 {
1270 0
1271 }
1272}
1273pub trait Connect: Sized {
1276 fn connect_with_config<'a, T>(info: T, config: AsyncConnectionConfig) -> RedisFuture<'a, Self>
1278 where
1279 T: IntoConnectionInfo + Send + 'a;
1280}
1281
1282impl Connect for MultiplexedConnection {
1283 fn connect_with_config<'a, T>(info: T, config: AsyncConnectionConfig) -> RedisFuture<'a, Self>
1284 where
1285 T: IntoConnectionInfo + Send + 'a,
1286 {
1287 async move {
1288 let connection_info = info.into_connection_info()?;
1289 let client = crate::Client::open(connection_info)?;
1290 client
1291 .get_multiplexed_async_connection_with_config(&config)
1292 .await
1293 }
1294 .boxed()
1295 }
1296}
1297
1298async fn connect_check_and_add<C>(core: Core<C>, addr: String) -> RedisResult<C>
1299where
1300 C: ConnectionLike + Connect + Send + Clone + 'static,
1301{
1302 match connect_and_check::<C>(&addr, core.cluster_params.clone()).await {
1303 Ok(conn) => {
1304 let conn_clone = conn.clone();
1305 core.conn_lock.write().await.0.insert(addr, conn_clone);
1306 Ok(conn)
1307 }
1308 Err(err) => Err(err),
1309 }
1310}
1311
1312async fn connect_and_check<C>(node: &str, params: ClusterParams) -> RedisResult<C>
1313where
1314 C: ConnectionLike + Connect + Send + 'static,
1315{
1316 let read_from_replicas = params.read_from_replicas;
1317 let connection_timeout = params.connection_timeout;
1318 let response_timeout = params.response_timeout;
1319 let push_sender = params.async_push_sender.clone();
1320 let tcp_settings = params.tcp_settings.clone();
1321 let dns_resolver = params.async_dns_resolver.clone();
1322 let info = get_connection_info(node, params)?;
1323 let mut config = AsyncConnectionConfig::default()
1324 .set_connection_timeout(connection_timeout)
1325 .set_tcp_settings(tcp_settings);
1326 if let Some(response_timeout) = response_timeout {
1327 config = config.set_response_timeout(response_timeout);
1328 };
1329 if let Some(push_sender) = push_sender {
1330 config = config.set_push_sender_internal(push_sender);
1331 }
1332 if let Some(resolver) = dns_resolver {
1333 config = config.set_dns_resolver_internal(resolver.clone());
1334 }
1335 let mut conn = match C::connect_with_config(info, config).await {
1336 Ok(conn) => conn,
1337 Err(err) => {
1338 warn!("Failed to connect to node: {:?}, due to: {:?}", node, err);
1339 return Err(err);
1340 }
1341 };
1342
1343 let check = if read_from_replicas {
1344 cmd("READONLY")
1346 } else {
1347 cmd("PING")
1348 };
1349
1350 conn.req_packed_command(&check).await?;
1351 Ok(conn)
1352}
1353
1354async fn check_connection<C>(conn: &mut C) -> RedisResult<()>
1355where
1356 C: ConnectionLike + Send + 'static,
1357{
1358 let mut cmd = Cmd::new();
1359 cmd.arg("PING");
1360 cmd.query_async::<String>(conn).await?;
1361 Ok(())
1362}
1363
1364fn get_random_connection<C>(connections: &ConnectionMap<C>) -> Option<(String, C)>
1365where
1366 C: Clone,
1367{
1368 connections.keys().choose(&mut rng()).and_then(|addr| {
1369 connections
1370 .get(addr)
1371 .map(|conn| (addr.clone(), conn.clone()))
1372 })
1373}