Skip to main content

redis/
connection.rs

1use std::collections::VecDeque;
2use std::fmt;
3use std::io::{self, Write};
4use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs};
5use std::ops::DerefMut;
6use std::path::PathBuf;
7use std::str::{from_utf8, FromStr};
8use std::time::Duration;
9
10use crate::cmd::{cmd, pipe, Cmd};
11use crate::parser::Parser;
12use crate::pipeline::Pipeline;
13use crate::types::{
14    from_redis_value, ErrorKind, FromRedisValue, RedisError, RedisResult, ToRedisArgs, Value,
15};
16
17#[cfg(unix)]
18use crate::types::HashMap;
19#[cfg(unix)]
20use std::os::unix::net::UnixStream;
21
22#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
23use native_tls::{TlsConnector, TlsStream};
24
25#[cfg(feature = "tls-rustls")]
26use rustls::{RootCertStore, StreamOwned};
27#[cfg(feature = "tls-rustls")]
28use std::{convert::TryInto, sync::Arc};
29
30#[cfg(feature = "tls-rustls-webpki-roots")]
31use rustls::OwnedTrustAnchor;
32#[cfg(feature = "tls-rustls-webpki-roots")]
33use webpki_roots::TLS_SERVER_ROOTS;
34
35#[cfg(all(
36    feature = "tls-rustls",
37    not(feature = "tls-native-tls"),
38    not(feature = "tls-rustls-webpki-roots")
39))]
40use rustls_native_certs::load_native_certs;
41
42#[cfg(feature = "tls-rustls")]
43use crate::tls::TlsConnParams;
44
45// Non-exhaustive to prevent construction outside this crate
46#[cfg(not(feature = "tls-rustls"))]
47#[derive(Clone, Debug)]
48#[non_exhaustive]
49pub struct TlsConnParams;
50
51static DEFAULT_PORT: u16 = 6379;
52
53#[inline(always)]
54fn connect_tcp(addr: (&str, u16)) -> io::Result<TcpStream> {
55    let socket = TcpStream::connect(addr)?;
56    #[cfg(feature = "tcp_nodelay")]
57    socket.set_nodelay(true)?;
58    #[cfg(feature = "keep-alive")]
59    {
60        //For now rely on system defaults
61        const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new();
62        //these are useless error that not going to happen
63        let socket2: socket2::Socket = socket.into();
64        socket2.set_tcp_keepalive(&KEEP_ALIVE)?;
65        Ok(socket2.into())
66    }
67    #[cfg(not(feature = "keep-alive"))]
68    {
69        Ok(socket)
70    }
71}
72
73#[inline(always)]
74fn connect_tcp_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
75    let socket = TcpStream::connect_timeout(addr, timeout)?;
76    #[cfg(feature = "tcp_nodelay")]
77    socket.set_nodelay(true)?;
78    #[cfg(feature = "keep-alive")]
79    {
80        //For now rely on system defaults
81        const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new();
82        //these are useless error that not going to happen
83        let socket2: socket2::Socket = socket.into();
84        socket2.set_tcp_keepalive(&KEEP_ALIVE)?;
85        Ok(socket2.into())
86    }
87    #[cfg(not(feature = "keep-alive"))]
88    {
89        Ok(socket)
90    }
91}
92
93/// This function takes a redis URL string and parses it into a URL
94/// as used by rust-url.  This is necessary as the default parser does
95/// not understand how redis URLs function.
96pub fn parse_redis_url(input: &str) -> Option<url::Url> {
97    match url::Url::parse(input) {
98        Ok(result) => match result.scheme() {
99            "redis" | "rediss" | "redis+unix" | "unix" => Some(result),
100            _ => None,
101        },
102        Err(_) => None,
103    }
104}
105
106/// TlsMode indicates use or do not use verification of certification.
107/// Check [ConnectionAddr](ConnectionAddr::TcpTls::insecure) for more.
108#[derive(Clone, Copy)]
109pub enum TlsMode {
110    /// Secure verify certification.
111    Secure,
112    /// Insecure do not verify certification.
113    Insecure,
114}
115
116/// Defines the connection address.
117///
118/// Not all connection addresses are supported on all platforms.  For instance
119/// to connect to a unix socket you need to run this on an operating system
120/// that supports them.
121#[derive(Clone, Debug)]
122pub enum ConnectionAddr {
123    /// Format for this is `(host, port)`.
124    Tcp(String, u16),
125    /// Format for this is `(host, port)`.
126    TcpTls {
127        /// Hostname
128        host: String,
129        /// Port
130        port: u16,
131        /// Disable hostname verification when connecting.
132        ///
133        /// # Warning
134        ///
135        /// You should think very carefully before you use this method. If hostname
136        /// verification is not used, any valid certificate for any site will be
137        /// trusted for use from any other. This introduces a significant
138        /// vulnerability to man-in-the-middle attacks.
139        insecure: bool,
140
141        /// TLS certificates and client key.
142        tls_params: Option<TlsConnParams>,
143    },
144    /// Format for this is the path to the unix socket.
145    Unix(PathBuf),
146}
147
148impl PartialEq for ConnectionAddr {
149    fn eq(&self, other: &Self) -> bool {
150        match (self, other) {
151            (ConnectionAddr::Tcp(host1, port1), ConnectionAddr::Tcp(host2, port2)) => {
152                host1 == host2 && port1 == port2
153            }
154            (
155                ConnectionAddr::TcpTls {
156                    host: host1,
157                    port: port1,
158                    insecure: insecure1,
159                    tls_params: _,
160                },
161                ConnectionAddr::TcpTls {
162                    host: host2,
163                    port: port2,
164                    insecure: insecure2,
165                    tls_params: _,
166                },
167            ) => port1 == port2 && host1 == host2 && insecure1 == insecure2,
168            (ConnectionAddr::Unix(path1), ConnectionAddr::Unix(path2)) => path1 == path2,
169            _ => false,
170        }
171    }
172}
173
174impl Eq for ConnectionAddr {}
175
176impl ConnectionAddr {
177    /// Checks if this address is supported.
178    ///
179    /// Because not all platforms support all connection addresses this is a
180    /// quick way to figure out if a connection method is supported.  Currently
181    /// this only affects unix connections which are only supported on unix
182    /// platforms and on older versions of rust also require an explicit feature
183    /// to be enabled.
184    pub fn is_supported(&self) -> bool {
185        match *self {
186            ConnectionAddr::Tcp(_, _) => true,
187            ConnectionAddr::TcpTls { .. } => {
188                cfg!(any(feature = "tls-native-tls", feature = "tls-rustls"))
189            }
190            ConnectionAddr::Unix(_) => cfg!(unix),
191        }
192    }
193}
194
195impl fmt::Display for ConnectionAddr {
196    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
197        // Cluster::get_connection_info depends on the return value from this function
198        match *self {
199            ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"),
200            ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"),
201            ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()),
202        }
203    }
204}
205
206/// Holds the connection information that redis should use for connecting.
207#[derive(Clone, Debug)]
208pub struct ConnectionInfo {
209    /// A connection address for where to connect to.
210    pub addr: ConnectionAddr,
211
212    /// A boxed connection address for where to connect to.
213    pub redis: RedisConnectionInfo,
214}
215
216/// Redis specific/connection independent information used to establish a connection to redis.
217#[derive(Clone, Debug, Default)]
218pub struct RedisConnectionInfo {
219    /// The database number to use.  This is usually `0`.
220    pub db: i64,
221    /// Optionally a username that should be used for connection.
222    pub username: Option<String>,
223    /// Optionally a password that should be used for connection.
224    pub password: Option<String>,
225}
226
227impl FromStr for ConnectionInfo {
228    type Err = RedisError;
229
230    fn from_str(s: &str) -> Result<Self, Self::Err> {
231        s.into_connection_info()
232    }
233}
234
235/// Converts an object into a connection info struct.  This allows the
236/// constructor of the client to accept connection information in a
237/// range of different formats.
238pub trait IntoConnectionInfo {
239    /// Converts the object into a connection info object.
240    fn into_connection_info(self) -> RedisResult<ConnectionInfo>;
241}
242
243impl IntoConnectionInfo for ConnectionInfo {
244    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
245        Ok(self)
246    }
247}
248
249/// URL format: `{redis|rediss}://[<username>][:<password>@]<hostname>[:port][/<db>]`
250///
251/// - Basic: `redis://127.0.0.1:6379`
252/// - Username & Password: `redis://user:password@127.0.0.1:6379`
253/// - Password only: `redis://:password@127.0.0.1:6379`
254/// - Specifying DB: `redis://127.0.0.1:6379/0`
255/// - Enabling TLS: `rediss://127.0.0.1:6379`
256/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure`
257impl IntoConnectionInfo for &str {
258    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
259        match parse_redis_url(self) {
260            Some(u) => u.into_connection_info(),
261            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
262        }
263    }
264}
265
266impl<T> IntoConnectionInfo for (T, u16)
267where
268    T: Into<String>,
269{
270    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
271        Ok(ConnectionInfo {
272            addr: ConnectionAddr::Tcp(self.0.into(), self.1),
273            redis: RedisConnectionInfo::default(),
274        })
275    }
276}
277
278/// URL format: `{redis|rediss}://[<username>][:<password>@]<hostname>[:port][/<db>]`
279///
280/// - Basic: `redis://127.0.0.1:6379`
281/// - Username & Password: `redis://user:password@127.0.0.1:6379`
282/// - Password only: `redis://:password@127.0.0.1:6379`
283/// - Specifying DB: `redis://127.0.0.1:6379/0`
284/// - Enabling TLS: `rediss://127.0.0.1:6379`
285/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure`
286impl IntoConnectionInfo for String {
287    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
288        match parse_redis_url(&self) {
289            Some(u) => u.into_connection_info(),
290            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
291        }
292    }
293}
294
295fn url_to_tcp_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
296    let host = match url.host() {
297        Some(host) => {
298            // Here we manually match host's enum arms and call their to_string().
299            // Because url.host().to_string() will add `[` and `]` for ipv6:
300            // https://docs.rs/url/latest/src/url/host.rs.html#170
301            // And these brackets will break host.parse::<Ipv6Addr>() when
302            // `client.open()` - `ActualConnection::new()` - `addr.to_socket_addrs()`:
303            // https://doc.rust-lang.org/src/std/net/addr.rs.html#963
304            // https://doc.rust-lang.org/src/std/net/parser.rs.html#158
305            // IpAddr string with brackets can ONLY parse to SocketAddrV6:
306            // https://doc.rust-lang.org/src/std/net/parser.rs.html#255
307            // But if we call Ipv6Addr.to_string directly, it follows rfc5952 without brackets:
308            // https://doc.rust-lang.org/src/std/net/ip.rs.html#1755
309            match host {
310                url::Host::Domain(path) => path.to_string(),
311                url::Host::Ipv4(v4) => v4.to_string(),
312                url::Host::Ipv6(v6) => v6.to_string(),
313            }
314        }
315        None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")),
316    };
317    let port = url.port().unwrap_or(DEFAULT_PORT);
318    let addr = if url.scheme() == "rediss" {
319        #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
320        {
321            match url.fragment() {
322                Some("insecure") => ConnectionAddr::TcpTls {
323                    host,
324                    port,
325                    insecure: true,
326                    tls_params: None,
327                },
328                Some(_) => fail!((
329                    ErrorKind::InvalidClientConfig,
330                    "only #insecure is supported as URL fragment"
331                )),
332                _ => ConnectionAddr::TcpTls {
333                    host,
334                    port,
335                    insecure: false,
336                    tls_params: None,
337                },
338            }
339        }
340
341        #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
342        fail!((
343            ErrorKind::InvalidClientConfig,
344            "can't connect with TLS, the feature is not enabled"
345        ));
346    } else {
347        ConnectionAddr::Tcp(host, port)
348    };
349    Ok(ConnectionInfo {
350        addr,
351        redis: RedisConnectionInfo {
352            db: match url.path().trim_matches('/') {
353                "" => 0,
354                path => unwrap_or!(
355                    path.parse::<i64>().ok(),
356                    fail!((ErrorKind::InvalidClientConfig, "Invalid database number"))
357                ),
358            },
359            username: if url.username().is_empty() {
360                None
361            } else {
362                match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() {
363                    Ok(decoded) => Some(decoded.into_owned()),
364                    Err(_) => fail!((
365                        ErrorKind::InvalidClientConfig,
366                        "Username is not valid UTF-8 string"
367                    )),
368                }
369            },
370            password: match url.password() {
371                Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() {
372                    Ok(decoded) => Some(decoded.into_owned()),
373                    Err(_) => fail!((
374                        ErrorKind::InvalidClientConfig,
375                        "Password is not valid UTF-8 string"
376                    )),
377                },
378                None => None,
379            },
380        },
381    })
382}
383
384#[cfg(unix)]
385fn url_to_unix_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
386    let query: HashMap<_, _> = url.query_pairs().collect();
387    Ok(ConnectionInfo {
388        addr: ConnectionAddr::Unix(unwrap_or!(
389            url.to_file_path().ok(),
390            fail!((ErrorKind::InvalidClientConfig, "Missing path"))
391        )),
392        redis: RedisConnectionInfo {
393            db: match query.get("db") {
394                Some(db) => unwrap_or!(
395                    db.parse::<i64>().ok(),
396                    fail!((ErrorKind::InvalidClientConfig, "Invalid database number"))
397                ),
398                None => 0,
399            },
400            username: query.get("user").map(|username| username.to_string()),
401            password: query.get("pass").map(|password| password.to_string()),
402        },
403    })
404}
405
406#[cfg(not(unix))]
407fn url_to_unix_connection_info(_: url::Url) -> RedisResult<ConnectionInfo> {
408    fail!((
409        ErrorKind::InvalidClientConfig,
410        "Unix sockets are not available on this platform."
411    ));
412}
413
414impl IntoConnectionInfo for url::Url {
415    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
416        match self.scheme() {
417            "redis" | "rediss" => url_to_tcp_connection_info(self),
418            "unix" | "redis+unix" => url_to_unix_connection_info(self),
419            _ => fail!((
420                ErrorKind::InvalidClientConfig,
421                "URL provided is not a redis URL"
422            )),
423        }
424    }
425}
426
427struct TcpConnection {
428    reader: TcpStream,
429    open: bool,
430}
431
432#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
433struct TcpNativeTlsConnection {
434    reader: TlsStream<TcpStream>,
435    open: bool,
436}
437
438#[cfg(feature = "tls-rustls")]
439struct TcpRustlsConnection {
440    reader: StreamOwned<rustls::ClientConnection, TcpStream>,
441    open: bool,
442}
443
444#[cfg(unix)]
445struct UnixConnection {
446    sock: UnixStream,
447    open: bool,
448}
449
450enum ActualConnection {
451    Tcp(TcpConnection),
452    #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
453    TcpNativeTls(Box<TcpNativeTlsConnection>),
454    #[cfg(feature = "tls-rustls")]
455    TcpRustls(Box<TcpRustlsConnection>),
456    #[cfg(unix)]
457    Unix(UnixConnection),
458}
459
460#[cfg(feature = "tls-rustls-insecure")]
461struct NoCertificateVerification;
462
463#[cfg(feature = "tls-rustls-insecure")]
464impl rustls::client::ServerCertVerifier for NoCertificateVerification {
465    fn verify_server_cert(
466        &self,
467        _end_entity: &rustls::Certificate,
468        _intermediates: &[rustls::Certificate],
469        _server_name: &rustls::ServerName,
470        _scts: &mut dyn Iterator<Item = &[u8]>,
471        _ocsp: &[u8],
472        _now: std::time::SystemTime,
473    ) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
474        Ok(rustls::client::ServerCertVerified::assertion())
475    }
476}
477
478/// Represents a stateful redis TCP connection.
479pub struct Connection {
480    con: ActualConnection,
481    parser: Parser,
482    db: i64,
483
484    /// Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`.
485    ///
486    /// This flag is checked when attempting to send a command, and if it's raised, we attempt to
487    /// exit the pubsub state before executing the new request.
488    pubsub: bool,
489}
490
491/// Represents a pubsub connection.
492pub struct PubSub<'a> {
493    con: &'a mut Connection,
494    waiting_messages: VecDeque<Msg>,
495}
496
497/// Represents a pubsub message.
498#[derive(Debug)]
499pub struct Msg {
500    payload: Value,
501    channel: Value,
502    pattern: Option<Value>,
503}
504
505impl ActualConnection {
506    pub fn new(addr: &ConnectionAddr, timeout: Option<Duration>) -> RedisResult<ActualConnection> {
507        Ok(match *addr {
508            ConnectionAddr::Tcp(ref host, ref port) => {
509                let addr = (host.as_str(), *port);
510                let tcp = match timeout {
511                    None => connect_tcp(addr)?,
512                    Some(timeout) => {
513                        let mut tcp = None;
514                        let mut last_error = None;
515                        for addr in addr.to_socket_addrs()? {
516                            match connect_tcp_timeout(&addr, timeout) {
517                                Ok(l) => {
518                                    tcp = Some(l);
519                                    break;
520                                }
521                                Err(e) => {
522                                    last_error = Some(e);
523                                }
524                            };
525                        }
526                        match (tcp, last_error) {
527                            (Some(tcp), _) => tcp,
528                            (None, Some(e)) => {
529                                fail!(e);
530                            }
531                            (None, None) => {
532                                fail!((
533                                    ErrorKind::InvalidClientConfig,
534                                    "could not resolve to any addresses"
535                                ));
536                            }
537                        }
538                    }
539                };
540                ActualConnection::Tcp(TcpConnection {
541                    reader: tcp,
542                    open: true,
543                })
544            }
545            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
546            ConnectionAddr::TcpTls {
547                ref host,
548                port,
549                insecure,
550                ..
551            } => {
552                let tls_connector = if insecure {
553                    TlsConnector::builder()
554                        .danger_accept_invalid_certs(true)
555                        .danger_accept_invalid_hostnames(true)
556                        .use_sni(false)
557                        .build()?
558                } else {
559                    TlsConnector::new()?
560                };
561                let addr = (host.as_str(), port);
562                let tls = match timeout {
563                    None => {
564                        let tcp = connect_tcp(addr)?;
565                        match tls_connector.connect(host, tcp) {
566                            Ok(res) => res,
567                            Err(e) => {
568                                fail!((ErrorKind::IoError, "SSL Handshake error", e.to_string()));
569                            }
570                        }
571                    }
572                    Some(timeout) => {
573                        let mut tcp = None;
574                        let mut last_error = None;
575                        for addr in (host.as_str(), port).to_socket_addrs()? {
576                            match connect_tcp_timeout(&addr, timeout) {
577                                Ok(l) => {
578                                    tcp = Some(l);
579                                    break;
580                                }
581                                Err(e) => {
582                                    last_error = Some(e);
583                                }
584                            };
585                        }
586                        match (tcp, last_error) {
587                            (Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(),
588                            (None, Some(e)) => {
589                                fail!(e);
590                            }
591                            (None, None) => {
592                                fail!((
593                                    ErrorKind::InvalidClientConfig,
594                                    "could not resolve to any addresses"
595                                ));
596                            }
597                        }
598                    }
599                };
600                ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection {
601                    reader: tls,
602                    open: true,
603                }))
604            }
605            #[cfg(feature = "tls-rustls")]
606            ConnectionAddr::TcpTls {
607                ref host,
608                port,
609                insecure,
610                ref tls_params,
611            } => {
612                let host: &str = host;
613                let config = create_rustls_config(insecure, tls_params.clone())?;
614                let conn = rustls::ClientConnection::new(Arc::new(config), host.try_into()?)?;
615                let reader = match timeout {
616                    None => {
617                        let tcp = connect_tcp((host, port))?;
618                        StreamOwned::new(conn, tcp)
619                    }
620                    Some(timeout) => {
621                        let mut tcp = None;
622                        let mut last_error = None;
623                        for addr in (host, port).to_socket_addrs()? {
624                            match connect_tcp_timeout(&addr, timeout) {
625                                Ok(l) => {
626                                    tcp = Some(l);
627                                    break;
628                                }
629                                Err(e) => {
630                                    last_error = Some(e);
631                                }
632                            };
633                        }
634                        match (tcp, last_error) {
635                            (Some(tcp), _) => StreamOwned::new(conn, tcp),
636                            (None, Some(e)) => {
637                                fail!(e);
638                            }
639                            (None, None) => {
640                                fail!((
641                                    ErrorKind::InvalidClientConfig,
642                                    "could not resolve to any addresses"
643                                ));
644                            }
645                        }
646                    }
647                };
648
649                ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true }))
650            }
651            #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
652            ConnectionAddr::TcpTls { .. } => {
653                fail!((
654                    ErrorKind::InvalidClientConfig,
655                    "Cannot connect to TCP with TLS without the tls feature"
656                ));
657            }
658            #[cfg(unix)]
659            ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection {
660                sock: UnixStream::connect(path)?,
661                open: true,
662            }),
663            #[cfg(not(unix))]
664            ConnectionAddr::Unix(ref _path) => {
665                fail!((
666                    ErrorKind::InvalidClientConfig,
667                    "Cannot connect to unix sockets \
668                     on this platform"
669                ));
670            }
671        })
672    }
673
674    pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
675        match *self {
676            ActualConnection::Tcp(ref mut connection) => {
677                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
678                match res {
679                    Err(e) => {
680                        if e.is_connection_dropped() {
681                            connection.open = false;
682                        }
683                        Err(e)
684                    }
685                    Ok(_) => Ok(Value::Okay),
686                }
687            }
688            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
689            ActualConnection::TcpNativeTls(ref mut connection) => {
690                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
691                match res {
692                    Err(e) => {
693                        if e.is_connection_dropped() {
694                            connection.open = false;
695                        }
696                        Err(e)
697                    }
698                    Ok(_) => Ok(Value::Okay),
699                }
700            }
701            #[cfg(feature = "tls-rustls")]
702            ActualConnection::TcpRustls(ref mut connection) => {
703                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
704                match res {
705                    Err(e) => {
706                        if e.is_connection_dropped() {
707                            connection.open = false;
708                        }
709                        Err(e)
710                    }
711                    Ok(_) => Ok(Value::Okay),
712                }
713            }
714            #[cfg(unix)]
715            ActualConnection::Unix(ref mut connection) => {
716                let result = connection.sock.write_all(bytes).map_err(RedisError::from);
717                match result {
718                    Err(e) => {
719                        if e.is_connection_dropped() {
720                            connection.open = false;
721                        }
722                        Err(e)
723                    }
724                    Ok(_) => Ok(Value::Okay),
725                }
726            }
727        }
728    }
729
730    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
731        match *self {
732            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
733                reader.set_write_timeout(dur)?;
734            }
735            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
736            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
737                let reader = &(boxed_tls_connection.reader);
738                reader.get_ref().set_write_timeout(dur)?;
739            }
740            #[cfg(feature = "tls-rustls")]
741            ActualConnection::TcpRustls(ref boxed_tls_connection) => {
742                let reader = &(boxed_tls_connection.reader);
743                reader.get_ref().set_write_timeout(dur)?;
744            }
745            #[cfg(unix)]
746            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
747                sock.set_write_timeout(dur)?;
748            }
749        }
750        Ok(())
751    }
752
753    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
754        match *self {
755            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
756                reader.set_read_timeout(dur)?;
757            }
758            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
759            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
760                let reader = &(boxed_tls_connection.reader);
761                reader.get_ref().set_read_timeout(dur)?;
762            }
763            #[cfg(feature = "tls-rustls")]
764            ActualConnection::TcpRustls(ref boxed_tls_connection) => {
765                let reader = &(boxed_tls_connection.reader);
766                reader.get_ref().set_read_timeout(dur)?;
767            }
768            #[cfg(unix)]
769            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
770                sock.set_read_timeout(dur)?;
771            }
772        }
773        Ok(())
774    }
775
776    pub fn is_open(&self) -> bool {
777        match *self {
778            ActualConnection::Tcp(TcpConnection { open, .. }) => open,
779            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
780            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open,
781            #[cfg(feature = "tls-rustls")]
782            ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open,
783            #[cfg(unix)]
784            ActualConnection::Unix(UnixConnection { open, .. }) => open,
785        }
786    }
787}
788
789#[cfg(feature = "tls-rustls")]
790pub(crate) fn create_rustls_config(
791    insecure: bool,
792    tls_params: Option<TlsConnParams>,
793) -> RedisResult<rustls::ClientConfig> {
794    use crate::tls::ClientTlsParams;
795
796    let mut root_store = RootCertStore::empty();
797    #[cfg(feature = "tls-rustls-webpki-roots")]
798    root_store.add_trust_anchors(TLS_SERVER_ROOTS.0.iter().map(|ta| {
799        OwnedTrustAnchor::from_subject_spki_name_constraints(
800            ta.subject,
801            ta.spki,
802            ta.name_constraints,
803        )
804    }));
805    #[cfg(all(feature = "tls-rustls", not(feature = "tls-rustls-webpki-roots")))]
806    for cert in load_native_certs()? {
807        root_store.add(&rustls::Certificate(cert.0))?;
808    }
809
810    let config = rustls::ClientConfig::builder()
811        .with_safe_default_cipher_suites()
812        .with_safe_default_kx_groups()
813        .with_protocol_versions(rustls::ALL_VERSIONS)?;
814
815    let config = if let Some(tls_params) = tls_params {
816        let config_builder =
817            config.with_root_certificates(tls_params.root_cert_store.unwrap_or(root_store));
818
819        if let Some(ClientTlsParams {
820            client_cert_chain: client_cert,
821            client_key,
822        }) = tls_params.client_tls_params
823        {
824            config_builder
825                .with_client_auth_cert(client_cert, client_key)
826                .map_err(|err| {
827                    RedisError::from((
828                        ErrorKind::InvalidClientConfig,
829                        "Unable to build client with TLS parameters provided.",
830                        err.to_string(),
831                    ))
832                })?
833        } else {
834            config_builder.with_no_client_auth()
835        }
836    } else {
837        config
838            .with_root_certificates(root_store)
839            .with_no_client_auth()
840    };
841
842    match (insecure, cfg!(feature = "tls-rustls-insecure")) {
843        #[cfg(feature = "tls-rustls-insecure")]
844        (true, true) => {
845            let mut config = config;
846            config.enable_sni = false;
847            config
848                .dangerous()
849                .set_certificate_verifier(Arc::new(NoCertificateVerification));
850
851            Ok(config)
852        }
853        (true, false) => {
854            fail!((
855                ErrorKind::InvalidClientConfig,
856                "Cannot create insecure client without tls-rustls-insecure feature"
857            ));
858        }
859        _ => Ok(config),
860    }
861}
862
863fn connect_auth(con: &mut Connection, connection_info: &RedisConnectionInfo) -> RedisResult<()> {
864    let mut command = cmd("AUTH");
865    if let Some(username) = &connection_info.username {
866        command.arg(username);
867    }
868    let password = connection_info.password.as_ref().unwrap();
869    let err = match command.arg(password).query::<Value>(con) {
870        Ok(Value::Okay) => return Ok(()),
871        Ok(_) => {
872            fail!((
873                ErrorKind::ResponseError,
874                "Redis server refused to authenticate, returns Ok() != Value::Okay"
875            ));
876        }
877        Err(e) => e,
878    };
879    let err_msg = err.detail().ok_or((
880        ErrorKind::AuthenticationFailed,
881        "Password authentication failed",
882    ))?;
883    if !err_msg.contains("wrong number of arguments for 'auth' command") {
884        fail!((
885            ErrorKind::AuthenticationFailed,
886            "Password authentication failed",
887        ));
888    }
889
890    // fallback to AUTH version <= 5
891    let mut command = cmd("AUTH");
892    match command.arg(password).query::<Value>(con) {
893        Ok(Value::Okay) => Ok(()),
894        _ => fail!((
895            ErrorKind::AuthenticationFailed,
896            "Password authentication failed",
897        )),
898    }
899}
900
901pub fn connect(
902    connection_info: &ConnectionInfo,
903    timeout: Option<Duration>,
904) -> RedisResult<Connection> {
905    let con = ActualConnection::new(&connection_info.addr, timeout)?;
906    setup_connection(con, &connection_info.redis)
907}
908
909pub(crate) fn client_set_info_pipeline() -> Pipeline {
910    let mut pipeline = crate::pipe();
911    pipeline
912        .cmd("CLIENT")
913        .arg("SETINFO")
914        .arg("LIB-NAME")
915        .arg("redis-rs")
916        .ignore();
917    pipeline
918        .cmd("CLIENT")
919        .arg("SETINFO")
920        .arg("LIB-VER")
921        .arg(env!("CARGO_PKG_VERSION"))
922        .ignore();
923    pipeline
924}
925
926fn setup_connection(
927    con: ActualConnection,
928    connection_info: &RedisConnectionInfo,
929) -> RedisResult<Connection> {
930    let mut rv = Connection {
931        con,
932        parser: Parser::new(),
933        db: connection_info.db,
934        pubsub: false,
935    };
936
937    if connection_info.password.is_some() {
938        connect_auth(&mut rv, connection_info)?;
939    }
940
941    if connection_info.db != 0 {
942        match cmd("SELECT")
943            .arg(connection_info.db)
944            .query::<Value>(&mut rv)
945        {
946            Ok(Value::Okay) => {}
947            _ => fail!((
948                ErrorKind::ResponseError,
949                "Redis server refused to switch database"
950            )),
951        }
952    }
953
954    // result is ignored, as per the command's instructions.
955    // https://redis.io/commands/client-setinfo/
956    let _: RedisResult<()> = client_set_info_pipeline().query(&mut rv);
957
958    Ok(rv)
959}
960
961/// Implements the "stateless" part of the connection interface that is used by the
962/// different objects in redis-rs.  Primarily it obviously applies to `Connection`
963/// object but also some other objects implement the interface (for instance
964/// whole clients or certain redis results).
965///
966/// Generally clients and connections (as well as redis results of those) implement
967/// this trait.  Actual connections provide more functionality which can be used
968/// to implement things like `PubSub` but they also can modify the intrinsic
969/// state of the TCP connection.  This is not possible with `ConnectionLike`
970/// implementors because that functionality is not exposed.
971pub trait ConnectionLike {
972    /// Sends an already encoded (packed) command into the TCP socket and
973    /// reads the single response from it.
974    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value>;
975
976    /// Sends multiple already encoded (packed) command into the TCP socket
977    /// and reads `count` responses from it.  This is used to implement
978    /// pipelining.
979    fn req_packed_commands(
980        &mut self,
981        cmd: &[u8],
982        offset: usize,
983        count: usize,
984    ) -> RedisResult<Vec<Value>>;
985
986    /// Sends a [Cmd] into the TCP socket and reads a single response from it.
987    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
988        let pcmd = cmd.get_packed_command();
989        self.req_packed_command(&pcmd)
990    }
991
992    /// Returns the database this connection is bound to.  Note that this
993    /// information might be unreliable because it's initially cached and
994    /// also might be incorrect if the connection like object is not
995    /// actually connected.
996    fn get_db(&self) -> i64;
997
998    /// Does this connection support pipelining?
999    #[doc(hidden)]
1000    fn supports_pipelining(&self) -> bool {
1001        true
1002    }
1003
1004    /// Check that all connections it has are available (`PING` internally).
1005    fn check_connection(&mut self) -> bool;
1006
1007    /// Returns the connection status.
1008    ///
1009    /// The connection is open until any `read_response` call recieved an
1010    /// invalid response from the server (most likely a closed or dropped
1011    /// connection, otherwise a Redis protocol error). When using unix
1012    /// sockets the connection is open until writing a command failed with a
1013    /// `BrokenPipe` error.
1014    fn is_open(&self) -> bool;
1015}
1016
1017/// A connection is an object that represents a single redis connection.  It
1018/// provides basic support for sending encoded commands into a redis connection
1019/// and to read a response from it.  It's bound to a single database and can
1020/// only be created from the client.
1021///
1022/// You generally do not much with this object other than passing it to
1023/// `Cmd` objects.
1024impl Connection {
1025    /// Sends an already encoded (packed) command into the TCP socket and
1026    /// does not read a response.  This is useful for commands like
1027    /// `MONITOR` which yield multiple items.  This needs to be used with
1028    /// care because it changes the state of the connection.
1029    pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> {
1030        self.con.send_bytes(cmd)?;
1031        Ok(())
1032    }
1033
1034    /// Fetches a single response from the connection.  This is useful
1035    /// if used in combination with `send_packed_command`.
1036    pub fn recv_response(&mut self) -> RedisResult<Value> {
1037        self.read_response()
1038    }
1039
1040    /// Sets the write timeout for the connection.
1041    ///
1042    /// If the provided value is `None`, then `send_packed_command` call will
1043    /// block indefinitely. It is an error to pass the zero `Duration` to this
1044    /// method.
1045    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1046        self.con.set_write_timeout(dur)
1047    }
1048
1049    /// Sets the read timeout for the connection.
1050    ///
1051    /// If the provided value is `None`, then `recv_response` call will
1052    /// block indefinitely. It is an error to pass the zero `Duration` to this
1053    /// method.
1054    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1055        self.con.set_read_timeout(dur)
1056    }
1057
1058    /// Creates a [`PubSub`] instance for this connection.
1059    pub fn as_pubsub(&mut self) -> PubSub<'_> {
1060        // NOTE: The pubsub flag is intentionally not raised at this time since
1061        // running commands within the pubsub state should not try and exit from
1062        // the pubsub state.
1063        PubSub::new(self)
1064    }
1065
1066    fn exit_pubsub(&mut self) -> RedisResult<()> {
1067        let res = self.clear_active_subscriptions();
1068        if res.is_ok() {
1069            self.pubsub = false;
1070        } else {
1071            // Raise the pubsub flag to indicate the connection is "stuck" in that state.
1072            self.pubsub = true;
1073        }
1074
1075        res
1076    }
1077
1078    /// Get the inner connection out of a PubSub
1079    ///
1080    /// Any active subscriptions are unsubscribed. In the event of an error, the connection is
1081    /// dropped.
1082    fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
1083        // Responses to unsubscribe commands return in a 3-tuple with values
1084        // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs).
1085        // The "count of remaining subs" includes both pattern subscriptions and non pattern
1086        // subscriptions. Thus, to accurately drain all unsubscribe messages received from the
1087        // server, both commands need to be executed at once.
1088        {
1089            // Prepare both unsubscribe commands
1090            let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command();
1091            let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command();
1092
1093            // Grab a reference to the underlying connection so that we may send
1094            // the commands without immediately blocking for a response.
1095            let con = &mut self.con;
1096
1097            // Execute commands
1098            con.send_bytes(&unsubscribe)?;
1099            con.send_bytes(&punsubscribe)?;
1100        }
1101
1102        // Receive responses
1103        //
1104        // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe
1105        // commands. There may be more responses if there are active subscriptions. In this case,
1106        // messages are received until the _subscription count_ in the responses reach zero.
1107        let mut received_unsub = false;
1108        let mut received_punsub = false;
1109        loop {
1110            let res: (Vec<u8>, (), isize) = from_redis_value(&self.recv_response()?)?;
1111
1112            match res.0.first() {
1113                Some(&b'u') => received_unsub = true,
1114                Some(&b'p') => received_punsub = true,
1115                _ => (),
1116            }
1117
1118            if received_unsub && received_punsub && res.2 == 0 {
1119                break;
1120            }
1121        }
1122
1123        // Finally, the connection is back in its normal state since all subscriptions were
1124        // cancelled *and* all unsubscribe messages were received.
1125        Ok(())
1126    }
1127
1128    /// Fetches a single response from the connection.
1129    fn read_response(&mut self) -> RedisResult<Value> {
1130        let result = match self.con {
1131            ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => {
1132                self.parser.parse_value(reader)
1133            }
1134            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1135            ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => {
1136                let reader = &mut boxed_tls_connection.reader;
1137                self.parser.parse_value(reader)
1138            }
1139            #[cfg(feature = "tls-rustls")]
1140            ActualConnection::TcpRustls(ref mut boxed_tls_connection) => {
1141                let reader = &mut boxed_tls_connection.reader;
1142                self.parser.parse_value(reader)
1143            }
1144            #[cfg(unix)]
1145            ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => {
1146                self.parser.parse_value(sock)
1147            }
1148        };
1149        // shutdown connection on protocol error
1150        if let Err(e) = &result {
1151            let shutdown = match e.as_io_error() {
1152                Some(e) => e.kind() == io::ErrorKind::UnexpectedEof,
1153                None => false,
1154            };
1155            if shutdown {
1156                match self.con {
1157                    ActualConnection::Tcp(ref mut connection) => {
1158                        let _ = connection.reader.shutdown(net::Shutdown::Both);
1159                        connection.open = false;
1160                    }
1161                    #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1162                    ActualConnection::TcpNativeTls(ref mut connection) => {
1163                        let _ = connection.reader.shutdown();
1164                        connection.open = false;
1165                    }
1166                    #[cfg(feature = "tls-rustls")]
1167                    ActualConnection::TcpRustls(ref mut connection) => {
1168                        let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both);
1169                        connection.open = false;
1170                    }
1171                    #[cfg(unix)]
1172                    ActualConnection::Unix(ref mut connection) => {
1173                        let _ = connection.sock.shutdown(net::Shutdown::Both);
1174                        connection.open = false;
1175                    }
1176                }
1177            }
1178        }
1179        result
1180    }
1181}
1182
1183impl ConnectionLike for Connection {
1184    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1185        if self.pubsub {
1186            self.exit_pubsub()?;
1187        }
1188
1189        self.con.send_bytes(cmd)?;
1190        self.read_response()
1191    }
1192
1193    fn req_packed_commands(
1194        &mut self,
1195        cmd: &[u8],
1196        offset: usize,
1197        count: usize,
1198    ) -> RedisResult<Vec<Value>> {
1199        if self.pubsub {
1200            self.exit_pubsub()?;
1201        }
1202        self.con.send_bytes(cmd)?;
1203        let mut rv = vec![];
1204        let mut first_err = None;
1205        for idx in 0..(offset + count) {
1206            // When processing a transaction, some responses may be errors.
1207            // We need to keep processing the rest of the responses in that case,
1208            // so bailing early with `?` would not be correct.
1209            // See: https://github.com/redis-rs/redis-rs/issues/436
1210            let response = self.read_response();
1211            match response {
1212                Ok(item) => {
1213                    if idx >= offset {
1214                        rv.push(item);
1215                    }
1216                }
1217                Err(err) => {
1218                    if first_err.is_none() {
1219                        first_err = Some(err);
1220                    }
1221                }
1222            }
1223        }
1224
1225        first_err.map_or(Ok(rv), Err)
1226    }
1227
1228    fn get_db(&self) -> i64 {
1229        self.db
1230    }
1231
1232    fn is_open(&self) -> bool {
1233        self.con.is_open()
1234    }
1235
1236    fn check_connection(&mut self) -> bool {
1237        cmd("PING").query::<String>(self).is_ok()
1238    }
1239}
1240
1241impl<C, T> ConnectionLike for T
1242where
1243    C: ConnectionLike,
1244    T: DerefMut<Target = C>,
1245{
1246    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1247        self.deref_mut().req_packed_command(cmd)
1248    }
1249
1250    fn req_packed_commands(
1251        &mut self,
1252        cmd: &[u8],
1253        offset: usize,
1254        count: usize,
1255    ) -> RedisResult<Vec<Value>> {
1256        self.deref_mut().req_packed_commands(cmd, offset, count)
1257    }
1258
1259    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1260        self.deref_mut().req_command(cmd)
1261    }
1262
1263    fn get_db(&self) -> i64 {
1264        self.deref().get_db()
1265    }
1266
1267    fn supports_pipelining(&self) -> bool {
1268        self.deref().supports_pipelining()
1269    }
1270
1271    fn check_connection(&mut self) -> bool {
1272        self.deref_mut().check_connection()
1273    }
1274
1275    fn is_open(&self) -> bool {
1276        self.deref().is_open()
1277    }
1278}
1279
1280/// The pubsub object provides convenient access to the redis pubsub
1281/// system.  Once created you can subscribe and unsubscribe from channels
1282/// and listen in on messages.
1283///
1284/// Example:
1285///
1286/// ```rust,no_run
1287/// # fn do_something() -> redis::RedisResult<()> {
1288/// let client = redis::Client::open("redis://127.0.0.1/")?;
1289/// let mut con = client.get_connection()?;
1290/// let mut pubsub = con.as_pubsub();
1291/// pubsub.subscribe("channel_1")?;
1292/// pubsub.subscribe("channel_2")?;
1293///
1294/// loop {
1295///     let msg = pubsub.get_message()?;
1296///     let payload : String = msg.get_payload()?;
1297///     println!("channel '{}': {}", msg.get_channel_name(), payload);
1298/// }
1299/// # }
1300/// ```
1301impl<'a> PubSub<'a> {
1302    fn new(con: &'a mut Connection) -> Self {
1303        Self {
1304            con,
1305            waiting_messages: VecDeque::new(),
1306        }
1307    }
1308
1309    fn cache_messages_until_received_response(&mut self, cmd: &Cmd) -> RedisResult<()> {
1310        let mut response = self.con.req_packed_command(&cmd.get_packed_command())?;
1311        loop {
1312            if let Some(msg) = Msg::from_value(&response) {
1313                self.waiting_messages.push_back(msg);
1314            } else {
1315                return Ok(());
1316            }
1317            response = self.con.recv_response()?;
1318        }
1319    }
1320
1321    /// Subscribes to a new channel.
1322    pub fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1323        self.cache_messages_until_received_response(cmd("SUBSCRIBE").arg(channel))
1324    }
1325
1326    /// Subscribes to a new channel with a pattern.
1327    pub fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1328        self.cache_messages_until_received_response(cmd("PSUBSCRIBE").arg(pchannel))
1329    }
1330
1331    /// Unsubscribes from a channel.
1332    pub fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1333        self.cache_messages_until_received_response(cmd("UNSUBSCRIBE").arg(channel))
1334    }
1335
1336    /// Unsubscribes from a channel with a pattern.
1337    pub fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1338        self.cache_messages_until_received_response(cmd("PUNSUBSCRIBE").arg(pchannel))
1339    }
1340
1341    /// Fetches the next message from the pubsub connection.  Blocks until
1342    /// a message becomes available.  This currently does not provide a
1343    /// wait not to block :(
1344    ///
1345    /// The message itself is still generic and can be converted into an
1346    /// appropriate type through the helper methods on it.
1347    pub fn get_message(&mut self) -> RedisResult<Msg> {
1348        if let Some(msg) = self.waiting_messages.pop_front() {
1349            return Ok(msg);
1350        }
1351        loop {
1352            if let Some(msg) = Msg::from_value(&self.con.recv_response()?) {
1353                return Ok(msg);
1354            } else {
1355                continue;
1356            }
1357        }
1358    }
1359
1360    /// Sets the read timeout for the connection.
1361    ///
1362    /// If the provided value is `None`, then `get_message` call will
1363    /// block indefinitely. It is an error to pass the zero `Duration` to this
1364    /// method.
1365    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1366        self.con.set_read_timeout(dur)
1367    }
1368}
1369
1370impl<'a> Drop for PubSub<'a> {
1371    fn drop(&mut self) {
1372        let _ = self.con.exit_pubsub();
1373    }
1374}
1375
1376/// This holds the data that comes from listening to a pubsub
1377/// connection.  It only contains actual message data.
1378impl Msg {
1379    /// Tries to convert provided [`Value`] into [`Msg`].
1380    pub fn from_value(value: &Value) -> Option<Self> {
1381        let raw_msg: Vec<Value> = from_redis_value(value).ok()?;
1382        let mut iter = raw_msg.into_iter();
1383        let msg_type: String = from_redis_value(&iter.next()?).ok()?;
1384        let mut pattern = None;
1385        let payload;
1386        let channel;
1387
1388        if msg_type == "message" {
1389            channel = iter.next()?;
1390            payload = iter.next()?;
1391        } else if msg_type == "pmessage" {
1392            pattern = Some(iter.next()?);
1393            channel = iter.next()?;
1394            payload = iter.next()?;
1395        } else {
1396            return None;
1397        }
1398
1399        Some(Msg {
1400            payload,
1401            channel,
1402            pattern,
1403        })
1404    }
1405
1406    /// Returns the channel this message came on.
1407    pub fn get_channel<T: FromRedisValue>(&self) -> RedisResult<T> {
1408        from_redis_value(&self.channel)
1409    }
1410
1411    /// Convenience method to get a string version of the channel.  Unless
1412    /// your channel contains non utf-8 bytes you can always use this
1413    /// method.  If the channel is not a valid string (which really should
1414    /// not happen) then the return value is `"?"`.
1415    pub fn get_channel_name(&self) -> &str {
1416        match self.channel {
1417            Value::Data(ref bytes) => from_utf8(bytes).unwrap_or("?"),
1418            _ => "?",
1419        }
1420    }
1421
1422    /// Returns the message's payload in a specific format.
1423    pub fn get_payload<T: FromRedisValue>(&self) -> RedisResult<T> {
1424        from_redis_value(&self.payload)
1425    }
1426
1427    /// Returns the bytes that are the message's payload.  This can be used
1428    /// as an alternative to the `get_payload` function if you are interested
1429    /// in the raw bytes in it.
1430    pub fn get_payload_bytes(&self) -> &[u8] {
1431        match self.payload {
1432            Value::Data(ref bytes) => bytes,
1433            _ => b"",
1434        }
1435    }
1436
1437    /// Returns true if the message was constructed from a pattern
1438    /// subscription.
1439    #[allow(clippy::wrong_self_convention)]
1440    pub fn from_pattern(&self) -> bool {
1441        self.pattern.is_some()
1442    }
1443
1444    /// If the message was constructed from a message pattern this can be
1445    /// used to find out which one.  It's recommended to match against
1446    /// an `Option<String>` so that you do not need to use `from_pattern`
1447    /// to figure out if a pattern was set.
1448    pub fn get_pattern<T: FromRedisValue>(&self) -> RedisResult<T> {
1449        match self.pattern {
1450            None => from_redis_value(&Value::Nil),
1451            Some(ref x) => from_redis_value(x),
1452        }
1453    }
1454}
1455
1456/// This function simplifies transaction management slightly.  What it
1457/// does is automatically watching keys and then going into a transaction
1458/// loop util it succeeds.  Once it goes through the results are
1459/// returned.
1460///
1461/// To use the transaction two pieces of information are needed: a list
1462/// of all the keys that need to be watched for modifications and a
1463/// closure with the code that should be execute in the context of the
1464/// transaction.  The closure is invoked with a fresh pipeline in atomic
1465/// mode.  To use the transaction the function needs to return the result
1466/// from querying the pipeline with the connection.
1467///
1468/// The end result of the transaction is then available as the return
1469/// value from the function call.
1470///
1471/// Example:
1472///
1473/// ```rust,no_run
1474/// use redis::Commands;
1475/// # fn do_something() -> redis::RedisResult<()> {
1476/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap();
1477/// # let mut con = client.get_connection().unwrap();
1478/// let key = "the_key";
1479/// let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| {
1480///     let old_val : isize = con.get(key)?;
1481///     pipe
1482///         .set(key, old_val + 1).ignore()
1483///         .get(key).query(con)
1484/// })?;
1485/// println!("The incremented number is: {}", new_val);
1486/// # Ok(()) }
1487/// ```
1488pub fn transaction<
1489    C: ConnectionLike,
1490    K: ToRedisArgs,
1491    T,
1492    F: FnMut(&mut C, &mut Pipeline) -> RedisResult<Option<T>>,
1493>(
1494    con: &mut C,
1495    keys: &[K],
1496    func: F,
1497) -> RedisResult<T> {
1498    let mut func = func;
1499    loop {
1500        cmd("WATCH").arg(keys).query::<()>(con)?;
1501        let mut p = pipe();
1502        let response: Option<T> = func(con, p.atomic())?;
1503        match response {
1504            None => {
1505                continue;
1506            }
1507            Some(response) => {
1508                // make sure no watch is left in the connection, even if
1509                // someone forgot to use the pipeline.
1510                cmd("UNWATCH").query::<()>(con)?;
1511                return Ok(response);
1512            }
1513        }
1514    }
1515}
1516
1517#[cfg(test)]
1518mod tests {
1519    use super::*;
1520
1521    #[test]
1522    fn test_parse_redis_url() {
1523        let cases = vec![
1524            ("redis://127.0.0.1", true),
1525            ("redis://[::1]", true),
1526            ("redis+unix:///run/redis.sock", true),
1527            ("unix:///run/redis.sock", true),
1528            ("http://127.0.0.1", false),
1529            ("tcp://127.0.0.1", false),
1530        ];
1531        for (url, expected) in cases.into_iter() {
1532            let res = parse_redis_url(url);
1533            assert_eq!(
1534                res.is_some(),
1535                expected,
1536                "Parsed result of `{url}` is not expected",
1537            );
1538        }
1539    }
1540
1541    #[test]
1542    fn test_url_to_tcp_connection_info() {
1543        let cases = vec![
1544            (
1545                url::Url::parse("redis://127.0.0.1").unwrap(),
1546                ConnectionInfo {
1547                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
1548                    redis: Default::default(),
1549                },
1550            ),
1551            (
1552                url::Url::parse("redis://[::1]").unwrap(),
1553                ConnectionInfo {
1554                    addr: ConnectionAddr::Tcp("::1".to_string(), 6379),
1555                    redis: Default::default(),
1556                },
1557            ),
1558            (
1559                url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(),
1560                ConnectionInfo {
1561                    addr: ConnectionAddr::Tcp("example.com".to_string(), 6379),
1562                    redis: RedisConnectionInfo {
1563                        db: 2,
1564                        username: Some("%johndoe%".to_string()),
1565                        password: Some("#@<>$".to_string()),
1566                    },
1567                },
1568            ),
1569        ];
1570        for (url, expected) in cases.into_iter() {
1571            let res = url_to_tcp_connection_info(url.clone()).unwrap();
1572            assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
1573            assert_eq!(
1574                res.redis.db, expected.redis.db,
1575                "db of {url} is not expected",
1576            );
1577            assert_eq!(
1578                res.redis.username, expected.redis.username,
1579                "username of {url} is not expected",
1580            );
1581            assert_eq!(
1582                res.redis.password, expected.redis.password,
1583                "password of {url} is not expected",
1584            );
1585        }
1586    }
1587
1588    #[test]
1589    fn test_url_to_tcp_connection_info_failed() {
1590        let cases = vec![
1591            (url::Url::parse("redis://").unwrap(), "Missing hostname"),
1592            (
1593                url::Url::parse("redis://127.0.0.1/db").unwrap(),
1594                "Invalid database number",
1595            ),
1596            (
1597                url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(),
1598                "Username is not valid UTF-8 string",
1599            ),
1600            (
1601                url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(),
1602                "Password is not valid UTF-8 string",
1603            ),
1604        ];
1605        for (url, expected) in cases.into_iter() {
1606            let res = url_to_tcp_connection_info(url).unwrap_err();
1607            assert_eq!(
1608                res.kind(),
1609                crate::ErrorKind::InvalidClientConfig,
1610                "{}",
1611                &res,
1612            );
1613            #[allow(deprecated)]
1614            let desc = std::error::Error::description(&res);
1615            assert_eq!(desc, expected, "{}", &res);
1616            assert_eq!(res.detail(), None, "{}", &res);
1617        }
1618    }
1619
1620    #[test]
1621    #[cfg(unix)]
1622    fn test_url_to_unix_connection_info() {
1623        let cases = vec![
1624            (
1625                url::Url::parse("unix:///var/run/redis.sock").unwrap(),
1626                ConnectionInfo {
1627                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
1628                    redis: RedisConnectionInfo {
1629                        db: 0,
1630                        username: None,
1631                        password: None,
1632                    },
1633                },
1634            ),
1635            (
1636                url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(),
1637                ConnectionInfo {
1638                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
1639                    redis: RedisConnectionInfo {
1640                        db: 1,
1641                        username: None,
1642                        password: None,
1643                    },
1644                },
1645            ),
1646            (
1647                url::Url::parse(
1648                    "unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2",
1649                )
1650                .unwrap(),
1651                ConnectionInfo {
1652                    addr: ConnectionAddr::Unix("/example.sock".into()),
1653                    redis: RedisConnectionInfo {
1654                        db: 2,
1655                        username: Some("%johndoe%".to_string()),
1656                        password: Some("#@<>$".to_string()),
1657                    },
1658                },
1659            ),
1660            (
1661                url::Url::parse(
1662                    "redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25",
1663                )
1664                .unwrap(),
1665                ConnectionInfo {
1666                    addr: ConnectionAddr::Unix("/example.sock".into()),
1667                    redis: RedisConnectionInfo {
1668                        db: 2,
1669                        username: Some("%johndoe%".to_string()),
1670                        password: Some("&?= *+".to_string()),
1671                    },
1672                },
1673            ),
1674        ];
1675        for (url, expected) in cases.into_iter() {
1676            assert_eq!(
1677                ConnectionAddr::Unix(url.to_file_path().unwrap()),
1678                expected.addr,
1679                "addr of {url} is not expected",
1680            );
1681            let res = url_to_unix_connection_info(url.clone()).unwrap();
1682            assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
1683            assert_eq!(
1684                res.redis.db, expected.redis.db,
1685                "db of {url} is not expected",
1686            );
1687            assert_eq!(
1688                res.redis.username, expected.redis.username,
1689                "username of {url} is not expected",
1690            );
1691            assert_eq!(
1692                res.redis.password, expected.redis.password,
1693                "password of {url} is not expected",
1694            );
1695        }
1696    }
1697}