mco_redis_rs/
aio.rs

1//! Adds experimental async IO support to redis.
2use async_trait::async_trait;
3use std::collections::VecDeque;
4use std::io;
5use std::mem;
6use std::net::SocketAddr;
7use std::net::ToSocketAddrs;
8#[cfg(unix)]
9use std::path::Path;
10use std::pin::Pin;
11use std::task::{self, Poll};
12
13use combine::{parser::combinator::AnySendSyncPartialState, stream::PointerOffset};
14
15use ::tokio::{
16    io::{AsyncRead, AsyncWrite, AsyncWriteExt},
17    sync::{mpsc, oneshot},
18};
19
20#[cfg(feature = "tls")]
21use native_tls::TlsConnector;
22
23#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
24use tokio_util::codec::Decoder;
25
26use futures_util::{
27    future::{Future, FutureExt},
28    ready,
29    sink::Sink,
30    stream::{self, Stream, StreamExt, TryStreamExt as _},
31};
32
33use pin_project_lite::pin_project;
34
35use crate::cmd::{cmd, Cmd};
36use crate::connection::{ConnectionAddr, ConnectionInfo, Msg, RedisConnectionInfo};
37
38#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
39use crate::parser::ValueCodec;
40use crate::types::{ErrorKind, FromRedisValue, RedisError, RedisFuture, RedisResult, Value};
41use crate::{from_redis_value, ToRedisArgs};
42
43/// Enables the async_std compatibility
44#[cfg(feature = "async-std-comp")]
45#[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))]
46pub mod async_std;
47
48/// Enables the tokio compatibility
49#[cfg(feature = "tokio-comp")]
50#[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))]
51pub mod tokio;
52
53/// Represents the ability of connecting via TCP or via Unix socket
54#[async_trait]
55pub(crate) trait RedisRuntime: AsyncStream + Send + Sync + Sized + 'static {
56    /// Performs a TCP connection
57    async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult<Self>;
58
59    // Performs a TCP TLS connection
60    #[cfg(feature = "tls")]
61    async fn connect_tcp_tls(
62        hostname: &str,
63        socket_addr: SocketAddr,
64        insecure: bool,
65    ) -> RedisResult<Self>;
66
67    /// Performs a UNIX connection
68    #[cfg(unix)]
69    async fn connect_unix(path: &Path) -> RedisResult<Self>;
70
71    fn spawn(f: impl Future<Output = ()> + Send + 'static);
72
73    fn boxed(self) -> Pin<Box<dyn AsyncStream + Send + Sync>> {
74        Box::pin(self)
75    }
76}
77
78#[derive(Clone, Debug)]
79pub(crate) enum Runtime {
80    #[cfg(feature = "tokio-comp")]
81    Tokio,
82    #[cfg(feature = "async-std-comp")]
83    AsyncStd,
84}
85
86impl Runtime {
87    pub(crate) fn locate() -> Self {
88        #[cfg(all(feature = "tokio-comp", not(feature = "async-std-comp")))]
89        {
90            Runtime::Tokio
91        }
92
93        #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
94        {
95            Runtime::AsyncStd
96        }
97
98        #[cfg(all(feature = "tokio-comp", feature = "async-std-comp"))]
99        {
100            if ::tokio::runtime::Handle::try_current().is_ok() {
101                Runtime::Tokio
102            } else {
103                Runtime::AsyncStd
104            }
105        }
106
107        #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
108        {
109            compile_error!("tokio-comp or async-std-comp features required for aio feature")
110        }
111    }
112
113    #[allow(dead_code)]
114    fn spawn(&self, f: impl Future<Output = ()> + Send + 'static) {
115        match self {
116            #[cfg(feature = "tokio-comp")]
117            Runtime::Tokio => tokio::Tokio::spawn(f),
118            #[cfg(feature = "async-std-comp")]
119            Runtime::AsyncStd => async_std::AsyncStd::spawn(f),
120        }
121    }
122}
123
124/// Trait for objects that implements `AsyncRead` and `AsyncWrite`
125pub trait AsyncStream: AsyncRead + AsyncWrite {}
126impl<S> AsyncStream for S where S: AsyncRead + AsyncWrite {}
127
128/// Represents a `PubSub` connection.
129pub struct PubSub<C = Pin<Box<dyn AsyncStream + Send + Sync>>>(Connection<C>);
130
131/// Represents a `Monitor` connection.
132pub struct Monitor<C = Pin<Box<dyn AsyncStream + Send + Sync>>>(Connection<C>);
133
134impl<C> PubSub<C>
135where
136    C: Unpin + AsyncRead + AsyncWrite + Send,
137{
138    fn new(con: Connection<C>) -> Self {
139        Self(con)
140    }
141
142    /// Subscribes to a new channel.
143    pub async fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
144        Ok(cmd("SUBSCRIBE")
145            .arg(channel)
146            .query_async(&mut self.0)
147            .await?)
148    }
149
150    /// Subscribes to a new channel with a pattern.
151    pub async fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
152        Ok(cmd("PSUBSCRIBE")
153            .arg(pchannel)
154            .query_async(&mut self.0)
155            .await?)
156    }
157
158    /// Unsubscribes from a channel.
159    pub async fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
160        Ok(cmd("UNSUBSCRIBE")
161            .arg(channel)
162            .query_async(&mut self.0)
163            .await?)
164    }
165
166    /// Unsubscribes from a channel with a pattern.
167    pub async fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
168        Ok(cmd("PUNSUBSCRIBE")
169            .arg(pchannel)
170            .query_async(&mut self.0)
171            .await?)
172    }
173
174    /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions.
175    ///
176    /// The message itself is still generic and can be converted into an appropriate type through
177    /// the helper methods on it.
178    pub fn on_message(&mut self) -> impl Stream<Item = Msg> + '_ {
179        ValueCodec::default()
180            .framed(&mut self.0.con)
181            .filter_map(|msg| Box::pin(async move { Msg::from_value(&msg.ok()?.ok()?) }))
182    }
183
184    /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions consuming it.
185    ///
186    /// The message itself is still generic and can be converted into an appropriate type through
187    /// the helper methods on it.
188    /// This can be useful in cases where the stream needs to be returned or held by something other
189    //  than the [`PubSub`].
190    pub fn into_on_message(self) -> impl Stream<Item = Msg> {
191        ValueCodec::default()
192            .framed(self.0.con)
193            .filter_map(|msg| Box::pin(async move { Msg::from_value(&msg.ok()?.ok()?) }))
194    }
195
196    /// Exits from `PubSub` mode and converts [`PubSub`] into [`Connection`].
197    pub async fn into_connection(mut self) -> Connection<C> {
198        self.0.exit_pubsub().await.ok();
199
200        self.0
201    }
202}
203
204impl<C> Monitor<C>
205where
206    C: Unpin + AsyncRead + AsyncWrite + Send,
207{
208    /// Create a [`Monitor`] from a [`Connection`]
209    pub fn new(con: Connection<C>) -> Self {
210        Self(con)
211    }
212
213    /// Deliver the MONITOR command to this [`Monitor`]ing wrapper.
214    pub async fn monitor(&mut self) -> RedisResult<()> {
215        Ok(cmd("MONITOR").query_async(&mut self.0).await?)
216    }
217
218    /// Returns [`Stream`] of [`FromRedisValue`] values from this [`Monitor`]ing connection
219    pub fn on_message<T: FromRedisValue>(&mut self) -> impl Stream<Item = T> + '_ {
220        ValueCodec::default()
221            .framed(&mut self.0.con)
222            .filter_map(|value| {
223                Box::pin(async move { T::from_redis_value(&value.ok()?.ok()?).ok() })
224            })
225    }
226
227    /// Returns [`Stream`] of [`FromRedisValue`] values from this [`Monitor`]ing connection
228    pub fn into_on_message<T: FromRedisValue>(self) -> impl Stream<Item = T> {
229        ValueCodec::default()
230            .framed(self.0.con)
231            .filter_map(|value| {
232                Box::pin(async move { T::from_redis_value(&value.ok()?.ok()?).ok() })
233            })
234    }
235}
236
237/// Represents a stateful redis TCP connection.
238pub struct Connection<C = Pin<Box<dyn AsyncStream + Send + Sync>>> {
239    con: C,
240    buf: Vec<u8>,
241    decoder: combine::stream::Decoder<AnySendSyncPartialState, PointerOffset<[u8]>>,
242    db: i64,
243
244    // Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`.
245    //
246    // This flag is checked when attempting to send a command, and if it's raised, we attempt to
247    // exit the pubsub state before executing the new request.
248    pubsub: bool,
249}
250
251fn assert_sync<T: Sync>() {}
252
253#[allow(unused)]
254fn test() {
255    assert_sync::<Connection>();
256}
257
258impl<C> Connection<C> {
259    pub(crate) fn map<D>(self, f: impl FnOnce(C) -> D) -> Connection<D> {
260        let Self {
261            con,
262            buf,
263            decoder,
264            db,
265            pubsub,
266        } = self;
267        Connection {
268            con: f(con),
269            buf,
270            decoder,
271            db,
272            pubsub,
273        }
274    }
275}
276
277impl<C> Connection<C>
278where
279    C: Unpin + AsyncRead + AsyncWrite + Send,
280{
281    /// Constructs a new `Connection` out of a `AsyncRead + AsyncWrite` object
282    /// and a `RedisConnectionInfo`
283    pub async fn new(connection_info: &RedisConnectionInfo, con: C) -> RedisResult<Self> {
284        let mut rv = Connection {
285            con,
286            buf: Vec::new(),
287            decoder: combine::stream::Decoder::new(),
288            db: connection_info.db,
289            pubsub: false,
290        };
291        authenticate(connection_info, &mut rv).await?;
292        Ok(rv)
293    }
294
295    /// Converts this [`Connection`] into [`PubSub`].
296    pub fn into_pubsub(self) -> PubSub<C> {
297        PubSub::new(self)
298    }
299
300    /// Converts this [`Connection`] into [`Monitor`]
301    pub fn into_monitor(self) -> Monitor<C> {
302        Monitor::new(self)
303    }
304
305    /// Fetches a single response from the connection.
306    async fn read_response(&mut self) -> RedisResult<Value> {
307        crate::parser::parse_redis_value_async(&mut self.decoder, &mut self.con).await
308    }
309
310    /// Brings [`Connection`] out of `PubSub` mode.
311    ///
312    /// This will unsubscribe this [`Connection`] from all subscriptions.
313    ///
314    /// If this function returns error then on all command send tries will be performed attempt
315    /// to exit from `PubSub` mode until it will be successful.
316    async fn exit_pubsub(&mut self) -> RedisResult<()> {
317        let res = self.clear_active_subscriptions().await;
318        if res.is_ok() {
319            self.pubsub = false;
320        } else {
321            // Raise the pubsub flag to indicate the connection is "stuck" in that state.
322            self.pubsub = true;
323        }
324
325        res
326    }
327
328    /// Get the inner connection out of a PubSub
329    ///
330    /// Any active subscriptions are unsubscribed. In the event of an error, the connection is
331    /// dropped.
332    async fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
333        // Responses to unsubscribe commands return in a 3-tuple with values
334        // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs).
335        // The "count of remaining subs" includes both pattern subscriptions and non pattern
336        // subscriptions. Thus, to accurately drain all unsubscribe messages received from the
337        // server, both commands need to be executed at once.
338        {
339            // Prepare both unsubscribe commands
340            let unsubscribe = crate::Pipeline::new()
341                .add_command(cmd("UNSUBSCRIBE"))
342                .add_command(cmd("PUNSUBSCRIBE"))
343                .get_packed_pipeline();
344
345            // Execute commands
346            self.con.write_all(&unsubscribe).await?;
347        }
348
349        // Receive responses
350        //
351        // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe
352        // commands. There may be more responses if there are active subscriptions. In this case,
353        // messages are received until the _subscription count_ in the responses reach zero.
354        let mut received_unsub = false;
355        let mut received_punsub = false;
356        loop {
357            let res: (Vec<u8>, (), isize) = from_redis_value(&self.read_response().await?)?;
358
359            match res.0.first() {
360                Some(&b'u') => received_unsub = true,
361                Some(&b'p') => received_punsub = true,
362                _ => (),
363            }
364
365            if received_unsub && received_punsub && res.2 == 0 {
366                break;
367            }
368        }
369
370        // Finally, the connection is back in its normal state since all subscriptions were
371        // cancelled *and* all unsubscribe messages were received.
372        Ok(())
373    }
374}
375
376#[cfg(feature = "async-std-comp")]
377#[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))]
378impl<C> Connection<async_std::AsyncStdWrapped<C>>
379where
380    C: Unpin + ::async_std::io::Read + ::async_std::io::Write + Send,
381{
382    /// Constructs a new `Connection` out of a `async_std::io::AsyncRead + async_std::io::AsyncWrite` object
383    /// and a `RedisConnectionInfo`
384    pub async fn new_async_std(connection_info: &RedisConnectionInfo, con: C) -> RedisResult<Self> {
385        Connection::new(connection_info, async_std::AsyncStdWrapped::new(con)).await
386    }
387}
388
389pub(crate) async fn connect<C>(connection_info: &ConnectionInfo) -> RedisResult<Connection<C>>
390where
391    C: Unpin + RedisRuntime + AsyncRead + AsyncWrite + Send,
392{
393    let con = connect_simple::<C>(connection_info).await?;
394    Connection::new(&connection_info.redis, con).await
395}
396
397async fn authenticate<C>(connection_info: &RedisConnectionInfo, con: &mut C) -> RedisResult<()>
398where
399    C: ConnectionLike,
400{
401    if let Some(password) = &connection_info.password {
402        let mut command = cmd("AUTH");
403        if let Some(username) = &connection_info.username {
404            command.arg(username);
405        }
406        match command.arg(password).query_async(con).await {
407            Ok(Value::Okay) => (),
408            Err(e) => {
409                let err_msg = e.detail().ok_or((
410                    ErrorKind::AuthenticationFailed,
411                    "Password authentication failed",
412                ))?;
413
414                if !err_msg.contains("wrong number of arguments for 'auth' command") {
415                    fail!((
416                        ErrorKind::AuthenticationFailed,
417                        "Password authentication failed",
418                    ));
419                }
420
421                let mut command = cmd("AUTH");
422                match command.arg(password).query_async(con).await {
423                    Ok(Value::Okay) => (),
424                    _ => {
425                        fail!((
426                            ErrorKind::AuthenticationFailed,
427                            "Password authentication failed"
428                        ));
429                    }
430                }
431            }
432            _ => {
433                fail!((
434                    ErrorKind::AuthenticationFailed,
435                    "Password authentication failed"
436                ));
437            }
438        }
439    }
440
441    if connection_info.db != 0 {
442        match cmd("SELECT").arg(connection_info.db).query_async(con).await {
443            Ok(Value::Okay) => (),
444            _ => fail!((
445                ErrorKind::ResponseError,
446                "Redis server refused to switch database"
447            )),
448        }
449    }
450
451    Ok(())
452}
453
454pub(crate) async fn connect_simple<T: RedisRuntime>(
455    connection_info: &ConnectionInfo,
456) -> RedisResult<T> {
457    Ok(match connection_info.addr {
458        ConnectionAddr::Tcp(ref host, port) => {
459            let socket_addr = get_socket_addrs(host, port)?;
460            <T>::connect_tcp(socket_addr).await?
461        }
462
463        #[cfg(feature = "tls")]
464        ConnectionAddr::TcpTls {
465            ref host,
466            port,
467            insecure,
468        } => {
469            let socket_addr = get_socket_addrs(host, port)?;
470            <T>::connect_tcp_tls(host, socket_addr, insecure).await?
471        }
472
473        #[cfg(not(feature = "tls"))]
474        ConnectionAddr::TcpTls { .. } => {
475            fail!((
476                ErrorKind::InvalidClientConfig,
477                "Cannot connect to TCP with TLS without the tls feature"
478            ));
479        }
480
481        #[cfg(unix)]
482        ConnectionAddr::Unix(ref path) => <T>::connect_unix(path).await?,
483
484        #[cfg(not(unix))]
485        ConnectionAddr::Unix(_) => {
486            return Err(RedisError::from((
487                ErrorKind::InvalidClientConfig,
488                "Cannot connect to unix sockets \
489                 on this platform",
490            )))
491        }
492    })
493}
494
495fn get_socket_addrs(host: &str, port: u16) -> RedisResult<SocketAddr> {
496    let mut socket_addrs = (host, port).to_socket_addrs()?;
497    match socket_addrs.next() {
498        Some(socket_addr) => Ok(socket_addr),
499        None => Err(RedisError::from((
500            ErrorKind::InvalidClientConfig,
501            "No address found for host",
502        ))),
503    }
504}
505
506/// An async abstraction over connections.
507pub trait ConnectionLike {
508    /// Sends an already encoded (packed) command into the TCP socket and
509    /// reads the single response from it.
510    fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value>;
511
512    /// Sends multiple already encoded (packed) command into the TCP socket
513    /// and reads `count` responses from it.  This is used to implement
514    /// pipelining.
515    fn req_packed_commands<'a>(
516        &'a mut self,
517        cmd: &'a crate::Pipeline,
518        offset: usize,
519        count: usize,
520    ) -> RedisFuture<'a, Vec<Value>>;
521
522    /// Returns the database this connection is bound to.  Note that this
523    /// information might be unreliable because it's initially cached and
524    /// also might be incorrect if the connection like object is not
525    /// actually connected.
526    fn get_db(&self) -> i64;
527}
528
529impl<C> ConnectionLike for Connection<C>
530where
531    C: Unpin + AsyncRead + AsyncWrite + Send,
532{
533    fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
534        (async move {
535            if self.pubsub {
536                self.exit_pubsub().await?;
537            }
538            self.buf.clear();
539            cmd.write_packed_command(&mut self.buf);
540            self.con.write_all(&self.buf).await?;
541            self.read_response().await
542        })
543        .boxed()
544    }
545
546    fn req_packed_commands<'a>(
547        &'a mut self,
548        cmd: &'a crate::Pipeline,
549        offset: usize,
550        count: usize,
551    ) -> RedisFuture<'a, Vec<Value>> {
552        (async move {
553            if self.pubsub {
554                self.exit_pubsub().await?;
555            }
556
557            self.buf.clear();
558            cmd.write_packed_pipeline(&mut self.buf);
559            self.con.write_all(&self.buf).await?;
560
561            let mut first_err = None;
562
563            for _ in 0..offset {
564                let response = self.read_response().await;
565                if let Err(err) = response {
566                    if first_err.is_none() {
567                        first_err = Some(err);
568                    }
569                }
570            }
571
572            let mut rv = Vec::with_capacity(count);
573            for _ in 0..count {
574                let response = self.read_response().await;
575                match response {
576                    Ok(item) => {
577                        rv.push(item);
578                    }
579                    Err(err) => {
580                        if first_err.is_none() {
581                            first_err = Some(err);
582                        }
583                    }
584                }
585            }
586
587            if let Some(err) = first_err {
588                Err(err)
589            } else {
590                Ok(rv)
591            }
592        })
593        .boxed()
594    }
595
596    fn get_db(&self) -> i64 {
597        self.db
598    }
599}
600
601// Senders which the result of a single request are sent through
602type PipelineOutput<O, E> = oneshot::Sender<Result<Vec<O>, E>>;
603
604struct InFlight<O, E> {
605    output: PipelineOutput<O, E>,
606    response_count: usize,
607    buffer: Vec<O>,
608}
609
610// A single message sent through the pipeline
611struct PipelineMessage<S, I, E> {
612    input: S,
613    output: PipelineOutput<I, E>,
614    response_count: usize,
615}
616
617/// Wrapper around a `Stream + Sink` where each item sent through the `Sink` results in one or more
618/// items being output by the `Stream` (the number is specified at time of sending). With the
619/// interface provided by `Pipeline` an easy interface of request to response, hiding the `Stream`
620/// and `Sink`.
621struct Pipeline<SinkItem, I, E>(mpsc::Sender<PipelineMessage<SinkItem, I, E>>);
622
623impl<SinkItem, I, E> Clone for Pipeline<SinkItem, I, E> {
624    fn clone(&self) -> Self {
625        Pipeline(self.0.clone())
626    }
627}
628
629pin_project! {
630    struct PipelineSink<T, I, E> {
631        #[pin]
632        sink_stream: T,
633        in_flight: VecDeque<InFlight<I, E>>,
634        error: Option<E>,
635    }
636}
637
638impl<T, I, E> PipelineSink<T, I, E>
639where
640    T: Stream<Item = Result<I, E>> + 'static,
641{
642    fn new<SinkItem>(sink_stream: T) -> Self
643    where
644        T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
645    {
646        PipelineSink {
647            sink_stream,
648            in_flight: VecDeque::new(),
649            error: None,
650        }
651    }
652
653    // Read messages from the stream and send them back to the caller
654    fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
655        loop {
656            // No need to try reading a message if there is no message in flight
657            if self.in_flight.is_empty() {
658                return Poll::Ready(Ok(()));
659            }
660            let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
661                Some(result) => result,
662                // The redis response stream is not going to produce any more items so we `Err`
663                // to break out of the `forward` combinator and stop handling requests
664                None => return Poll::Ready(Err(())),
665            };
666            self.as_mut().send_result(item);
667        }
668    }
669
670    fn send_result(self: Pin<&mut Self>, result: Result<I, E>) {
671        let self_ = self.project();
672        let response = {
673            let entry = match self_.in_flight.front_mut() {
674                Some(entry) => entry,
675                None => return,
676            };
677            match result {
678                Ok(item) => {
679                    entry.buffer.push(item);
680                    if entry.response_count > entry.buffer.len() {
681                        // Need to gather more response values
682                        return;
683                    }
684                    Ok(mem::take(&mut entry.buffer))
685                }
686                // If we fail we must respond immediately
687                Err(err) => Err(err),
688            }
689        };
690
691        let entry = self_.in_flight.pop_front().unwrap();
692        // `Err` means that the receiver was dropped in which case it does not
693        // care about the output and we can continue by just dropping the value
694        // and sender
695        entry.output.send(response).ok();
696    }
697}
698
699impl<SinkItem, T, I, E> Sink<PipelineMessage<SinkItem, I, E>> for PipelineSink<T, I, E>
700where
701    T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
702{
703    type Error = ();
704
705    // Retrieve incoming messages and write them to the sink
706    fn poll_ready(
707        mut self: Pin<&mut Self>,
708        cx: &mut task::Context,
709    ) -> Poll<Result<(), Self::Error>> {
710        match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) {
711            Ok(()) => Ok(()).into(),
712            Err(err) => {
713                *self.project().error = Some(err);
714                Ok(()).into()
715            }
716        }
717    }
718
719    fn start_send(
720        mut self: Pin<&mut Self>,
721        PipelineMessage {
722            input,
723            output,
724            response_count,
725        }: PipelineMessage<SinkItem, I, E>,
726    ) -> Result<(), Self::Error> {
727        // If there is nothing to receive our output we do not need to send the message as it is
728        // ambiguous whether the message will be sent anyway. Helps shed some load on the
729        // connection.
730        if output.is_closed() {
731            return Ok(());
732        }
733
734        let self_ = self.as_mut().project();
735
736        if let Some(err) = self_.error.take() {
737            let _ = output.send(Err(err));
738            return Err(());
739        }
740
741        match self_.sink_stream.start_send(input) {
742            Ok(()) => {
743                self_.in_flight.push_back(InFlight {
744                    output,
745                    response_count,
746                    buffer: Vec::new(),
747                });
748                Ok(())
749            }
750            Err(err) => {
751                let _ = output.send(Err(err));
752                Err(())
753            }
754        }
755    }
756
757    fn poll_flush(
758        mut self: Pin<&mut Self>,
759        cx: &mut task::Context,
760    ) -> Poll<Result<(), Self::Error>> {
761        ready!(self
762            .as_mut()
763            .project()
764            .sink_stream
765            .poll_flush(cx)
766            .map_err(|err| {
767                self.as_mut().send_result(Err(err));
768            }))?;
769        self.poll_read(cx)
770    }
771
772    fn poll_close(
773        mut self: Pin<&mut Self>,
774        cx: &mut task::Context,
775    ) -> Poll<Result<(), Self::Error>> {
776        // No new requests will come in after the first call to `close` but we need to complete any
777        // in progress requests before closing
778        if !self.in_flight.is_empty() {
779            ready!(self.as_mut().poll_flush(cx))?;
780        }
781        let this = self.as_mut().project();
782        this.sink_stream.poll_close(cx).map_err(|err| {
783            self.send_result(Err(err));
784        })
785    }
786}
787
788impl<SinkItem, I, E> Pipeline<SinkItem, I, E>
789where
790    SinkItem: Send + 'static,
791    I: Send + 'static,
792    E: Send + 'static,
793{
794    fn new<T>(sink_stream: T) -> (Self, impl Future<Output = ()>)
795    where
796        T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
797        T: Send + 'static,
798        T::Item: Send,
799        T::Error: Send,
800        T::Error: ::std::fmt::Debug,
801    {
802        const BUFFER_SIZE: usize = 50;
803        let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE);
804        let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
805            .map(Ok)
806            .forward(PipelineSink::new::<SinkItem>(sink_stream))
807            .map(|_| ());
808        (Pipeline(sender), f)
809    }
810
811    // `None` means that the stream was out of items causing that poll loop to shut down.
812    async fn send(&mut self, item: SinkItem) -> Result<I, Option<E>> {
813        self.send_recv_multiple(item, 1)
814            .await
815            // We can unwrap since we do a request for `1` item
816            .map(|mut item| item.pop().unwrap())
817    }
818
819    async fn send_recv_multiple(
820        &mut self,
821        input: SinkItem,
822        count: usize,
823    ) -> Result<Vec<I>, Option<E>> {
824        let (sender, receiver) = oneshot::channel();
825
826        self.0
827            .send(PipelineMessage {
828                input,
829                response_count: count,
830                output: sender,
831            })
832            .await
833            .map_err(|_| None)?;
834        match receiver.await {
835            Ok(result) => result.map_err(Some),
836            Err(_) => {
837                // The `sender` was dropped which likely means that the stream part
838                // failed for one reason or another
839                Err(None)
840            }
841        }
842    }
843}
844
845/// A connection object which can be cloned, allowing requests to be be sent concurrently
846/// on the same underlying connection (tcp/unix socket).
847#[derive(Clone)]
848pub struct MultiplexedConnection {
849    pipeline: Pipeline<Vec<u8>, Value, RedisError>,
850    db: i64,
851}
852
853impl MultiplexedConnection {
854    /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object
855    /// and a `ConnectionInfo`
856    pub async fn new<C>(
857        connection_info: &RedisConnectionInfo,
858        stream: C,
859    ) -> RedisResult<(Self, impl Future<Output = ()>)>
860    where
861        C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
862    {
863        fn boxed(
864            f: impl Future<Output = ()> + Send + 'static,
865        ) -> Pin<Box<dyn Future<Output = ()> + Send>> {
866            Box::pin(f)
867        }
868
869        #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
870        compile_error!("tokio-comp or async-std-comp features required for aio feature");
871
872        let codec = ValueCodec::default()
873            .framed(stream)
874            .and_then(|msg| async move { msg });
875        let (pipeline, driver) = Pipeline::new(codec);
876        let driver = boxed(driver);
877        let mut con = MultiplexedConnection {
878            pipeline,
879            db: connection_info.db,
880        };
881        let driver = {
882            let auth = authenticate(connection_info, &mut con);
883            futures_util::pin_mut!(auth);
884
885            match futures_util::future::select(auth, driver).await {
886                futures_util::future::Either::Left((result, driver)) => {
887                    result?;
888                    driver
889                }
890                futures_util::future::Either::Right(((), _)) => {
891                    unreachable!("Multiplexed connection driver unexpectedly terminated")
892                }
893            }
894        };
895        Ok((con, driver))
896    }
897}
898
899impl ConnectionLike for MultiplexedConnection {
900    fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
901        (async move {
902            let value = self
903                .pipeline
904                .send(cmd.get_packed_command())
905                .await
906                .map_err(|err| {
907                    err.unwrap_or_else(|| {
908                        RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))
909                    })
910                })?;
911            Ok(value)
912        })
913        .boxed()
914    }
915
916    fn req_packed_commands<'a>(
917        &'a mut self,
918        cmd: &'a crate::Pipeline,
919        offset: usize,
920        count: usize,
921    ) -> RedisFuture<'a, Vec<Value>> {
922        (async move {
923            let mut value = self
924                .pipeline
925                .send_recv_multiple(cmd.get_packed_pipeline(), offset + count)
926                .await
927                .map_err(|err| {
928                    err.unwrap_or_else(|| {
929                        RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))
930                    })
931                })?;
932
933            value.drain(..offset);
934            Ok(value)
935        })
936        .boxed()
937    }
938
939    fn get_db(&self) -> i64 {
940        self.db
941    }
942}
943
944#[cfg(feature = "connection-manager")]
945mod connection_manager {
946    use super::*;
947
948    use std::sync::Arc;
949
950    use arc_swap::{self, ArcSwap};
951    use futures::future::{self, Shared};
952    use futures_util::future::BoxFuture;
953
954    use crate::Client;
955
956    /// A `ConnectionManager` is a proxy that wraps a [multiplexed
957    /// connection][multiplexed-connection] and automatically reconnects to the
958    /// server when necessary.
959    ///
960    /// Like the [`MultiplexedConnection`][multiplexed-connection], this
961    /// manager can be cloned, allowing requests to be be sent concurrently on
962    /// the same underlying connection (tcp/unix socket).
963    ///
964    /// ## Behavior
965    ///
966    /// - When creating an instance of the `ConnectionManager`, an initial
967    ///   connection will be established and awaited. Connection errors will be
968    ///   returned directly.
969    /// - When a command sent to the server fails with an error that represents
970    ///   a "connection dropped" condition, that error will be passed on to the
971    ///   user, but it will trigger a reconnection in the background.
972    /// - The reconnect code will atomically swap the current (dead) connection
973    ///   with a future that will eventually resolve to a `MultiplexedConnection`
974    ///   or to a `RedisError`
975    /// - All commands that are issued after the reconnect process has been
976    ///   initiated, will have to await the connection future.
977    /// - If reconnecting fails, all pending commands will be failed as well. A
978    ///   new reconnection attempt will be triggered if the error is an I/O error.
979    ///
980    /// [multiplexed-connection]: struct.MultiplexedConnection.html
981    #[derive(Clone)]
982    pub struct ConnectionManager {
983        /// Information used for the connection. This is needed to be able to reconnect.
984        client: Client,
985        /// The connection future.
986        ///
987        /// The `ArcSwap` is required to be able to replace the connection
988        /// without making the `ConnectionManager` mutable.
989        connection: Arc<ArcSwap<SharedRedisFuture<MultiplexedConnection>>>,
990
991        runtime: Runtime,
992    }
993
994    /// A `RedisResult` that can be cloned because `RedisError` is behind an `Arc`.
995    type CloneableRedisResult<T> = Result<T, Arc<RedisError>>;
996
997    /// Type alias for a shared boxed future that will resolve to a `CloneableRedisResult`.
998    type SharedRedisFuture<T> = Shared<BoxFuture<'static, CloneableRedisResult<T>>>;
999
1000    impl ConnectionManager {
1001        /// Connect to the server and store the connection inside the returned `ConnectionManager`.
1002        ///
1003        /// This requires the `connection-manager` feature, which will also pull in
1004        /// the Tokio executor.
1005        pub async fn new(client: Client) -> RedisResult<Self> {
1006            // Create a MultiplexedConnection and wait for it to be established
1007
1008            let runtime = Runtime::locate();
1009            let connection = client.get_multiplexed_async_connection().await?;
1010
1011            // Wrap the connection in an `ArcSwap` instance for fast atomic access
1012            Ok(Self {
1013                client,
1014                connection: Arc::new(ArcSwap::from_pointee(
1015                    future::ok(connection).boxed().shared(),
1016                )),
1017                runtime,
1018            })
1019        }
1020
1021        /// Reconnect and overwrite the old connection.
1022        ///
1023        /// The `current` guard points to the shared future that was active
1024        /// when the connection loss was detected.
1025        fn reconnect(
1026            &self,
1027            current: arc_swap::Guard<Arc<SharedRedisFuture<MultiplexedConnection>>>,
1028        ) {
1029            let client = self.client.clone();
1030            let new_connection: SharedRedisFuture<MultiplexedConnection> =
1031                async move { Ok(client.get_multiplexed_async_connection().await?) }
1032                    .boxed()
1033                    .shared();
1034
1035            // Update the connection in the connection manager
1036            let new_connection_arc = Arc::new(new_connection.clone());
1037            let prev = self
1038                .connection
1039                .compare_and_swap(&current, new_connection_arc);
1040
1041            // If the swap happened...
1042            if Arc::ptr_eq(&prev, &current) {
1043                // ...start the connection attempt immediately but do not wait on it.
1044                self.runtime.spawn(new_connection.map(|_| ()));
1045            }
1046        }
1047    }
1048
1049    /// Handle a command result. If the connection was dropped, reconnect.
1050    macro_rules! reconnect_if_dropped {
1051        ($self:expr, $result:expr, $current:expr) => {
1052            if let Err(ref e) = $result {
1053                if e.is_connection_dropped() {
1054                    $self.reconnect($current);
1055                }
1056            }
1057        };
1058    }
1059
1060    /// Handle a connection result. If there's an I/O error, reconnect.
1061    /// Propagate any error.
1062    macro_rules! reconnect_if_io_error {
1063        ($self:expr, $result:expr, $current:expr) => {
1064            if let Err(e) = $result {
1065                if e.is_io_error() {
1066                    $self.reconnect($current);
1067                }
1068                return Err(e);
1069            }
1070        };
1071    }
1072
1073    impl ConnectionLike for ConnectionManager {
1074        fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
1075            (async move {
1076                // Clone connection to avoid having to lock the ArcSwap in write mode
1077                let guard = self.connection.load();
1078                let connection_result = (**guard)
1079                    .clone()
1080                    .await
1081                    .map_err(|e| e.clone_mostly("Reconnecting failed"));
1082                reconnect_if_io_error!(self, connection_result, guard);
1083                let result = connection_result?.req_packed_command(cmd).await;
1084                reconnect_if_dropped!(self, &result, guard);
1085                result
1086            })
1087            .boxed()
1088        }
1089
1090        fn req_packed_commands<'a>(
1091            &'a mut self,
1092            cmd: &'a crate::Pipeline,
1093            offset: usize,
1094            count: usize,
1095        ) -> RedisFuture<'a, Vec<Value>> {
1096            (async move {
1097                // Clone shared connection future to avoid having to lock the ArcSwap in write mode
1098                let guard = self.connection.load();
1099                let connection_result = (**guard)
1100                    .clone()
1101                    .await
1102                    .map_err(|e| e.clone_mostly("Reconnecting failed"));
1103                reconnect_if_io_error!(self, connection_result, guard);
1104                let result = connection_result?
1105                    .req_packed_commands(cmd, offset, count)
1106                    .await;
1107                reconnect_if_dropped!(self, &result, guard);
1108                result
1109            })
1110            .boxed()
1111        }
1112
1113        fn get_db(&self) -> i64 {
1114            self.client.connection_info().redis.db
1115        }
1116    }
1117}
1118
1119#[cfg(feature = "connection-manager")]
1120#[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))]
1121pub use connection_manager::ConnectionManager;