mco_redis_rs/
connection.rs

1use std::fmt;
2use std::io::{self, Write};
3use std::net::{self, ToSocketAddrs};
4use std::path::PathBuf;
5use std::str::{from_utf8, FromStr};
6use std::time::Duration;
7
8use crate::cmd::{cmd, pipe, Cmd};
9use crate::parser::Parser;
10use crate::pipeline::Pipeline;
11use crate::types::{
12    from_redis_value, ErrorKind, FromRedisValue, RedisError, RedisResult, ToRedisArgs, Value,
13};
14
15#[cfg(unix)]
16#[cfg(not(feature = "mco"))]
17use std::os::unix::net::UnixStream;
18
19#[cfg(not(feature = "mco"))]
20use std::net::TcpStream;
21
22#[cfg(feature = "mco")]
23use mco::net::TcpStream;
24
25#[cfg(feature = "mco")]
26use mco::net::TcpStream as UnixStream;
27
28
29#[cfg(feature = "tls")]
30use native_tls::{TlsConnector, TlsStream};
31
32static DEFAULT_PORT: u16 = 6379;
33
34/// This function takes a redis URL string and parses it into a URL
35/// as used by rust-url.  This is necessary as the default parser does
36/// not understand how redis URLs function.
37pub fn parse_redis_url(input: &str) -> Option<url::Url> {
38    match url::Url::parse(input) {
39        Ok(result) => match result.scheme() {
40            "redis" | "rediss" | "redis+unix" | "unix" => Some(result),
41            _ => None,
42        },
43        Err(_) => None,
44    }
45}
46
47/// Defines the connection address.
48///
49/// Not all connection addresses are supported on all platforms.  For instance
50/// to connect to a unix socket you need to run this on an operating system
51/// that supports them.
52#[derive(Clone, Debug, PartialEq)]
53pub enum ConnectionAddr {
54    /// Format for this is `(host, port)`.
55    Tcp(String, u16),
56    /// Format for this is `(host, port)`.
57    TcpTls {
58        /// Hostname
59        host: String,
60        /// Port
61        port: u16,
62        /// Disable hostname verification when connecting.
63        ///
64        /// # Warning
65        ///
66        /// You should think very carefully before you use this method. If hostname
67        /// verification is not used, any valid certificate for any site will be
68        /// trusted for use from any other. This introduces a significant
69        /// vulnerability to man-in-the-middle attacks.
70        insecure: bool,
71    },
72    /// Format for this is the path to the unix socket.
73    Unix(PathBuf),
74}
75
76impl ConnectionAddr {
77    /// Checks if this address is supported.
78    ///
79    /// Because not all platforms support all connection addresses this is a
80    /// quick way to figure out if a connection method is supported.  Currently
81    /// this only affects unix connections which are only supported on unix
82    /// platforms and on older versions of rust also require an explicit feature
83    /// to be enabled.
84    pub fn is_supported(&self) -> bool {
85        match *self {
86            ConnectionAddr::Tcp(_, _) => true,
87            ConnectionAddr::TcpTls { .. } => cfg!(feature = "tls"),
88            ConnectionAddr::Unix(_) => cfg!(unix),
89        }
90    }
91}
92
93impl fmt::Display for ConnectionAddr {
94    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
95        match *self {
96            ConnectionAddr::Tcp(ref host, port) => write!(f, "{}:{}", host, port),
97            ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{}:{}", host, port),
98            ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()),
99        }
100    }
101}
102
103/// Holds the connection information that redis should use for connecting.
104#[derive(Clone, Debug)]
105pub struct ConnectionInfo {
106    /// A connection address for where to connect to.
107    pub addr: ConnectionAddr,
108
109    /// A boxed connection address for where to connect to.
110    pub redis: RedisConnectionInfo,
111}
112
113/// Redis specific/connection independent information used to establish a connection to redis.
114#[derive(Clone, Debug, Default)]
115pub struct RedisConnectionInfo {
116    /// The database number to use.  This is usually `0`.
117    pub db: i64,
118    /// Optionally a username that should be used for connection.
119    pub username: Option<String>,
120    /// Optionally a password that should be used for connection.
121    pub password: Option<String>,
122}
123
124impl FromStr for ConnectionInfo {
125    type Err = RedisError;
126
127    fn from_str(s: &str) -> Result<Self, Self::Err> {
128        s.into_connection_info()
129    }
130}
131
132/// Converts an object into a connection info struct.  This allows the
133/// constructor of the client to accept connection information in a
134/// range of different formats.
135pub trait IntoConnectionInfo {
136    /// Converts the object into a connection info object.
137    fn into_connection_info(self) -> RedisResult<ConnectionInfo>;
138}
139
140impl IntoConnectionInfo for ConnectionInfo {
141    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
142        Ok(self)
143    }
144}
145
146impl<'a> IntoConnectionInfo for &'a str {
147    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
148        match parse_redis_url(self) {
149            Some(u) => u.into_connection_info(),
150            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
151        }
152    }
153}
154
155impl<T> IntoConnectionInfo for (T, u16)
156where
157    T: Into<String>,
158{
159    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
160        Ok(ConnectionInfo {
161            addr: ConnectionAddr::Tcp(self.0.into(), self.1),
162            redis: RedisConnectionInfo::default(),
163        })
164    }
165}
166
167impl IntoConnectionInfo for String {
168    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
169        match parse_redis_url(&self) {
170            Some(u) => u.into_connection_info(),
171            None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
172        }
173    }
174}
175
176fn url_to_tcp_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
177    let host = match url.host() {
178        Some(host) => host.to_string(),
179        None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")),
180    };
181    let port = url.port().unwrap_or(DEFAULT_PORT);
182    let addr = if url.scheme() == "rediss" {
183        #[cfg(feature = "tls")]
184        {
185            match url.fragment() {
186                Some("insecure") => ConnectionAddr::TcpTls {
187                    host,
188                    port,
189                    insecure: true,
190                },
191                Some(_) => fail!((
192                    ErrorKind::InvalidClientConfig,
193                    "only #insecure is supported as URL fragment"
194                )),
195                _ => ConnectionAddr::TcpTls {
196                    host,
197                    port,
198                    insecure: false,
199                },
200            }
201        }
202
203        #[cfg(not(feature = "tls"))]
204        fail!((
205            ErrorKind::InvalidClientConfig,
206            "can't connect with TLS, the feature is not enabled"
207        ));
208    } else {
209        ConnectionAddr::Tcp(host, port)
210    };
211    Ok(ConnectionInfo {
212        addr,
213        redis: RedisConnectionInfo {
214            db: match url.path().trim_matches('/') {
215                "" => 0,
216                path => unwrap_or!(
217                    path.parse::<i64>().ok(),
218                    fail!((ErrorKind::InvalidClientConfig, "Invalid database number"))
219                ),
220            },
221            username: if url.username().is_empty() {
222                None
223            } else {
224                match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() {
225                    Ok(decoded) => Some(decoded.into_owned()),
226                    Err(_) => fail!((
227                        ErrorKind::InvalidClientConfig,
228                        "Username is not valid UTF-8 string"
229                    )),
230                }
231            },
232            password: match url.password() {
233                Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() {
234                    Ok(decoded) => Some(decoded.into_owned()),
235                    Err(_) => fail!((
236                        ErrorKind::InvalidClientConfig,
237                        "Password is not valid UTF-8 string"
238                    )),
239                },
240                None => None,
241            },
242        },
243    })
244}
245
246#[cfg(unix)]
247fn url_to_unix_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
248    let query: std::collections::HashMap<_, _> = url.query_pairs().collect();
249    Ok(ConnectionInfo {
250        addr: ConnectionAddr::Unix(unwrap_or!(
251            url.to_file_path().ok(),
252            fail!((ErrorKind::InvalidClientConfig, "Missing path"))
253        )),
254        redis: RedisConnectionInfo {
255            db: match query.get("db") {
256                Some(db) => unwrap_or!(
257                    db.parse::<i64>().ok(),
258                    fail!((ErrorKind::InvalidClientConfig, "Invalid database number"))
259                ),
260                None => 0,
261            },
262            username: query.get("user").map(|username| username.to_string()),
263            password: query.get("pass").map(|password| password.to_string()),
264        },
265    })
266}
267
268#[cfg(not(unix))]
269fn url_to_unix_connection_info(_: url::Url) -> RedisResult<ConnectionInfo> {
270    fail!((
271        ErrorKind::InvalidClientConfig,
272        "Unix sockets are not available on this platform."
273    ));
274}
275
276impl IntoConnectionInfo for url::Url {
277    fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
278        match self.scheme() {
279            "redis" | "rediss" => url_to_tcp_connection_info(self),
280            "unix" | "redis+unix" => url_to_unix_connection_info(self),
281            _ => fail!((
282                ErrorKind::InvalidClientConfig,
283                "URL provided is not a redis URL"
284            )),
285        }
286    }
287}
288
289struct TcpConnection {
290    reader: TcpStream,
291    open: bool,
292}
293
294#[cfg(feature = "tls")]
295struct TcpTlsConnection {
296    reader: TlsStream<TcpStream>,
297    open: bool,
298}
299
300#[cfg(unix)]
301struct UnixConnection {
302    sock: UnixStream,
303    open: bool,
304}
305
306enum ActualConnection {
307    Tcp(TcpConnection),
308    #[cfg(feature = "tls")]
309    TcpTls(TcpTlsConnection),
310    #[cfg(unix)]
311    Unix(UnixConnection),
312}
313
314/// Represents a stateful redis TCP connection.
315pub struct Connection {
316    con: ActualConnection,
317    parser: Parser,
318    db: i64,
319
320    /// Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`.
321    ///
322    /// This flag is checked when attempting to send a command, and if it's raised, we attempt to
323    /// exit the pubsub state before executing the new request.
324    pubsub: bool,
325}
326
327/// Represents a pubsub connection.
328pub struct PubSub<'a> {
329    con: &'a mut Connection,
330}
331
332/// Represents a pubsub message.
333#[derive(Debug)]
334pub struct Msg {
335    payload: Value,
336    channel: Value,
337    pattern: Option<Value>,
338}
339
340impl ActualConnection {
341    pub fn new(addr: &ConnectionAddr, timeout: Option<Duration>) -> RedisResult<ActualConnection> {
342        Ok(match *addr {
343            ConnectionAddr::Tcp(ref host, ref port) => {
344                let host: &str = &*host;
345                let tcp = match timeout {
346                    None => TcpStream::connect((host, *port))?,
347                    Some(timeout) => {
348                        let mut tcp = None;
349                        let mut last_error = None;
350                        for addr in (host, *port).to_socket_addrs()? {
351                            match TcpStream::connect_timeout(&addr, timeout) {
352                                Ok(l) => {
353                                    tcp = Some(l);
354                                    break;
355                                }
356                                Err(e) => {
357                                    last_error = Some(e);
358                                }
359                            };
360                        }
361                        match (tcp, last_error) {
362                            (Some(tcp), _) => tcp,
363                            (None, Some(e)) => {
364                                fail!(e);
365                            }
366                            (None, None) => {
367                                fail!((
368                                    ErrorKind::InvalidClientConfig,
369                                    "could not resolve to any addresses"
370                                ));
371                            }
372                        }
373                    }
374                };
375                ActualConnection::Tcp(TcpConnection {
376                    reader: tcp,
377                    open: true,
378                })
379            }
380            #[cfg(feature = "tls")]
381            ConnectionAddr::TcpTls {
382                ref host,
383                port,
384                insecure,
385            } => {
386                let tls_connector = if insecure {
387                    TlsConnector::builder()
388                        .danger_accept_invalid_certs(true)
389                        .danger_accept_invalid_hostnames(true)
390                        .use_sni(false)
391                        .build()?
392                } else {
393                    TlsConnector::new()?
394                };
395                let host: &str = &*host;
396                let tls = match timeout {
397                    None => {
398                        let tcp = TcpStream::connect((host, port))?;
399                        match tls_connector.connect(host, tcp) {
400                            Ok(res) => res,
401                            Err(e) => {
402                                fail!((ErrorKind::IoError, "SSL Handshake error", e.to_string()));
403                            }
404                        }
405                    }
406                    Some(timeout) => {
407                        let mut tcp = None;
408                        let mut last_error = None;
409                        for addr in (host, port).to_socket_addrs()? {
410                            match TcpStream::connect_timeout(&addr, timeout) {
411                                Ok(l) => {
412                                    tcp = Some(l);
413                                    break;
414                                }
415                                Err(e) => {
416                                    last_error = Some(e);
417                                }
418                            };
419                        }
420                        match (tcp, last_error) {
421                            (Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(),
422                            (None, Some(e)) => {
423                                fail!(e);
424                            }
425                            (None, None) => {
426                                fail!((
427                                    ErrorKind::InvalidClientConfig,
428                                    "could not resolve to any addresses"
429                                ));
430                            }
431                        }
432                    }
433                };
434                ActualConnection::TcpTls(TcpTlsConnection {
435                    reader: tls,
436                    open: true,
437                })
438            }
439            #[cfg(not(feature = "tls"))]
440            ConnectionAddr::TcpTls { .. } => {
441                fail!((
442                    ErrorKind::InvalidClientConfig,
443                    "Cannot connect to TCP with TLS without the tls feature"
444                ));
445            }
446            #[cfg(unix)]
447            ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection {
448                sock: UnixStream::connect(path.to_str().unwrap_or_default())?,
449                open: true,
450            }),
451            #[cfg(not(unix))]
452            ConnectionAddr::Unix(ref _path) => {
453                fail!((
454                    ErrorKind::InvalidClientConfig,
455                    "Cannot connect to unix sockets \
456                     on this platform"
457                ));
458            }
459        })
460    }
461
462    pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
463        match *self {
464            ActualConnection::Tcp(ref mut connection) => {
465                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
466                match res {
467                    Err(e) => {
468                        if e.is_connection_dropped() {
469                            connection.open = false;
470                        }
471                        Err(e)
472                    }
473                    Ok(_) => Ok(Value::Okay),
474                }
475            }
476            #[cfg(feature = "tls")]
477            ActualConnection::TcpTls(ref mut connection) => {
478                let res = connection.reader.write_all(bytes).map_err(RedisError::from);
479                match res {
480                    Err(e) => {
481                        if e.is_connection_dropped() {
482                            connection.open = false;
483                        }
484                        Err(e)
485                    }
486                    Ok(_) => Ok(Value::Okay),
487                }
488            }
489            #[cfg(unix)]
490            ActualConnection::Unix(ref mut connection) => {
491                let result = connection.sock.write_all(bytes).map_err(RedisError::from);
492                match result {
493                    Err(e) => {
494                        if e.is_connection_dropped() {
495                            connection.open = false;
496                        }
497                        Err(e)
498                    }
499                    Ok(_) => Ok(Value::Okay),
500                }
501            }
502        }
503    }
504
505    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
506        match *self {
507            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
508                reader.set_write_timeout(dur)?;
509            }
510            #[cfg(feature = "tls")]
511            ActualConnection::TcpTls(TcpTlsConnection { ref reader, .. }) => {
512                reader.get_ref().set_write_timeout(dur)?;
513            }
514            #[cfg(unix)]
515            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
516                sock.set_write_timeout(dur)?;
517            }
518        }
519        Ok(())
520    }
521
522    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
523        match *self {
524            ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
525                reader.set_read_timeout(dur)?;
526            }
527            #[cfg(feature = "tls")]
528            ActualConnection::TcpTls(TcpTlsConnection { ref reader, .. }) => {
529                reader.get_ref().set_read_timeout(dur)?;
530            }
531            #[cfg(unix)]
532            ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
533                sock.set_read_timeout(dur)?;
534            }
535        }
536        Ok(())
537    }
538
539    pub fn is_open(&self) -> bool {
540        match *self {
541            ActualConnection::Tcp(TcpConnection { open, .. }) => open,
542            #[cfg(feature = "tls")]
543            ActualConnection::TcpTls(TcpTlsConnection { open, .. }) => open,
544            #[cfg(unix)]
545            ActualConnection::Unix(UnixConnection { open, .. }) => open,
546        }
547    }
548}
549
550fn connect_auth(con: &mut Connection, connection_info: &RedisConnectionInfo) -> RedisResult<()> {
551    let mut command = cmd("AUTH");
552    if let Some(username) = &connection_info.username {
553        command.arg(username);
554    }
555    let password = connection_info.password.as_ref().unwrap();
556    let err = match command.arg(password).query::<Value>(con) {
557        Ok(Value::Okay) => return Ok(()),
558        Ok(_) => {
559            fail!((
560                ErrorKind::ResponseError,
561                "Redis server refused to authenticate, returns Ok() != Value::Okay"
562            ));
563        }
564        Err(e) => e,
565    };
566    let err_msg = err.detail().ok_or((
567        ErrorKind::AuthenticationFailed,
568        "Password authentication failed",
569    ))?;
570    if !err_msg.contains("wrong number of arguments for 'auth' command") {
571        fail!((
572            ErrorKind::AuthenticationFailed,
573            "Password authentication failed",
574        ));
575    }
576
577    // fallback to AUTH version <= 5
578    let mut command = cmd("AUTH");
579    match command.arg(password).query::<Value>(con) {
580        Ok(Value::Okay) => Ok(()),
581        _ => fail!((
582            ErrorKind::AuthenticationFailed,
583            "Password authentication failed",
584        )),
585    }
586}
587
588pub fn connect(
589    connection_info: &ConnectionInfo,
590    timeout: Option<Duration>,
591) -> RedisResult<Connection> {
592    let con = ActualConnection::new(&connection_info.addr, timeout)?;
593    setup_connection(con, &connection_info.redis)
594}
595
596fn setup_connection(
597    con: ActualConnection,
598    connection_info: &RedisConnectionInfo,
599) -> RedisResult<Connection> {
600    let mut rv = Connection {
601        con,
602        parser: Parser::new(),
603        db: connection_info.db,
604        pubsub: false,
605    };
606
607    if connection_info.password.is_some() {
608        connect_auth(&mut rv, connection_info)?;
609    }
610
611    if connection_info.db != 0 {
612        match cmd("SELECT")
613            .arg(connection_info.db)
614            .query::<Value>(&mut rv)
615        {
616            Ok(Value::Okay) => {}
617            _ => fail!((
618                ErrorKind::ResponseError,
619                "Redis server refused to switch database"
620            )),
621        }
622    }
623
624    Ok(rv)
625}
626
627/// Implements the "stateless" part of the connection interface that is used by the
628/// different objects in redis-rs.  Primarily it obviously applies to `Connection`
629/// object but also some other objects implement the interface (for instance
630/// whole clients or certain redis results).
631///
632/// Generally clients and connections (as well as redis results of those) implement
633/// this trait.  Actual connections provide more functionality which can be used
634/// to implement things like `PubSub` but they also can modify the intrinsic
635/// state of the TCP connection.  This is not possible with `ConnectionLike`
636/// implementors because that functionality is not exposed.
637pub trait ConnectionLike {
638    /// Sends an already encoded (packed) command into the TCP socket and
639    /// reads the single response from it.
640    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value>;
641
642    /// Sends multiple already encoded (packed) command into the TCP socket
643    /// and reads `count` responses from it.  This is used to implement
644    /// pipelining.
645    fn req_packed_commands(
646        &mut self,
647        cmd: &[u8],
648        offset: usize,
649        count: usize,
650    ) -> RedisResult<Vec<Value>>;
651
652    /// Sends a [Cmd](Cmd) into the TCP socket and reads a single response from it.
653    fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
654        let pcmd = cmd.get_packed_command();
655        self.req_packed_command(&pcmd)
656    }
657
658    /// Returns the database this connection is bound to.  Note that this
659    /// information might be unreliable because it's initially cached and
660    /// also might be incorrect if the connection like object is not
661    /// actually connected.
662    fn get_db(&self) -> i64;
663
664    /// Does this connection support pipelining?
665    #[doc(hidden)]
666    fn supports_pipelining(&self) -> bool {
667        true
668    }
669
670    /// Check that all connections it has are available (`PING` internally).
671    fn check_connection(&mut self) -> bool;
672
673    /// Returns the connection status.
674    ///
675    /// The connection is open until any `read_response` call recieved an
676    /// invalid response from the server (most likely a closed or dropped
677    /// connection, otherwise a Redis protocol error). When using unix
678    /// sockets the connection is open until writing a command failed with a
679    /// `BrokenPipe` error.
680    fn is_open(&self) -> bool;
681}
682
683/// A connection is an object that represents a single redis connection.  It
684/// provides basic support for sending encoded commands into a redis connection
685/// and to read a response from it.  It's bound to a single database and can
686/// only be created from the client.
687///
688/// You generally do not much with this object other than passing it to
689/// `Cmd` objects.
690impl Connection {
691    /// Sends an already encoded (packed) command into the TCP socket and
692    /// does not read a response.  This is useful for commands like
693    /// `MONITOR` which yield multiple items.  This needs to be used with
694    /// care because it changes the state of the connection.
695    pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> {
696        self.con.send_bytes(cmd)?;
697        Ok(())
698    }
699
700    /// Fetches a single response from the connection.  This is useful
701    /// if used in combination with `send_packed_command`.
702    pub fn recv_response(&mut self) -> RedisResult<Value> {
703        self.read_response()
704    }
705
706    /// Sets the write timeout for the connection.
707    ///
708    /// If the provided value is `None`, then `send_packed_command` call will
709    /// block indefinitely. It is an error to pass the zero `Duration` to this
710    /// method.
711    pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
712        self.con.set_write_timeout(dur)
713    }
714
715    /// Sets the read timeout for the connection.
716    ///
717    /// If the provided value is `None`, then `recv_response` call will
718    /// block indefinitely. It is an error to pass the zero `Duration` to this
719    /// method.
720    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
721        self.con.set_read_timeout(dur)
722    }
723
724    /// Creates a [`PubSub`] instance for this connection.
725    pub fn as_pubsub(&mut self) -> PubSub<'_> {
726        // NOTE: The pubsub flag is intentionally not raised at this time since
727        // running commands within the pubsub state should not try and exit from
728        // the pubsub state.
729        PubSub::new(self)
730    }
731
732    fn exit_pubsub(&mut self) -> RedisResult<()> {
733        let res = self.clear_active_subscriptions();
734        if res.is_ok() {
735            self.pubsub = false;
736        } else {
737            // Raise the pubsub flag to indicate the connection is "stuck" in that state.
738            self.pubsub = true;
739        }
740
741        res
742    }
743
744    /// Get the inner connection out of a PubSub
745    ///
746    /// Any active subscriptions are unsubscribed. In the event of an error, the connection is
747    /// dropped.
748    fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
749        // Responses to unsubscribe commands return in a 3-tuple with values
750        // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs).
751        // The "count of remaining subs" includes both pattern subscriptions and non pattern
752        // subscriptions. Thus, to accurately drain all unsubscribe messages received from the
753        // server, both commands need to be executed at once.
754        {
755            // Prepare both unsubscribe commands
756            let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command();
757            let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command();
758
759            // Grab a reference to the underlying connection so that we may send
760            // the commands without immediately blocking for a response.
761            let con = &mut self.con;
762
763            // Execute commands
764            con.send_bytes(&unsubscribe)?;
765            con.send_bytes(&punsubscribe)?;
766        }
767
768        // Receive responses
769        //
770        // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe
771        // commands. There may be more responses if there are active subscriptions. In this case,
772        // messages are received until the _subscription count_ in the responses reach zero.
773        let mut received_unsub = false;
774        let mut received_punsub = false;
775        loop {
776            let res: (Vec<u8>, (), isize) = from_redis_value(&self.recv_response()?)?;
777
778            match res.0.first() {
779                Some(&b'u') => received_unsub = true,
780                Some(&b'p') => received_punsub = true,
781                _ => (),
782            }
783
784            if received_unsub && received_punsub && res.2 == 0 {
785                break;
786            }
787        }
788
789        // Finally, the connection is back in its normal state since all subscriptions were
790        // cancelled *and* all unsubscribe messages were received.
791        Ok(())
792    }
793
794    /// Fetches a single response from the connection.
795    fn read_response(&mut self) -> RedisResult<Value> {
796        let result = match self.con {
797            ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => {
798                self.parser.parse_value(reader)
799            }
800            #[cfg(feature = "tls")]
801            ActualConnection::TcpTls(TcpTlsConnection { ref mut reader, .. }) => {
802                self.parser.parse_value(reader)
803            }
804            #[cfg(unix)]
805            ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => {
806                self.parser.parse_value(sock)
807            }
808        };
809        // shutdown connection on protocol error
810        if let Err(e) = &result {
811            let shutdown = match e.as_io_error() {
812                Some(e) => e.kind() == io::ErrorKind::UnexpectedEof,
813                None => false,
814            };
815            if shutdown {
816                match self.con {
817                    ActualConnection::Tcp(ref mut connection) => {
818                        let _ = connection.reader.shutdown(net::Shutdown::Both);
819                        connection.open = false;
820                    }
821                    #[cfg(feature = "tls")]
822                    ActualConnection::TcpTls(ref mut connection) => {
823                        let _ = connection.reader.shutdown();
824                        connection.open = false;
825                    }
826                    #[cfg(unix)]
827                    ActualConnection::Unix(ref mut connection) => {
828                        let _ = connection.sock.shutdown(net::Shutdown::Both);
829                        connection.open = false;
830                    }
831                }
832            }
833        }
834        result
835    }
836}
837
838impl ConnectionLike for Connection {
839    fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
840        if self.pubsub {
841            self.exit_pubsub()?;
842        }
843
844        self.con.send_bytes(cmd)?;
845        self.read_response()
846    }
847
848    fn req_packed_commands(
849        &mut self,
850        cmd: &[u8],
851        offset: usize,
852        count: usize,
853    ) -> RedisResult<Vec<Value>> {
854        if self.pubsub {
855            self.exit_pubsub()?;
856        }
857        self.con.send_bytes(cmd)?;
858        let mut rv = vec![];
859        let mut first_err = None;
860        for idx in 0..(offset + count) {
861            // When processing a transaction, some responses may be errors.
862            // We need to keep processing the rest of the responses in that case,
863            // so bailing early with `?` would not be correct.
864            // See: https://github.com/mitsuhiko/redis-rs/issues/436
865            let response = self.read_response();
866            match response {
867                Ok(item) => {
868                    if idx >= offset {
869                        rv.push(item);
870                    }
871                }
872                Err(err) => {
873                    if first_err.is_none() {
874                        first_err = Some(err);
875                    }
876                }
877            }
878        }
879
880        if let Some(err) = first_err {
881            Err(err)
882        } else {
883            Ok(rv)
884        }
885    }
886
887    fn get_db(&self) -> i64 {
888        self.db
889    }
890
891    fn is_open(&self) -> bool {
892        self.con.is_open()
893    }
894
895    fn check_connection(&mut self) -> bool {
896        cmd("PING").query::<String>(self).is_ok()
897    }
898}
899
900/// The pubsub object provides convenient access to the redis pubsub
901/// system.  Once created you can subscribe and unsubscribe from channels
902/// and listen in on messages.
903///
904/// Example:
905///
906/// ```rust,no_run
907/// # fn do_something() -> redis::RedisResult<()> {
908/// let client = redis::Client::open("redis://127.0.0.1/")?;
909/// let mut con = client.get_connection()?;
910/// let mut pubsub = con.as_pubsub();
911/// pubsub.subscribe("channel_1")?;
912/// pubsub.subscribe("channel_2")?;
913///
914/// loop {
915///     let msg = pubsub.get_message()?;
916///     let payload : String = msg.get_payload()?;
917///     println!("channel '{}': {}", msg.get_channel_name(), payload);
918/// }
919/// # }
920/// ```
921impl<'a> PubSub<'a> {
922    fn new(con: &'a mut Connection) -> Self {
923        Self { con }
924    }
925
926    /// Subscribes to a new channel.
927    pub fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
928        cmd("SUBSCRIBE").arg(channel).query(self.con)
929    }
930
931    /// Subscribes to a new channel with a pattern.
932    pub fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
933        cmd("PSUBSCRIBE").arg(pchannel).query(self.con)
934    }
935
936    /// Unsubscribes from a channel.
937    pub fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
938        cmd("UNSUBSCRIBE").arg(channel).query(self.con)
939    }
940
941    /// Unsubscribes from a channel with a pattern.
942    pub fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
943        cmd("PUNSUBSCRIBE").arg(pchannel).query(self.con)
944    }
945
946    /// Fetches the next message from the pubsub connection.  Blocks until
947    /// a message becomes available.  This currently does not provide a
948    /// wait not to block :(
949    ///
950    /// The message itself is still generic and can be converted into an
951    /// appropriate type through the helper methods on it.
952    pub fn get_message(&mut self) -> RedisResult<Msg> {
953        loop {
954            if let Some(msg) = Msg::from_value(&self.con.recv_response()?) {
955                return Ok(msg);
956            } else {
957                continue;
958            }
959        }
960    }
961
962    /// Sets the read timeout for the connection.
963    ///
964    /// If the provided value is `None`, then `get_message` call will
965    /// block indefinitely. It is an error to pass the zero `Duration` to this
966    /// method.
967    pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
968        self.con.set_read_timeout(dur)
969    }
970}
971
972impl<'a> Drop for PubSub<'a> {
973    fn drop(&mut self) {
974        let _ = self.con.exit_pubsub();
975    }
976}
977
978/// This holds the data that comes from listening to a pubsub
979/// connection.  It only contains actual message data.
980impl Msg {
981    /// Tries to convert provided [`Value`] into [`Msg`].
982    pub fn from_value(value: &Value) -> Option<Self> {
983        let raw_msg: Vec<Value> = from_redis_value(value).ok()?;
984        let mut iter = raw_msg.into_iter();
985        let msg_type: String = from_redis_value(&iter.next()?).ok()?;
986        let mut pattern = None;
987        let payload;
988        let channel;
989
990        if msg_type == "message" {
991            channel = iter.next()?;
992            payload = iter.next()?;
993        } else if msg_type == "pmessage" {
994            pattern = Some(iter.next()?);
995            channel = iter.next()?;
996            payload = iter.next()?;
997        } else {
998            return None;
999        }
1000
1001        Some(Msg {
1002            payload,
1003            channel,
1004            pattern,
1005        })
1006    }
1007
1008    /// Returns the channel this message came on.
1009    pub fn get_channel<T: FromRedisValue>(&self) -> RedisResult<T> {
1010        from_redis_value(&self.channel)
1011    }
1012
1013    /// Convenience method to get a string version of the channel.  Unless
1014    /// your channel contains non utf-8 bytes you can always use this
1015    /// method.  If the channel is not a valid string (which really should
1016    /// not happen) then the return value is `"?"`.
1017    pub fn get_channel_name(&self) -> &str {
1018        match self.channel {
1019            Value::Data(ref bytes) => from_utf8(bytes).unwrap_or("?"),
1020            _ => "?",
1021        }
1022    }
1023
1024    /// Returns the message's payload in a specific format.
1025    pub fn get_payload<T: FromRedisValue>(&self) -> RedisResult<T> {
1026        from_redis_value(&self.payload)
1027    }
1028
1029    /// Returns the bytes that are the message's payload.  This can be used
1030    /// as an alternative to the `get_payload` function if you are interested
1031    /// in the raw bytes in it.
1032    pub fn get_payload_bytes(&self) -> &[u8] {
1033        match self.payload {
1034            Value::Data(ref bytes) => bytes,
1035            _ => b"",
1036        }
1037    }
1038
1039    /// Returns true if the message was constructed from a pattern
1040    /// subscription.
1041    #[allow(clippy::wrong_self_convention)]
1042    pub fn from_pattern(&self) -> bool {
1043        self.pattern.is_some()
1044    }
1045
1046    /// If the message was constructed from a message pattern this can be
1047    /// used to find out which one.  It's recommended to match against
1048    /// an `Option<String>` so that you do not need to use `from_pattern`
1049    /// to figure out if a pattern was set.
1050    pub fn get_pattern<T: FromRedisValue>(&self) -> RedisResult<T> {
1051        match self.pattern {
1052            None => from_redis_value(&Value::Nil),
1053            Some(ref x) => from_redis_value(x),
1054        }
1055    }
1056}
1057
1058/// This function simplifies transaction management slightly.  What it
1059/// does is automatically watching keys and then going into a transaction
1060/// loop util it succeeds.  Once it goes through the results are
1061/// returned.
1062///
1063/// To use the transaction two pieces of information are needed: a list
1064/// of all the keys that need to be watched for modifications and a
1065/// closure with the code that should be execute in the context of the
1066/// transaction.  The closure is invoked with a fresh pipeline in atomic
1067/// mode.  To use the transaction the function needs to return the result
1068/// from querying the pipeline with the connection.
1069///
1070/// The end result of the transaction is then available as the return
1071/// value from the function call.
1072///
1073/// Example:
1074///
1075/// ```rust,no_run
1076/// use redis::Commands;
1077/// # fn do_something() -> redis::RedisResult<()> {
1078/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap();
1079/// # let mut con = client.get_connection().unwrap();
1080/// let key = "the_key";
1081/// let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| {
1082///     let old_val : isize = con.get(key)?;
1083///     pipe
1084///         .set(key, old_val + 1).ignore()
1085///         .get(key).query(con)
1086/// })?;
1087/// println!("The incremented number is: {}", new_val);
1088/// # Ok(()) }
1089/// ```
1090pub fn transaction<
1091    C: ConnectionLike,
1092    K: ToRedisArgs,
1093    T,
1094    F: FnMut(&mut C, &mut Pipeline) -> RedisResult<Option<T>>,
1095>(
1096    con: &mut C,
1097    keys: &[K],
1098    func: F,
1099) -> RedisResult<T> {
1100    let mut func = func;
1101    loop {
1102        cmd("WATCH").arg(keys).query::<()>(con)?;
1103        let mut p = pipe();
1104        let response: Option<T> = func(con, p.atomic())?;
1105        match response {
1106            None => {
1107                continue;
1108            }
1109            Some(response) => {
1110                // make sure no watch is left in the connection, even if
1111                // someone forgot to use the pipeline.
1112                cmd("UNWATCH").query::<()>(con)?;
1113                return Ok(response);
1114            }
1115        }
1116    }
1117}
1118
1119#[cfg(test)]
1120mod tests {
1121    use super::*;
1122
1123    #[test]
1124    fn test_parse_redis_url() {
1125        let cases = vec![
1126            ("redis://127.0.0.1", true),
1127            ("redis+unix:///run/redis.sock", true),
1128            ("unix:///run/redis.sock", true),
1129            ("http://127.0.0.1", false),
1130            ("tcp://127.0.0.1", false),
1131        ];
1132        for (url, expected) in cases.into_iter() {
1133            let res = parse_redis_url(url);
1134            assert_eq!(
1135                res.is_some(),
1136                expected,
1137                "Parsed result of `{}` is not expected",
1138                url,
1139            );
1140        }
1141    }
1142
1143    #[test]
1144    fn test_url_to_tcp_connection_info() {
1145        let cases = vec![
1146            (
1147                url::Url::parse("redis://127.0.0.1").unwrap(),
1148                ConnectionInfo {
1149                    addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
1150                    redis: Default::default(),
1151                },
1152            ),
1153            (
1154                url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(),
1155                ConnectionInfo {
1156                    addr: ConnectionAddr::Tcp("example.com".to_string(), 6379),
1157                    redis: RedisConnectionInfo {
1158                        db: 2,
1159                        username: Some("%johndoe%".to_string()),
1160                        password: Some("#@<>$".to_string()),
1161                    },
1162                },
1163            ),
1164        ];
1165        for (url, expected) in cases.into_iter() {
1166            let res = url_to_tcp_connection_info(url.clone()).unwrap();
1167            assert_eq!(res.addr, expected.addr, "addr of {} is not expected", url);
1168            assert_eq!(
1169                res.redis.db, expected.redis.db,
1170                "db of {} is not expected",
1171                url
1172            );
1173            assert_eq!(
1174                res.redis.username, expected.redis.username,
1175                "username of {} is not expected",
1176                url
1177            );
1178            assert_eq!(
1179                res.redis.password, expected.redis.password,
1180                "password of {} is not expected",
1181                url
1182            );
1183        }
1184    }
1185
1186    #[test]
1187    fn test_url_to_tcp_connection_info_failed() {
1188        let cases = vec![
1189            (url::Url::parse("redis://").unwrap(), "Missing hostname"),
1190            (
1191                url::Url::parse("redis://127.0.0.1/db").unwrap(),
1192                "Invalid database number",
1193            ),
1194            (
1195                url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(),
1196                "Username is not valid UTF-8 string",
1197            ),
1198            (
1199                url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(),
1200                "Password is not valid UTF-8 string",
1201            ),
1202        ];
1203        for (url, expected) in cases.into_iter() {
1204            let res = url_to_tcp_connection_info(url);
1205            assert_eq!(
1206                res.as_ref().unwrap_err().kind(),
1207                crate::ErrorKind::InvalidClientConfig,
1208                "{}",
1209                res.as_ref().unwrap_err(),
1210            );
1211            assert_eq!(
1212                res.as_ref().unwrap_err().to_string(),
1213                expected,
1214                "{}",
1215                res.as_ref().unwrap_err(),
1216            );
1217        }
1218    }
1219
1220    #[test]
1221    #[cfg(unix)]
1222    fn test_url_to_unix_connection_info() {
1223        let cases = vec![
1224            (
1225                url::Url::parse("unix:///var/run/redis.sock").unwrap(),
1226                ConnectionInfo {
1227                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
1228                    redis: RedisConnectionInfo {
1229                        db: 0,
1230                        username: None,
1231                        password: None,
1232                    },
1233                },
1234            ),
1235            (
1236                url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(),
1237                ConnectionInfo {
1238                    addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
1239                    redis: RedisConnectionInfo {
1240                        db: 1,
1241                        username: None,
1242                        password: None,
1243                    },
1244                },
1245            ),
1246            (
1247                url::Url::parse(
1248                    "unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2",
1249                )
1250                .unwrap(),
1251                ConnectionInfo {
1252                    addr: ConnectionAddr::Unix("/example.sock".into()),
1253                    redis: RedisConnectionInfo {
1254                        db: 2,
1255                        username: Some("%johndoe%".to_string()),
1256                        password: Some("#@<>$".to_string()),
1257                    },
1258                },
1259            ),
1260            (
1261                url::Url::parse(
1262                    "redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25",
1263                )
1264                .unwrap(),
1265                ConnectionInfo {
1266                    addr: ConnectionAddr::Unix("/example.sock".into()),
1267                    redis: RedisConnectionInfo {
1268                        db: 2,
1269                        username: Some("%johndoe%".to_string()),
1270                        password: Some("&?= *+".to_string()),
1271                    },
1272                },
1273            ),
1274        ];
1275        for (url, expected) in cases.into_iter() {
1276            assert_eq!(
1277                ConnectionAddr::Unix(url.to_file_path().unwrap()),
1278                expected.addr,
1279                "addr of {} is not expected",
1280                url
1281            );
1282            let res = url_to_unix_connection_info(url.clone()).unwrap();
1283            assert_eq!(res.addr, expected.addr, "addr of {} is not expected", url);
1284            assert_eq!(
1285                res.redis.db, expected.redis.db,
1286                "db of {} is not expected",
1287                url
1288            );
1289            assert_eq!(
1290                res.redis.username, expected.redis.username,
1291                "username of {} is not expected",
1292                url
1293            );
1294            assert_eq!(
1295                res.redis.password, expected.redis.password,
1296                "password of {} is not expected",
1297                url
1298            );
1299        }
1300    }
1301}