1use std::collections::VecDeque;
2use std::fmt;
3use std::io::{self, Write};
4use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs};
5use std::ops::DerefMut;
6use std::path::PathBuf;
7use std::str::{from_utf8, FromStr};
8use std::time::Duration;
9
10use crate::cmd::{cmd, pipe, Cmd};
11use crate::parser::Parser;
12use crate::pipeline::Pipeline;
13use crate::types::{
14 from_redis_value, ErrorKind, FromRedisValue, RedisError, RedisResult, ToRedisArgs, Value,
15};
16
17#[cfg(unix)]
18use crate::types::HashMap;
19#[cfg(unix)]
20use std::os::unix::net::UnixStream;
21
22#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
23use native_tls::{TlsConnector, TlsStream};
24
25#[cfg(feature = "tls-rustls")]
26use rustls::{RootCertStore, StreamOwned};
27#[cfg(feature = "tls-rustls")]
28use std::{convert::TryInto, sync::Arc};
29
30#[cfg(feature = "tls-rustls-webpki-roots")]
31use rustls::OwnedTrustAnchor;
32#[cfg(feature = "tls-rustls-webpki-roots")]
33use webpki_roots::TLS_SERVER_ROOTS;
34
35#[cfg(all(
36 feature = "tls-rustls",
37 not(feature = "tls-native-tls"),
38 not(feature = "tls-rustls-webpki-roots")
39))]
40use rustls_native_certs::load_native_certs;
41
42#[cfg(feature = "tls-rustls")]
43use crate::tls::TlsConnParams;
44
45#[cfg(not(feature = "tls-rustls"))]
47#[derive(Clone, Debug)]
48#[non_exhaustive]
49pub struct TlsConnParams;
50
51static DEFAULT_PORT: u16 = 6379;
52
53#[inline(always)]
54fn connect_tcp(addr: (&str, u16)) -> io::Result<TcpStream> {
55 let socket = TcpStream::connect(addr)?;
56 #[cfg(feature = "tcp_nodelay")]
57 socket.set_nodelay(true)?;
58 #[cfg(feature = "keep-alive")]
59 {
60 const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new();
62 let socket2: socket2::Socket = socket.into();
64 socket2.set_tcp_keepalive(&KEEP_ALIVE)?;
65 Ok(socket2.into())
66 }
67 #[cfg(not(feature = "keep-alive"))]
68 {
69 Ok(socket)
70 }
71}
72
73#[inline(always)]
74fn connect_tcp_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result<TcpStream> {
75 let socket = TcpStream::connect_timeout(addr, timeout)?;
76 #[cfg(feature = "tcp_nodelay")]
77 socket.set_nodelay(true)?;
78 #[cfg(feature = "keep-alive")]
79 {
80 const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new();
82 let socket2: socket2::Socket = socket.into();
84 socket2.set_tcp_keepalive(&KEEP_ALIVE)?;
85 Ok(socket2.into())
86 }
87 #[cfg(not(feature = "keep-alive"))]
88 {
89 Ok(socket)
90 }
91}
92
93pub fn parse_redis_url(input: &str) -> Option<url::Url> {
97 match url::Url::parse(input) {
98 Ok(result) => match result.scheme() {
99 "redis" | "rediss" | "redis+unix" | "unix" => Some(result),
100 _ => None,
101 },
102 Err(_) => None,
103 }
104}
105
106#[derive(Clone, Copy)]
109pub enum TlsMode {
110 Secure,
112 Insecure,
114}
115
116#[derive(Clone, Debug)]
122pub enum ConnectionAddr {
123 Tcp(String, u16),
125 TcpTls {
127 host: String,
129 port: u16,
131 insecure: bool,
140
141 tls_params: Option<TlsConnParams>,
143 },
144 Unix(PathBuf),
146}
147
148impl PartialEq for ConnectionAddr {
149 fn eq(&self, other: &Self) -> bool {
150 match (self, other) {
151 (ConnectionAddr::Tcp(host1, port1), ConnectionAddr::Tcp(host2, port2)) => {
152 host1 == host2 && port1 == port2
153 }
154 (
155 ConnectionAddr::TcpTls {
156 host: host1,
157 port: port1,
158 insecure: insecure1,
159 tls_params: _,
160 },
161 ConnectionAddr::TcpTls {
162 host: host2,
163 port: port2,
164 insecure: insecure2,
165 tls_params: _,
166 },
167 ) => port1 == port2 && host1 == host2 && insecure1 == insecure2,
168 (ConnectionAddr::Unix(path1), ConnectionAddr::Unix(path2)) => path1 == path2,
169 _ => false,
170 }
171 }
172}
173
174impl Eq for ConnectionAddr {}
175
176impl ConnectionAddr {
177 pub fn is_supported(&self) -> bool {
185 match *self {
186 ConnectionAddr::Tcp(_, _) => true,
187 ConnectionAddr::TcpTls { .. } => {
188 cfg!(any(feature = "tls-native-tls", feature = "tls-rustls"))
189 }
190 ConnectionAddr::Unix(_) => cfg!(unix),
191 }
192 }
193}
194
195impl fmt::Display for ConnectionAddr {
196 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
197 match *self {
199 ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"),
200 ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"),
201 ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()),
202 }
203 }
204}
205
206#[derive(Clone, Debug)]
208pub struct ConnectionInfo {
209 pub addr: ConnectionAddr,
211
212 pub redis: RedisConnectionInfo,
214}
215
216#[derive(Clone, Debug, Default)]
218pub struct RedisConnectionInfo {
219 pub db: i64,
221 pub username: Option<String>,
223 pub password: Option<String>,
225}
226
227impl FromStr for ConnectionInfo {
228 type Err = RedisError;
229
230 fn from_str(s: &str) -> Result<Self, Self::Err> {
231 s.into_connection_info()
232 }
233}
234
235pub trait IntoConnectionInfo {
239 fn into_connection_info(self) -> RedisResult<ConnectionInfo>;
241}
242
243impl IntoConnectionInfo for ConnectionInfo {
244 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
245 Ok(self)
246 }
247}
248
249impl IntoConnectionInfo for &str {
258 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
259 match parse_redis_url(self) {
260 Some(u) => u.into_connection_info(),
261 None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
262 }
263 }
264}
265
266impl<T> IntoConnectionInfo for (T, u16)
267where
268 T: Into<String>,
269{
270 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
271 Ok(ConnectionInfo {
272 addr: ConnectionAddr::Tcp(self.0.into(), self.1),
273 redis: RedisConnectionInfo::default(),
274 })
275 }
276}
277
278impl IntoConnectionInfo for String {
287 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
288 match parse_redis_url(&self) {
289 Some(u) => u.into_connection_info(),
290 None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
291 }
292 }
293}
294
295fn url_to_tcp_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
296 let host = match url.host() {
297 Some(host) => {
298 match host {
310 url::Host::Domain(path) => path.to_string(),
311 url::Host::Ipv4(v4) => v4.to_string(),
312 url::Host::Ipv6(v6) => v6.to_string(),
313 }
314 }
315 None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")),
316 };
317 let port = url.port().unwrap_or(DEFAULT_PORT);
318 let addr = if url.scheme() == "rediss" {
319 #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
320 {
321 match url.fragment() {
322 Some("insecure") => ConnectionAddr::TcpTls {
323 host,
324 port,
325 insecure: true,
326 tls_params: None,
327 },
328 Some(_) => fail!((
329 ErrorKind::InvalidClientConfig,
330 "only #insecure is supported as URL fragment"
331 )),
332 _ => ConnectionAddr::TcpTls {
333 host,
334 port,
335 insecure: false,
336 tls_params: None,
337 },
338 }
339 }
340
341 #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
342 fail!((
343 ErrorKind::InvalidClientConfig,
344 "can't connect with TLS, the feature is not enabled"
345 ));
346 } else {
347 ConnectionAddr::Tcp(host, port)
348 };
349 Ok(ConnectionInfo {
350 addr,
351 redis: RedisConnectionInfo {
352 db: match url.path().trim_matches('/') {
353 "" => 0,
354 path => unwrap_or!(
355 path.parse::<i64>().ok(),
356 fail!((ErrorKind::InvalidClientConfig, "Invalid database number"))
357 ),
358 },
359 username: if url.username().is_empty() {
360 None
361 } else {
362 match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() {
363 Ok(decoded) => Some(decoded.into_owned()),
364 Err(_) => fail!((
365 ErrorKind::InvalidClientConfig,
366 "Username is not valid UTF-8 string"
367 )),
368 }
369 },
370 password: match url.password() {
371 Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() {
372 Ok(decoded) => Some(decoded.into_owned()),
373 Err(_) => fail!((
374 ErrorKind::InvalidClientConfig,
375 "Password is not valid UTF-8 string"
376 )),
377 },
378 None => None,
379 },
380 },
381 })
382}
383
384#[cfg(unix)]
385fn url_to_unix_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
386 let query: HashMap<_, _> = url.query_pairs().collect();
387 Ok(ConnectionInfo {
388 addr: ConnectionAddr::Unix(unwrap_or!(
389 url.to_file_path().ok(),
390 fail!((ErrorKind::InvalidClientConfig, "Missing path"))
391 )),
392 redis: RedisConnectionInfo {
393 db: match query.get("db") {
394 Some(db) => unwrap_or!(
395 db.parse::<i64>().ok(),
396 fail!((ErrorKind::InvalidClientConfig, "Invalid database number"))
397 ),
398 None => 0,
399 },
400 username: query.get("user").map(|username| username.to_string()),
401 password: query.get("pass").map(|password| password.to_string()),
402 },
403 })
404}
405
406#[cfg(not(unix))]
407fn url_to_unix_connection_info(_: url::Url) -> RedisResult<ConnectionInfo> {
408 fail!((
409 ErrorKind::InvalidClientConfig,
410 "Unix sockets are not available on this platform."
411 ));
412}
413
414impl IntoConnectionInfo for url::Url {
415 fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
416 match self.scheme() {
417 "redis" | "rediss" => url_to_tcp_connection_info(self),
418 "unix" | "redis+unix" => url_to_unix_connection_info(self),
419 _ => fail!((
420 ErrorKind::InvalidClientConfig,
421 "URL provided is not a redis URL"
422 )),
423 }
424 }
425}
426
427struct TcpConnection {
428 reader: TcpStream,
429 open: bool,
430}
431
432#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
433struct TcpNativeTlsConnection {
434 reader: TlsStream<TcpStream>,
435 open: bool,
436}
437
438#[cfg(feature = "tls-rustls")]
439struct TcpRustlsConnection {
440 reader: StreamOwned<rustls::ClientConnection, TcpStream>,
441 open: bool,
442}
443
444#[cfg(unix)]
445struct UnixConnection {
446 sock: UnixStream,
447 open: bool,
448}
449
450enum ActualConnection {
451 Tcp(TcpConnection),
452 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
453 TcpNativeTls(Box<TcpNativeTlsConnection>),
454 #[cfg(feature = "tls-rustls")]
455 TcpRustls(Box<TcpRustlsConnection>),
456 #[cfg(unix)]
457 Unix(UnixConnection),
458}
459
460#[cfg(feature = "tls-rustls-insecure")]
461struct NoCertificateVerification;
462
463#[cfg(feature = "tls-rustls-insecure")]
464impl rustls::client::ServerCertVerifier for NoCertificateVerification {
465 fn verify_server_cert(
466 &self,
467 _end_entity: &rustls::Certificate,
468 _intermediates: &[rustls::Certificate],
469 _server_name: &rustls::ServerName,
470 _scts: &mut dyn Iterator<Item = &[u8]>,
471 _ocsp: &[u8],
472 _now: std::time::SystemTime,
473 ) -> Result<rustls::client::ServerCertVerified, rustls::Error> {
474 Ok(rustls::client::ServerCertVerified::assertion())
475 }
476}
477
478pub struct Connection {
480 con: ActualConnection,
481 parser: Parser,
482 db: i64,
483
484 pubsub: bool,
489}
490
491pub struct PubSub<'a> {
493 con: &'a mut Connection,
494 waiting_messages: VecDeque<Msg>,
495}
496
497#[derive(Debug)]
499pub struct Msg {
500 payload: Value,
501 channel: Value,
502 pattern: Option<Value>,
503}
504
505impl ActualConnection {
506 pub fn new(addr: &ConnectionAddr, timeout: Option<Duration>) -> RedisResult<ActualConnection> {
507 Ok(match *addr {
508 ConnectionAddr::Tcp(ref host, ref port) => {
509 let addr = (host.as_str(), *port);
510 let tcp = match timeout {
511 None => connect_tcp(addr)?,
512 Some(timeout) => {
513 let mut tcp = None;
514 let mut last_error = None;
515 for addr in addr.to_socket_addrs()? {
516 match connect_tcp_timeout(&addr, timeout) {
517 Ok(l) => {
518 tcp = Some(l);
519 break;
520 }
521 Err(e) => {
522 last_error = Some(e);
523 }
524 };
525 }
526 match (tcp, last_error) {
527 (Some(tcp), _) => tcp,
528 (None, Some(e)) => {
529 fail!(e);
530 }
531 (None, None) => {
532 fail!((
533 ErrorKind::InvalidClientConfig,
534 "could not resolve to any addresses"
535 ));
536 }
537 }
538 }
539 };
540 ActualConnection::Tcp(TcpConnection {
541 reader: tcp,
542 open: true,
543 })
544 }
545 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
546 ConnectionAddr::TcpTls {
547 ref host,
548 port,
549 insecure,
550 ..
551 } => {
552 let tls_connector = if insecure {
553 TlsConnector::builder()
554 .danger_accept_invalid_certs(true)
555 .danger_accept_invalid_hostnames(true)
556 .use_sni(false)
557 .build()?
558 } else {
559 TlsConnector::new()?
560 };
561 let addr = (host.as_str(), port);
562 let tls = match timeout {
563 None => {
564 let tcp = connect_tcp(addr)?;
565 match tls_connector.connect(host, tcp) {
566 Ok(res) => res,
567 Err(e) => {
568 fail!((ErrorKind::IoError, "SSL Handshake error", e.to_string()));
569 }
570 }
571 }
572 Some(timeout) => {
573 let mut tcp = None;
574 let mut last_error = None;
575 for addr in (host.as_str(), port).to_socket_addrs()? {
576 match connect_tcp_timeout(&addr, timeout) {
577 Ok(l) => {
578 tcp = Some(l);
579 break;
580 }
581 Err(e) => {
582 last_error = Some(e);
583 }
584 };
585 }
586 match (tcp, last_error) {
587 (Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(),
588 (None, Some(e)) => {
589 fail!(e);
590 }
591 (None, None) => {
592 fail!((
593 ErrorKind::InvalidClientConfig,
594 "could not resolve to any addresses"
595 ));
596 }
597 }
598 }
599 };
600 ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection {
601 reader: tls,
602 open: true,
603 }))
604 }
605 #[cfg(feature = "tls-rustls")]
606 ConnectionAddr::TcpTls {
607 ref host,
608 port,
609 insecure,
610 ref tls_params,
611 } => {
612 let host: &str = host;
613 let config = create_rustls_config(insecure, tls_params.clone())?;
614 let conn = rustls::ClientConnection::new(Arc::new(config), host.try_into()?)?;
615 let reader = match timeout {
616 None => {
617 let tcp = connect_tcp((host, port))?;
618 StreamOwned::new(conn, tcp)
619 }
620 Some(timeout) => {
621 let mut tcp = None;
622 let mut last_error = None;
623 for addr in (host, port).to_socket_addrs()? {
624 match connect_tcp_timeout(&addr, timeout) {
625 Ok(l) => {
626 tcp = Some(l);
627 break;
628 }
629 Err(e) => {
630 last_error = Some(e);
631 }
632 };
633 }
634 match (tcp, last_error) {
635 (Some(tcp), _) => StreamOwned::new(conn, tcp),
636 (None, Some(e)) => {
637 fail!(e);
638 }
639 (None, None) => {
640 fail!((
641 ErrorKind::InvalidClientConfig,
642 "could not resolve to any addresses"
643 ));
644 }
645 }
646 }
647 };
648
649 ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true }))
650 }
651 #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
652 ConnectionAddr::TcpTls { .. } => {
653 fail!((
654 ErrorKind::InvalidClientConfig,
655 "Cannot connect to TCP with TLS without the tls feature"
656 ));
657 }
658 #[cfg(unix)]
659 ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection {
660 sock: UnixStream::connect(path)?,
661 open: true,
662 }),
663 #[cfg(not(unix))]
664 ConnectionAddr::Unix(ref _path) => {
665 fail!((
666 ErrorKind::InvalidClientConfig,
667 "Cannot connect to unix sockets \
668 on this platform"
669 ));
670 }
671 })
672 }
673
674 pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
675 match *self {
676 ActualConnection::Tcp(ref mut connection) => {
677 let res = connection.reader.write_all(bytes).map_err(RedisError::from);
678 match res {
679 Err(e) => {
680 if e.is_connection_dropped() {
681 connection.open = false;
682 }
683 Err(e)
684 }
685 Ok(_) => Ok(Value::Okay),
686 }
687 }
688 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
689 ActualConnection::TcpNativeTls(ref mut connection) => {
690 let res = connection.reader.write_all(bytes).map_err(RedisError::from);
691 match res {
692 Err(e) => {
693 if e.is_connection_dropped() {
694 connection.open = false;
695 }
696 Err(e)
697 }
698 Ok(_) => Ok(Value::Okay),
699 }
700 }
701 #[cfg(feature = "tls-rustls")]
702 ActualConnection::TcpRustls(ref mut connection) => {
703 let res = connection.reader.write_all(bytes).map_err(RedisError::from);
704 match res {
705 Err(e) => {
706 if e.is_connection_dropped() {
707 connection.open = false;
708 }
709 Err(e)
710 }
711 Ok(_) => Ok(Value::Okay),
712 }
713 }
714 #[cfg(unix)]
715 ActualConnection::Unix(ref mut connection) => {
716 let result = connection.sock.write_all(bytes).map_err(RedisError::from);
717 match result {
718 Err(e) => {
719 if e.is_connection_dropped() {
720 connection.open = false;
721 }
722 Err(e)
723 }
724 Ok(_) => Ok(Value::Okay),
725 }
726 }
727 }
728 }
729
730 pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
731 match *self {
732 ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
733 reader.set_write_timeout(dur)?;
734 }
735 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
736 ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
737 let reader = &(boxed_tls_connection.reader);
738 reader.get_ref().set_write_timeout(dur)?;
739 }
740 #[cfg(feature = "tls-rustls")]
741 ActualConnection::TcpRustls(ref boxed_tls_connection) => {
742 let reader = &(boxed_tls_connection.reader);
743 reader.get_ref().set_write_timeout(dur)?;
744 }
745 #[cfg(unix)]
746 ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
747 sock.set_write_timeout(dur)?;
748 }
749 }
750 Ok(())
751 }
752
753 pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
754 match *self {
755 ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
756 reader.set_read_timeout(dur)?;
757 }
758 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
759 ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
760 let reader = &(boxed_tls_connection.reader);
761 reader.get_ref().set_read_timeout(dur)?;
762 }
763 #[cfg(feature = "tls-rustls")]
764 ActualConnection::TcpRustls(ref boxed_tls_connection) => {
765 let reader = &(boxed_tls_connection.reader);
766 reader.get_ref().set_read_timeout(dur)?;
767 }
768 #[cfg(unix)]
769 ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
770 sock.set_read_timeout(dur)?;
771 }
772 }
773 Ok(())
774 }
775
776 pub fn is_open(&self) -> bool {
777 match *self {
778 ActualConnection::Tcp(TcpConnection { open, .. }) => open,
779 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
780 ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open,
781 #[cfg(feature = "tls-rustls")]
782 ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open,
783 #[cfg(unix)]
784 ActualConnection::Unix(UnixConnection { open, .. }) => open,
785 }
786 }
787}
788
789#[cfg(feature = "tls-rustls")]
790pub(crate) fn create_rustls_config(
791 insecure: bool,
792 tls_params: Option<TlsConnParams>,
793) -> RedisResult<rustls::ClientConfig> {
794 use crate::tls::ClientTlsParams;
795
796 let mut root_store = RootCertStore::empty();
797 #[cfg(feature = "tls-rustls-webpki-roots")]
798 root_store.add_trust_anchors(TLS_SERVER_ROOTS.0.iter().map(|ta| {
799 OwnedTrustAnchor::from_subject_spki_name_constraints(
800 ta.subject,
801 ta.spki,
802 ta.name_constraints,
803 )
804 }));
805 #[cfg(all(feature = "tls-rustls", not(feature = "tls-rustls-webpki-roots")))]
806 for cert in load_native_certs()? {
807 root_store.add(&rustls::Certificate(cert.0))?;
808 }
809
810 let config = rustls::ClientConfig::builder()
811 .with_safe_default_cipher_suites()
812 .with_safe_default_kx_groups()
813 .with_protocol_versions(rustls::ALL_VERSIONS)?;
814
815 let config = if let Some(tls_params) = tls_params {
816 let config_builder =
817 config.with_root_certificates(tls_params.root_cert_store.unwrap_or(root_store));
818
819 if let Some(ClientTlsParams {
820 client_cert_chain: client_cert,
821 client_key,
822 }) = tls_params.client_tls_params
823 {
824 config_builder
825 .with_client_auth_cert(client_cert, client_key)
826 .map_err(|err| {
827 RedisError::from((
828 ErrorKind::InvalidClientConfig,
829 "Unable to build client with TLS parameters provided.",
830 err.to_string(),
831 ))
832 })?
833 } else {
834 config_builder.with_no_client_auth()
835 }
836 } else {
837 config
838 .with_root_certificates(root_store)
839 .with_no_client_auth()
840 };
841
842 match (insecure, cfg!(feature = "tls-rustls-insecure")) {
843 #[cfg(feature = "tls-rustls-insecure")]
844 (true, true) => {
845 let mut config = config;
846 config.enable_sni = false;
847 config
848 .dangerous()
849 .set_certificate_verifier(Arc::new(NoCertificateVerification));
850
851 Ok(config)
852 }
853 (true, false) => {
854 fail!((
855 ErrorKind::InvalidClientConfig,
856 "Cannot create insecure client without tls-rustls-insecure feature"
857 ));
858 }
859 _ => Ok(config),
860 }
861}
862
863fn connect_auth(con: &mut Connection, connection_info: &RedisConnectionInfo) -> RedisResult<()> {
864 let mut command = cmd("AUTH");
865 if let Some(username) = &connection_info.username {
866 command.arg(username);
867 }
868 let password = connection_info.password.as_ref().unwrap();
869 let err = match command.arg(password).query::<Value>(con) {
870 Ok(Value::Okay) => return Ok(()),
871 Ok(_) => {
872 fail!((
873 ErrorKind::ResponseError,
874 "Redis server refused to authenticate, returns Ok() != Value::Okay"
875 ));
876 }
877 Err(e) => e,
878 };
879 let err_msg = err.detail().ok_or((
880 ErrorKind::AuthenticationFailed,
881 "Password authentication failed",
882 ))?;
883 if !err_msg.contains("wrong number of arguments for 'auth' command") {
884 fail!((
885 ErrorKind::AuthenticationFailed,
886 "Password authentication failed",
887 ));
888 }
889
890 let mut command = cmd("AUTH");
892 match command.arg(password).query::<Value>(con) {
893 Ok(Value::Okay) => Ok(()),
894 _ => fail!((
895 ErrorKind::AuthenticationFailed,
896 "Password authentication failed",
897 )),
898 }
899}
900
901pub fn connect(
902 connection_info: &ConnectionInfo,
903 timeout: Option<Duration>,
904) -> RedisResult<Connection> {
905 let con = ActualConnection::new(&connection_info.addr, timeout)?;
906 setup_connection(con, &connection_info.redis)
907}
908
909pub(crate) fn client_set_info_pipeline() -> Pipeline {
910 let mut pipeline = crate::pipe();
911 pipeline
912 .cmd("CLIENT")
913 .arg("SETINFO")
914 .arg("LIB-NAME")
915 .arg("redis-rs")
916 .ignore();
917 pipeline
918 .cmd("CLIENT")
919 .arg("SETINFO")
920 .arg("LIB-VER")
921 .arg(env!("CARGO_PKG_VERSION"))
922 .ignore();
923 pipeline
924}
925
926fn setup_connection(
927 con: ActualConnection,
928 connection_info: &RedisConnectionInfo,
929) -> RedisResult<Connection> {
930 let mut rv = Connection {
931 con,
932 parser: Parser::new(),
933 db: connection_info.db,
934 pubsub: false,
935 };
936
937 if connection_info.password.is_some() {
938 connect_auth(&mut rv, connection_info)?;
939 }
940
941 if connection_info.db != 0 {
942 match cmd("SELECT")
943 .arg(connection_info.db)
944 .query::<Value>(&mut rv)
945 {
946 Ok(Value::Okay) => {}
947 _ => fail!((
948 ErrorKind::ResponseError,
949 "Redis server refused to switch database"
950 )),
951 }
952 }
953
954 let _: RedisResult<()> = client_set_info_pipeline().query(&mut rv);
957
958 Ok(rv)
959}
960
961pub trait ConnectionLike {
972 fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value>;
975
976 fn req_packed_commands(
980 &mut self,
981 cmd: &[u8],
982 offset: usize,
983 count: usize,
984 ) -> RedisResult<Vec<Value>>;
985
986 fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
988 let pcmd = cmd.get_packed_command();
989 self.req_packed_command(&pcmd)
990 }
991
992 fn get_db(&self) -> i64;
997
998 #[doc(hidden)]
1000 fn supports_pipelining(&self) -> bool {
1001 true
1002 }
1003
1004 fn check_connection(&mut self) -> bool;
1006
1007 fn is_open(&self) -> bool;
1015}
1016
1017impl Connection {
1025 pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> {
1030 self.con.send_bytes(cmd)?;
1031 Ok(())
1032 }
1033
1034 pub fn recv_response(&mut self) -> RedisResult<Value> {
1037 self.read_response()
1038 }
1039
1040 pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1046 self.con.set_write_timeout(dur)
1047 }
1048
1049 pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1055 self.con.set_read_timeout(dur)
1056 }
1057
1058 pub fn as_pubsub(&mut self) -> PubSub<'_> {
1060 PubSub::new(self)
1064 }
1065
1066 fn exit_pubsub(&mut self) -> RedisResult<()> {
1067 let res = self.clear_active_subscriptions();
1068 if res.is_ok() {
1069 self.pubsub = false;
1070 } else {
1071 self.pubsub = true;
1073 }
1074
1075 res
1076 }
1077
1078 fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
1083 {
1089 let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command();
1091 let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command();
1092
1093 let con = &mut self.con;
1096
1097 con.send_bytes(&unsubscribe)?;
1099 con.send_bytes(&punsubscribe)?;
1100 }
1101
1102 let mut received_unsub = false;
1108 let mut received_punsub = false;
1109 loop {
1110 let res: (Vec<u8>, (), isize) = from_redis_value(&self.recv_response()?)?;
1111
1112 match res.0.first() {
1113 Some(&b'u') => received_unsub = true,
1114 Some(&b'p') => received_punsub = true,
1115 _ => (),
1116 }
1117
1118 if received_unsub && received_punsub && res.2 == 0 {
1119 break;
1120 }
1121 }
1122
1123 Ok(())
1126 }
1127
1128 fn read_response(&mut self) -> RedisResult<Value> {
1130 let result = match self.con {
1131 ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => {
1132 self.parser.parse_value(reader)
1133 }
1134 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1135 ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => {
1136 let reader = &mut boxed_tls_connection.reader;
1137 self.parser.parse_value(reader)
1138 }
1139 #[cfg(feature = "tls-rustls")]
1140 ActualConnection::TcpRustls(ref mut boxed_tls_connection) => {
1141 let reader = &mut boxed_tls_connection.reader;
1142 self.parser.parse_value(reader)
1143 }
1144 #[cfg(unix)]
1145 ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => {
1146 self.parser.parse_value(sock)
1147 }
1148 };
1149 if let Err(e) = &result {
1151 let shutdown = match e.as_io_error() {
1152 Some(e) => e.kind() == io::ErrorKind::UnexpectedEof,
1153 None => false,
1154 };
1155 if shutdown {
1156 match self.con {
1157 ActualConnection::Tcp(ref mut connection) => {
1158 let _ = connection.reader.shutdown(net::Shutdown::Both);
1159 connection.open = false;
1160 }
1161 #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
1162 ActualConnection::TcpNativeTls(ref mut connection) => {
1163 let _ = connection.reader.shutdown();
1164 connection.open = false;
1165 }
1166 #[cfg(feature = "tls-rustls")]
1167 ActualConnection::TcpRustls(ref mut connection) => {
1168 let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both);
1169 connection.open = false;
1170 }
1171 #[cfg(unix)]
1172 ActualConnection::Unix(ref mut connection) => {
1173 let _ = connection.sock.shutdown(net::Shutdown::Both);
1174 connection.open = false;
1175 }
1176 }
1177 }
1178 }
1179 result
1180 }
1181}
1182
1183impl ConnectionLike for Connection {
1184 fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1185 if self.pubsub {
1186 self.exit_pubsub()?;
1187 }
1188
1189 self.con.send_bytes(cmd)?;
1190 self.read_response()
1191 }
1192
1193 fn req_packed_commands(
1194 &mut self,
1195 cmd: &[u8],
1196 offset: usize,
1197 count: usize,
1198 ) -> RedisResult<Vec<Value>> {
1199 if self.pubsub {
1200 self.exit_pubsub()?;
1201 }
1202 self.con.send_bytes(cmd)?;
1203 let mut rv = vec![];
1204 let mut first_err = None;
1205 for idx in 0..(offset + count) {
1206 let response = self.read_response();
1211 match response {
1212 Ok(item) => {
1213 if idx >= offset {
1214 rv.push(item);
1215 }
1216 }
1217 Err(err) => {
1218 if first_err.is_none() {
1219 first_err = Some(err);
1220 }
1221 }
1222 }
1223 }
1224
1225 first_err.map_or(Ok(rv), Err)
1226 }
1227
1228 fn get_db(&self) -> i64 {
1229 self.db
1230 }
1231
1232 fn is_open(&self) -> bool {
1233 self.con.is_open()
1234 }
1235
1236 fn check_connection(&mut self) -> bool {
1237 cmd("PING").query::<String>(self).is_ok()
1238 }
1239}
1240
1241impl<C, T> ConnectionLike for T
1242where
1243 C: ConnectionLike,
1244 T: DerefMut<Target = C>,
1245{
1246 fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
1247 self.deref_mut().req_packed_command(cmd)
1248 }
1249
1250 fn req_packed_commands(
1251 &mut self,
1252 cmd: &[u8],
1253 offset: usize,
1254 count: usize,
1255 ) -> RedisResult<Vec<Value>> {
1256 self.deref_mut().req_packed_commands(cmd, offset, count)
1257 }
1258
1259 fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
1260 self.deref_mut().req_command(cmd)
1261 }
1262
1263 fn get_db(&self) -> i64 {
1264 self.deref().get_db()
1265 }
1266
1267 fn supports_pipelining(&self) -> bool {
1268 self.deref().supports_pipelining()
1269 }
1270
1271 fn check_connection(&mut self) -> bool {
1272 self.deref_mut().check_connection()
1273 }
1274
1275 fn is_open(&self) -> bool {
1276 self.deref().is_open()
1277 }
1278}
1279
1280impl<'a> PubSub<'a> {
1302 fn new(con: &'a mut Connection) -> Self {
1303 Self {
1304 con,
1305 waiting_messages: VecDeque::new(),
1306 }
1307 }
1308
1309 fn cache_messages_until_received_response(&mut self, cmd: &Cmd) -> RedisResult<()> {
1310 let mut response = self.con.req_packed_command(&cmd.get_packed_command())?;
1311 loop {
1312 if let Some(msg) = Msg::from_value(&response) {
1313 self.waiting_messages.push_back(msg);
1314 } else {
1315 return Ok(());
1316 }
1317 response = self.con.recv_response()?;
1318 }
1319 }
1320
1321 pub fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1323 self.cache_messages_until_received_response(cmd("SUBSCRIBE").arg(channel))
1324 }
1325
1326 pub fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1328 self.cache_messages_until_received_response(cmd("PSUBSCRIBE").arg(pchannel))
1329 }
1330
1331 pub fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
1333 self.cache_messages_until_received_response(cmd("UNSUBSCRIBE").arg(channel))
1334 }
1335
1336 pub fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
1338 self.cache_messages_until_received_response(cmd("PUNSUBSCRIBE").arg(pchannel))
1339 }
1340
1341 pub fn get_message(&mut self) -> RedisResult<Msg> {
1348 if let Some(msg) = self.waiting_messages.pop_front() {
1349 return Ok(msg);
1350 }
1351 loop {
1352 if let Some(msg) = Msg::from_value(&self.con.recv_response()?) {
1353 return Ok(msg);
1354 } else {
1355 continue;
1356 }
1357 }
1358 }
1359
1360 pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
1366 self.con.set_read_timeout(dur)
1367 }
1368}
1369
1370impl<'a> Drop for PubSub<'a> {
1371 fn drop(&mut self) {
1372 let _ = self.con.exit_pubsub();
1373 }
1374}
1375
1376impl Msg {
1379 pub fn from_value(value: &Value) -> Option<Self> {
1381 let raw_msg: Vec<Value> = from_redis_value(value).ok()?;
1382 let mut iter = raw_msg.into_iter();
1383 let msg_type: String = from_redis_value(&iter.next()?).ok()?;
1384 let mut pattern = None;
1385 let payload;
1386 let channel;
1387
1388 if msg_type == "message" {
1389 channel = iter.next()?;
1390 payload = iter.next()?;
1391 } else if msg_type == "pmessage" {
1392 pattern = Some(iter.next()?);
1393 channel = iter.next()?;
1394 payload = iter.next()?;
1395 } else {
1396 return None;
1397 }
1398
1399 Some(Msg {
1400 payload,
1401 channel,
1402 pattern,
1403 })
1404 }
1405
1406 pub fn get_channel<T: FromRedisValue>(&self) -> RedisResult<T> {
1408 from_redis_value(&self.channel)
1409 }
1410
1411 pub fn get_channel_name(&self) -> &str {
1416 match self.channel {
1417 Value::Data(ref bytes) => from_utf8(bytes).unwrap_or("?"),
1418 _ => "?",
1419 }
1420 }
1421
1422 pub fn get_payload<T: FromRedisValue>(&self) -> RedisResult<T> {
1424 from_redis_value(&self.payload)
1425 }
1426
1427 pub fn get_payload_bytes(&self) -> &[u8] {
1431 match self.payload {
1432 Value::Data(ref bytes) => bytes,
1433 _ => b"",
1434 }
1435 }
1436
1437 #[allow(clippy::wrong_self_convention)]
1440 pub fn from_pattern(&self) -> bool {
1441 self.pattern.is_some()
1442 }
1443
1444 pub fn get_pattern<T: FromRedisValue>(&self) -> RedisResult<T> {
1449 match self.pattern {
1450 None => from_redis_value(&Value::Nil),
1451 Some(ref x) => from_redis_value(x),
1452 }
1453 }
1454}
1455
1456pub fn transaction<
1489 C: ConnectionLike,
1490 K: ToRedisArgs,
1491 T,
1492 F: FnMut(&mut C, &mut Pipeline) -> RedisResult<Option<T>>,
1493>(
1494 con: &mut C,
1495 keys: &[K],
1496 func: F,
1497) -> RedisResult<T> {
1498 let mut func = func;
1499 loop {
1500 cmd("WATCH").arg(keys).query::<()>(con)?;
1501 let mut p = pipe();
1502 let response: Option<T> = func(con, p.atomic())?;
1503 match response {
1504 None => {
1505 continue;
1506 }
1507 Some(response) => {
1508 cmd("UNWATCH").query::<()>(con)?;
1511 return Ok(response);
1512 }
1513 }
1514 }
1515}
1516
1517#[cfg(test)]
1518mod tests {
1519 use super::*;
1520
1521 #[test]
1522 fn test_parse_redis_url() {
1523 let cases = vec![
1524 ("redis://127.0.0.1", true),
1525 ("redis://[::1]", true),
1526 ("redis+unix:///run/redis.sock", true),
1527 ("unix:///run/redis.sock", true),
1528 ("http://127.0.0.1", false),
1529 ("tcp://127.0.0.1", false),
1530 ];
1531 for (url, expected) in cases.into_iter() {
1532 let res = parse_redis_url(url);
1533 assert_eq!(
1534 res.is_some(),
1535 expected,
1536 "Parsed result of `{url}` is not expected",
1537 );
1538 }
1539 }
1540
1541 #[test]
1542 fn test_url_to_tcp_connection_info() {
1543 let cases = vec![
1544 (
1545 url::Url::parse("redis://127.0.0.1").unwrap(),
1546 ConnectionInfo {
1547 addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
1548 redis: Default::default(),
1549 },
1550 ),
1551 (
1552 url::Url::parse("redis://[::1]").unwrap(),
1553 ConnectionInfo {
1554 addr: ConnectionAddr::Tcp("::1".to_string(), 6379),
1555 redis: Default::default(),
1556 },
1557 ),
1558 (
1559 url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(),
1560 ConnectionInfo {
1561 addr: ConnectionAddr::Tcp("example.com".to_string(), 6379),
1562 redis: RedisConnectionInfo {
1563 db: 2,
1564 username: Some("%johndoe%".to_string()),
1565 password: Some("#@<>$".to_string()),
1566 },
1567 },
1568 ),
1569 ];
1570 for (url, expected) in cases.into_iter() {
1571 let res = url_to_tcp_connection_info(url.clone()).unwrap();
1572 assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
1573 assert_eq!(
1574 res.redis.db, expected.redis.db,
1575 "db of {url} is not expected",
1576 );
1577 assert_eq!(
1578 res.redis.username, expected.redis.username,
1579 "username of {url} is not expected",
1580 );
1581 assert_eq!(
1582 res.redis.password, expected.redis.password,
1583 "password of {url} is not expected",
1584 );
1585 }
1586 }
1587
1588 #[test]
1589 fn test_url_to_tcp_connection_info_failed() {
1590 let cases = vec![
1591 (url::Url::parse("redis://").unwrap(), "Missing hostname"),
1592 (
1593 url::Url::parse("redis://127.0.0.1/db").unwrap(),
1594 "Invalid database number",
1595 ),
1596 (
1597 url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(),
1598 "Username is not valid UTF-8 string",
1599 ),
1600 (
1601 url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(),
1602 "Password is not valid UTF-8 string",
1603 ),
1604 ];
1605 for (url, expected) in cases.into_iter() {
1606 let res = url_to_tcp_connection_info(url).unwrap_err();
1607 assert_eq!(
1608 res.kind(),
1609 crate::ErrorKind::InvalidClientConfig,
1610 "{}",
1611 &res,
1612 );
1613 #[allow(deprecated)]
1614 let desc = std::error::Error::description(&res);
1615 assert_eq!(desc, expected, "{}", &res);
1616 assert_eq!(res.detail(), None, "{}", &res);
1617 }
1618 }
1619
1620 #[test]
1621 #[cfg(unix)]
1622 fn test_url_to_unix_connection_info() {
1623 let cases = vec![
1624 (
1625 url::Url::parse("unix:///var/run/redis.sock").unwrap(),
1626 ConnectionInfo {
1627 addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
1628 redis: RedisConnectionInfo {
1629 db: 0,
1630 username: None,
1631 password: None,
1632 },
1633 },
1634 ),
1635 (
1636 url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(),
1637 ConnectionInfo {
1638 addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
1639 redis: RedisConnectionInfo {
1640 db: 1,
1641 username: None,
1642 password: None,
1643 },
1644 },
1645 ),
1646 (
1647 url::Url::parse(
1648 "unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2",
1649 )
1650 .unwrap(),
1651 ConnectionInfo {
1652 addr: ConnectionAddr::Unix("/example.sock".into()),
1653 redis: RedisConnectionInfo {
1654 db: 2,
1655 username: Some("%johndoe%".to_string()),
1656 password: Some("#@<>$".to_string()),
1657 },
1658 },
1659 ),
1660 (
1661 url::Url::parse(
1662 "redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25",
1663 )
1664 .unwrap(),
1665 ConnectionInfo {
1666 addr: ConnectionAddr::Unix("/example.sock".into()),
1667 redis: RedisConnectionInfo {
1668 db: 2,
1669 username: Some("%johndoe%".to_string()),
1670 password: Some("&?= *+".to_string()),
1671 },
1672 },
1673 ),
1674 ];
1675 for (url, expected) in cases.into_iter() {
1676 assert_eq!(
1677 ConnectionAddr::Unix(url.to_file_path().unwrap()),
1678 expected.addr,
1679 "addr of {url} is not expected",
1680 );
1681 let res = url_to_unix_connection_info(url.clone()).unwrap();
1682 assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
1683 assert_eq!(
1684 res.redis.db, expected.redis.db,
1685 "db of {url} is not expected",
1686 );
1687 assert_eq!(
1688 res.redis.username, expected.redis.username,
1689 "username of {url} is not expected",
1690 );
1691 assert_eq!(
1692 res.redis.password, expected.redis.password,
1693 "password of {url} is not expected",
1694 );
1695 }
1696 }
1697}