redis_cluster_async/
lib.rs

1//! This library extends redis-rs library to be able to use Redis Cluster asynchronously.
2//! [`Client`] implements the `redis::ConnectionLike` and `redis::Commands` traits,
3//! so you can use redis-rs's access methods.
4//! If you want more information, read [the redis-rs documentation].
5//!
6//! Note that this library currently does not have Pubsub features.
7//!
8//! [the redis-rs documentation]: https://docs.rs/redis
9//!
10//! # Example
11//! ```rust
12//! use redis_cluster_async::{Client, redis::{Commands, cmd}};
13//!
14//! #[tokio::main]
15//! async fn main() -> redis::RedisResult<()> {
16//! #   let _ = env_logger::try_init();
17//!     let nodes = vec!["redis://127.0.0.1:7000/", "redis://127.0.0.1:7001/", "redis://127.0.0.1:7002/"];
18//!
19//!     let client = Client::open(nodes)?;
20//!     let mut connection = client.get_connection().await?;
21//!     cmd("SET").arg("test").arg("test_data").query_async(&mut connection).await?;
22//!     let res: String = cmd("GET").arg("test").query_async(&mut connection).await?;
23//!     assert_eq!(res, "test_data");
24//!     Ok(())
25//! }
26//! ```
27//!
28//! ## Pipelining
29//! ```rust
30//! use redis_cluster_async::{Client, redis::pipe};
31//!
32//! #[tokio::main]
33//! async fn main() -> redis::RedisResult<()> {
34//! #   let _ = env_logger::try_init();
35//!     let nodes = vec!["redis://127.0.0.1:7000/", "redis://127.0.0.1:7001/", "redis://127.0.0.1:7002/"];
36//!
37//!     let client = Client::open(nodes)?;
38//!     let mut connection = client.get_connection().await?;
39//!     let key = "test2";
40//!
41//!     let mut pipe = pipe();
42//!     pipe.rpush(key, "123").ignore()
43//!         .ltrim(key, -10, -1).ignore()
44//!         .expire(key, 60).ignore();
45//!     pipe.query_async(&mut connection)
46//!         .await?;
47//!     Ok(())
48//! }
49//! ```
50
51pub use redis;
52
53use std::{
54    collections::{BTreeMap, HashMap, HashSet},
55    fmt, io,
56    iter::Iterator,
57    marker::Unpin,
58    mem,
59    pin::Pin,
60    sync::Arc,
61    task::{self, Poll},
62    time::Duration,
63};
64
65use crc16::*;
66use futures::{
67    future::{self, BoxFuture},
68    prelude::*,
69    ready, stream,
70};
71use log::trace;
72use pin_project_lite::pin_project;
73use rand::seq::IteratorRandom;
74use rand::thread_rng;
75use redis::{
76    aio::ConnectionLike, Arg, Cmd, ConnectionAddr, ConnectionInfo, ErrorKind, IntoConnectionInfo,
77    RedisError, RedisFuture, RedisResult, Value,
78};
79use tokio::sync::{mpsc, oneshot};
80
81const SLOT_SIZE: usize = 16384;
82const DEFAULT_RETRIES: u32 = 16;
83
84/// This is a Redis cluster client.
85pub struct Client {
86    initial_nodes: Vec<ConnectionInfo>,
87    retries: Option<u32>,
88}
89
90impl Client {
91    /// Connect to a redis cluster server and return a cluster client.
92    /// This does not actually open a connection yet but it performs some basic checks on the URL.
93    ///
94    /// # Errors
95    ///
96    /// If it is failed to parse initial_nodes, an error is returned.
97    pub fn open<T: IntoConnectionInfo>(initial_nodes: Vec<T>) -> RedisResult<Client> {
98        let mut nodes = Vec::with_capacity(initial_nodes.len());
99
100        for info in initial_nodes {
101            let info = info.into_connection_info()?;
102            if let ConnectionAddr::Unix(_) = info.addr {
103                return Err(RedisError::from((ErrorKind::InvalidClientConfig,
104                                             "This library cannot use unix socket because Redis's cluster command returns only cluster's IP and port.")));
105            }
106            nodes.push(info);
107        }
108
109        Ok(Client {
110            initial_nodes: nodes,
111            retries: Some(DEFAULT_RETRIES),
112        })
113    }
114
115    /// Set how many times we should retry a query. Set `None` to retry forever.
116    /// Default: 16
117    pub fn set_retries(&mut self, retries: Option<u32>) -> &mut Self {
118        self.retries = retries;
119        self
120    }
121
122    /// Open and get a Redis cluster connection.
123    ///
124    /// # Errors
125    ///
126    /// If it is failed to open connections and to create slots, an error is returned.
127    pub async fn get_connection(&self) -> RedisResult<Connection> {
128        Connection::new(&self.initial_nodes, self.retries).await
129    }
130
131    #[doc(hidden)]
132    pub async fn get_generic_connection<C>(&self) -> RedisResult<Connection<C>>
133    where
134        C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static,
135    {
136        Connection::new(&self.initial_nodes, self.retries).await
137    }
138}
139
140/// This is a connection of Redis cluster.
141#[derive(Clone)]
142pub struct Connection<C = redis::aio::MultiplexedConnection>(mpsc::Sender<Message<C>>);
143
144impl<C> Connection<C>
145where
146    C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static,
147{
148    async fn new(
149        initial_nodes: &[ConnectionInfo],
150        retries: Option<u32>,
151    ) -> RedisResult<Connection<C>> {
152        Pipeline::new(initial_nodes, retries).await.map(|pipeline| {
153            let (tx, mut rx) = mpsc::channel::<Message<_>>(100);
154
155            tokio::spawn(async move {
156                let _ = stream::poll_fn(move |cx| rx.poll_recv(cx))
157                    .map(Ok)
158                    .forward(pipeline)
159                    .await;
160            });
161
162            Connection(tx)
163        })
164    }
165}
166
167type SlotMap = BTreeMap<u16, String>;
168type ConnectionFuture<C> = future::Shared<BoxFuture<'static, C>>;
169type ConnectionMap<C> = HashMap<String, ConnectionFuture<C>>;
170
171struct Pipeline<C> {
172    connections: ConnectionMap<C>,
173    slots: SlotMap,
174    state: ConnectionState<C>,
175    in_flight_requests: stream::FuturesUnordered<
176        Pin<Box<Request<BoxFuture<'static, (String, RedisResult<Response>)>, Response, C>>>,
177    >,
178    refresh_error: Option<RedisError>,
179    pending_requests: Vec<PendingRequest<Response, C>>,
180    retries: Option<u32>,
181    tls: bool,
182    insecure: bool,
183}
184
185#[derive(Clone)]
186enum CmdArg<C> {
187    Cmd {
188        cmd: Arc<redis::Cmd>,
189        func: fn(C, Arc<redis::Cmd>) -> RedisFuture<'static, Response>,
190    },
191    Pipeline {
192        pipeline: Arc<redis::Pipeline>,
193        offset: usize,
194        count: usize,
195        func: fn(C, Arc<redis::Pipeline>, usize, usize) -> RedisFuture<'static, Response>,
196    },
197}
198
199impl<C> CmdArg<C> {
200    fn exec(&self, con: C) -> RedisFuture<'static, Response> {
201        match self {
202            Self::Cmd { cmd, func } => func(con, cmd.clone()),
203            Self::Pipeline {
204                pipeline,
205                offset,
206                count,
207                func,
208            } => func(con, pipeline.clone(), *offset, *count),
209        }
210    }
211
212    fn slot(&self) -> Option<u16> {
213        fn get_cmd_arg(cmd: &Cmd, arg_num: usize) -> Option<&[u8]> {
214            cmd.args_iter().nth(arg_num).and_then(|arg| match arg {
215                redis::Arg::Simple(arg) => Some(arg),
216                redis::Arg::Cursor => None,
217            })
218        }
219
220        fn position(cmd: &Cmd, candidate: &[u8]) -> Option<usize> {
221            cmd.args_iter().position(|arg| match arg {
222                Arg::Simple(arg) => arg.eq_ignore_ascii_case(candidate),
223                _ => false,
224            })
225        }
226
227        fn slot_for_command(cmd: &Cmd) -> Option<u16> {
228            match get_cmd_arg(cmd, 0) {
229                Some(b"EVAL") | Some(b"EVALSHA") => {
230                    get_cmd_arg(cmd, 2).and_then(|key_count_bytes| {
231                        let key_count_res = std::str::from_utf8(key_count_bytes)
232                            .ok()
233                            .and_then(|key_count_str| key_count_str.parse::<usize>().ok());
234                        key_count_res.and_then(|key_count| {
235                            if key_count > 0 {
236                                get_cmd_arg(cmd, 3).map(|key| slot_for_key(key))
237                            } else {
238                                // TODO need to handle sending to all masters
239                                None
240                            }
241                        })
242                    })
243                }
244                Some(b"XGROUP") => get_cmd_arg(cmd, 2).map(|key| slot_for_key(key)),
245                Some(b"XREAD") | Some(b"XREADGROUP") => {
246                    let pos = position(cmd, b"STREAMS")?;
247                    get_cmd_arg(cmd, pos + 1).map(slot_for_key)
248                }
249                Some(b"SCRIPT") => {
250                    // TODO need to handle sending to all masters
251                    None
252                }
253                _ => get_cmd_arg(cmd, 1).map(|key| slot_for_key(key)),
254            }
255        }
256        match self {
257            Self::Cmd { cmd, .. } => slot_for_command(cmd),
258            Self::Pipeline { pipeline, .. } => {
259                let mut iter = pipeline.cmd_iter();
260                let slot = iter.next().map(slot_for_command)?;
261                for cmd in iter {
262                    if slot != slot_for_command(cmd) {
263                        return None;
264                    }
265                }
266                slot
267            }
268        }
269    }
270}
271
272enum Response {
273    Single(Value),
274    Multiple(Vec<Value>),
275}
276
277struct Message<C> {
278    cmd: CmdArg<C>,
279    sender: oneshot::Sender<RedisResult<Response>>,
280}
281
282type RecoverFuture<C> =
283    BoxFuture<'static, Result<(SlotMap, ConnectionMap<C>), (RedisError, ConnectionMap<C>)>>;
284
285enum ConnectionState<C> {
286    PollComplete,
287    Recover(RecoverFuture<C>),
288}
289
290impl<C> fmt::Debug for ConnectionState<C> {
291    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292        write!(
293            f,
294            "{}",
295            match self {
296                ConnectionState::PollComplete => "PollComplete",
297                ConnectionState::Recover(_) => "Recover",
298            }
299        )
300    }
301}
302
303struct RequestInfo<C> {
304    cmd: CmdArg<C>,
305    slot: Option<u16>,
306    excludes: HashSet<String>,
307}
308
309pin_project! {
310    #[project = RequestStateProj]
311    enum RequestState<F> {
312        None,
313        Future {
314            #[pin]
315            future: F,
316        },
317        Sleep {
318            #[pin]
319            sleep: tokio::time::Sleep,
320        },
321    }
322}
323
324struct PendingRequest<I, C> {
325    retry: u32,
326    sender: oneshot::Sender<RedisResult<I>>,
327    info: RequestInfo<C>,
328}
329
330pin_project! {
331    struct Request<F, I, C> {
332        max_retries: Option<u32>,
333        request: Option<PendingRequest<I, C>>,
334        #[pin]
335        future: RequestState<F>,
336    }
337}
338
339#[must_use]
340enum Next<I, C> {
341    TryNewConnection {
342        request: PendingRequest<I, C>,
343        error: Option<RedisError>,
344    },
345    Err {
346        request: PendingRequest<I, C>,
347        error: RedisError,
348    },
349    Done,
350}
351
352impl<F, I, C> Future for Request<F, I, C>
353where
354    F: Future<Output = (String, RedisResult<I>)>,
355    C: ConnectionLike,
356{
357    type Output = Next<I, C>;
358
359    fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Self::Output> {
360        let mut this = self.as_mut().project();
361        if this.request.is_none() {
362            return Poll::Ready(Next::Done);
363        }
364        let future = match this.future.as_mut().project() {
365            RequestStateProj::Future { future } => future,
366            RequestStateProj::Sleep { sleep } => {
367                return match ready!(sleep.poll(cx)) {
368                    () => Next::TryNewConnection {
369                        request: self.project().request.take().unwrap(),
370                        error: None,
371                    },
372                }
373                .into();
374            }
375            _ => panic!("Request future must be Some"),
376        };
377        match ready!(future.poll(cx)) {
378            (_, Ok(item)) => {
379                trace!("Ok");
380                self.respond(Ok(item));
381                Next::Done.into()
382            }
383            (addr, Err(err)) => {
384                trace!("Request error {}", err);
385
386                let request = this.request.as_mut().unwrap();
387
388                match *this.max_retries {
389                    Some(max_retries) if request.retry >= max_retries => {
390                        self.respond(Err(err));
391                        return Next::Done.into();
392                    }
393                    _ => (),
394                }
395                request.retry = request.retry.saturating_add(1);
396
397                if let Some(error_code) = err.code() {
398                    if error_code == "MOVED" || error_code == "ASK" {
399                        // Refresh slots and request again.
400                        request.info.excludes.clear();
401                        return Next::Err {
402                            request: this.request.take().unwrap(),
403                            error: err,
404                        }
405                        .into();
406                    } else if error_code == "TRYAGAIN" || error_code == "CLUSTERDOWN" {
407                        // Sleep and retry.
408                        let sleep_duration =
409                            Duration::from_millis(2u64.pow(request.retry.max(7).min(16)) * 10);
410                        request.info.excludes.clear();
411                        this.future.set(RequestState::Sleep {
412                            sleep: tokio::time::sleep(sleep_duration),
413                        });
414                        return self.poll(cx);
415                    }
416                }
417
418                request.info.excludes.insert(addr);
419
420                Next::TryNewConnection {
421                    request: this.request.take().unwrap(),
422                    error: Some(err),
423                }
424                .into()
425            }
426        }
427    }
428}
429
430impl<F, I, C> Request<F, I, C>
431where
432    F: Future<Output = (String, RedisResult<I>)>,
433    C: ConnectionLike,
434{
435    fn respond(self: Pin<&mut Self>, msg: RedisResult<I>) {
436        // If `send` errors the receiver has dropped and thus does not care about the message
437        let _ = self
438            .project()
439            .request
440            .take()
441            .expect("Result should only be sent once")
442            .sender
443            .send(msg);
444    }
445}
446
447impl<C> Pipeline<C>
448where
449    C: ConnectionLike + Connect + Clone + Send + Sync + 'static,
450{
451    async fn new(initial_nodes: &[ConnectionInfo], retries: Option<u32>) -> RedisResult<Self> {
452        let tls = initial_nodes.iter().all(|c| match c.addr {
453            ConnectionAddr::TcpTls { .. } => true,
454            _ => false,
455        });
456        let insecure = initial_nodes.iter().all(|c| match c.addr {
457            ConnectionAddr::TcpTls { insecure, .. } => insecure,
458            _ => false,
459        });
460        let connections = Self::create_initial_connections(initial_nodes).await?;
461        let mut connection = Pipeline {
462            connections,
463            slots: Default::default(),
464            in_flight_requests: Default::default(),
465            refresh_error: None,
466            pending_requests: Vec::new(),
467            state: ConnectionState::PollComplete,
468            retries,
469            tls,
470            insecure,
471        };
472        let (slots, connections) = connection.refresh_slots().await.map_err(|(err, _)| err)?;
473        connection.slots = slots;
474        connection.connections = connections;
475        Ok(connection)
476    }
477
478    async fn create_initial_connections(
479        initial_nodes: &[ConnectionInfo],
480    ) -> RedisResult<ConnectionMap<C>> {
481        let mut error = None;
482        let connections = stream::iter(initial_nodes.iter().cloned())
483            .map(|info| async move {
484                let addr = match info.addr {
485                    ConnectionAddr::Tcp(ref host, port) => build_connection_string(
486                        info.redis.username.as_deref(),
487                        info.redis.password.as_deref(),
488                        host,
489                        port as i64,
490                        false, // use_tls
491                        false, // tls_insecure
492                    ),
493                    ConnectionAddr::TcpTls {
494                        ref host,
495                        port,
496                        insecure,
497                    } => build_connection_string(
498                        info.redis.username.as_deref(),
499                        info.redis.password.as_deref(),
500                        host,
501                        port as i64,
502                        true,     // use_tls
503                        insecure, // tls_insecure
504                    ),
505                    _ => panic!("No reach."),
506                };
507
508                let result = connect_and_check(info).await;
509                match result {
510                    Ok(conn) => Ok((addr, async { conn }.boxed().shared())),
511                    Err(e) => {
512                        trace!("Failed to connect to initial node: {:?}", e);
513                        Err(e)
514                    }
515                }
516            })
517            .buffer_unordered(initial_nodes.len())
518            .fold(
519                HashMap::with_capacity(initial_nodes.len()),
520                |mut connections: ConnectionMap<C>, result| {
521                    match result {
522                        Ok((k, v)) => {
523                            connections.insert(k, v);
524                        }
525                        Err(err) => error = Some(err),
526                    }
527                    async move { connections }
528                },
529            )
530            .await;
531        if connections.len() == 0 {
532            if let Some(err) = error {
533                return Err(err);
534            } else {
535                return Err(RedisError::from((
536                    ErrorKind::IoError,
537                    "Failed to create initial connections",
538                )));
539            }
540        }
541        Ok(connections)
542    }
543
544    // Query a node to discover slot-> master mappings.
545    fn refresh_slots(
546        &mut self,
547    ) -> impl Future<Output = Result<(SlotMap, ConnectionMap<C>), (RedisError, ConnectionMap<C>)>>
548    {
549        let mut connections = mem::replace(&mut self.connections, Default::default());
550        let use_tls = self.tls;
551        let tls_insecure = self.insecure;
552
553        async move {
554            let mut result = Ok(SlotMap::new());
555            for (addr, conn) in connections.iter_mut() {
556                let mut conn = conn.clone().await;
557                match get_slots(addr, &mut conn, use_tls, tls_insecure)
558                    .await
559                    .and_then(|v| Self::build_slot_map(v))
560                {
561                    Ok(s) => {
562                        result = Ok(s);
563                        break;
564                    }
565                    Err(err) => result = Err(err),
566                }
567            }
568            let slots = match result {
569                Ok(slots) => slots,
570                Err(err) => return Err((err, connections)),
571            };
572
573            // Remove dead connections and connect to new nodes if necessary
574            let new_connections = HashMap::with_capacity(connections.len());
575
576            let (_, connections) = stream::iter(slots.values())
577                .fold(
578                    (connections, new_connections),
579                    move |(mut connections, mut new_connections), addr| async move {
580                        if !new_connections.contains_key(addr) {
581                            let new_connection = if let Some(conn) = connections.remove(addr) {
582                                let mut conn = conn.await;
583                                match check_connection(&mut conn).await {
584                                    Ok(_) => Some((addr.to_string(), conn)),
585                                    Err(_) => match connect_and_check(addr.as_ref()).await {
586                                        Ok(conn) => Some((addr.to_string(), conn)),
587                                        Err(_) => None,
588                                    },
589                                }
590                            } else {
591                                match connect_and_check(addr.as_ref()).await {
592                                    Ok(conn) => Some((addr.to_string(), conn)),
593                                    Err(_) => None,
594                                }
595                            };
596                            if let Some((addr, new_connection)) = new_connection {
597                                new_connections
598                                    .insert(addr, async { new_connection }.boxed().shared());
599                            }
600                        }
601                        (connections, new_connections)
602                    },
603                )
604                .await;
605            Ok((slots, connections))
606        }
607    }
608
609    fn build_slot_map(mut slots_data: Vec<Slot>) -> RedisResult<SlotMap> {
610        slots_data.sort_by_key(|slot_data| slot_data.start);
611        let last_slot = slots_data.iter().try_fold(0, |prev_end, slot_data| {
612            if prev_end != slot_data.start() {
613                return Err(RedisError::from((
614                    ErrorKind::ResponseError,
615                    "Slot refresh error.",
616                    format!(
617                        "Received overlapping slots {} and {}..{}",
618                        prev_end, slot_data.start, slot_data.end
619                    ),
620                )));
621            }
622            Ok(slot_data.end() + 1)
623        })?;
624
625        if usize::from(last_slot) != SLOT_SIZE {
626            return Err(RedisError::from((
627                ErrorKind::ResponseError,
628                "Slot refresh error.",
629                format!("Lacks the slots >= {}", last_slot),
630            )));
631        }
632        let slot_map = slots_data
633            .iter()
634            .map(|slot_data| (slot_data.end(), slot_data.master().to_string()))
635            .collect();
636        trace!("{:?}", slot_map);
637        Ok(slot_map)
638    }
639
640    fn get_connection(&mut self, slot: u16) -> (String, ConnectionFuture<C>) {
641        if let Some((_, addr)) = self.slots.range(&slot..).next() {
642            if let Some(conn) = self.connections.get(addr) {
643                return (addr.clone(), conn.clone());
644            }
645
646            // Create new connection.
647            //
648            let (_, random_conn) = get_random_connection(&self.connections, None); // TODO Only do this lookup if the first check fails
649            let connection_future = {
650                let addr = addr.clone();
651                async move {
652                    match connect_and_check(addr.as_ref()).await {
653                        Ok(conn) => conn,
654                        Err(_) => random_conn.await,
655                    }
656                }
657            }
658            .boxed()
659            .shared();
660            self.connections
661                .insert(addr.clone(), connection_future.clone());
662            (addr.clone(), connection_future)
663        } else {
664            // Return a random connection
665            get_random_connection(&self.connections, None)
666        }
667    }
668
669    fn try_request(
670        &mut self,
671        info: &RequestInfo<C>,
672    ) -> impl Future<Output = (String, RedisResult<Response>)> {
673        // TODO remove clone by changing the ConnectionLike trait
674        let cmd = info.cmd.clone();
675        let (addr, conn) = if info.excludes.len() > 0 || info.slot.is_none() {
676            get_random_connection(&self.connections, Some(&info.excludes))
677        } else {
678            self.get_connection(info.slot.unwrap())
679        };
680        async move {
681            let conn = conn.await;
682            let result = cmd.exec(conn).await;
683            (addr, result)
684        }
685    }
686
687    fn poll_recover(
688        &mut self,
689        cx: &mut task::Context<'_>,
690        mut future: RecoverFuture<C>,
691    ) -> Poll<Result<(), RedisError>> {
692        match future.as_mut().poll(cx) {
693            Poll::Ready(Ok((slots, connections))) => {
694                trace!("Recovered with {} connections!", connections.len());
695                self.slots = slots;
696                self.connections = connections;
697                self.state = ConnectionState::PollComplete;
698                Poll::Ready(Ok(()))
699            }
700            Poll::Pending => {
701                self.state = ConnectionState::Recover(future);
702                trace!("Recover not ready");
703                Poll::Pending
704            }
705            Poll::Ready(Err((err, connections))) => {
706                self.connections = connections;
707                self.state = ConnectionState::Recover(Box::pin(self.refresh_slots()));
708                Poll::Ready(Err(err))
709            }
710        }
711    }
712
713    fn poll_complete(&mut self, cx: &mut task::Context<'_>) -> Poll<Result<(), RedisError>> {
714        let mut connection_error = None;
715
716        if !self.pending_requests.is_empty() {
717            let mut pending_requests = mem::take(&mut self.pending_requests);
718            for request in pending_requests.drain(..) {
719                // Drop the request if noone is waiting for a response to free up resources for
720                // requests callers care about (load shedding). It will be ambigous whether the
721                // request actually goes through regardless.
722                if request.sender.is_closed() {
723                    continue;
724                }
725
726                let future = self.try_request(&request.info);
727                self.in_flight_requests.push(Box::pin(Request {
728                    max_retries: self.retries,
729                    request: Some(request),
730                    future: RequestState::Future {
731                        future: future.boxed(),
732                    },
733                }));
734            }
735            self.pending_requests = pending_requests;
736        }
737
738        loop {
739            let result = match Pin::new(&mut self.in_flight_requests).poll_next(cx) {
740                Poll::Ready(Some(result)) => result,
741                Poll::Ready(None) | Poll::Pending => break,
742            };
743            let self_ = &mut *self;
744            match result {
745                Next::Done => {}
746                Next::TryNewConnection { request, error } => {
747                    if let Some(error) = error {
748                        if request.info.excludes.len() >= self_.connections.len() {
749                            let _ = request.sender.send(Err(error));
750                            continue;
751                        }
752                    }
753                    let future = self.try_request(&request.info);
754                    self.in_flight_requests.push(Box::pin(Request {
755                        max_retries: self.retries,
756                        request: Some(request),
757                        future: RequestState::Future {
758                            future: Box::pin(future),
759                        },
760                    }));
761                }
762                Next::Err { request, error } => {
763                    connection_error = Some(error);
764                    self.pending_requests.push(request);
765                }
766            }
767        }
768
769        if let Some(err) = connection_error {
770            Poll::Ready(Err(err))
771        } else if self.in_flight_requests.is_empty() {
772            Poll::Ready(Ok(()))
773        } else {
774            Poll::Pending
775        }
776    }
777
778    fn send_refresh_error(&mut self) {
779        if self.refresh_error.is_some() {
780            if let Some(mut request) = Pin::new(&mut self.in_flight_requests)
781                .iter_pin_mut()
782                .find(|request| request.request.is_some())
783            {
784                (*request)
785                    .as_mut()
786                    .respond(Err(self.refresh_error.take().unwrap()));
787            } else if let Some(request) = self.pending_requests.pop() {
788                let _ = request.sender.send(Err(self.refresh_error.take().unwrap()));
789            }
790        }
791    }
792}
793
794impl<C> Sink<Message<C>> for Pipeline<C>
795where
796    C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static,
797{
798    type Error = ();
799
800    fn poll_ready(
801        mut self: Pin<&mut Self>,
802        cx: &mut task::Context,
803    ) -> Poll<Result<(), Self::Error>> {
804        match mem::replace(&mut self.state, ConnectionState::PollComplete) {
805            ConnectionState::PollComplete => Poll::Ready(Ok(())),
806            ConnectionState::Recover(future) => {
807                match ready!(self.as_mut().poll_recover(cx, future)) {
808                    Ok(()) => Poll::Ready(Ok(())),
809                    Err(err) => {
810                        // We failed to reconnect, while we will try again we will report the
811                        // error if we can to avoid getting trapped in an infinite loop of
812                        // trying to reconnect
813                        if let Some(mut request) = Pin::new(&mut self.in_flight_requests)
814                            .iter_pin_mut()
815                            .find(|request| request.request.is_some())
816                        {
817                            (*request).as_mut().respond(Err(err));
818                        } else {
819                            self.refresh_error = Some(err);
820                        }
821                        Poll::Ready(Ok(()))
822                    }
823                }
824            }
825        }
826    }
827
828    fn start_send(mut self: Pin<&mut Self>, msg: Message<C>) -> Result<(), Self::Error> {
829        trace!("start_send");
830        let Message { cmd, sender } = msg;
831
832        let excludes = HashSet::new();
833        let slot = cmd.slot();
834
835        let info = RequestInfo {
836            cmd,
837            slot,
838            excludes,
839        };
840
841        self.pending_requests.push(PendingRequest {
842            retry: 0,
843            sender,
844            info,
845        });
846        Ok(()).into()
847    }
848
849    fn poll_flush(
850        mut self: Pin<&mut Self>,
851        cx: &mut task::Context,
852    ) -> Poll<Result<(), Self::Error>> {
853        trace!("poll_complete: {:?}", self.state);
854        loop {
855            self.send_refresh_error();
856
857            match mem::replace(&mut self.state, ConnectionState::PollComplete) {
858                ConnectionState::Recover(future) => {
859                    match ready!(self.as_mut().poll_recover(cx, future)) {
860                        Ok(()) => (),
861                        Err(err) => {
862                            // We failed to reconnect, while we will try again we will report the
863                            // error if we can to avoid getting trapped in an infinite loop of
864                            // trying to reconnect
865                            self.refresh_error = Some(err);
866
867                            // Give other tasks a chance to progress before we try to recover
868                            // again. Since the future may not have registered a wake up we do so
869                            // now so the task is not forgotten
870                            cx.waker().wake_by_ref();
871                            return Poll::Pending;
872                        }
873                    }
874                }
875                ConnectionState::PollComplete => match ready!(self.poll_complete(cx)) {
876                    Ok(()) => return Poll::Ready(Ok(())),
877                    Err(err) => {
878                        trace!("Recovering {}", err);
879                        self.state = ConnectionState::Recover(Box::pin(self.refresh_slots()));
880                    }
881                },
882            }
883        }
884    }
885
886    fn poll_close(
887        mut self: Pin<&mut Self>,
888        cx: &mut task::Context,
889    ) -> Poll<Result<(), Self::Error>> {
890        // Try to drive any in flight requests to completion
891        match self.poll_complete(cx) {
892            Poll::Ready(result) => {
893                result.map_err(|_| ())?;
894            }
895            Poll::Pending => (),
896        };
897        // If we no longer have any requests in flight we are done (skips any reconnection
898        // attempts)
899        if self.in_flight_requests.is_empty() {
900            return Poll::Ready(Ok(()));
901        }
902
903        self.poll_flush(cx)
904    }
905}
906
907impl<C> ConnectionLike for Connection<C>
908where
909    C: ConnectionLike + Send + 'static,
910{
911    fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
912        trace!("req_packed_command");
913        let (sender, receiver) = oneshot::channel();
914        Box::pin(async move {
915            self.0
916                .send(Message {
917                    cmd: CmdArg::Cmd {
918                        cmd: Arc::new(cmd.clone()), // TODO Remove this clone?
919                        func: |mut conn, cmd| {
920                            Box::pin(async move {
921                                conn.req_packed_command(&cmd).await.map(Response::Single)
922                            })
923                        },
924                    },
925                    sender,
926                })
927                .await
928                .map_err(|_| {
929                    RedisError::from(io::Error::new(
930                        io::ErrorKind::BrokenPipe,
931                        "redis_cluster: Unable to send command",
932                    ))
933                })?;
934            receiver
935                .await
936                .unwrap_or_else(|_| {
937                    Err(RedisError::from(io::Error::new(
938                        io::ErrorKind::BrokenPipe,
939                        "redis_cluster: Unable to receive command",
940                    )))
941                })
942                .map(|response| match response {
943                    Response::Single(value) => value,
944                    Response::Multiple(_) => unreachable!(),
945                })
946        })
947    }
948
949    fn req_packed_commands<'a>(
950        &'a mut self,
951        pipeline: &'a redis::Pipeline,
952        offset: usize,
953        count: usize,
954    ) -> RedisFuture<'a, Vec<Value>> {
955        let (sender, receiver) = oneshot::channel();
956        Box::pin(async move {
957            self.0
958                .send(Message {
959                    cmd: CmdArg::Pipeline {
960                        pipeline: Arc::new(pipeline.clone()), // TODO Remove this clone?
961                        offset,
962                        count,
963                        func: |mut conn, pipeline, offset, count| {
964                            Box::pin(async move {
965                                conn.req_packed_commands(&pipeline, offset, count)
966                                    .await
967                                    .map(Response::Multiple)
968                            })
969                        },
970                    },
971                    sender,
972                })
973                .await
974                .map_err(|_| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))?;
975
976            receiver
977                .await
978                .unwrap_or_else(|_| {
979                    Err(RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))
980                })
981                .map(|response| match response {
982                    Response::Multiple(values) => values,
983                    Response::Single(_) => unreachable!(),
984                })
985        })
986    }
987
988    fn get_db(&self) -> i64 {
989        0
990    }
991}
992
993impl Clone for Client {
994    fn clone(&self) -> Client {
995        Client::open(self.initial_nodes.clone()).unwrap()
996    }
997}
998
999pub trait Connect: Sized {
1000    fn connect<'a, T>(info: T) -> RedisFuture<'a, Self>
1001    where
1002        T: IntoConnectionInfo + Send + 'a;
1003}
1004
1005impl Connect for redis::aio::MultiplexedConnection {
1006    fn connect<'a, T>(info: T) -> RedisFuture<'a, redis::aio::MultiplexedConnection>
1007    where
1008        T: IntoConnectionInfo + Send + 'a,
1009    {
1010        async move {
1011            let connection_info = info.into_connection_info()?;
1012            let client = redis::Client::open(connection_info)?;
1013            client.get_multiplexed_tokio_connection().await
1014        }
1015        .boxed()
1016    }
1017}
1018
1019async fn connect_and_check<T, C>(info: T) -> RedisResult<C>
1020where
1021    T: IntoConnectionInfo + Send,
1022    C: ConnectionLike + Connect + Send + 'static,
1023{
1024    let mut conn = C::connect(info).await?;
1025    check_connection(&mut conn).await?;
1026    Ok(conn)
1027}
1028
1029async fn check_connection<C>(conn: &mut C) -> RedisResult<()>
1030where
1031    C: ConnectionLike + Send + 'static,
1032{
1033    let mut cmd = Cmd::new();
1034    cmd.arg("PING");
1035    cmd.query_async::<_, String>(conn).await?;
1036    Ok(())
1037}
1038
1039fn get_random_connection<'a, C>(
1040    connections: &'a ConnectionMap<C>,
1041    excludes: Option<&'a HashSet<String>>,
1042) -> (String, ConnectionFuture<C>)
1043where
1044    C: Clone,
1045{
1046    debug_assert!(!connections.is_empty());
1047
1048    let mut rng = thread_rng();
1049    let sample = match excludes {
1050        Some(excludes) if excludes.len() < connections.len() => {
1051            let target_keys = connections.keys().filter(|key| !excludes.contains(*key));
1052            target_keys.choose(&mut rng)
1053        }
1054        _ => connections.keys().choose(&mut rng),
1055    };
1056
1057    let addr = sample.expect("No targets to choose from");
1058    (addr.to_string(), connections.get(addr).unwrap().clone())
1059}
1060
1061fn slot_for_key(key: &[u8]) -> u16 {
1062    let key = sub_key(&key);
1063    State::<XMODEM>::calculate(&key) % SLOT_SIZE as u16
1064}
1065
1066// If a key contains `{` and `}`, everything between the first occurence is the only thing that
1067// determines the hash slot
1068fn sub_key(key: &[u8]) -> &[u8] {
1069    key.iter()
1070        .position(|b| *b == b'{')
1071        .and_then(|open| {
1072            let after_open = open + 1;
1073            key[after_open..]
1074                .iter()
1075                .position(|b| *b == b'}')
1076                .and_then(|close_offset| {
1077                    if close_offset != 0 {
1078                        Some(&key[after_open..after_open + close_offset])
1079                    } else {
1080                        None
1081                    }
1082                })
1083        })
1084        .unwrap_or(key)
1085}
1086
1087#[derive(Debug)]
1088struct Slot {
1089    start: u16,
1090    end: u16,
1091    master: String,
1092    replicas: Vec<String>,
1093}
1094
1095impl Slot {
1096    pub fn start(&self) -> u16 {
1097        self.start
1098    }
1099    pub fn end(&self) -> u16 {
1100        self.end
1101    }
1102    pub fn master(&self) -> &str {
1103        &self.master
1104    }
1105    #[allow(dead_code)]
1106    pub fn replicas(&self) -> &Vec<String> {
1107        &self.replicas
1108    }
1109}
1110
1111// Get slot data from connection.
1112async fn get_slots<C>(
1113    addr: &str,
1114    connection: &mut C,
1115    use_tls: bool,
1116    tls_insecure: bool,
1117) -> RedisResult<Vec<Slot>>
1118where
1119    C: ConnectionLike,
1120{
1121    trace!("get_slots");
1122    let mut cmd = Cmd::new();
1123    cmd.arg("CLUSTER").arg("SLOTS");
1124    let value = connection.req_packed_command(&cmd).await.map_err(|err| {
1125        trace!("get_slots error: {}", err);
1126        err
1127    })?;
1128    trace!("get_slots -> {:#?}", value);
1129    // Parse response.
1130    let mut result = Vec::with_capacity(2);
1131
1132    if let Value::Bulk(items) = value {
1133        // TODO optimize by calling parse_redis_url only once
1134        // TODO these values could be cached
1135        let username = get_username(addr);
1136        let password = get_password(addr);
1137        let host = get_hostname(addr);
1138
1139        let mut iter = items.into_iter();
1140        while let Some(Value::Bulk(item)) = iter.next() {
1141            if item.len() < 3 {
1142                continue;
1143            }
1144
1145            let start = if let Value::Int(start) = item[0] {
1146                start as u16
1147            } else {
1148                continue;
1149            };
1150
1151            let end = if let Value::Int(end) = item[1] {
1152                end as u16
1153            } else {
1154                continue;
1155            };
1156
1157            let mut nodes: Vec<String> = item
1158                .into_iter()
1159                .skip(2)
1160                .filter_map(|node| {
1161                    if let Value::Bulk(node) = node {
1162                        if node.len() < 2 {
1163                            return None;
1164                        }
1165
1166                        let ip = if let Value::Data(ref ip) = node[0] {
1167                            String::from_utf8_lossy(ip)
1168                        } else {
1169                            return None;
1170                        };
1171
1172                        let port = if let Value::Int(port) = node[1] {
1173                            port
1174                        } else {
1175                            return None;
1176                        };
1177
1178                        let ip = if ip != "" {
1179                            &*ip
1180                        } else {
1181                            &*host.as_ref().unwrap()
1182                        };
1183
1184                        Some(build_connection_string(
1185                            username.as_deref(),
1186                            password.as_deref(),
1187                            &ip,
1188                            port,
1189                            use_tls,
1190                            tls_insecure,
1191                        ))
1192                    } else {
1193                        None
1194                    }
1195                })
1196                .collect();
1197
1198            if nodes.len() < 1 {
1199                continue;
1200            }
1201
1202            let replicas = nodes.split_off(1);
1203            result.push(Slot {
1204                start,
1205                end,
1206                master: nodes.pop().unwrap(),
1207                replicas,
1208            });
1209        }
1210    }
1211
1212    Ok(result)
1213}
1214
1215fn build_connection_string(
1216    username: Option<&str>,
1217    password: Option<&str>,
1218    host: &str,
1219    port: i64,
1220    use_tls: bool,
1221    tls_insecure: bool,
1222) -> String {
1223    let scheme = if use_tls { "rediss" } else { "redis" };
1224    let fragment = if use_tls && tls_insecure {
1225        "#insecure"
1226    } else {
1227        ""
1228    };
1229    match (username, password) {
1230        (Some(username), Some(pw)) => {
1231            format!(
1232                "{}://{}:{}@{}:{}{}",
1233                scheme, username, pw, host, port, fragment
1234            )
1235        }
1236        (None, Some(pw)) => {
1237            format!("{}://:{}@{}:{}{}", scheme, pw, host, port, fragment)
1238        }
1239        (Some(username), None) => {
1240            format!("{}://{}@{}:{}{}", scheme, username, host, port, fragment)
1241        }
1242        (None, None) => {
1243            format!("{}://{}:{}{}", scheme, host, port, fragment)
1244        }
1245    }
1246}
1247
1248fn get_password(addr: &str) -> Option<String> {
1249    redis::parse_redis_url(addr).and_then(|url| url.password().map(|s| s.into()))
1250}
1251
1252fn get_username(addr: &str) -> Option<String> {
1253    redis::parse_redis_url(addr).and_then(|url| {
1254        let username = url.username();
1255        if username != "" {
1256            Some(url.username().to_string())
1257        } else {
1258            None
1259        }
1260    })
1261}
1262
1263fn get_hostname(addr: &str) -> Option<String> {
1264    redis::parse_redis_url(addr).and_then(|url| url.host_str().map(String::from))
1265}
1266
1267#[cfg(test)]
1268mod tests {
1269    use super::*;
1270
1271    fn slot_for_packed_command(cmd: &[u8]) -> Option<u16> {
1272        command_key(cmd).map(|key| {
1273            let key = sub_key(&key);
1274            State::<XMODEM>::calculate(&key) % SLOT_SIZE as u16
1275        })
1276    }
1277
1278    fn command_key(cmd: &[u8]) -> Option<Vec<u8>> {
1279        redis::parse_redis_value(cmd)
1280            .ok()
1281            .and_then(|value| match value {
1282                Value::Bulk(mut args) => {
1283                    if args.len() >= 2 {
1284                        match args.swap_remove(1) {
1285                            Value::Data(key) => Some(key),
1286                            _ => None,
1287                        }
1288                    } else {
1289                        None
1290                    }
1291                }
1292                _ => None,
1293            })
1294    }
1295
1296    #[test]
1297    fn slot() {
1298        assert_eq!(
1299            slot_for_packed_command(&[
1300                42, 50, 13, 10, 36, 54, 13, 10, 69, 88, 73, 83, 84, 83, 13, 10, 36, 49, 54, 13, 10,
1301                244, 93, 23, 40, 126, 127, 253, 33, 89, 47, 185, 204, 171, 249, 96, 139, 13, 10
1302            ]),
1303            Some(964)
1304        );
1305        assert_eq!(
1306            slot_for_packed_command(&[
1307                42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 36, 241,
1308                197, 111, 180, 254, 5, 175, 143, 146, 171, 39, 172, 23, 164, 145, 13, 10, 36, 52,
1309                13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10,
1310                80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10
1311            ]),
1312            Some(8352)
1313        );
1314
1315        assert_eq!(
1316            slot_for_packed_command(&[
1317                42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 169, 233,
1318                247, 59, 50, 247, 100, 232, 123, 140, 2, 101, 125, 221, 66, 170, 13, 10, 36, 52,
1319                13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10,
1320                80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10
1321            ]),
1322            Some(5210),
1323        );
1324    }
1325
1326    #[test]
1327    fn test_get_username_password() {
1328        let testcases: Vec<(&str, Option<String>, Option<String>)> = vec![
1329            ("redis://127.0.0.1:7000", None, None),
1330            (
1331                "redis://:password@127.0.0.1:7000",
1332                None,
1333                Some("password".to_string()),
1334            ),
1335            (
1336                "redis://username:password@127.0.0.1:7000",
1337                Some("username".to_string()),
1338                Some("password".to_string()),
1339            ),
1340            (
1341                "redis://username:@127.0.0.1:7000",
1342                Some("username".to_string()),
1343                None,
1344            ),
1345            (
1346                "redis://username@127.0.0.1:7000",
1347                Some("username".to_string()),
1348                None,
1349            ),
1350        ];
1351
1352        for (redis_url, username, password) in testcases {
1353            assert_eq!(username, get_username(redis_url));
1354            assert_eq!(password, get_password(redis_url));
1355        }
1356    }
1357}