1use async_trait::async_trait;
3use std::collections::VecDeque;
4use std::io;
5use std::mem;
6use std::net::SocketAddr;
7use std::net::ToSocketAddrs;
8#[cfg(unix)]
9use std::path::Path;
10use std::pin::Pin;
11use std::task::{self, Poll};
12
13use combine::{parser::combinator::AnySendSyncPartialState, stream::PointerOffset};
14
15use ::tokio::{
16 io::{AsyncRead, AsyncWrite, AsyncWriteExt},
17 sync::{mpsc, oneshot},
18};
19
20#[cfg(feature = "tls")]
21use native_tls::TlsConnector;
22
23#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
24use tokio_util::codec::Decoder;
25
26use futures_util::{
27 future::{Future, FutureExt},
28 ready,
29 sink::Sink,
30 stream::{self, Stream, StreamExt, TryStreamExt as _},
31};
32
33use pin_project_lite::pin_project;
34
35use crate::cmd::{cmd, Cmd};
36use crate::connection::{ConnectionAddr, ConnectionInfo, Msg, RedisConnectionInfo};
37
38#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))]
39use crate::parser::ValueCodec;
40use crate::types::{ErrorKind, FromRedisValue, RedisError, RedisFuture, RedisResult, Value};
41use crate::{from_redis_value, ToRedisArgs};
42
43#[cfg(feature = "async-std-comp")]
45#[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))]
46pub mod async_std;
47
48#[cfg(feature = "tokio-comp")]
50#[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))]
51pub mod tokio;
52
53#[async_trait]
55pub(crate) trait RedisRuntime: AsyncStream + Send + Sync + Sized + 'static {
56 async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult<Self>;
58
59 #[cfg(feature = "tls")]
61 async fn connect_tcp_tls(
62 hostname: &str,
63 socket_addr: SocketAddr,
64 insecure: bool,
65 ) -> RedisResult<Self>;
66
67 #[cfg(unix)]
69 async fn connect_unix(path: &Path) -> RedisResult<Self>;
70
71 fn spawn(f: impl Future<Output = ()> + Send + 'static);
72
73 fn boxed(self) -> Pin<Box<dyn AsyncStream + Send + Sync>> {
74 Box::pin(self)
75 }
76}
77
78#[derive(Clone, Debug)]
79pub(crate) enum Runtime {
80 #[cfg(feature = "tokio-comp")]
81 Tokio,
82 #[cfg(feature = "async-std-comp")]
83 AsyncStd,
84}
85
86impl Runtime {
87 pub(crate) fn locate() -> Self {
88 #[cfg(all(feature = "tokio-comp", not(feature = "async-std-comp")))]
89 {
90 Runtime::Tokio
91 }
92
93 #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))]
94 {
95 Runtime::AsyncStd
96 }
97
98 #[cfg(all(feature = "tokio-comp", feature = "async-std-comp"))]
99 {
100 if ::tokio::runtime::Handle::try_current().is_ok() {
101 Runtime::Tokio
102 } else {
103 Runtime::AsyncStd
104 }
105 }
106
107 #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
108 {
109 compile_error!("tokio-comp or async-std-comp features required for aio feature")
110 }
111 }
112
113 #[allow(dead_code)]
114 fn spawn(&self, f: impl Future<Output = ()> + Send + 'static) {
115 match self {
116 #[cfg(feature = "tokio-comp")]
117 Runtime::Tokio => tokio::Tokio::spawn(f),
118 #[cfg(feature = "async-std-comp")]
119 Runtime::AsyncStd => async_std::AsyncStd::spawn(f),
120 }
121 }
122}
123
124pub trait AsyncStream: AsyncRead + AsyncWrite {}
126impl<S> AsyncStream for S where S: AsyncRead + AsyncWrite {}
127
128pub struct PubSub<C = Pin<Box<dyn AsyncStream + Send + Sync>>>(Connection<C>);
130
131pub struct Monitor<C = Pin<Box<dyn AsyncStream + Send + Sync>>>(Connection<C>);
133
134impl<C> PubSub<C>
135where
136 C: Unpin + AsyncRead + AsyncWrite + Send,
137{
138 fn new(con: Connection<C>) -> Self {
139 Self(con)
140 }
141
142 pub async fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
144 Ok(cmd("SUBSCRIBE")
145 .arg(channel)
146 .query_async(&mut self.0)
147 .await?)
148 }
149
150 pub async fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
152 Ok(cmd("PSUBSCRIBE")
153 .arg(pchannel)
154 .query_async(&mut self.0)
155 .await?)
156 }
157
158 pub async fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
160 Ok(cmd("UNSUBSCRIBE")
161 .arg(channel)
162 .query_async(&mut self.0)
163 .await?)
164 }
165
166 pub async fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
168 Ok(cmd("PUNSUBSCRIBE")
169 .arg(pchannel)
170 .query_async(&mut self.0)
171 .await?)
172 }
173
174 pub fn on_message(&mut self) -> impl Stream<Item = Msg> + '_ {
179 ValueCodec::default()
180 .framed(&mut self.0.con)
181 .filter_map(|msg| Box::pin(async move { Msg::from_value(&msg.ok()?.ok()?) }))
182 }
183
184 pub fn into_on_message(self) -> impl Stream<Item = Msg> {
191 ValueCodec::default()
192 .framed(self.0.con)
193 .filter_map(|msg| Box::pin(async move { Msg::from_value(&msg.ok()?.ok()?) }))
194 }
195
196 pub async fn into_connection(mut self) -> Connection<C> {
198 self.0.exit_pubsub().await.ok();
199
200 self.0
201 }
202}
203
204impl<C> Monitor<C>
205where
206 C: Unpin + AsyncRead + AsyncWrite + Send,
207{
208 pub fn new(con: Connection<C>) -> Self {
210 Self(con)
211 }
212
213 pub async fn monitor(&mut self) -> RedisResult<()> {
215 Ok(cmd("MONITOR").query_async(&mut self.0).await?)
216 }
217
218 pub fn on_message<T: FromRedisValue>(&mut self) -> impl Stream<Item = T> + '_ {
220 ValueCodec::default()
221 .framed(&mut self.0.con)
222 .filter_map(|value| {
223 Box::pin(async move { T::from_redis_value(&value.ok()?.ok()?).ok() })
224 })
225 }
226
227 pub fn into_on_message<T: FromRedisValue>(self) -> impl Stream<Item = T> {
229 ValueCodec::default()
230 .framed(self.0.con)
231 .filter_map(|value| {
232 Box::pin(async move { T::from_redis_value(&value.ok()?.ok()?).ok() })
233 })
234 }
235}
236
237pub struct Connection<C = Pin<Box<dyn AsyncStream + Send + Sync>>> {
239 con: C,
240 buf: Vec<u8>,
241 decoder: combine::stream::Decoder<AnySendSyncPartialState, PointerOffset<[u8]>>,
242 db: i64,
243
244 pubsub: bool,
249}
250
251fn assert_sync<T: Sync>() {}
252
253#[allow(unused)]
254fn test() {
255 assert_sync::<Connection>();
256}
257
258impl<C> Connection<C> {
259 pub(crate) fn map<D>(self, f: impl FnOnce(C) -> D) -> Connection<D> {
260 let Self {
261 con,
262 buf,
263 decoder,
264 db,
265 pubsub,
266 } = self;
267 Connection {
268 con: f(con),
269 buf,
270 decoder,
271 db,
272 pubsub,
273 }
274 }
275}
276
277impl<C> Connection<C>
278where
279 C: Unpin + AsyncRead + AsyncWrite + Send,
280{
281 pub async fn new(connection_info: &RedisConnectionInfo, con: C) -> RedisResult<Self> {
284 let mut rv = Connection {
285 con,
286 buf: Vec::new(),
287 decoder: combine::stream::Decoder::new(),
288 db: connection_info.db,
289 pubsub: false,
290 };
291 authenticate(connection_info, &mut rv).await?;
292 Ok(rv)
293 }
294
295 pub fn into_pubsub(self) -> PubSub<C> {
297 PubSub::new(self)
298 }
299
300 pub fn into_monitor(self) -> Monitor<C> {
302 Monitor::new(self)
303 }
304
305 async fn read_response(&mut self) -> RedisResult<Value> {
307 crate::parser::parse_redis_value_async(&mut self.decoder, &mut self.con).await
308 }
309
310 async fn exit_pubsub(&mut self) -> RedisResult<()> {
317 let res = self.clear_active_subscriptions().await;
318 if res.is_ok() {
319 self.pubsub = false;
320 } else {
321 self.pubsub = true;
323 }
324
325 res
326 }
327
328 async fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
333 {
339 let unsubscribe = crate::Pipeline::new()
341 .add_command(cmd("UNSUBSCRIBE"))
342 .add_command(cmd("PUNSUBSCRIBE"))
343 .get_packed_pipeline();
344
345 self.con.write_all(&unsubscribe).await?;
347 }
348
349 let mut received_unsub = false;
355 let mut received_punsub = false;
356 loop {
357 let res: (Vec<u8>, (), isize) = from_redis_value(&self.read_response().await?)?;
358
359 match res.0.first() {
360 Some(&b'u') => received_unsub = true,
361 Some(&b'p') => received_punsub = true,
362 _ => (),
363 }
364
365 if received_unsub && received_punsub && res.2 == 0 {
366 break;
367 }
368 }
369
370 Ok(())
373 }
374}
375
376#[cfg(feature = "async-std-comp")]
377#[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))]
378impl<C> Connection<async_std::AsyncStdWrapped<C>>
379where
380 C: Unpin + ::async_std::io::Read + ::async_std::io::Write + Send,
381{
382 pub async fn new_async_std(connection_info: &RedisConnectionInfo, con: C) -> RedisResult<Self> {
385 Connection::new(connection_info, async_std::AsyncStdWrapped::new(con)).await
386 }
387}
388
389pub(crate) async fn connect<C>(connection_info: &ConnectionInfo) -> RedisResult<Connection<C>>
390where
391 C: Unpin + RedisRuntime + AsyncRead + AsyncWrite + Send,
392{
393 let con = connect_simple::<C>(connection_info).await?;
394 Connection::new(&connection_info.redis, con).await
395}
396
397async fn authenticate<C>(connection_info: &RedisConnectionInfo, con: &mut C) -> RedisResult<()>
398where
399 C: ConnectionLike,
400{
401 if let Some(password) = &connection_info.password {
402 let mut command = cmd("AUTH");
403 if let Some(username) = &connection_info.username {
404 command.arg(username);
405 }
406 match command.arg(password).query_async(con).await {
407 Ok(Value::Okay) => (),
408 Err(e) => {
409 let err_msg = e.detail().ok_or((
410 ErrorKind::AuthenticationFailed,
411 "Password authentication failed",
412 ))?;
413
414 if !err_msg.contains("wrong number of arguments for 'auth' command") {
415 fail!((
416 ErrorKind::AuthenticationFailed,
417 "Password authentication failed",
418 ));
419 }
420
421 let mut command = cmd("AUTH");
422 match command.arg(password).query_async(con).await {
423 Ok(Value::Okay) => (),
424 _ => {
425 fail!((
426 ErrorKind::AuthenticationFailed,
427 "Password authentication failed"
428 ));
429 }
430 }
431 }
432 _ => {
433 fail!((
434 ErrorKind::AuthenticationFailed,
435 "Password authentication failed"
436 ));
437 }
438 }
439 }
440
441 if connection_info.db != 0 {
442 match cmd("SELECT").arg(connection_info.db).query_async(con).await {
443 Ok(Value::Okay) => (),
444 _ => fail!((
445 ErrorKind::ResponseError,
446 "Redis server refused to switch database"
447 )),
448 }
449 }
450
451 Ok(())
452}
453
454pub(crate) async fn connect_simple<T: RedisRuntime>(
455 connection_info: &ConnectionInfo,
456) -> RedisResult<T> {
457 Ok(match connection_info.addr {
458 ConnectionAddr::Tcp(ref host, port) => {
459 let socket_addr = get_socket_addrs(host, port)?;
460 <T>::connect_tcp(socket_addr).await?
461 }
462
463 #[cfg(feature = "tls")]
464 ConnectionAddr::TcpTls {
465 ref host,
466 port,
467 insecure,
468 } => {
469 let socket_addr = get_socket_addrs(host, port)?;
470 <T>::connect_tcp_tls(host, socket_addr, insecure).await?
471 }
472
473 #[cfg(not(feature = "tls"))]
474 ConnectionAddr::TcpTls { .. } => {
475 fail!((
476 ErrorKind::InvalidClientConfig,
477 "Cannot connect to TCP with TLS without the tls feature"
478 ));
479 }
480
481 #[cfg(unix)]
482 ConnectionAddr::Unix(ref path) => <T>::connect_unix(path).await?,
483
484 #[cfg(not(unix))]
485 ConnectionAddr::Unix(_) => {
486 return Err(RedisError::from((
487 ErrorKind::InvalidClientConfig,
488 "Cannot connect to unix sockets \
489 on this platform",
490 )))
491 }
492 })
493}
494
495fn get_socket_addrs(host: &str, port: u16) -> RedisResult<SocketAddr> {
496 let mut socket_addrs = (host, port).to_socket_addrs()?;
497 match socket_addrs.next() {
498 Some(socket_addr) => Ok(socket_addr),
499 None => Err(RedisError::from((
500 ErrorKind::InvalidClientConfig,
501 "No address found for host",
502 ))),
503 }
504}
505
506pub trait ConnectionLike {
508 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value>;
511
512 fn req_packed_commands<'a>(
516 &'a mut self,
517 cmd: &'a crate::Pipeline,
518 offset: usize,
519 count: usize,
520 ) -> RedisFuture<'a, Vec<Value>>;
521
522 fn get_db(&self) -> i64;
527}
528
529impl<C> ConnectionLike for Connection<C>
530where
531 C: Unpin + AsyncRead + AsyncWrite + Send,
532{
533 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
534 (async move {
535 if self.pubsub {
536 self.exit_pubsub().await?;
537 }
538 self.buf.clear();
539 cmd.write_packed_command(&mut self.buf);
540 self.con.write_all(&self.buf).await?;
541 self.read_response().await
542 })
543 .boxed()
544 }
545
546 fn req_packed_commands<'a>(
547 &'a mut self,
548 cmd: &'a crate::Pipeline,
549 offset: usize,
550 count: usize,
551 ) -> RedisFuture<'a, Vec<Value>> {
552 (async move {
553 if self.pubsub {
554 self.exit_pubsub().await?;
555 }
556
557 self.buf.clear();
558 cmd.write_packed_pipeline(&mut self.buf);
559 self.con.write_all(&self.buf).await?;
560
561 let mut first_err = None;
562
563 for _ in 0..offset {
564 let response = self.read_response().await;
565 if let Err(err) = response {
566 if first_err.is_none() {
567 first_err = Some(err);
568 }
569 }
570 }
571
572 let mut rv = Vec::with_capacity(count);
573 for _ in 0..count {
574 let response = self.read_response().await;
575 match response {
576 Ok(item) => {
577 rv.push(item);
578 }
579 Err(err) => {
580 if first_err.is_none() {
581 first_err = Some(err);
582 }
583 }
584 }
585 }
586
587 if let Some(err) = first_err {
588 Err(err)
589 } else {
590 Ok(rv)
591 }
592 })
593 .boxed()
594 }
595
596 fn get_db(&self) -> i64 {
597 self.db
598 }
599}
600
601type PipelineOutput<O, E> = oneshot::Sender<Result<Vec<O>, E>>;
603
604struct InFlight<O, E> {
605 output: PipelineOutput<O, E>,
606 response_count: usize,
607 buffer: Vec<O>,
608}
609
610struct PipelineMessage<S, I, E> {
612 input: S,
613 output: PipelineOutput<I, E>,
614 response_count: usize,
615}
616
617struct Pipeline<SinkItem, I, E>(mpsc::Sender<PipelineMessage<SinkItem, I, E>>);
622
623impl<SinkItem, I, E> Clone for Pipeline<SinkItem, I, E> {
624 fn clone(&self) -> Self {
625 Pipeline(self.0.clone())
626 }
627}
628
629pin_project! {
630 struct PipelineSink<T, I, E> {
631 #[pin]
632 sink_stream: T,
633 in_flight: VecDeque<InFlight<I, E>>,
634 error: Option<E>,
635 }
636}
637
638impl<T, I, E> PipelineSink<T, I, E>
639where
640 T: Stream<Item = Result<I, E>> + 'static,
641{
642 fn new<SinkItem>(sink_stream: T) -> Self
643 where
644 T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
645 {
646 PipelineSink {
647 sink_stream,
648 in_flight: VecDeque::new(),
649 error: None,
650 }
651 }
652
653 fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll<Result<(), ()>> {
655 loop {
656 if self.in_flight.is_empty() {
658 return Poll::Ready(Ok(()));
659 }
660 let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) {
661 Some(result) => result,
662 None => return Poll::Ready(Err(())),
665 };
666 self.as_mut().send_result(item);
667 }
668 }
669
670 fn send_result(self: Pin<&mut Self>, result: Result<I, E>) {
671 let self_ = self.project();
672 let response = {
673 let entry = match self_.in_flight.front_mut() {
674 Some(entry) => entry,
675 None => return,
676 };
677 match result {
678 Ok(item) => {
679 entry.buffer.push(item);
680 if entry.response_count > entry.buffer.len() {
681 return;
683 }
684 Ok(mem::take(&mut entry.buffer))
685 }
686 Err(err) => Err(err),
688 }
689 };
690
691 let entry = self_.in_flight.pop_front().unwrap();
692 entry.output.send(response).ok();
696 }
697}
698
699impl<SinkItem, T, I, E> Sink<PipelineMessage<SinkItem, I, E>> for PipelineSink<T, I, E>
700where
701 T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
702{
703 type Error = ();
704
705 fn poll_ready(
707 mut self: Pin<&mut Self>,
708 cx: &mut task::Context,
709 ) -> Poll<Result<(), Self::Error>> {
710 match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) {
711 Ok(()) => Ok(()).into(),
712 Err(err) => {
713 *self.project().error = Some(err);
714 Ok(()).into()
715 }
716 }
717 }
718
719 fn start_send(
720 mut self: Pin<&mut Self>,
721 PipelineMessage {
722 input,
723 output,
724 response_count,
725 }: PipelineMessage<SinkItem, I, E>,
726 ) -> Result<(), Self::Error> {
727 if output.is_closed() {
731 return Ok(());
732 }
733
734 let self_ = self.as_mut().project();
735
736 if let Some(err) = self_.error.take() {
737 let _ = output.send(Err(err));
738 return Err(());
739 }
740
741 match self_.sink_stream.start_send(input) {
742 Ok(()) => {
743 self_.in_flight.push_back(InFlight {
744 output,
745 response_count,
746 buffer: Vec::new(),
747 });
748 Ok(())
749 }
750 Err(err) => {
751 let _ = output.send(Err(err));
752 Err(())
753 }
754 }
755 }
756
757 fn poll_flush(
758 mut self: Pin<&mut Self>,
759 cx: &mut task::Context,
760 ) -> Poll<Result<(), Self::Error>> {
761 ready!(self
762 .as_mut()
763 .project()
764 .sink_stream
765 .poll_flush(cx)
766 .map_err(|err| {
767 self.as_mut().send_result(Err(err));
768 }))?;
769 self.poll_read(cx)
770 }
771
772 fn poll_close(
773 mut self: Pin<&mut Self>,
774 cx: &mut task::Context,
775 ) -> Poll<Result<(), Self::Error>> {
776 if !self.in_flight.is_empty() {
779 ready!(self.as_mut().poll_flush(cx))?;
780 }
781 let this = self.as_mut().project();
782 this.sink_stream.poll_close(cx).map_err(|err| {
783 self.send_result(Err(err));
784 })
785 }
786}
787
788impl<SinkItem, I, E> Pipeline<SinkItem, I, E>
789where
790 SinkItem: Send + 'static,
791 I: Send + 'static,
792 E: Send + 'static,
793{
794 fn new<T>(sink_stream: T) -> (Self, impl Future<Output = ()>)
795 where
796 T: Sink<SinkItem, Error = E> + Stream<Item = Result<I, E>> + 'static,
797 T: Send + 'static,
798 T::Item: Send,
799 T::Error: Send,
800 T::Error: ::std::fmt::Debug,
801 {
802 const BUFFER_SIZE: usize = 50;
803 let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE);
804 let f = stream::poll_fn(move |cx| receiver.poll_recv(cx))
805 .map(Ok)
806 .forward(PipelineSink::new::<SinkItem>(sink_stream))
807 .map(|_| ());
808 (Pipeline(sender), f)
809 }
810
811 async fn send(&mut self, item: SinkItem) -> Result<I, Option<E>> {
813 self.send_recv_multiple(item, 1)
814 .await
815 .map(|mut item| item.pop().unwrap())
817 }
818
819 async fn send_recv_multiple(
820 &mut self,
821 input: SinkItem,
822 count: usize,
823 ) -> Result<Vec<I>, Option<E>> {
824 let (sender, receiver) = oneshot::channel();
825
826 self.0
827 .send(PipelineMessage {
828 input,
829 response_count: count,
830 output: sender,
831 })
832 .await
833 .map_err(|_| None)?;
834 match receiver.await {
835 Ok(result) => result.map_err(Some),
836 Err(_) => {
837 Err(None)
840 }
841 }
842 }
843}
844
845#[derive(Clone)]
848pub struct MultiplexedConnection {
849 pipeline: Pipeline<Vec<u8>, Value, RedisError>,
850 db: i64,
851}
852
853impl MultiplexedConnection {
854 pub async fn new<C>(
857 connection_info: &RedisConnectionInfo,
858 stream: C,
859 ) -> RedisResult<(Self, impl Future<Output = ()>)>
860 where
861 C: Unpin + AsyncRead + AsyncWrite + Send + 'static,
862 {
863 fn boxed(
864 f: impl Future<Output = ()> + Send + 'static,
865 ) -> Pin<Box<dyn Future<Output = ()> + Send>> {
866 Box::pin(f)
867 }
868
869 #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))]
870 compile_error!("tokio-comp or async-std-comp features required for aio feature");
871
872 let codec = ValueCodec::default()
873 .framed(stream)
874 .and_then(|msg| async move { msg });
875 let (pipeline, driver) = Pipeline::new(codec);
876 let driver = boxed(driver);
877 let mut con = MultiplexedConnection {
878 pipeline,
879 db: connection_info.db,
880 };
881 let driver = {
882 let auth = authenticate(connection_info, &mut con);
883 futures_util::pin_mut!(auth);
884
885 match futures_util::future::select(auth, driver).await {
886 futures_util::future::Either::Left((result, driver)) => {
887 result?;
888 driver
889 }
890 futures_util::future::Either::Right(((), _)) => {
891 unreachable!("Multiplexed connection driver unexpectedly terminated")
892 }
893 }
894 };
895 Ok((con, driver))
896 }
897}
898
899impl ConnectionLike for MultiplexedConnection {
900 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
901 (async move {
902 let value = self
903 .pipeline
904 .send(cmd.get_packed_command())
905 .await
906 .map_err(|err| {
907 err.unwrap_or_else(|| {
908 RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))
909 })
910 })?;
911 Ok(value)
912 })
913 .boxed()
914 }
915
916 fn req_packed_commands<'a>(
917 &'a mut self,
918 cmd: &'a crate::Pipeline,
919 offset: usize,
920 count: usize,
921 ) -> RedisFuture<'a, Vec<Value>> {
922 (async move {
923 let mut value = self
924 .pipeline
925 .send_recv_multiple(cmd.get_packed_pipeline(), offset + count)
926 .await
927 .map_err(|err| {
928 err.unwrap_or_else(|| {
929 RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))
930 })
931 })?;
932
933 value.drain(..offset);
934 Ok(value)
935 })
936 .boxed()
937 }
938
939 fn get_db(&self) -> i64 {
940 self.db
941 }
942}
943
944#[cfg(feature = "connection-manager")]
945mod connection_manager {
946 use super::*;
947
948 use std::sync::Arc;
949
950 use arc_swap::{self, ArcSwap};
951 use futures::future::{self, Shared};
952 use futures_util::future::BoxFuture;
953
954 use crate::Client;
955
956 #[derive(Clone)]
982 pub struct ConnectionManager {
983 client: Client,
985 connection: Arc<ArcSwap<SharedRedisFuture<MultiplexedConnection>>>,
990
991 runtime: Runtime,
992 }
993
994 type CloneableRedisResult<T> = Result<T, Arc<RedisError>>;
996
997 type SharedRedisFuture<T> = Shared<BoxFuture<'static, CloneableRedisResult<T>>>;
999
1000 impl ConnectionManager {
1001 pub async fn new(client: Client) -> RedisResult<Self> {
1006 let runtime = Runtime::locate();
1009 let connection = client.get_multiplexed_async_connection().await?;
1010
1011 Ok(Self {
1013 client,
1014 connection: Arc::new(ArcSwap::from_pointee(
1015 future::ok(connection).boxed().shared(),
1016 )),
1017 runtime,
1018 })
1019 }
1020
1021 fn reconnect(
1026 &self,
1027 current: arc_swap::Guard<Arc<SharedRedisFuture<MultiplexedConnection>>>,
1028 ) {
1029 let client = self.client.clone();
1030 let new_connection: SharedRedisFuture<MultiplexedConnection> =
1031 async move { Ok(client.get_multiplexed_async_connection().await?) }
1032 .boxed()
1033 .shared();
1034
1035 let new_connection_arc = Arc::new(new_connection.clone());
1037 let prev = self
1038 .connection
1039 .compare_and_swap(¤t, new_connection_arc);
1040
1041 if Arc::ptr_eq(&prev, ¤t) {
1043 self.runtime.spawn(new_connection.map(|_| ()));
1045 }
1046 }
1047 }
1048
1049 macro_rules! reconnect_if_dropped {
1051 ($self:expr, $result:expr, $current:expr) => {
1052 if let Err(ref e) = $result {
1053 if e.is_connection_dropped() {
1054 $self.reconnect($current);
1055 }
1056 }
1057 };
1058 }
1059
1060 macro_rules! reconnect_if_io_error {
1063 ($self:expr, $result:expr, $current:expr) => {
1064 if let Err(e) = $result {
1065 if e.is_io_error() {
1066 $self.reconnect($current);
1067 }
1068 return Err(e);
1069 }
1070 };
1071 }
1072
1073 impl ConnectionLike for ConnectionManager {
1074 fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> {
1075 (async move {
1076 let guard = self.connection.load();
1078 let connection_result = (**guard)
1079 .clone()
1080 .await
1081 .map_err(|e| e.clone_mostly("Reconnecting failed"));
1082 reconnect_if_io_error!(self, connection_result, guard);
1083 let result = connection_result?.req_packed_command(cmd).await;
1084 reconnect_if_dropped!(self, &result, guard);
1085 result
1086 })
1087 .boxed()
1088 }
1089
1090 fn req_packed_commands<'a>(
1091 &'a mut self,
1092 cmd: &'a crate::Pipeline,
1093 offset: usize,
1094 count: usize,
1095 ) -> RedisFuture<'a, Vec<Value>> {
1096 (async move {
1097 let guard = self.connection.load();
1099 let connection_result = (**guard)
1100 .clone()
1101 .await
1102 .map_err(|e| e.clone_mostly("Reconnecting failed"));
1103 reconnect_if_io_error!(self, connection_result, guard);
1104 let result = connection_result?
1105 .req_packed_commands(cmd, offset, count)
1106 .await;
1107 reconnect_if_dropped!(self, &result, guard);
1108 result
1109 })
1110 .boxed()
1111 }
1112
1113 fn get_db(&self) -> i64 {
1114 self.client.connection_info().redis.db
1115 }
1116 }
1117}
1118
1119#[cfg(feature = "connection-manager")]
1120#[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))]
1121pub use connection_manager::ConnectionManager;