1pub use redis;
52
53use std::{
54 collections::{BTreeMap, HashMap, HashSet},
55 fmt, io,
56 iter::Iterator,
57 marker::Unpin,
58 mem,
59 pin::Pin,
60 sync::Arc,
61 task::{self, Poll},
62 time::Duration,
63};
64
65use crc16::*;
66use futures::{
67 future::{self, BoxFuture},
68 prelude::*,
69 ready, stream,
70};
71use log::trace;
72use pin_project_lite::pin_project;
73use rand::seq::IteratorRandom;
74use rand::thread_rng;
75use redis::{
76 aio::ConnectionLike, Arg, Cmd, ConnectionAddr, ConnectionInfo, ErrorKind, IntoConnectionInfo,
77 RedisError, RedisFuture, RedisResult, Value,
78};
79use tokio::sync::{mpsc, oneshot};
80
81const SLOT_SIZE: usize = 16384;
82const DEFAULT_RETRIES: u32 = 16;
83
84pub struct Client {
86 initial_nodes: Vec<ConnectionInfo>,
87 retries: Option<u32>,
88}
89
90impl Client {
91 pub fn open<T: IntoConnectionInfo>(initial_nodes: Vec<T>) -> RedisResult<Client> {
98 let mut nodes = Vec::with_capacity(initial_nodes.len());
99
100 for info in initial_nodes {
101 let info = info.into_connection_info()?;
102 if let ConnectionAddr::Unix(_) = info.addr {
103 return Err(RedisError::from((ErrorKind::InvalidClientConfig,
104 "This library cannot use unix socket because Redis's cluster command returns only cluster's IP and port.")));
105 }
106 nodes.push(info);
107 }
108
109 Ok(Client {
110 initial_nodes: nodes,
111 retries: Some(DEFAULT_RETRIES),
112 })
113 }
114
115 pub fn set_retries(&mut self, retries: Option<u32>) -> &mut Self {
118 self.retries = retries;
119 self
120 }
121
122 pub async fn get_connection(&self) -> RedisResult<Connection> {
128 Connection::new(&self.initial_nodes, self.retries).await
129 }
130
131 #[doc(hidden)]
132 pub async fn get_generic_connection<C>(&self) -> RedisResult<Connection<C>>
133 where
134 C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static,
135 {
136 Connection::new(&self.initial_nodes, self.retries).await
137 }
138}
139
140#[derive(Clone)]
142pub struct Connection<C = redis::aio::MultiplexedConnection>(mpsc::Sender<Message<C>>);
143
144impl<C> Connection<C>
145where
146 C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static,
147{
148 async fn new(
149 initial_nodes: &[ConnectionInfo],
150 retries: Option<u32>,
151 ) -> RedisResult<Connection<C>> {
152 Pipeline::new(initial_nodes, retries).await.map(|pipeline| {
153 let (tx, mut rx) = mpsc::channel::<Message<_>>(100);
154
155 tokio::spawn(async move {
156 let _ = stream::poll_fn(move |cx| rx.poll_recv(cx))
157 .map(Ok)
158 .forward(pipeline)
159 .await;
160 });
161
162 Connection(tx)
163 })
164 }
165}
166
167type SlotMap = BTreeMap<u16, String>;
168type ConnectionFuture<C> = future::Shared<BoxFuture<'static, C>>;
169type ConnectionMap<C> = HashMap<String, ConnectionFuture<C>>;
170
171struct Pipeline<C> {
172 connections: ConnectionMap<C>,
173 slots: SlotMap,
174 state: ConnectionState<C>,
175 in_flight_requests: stream::FuturesUnordered<
176 Pin<Box<Request<BoxFuture<'static, (String, RedisResult<Response>)>, Response, C>>>,
177 >,
178 refresh_error: Option<RedisError>,
179 pending_requests: Vec<PendingRequest<Response, C>>,
180 retries: Option<u32>,
181 tls: bool,
182 insecure: bool,
183}
184
185#[derive(Clone)]
186enum CmdArg<C> {
187 Cmd {
188 cmd: Arc<redis::Cmd>,
189 func: fn(C, Arc<redis::Cmd>) -> RedisFuture<'static, Response>,
190 },
191 Pipeline {
192 pipeline: Arc<redis::Pipeline>,
193 offset: usize,
194 count: usize,
195 func: fn(C, Arc<redis::Pipeline>, usize, usize) -> RedisFuture<'static, Response>,
196 },
197}
198
199impl<C> CmdArg<C> {
200 fn exec(&self, con: C) -> RedisFuture<'static, Response> {
201 match self {
202 Self::Cmd { cmd, func } => func(con, cmd.clone()),
203 Self::Pipeline {
204 pipeline,
205 offset,
206 count,
207 func,
208 } => func(con, pipeline.clone(), *offset, *count),
209 }
210 }
211
212 fn slot(&self) -> Option<u16> {
213 fn get_cmd_arg(cmd: &Cmd, arg_num: usize) -> Option<&[u8]> {
214 cmd.args_iter().nth(arg_num).and_then(|arg| match arg {
215 redis::Arg::Simple(arg) => Some(arg),
216 redis::Arg::Cursor => None,
217 })
218 }
219
220 fn position(cmd: &Cmd, candidate: &[u8]) -> Option<usize> {
221 cmd.args_iter().position(|arg| match arg {
222 Arg::Simple(arg) => arg.eq_ignore_ascii_case(candidate),
223 _ => false,
224 })
225 }
226
227 fn slot_for_command(cmd: &Cmd) -> Option<u16> {
228 match get_cmd_arg(cmd, 0) {
229 Some(b"EVAL") | Some(b"EVALSHA") => {
230 get_cmd_arg(cmd, 2).and_then(|key_count_bytes| {
231 let key_count_res = std::str::from_utf8(key_count_bytes)
232 .ok()
233 .and_then(|key_count_str| key_count_str.parse::<usize>().ok());
234 key_count_res.and_then(|key_count| {
235 if key_count > 0 {
236 get_cmd_arg(cmd, 3).map(|key| slot_for_key(key))
237 } else {
238 None
240 }
241 })
242 })
243 }
244 Some(b"XGROUP") => get_cmd_arg(cmd, 2).map(|key| slot_for_key(key)),
245 Some(b"XREAD") | Some(b"XREADGROUP") => {
246 let pos = position(cmd, b"STREAMS")?;
247 get_cmd_arg(cmd, pos + 1).map(slot_for_key)
248 }
249 Some(b"SCRIPT") => {
250 None
252 }
253 _ => get_cmd_arg(cmd, 1).map(|key| slot_for_key(key)),
254 }
255 }
256 match self {
257 Self::Cmd { cmd, .. } => slot_for_command(cmd),
258 Self::Pipeline { pipeline, .. } => {
259 let mut iter = pipeline.cmd_iter();
260 let slot = iter.next().map(slot_for_command)?;
261 for cmd in iter {
262 if slot != slot_for_command(cmd) {
263 return None;
264 }
265 }
266 slot
267 }
268 }
269 }
270}
271
272enum Response {
273 Single(Value),
274 Multiple(Vec<Value>),
275}
276
277struct Message<C> {
278 cmd: CmdArg<C>,
279 sender: oneshot::Sender<RedisResult<Response>>,
280}
281
282type RecoverFuture<C> =
283 BoxFuture<'static, Result<(SlotMap, ConnectionMap<C>), (RedisError, ConnectionMap<C>)>>;
284
285enum ConnectionState<C> {
286 PollComplete,
287 Recover(RecoverFuture<C>),
288}
289
290impl<C> fmt::Debug for ConnectionState<C> {
291 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292 write!(
293 f,
294 "{}",
295 match self {
296 ConnectionState::PollComplete => "PollComplete",
297 ConnectionState::Recover(_) => "Recover",
298 }
299 )
300 }
301}
302
303struct RequestInfo<C> {
304 cmd: CmdArg<C>,
305 slot: Option<u16>,
306 excludes: HashSet<String>,
307}
308
309pin_project! {
310 #[project = RequestStateProj]
311 enum RequestState<F> {
312 None,
313 Future {
314 #[pin]
315 future: F,
316 },
317 Sleep {
318 #[pin]
319 sleep: tokio::time::Sleep,
320 },
321 }
322}
323
324struct PendingRequest<I, C> {
325 retry: u32,
326 sender: oneshot::Sender<RedisResult<I>>,
327 info: RequestInfo<C>,
328}
329
330pin_project! {
331 struct Request<F, I, C> {
332 max_retries: Option<u32>,
333 request: Option<PendingRequest<I, C>>,
334 #[pin]
335 future: RequestState<F>,
336 }
337}
338
339#[must_use]
340enum Next<I, C> {
341 TryNewConnection {
342 request: PendingRequest<I, C>,
343 error: Option<RedisError>,
344 },
345 Err {
346 request: PendingRequest<I, C>,
347 error: RedisError,
348 },
349 Done,
350}
351
352impl<F, I, C> Future for Request<F, I, C>
353where
354 F: Future<Output = (String, RedisResult<I>)>,
355 C: ConnectionLike,
356{
357 type Output = Next<I, C>;
358
359 fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> {
360 let mut this = self.as_mut().project();
361 if this.request.is_none() {
362 return Poll::Ready(Next::Done);
363 }
364 let future = match this.future.as_mut().project() {
365 RequestStateProj::Future { future } => future,
366 RequestStateProj::Sleep { sleep } => {
367 return match ready!(sleep.poll(cx)) {
368 () => Next::TryNewConnection {
369 request: self.project().request.take().unwrap(),
370 error: None,
371 },
372 }
373 .into();
374 }
375 _ => panic!("Request future must be Some"),
376 };
377 match ready!(future.poll(cx)) {
378 (_, Ok(item)) => {
379 trace!("Ok");
380 self.respond(Ok(item));
381 Next::Done.into()
382 }
383 (addr, Err(err)) => {
384 trace!("Request error {}", err);
385
386 let request = this.request.as_mut().unwrap();
387
388 match *this.max_retries {
389 Some(max_retries) if request.retry >= max_retries => {
390 self.respond(Err(err));
391 return Next::Done.into();
392 }
393 _ => (),
394 }
395 request.retry = request.retry.saturating_add(1);
396
397 if let Some(error_code) = err.code() {
398 if error_code == "MOVED" || error_code == "ASK" {
399 request.info.excludes.clear();
401 return Next::Err {
402 request: this.request.take().unwrap(),
403 error: err,
404 }
405 .into();
406 } else if error_code == "TRYAGAIN" || error_code == "CLUSTERDOWN" {
407 let sleep_duration =
409 Duration::from_millis(2u64.pow(request.retry.max(7).min(16)) * 10);
410 request.info.excludes.clear();
411 this.future.set(RequestState::Sleep {
412 sleep: tokio::time::sleep(sleep_duration),
413 });
414 return self.poll(cx);
415 }
416 }
417
418 request.info.excludes.insert(addr);
419
420 Next::TryNewConnection {
421 request: this.request.take().unwrap(),
422 error: Some(err),
423 }
424 .into()
425 }
426 }
427 }
428}
429
430impl<F, I, C> Request<F, I, C>
431where
432 F: Future<Output = (String, RedisResult<I>)>,
433 C: ConnectionLike,
434{
435 fn respond(self: Pin<&mut Self>, msg: RedisResult<I>) {
436 let _ = self
438 .project()
439 .request
440 .take()
441 .expect("Result should only be sent once")
442 .sender
443 .send(msg);
444 }
445}
446
447impl<C> Pipeline<C>
448where
449 C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
450{
451 async fn new(initial_nodes: &[ConnectionInfo], retries: Option<u32>) -> RedisResult<Self> {
452 let tls = initial_nodes.iter().all(|c| match c.addr {
453 ConnectionAddr::TcpTls { .. } => true,
454 _ => false,
455 });
456 let insecure = initial_nodes.iter().all(|c| match c.addr {
457 ConnectionAddr::TcpTls { insecure, .. } => insecure,
458 _ => false,
459 });
460 let connections = Self::create_initial_connections(initial_nodes).await?;
461 let mut connection = Pipeline {
462 connections,
463 slots: Default::default(),
464 in_flight_requests: Default::default(),
465 refresh_error: None,
466 pending_requests: Vec::new(),
467 state: ConnectionState::PollComplete,
468 retries,
469 tls,
470 insecure,
471 };
472 let (slots, connections) = connection.refresh_slots().await.map_err(|(err, _)| err)?;
473 connection.slots = slots;
474 connection.connections = connections;
475 Ok(connection)
476 }
477
478 async fn create_initial_connections(
479 initial_nodes: &[ConnectionInfo],
480 ) -> RedisResult<ConnectionMap<C>> {
481 let mut error = None;
482 let connections = stream::iter(initial_nodes.iter().cloned())
483 .map(|info| async move {
484 let addr = match info.addr {
485 ConnectionAddr::Tcp(ref host, port) => build_connection_string(
486 info.redis.username.as_deref(),
487 info.redis.password.as_deref(),
488 host,
489 port as i64,
490 false, false, ),
493 ConnectionAddr::TcpTls {
494 ref host,
495 port,
496 insecure,
497 } => build_connection_string(
498 info.redis.username.as_deref(),
499 info.redis.password.as_deref(),
500 host,
501 port as i64,
502 true, insecure, ),
505 _ => panic!("No reach."),
506 };
507
508 let result = connect_and_check(info).await;
509 match result {
510 Ok(conn) => Ok((addr, async { conn }.boxed().shared())),
511 Err(e) => {
512 trace!("Failed to connect to initial node: {:?}", e);
513 Err(e)
514 }
515 }
516 })
517 .buffer_unordered(initial_nodes.len())
518 .fold(
519 HashMap::with_capacity(initial_nodes.len()),
520 |mut connections: ConnectionMap<C>, result| {
521 match result {
522 Ok((k, v)) => {
523 connections.insert(k, v);
524 }
525 Err(err) => error = Some(err),
526 }
527 async move { connections }
528 },
529 )
530 .await;
531 if connections.len() == 0 {
532 if let Some(err) = error {
533 return Err(err);
534 } else {
535 return Err(RedisError::from((
536 ErrorKind::IoError,
537 "Failed to create initial connections",
538 )));
539 }
540 }
541 Ok(connections)
542 }
543
544 fn refresh_slots(
546 &mut self,
547 ) -> impl Future<Output = Result<(SlotMap, ConnectionMap<C>), (RedisError, ConnectionMap<C>)>>
548 {
549 let mut connections = mem::replace(&mut self.connections, Default::default());
550 let use_tls = self.tls;
551 let tls_insecure = self.insecure;
552
553 async move {
554 let mut result = Ok(SlotMap::new());
555 for (addr, conn) in connections.iter_mut() {
556 let mut conn = conn.clone().await;
557 match get_slots(addr, &mut conn, use_tls, tls_insecure)
558 .await
559 .and_then(|v| Self::build_slot_map(v))
560 {
561 Ok(s) => {
562 result = Ok(s);
563 break;
564 }
565 Err(err) => result = Err(err),
566 }
567 }
568 let slots = match result {
569 Ok(slots) => slots,
570 Err(err) => return Err((err, connections)),
571 };
572
573 let new_connections = HashMap::with_capacity(connections.len());
575
576 let (_, connections) = stream::iter(slots.values())
577 .fold(
578 (connections, new_connections),
579 move |(mut connections, mut new_connections), addr| async move {
580 if !new_connections.contains_key(addr) {
581 let new_connection = if let Some(conn) = connections.remove(addr) {
582 let mut conn = conn.await;
583 match check_connection(&mut conn).await {
584 Ok(_) => Some((addr.to_string(), conn)),
585 Err(_) => match connect_and_check(addr.as_ref()).await {
586 Ok(conn) => Some((addr.to_string(), conn)),
587 Err(_) => None,
588 },
589 }
590 } else {
591 match connect_and_check(addr.as_ref()).await {
592 Ok(conn) => Some((addr.to_string(), conn)),
593 Err(_) => None,
594 }
595 };
596 if let Some((addr, new_connection)) = new_connection {
597 new_connections
598 .insert(addr, async { new_connection }.boxed().shared());
599 }
600 }
601 (connections, new_connections)
602 },
603 )
604 .await;
605 Ok((slots, connections))
606 }
607 }
608
609 fn build_slot_map(mut slots_data: Vec<Slot>) -> RedisResult<SlotMap> {
610 slots_data.sort_by_key(|slot_data| slot_data.start);
611 let last_slot = slots_data.iter().try_fold(0, |prev_end, slot_data| {
612 if prev_end != slot_data.start() {
613 return Err(RedisError::from((
614 ErrorKind::ResponseError,
615 "Slot refresh error.",
616 format!(
617 "Received overlapping slots {} and {}..{}",
618 prev_end, slot_data.start, slot_data.end
619 ),
620 )));
621 }
622 Ok(slot_data.end() + 1)
623 })?;
624
625 if usize::from(last_slot) != SLOT_SIZE {
626 return Err(RedisError::from((
627 ErrorKind::ResponseError,
628 "Slot refresh error.",
629 format!("Lacks the slots >= {}", last_slot),
630 )));
631 }
632 let slot_map = slots_data
633 .iter()
634 .map(|slot_data| (slot_data.end(), slot_data.master().to_string()))
635 .collect();
636 trace!("{:?}", slot_map);
637 Ok(slot_map)
638 }
639
640 fn get_connection(&mut self, slot: u16) -> (String, ConnectionFuture<C>) {
641 if let Some((_, addr)) = self.slots.range(&slot..).next() {
642 if let Some(conn) = self.connections.get(addr) {
643 return (addr.clone(), conn.clone());
644 }
645
646 let (_, random_conn) = get_random_connection(&self.connections, None); let connection_future = {
650 let addr = addr.clone();
651 async move {
652 match connect_and_check(addr.as_ref()).await {
653 Ok(conn) => conn,
654 Err(_) => random_conn.await,
655 }
656 }
657 }
658 .boxed()
659 .shared();
660 self.connections
661 .insert(addr.clone(), connection_future.clone());
662 (addr.clone(), connection_future)
663 } else {
664 get_random_connection(&self.connections, None)
666 }
667 }
668
669 fn try_request(
670 &mut self,
671 info: &RequestInfo<C>,
672 ) -> impl Future<Output = (String, RedisResult<Response>)> {
673 let cmd = info.cmd.clone();
675 let (addr, conn) = if info.excludes.len() > 0 || info.slot.is_none() {
676 get_random_connection(&self.connections, Some(&info.excludes))
677 } else {
678 self.get_connection(info.slot.unwrap())
679 };
680 async move {
681 let conn = conn.await;
682 let result = cmd.exec(conn).await;
683 (addr, result)
684 }
685 }
686
687 fn poll_recover(
688 &mut self,
689 cx: &mut task::Context<'_>,
690 mut future: RecoverFuture<C>,
691 ) -> Poll<Result<(), RedisError>> {
692 match future.as_mut().poll(cx) {
693 Poll::Ready(Ok((slots, connections))) => {
694 trace!("Recovered with {} connections!", connections.len());
695 self.slots = slots;
696 self.connections = connections;
697 self.state = ConnectionState::PollComplete;
698 Poll::Ready(Ok(()))
699 }
700 Poll::Pending => {
701 self.state = ConnectionState::Recover(future);
702 trace!("Recover not ready");
703 Poll::Pending
704 }
705 Poll::Ready(Err((err, connections))) => {
706 self.connections = connections;
707 self.state = ConnectionState::Recover(Box::pin(self.refresh_slots()));
708 Poll::Ready(Err(err))
709 }
710 }
711 }
712
713 fn poll_complete(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), RedisError>> {
714 let mut connection_error = None;
715
716 if !self.pending_requests.is_empty() {
717 let mut pending_requests = mem::take(&mut self.pending_requests);
718 for request in pending_requests.drain(..) {
719 if request.sender.is_closed() {
723 continue;
724 }
725
726 let future = self.try_request(&request.info);
727 self.in_flight_requests.push(Box::pin(Request {
728 max_retries: self.retries,
729 request: Some(request),
730 future: RequestState::Future {
731 future: future.boxed(),
732 },
733 }));
734 }
735 self.pending_requests = pending_requests;
736 }
737
738 loop {
739 let result = match Pin::new(&mut self.in_flight_requests).poll_next(cx) {
740 Poll::Ready(Some(result)) => result,
741 Poll::Ready(None) | Poll::Pending => break,
742 };
743 let self_ = &mut *self;
744 match result {
745 Next::Done => {}
746 Next::TryNewConnection { request, error } => {
747 if let Some(error) = error {
748 if request.info.excludes.len() >= self_.connections.len() {
749 let _ = request.sender.send(Err(error));
750 continue;
751 }
752 }
753 let future = self.try_request(&request.info);
754 self.in_flight_requests.push(Box::pin(Request {
755 max_retries: self.retries,
756 request: Some(request),
757 future: RequestState::Future {
758 future: Box::pin(future),
759 },
760 }));
761 }
762 Next::Err { request, error } => {
763 connection_error = Some(error);
764 self.pending_requests.push(request);
765 }
766 }
767 }
768
769 if let Some(err) = connection_error {
770 Poll::Ready(Err(err))
771 } else if self.in_flight_requests.is_empty() {
772 Poll::Ready(Ok(()))
773 } else {
774 Poll::Pending
775 }
776 }
777
778 fn send_refresh_error(&mut self) {
779 if self.refresh_error.is_some() {
780 if let Some(mut request) = Pin::new(&mut self.in_flight_requests)
781 .iter_pin_mut()
782 .find(|request| request.request.is_some())
783 {
784 (*request)
785 .as_mut()
786 .respond(Err(self.refresh_error.take().unwrap()));
787 } else if let Some(request) = self.pending_requests.pop() {
788 let _ = request.sender.send(Err(self.refresh_error.take().unwrap()));
789 }
790 }
791 }
792}
793
794impl<C> Sink<Message<C>> for Pipeline<C>
795where
796 C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static,
797{
798 type Error = ();
799
800 fn poll_ready(
801 mut self: Pin<&mut Self>,
802 cx: &mut task::Context,
803 ) -> Poll<Result<(), Self::Error>> {
804 match mem::replace(&mut self.state, ConnectionState::PollComplete) {
805 ConnectionState::PollComplete => Poll::Ready(Ok(())),
806 ConnectionState::Recover(future) => {
807 match ready!(self.as_mut().poll_recover(cx, future)) {
808 Ok(()) => Poll::Ready(Ok(())),
809 Err(err) => {
810 if let Some(mut request) = Pin::new(&mut self.in_flight_requests)
814 .iter_pin_mut()
815 .find(|request| request.request.is_some())
816 {
817 (*request).as_mut().respond(Err(err));
818 } else {
819 self.refresh_error = Some(err);
820 }
821 Poll::Ready(Ok(()))
822 }
823 }
824 }
825 }
826 }
827
828 fn start_send(mut self: Pin<&mut Self>, msg: Message<C>) -> Result<(), Self::Error> {
829 trace!("start_send");
830 let Message { cmd, sender } = msg;
831
832 let excludes = HashSet::new();
833 let slot = cmd.slot();
834
835 let info = RequestInfo {
836 cmd,
837 slot,
838 excludes,
839 };
840
841 self.pending_requests.push(PendingRequest {
842 retry: 0,
843 sender,
844 info,
845 });
846 Ok(()).into()
847 }
848
849 fn poll_flush(
850 mut self: Pin<&mut Self>,
851 cx: &mut task::Context,
852 ) -> Poll<Result<(), Self::Error>> {
853 trace!("poll_complete: {:?}", self.state);
854 loop {
855 self.send_refresh_error();
856
857 match mem::replace(&mut self.state, ConnectionState::PollComplete) {
858 ConnectionState::Recover(future) => {
859 match ready!(self.as_mut().poll_recover(cx, future)) {
860 Ok(()) => (),
861 Err(err) => {
862 self.refresh_error = Some(err);
866
867 cx.waker().wake_by_ref();
871 return Poll::Pending;
872 }
873 }
874 }
875 ConnectionState::PollComplete => match ready!(self.poll_complete(cx)) {
876 Ok(()) => return Poll::Ready(Ok(())),
877 Err(err) => {
878 trace!("Recovering {}", err);
879 self.state = ConnectionState::Recover(Box::pin(self.refresh_slots()));
880 }
881 },
882 }
883 }
884 }
885
886 fn poll_close(
887 mut self: Pin<&mut Self>,
888 cx: &mut task::Context,
889 ) -> Poll<Result<(), Self::Error>> {
890 match self.poll_complete(cx) {
892 Poll::Ready(result) => {
893 result.map_err(|_| ())?;
894 }
895 Poll::Pending => (),
896 };
897 if self.in_flight_requests.is_empty() {
900 return Poll::Ready(Ok(()));
901 }
902
903 self.poll_flush(cx)
904 }
905}
906
907impl<C> ConnectionLike for Connection<C>
908where
909 C: ConnectionLike + Send + 'static,
910{
911 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
912 trace!("req_packed_command");
913 let (sender, receiver) = oneshot::channel();
914 Box::pin(async move {
915 self.0
916 .send(Message {
917 cmd: CmdArg::Cmd {
918 cmd: Arc::new(cmd.clone()), func: |mut conn, cmd| {
920 Box::pin(async move {
921 conn.req_packed_command(&cmd).await.map(Response::Single)
922 })
923 },
924 },
925 sender,
926 })
927 .await
928 .map_err(|_| {
929 RedisError::from(io::Error::new(
930 io::ErrorKind::BrokenPipe,
931 "redis_cluster: Unable to send command",
932 ))
933 })?;
934 receiver
935 .await
936 .unwrap_or_else(|_| {
937 Err(RedisError::from(io::Error::new(
938 io::ErrorKind::BrokenPipe,
939 "redis_cluster: Unable to receive command",
940 )))
941 })
942 .map(|response| match response {
943 Response::Single(value) => value,
944 Response::Multiple(_) => unreachable!(),
945 })
946 })
947 }
948
949 fn req_packed_commands<'a>(
950 &'a mut self,
951 pipeline: &'a redis::Pipeline,
952 offset: usize,
953 count: usize,
954 ) -> RedisFuture<'a, Vec<Value>> {
955 let (sender, receiver) = oneshot::channel();
956 Box::pin(async move {
957 self.0
958 .send(Message {
959 cmd: CmdArg::Pipeline {
960 pipeline: Arc::new(pipeline.clone()), offset,
962 count,
963 func: |mut conn, pipeline, offset, count| {
964 Box::pin(async move {
965 conn.req_packed_commands(&pipeline, offset, count)
966 .await
967 .map(Response::Multiple)
968 })
969 },
970 },
971 sender,
972 })
973 .await
974 .map_err(|_| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))?;
975
976 receiver
977 .await
978 .unwrap_or_else(|_| {
979 Err(RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))
980 })
981 .map(|response| match response {
982 Response::Multiple(values) => values,
983 Response::Single(_) => unreachable!(),
984 })
985 })
986 }
987
988 fn get_db(&self) -> i64 {
989 0
990 }
991}
992
993impl Clone for Client {
994 fn clone(&self) -> Client {
995 Client::open(self.initial_nodes.clone()).unwrap()
996 }
997}
998
999pub trait Connect: Sized {
1000 fn connect<'a, T>(info: T) -> RedisFuture<'a, Self>
1001 where
1002 T: IntoConnectionInfo + Send + 'a;
1003}
1004
1005impl Connect for redis::aio::MultiplexedConnection {
1006 fn connect<'a, T>(info: T) -> RedisFuture<'a, redis::aio::MultiplexedConnection>
1007 where
1008 T: IntoConnectionInfo + Send + 'a,
1009 {
1010 async move {
1011 let connection_info = info.into_connection_info()?;
1012 let client = redis::Client::open(connection_info)?;
1013 client.get_multiplexed_tokio_connection().await
1014 }
1015 .boxed()
1016 }
1017}
1018
1019async fn connect_and_check<T, C>(info: T) -> RedisResult<C>
1020where
1021 T: IntoConnectionInfo + Send,
1022 C: ConnectionLike + Connect + Send + 'static,
1023{
1024 let mut conn = C::connect(info).await?;
1025 check_connection(&mut conn).await?;
1026 Ok(conn)
1027}
1028
1029async fn check_connection<C>(conn: &mut C) -> RedisResult<()>
1030where
1031 C: ConnectionLike + Send + 'static,
1032{
1033 let mut cmd = Cmd::new();
1034 cmd.arg("PING");
1035 cmd.query_async::<_, String>(conn).await?;
1036 Ok(())
1037}
1038
1039fn get_random_connection<'a, C>(
1040 connections: &'a ConnectionMap<C>,
1041 excludes: Option<&'a HashSet<String>>,
1042) -> (String, ConnectionFuture<C>)
1043where
1044 C: Clone,
1045{
1046 debug_assert!(!connections.is_empty());
1047
1048 let mut rng = thread_rng();
1049 let sample = match excludes {
1050 Some(excludes) if excludes.len() < connections.len() => {
1051 let target_keys = connections.keys().filter(|key| !excludes.contains(*key));
1052 target_keys.choose(&mut rng)
1053 }
1054 _ => connections.keys().choose(&mut rng),
1055 };
1056
1057 let addr = sample.expect("No targets to choose from");
1058 (addr.to_string(), connections.get(addr).unwrap().clone())
1059}
1060
1061fn slot_for_key(key: &[u8]) -> u16 {
1062 let key = sub_key(&key);
1063 State::<XMODEM>::calculate(&key) % SLOT_SIZE as u16
1064}
1065
1066fn sub_key(key: &[u8]) -> &[u8] {
1069 key.iter()
1070 .position(|b| *b == b'{')
1071 .and_then(|open| {
1072 let after_open = open + 1;
1073 key[after_open..]
1074 .iter()
1075 .position(|b| *b == b'}')
1076 .and_then(|close_offset| {
1077 if close_offset != 0 {
1078 Some(&key[after_open..after_open + close_offset])
1079 } else {
1080 None
1081 }
1082 })
1083 })
1084 .unwrap_or(key)
1085}
1086
1087#[derive(Debug)]
1088struct Slot {
1089 start: u16,
1090 end: u16,
1091 master: String,
1092 replicas: Vec<String>,
1093}
1094
1095impl Slot {
1096 pub fn start(&self) -> u16 {
1097 self.start
1098 }
1099 pub fn end(&self) -> u16 {
1100 self.end
1101 }
1102 pub fn master(&self) -> &str {
1103 &self.master
1104 }
1105 #[allow(dead_code)]
1106 pub fn replicas(&self) -> &Vec<String> {
1107 &self.replicas
1108 }
1109}
1110
1111async fn get_slots<C>(
1113 addr: &str,
1114 connection: &mut C,
1115 use_tls: bool,
1116 tls_insecure: bool,
1117) -> RedisResult<Vec<Slot>>
1118where
1119 C: ConnectionLike,
1120{
1121 trace!("get_slots");
1122 let mut cmd = Cmd::new();
1123 cmd.arg("CLUSTER").arg("SLOTS");
1124 let value = connection.req_packed_command(&cmd).await.map_err(|err| {
1125 trace!("get_slots error: {}", err);
1126 err
1127 })?;
1128 trace!("get_slots -> {:#?}", value);
1129 let mut result = Vec::with_capacity(2);
1131
1132 if let Value::Bulk(items) = value {
1133 let username = get_username(addr);
1136 let password = get_password(addr);
1137 let host = get_hostname(addr);
1138
1139 let mut iter = items.into_iter();
1140 while let Some(Value::Bulk(item)) = iter.next() {
1141 if item.len() < 3 {
1142 continue;
1143 }
1144
1145 let start = if let Value::Int(start) = item[0] {
1146 start as u16
1147 } else {
1148 continue;
1149 };
1150
1151 let end = if let Value::Int(end) = item[1] {
1152 end as u16
1153 } else {
1154 continue;
1155 };
1156
1157 let mut nodes: Vec<String> = item
1158 .into_iter()
1159 .skip(2)
1160 .filter_map(|node| {
1161 if let Value::Bulk(node) = node {
1162 if node.len() < 2 {
1163 return None;
1164 }
1165
1166 let ip = if let Value::Data(ref ip) = node[0] {
1167 String::from_utf8_lossy(ip)
1168 } else {
1169 return None;
1170 };
1171
1172 let port = if let Value::Int(port) = node[1] {
1173 port
1174 } else {
1175 return None;
1176 };
1177
1178 let ip = if ip != "" {
1179 &*ip
1180 } else {
1181 &*host.as_ref().unwrap()
1182 };
1183
1184 Some(build_connection_string(
1185 username.as_deref(),
1186 password.as_deref(),
1187 &ip,
1188 port,
1189 use_tls,
1190 tls_insecure,
1191 ))
1192 } else {
1193 None
1194 }
1195 })
1196 .collect();
1197
1198 if nodes.len() < 1 {
1199 continue;
1200 }
1201
1202 let replicas = nodes.split_off(1);
1203 result.push(Slot {
1204 start,
1205 end,
1206 master: nodes.pop().unwrap(),
1207 replicas,
1208 });
1209 }
1210 }
1211
1212 Ok(result)
1213}
1214
1215fn build_connection_string(
1216 username: Option<&str>,
1217 password: Option<&str>,
1218 host: &str,
1219 port: i64,
1220 use_tls: bool,
1221 tls_insecure: bool,
1222) -> String {
1223 let scheme = if use_tls { "rediss" } else { "redis" };
1224 let fragment = if use_tls && tls_insecure {
1225 "#insecure"
1226 } else {
1227 ""
1228 };
1229 match (username, password) {
1230 (Some(username), Some(pw)) => {
1231 format!(
1232 "{}://{}:{}@{}:{}{}",
1233 scheme, username, pw, host, port, fragment
1234 )
1235 }
1236 (None, Some(pw)) => {
1237 format!("{}://:{}@{}:{}{}", scheme, pw, host, port, fragment)
1238 }
1239 (Some(username), None) => {
1240 format!("{}://{}@{}:{}{}", scheme, username, host, port, fragment)
1241 }
1242 (None, None) => {
1243 format!("{}://{}:{}{}", scheme, host, port, fragment)
1244 }
1245 }
1246}
1247
1248fn get_password(addr: &str) -> Option<String> {
1249 redis::parse_redis_url(addr).and_then(|url| url.password().map(|s| s.into()))
1250}
1251
1252fn get_username(addr: &str) -> Option<String> {
1253 redis::parse_redis_url(addr).and_then(|url| {
1254 let username = url.username();
1255 if username != "" {
1256 Some(url.username().to_string())
1257 } else {
1258 None
1259 }
1260 })
1261}
1262
1263fn get_hostname(addr: &str) -> Option<String> {
1264 redis::parse_redis_url(addr).and_then(|url| url.host_str().map(String::from))
1265}
1266
1267#[cfg(test)]
1268mod tests {
1269 use super::*;
1270
1271 fn slot_for_packed_command(cmd: &[u8]) -> Option<u16> {
1272 command_key(cmd).map(|key| {
1273 let key = sub_key(&key);
1274 State::<XMODEM>::calculate(&key) % SLOT_SIZE as u16
1275 })
1276 }
1277
1278 fn command_key(cmd: &[u8]) -> Option<Vec<u8>> {
1279 redis::parse_redis_value(cmd)
1280 .ok()
1281 .and_then(|value| match value {
1282 Value::Bulk(mut args) => {
1283 if args.len() >= 2 {
1284 match args.swap_remove(1) {
1285 Value::Data(key) => Some(key),
1286 _ => None,
1287 }
1288 } else {
1289 None
1290 }
1291 }
1292 _ => None,
1293 })
1294 }
1295
1296 #[test]
1297 fn slot() {
1298 assert_eq!(
1299 slot_for_packed_command(&[
1300 42, 50, 13, 10, 36, 54, 13, 10, 69, 88, 73, 83, 84, 83, 13, 10, 36, 49, 54, 13, 10,
1301 244, 93, 23, 40, 126, 127, 253, 33, 89, 47, 185, 204, 171, 249, 96, 139, 13, 10
1302 ]),
1303 Some(964)
1304 );
1305 assert_eq!(
1306 slot_for_packed_command(&[
1307 42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 36, 241,
1308 197, 111, 180, 254, 5, 175, 143, 146, 171, 39, 172, 23, 164, 145, 13, 10, 36, 52,
1309 13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10,
1310 80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10
1311 ]),
1312 Some(8352)
1313 );
1314
1315 assert_eq!(
1316 slot_for_packed_command(&[
1317 42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 169, 233,
1318 247, 59, 50, 247, 100, 232, 123, 140, 2, 101, 125, 221, 66, 170, 13, 10, 36, 52,
1319 13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10,
1320 80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10
1321 ]),
1322 Some(5210),
1323 );
1324 }
1325
1326 #[test]
1327 fn test_get_username_password() {
1328 let testcases: Vec<(&str, Option<String>, Option<String>)> = vec![
1329 ("redis://127.0.0.1:7000", None, None),
1330 (
1331 "redis://:password@127.0.0.1:7000",
1332 None,
1333 Some("password".to_string()),
1334 ),
1335 (
1336 "redis://username:password@127.0.0.1:7000",
1337 Some("username".to_string()),
1338 Some("password".to_string()),
1339 ),
1340 (
1341 "redis://username:@127.0.0.1:7000",
1342 Some("username".to_string()),
1343 None,
1344 ),
1345 (
1346 "redis://username@127.0.0.1:7000",
1347 Some("username".to_string()),
1348 None,
1349 ),
1350 ];
1351
1352 for (redis_url, username, password) in testcases {
1353 assert_eq!(username, get_username(redis_url));
1354 assert_eq!(password, get_password(redis_url));
1355 }
1356 }
1357}