redis/
connection.rs

1use std::borrow::Cow;
2use std::collections::VecDeque;
3use std::fmt;
4use std::io::{self, Write};
5use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs};
6use std::ops::DerefMut;
7use std::path::PathBuf;
8use std::str::{from_utf8, FromStr};
9use std::time::{Duration, Instant};
10
11use crate::cmd::{cmd, pipe, Cmd};
12use crate::errors::{ErrorKind, RedisError, ServerError, ServerErrorKind};
13use crate::io::tcp::{stream_with_settings, TcpSettings};
14use crate::parser::Parser;
15use crate::pipeline::Pipeline;
16use crate::types::{
17    from_redis_value_ref, FromRedisValue, HashMap, PushKind, RedisResult, SyncPushSender,
18    ToRedisArgs, Value,
19};
20use crate::{check_resp3, from_redis_value, ProtocolVersion};
21
22#[cfg(unix)]
23use std::os::unix::net::UnixStream;
24
25use crate::commands::resp3_hello;
26use arcstr::ArcStr;
27#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
28use native_tls::{TlsConnector, TlsStream};
29
30#[cfg(feature = "tls-rustls")]
31use rustls::{RootCertStore, StreamOwned};
32#[cfg(feature = "tls-rustls")]
33use std::sync::Arc;
34
35use crate::PushInfo;
36
37#[cfg(all(
38    feature = "tls-rustls",
39    not(feature = "tls-native-tls"),
40    not(feature = "tls-rustls-webpki-roots")
41))]
42use rustls_native_certs::load_native_certs;
43
44#[cfg(feature = "tls-rustls")]
45use crate::tls::ClientTlsParams;
46
47// Non-exhaustive to prevent construction outside this crate
48#[derive(Clone, Debug)]
49pub struct TlsConnParams {
50    #[cfg(feature = "tls-rustls")]
51    pub(crate) client_tls_params: Option<ClientTlsParams>,
52    #[cfg(feature = "tls-rustls")]
53    pub(crate) root_cert_store: Option<RootCertStore>,
54    #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
55    pub(crate) danger_accept_invalid_hostnames: bool,
56}
57
58static DEFAULT_PORT: u16 = 6379;
59
60#[inline(always)]
61fn connect_tcp(addr: (&str, u16), tcp_settings: &TcpSettings) -> io::Result<TcpStream> {
62    let socket = TcpStream::connect(addr)?;
63    stream_with_settings(socket, tcp_settings)
64}
65
66#[inline(always)]
67fn connect_tcp_timeout(
68    addr: &SocketAddr,
69    timeout: Duration,
70    tcp_settings: &TcpSettings,
71) -> io::Result<TcpStream> {
72    let socket = TcpStream::connect_timeout(addr, timeout)?;
73    stream_with_settings(socket, tcp_settings)
74}
75
76/// This function takes a redis URL string and parses it into a URL
77/// as used by rust-url.
78///
79/// This is necessary as the default parser does not understand how redis URLs function.
80pub fn parse_redis_url(input: &str) -> Option<url::Url> {
81    match url::Url::parse(input) {
82        Ok(result) => match result.scheme() {
83            "redis" | "rediss" | "valkey" | "valkeys" | "redis+unix" | "valkey+unix" | "unix" => {
84                Some(result)
85            }
86            _ => None,
87        },
88        Err(_) => None,
89    }
90}
91
92/// TlsMode indicates use or do not use verification of certification.
93///
94/// Check [ConnectionAddr](ConnectionAddr::TcpTls::insecure) for more.
95#[derive(Clone, Copy, PartialEq)]
96#[non_exhaustive]
97pub enum TlsMode {
98    /// Secure verify certification.
99    Secure,
100    /// Insecure do not verify certification.
101    Insecure,
102}
103
104/// Defines the connection address.
105///
106/// Not all connection addresses are supported on all platforms.  For instance
107/// to connect to a unix socket you need to run this on an operating system
108/// that supports them.
109#[derive(Clone, Debug)]
110#[non_exhaustive]
111pub enum ConnectionAddr {
112    /// Format for this is `(host, port)`.
113    Tcp(String, u16),
114    /// Format for this is `(host, port)`.
115    TcpTls {
116        /// Hostname
117        host: String,
118        /// Port
119        port: u16,
120        /// Disable hostname verification when connecting.
121        ///
122        /// # Warning
123        ///
124        /// You should think very carefully before you use this method. If hostname
125        /// verification is not used, any valid certificate for any site will be
126        /// trusted for use from any other. This introduces a significant
127        /// vulnerability to man-in-the-middle attacks.
128        insecure: bool,
129
130        /// TLS certificates and client key.
131        tls_params: Option<TlsConnParams>,
132    },
133    /// Format for this is the path to the unix socket.
134    Unix(PathBuf),
135}
136
137impl PartialEq for ConnectionAddr {
138    fn eq(&self, other: &Self) -> bool {
139        match (self, other) {
140            (ConnectionAddr::Tcp(host1, port1), ConnectionAddr::Tcp(host2, port2)) => {
141                host1 == host2 && port1 == port2
142            }
143            (
144                ConnectionAddr::TcpTls {
145                    host: host1,
146                    port: port1,
147                    insecure: insecure1,
148                    tls_params: _,
149                },
150                ConnectionAddr::TcpTls {
151                    host: host2,
152                    port: port2,
153                    insecure: insecure2,
154                    tls_params: _,
155                },
156            ) => port1 == port2 && host1 == host2 && insecure1 == insecure2,
157            (ConnectionAddr::Unix(path1), ConnectionAddr::Unix(path2)) => path1 == path2,
158            _ => false,
159        }
160    }
161}
162
163impl Eq for ConnectionAddr {}
164
165impl ConnectionAddr {
166    /// Checks if this address is supported.
167    ///
168    /// Because not all platforms support all connection addresses this is a
169    /// quick way to figure out if a connection method is supported. Currently
170    /// this affects:
171    ///
172    /// - Unix socket addresses, which are supported only on Unix
173    ///
174    /// - TLS addresses, which are supported only if a TLS feature is enabled
175    ///   (either `tls-native-tls` or `tls-rustls`).
176    pub fn is_supported(&self) -> bool {
177        match *self {
178            ConnectionAddr::Tcp(_, _) => true,
179            ConnectionAddr::TcpTls { .. } => {
180                cfg!(any(feature = "tls-native-tls", feature = "tls-rustls"))
181            }
182            ConnectionAddr::Unix(_) => cfg!(unix),
183        }
184    }
185
186    /// Configure this address to connect without checking certificate hostnames.
187    ///
188    /// # Warning
189    ///
190    /// You should think very carefully before you use this method. If hostname
191    /// verification is not used, any valid certificate for any site will be
192    /// trusted for use from any other. This introduces a significant
193    /// vulnerability to man-in-the-middle attacks.
194    #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
195    pub fn set_danger_accept_invalid_hostnames(&mut self, insecure: bool) {
196        if let ConnectionAddr::TcpTls { tls_params, .. } = self {
197            if let Some(ref mut params) = tls_params {
198                params.danger_accept_invalid_hostnames = insecure;
199            } else if insecure {
200                *tls_params = Some(TlsConnParams {
201                    #[cfg(feature = "tls-rustls")]
202                    client_tls_params: None,
203                    #[cfg(feature = "tls-rustls")]
204                    root_cert_store: None,
205                    danger_accept_invalid_hostnames: insecure,
206                });
207            }
208        }
209    }
210
211    #[cfg(feature = "cluster")]
212    pub(crate) fn tls_mode(&self) -> Option<TlsMode> {
213        match self {
214            ConnectionAddr::TcpTls { insecure, .. } => {
215                if *insecure {
216                    Some(TlsMode::Insecure)
217                } else {
218                    Some(TlsMode::Secure)
219                }
220            }
221            _ => None,
222        }
223    }
224}
225
226impl fmt::Display for ConnectionAddr {
227    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
228        // Cluster::get_connection_info depends on the return value from this function
229        match *self {
230            ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"),
231            ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"),
232            ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()),
233        }
234    }
235}
236
237/// Holds the connection information that redis should use for connecting.
238#[derive(Clone, Debug)]
239pub struct ConnectionInfo {
240    /// A connection address for where to connect to.
241    pub(crate) addr: ConnectionAddr,
242
243    /// The settings for the TCP connection
244    pub(crate) tcp_settings: TcpSettings,
245    /// A redis connection info for how to handshake with redis.
246    pub(crate) redis: RedisConnectionInfo,
247}
248
249impl ConnectionInfo {
250    /// Returns the connection address.
251    pub fn addr(&self) -> &ConnectionAddr {
252        &self.addr
253    }
254
255    /// Returns the settings for the TCP connection.
256    pub fn tcp_settings(&self) -> &TcpSettings {
257        &self.tcp_settings
258    }
259
260    /// Returns the redis connection info for how to handshake with redis.
261    pub fn redis_settings(&self) -> &RedisConnectionInfo {
262        &self.redis
263    }
264
265    /// Sets the connection address for where to connect to.
266    pub fn set_addr(mut self, addr: ConnectionAddr) -> Self {
267        self.addr = addr;
268        self
269    }
270
271    /// Sets the TCP settings for the connection.
272    pub fn set_tcp_settings(mut self, tcp_settings: TcpSettings) -> Self {
273        self.tcp_settings = tcp_settings;
274        self
275    }
276
277    /// Set all redis connection info fields at once.
278    pub fn set_redis_settings(mut self, redis: RedisConnectionInfo) -> Self {
279        self.redis = redis;
280        self
281    }
282}
283
284/// Redis specific/connection independent information used to establish a connection to redis.
285#[derive(Clone, Debug, Default)]
286pub struct RedisConnectionInfo {
287    /// The database number to use.  This is usually `0`.
288    pub(crate) db: i64,
289    /// Optionally a username that should be used for connection.
290    pub(crate) username: Option<ArcStr>,
291    /// Optionally a password that should be used for connection.
292    pub(crate) password: Option<ArcStr>,
293    /// Version of the protocol to use.
294    pub(crate) protocol: ProtocolVersion,
295    /// If set, the connection shouldn't send the library name to the server.
296    pub(crate) skip_set_lib_name: bool,
297}
298
299impl RedisConnectionInfo {
300    /// Returns the username that should be used for connection.
301    pub fn username(&self) -> Option<&str> {
302        self.username.as_deref()
303    }
304
305    /// Returns the password that should be used for connection.
306    pub fn password(&self) -> Option<&str> {
307        self.password.as_deref()
308    }
309
310    /// Returns version of the protocol to use.
311    pub fn protocol(&self) -> ProtocolVersion {
312        self.protocol
313    }
314
315    /// Returns `true` if the `CLIENT SETINFO` command should be skipped.
316    pub fn skip_set_lib_name(&self) -> bool {
317        self.skip_set_lib_name
318    }
319
320    /// Returns the database number to use.
321    pub fn db(&self) -> i64 {
322        self.db
323    }
324
325    /// Sets the username for the connection's ACL.
326    pub fn set_username(mut self, username: impl AsRef<str>) -> Self {
327        self.username = Some(username.as_ref().into());
328        self
329    }
330
331    /// Sets the password for the connection's ACL.
332    pub fn set_password(mut self, password: impl AsRef<str>) -> Self {
333        self.password = Some(password.as_ref().into());
334        self
335    }
336
337    /// Sets the version of the RESP to use.
338    pub fn set_protocol(mut self, protocol: ProtocolVersion) -> Self {
339        self.protocol = protocol;
340        self
341    }
342
343    /// Removes the pipelined `CLIENT SETINFO` call from the connection creation.
344    pub fn set_skip_set_lib_name(mut self) -> Self {
345        self.skip_set_lib_name = true;
346        self
347    }
348
349    /// Sets the database number to use.
350    pub fn set_db(mut self, db: i64) -> Self {
351        self.db = db;
352        self
353    }
354}
355
356impl FromStr for ConnectionInfo {
357    type Err = RedisError;
358
359    fn from_str(s: &str) -> Result<Self, Self::Err> {
360        s.into_connection_info()
361    }
362}
363
364/// Converts an object into a connection info struct.  This allows the
365/// constructor of the client to accept connection information in a
366/// range of different formats.
367pub trait IntoConnectionInfo {
368    /// Converts the object into a connection info object.
369    fn into_connection_info(self) -> RedisResult<ConnectionInfo>;
370}
371
372impl IntoConnectionInfo for ConnectionInfo {
373    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
374        Ok(self)
375    }
376}
377
378impl IntoConnectionInfo for ConnectionAddr {
379    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
380        Ok(ConnectionInfo {
381            addr: self,
382            redis: Default::default(),
383            tcp_settings: Default::default(),
384        })
385    }
386}
387
388/// URL format: `{redis|rediss|valkey|valkeys}://[<username>][:<password>@]<hostname>[:port][/<db>]`
389///
390/// - Basic: `redis://127.0.0.1:6379`
391/// - Username & Password: `redis://user:password@127.0.0.1:6379`
392/// - Password only: `redis://:password@127.0.0.1:6379`
393/// - Specifying DB: `redis://127.0.0.1:6379/0`
394/// - Enabling TLS: `rediss://127.0.0.1:6379`
395/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure`
396/// - Enabling RESP3: `redis://127.0.0.1:6379/?protocol=resp3`
397impl IntoConnectionInfo for &str {
398    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
399        match parse_redis_url(self) {
400            Some(u) => u.into_connection_info(),
401            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
402        }
403    }
404}
405
406impl<T> IntoConnectionInfo for (T, u16)
407where
408    T: Into<String>,
409{
410    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
411        Ok(ConnectionInfo {
412            addr: ConnectionAddr::Tcp(self.0.into(), self.1),
413            redis: RedisConnectionInfo::default(),
414            tcp_settings: TcpSettings::default(),
415        })
416    }
417}
418
419/// URL format: `{redis|rediss|valkey|valkeys}://[<username>][:<password>@]<hostname>[:port][/<db>]`
420///
421/// - Basic: `redis://127.0.0.1:6379`
422/// - Username & Password: `redis://user:password@127.0.0.1:6379`
423/// - Password only: `redis://:password@127.0.0.1:6379`
424/// - Specifying DB: `redis://127.0.0.1:6379/0`
425/// - Enabling TLS: `rediss://127.0.0.1:6379`
426/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure`
427/// - Enabling RESP3: `redis://127.0.0.1:6379/?protocol=resp3`
428impl IntoConnectionInfo for String {
429    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
430        match parse_redis_url(&self) {
431            Some(u) => u.into_connection_info(),
432            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
433        }
434    }
435}
436
437fn parse_protocol(query: &HashMap<Cow<str>, Cow<str>>) -> RedisResult<ProtocolVersion> {
438    Ok(match query.get("protocol") {
439        Some(protocol) => {
440            if protocol == "2" || protocol == "resp2" {
441                ProtocolVersion::RESP2
442            } else if protocol == "3" || protocol == "resp3" {
443                ProtocolVersion::RESP3
444            } else {
445                fail!((
446                    ErrorKind::InvalidClientConfig,
447                    "Invalid protocol version",
448                    protocol.to_string()
449                ))
450            }
451        }
452        None => ProtocolVersion::RESP2,
453    })
454}
455
456#[inline]
457pub(crate) fn is_wildcard_address(address: &str) -> bool {
458    address == "0.0.0.0" || address == "::"
459}
460
461fn url_to_tcp_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
462    let host = match url.host() {
463        Some(host) => {
464            // Here we manually match host's enum arms and call their to_string().
465            // Because url.host().to_string() will add `[` and `]` for ipv6:
466            // https://docs.rs/url/latest/src/url/host.rs.html#170
467            // And these brackets will break host.parse::<Ipv6Addr>() when
468            // `client.open()` - `ActualConnection::new()` - `addr.to_socket_addrs()`:
469            // https://doc.rust-lang.org/src/std/net/addr.rs.html#963
470            // https://doc.rust-lang.org/src/std/net/parser.rs.html#158
471            // IpAddr string with brackets can ONLY parse to SocketAddrV6:
472            // https://doc.rust-lang.org/src/std/net/parser.rs.html#255
473            // But if we call Ipv6Addr.to_string directly, it follows rfc5952 without brackets:
474            // https://doc.rust-lang.org/src/std/net/ip.rs.html#1755
475            let host_str = match host {
476                url::Host::Domain(path) => path.to_string(),
477                url::Host::Ipv4(v4) => v4.to_string(),
478                url::Host::Ipv6(v6) => v6.to_string(),
479            };
480
481            if is_wildcard_address(&host_str) {
482                return Err(RedisError::from((
483                    ErrorKind::InvalidClientConfig,
484                    "Cannot connect to a wildcard address (0.0.0.0 or ::)",
485                )));
486            }
487            host_str
488        }
489        None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")),
490    };
491    let port = url.port().unwrap_or(DEFAULT_PORT);
492    let addr = if url.scheme() == "rediss" || url.scheme() == "valkeys" {
493        #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
494        {
495            match url.fragment() {
496                Some("insecure") => ConnectionAddr::TcpTls {
497                    host,
498                    port,
499                    insecure: true,
500                    tls_params: None,
501                },
502                Some(_) => fail!((
503                    ErrorKind::InvalidClientConfig,
504                    "only #insecure is supported as URL fragment"
505                )),
506                _ => ConnectionAddr::TcpTls {
507                    host,
508                    port,
509                    insecure: false,
510                    tls_params: None,
511                },
512            }
513        }
514
515        #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
516        fail!((
517            ErrorKind::InvalidClientConfig,
518            "can't connect with TLS, the feature is not enabled"
519        ));
520    } else {
521        ConnectionAddr::Tcp(host, port)
522    };
523    let query: HashMap<_, _> = url.query_pairs().collect();
524    Ok(ConnectionInfo {
525        addr,
526        redis: RedisConnectionInfo {
527            db: match url.path().trim_matches('/') {
528                "" => 0,
529                path => path.parse::<i64>().map_err(|_| -> RedisError {
530                    (ErrorKind::InvalidClientConfig, "Invalid database number").into()
531                })?,
532            },
533            username: if url.username().is_empty() {
534                None
535            } else {
536                match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() {
537                    Ok(decoded) => Some(decoded.into()),
538                    Err(_) => fail!((
539                        ErrorKind::InvalidClientConfig,
540                        "Username is not valid UTF-8 string"
541                    )),
542                }
543            },
544            password: match url.password() {
545                Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() {
546                    Ok(decoded) => Some(decoded.into()),
547                    Err(_) => fail!((
548                        ErrorKind::InvalidClientConfig,
549                        "Password is not valid UTF-8 string"
550                    )),
551                },
552                None => None,
553            },
554            protocol: parse_protocol(&query)?,
555            skip_set_lib_name: false,
556        },
557        tcp_settings: TcpSettings::default(),
558    })
559}
560
561#[cfg(unix)]
562fn url_to_unix_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
563    let query: HashMap<_, _> = url.query_pairs().collect();
564    Ok(ConnectionInfo {
565        addr: ConnectionAddr::Unix(url.to_file_path().map_err(|_| -> RedisError {
566            (ErrorKind::InvalidClientConfig, "Missing path").into()
567        })?),
568        redis: RedisConnectionInfo {
569            db: match query.get("db") {
570                Some(db) => db.parse::<i64>().map_err(|_| -> RedisError {
571                    (ErrorKind::InvalidClientConfig, "Invalid database number").into()
572                })?,
573
574                None => 0,
575            },
576            username: query.get("user").map(|username| username.as_ref().into()),
577            password: query.get("pass").map(|password| password.as_ref().into()),
578            protocol: parse_protocol(&query)?,
579            ..Default::default()
580        },
581        tcp_settings: TcpSettings::default(),
582    })
583}
584
585#[cfg(not(unix))]
586fn url_to_unix_connection_info(_: url::Url) -> RedisResult<ConnectionInfo> {
587    fail!((
588        ErrorKind::InvalidClientConfig,
589        "Unix sockets are not available on this platform."
590    ));
591}
592
593impl IntoConnectionInfo for url::Url {
594    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
595        match self.scheme() {
596            "redis" | "rediss" | "valkey" | "valkeys" => url_to_tcp_connection_info(self),
597            "unix" | "redis+unix" | "valkey+unix" => url_to_unix_connection_info(self),
598            _ => fail!((
599                ErrorKind::InvalidClientConfig,
600                "URL provided is not a redis URL"
601            )),
602        }
603    }
604}
605
606struct TcpConnection {
607    reader: TcpStream,
608    open: bool,
609}
610
611#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
612struct TcpNativeTlsConnection {
613    reader: TlsStream<TcpStream>,
614    open: bool,
615}
616
617#[cfg(feature = "tls-rustls")]
618struct TcpRustlsConnection {
619    reader: StreamOwned<rustls::ClientConnection, TcpStream>,
620    open: bool,
621}
622
623#[cfg(unix)]
624struct UnixConnection {
625    sock: UnixStream,
626    open: bool,
627}
628
629enum ActualConnection {
630    Tcp(TcpConnection),
631    #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
632    TcpNativeTls(Box<TcpNativeTlsConnection>),
633    #[cfg(feature = "tls-rustls")]
634    TcpRustls(Box<TcpRustlsConnection>),
635    #[cfg(unix)]
636    Unix(UnixConnection),
637}
638
639#[cfg(feature = "tls-rustls-insecure")]
640struct NoCertificateVerification {
641    supported: rustls::crypto::WebPkiSupportedAlgorithms,
642}
643
644#[cfg(feature = "tls-rustls-insecure")]
645impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
646    fn verify_server_cert(
647        &self,
648        _end_entity: &rustls::pki_types::CertificateDer<'_>,
649        _intermediates: &[rustls::pki_types::CertificateDer<'_>],
650        _server_name: &rustls::pki_types::ServerName<'_>,
651        _ocsp_response: &[u8],
652        _now: rustls::pki_types::UnixTime,
653    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
654        Ok(rustls::client::danger::ServerCertVerified::assertion())
655    }
656
657    fn verify_tls12_signature(
658        &self,
659        _message: &[u8],
660        _cert: &rustls::pki_types::CertificateDer<'_>,
661        _dss: &rustls::DigitallySignedStruct,
662    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
663        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
664    }
665
666    fn verify_tls13_signature(
667        &self,
668        _message: &[u8],
669        _cert: &rustls::pki_types::CertificateDer<'_>,
670        _dss: &rustls::DigitallySignedStruct,
671    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
672        Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
673    }
674
675    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
676        self.supported.supported_schemes()
677    }
678}
679
680#[cfg(feature = "tls-rustls-insecure")]
681impl fmt::Debug for NoCertificateVerification {
682    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
683        f.debug_struct("NoCertificateVerification").finish()
684    }
685}
686
687/// Insecure `ServerCertVerifier` for rustls that implements `danger_accept_invalid_hostnames`.
688#[cfg(feature = "tls-rustls-insecure")]
689#[derive(Debug)]
690struct AcceptInvalidHostnamesCertVerifier {
691    inner: Arc<rustls::client::WebPkiServerVerifier>,
692}
693
694#[cfg(feature = "tls-rustls-insecure")]
695fn is_hostname_error(err: &rustls::Error) -> bool {
696    matches!(
697        err,
698        rustls::Error::InvalidCertificate(
699            rustls::CertificateError::NotValidForName
700                | rustls::CertificateError::NotValidForNameContext { .. }
701        )
702    )
703}
704
705#[cfg(feature = "tls-rustls-insecure")]
706impl rustls::client::danger::ServerCertVerifier for AcceptInvalidHostnamesCertVerifier {
707    fn verify_server_cert(
708        &self,
709        end_entity: &rustls::pki_types::CertificateDer<'_>,
710        intermediates: &[rustls::pki_types::CertificateDer<'_>],
711        server_name: &rustls::pki_types::ServerName<'_>,
712        ocsp_response: &[u8],
713        now: rustls::pki_types::UnixTime,
714    ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
715        self.inner
716            .verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
717            .or_else(|err| {
718                if is_hostname_error(&err) {
719                    Ok(rustls::client::danger::ServerCertVerified::assertion())
720                } else {
721                    Err(err)
722                }
723            })
724    }
725
726    fn verify_tls12_signature(
727        &self,
728        message: &[u8],
729        cert: &rustls::pki_types::CertificateDer<'_>,
730        dss: &rustls::DigitallySignedStruct,
731    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
732        self.inner
733            .verify_tls12_signature(message, cert, dss)
734            .or_else(|err| {
735                if is_hostname_error(&err) {
736                    Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
737                } else {
738                    Err(err)
739                }
740            })
741    }
742
743    fn verify_tls13_signature(
744        &self,
745        message: &[u8],
746        cert: &rustls::pki_types::CertificateDer<'_>,
747        dss: &rustls::DigitallySignedStruct,
748    ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
749        self.inner
750            .verify_tls13_signature(message, cert, dss)
751            .or_else(|err| {
752                if is_hostname_error(&err) {
753                    Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
754                } else {
755                    Err(err)
756                }
757            })
758    }
759
760    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
761        self.inner.supported_verify_schemes()
762    }
763}
764
765/// Represents a stateful redis TCP connection.
766pub struct Connection {
767    con: ActualConnection,
768    parser: Parser,
769    db: i64,
770
771    /// Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`.
772    ///
773    /// This flag is checked when attempting to send a command, and if it's raised, we attempt to
774    /// exit the pubsub state before executing the new request.
775    pubsub: bool,
776
777    // Field indicating which protocol to use for server communications.
778    protocol: ProtocolVersion,
779
780    /// This is used to manage Push messages in RESP3 mode.
781    push_sender: Option<SyncPushSender>,
782
783    /// The number of messages that are expected to be returned from the server,
784    /// but the user no longer waits for - answers for requests that already returned a transient error.
785    messages_to_skip: usize,
786}
787
788/// Represents a RESP2 pubsub connection.
789///
790/// If you're using a DB that supports RESP3, consider using a regular connection and setting a push sender it using [Connection::set_push_sender].
791pub struct PubSub<'a> {
792    con: &'a mut Connection,
793    waiting_messages: VecDeque<Msg>,
794}
795
796/// Represents a pubsub message.
797#[derive(Debug, Clone)]
798pub struct Msg {
799    payload: Value,
800    channel: Value,
801    pattern: Option<Value>,
802}
803
804impl ActualConnection {
805    pub fn new(
806        addr: &ConnectionAddr,
807        timeout: Option<Duration>,
808        tcp_settings: &TcpSettings,
809    ) -> RedisResult<ActualConnection> {
810        Ok(match *addr {
811            ConnectionAddr::Tcp(ref host, ref port) => {
812                if is_wildcard_address(host) {
813                    fail!((
814                        ErrorKind::InvalidClientConfig,
815                        "Cannot connect to a wildcard address (0.0.0.0 or ::)"
816                    ));
817                }
818                let addr = (host.as_str(), *port);
819                let tcp = match timeout {
820                    None => connect_tcp(addr, tcp_settings)?,
821                    Some(timeout) => {
822                        let mut tcp = None;
823                        let mut last_error = None;
824                        for addr in addr.to_socket_addrs()? {
825                            match connect_tcp_timeout(&addr, timeout, tcp_settings) {
826                                Ok(l) => {
827                                    tcp = Some(l);
828                                    break;
829                                }
830                                Err(e) => {
831                                    last_error = Some(e);
832                                }
833                            };
834                        }
835                        match (tcp, last_error) {
836                            (Some(tcp), _) => tcp,
837                            (None, Some(e)) => {
838                                fail!(e);
839                            }
840                            (None, None) => {
841                                fail!((
842                                    ErrorKind::InvalidClientConfig,
843                                    "could not resolve to any addresses"
844                                ));
845                            }
846                        }
847                    }
848                };
849                ActualConnection::Tcp(TcpConnection {
850                    reader: tcp,
851                    open: true,
852                })
853            }
854            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
855            ConnectionAddr::TcpTls {
856                ref host,
857                port,
858                insecure,
859                ref tls_params,
860            } => {
861                let tls_connector = if insecure {
862                    TlsConnector::builder()
863                        .danger_accept_invalid_certs(true)
864                        .danger_accept_invalid_hostnames(true)
865                        .use_sni(false)
866                        .build()?
867                } else if let Some(params) = tls_params {
868                    TlsConnector::builder()
869                        .danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames)
870                        .build()?
871                } else {
872                    TlsConnector::new()?
873                };
874                let addr = (host.as_str(), port);
875                let tls = match timeout {
876                    None => {
877                        let tcp = connect_tcp(addr, tcp_settings)?;
878                        match tls_connector.connect(host, tcp) {
879                            Ok(res) => res,
880                            Err(e) => {
881                                fail!((ErrorKind::Io, "SSL Handshake error", e.to_string()));
882                            }
883                        }
884                    }
885                    Some(timeout) => {
886                        let mut tcp = None;
887                        let mut last_error = None;
888                        for addr in (host.as_str(), port).to_socket_addrs()? {
889                            match connect_tcp_timeout(&addr, timeout, tcp_settings) {
890                                Ok(l) => {
891                                    tcp = Some(l);
892                                    break;
893                                }
894                                Err(e) => {
895                                    last_error = Some(e);
896                                }
897                            };
898                        }
899                        match (tcp, last_error) {
900                            (Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(),
901                            (None, Some(e)) => {
902                                fail!(e);
903                            }
904                            (None, None) => {
905                                fail!((
906                                    ErrorKind::InvalidClientConfig,
907                                    "could not resolve to any addresses"
908                                ));
909                            }
910                        }
911                    }
912                };
913                ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection {
914                    reader: tls,
915                    open: true,
916                }))
917            }
918            #[cfg(feature = "tls-rustls")]
919            ConnectionAddr::TcpTls {
920                ref host,
921                port,
922                insecure,
923                ref tls_params,
924            } => {
925                let host: &str = host;
926                let config = create_rustls_config(insecure, tls_params.clone())?;
927                let conn = rustls::ClientConnection::new(
928                    Arc::new(config),
929                    rustls::pki_types::ServerName::try_from(host)?.to_owned(),
930                )?;
931                let reader = match timeout {
932                    None => {
933                        let tcp = connect_tcp((host, port), tcp_settings)?;
934                        StreamOwned::new(conn, tcp)
935                    }
936                    Some(timeout) => {
937                        let mut tcp = None;
938                        let mut last_error = None;
939                        for addr in (host, port).to_socket_addrs()? {
940                            match connect_tcp_timeout(&addr, timeout, tcp_settings) {
941                                Ok(l) => {
942                                    tcp = Some(l);
943                                    break;
944                                }
945                                Err(e) => {
946                                    last_error = Some(e);
947                                }
948                            };
949                        }
950                        match (tcp, last_error) {
951                            (Some(tcp), _) => StreamOwned::new(conn, tcp),
952                            (None, Some(e)) => {
953                                fail!(e);
954                            }
955                            (None, None) => {
956                                fail!((
957                                    ErrorKind::InvalidClientConfig,
958                                    "could not resolve to any addresses"
959                                ));
960                            }
961                        }
962                    }
963                };
964
965                ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true }))
966            }
967            #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
968            ConnectionAddr::TcpTls { .. } => {
969                fail!((
970                    ErrorKind::InvalidClientConfig,
971                    "Cannot connect to TCP with TLS without the tls feature"
972                ));
973            }
974            #[cfg(unix)]
975            ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection {
976                sock: UnixStream::connect(path)?,
977                open: true,
978            }),
979            #[cfg(not(unix))]
980            ConnectionAddr::Unix(ref _path) => {
981                fail!((
982                    ErrorKind::InvalidClientConfig,
983                    "Cannot connect to unix sockets \
984                     on this platform"
985                ));
986            }
987        })
988    }
989
990    pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
991        match *self {
992            ActualConnection::Tcp(ref mut connection) => {
993                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
994                match res {
995                    Err(e) => {
996                        if e.is_unrecoverable_error() {
997                            connection.open = false;
998                        }
999                        Err(e)
1000                    }
1001                    Ok(_) => Ok(Value::Okay),
1002                }
1003            }
1004            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1005            ActualConnection::TcpNativeTls(ref mut connection) => {
1006                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
1007                match res {
1008                    Err(e) => {
1009                        if e.is_unrecoverable_error() {
1010                            connection.open = false;
1011                        }
1012                        Err(e)
1013                    }
1014                    Ok(_) => Ok(Value::Okay),
1015                }
1016            }
1017            #[cfg(feature = "tls-rustls")]
1018            ActualConnection::TcpRustls(ref mut connection) => {
1019                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
1020                match res {
1021                    Err(e) => {
1022                        if e.is_unrecoverable_error() {
1023                            connection.open = false;
1024                        }
1025                        Err(e)
1026                    }
1027                    Ok(_) => Ok(Value::Okay),
1028                }
1029            }
1030            #[cfg(unix)]
1031            ActualConnection::Unix(ref mut connection) => {
1032                let result = connection.sock.write_all(bytes).map_err(RedisError::from);
1033                match result {
1034                    Err(e) => {
1035                        if e.is_unrecoverable_error() {
1036                            connection.open = false;
1037                        }
1038                        Err(e)
1039                    }
1040                    Ok(_) => Ok(Value::Okay),
1041                }
1042            }
1043        }
1044    }
1045
1046    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1047        match *self {
1048            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
1049                reader.set_write_timeout(dur)?;
1050            }
1051            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1052            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
1053                let reader = &(boxed_tls_connection.reader);
1054                reader.get_ref().set_write_timeout(dur)?;
1055            }
1056            #[cfg(feature = "tls-rustls")]
1057            ActualConnection::TcpRustls(ref boxed_tls_connection) => {
1058                let reader = &(boxed_tls_connection.reader);
1059                reader.get_ref().set_write_timeout(dur)?;
1060            }
1061            #[cfg(unix)]
1062            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
1063                sock.set_write_timeout(dur)?;
1064            }
1065        }
1066        Ok(())
1067    }
1068
1069    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1070        match *self {
1071            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
1072                reader.set_read_timeout(dur)?;
1073            }
1074            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1075            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
1076                let reader = &(boxed_tls_connection.reader);
1077                reader.get_ref().set_read_timeout(dur)?;
1078            }
1079            #[cfg(feature = "tls-rustls")]
1080            ActualConnection::TcpRustls(ref boxed_tls_connection) => {
1081                let reader = &(boxed_tls_connection.reader);
1082                reader.get_ref().set_read_timeout(dur)?;
1083            }
1084            #[cfg(unix)]
1085            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
1086                sock.set_read_timeout(dur)?;
1087            }
1088        }
1089        Ok(())
1090    }
1091
1092    pub fn is_open(&self) -> bool {
1093        match *self {
1094            ActualConnection::Tcp(TcpConnection { open, .. }) => open,
1095            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1096            ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open,
1097            #[cfg(feature = "tls-rustls")]
1098            ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open,
1099            #[cfg(unix)]
1100            ActualConnection::Unix(UnixConnection { open, .. }) => open,
1101        }
1102    }
1103}
1104
1105#[cfg(feature = "tls-rustls")]
1106pub(crate) fn create_rustls_config(
1107    insecure: bool,
1108    tls_params: Option<TlsConnParams>,
1109) -> RedisResult<rustls::ClientConfig> {
1110    #[allow(unused_mut)]
1111    let mut root_store = RootCertStore::empty();
1112    #[cfg(feature = "tls-rustls-webpki-roots")]
1113    root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
1114    #[cfg(all(
1115        feature = "tls-rustls",
1116        not(feature = "tls-native-tls"),
1117        not(feature = "tls-rustls-webpki-roots")
1118    ))]
1119    {
1120        let mut certificate_result = load_native_certs();
1121        if let Some(error) = certificate_result.errors.pop() {
1122            return Err(error.into());
1123        }
1124        for cert in certificate_result.certs {
1125            root_store.add(cert)?;
1126        }
1127    }
1128
1129    let config = rustls::ClientConfig::builder();
1130    let config = if let Some(tls_params) = tls_params {
1131        let root_cert_store = tls_params.root_cert_store.unwrap_or(root_store);
1132        let config_builder = config.with_root_certificates(root_cert_store.clone());
1133
1134        let config_builder = if let Some(ClientTlsParams {
1135            client_cert_chain: client_cert,
1136            client_key,
1137        }) = tls_params.client_tls_params
1138        {
1139            config_builder
1140                .with_client_auth_cert(client_cert, client_key)
1141                .map_err(|err| {
1142                    RedisError::from((
1143                        ErrorKind::InvalidClientConfig,
1144                        "Unable to build client with TLS parameters provided.",
1145                        err.to_string(),
1146                    ))
1147                })?
1148        } else {
1149            config_builder.with_no_client_auth()
1150        };
1151
1152        // Implement `danger_accept_invalid_hostnames`.
1153        //
1154        // The strange cfg here is to handle a specific unusual combination of features: if
1155        // `tls-native-tls` and `tls-rustls` are enabled, but `tls-rustls-insecure` is not, and the
1156        // application tries to use the danger flag.
1157        #[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
1158        let config_builder = if !insecure && tls_params.danger_accept_invalid_hostnames {
1159            #[cfg(not(feature = "tls-rustls-insecure"))]
1160            {
1161                // This code should not enable an insecure mode if the `insecure` feature is not
1162                // set, but it shouldn't silently ignore the flag either. So return an error.
1163                fail!((
1164                    ErrorKind::InvalidClientConfig,
1165                    "Cannot create insecure client via danger_accept_invalid_hostnames without tls-rustls-insecure feature"
1166                ));
1167            }
1168
1169            #[cfg(feature = "tls-rustls-insecure")]
1170            {
1171                let mut config = config_builder;
1172                config.dangerous().set_certificate_verifier(Arc::new(
1173                    AcceptInvalidHostnamesCertVerifier {
1174                        inner: rustls::client::WebPkiServerVerifier::builder(Arc::new(
1175                            root_cert_store,
1176                        ))
1177                        .build()
1178                        .map_err(|err| rustls::Error::from(rustls::OtherError(Arc::new(err))))?,
1179                    },
1180                ));
1181                config
1182            }
1183        } else {
1184            config_builder
1185        };
1186
1187        config_builder
1188    } else {
1189        config
1190            .with_root_certificates(root_store)
1191            .with_no_client_auth()
1192    };
1193
1194    match (insecure, cfg!(feature = "tls-rustls-insecure")) {
1195        #[cfg(feature = "tls-rustls-insecure")]
1196        (true, true) => {
1197            let mut config = config;
1198            config.enable_sni = false;
1199            let Some(crypto_provider) = rustls::crypto::CryptoProvider::get_default() else {
1200                return Err(RedisError::from((
1201                    ErrorKind::InvalidClientConfig,
1202                    "No crypto provider available for rustls",
1203                )));
1204            };
1205            config
1206                .dangerous()
1207                .set_certificate_verifier(Arc::new(NoCertificateVerification {
1208                    supported: crypto_provider.signature_verification_algorithms,
1209                }));
1210
1211            Ok(config)
1212        }
1213        (true, false) => {
1214            fail!((
1215                ErrorKind::InvalidClientConfig,
1216                "Cannot create insecure client without tls-rustls-insecure feature"
1217            ));
1218        }
1219        _ => Ok(config),
1220    }
1221}
1222
1223fn authenticate_cmd(
1224    connection_info: &RedisConnectionInfo,
1225    check_username: bool,
1226    password: &str,
1227) -> Cmd {
1228    let mut command = cmd("AUTH");
1229    if check_username {
1230        if let Some(username) = &connection_info.username {
1231            command.arg(username.as_str());
1232        }
1233    }
1234    command.arg(password);
1235    command
1236}
1237
1238pub fn connect(
1239    connection_info: &ConnectionInfo,
1240    timeout: Option<Duration>,
1241) -> RedisResult<Connection> {
1242    let start = Instant::now();
1243    let con: ActualConnection = ActualConnection::new(
1244        &connection_info.addr,
1245        timeout,
1246        &connection_info.tcp_settings,
1247    )?;
1248
1249    // we temporarily set the timeout, and will remove it after finishing setup.
1250    let remaining_timeout = timeout.and_then(|timeout| timeout.checked_sub(start.elapsed()));
1251    // TLS could run logic that doesn't contain a timeout, and should fail if it takes too long.
1252    if timeout.is_some() && remaining_timeout.is_none() {
1253        return Err(RedisError::from(std::io::Error::new(
1254            std::io::ErrorKind::TimedOut,
1255            "Connection timed out",
1256        )));
1257    }
1258    con.set_read_timeout(remaining_timeout)?;
1259    con.set_write_timeout(remaining_timeout)?;
1260
1261    let con = setup_connection(
1262        con,
1263        &connection_info.redis,
1264        #[cfg(feature = "cache-aio")]
1265        None,
1266    )?;
1267
1268    // remove the temporary timeout.
1269    con.set_read_timeout(None)?;
1270    con.set_write_timeout(None)?;
1271
1272    Ok(con)
1273}
1274
1275pub(crate) struct ConnectionSetupComponents {
1276    resp3_auth_cmd_idx: Option<usize>,
1277    resp2_auth_cmd_idx: Option<usize>,
1278    select_cmd_idx: Option<usize>,
1279    #[cfg(feature = "cache-aio")]
1280    cache_cmd_idx: Option<usize>,
1281}
1282
1283pub(crate) fn connection_setup_pipeline(
1284    connection_info: &RedisConnectionInfo,
1285    check_username: bool,
1286    #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1287) -> (crate::Pipeline, ConnectionSetupComponents) {
1288    let mut pipeline = pipe();
1289    let (authenticate_with_resp3_cmd_index, authenticate_with_resp2_cmd_index) =
1290        if connection_info.protocol.supports_resp3() {
1291            pipeline.add_command(resp3_hello(connection_info));
1292            (Some(0), None)
1293        } else if connection_info.password.is_some() {
1294            pipeline.add_command(authenticate_cmd(
1295                connection_info,
1296                check_username,
1297                connection_info.password.as_ref().unwrap(),
1298            ));
1299            (None, Some(0))
1300        } else {
1301            (None, None)
1302        };
1303
1304    let select_db_cmd_index = (connection_info.db != 0)
1305        .then(|| pipeline.len())
1306        .inspect(|_| {
1307            pipeline.cmd("SELECT").arg(connection_info.db);
1308        });
1309
1310    #[cfg(feature = "cache-aio")]
1311    let cache_cmd_index = cache_config.map(|cache_config| {
1312        pipeline.cmd("CLIENT").arg("TRACKING").arg("ON");
1313        match cache_config.mode {
1314            crate::caching::CacheMode::All => {}
1315            crate::caching::CacheMode::OptIn => {
1316                pipeline.arg("OPTIN");
1317            }
1318        }
1319        pipeline.len() - 1
1320    });
1321
1322    // result is ignored, as per the command's instructions.
1323    // https://redis.io/commands/client-setinfo/
1324    if !connection_info.skip_set_lib_name {
1325        pipeline
1326            .cmd("CLIENT")
1327            .arg("SETINFO")
1328            .arg("LIB-NAME")
1329            .arg("redis-rs")
1330            .ignore();
1331        pipeline
1332            .cmd("CLIENT")
1333            .arg("SETINFO")
1334            .arg("LIB-VER")
1335            .arg(env!("CARGO_PKG_VERSION"))
1336            .ignore();
1337    }
1338
1339    (
1340        pipeline,
1341        ConnectionSetupComponents {
1342            resp3_auth_cmd_idx: authenticate_with_resp3_cmd_index,
1343            resp2_auth_cmd_idx: authenticate_with_resp2_cmd_index,
1344            select_cmd_idx: select_db_cmd_index,
1345            #[cfg(feature = "cache-aio")]
1346            cache_cmd_idx: cache_cmd_index,
1347        },
1348    )
1349}
1350
1351fn check_resp3_auth(result: &Value) -> RedisResult<()> {
1352    if let Value::ServerError(err) = result {
1353        return Err(get_resp3_hello_command_error(err.clone().into()));
1354    }
1355    Ok(())
1356}
1357
1358#[derive(PartialEq)]
1359pub(crate) enum AuthResult {
1360    Succeeded,
1361    ShouldRetryWithoutUsername,
1362}
1363
1364fn check_resp2_auth(result: &Value) -> RedisResult<AuthResult> {
1365    let err = match result {
1366        Value::Okay => {
1367            return Ok(AuthResult::Succeeded);
1368        }
1369        Value::ServerError(err) => err,
1370        _ => {
1371            return Err((
1372                ServerErrorKind::ResponseError.into(),
1373                "Redis server refused to authenticate, returns Ok() != Value::Okay",
1374            )
1375                .into());
1376        }
1377    };
1378
1379    let err_msg = err.details().ok_or((
1380        ErrorKind::AuthenticationFailed,
1381        "Password authentication failed",
1382    ))?;
1383    if !err_msg.contains("wrong number of arguments for 'auth' command") {
1384        return Err((
1385            ErrorKind::AuthenticationFailed,
1386            "Password authentication failed",
1387        )
1388            .into());
1389    }
1390    Ok(AuthResult::ShouldRetryWithoutUsername)
1391}
1392
1393fn check_db_select(value: &Value) -> RedisResult<()> {
1394    let Value::ServerError(err) = value else {
1395        return Ok(());
1396    };
1397
1398    match err.details() {
1399        Some(err_msg) => Err((
1400            ServerErrorKind::ResponseError.into(),
1401            "Redis server refused to switch database",
1402            err_msg.to_string(),
1403        )
1404            .into()),
1405        None => Err((
1406            ServerErrorKind::ResponseError.into(),
1407            "Redis server refused to switch database",
1408        )
1409            .into()),
1410    }
1411}
1412
1413#[cfg(feature = "cache-aio")]
1414fn check_caching(result: &Value) -> RedisResult<()> {
1415    match result {
1416        Value::Okay => Ok(()),
1417        _ => Err((
1418            ServerErrorKind::ResponseError.into(),
1419            "Client-side caching returned unknown response",
1420            format!("{result:?}"),
1421        )
1422            .into()),
1423    }
1424}
1425
1426pub(crate) fn check_connection_setup(
1427    results: Vec<Value>,
1428    ConnectionSetupComponents {
1429        resp3_auth_cmd_idx,
1430        resp2_auth_cmd_idx,
1431        select_cmd_idx,
1432        #[cfg(feature = "cache-aio")]
1433        cache_cmd_idx,
1434    }: ConnectionSetupComponents,
1435) -> RedisResult<AuthResult> {
1436    // can't have both values set
1437    assert!(!(resp2_auth_cmd_idx.is_some() && resp3_auth_cmd_idx.is_some()));
1438
1439    if let Some(index) = resp3_auth_cmd_idx {
1440        let Some(value) = results.get(index) else {
1441            return Err((ErrorKind::Client, "Missing RESP3 auth response").into());
1442        };
1443        check_resp3_auth(value)?;
1444    } else if let Some(index) = resp2_auth_cmd_idx {
1445        let Some(value) = results.get(index) else {
1446            return Err((ErrorKind::Client, "Missing RESP2 auth response").into());
1447        };
1448        if check_resp2_auth(value)? == AuthResult::ShouldRetryWithoutUsername {
1449            return Ok(AuthResult::ShouldRetryWithoutUsername);
1450        }
1451    }
1452
1453    if let Some(index) = select_cmd_idx {
1454        let Some(value) = results.get(index) else {
1455            return Err((ErrorKind::Client, "Missing SELECT DB response").into());
1456        };
1457        check_db_select(value)?;
1458    }
1459
1460    #[cfg(feature = "cache-aio")]
1461    if let Some(index) = cache_cmd_idx {
1462        let Some(value) = results.get(index) else {
1463            return Err((ErrorKind::Client, "Missing Caching response").into());
1464        };
1465        check_caching(value)?;
1466    }
1467
1468    Ok(AuthResult::Succeeded)
1469}
1470
1471fn execute_connection_pipeline(
1472    rv: &mut Connection,
1473    (pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents),
1474) -> RedisResult<AuthResult> {
1475    if pipeline.is_empty() {
1476        return Ok(AuthResult::Succeeded);
1477    }
1478    let results = rv.req_packed_commands(&pipeline.get_packed_pipeline(), 0, pipeline.len())?;
1479
1480    check_connection_setup(results, instructions)
1481}
1482
1483fn setup_connection(
1484    con: ActualConnection,
1485    connection_info: &RedisConnectionInfo,
1486    #[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
1487) -> RedisResult<Connection> {
1488    let mut rv = Connection {
1489        con,
1490        parser: Parser::new(),
1491        db: connection_info.db,
1492        pubsub: false,
1493        protocol: connection_info.protocol,
1494        push_sender: None,
1495        messages_to_skip: 0,
1496    };
1497
1498    if execute_connection_pipeline(
1499        &mut rv,
1500        connection_setup_pipeline(
1501            connection_info,
1502            true,
1503            #[cfg(feature = "cache-aio")]
1504            cache_config,
1505        ),
1506    )? == AuthResult::ShouldRetryWithoutUsername
1507    {
1508        execute_connection_pipeline(
1509            &mut rv,
1510            connection_setup_pipeline(
1511                connection_info,
1512                false,
1513                #[cfg(feature = "cache-aio")]
1514                cache_config,
1515            ),
1516        )?;
1517    }
1518
1519    Ok(rv)
1520}
1521
1522/// Implements the "stateless" part of the connection interface that is used by the
1523/// different objects in redis-rs.
1524///
1525/// Primarily it obviously applies to `Connection` object but also some other objects
1526///  implement the interface (for instance whole clients or certain redis results).
1527///
1528/// Generally clients and connections (as well as redis results of those) implement
1529/// this trait.  Actual connections provide more functionality which can be used
1530/// to implement things like `PubSub` but they also can modify the intrinsic
1531/// state of the TCP connection.  This is not possible with `ConnectionLike`
1532/// implementors because that functionality is not exposed.
1533pub trait ConnectionLike {
1534    /// Sends an already encoded (packed) command into the TCP socket and
1535    /// reads the single response from it.
1536    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value>;
1537
1538    /// Sends multiple already encoded (packed) command into the TCP socket
1539    /// and reads `count` responses from it.  This is used to implement
1540    /// pipelining.
1541    /// Important - this function is meant for internal usage, since it's
1542    /// easy to pass incorrect `offset` & `count` parameters, which might
1543    /// cause the connection to enter an erroneous state. Users shouldn't
1544    /// call it, instead using the Pipeline::query function.
1545    #[doc(hidden)]
1546    fn req_packed_commands(
1547        &mut self,
1548        cmd: &[u8],
1549        offset: usize,
1550        count: usize,
1551    ) -> RedisResult<Vec<Value>>;
1552
1553    /// Sends a [Cmd] into the TCP socket and reads a single response from it.
1554    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1555        let pcmd = cmd.get_packed_command();
1556        self.req_packed_command(&pcmd)
1557    }
1558
1559    /// Returns the database this connection is bound to.  Note that this
1560    /// information might be unreliable because it's initially cached and
1561    /// also might be incorrect if the connection like object is not
1562    /// actually connected.
1563    fn get_db(&self) -> i64;
1564
1565    /// Does this connection support pipelining?
1566    #[doc(hidden)]
1567    fn supports_pipelining(&self) -> bool {
1568        true
1569    }
1570
1571    /// Check that all connections it has are available (`PING` internally).
1572    fn check_connection(&mut self) -> bool;
1573
1574    /// Returns the connection status.
1575    ///
1576    /// The connection is open until any `read` call received an
1577    /// invalid response from the server (most likely a closed or dropped
1578    /// connection, otherwise a Redis protocol error). When using unix
1579    /// sockets the connection is open until writing a command failed with a
1580    /// `BrokenPipe` error.
1581    fn is_open(&self) -> bool;
1582}
1583
1584/// A connection is an object that represents a single redis connection.  It
1585/// provides basic support for sending encoded commands into a redis connection
1586/// and to read a response from it.  It's bound to a single database and can
1587/// only be created from the client.
1588///
1589/// You generally do not much with this object other than passing it to
1590/// `Cmd` objects.
1591impl Connection {
1592    /// Sends an already encoded (packed) command into the TCP socket and
1593    /// does not read a response.  This is useful for commands like
1594    /// `MONITOR` which yield multiple items.  This needs to be used with
1595    /// care because it changes the state of the connection.
1596    pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> {
1597        self.send_bytes(cmd)?;
1598        Ok(())
1599    }
1600
1601    /// Fetches a single response from the connection.  This is useful
1602    /// if used in combination with `send_packed_command`.
1603    pub fn recv_response(&mut self) -> RedisResult<Value> {
1604        self.read(true)
1605    }
1606
1607    /// Sets the write timeout for the connection.
1608    ///
1609    /// If the provided value is `None`, then `send_packed_command` call will
1610    /// block indefinitely. It is an error to pass the zero `Duration` to this
1611    /// method.
1612    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1613        self.con.set_write_timeout(dur)
1614    }
1615
1616    /// Sets the read timeout for the connection.
1617    ///
1618    /// If the provided value is `None`, then `recv_response` call will
1619    /// block indefinitely. It is an error to pass the zero `Duration` to this
1620    /// method.
1621    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1622        self.con.set_read_timeout(dur)
1623    }
1624
1625    /// Creates a [`PubSub`] instance for this connection.
1626    pub fn as_pubsub(&mut self) -> PubSub<'_> {
1627        // NOTE: The pubsub flag is intentionally not raised at this time since
1628        // running commands within the pubsub state should not try and exit from
1629        // the pubsub state.
1630        PubSub::new(self)
1631    }
1632
1633    fn exit_pubsub(&mut self) -> RedisResult<()> {
1634        let res = self.clear_active_subscriptions();
1635        if res.is_ok() {
1636            self.pubsub = false;
1637        } else {
1638            // Raise the pubsub flag to indicate the connection is "stuck" in that state.
1639            self.pubsub = true;
1640        }
1641
1642        res
1643    }
1644
1645    /// Get the inner connection out of a PubSub
1646    ///
1647    /// Any active subscriptions are unsubscribed. In the event of an error, the connection is
1648    /// dropped.
1649    fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
1650        // Responses to unsubscribe commands return in a 3-tuple with values
1651        // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs).
1652        // The "count of remaining subs" includes both pattern subscriptions and non pattern
1653        // subscriptions. Thus, to accurately drain all unsubscribe messages received from the
1654        // server, both commands need to be executed at once.
1655        {
1656            // Prepare both unsubscribe commands
1657            let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command();
1658            let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command();
1659
1660            // Execute commands
1661            self.send_bytes(&unsubscribe)?;
1662            self.send_bytes(&punsubscribe)?;
1663        }
1664
1665        // Receive responses
1666        //
1667        // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe
1668        // commands. There may be more responses if there are active subscriptions. In this case,
1669        // messages are received until the _subscription count_ in the responses reach zero.
1670        let mut received_unsub = false;
1671        let mut received_punsub = false;
1672
1673        loop {
1674            let resp = self.recv_response()?;
1675
1676            match resp {
1677                Value::Push { kind, data } => {
1678                    if data.len() >= 2 {
1679                        if let Value::Int(num) = data[1] {
1680                            if resp3_is_pub_sub_state_cleared(
1681                                &mut received_unsub,
1682                                &mut received_punsub,
1683                                &kind,
1684                                num as isize,
1685                            ) {
1686                                break;
1687                            }
1688                        }
1689                    }
1690                }
1691                Value::ServerError(err) => {
1692                    // a new error behavior, introduced in valkey 8.
1693                    // https://github.com/valkey-io/valkey/pull/759
1694                    if err.kind() == Some(ServerErrorKind::NoSub) {
1695                        if no_sub_err_is_pub_sub_state_cleared(
1696                            &mut received_unsub,
1697                            &mut received_punsub,
1698                            &err,
1699                        ) {
1700                            break;
1701                        } else {
1702                            continue;
1703                        }
1704                    }
1705
1706                    return Err(err.into());
1707                }
1708                Value::Array(vec) => {
1709                    let res: (Vec<u8>, (), isize) = from_redis_value(Value::Array(vec))?;
1710                    if resp2_is_pub_sub_state_cleared(
1711                        &mut received_unsub,
1712                        &mut received_punsub,
1713                        &res.0,
1714                        res.2,
1715                    ) {
1716                        break;
1717                    }
1718                }
1719                _ => {
1720                    return Err((
1721                        ErrorKind::Client,
1722                        "Unexpected unsubscribe response",
1723                        format!("{resp:?}"),
1724                    )
1725                        .into())
1726                }
1727            }
1728        }
1729
1730        // Finally, the connection is back in its normal state since all subscriptions were
1731        // cancelled *and* all unsubscribe messages were received.
1732        Ok(())
1733    }
1734
1735    fn send_push(&self, push: PushInfo) {
1736        if let Some(sender) = &self.push_sender {
1737            let _ = sender.send(push);
1738        }
1739    }
1740
1741    fn try_send(&self, value: &RedisResult<Value>) {
1742        if let Ok(Value::Push { kind, data }) = value {
1743            self.send_push(PushInfo {
1744                kind: kind.clone(),
1745                data: data.clone(),
1746            });
1747        }
1748    }
1749
1750    fn send_disconnect(&self) {
1751        self.send_push(PushInfo::disconnect())
1752    }
1753
1754    fn close_connection(&mut self) {
1755        // Notify the PushManager that the connection was lost
1756        self.send_disconnect();
1757        match self.con {
1758            ActualConnection::Tcp(ref mut connection) => {
1759                let _ = connection.reader.shutdown(net::Shutdown::Both);
1760                connection.open = false;
1761            }
1762            #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1763            ActualConnection::TcpNativeTls(ref mut connection) => {
1764                let _ = connection.reader.shutdown();
1765                connection.open = false;
1766            }
1767            #[cfg(feature = "tls-rustls")]
1768            ActualConnection::TcpRustls(ref mut connection) => {
1769                let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both);
1770                connection.open = false;
1771            }
1772            #[cfg(unix)]
1773            ActualConnection::Unix(ref mut connection) => {
1774                let _ = connection.sock.shutdown(net::Shutdown::Both);
1775                connection.open = false;
1776            }
1777        }
1778    }
1779
1780    /// Fetches a single message from the connection. If the message is a response,
1781    /// increment `messages_to_skip` if it wasn't received before a timeout.
1782    fn read(&mut self, is_response: bool) -> RedisResult<Value> {
1783        loop {
1784            let result = match self.con {
1785                ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => {
1786                    self.parser.parse_value(reader)
1787                }
1788                #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1789                ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => {
1790                    let reader = &mut boxed_tls_connection.reader;
1791                    self.parser.parse_value(reader)
1792                }
1793                #[cfg(feature = "tls-rustls")]
1794                ActualConnection::TcpRustls(ref mut boxed_tls_connection) => {
1795                    let reader = &mut boxed_tls_connection.reader;
1796                    self.parser.parse_value(reader)
1797                }
1798                #[cfg(unix)]
1799                ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => {
1800                    self.parser.parse_value(sock)
1801                }
1802            };
1803            self.try_send(&result);
1804
1805            let Err(err) = &result else {
1806                if self.messages_to_skip > 0 {
1807                    self.messages_to_skip -= 1;
1808                    continue;
1809                }
1810                return result;
1811            };
1812            let Some(io_error) = err.as_io_error() else {
1813                if self.messages_to_skip > 0 {
1814                    self.messages_to_skip -= 1;
1815                    continue;
1816                }
1817                return result;
1818            };
1819            // shutdown connection on protocol error
1820            if io_error.kind() == io::ErrorKind::UnexpectedEof {
1821                self.close_connection();
1822            } else if is_response {
1823                self.messages_to_skip += 1;
1824            }
1825
1826            return result;
1827        }
1828    }
1829
1830    /// Sets sender channel for push values.
1831    pub fn set_push_sender(&mut self, sender: SyncPushSender) {
1832        self.push_sender = Some(sender);
1833    }
1834
1835    fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
1836        if bytes.is_empty() {
1837            return Err(RedisError::make_empty_command());
1838        }
1839        let result = self.con.send_bytes(bytes);
1840        if self.protocol.supports_resp3() {
1841            if let Err(e) = &result {
1842                if e.is_connection_dropped() {
1843                    self.send_disconnect();
1844                }
1845            }
1846        }
1847        result
1848    }
1849
1850    /// Subscribes to a new channel(s).
1851    ///
1852    /// This only works if the connection was configured with [ProtocolVersion::RESP3] and [Self::set_push_sender].
1853    pub fn subscribe_resp3<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1854        check_resp3!(self.protocol);
1855        cmd("SUBSCRIBE")
1856            .arg(channel)
1857            .set_no_response(true)
1858            .exec(self)
1859    }
1860
1861    /// Subscribes to new channel(s) with pattern(s).
1862    ///
1863    /// This only works if the connection was configured with [ProtocolVersion::RESP3] and [Self::set_push_sender].
1864    pub fn psubscribe_resp3<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1865        check_resp3!(self.protocol);
1866        cmd("PSUBSCRIBE")
1867            .arg(pchannel)
1868            .set_no_response(true)
1869            .exec(self)
1870    }
1871
1872    /// Unsubscribes from a channel(s).
1873    ///
1874    /// This only works if the connection was configured with [ProtocolVersion::RESP3] and [Self::set_push_sender].
1875    pub fn unsubscribe_resp3<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1876        check_resp3!(self.protocol);
1877        cmd("UNSUBSCRIBE")
1878            .arg(channel)
1879            .set_no_response(true)
1880            .exec(self)
1881    }
1882
1883    /// Unsubscribes from channel pattern(s).
1884    ///
1885    /// This only works if the connection was configured with [ProtocolVersion::RESP3] and [Self::set_push_sender].
1886    pub fn punsubscribe_resp3<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1887        check_resp3!(self.protocol);
1888        cmd("PUNSUBSCRIBE")
1889            .arg(pchannel)
1890            .set_no_response(true)
1891            .exec(self)
1892    }
1893}
1894
1895impl ConnectionLike for Connection {
1896    /// Sends a [Cmd] into the TCP socket and reads a single response from it.
1897    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1898        let pcmd = cmd.get_packed_command();
1899        if self.pubsub {
1900            self.exit_pubsub()?;
1901        }
1902
1903        self.send_bytes(&pcmd)?;
1904        if cmd.is_no_response() {
1905            return Ok(Value::Nil);
1906        }
1907        loop {
1908            match self.read(true)? {
1909                Value::Push {
1910                    kind: _kind,
1911                    data: _data,
1912                } => continue,
1913                val => return Ok(val),
1914            }
1915        }
1916    }
1917    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1918        if self.pubsub {
1919            self.exit_pubsub()?;
1920        }
1921
1922        self.send_bytes(cmd)?;
1923        loop {
1924            match self.read(true)? {
1925                Value::Push {
1926                    kind: _kind,
1927                    data: _data,
1928                } => continue,
1929                val => return Ok(val),
1930            }
1931        }
1932    }
1933
1934    fn req_packed_commands(
1935        &mut self,
1936        cmd: &[u8],
1937        offset: usize,
1938        count: usize,
1939    ) -> RedisResult<Vec<Value>> {
1940        if self.pubsub {
1941            self.exit_pubsub()?;
1942        }
1943        self.send_bytes(cmd)?;
1944        let mut rv = vec![];
1945        let mut first_err = None;
1946        let mut server_errors = vec![];
1947        let mut count = count;
1948        let mut idx = 0;
1949        while idx < (offset + count) {
1950            // When processing a transaction, some responses may be errors.
1951            // We need to keep processing the rest of the responses in that case,
1952            // so bailing early with `?` would not be correct.
1953            // See: https://github.com/redis-rs/redis-rs/issues/436
1954            let response = self.read(true);
1955            match response {
1956                Ok(Value::ServerError(err)) => {
1957                    if idx < offset {
1958                        server_errors.push((idx - 1, err)); // -1, to offset the added MULTI call.
1959                    } else {
1960                        rv.push(Value::ServerError(err));
1961                    }
1962                }
1963                Ok(item) => {
1964                    // RESP3 can insert push data between command replies
1965                    if let Value::Push {
1966                        kind: _kind,
1967                        data: _data,
1968                    } = item
1969                    {
1970                        // if that is the case we have to extend the loop and handle push data
1971                        count += 1;
1972                    } else if idx >= offset {
1973                        rv.push(item);
1974                    }
1975                }
1976                Err(err) => {
1977                    if first_err.is_none() {
1978                        first_err = Some(err);
1979                    }
1980                }
1981            }
1982            idx += 1;
1983        }
1984
1985        if !server_errors.is_empty() {
1986            return Err(RedisError::make_aborted_transaction(server_errors));
1987        }
1988
1989        first_err.map_or(Ok(rv), Err)
1990    }
1991
1992    fn get_db(&self) -> i64 {
1993        self.db
1994    }
1995
1996    fn check_connection(&mut self) -> bool {
1997        cmd("PING").query::<String>(self).is_ok()
1998    }
1999
2000    fn is_open(&self) -> bool {
2001        self.con.is_open()
2002    }
2003}
2004
2005impl<C, T> ConnectionLike for T
2006where
2007    C: ConnectionLike,
2008    T: DerefMut<Target = C>,
2009{
2010    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
2011        self.deref_mut().req_packed_command(cmd)
2012    }
2013
2014    fn req_packed_commands(
2015        &mut self,
2016        cmd: &[u8],
2017        offset: usize,
2018        count: usize,
2019    ) -> RedisResult<Vec<Value>> {
2020        self.deref_mut().req_packed_commands(cmd, offset, count)
2021    }
2022
2023    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
2024        self.deref_mut().req_command(cmd)
2025    }
2026
2027    fn get_db(&self) -> i64 {
2028        self.deref().get_db()
2029    }
2030
2031    fn supports_pipelining(&self) -> bool {
2032        self.deref().supports_pipelining()
2033    }
2034
2035    fn check_connection(&mut self) -> bool {
2036        self.deref_mut().check_connection()
2037    }
2038
2039    fn is_open(&self) -> bool {
2040        self.deref().is_open()
2041    }
2042}
2043
2044/// The pubsub object provides convenient access to the redis pubsub
2045/// system.  Once created you can subscribe and unsubscribe from channels
2046/// and listen in on messages.
2047///
2048/// Example:
2049///
2050/// ```rust,no_run
2051/// # fn do_something() -> redis::RedisResult<()> {
2052/// let client = redis::Client::open("redis://127.0.0.1/")?;
2053/// let mut con = client.get_connection()?;
2054/// let mut pubsub = con.as_pubsub();
2055/// pubsub.subscribe("channel_1")?;
2056/// pubsub.subscribe("channel_2")?;
2057///
2058/// loop {
2059///     let msg = pubsub.get_message()?;
2060///     let payload : String = msg.get_payload()?;
2061///     println!("channel '{}': {}", msg.get_channel_name(), payload);
2062/// }
2063/// # }
2064/// ```
2065impl<'a> PubSub<'a> {
2066    fn new(con: &'a mut Connection) -> Self {
2067        Self {
2068            con,
2069            waiting_messages: VecDeque::new(),
2070        }
2071    }
2072
2073    fn cache_messages_until_received_response(
2074        &mut self,
2075        cmd: &mut Cmd,
2076        is_sub_unsub: bool,
2077    ) -> RedisResult<Value> {
2078        let ignore_response = self.con.protocol.supports_resp3() && is_sub_unsub;
2079        cmd.set_no_response(ignore_response);
2080
2081        self.con.send_packed_command(&cmd.get_packed_command())?;
2082
2083        loop {
2084            let response = self.con.recv_response()?;
2085            if let Some(msg) = Msg::from_value(&response) {
2086                self.waiting_messages.push_back(msg);
2087            } else {
2088                return Ok(response);
2089            }
2090        }
2091    }
2092
2093    /// Subscribes to a new channel(s).
2094    pub fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
2095        self.cache_messages_until_received_response(cmd("SUBSCRIBE").arg(channel), true)?;
2096        Ok(())
2097    }
2098
2099    /// Subscribes to new channel(s) with pattern(s).
2100    pub fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
2101        self.cache_messages_until_received_response(cmd("PSUBSCRIBE").arg(pchannel), true)?;
2102        Ok(())
2103    }
2104
2105    /// Unsubscribes from a channel(s).
2106    pub fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
2107        self.cache_messages_until_received_response(cmd("UNSUBSCRIBE").arg(channel), true)?;
2108        Ok(())
2109    }
2110
2111    /// Unsubscribes from channel pattern(s).
2112    pub fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
2113        self.cache_messages_until_received_response(cmd("PUNSUBSCRIBE").arg(pchannel), true)?;
2114        Ok(())
2115    }
2116
2117    /// Sends a ping with a message to the server
2118    pub fn ping_message<T: FromRedisValue>(&mut self, message: impl ToRedisArgs) -> RedisResult<T> {
2119        Ok(from_redis_value(
2120            self.cache_messages_until_received_response(cmd("PING").arg(message), false)?,
2121        )?)
2122    }
2123    /// Sends a ping to the server
2124    pub fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
2125        Ok(from_redis_value(
2126            self.cache_messages_until_received_response(&mut cmd("PING"), false)?,
2127        )?)
2128    }
2129
2130    /// Fetches the next message from the pubsub connection.  Blocks until
2131    /// a message becomes available.  This currently does not provide a
2132    /// wait not to block :(
2133    ///
2134    /// The message itself is still generic and can be converted into an
2135    /// appropriate type through the helper methods on it.
2136    pub fn get_message(&mut self) -> RedisResult<Msg> {
2137        if let Some(msg) = self.waiting_messages.pop_front() {
2138            return Ok(msg);
2139        }
2140        loop {
2141            if let Some(msg) = Msg::from_owned_value(self.con.read(false)?) {
2142                return Ok(msg);
2143            } else {
2144                continue;
2145            }
2146        }
2147    }
2148
2149    /// Sets the read timeout for the connection.
2150    ///
2151    /// If the provided value is `None`, then `get_message` call will
2152    /// block indefinitely. It is an error to pass the zero `Duration` to this
2153    /// method.
2154    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
2155        self.con.set_read_timeout(dur)
2156    }
2157}
2158
2159impl Drop for PubSub<'_> {
2160    fn drop(&mut self) {
2161        let _ = self.con.exit_pubsub();
2162    }
2163}
2164
2165/// This holds the data that comes from listening to a pubsub
2166/// connection.  It only contains actual message data.
2167impl Msg {
2168    /// Tries to convert provided [`Value`] into [`Msg`].
2169    pub fn from_value(value: &Value) -> Option<Self> {
2170        Self::from_owned_value(value.clone())
2171    }
2172
2173    /// Tries to convert provided [`Value`] into [`Msg`].
2174    pub fn from_owned_value(value: Value) -> Option<Self> {
2175        let mut pattern = None;
2176        let payload;
2177        let channel;
2178
2179        if let Value::Push { kind, data } = value {
2180            return Self::from_push_info(PushInfo { kind, data });
2181        } else {
2182            let raw_msg: Vec<Value> = from_redis_value(value).ok()?;
2183            let mut iter = raw_msg.into_iter();
2184            let msg_type: String = from_redis_value(iter.next()?).ok()?;
2185            if msg_type == "message" {
2186                channel = iter.next()?;
2187                payload = iter.next()?;
2188            } else if msg_type == "pmessage" {
2189                pattern = Some(iter.next()?);
2190                channel = iter.next()?;
2191                payload = iter.next()?;
2192            } else {
2193                return None;
2194            }
2195        };
2196        Some(Msg {
2197            payload,
2198            channel,
2199            pattern,
2200        })
2201    }
2202
2203    /// Tries to convert provided [`PushInfo`] into [`Msg`].
2204    pub fn from_push_info(push_info: PushInfo) -> Option<Self> {
2205        let mut pattern = None;
2206        let payload;
2207        let channel;
2208
2209        let mut iter = push_info.data.into_iter();
2210        if push_info.kind == PushKind::Message || push_info.kind == PushKind::SMessage {
2211            channel = iter.next()?;
2212            payload = iter.next()?;
2213        } else if push_info.kind == PushKind::PMessage {
2214            pattern = Some(iter.next()?);
2215            channel = iter.next()?;
2216            payload = iter.next()?;
2217        } else {
2218            return None;
2219        }
2220
2221        Some(Msg {
2222            payload,
2223            channel,
2224            pattern,
2225        })
2226    }
2227
2228    /// Returns the channel this message came on.
2229    pub fn get_channel<T: FromRedisValue>(&self) -> RedisResult<T> {
2230        Ok(from_redis_value_ref(&self.channel)?)
2231    }
2232
2233    /// Convenience method to get a string version of the channel.  Unless
2234    /// your channel contains non utf-8 bytes you can always use this
2235    /// method.  If the channel is not a valid string (which really should
2236    /// not happen) then the return value is `"?"`.
2237    pub fn get_channel_name(&self) -> &str {
2238        match self.channel {
2239            Value::BulkString(ref bytes) => from_utf8(bytes).unwrap_or("?"),
2240            _ => "?",
2241        }
2242    }
2243
2244    /// Returns the message's payload in a specific format.
2245    pub fn get_payload<T: FromRedisValue>(&self) -> RedisResult<T> {
2246        Ok(from_redis_value_ref(&self.payload)?)
2247    }
2248
2249    /// Returns the bytes that are the message's payload.  This can be used
2250    /// as an alternative to the `get_payload` function if you are interested
2251    /// in the raw bytes in it.
2252    pub fn get_payload_bytes(&self) -> &[u8] {
2253        match self.payload {
2254            Value::BulkString(ref bytes) => bytes,
2255            _ => b"",
2256        }
2257    }
2258
2259    /// Returns true if the message was constructed from a pattern
2260    /// subscription.
2261    #[allow(clippy::wrong_self_convention)]
2262    pub fn from_pattern(&self) -> bool {
2263        self.pattern.is_some()
2264    }
2265
2266    /// If the message was constructed from a message pattern this can be
2267    /// used to find out which one.  It's recommended to match against
2268    /// an `Option<String>` so that you do not need to use `from_pattern`
2269    /// to figure out if a pattern was set.
2270    pub fn get_pattern<T: FromRedisValue>(&self) -> RedisResult<T> {
2271        Ok(match self.pattern {
2272            None => from_redis_value_ref(&Value::Nil),
2273            Some(ref x) => from_redis_value_ref(x),
2274        }?)
2275    }
2276}
2277
2278/// This function simplifies transaction management slightly.  What it
2279/// does is automatically watching keys and then going into a transaction
2280/// loop util it succeeds.  Once it goes through the results are
2281/// returned.
2282///
2283/// To use the transaction two pieces of information are needed: a list
2284/// of all the keys that need to be watched for modifications and a
2285/// closure with the code that should be execute in the context of the
2286/// transaction.  The closure is invoked with a fresh pipeline in atomic
2287/// mode.  To use the transaction the function needs to return the result
2288/// from querying the pipeline with the connection.
2289///
2290/// The end result of the transaction is then available as the return
2291/// value from the function call.
2292///
2293/// Example:
2294///
2295/// ```rust,no_run
2296/// use redis::Commands;
2297/// # fn do_something() -> redis::RedisResult<()> {
2298/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap();
2299/// # let mut con = client.get_connection().unwrap();
2300/// let key = "the_key";
2301/// let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| {
2302///     let old_val : isize = con.get(key)?;
2303///     pipe
2304///         .set(key, old_val + 1).ignore()
2305///         .get(key).query(con)
2306/// })?;
2307/// println!("The incremented number is: {}", new_val);
2308/// # Ok(()) }
2309/// ```
2310pub fn transaction<
2311    C: ConnectionLike,
2312    K: ToRedisArgs,
2313    T,
2314    F: FnMut(&mut C, &mut Pipeline) -> RedisResult<Option<T>>,
2315>(
2316    con: &mut C,
2317    keys: &[K],
2318    func: F,
2319) -> RedisResult<T> {
2320    let mut func = func;
2321    loop {
2322        cmd("WATCH").arg(keys).exec(con)?;
2323        let mut p = pipe();
2324        let response: Option<T> = func(con, p.atomic())?;
2325        match response {
2326            None => {
2327                continue;
2328            }
2329            Some(response) => {
2330                // make sure no watch is left in the connection, even if
2331                // someone forgot to use the pipeline.
2332                cmd("UNWATCH").exec(con)?;
2333                return Ok(response);
2334            }
2335        }
2336    }
2337}
2338//TODO: for both clearing logic support sharded channels.
2339
2340/// Common logic for clearing subscriptions in RESP2 async/sync
2341pub fn resp2_is_pub_sub_state_cleared(
2342    received_unsub: &mut bool,
2343    received_punsub: &mut bool,
2344    kind: &[u8],
2345    num: isize,
2346) -> bool {
2347    match kind.first() {
2348        Some(&b'u') => *received_unsub = true,
2349        Some(&b'p') => *received_punsub = true,
2350        _ => (),
2351    };
2352    *received_unsub && *received_punsub && num == 0
2353}
2354
2355/// Common logic for clearing subscriptions in RESP3 async/sync
2356pub fn resp3_is_pub_sub_state_cleared(
2357    received_unsub: &mut bool,
2358    received_punsub: &mut bool,
2359    kind: &PushKind,
2360    num: isize,
2361) -> bool {
2362    match kind {
2363        PushKind::Unsubscribe => *received_unsub = true,
2364        PushKind::PUnsubscribe => *received_punsub = true,
2365        _ => (),
2366    };
2367    *received_unsub && *received_punsub && num == 0
2368}
2369
2370pub fn no_sub_err_is_pub_sub_state_cleared(
2371    received_unsub: &mut bool,
2372    received_punsub: &mut bool,
2373    err: &ServerError,
2374) -> bool {
2375    let details = err.details();
2376    *received_unsub = *received_unsub
2377        || details
2378            .map(|details| details.starts_with("'unsub"))
2379            .unwrap_or_default();
2380    *received_punsub = *received_punsub
2381        || details
2382            .map(|details| details.starts_with("'punsub"))
2383            .unwrap_or_default();
2384    *received_unsub && *received_punsub
2385}
2386
2387/// Common logic for checking real cause of hello3 command error
2388pub fn get_resp3_hello_command_error(err: RedisError) -> RedisError {
2389    if let Some(detail) = err.detail() {
2390        if detail.starts_with("unknown command `HELLO`") {
2391            return (
2392                ErrorKind::RESP3NotSupported,
2393                "Redis Server doesn't support HELLO command therefore resp3 cannot be used",
2394            )
2395                .into();
2396        }
2397    }
2398    err
2399}
2400
2401#[cfg(test)]
2402mod tests {
2403    use super::*;
2404
2405    #[test]
2406    fn test_parse_redis_url() {
2407        let cases = vec![
2408            ("redis://127.0.0.1", true),
2409            ("redis://[::1]", true),
2410            ("rediss://127.0.0.1", true),
2411            ("rediss://[::1]", true),
2412            ("valkey://127.0.0.1", true),
2413            ("valkey://[::1]", true),
2414            ("valkeys://127.0.0.1", true),
2415            ("valkeys://[::1]", true),
2416            ("redis+unix:///run/redis.sock", true),
2417            ("valkey+unix:///run/valkey.sock", true),
2418            ("unix:///run/redis.sock", true),
2419            ("http://127.0.0.1", false),
2420            ("tcp://127.0.0.1", false),
2421        ];
2422        for (url, expected) in cases.into_iter() {
2423            let res = parse_redis_url(url);
2424            assert_eq!(
2425                res.is_some(),
2426                expected,
2427                "Parsed result of `{url}` is not expected",
2428            );
2429        }
2430    }
2431
2432    #[test]
2433    fn test_url_to_tcp_connection_info() {
2434        let cases = vec![
2435            (
2436                url::Url::parse("redis://127.0.0.1").unwrap(),
2437                ConnectionInfo {
2438                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2439                    redis: Default::default(),
2440                    tcp_settings: TcpSettings::default(),
2441                },
2442            ),
2443            (
2444                url::Url::parse("redis://[::1]").unwrap(),
2445                ConnectionInfo {
2446                    addr: ConnectionAddr::Tcp("::1".to_string(), 6379),
2447                    redis: Default::default(),
2448                    tcp_settings: TcpSettings::default(),
2449                },
2450            ),
2451            (
2452                url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(),
2453                ConnectionInfo {
2454                    addr: ConnectionAddr::Tcp("example.com".to_string(), 6379),
2455                    redis: RedisConnectionInfo {
2456                        db: 2,
2457                        username: Some("%johndoe%".into()),
2458                        password: Some("#@<>$".into()),
2459                        ..Default::default()
2460                    },
2461                    tcp_settings: TcpSettings::default(),
2462                },
2463            ),
2464            (
2465                url::Url::parse("redis://127.0.0.1/?protocol=2").unwrap(),
2466                ConnectionInfo {
2467                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2468                    redis: Default::default(),
2469                    tcp_settings: TcpSettings::default(),
2470                },
2471            ),
2472            (
2473                url::Url::parse("redis://127.0.0.1/?protocol=resp3").unwrap(),
2474                ConnectionInfo {
2475                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
2476                    redis: RedisConnectionInfo {
2477                        protocol: ProtocolVersion::RESP3,
2478                        ..Default::default()
2479                    },
2480                    tcp_settings: TcpSettings::default(),
2481                },
2482            ),
2483        ];
2484        for (url, expected) in cases.into_iter() {
2485            let res = url_to_tcp_connection_info(url.clone()).unwrap();
2486            assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2487            assert_eq!(
2488                res.redis.db, expected.redis.db,
2489                "db of {url} is not expected",
2490            );
2491            assert_eq!(
2492                res.redis.username, expected.redis.username,
2493                "username of {url} is not expected",
2494            );
2495            assert_eq!(
2496                res.redis.password, expected.redis.password,
2497                "password of {url} is not expected",
2498            );
2499        }
2500    }
2501
2502    #[test]
2503    fn test_url_to_tcp_connection_info_failed() {
2504        let cases = vec![
2505            (
2506                url::Url::parse("redis://").unwrap(),
2507                "Missing hostname",
2508                None,
2509            ),
2510            (
2511                url::Url::parse("redis://127.0.0.1/db").unwrap(),
2512                "Invalid database number",
2513                None,
2514            ),
2515            (
2516                url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(),
2517                "Username is not valid UTF-8 string",
2518                None,
2519            ),
2520            (
2521                url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(),
2522                "Password is not valid UTF-8 string",
2523                None,
2524            ),
2525            (
2526                url::Url::parse("redis://127.0.0.1/?protocol=4").unwrap(),
2527                "Invalid protocol version",
2528                Some("4"),
2529            ),
2530        ];
2531        for (url, expected, detail) in cases.into_iter() {
2532            let res = url_to_tcp_connection_info(url).unwrap_err();
2533            assert_eq!(res.kind(), crate::ErrorKind::InvalidClientConfig,);
2534            let desc = res.to_string();
2535            assert!(desc.contains(expected), "{desc}");
2536            assert_eq!(res.detail(), detail);
2537        }
2538    }
2539
2540    #[test]
2541    #[cfg(unix)]
2542    fn test_url_to_unix_connection_info() {
2543        let cases = vec![
2544            (
2545                url::Url::parse("unix:///var/run/redis.sock").unwrap(),
2546                ConnectionInfo {
2547                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2548                    redis: RedisConnectionInfo {
2549                        db: 0,
2550                        username: None,
2551                        password: None,
2552                        protocol: ProtocolVersion::RESP2,
2553                        skip_set_lib_name: false,
2554                    },
2555                    tcp_settings: Default::default(),
2556                },
2557            ),
2558            (
2559                url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(),
2560                ConnectionInfo {
2561                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2562                    redis: RedisConnectionInfo {
2563                        db: 1,
2564                        ..Default::default()
2565                    },
2566                    tcp_settings: TcpSettings::default(),
2567                },
2568            ),
2569            (
2570                url::Url::parse(
2571                    "unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2",
2572                )
2573                .unwrap(),
2574                ConnectionInfo {
2575                    addr: ConnectionAddr::Unix("/example.sock".into()),
2576                    redis: RedisConnectionInfo {
2577                        db: 2,
2578                        username: Some("%johndoe%".into()),
2579                        password: Some("#@<>$".into()),
2580                        ..Default::default()
2581                    },
2582                    tcp_settings: TcpSettings::default(),
2583                },
2584            ),
2585            (
2586                url::Url::parse(
2587                    "redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25",
2588                )
2589                .unwrap(),
2590                ConnectionInfo {
2591                    addr: ConnectionAddr::Unix("/example.sock".into()),
2592                    redis: RedisConnectionInfo {
2593                        db: 2,
2594                        username: Some("%johndoe%".into()),
2595                        password: Some("&?= *+".into()),
2596                        ..Default::default()
2597                    },
2598                    tcp_settings: TcpSettings::default(),
2599                },
2600            ),
2601            (
2602                url::Url::parse("redis+unix:///var/run/redis.sock?protocol=3").unwrap(),
2603                ConnectionInfo {
2604                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
2605                    redis: RedisConnectionInfo {
2606                        protocol: ProtocolVersion::RESP3,
2607                        ..Default::default()
2608                    },
2609                    tcp_settings: TcpSettings::default(),
2610                },
2611            ),
2612        ];
2613        for (url, expected) in cases.into_iter() {
2614            assert_eq!(
2615                ConnectionAddr::Unix(url.to_file_path().unwrap()),
2616                expected.addr,
2617                "addr of {url} is not expected",
2618            );
2619            let res = url_to_unix_connection_info(url.clone()).unwrap();
2620            assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
2621            assert_eq!(
2622                res.redis.db, expected.redis.db,
2623                "db of {url} is not expected",
2624            );
2625            assert_eq!(
2626                res.redis.username, expected.redis.username,
2627                "username of {url} is not expected",
2628            );
2629            assert_eq!(
2630                res.redis.password, expected.redis.password,
2631                "password of {url} is not expected",
2632            );
2633        }
2634    }
2635}