use std::borrow::Cow;
use std::collections::VecDeque;
use std::fmt;
use std::io::{self, Write};
use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs};
use std::ops::DerefMut;
use std::path::PathBuf;
use std::str::{FromStr, from_utf8};
use std::time::{Duration, Instant};
use crate::cmd::{Cmd, cmd, pipe};
use crate::errors::{ErrorKind, RedisError, ServerError, ServerErrorKind};
use crate::io::tcp::{TcpSettings, stream_with_settings};
use crate::parser::Parser;
use crate::pipeline::Pipeline;
use crate::types::{
FromRedisValue, HashMap, PushKind, RedisResult, SyncPushSender, ToRedisArgs, Value,
from_redis_value_ref,
};
use crate::{ProtocolVersion, check_resp3, from_redis_value};
#[cfg(unix)]
use std::os::unix::net::UnixStream;
use crate::commands::resp3_hello;
use arcstr::ArcStr;
#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
use native_tls::{TlsConnector, TlsStream};
#[cfg(feature = "tls-rustls")]
use rustls::{RootCertStore, StreamOwned};
#[cfg(feature = "tls-rustls")]
use std::sync::Arc;
use crate::PushInfo;
#[cfg(all(
feature = "tls-rustls",
not(feature = "tls-native-tls"),
not(feature = "tls-rustls-webpki-roots")
))]
use rustls_native_certs::load_native_certs;
#[cfg(feature = "tls-rustls")]
use crate::tls::ClientTlsParams;
#[derive(Clone, Debug)]
pub struct TlsConnParams {
#[cfg(feature = "tls-rustls")]
pub(crate) client_tls_params: Option<ClientTlsParams>,
#[cfg(feature = "tls-rustls")]
pub(crate) root_cert_store: Option<RootCertStore>,
#[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
pub(crate) danger_accept_invalid_hostnames: bool,
}
static DEFAULT_PORT: u16 = 6379;
const DEFAULT_CLIENT_SETINFO_LIB_NAME: &str = "redis-rs";
const DEFAULT_CLIENT_SETINFO_LIB_VER: &str = env!("CARGO_PKG_VERSION");
#[inline(always)]
fn connect_tcp(addr: (&str, u16), tcp_settings: &TcpSettings) -> io::Result<TcpStream> {
let socket = TcpStream::connect(addr)?;
stream_with_settings(socket, tcp_settings)
}
#[inline(always)]
fn connect_tcp_timeout(
addr: &SocketAddr,
timeout: Duration,
tcp_settings: &TcpSettings,
) -> io::Result<TcpStream> {
let socket = TcpStream::connect_timeout(addr, timeout)?;
stream_with_settings(socket, tcp_settings)
}
pub fn parse_redis_url(input: &str) -> Option<url::Url> {
match url::Url::parse(input) {
Ok(result) => match result.scheme() {
"redis" | "rediss" | "valkey" | "valkeys" | "redis+unix" | "valkey+unix" | "unix" => {
Some(result)
}
_ => None,
},
Err(_) => None,
}
}
#[derive(Clone, Copy, PartialEq)]
#[non_exhaustive]
pub enum TlsMode {
Secure,
Insecure,
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum ConnectionAddr {
Tcp(String, u16),
TcpTls {
host: String,
port: u16,
insecure: bool,
tls_params: Option<TlsConnParams>,
},
Unix(PathBuf),
}
impl PartialEq for ConnectionAddr {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(ConnectionAddr::Tcp(host1, port1), ConnectionAddr::Tcp(host2, port2)) => {
host1 == host2 && port1 == port2
}
(
ConnectionAddr::TcpTls {
host: host1,
port: port1,
insecure: insecure1,
tls_params: _,
},
ConnectionAddr::TcpTls {
host: host2,
port: port2,
insecure: insecure2,
tls_params: _,
},
) => port1 == port2 && host1 == host2 && insecure1 == insecure2,
(ConnectionAddr::Unix(path1), ConnectionAddr::Unix(path2)) => path1 == path2,
_ => false,
}
}
}
impl Eq for ConnectionAddr {}
impl ConnectionAddr {
pub fn is_supported(&self) -> bool {
match *self {
ConnectionAddr::Tcp(_, _) => true,
ConnectionAddr::TcpTls { .. } => {
cfg!(any(feature = "tls-native-tls", feature = "tls-rustls"))
}
ConnectionAddr::Unix(_) => cfg!(unix),
}
}
#[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
pub fn set_danger_accept_invalid_hostnames(&mut self, insecure: bool) {
if let ConnectionAddr::TcpTls { tls_params, .. } = self {
if let Some(params) = tls_params {
params.danger_accept_invalid_hostnames = insecure;
} else if insecure {
*tls_params = Some(TlsConnParams {
#[cfg(feature = "tls-rustls")]
client_tls_params: None,
#[cfg(feature = "tls-rustls")]
root_cert_store: None,
danger_accept_invalid_hostnames: insecure,
});
}
}
}
#[cfg(feature = "cluster")]
pub(crate) fn tls_mode(&self) -> Option<TlsMode> {
match self {
ConnectionAddr::TcpTls { insecure, .. } => {
if *insecure {
Some(TlsMode::Insecure)
} else {
Some(TlsMode::Secure)
}
}
_ => None,
}
}
}
impl fmt::Display for ConnectionAddr {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"),
ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"),
ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()),
}
}
}
#[derive(Clone, Debug)]
pub struct ConnectionInfo {
pub(crate) addr: ConnectionAddr,
pub(crate) tcp_settings: TcpSettings,
pub(crate) redis: RedisConnectionInfo,
}
impl ConnectionInfo {
pub fn addr(&self) -> &ConnectionAddr {
&self.addr
}
pub fn tcp_settings(&self) -> &TcpSettings {
&self.tcp_settings
}
pub fn redis_settings(&self) -> &RedisConnectionInfo {
&self.redis
}
pub fn set_addr(mut self, addr: ConnectionAddr) -> Self {
self.addr = addr;
self
}
pub fn set_tcp_settings(mut self, tcp_settings: TcpSettings) -> Self {
self.tcp_settings = tcp_settings;
self
}
pub fn set_redis_settings(mut self, redis: RedisConnectionInfo) -> Self {
self.redis = redis;
self
}
}
#[derive(Clone, Default)]
pub struct RedisConnectionInfo {
pub(crate) db: i64,
pub(crate) username: Option<ArcStr>,
pub(crate) password: Option<ArcStr>,
pub(crate) protocol: ProtocolVersion,
pub(crate) skip_set_lib_name: bool,
pub(crate) lib_name: Option<ArcStr>,
pub(crate) lib_ver: Option<ArcStr>,
}
impl RedisConnectionInfo {
pub fn username(&self) -> Option<&str> {
self.username.as_deref()
}
pub fn password(&self) -> Option<&str> {
self.password.as_deref()
}
pub fn protocol(&self) -> ProtocolVersion {
self.protocol
}
pub fn skip_set_lib_name(&self) -> bool {
self.skip_set_lib_name
}
pub fn lib_name(&self) -> Option<&str> {
self.lib_name.as_deref()
}
pub fn lib_ver(&self) -> Option<&str> {
self.lib_ver.as_deref()
}
pub fn db(&self) -> i64 {
self.db
}
pub fn set_username(mut self, username: impl AsRef<str>) -> Self {
self.username = Some(username.as_ref().into());
self
}
pub fn set_password(mut self, password: impl AsRef<str>) -> Self {
self.password = Some(password.as_ref().into());
self
}
pub fn set_protocol(mut self, protocol: ProtocolVersion) -> Self {
self.protocol = protocol;
self
}
pub fn set_skip_set_lib_name(mut self) -> Self {
self.skip_set_lib_name = true;
self
}
pub fn set_lib_name(mut self, lib_name: impl AsRef<str>, lib_ver: impl AsRef<str>) -> Self {
self.lib_name = Some(lib_name.as_ref().into());
self.lib_ver = Some(lib_ver.as_ref().into());
self.skip_set_lib_name = false;
self
}
pub fn set_db(mut self, db: i64) -> Self {
self.db = db;
self
}
}
impl std::fmt::Debug for RedisConnectionInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let RedisConnectionInfo {
db,
username,
password,
protocol,
skip_set_lib_name,
lib_name,
lib_ver,
} = self;
let mut debug_info = f.debug_struct("RedisConnectionInfo");
debug_info.field("db", &db);
debug_info.field("username", &username);
debug_info.field("password", &password.as_ref().map(|_| "<redacted>"));
debug_info.field("protocol", &protocol);
debug_info.field("skip_set_lib_name", &skip_set_lib_name);
debug_info.field("lib_name", &lib_name);
debug_info.field("lib_ver", &lib_ver);
debug_info.finish()
}
}
impl FromStr for ConnectionInfo {
type Err = RedisError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
s.into_connection_info()
}
}
pub trait IntoConnectionInfo {
fn into_connection_info(self) -> RedisResult<ConnectionInfo>;
}
impl IntoConnectionInfo for ConnectionInfo {
fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
Ok(self)
}
}
impl IntoConnectionInfo for ConnectionAddr {
fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
Ok(ConnectionInfo {
addr: self,
redis: Default::default(),
tcp_settings: Default::default(),
})
}
}
impl IntoConnectionInfo for &str {
fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
match parse_redis_url(self) {
Some(u) => u.into_connection_info(),
None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
}
}
}
impl<T> IntoConnectionInfo for (T, u16)
where
T: Into<String>,
{
fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
Ok(ConnectionInfo {
addr: ConnectionAddr::Tcp(self.0.into(), self.1),
redis: RedisConnectionInfo::default(),
tcp_settings: TcpSettings::default(),
})
}
}
impl IntoConnectionInfo for String {
fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
match parse_redis_url(&self) {
Some(u) => u.into_connection_info(),
None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")),
}
}
}
fn parse_protocol(query: &HashMap<Cow<str>, Cow<str>>) -> RedisResult<ProtocolVersion> {
Ok(match query.get("protocol") {
Some(protocol) => {
if protocol == "2" || protocol == "resp2" {
ProtocolVersion::RESP2
} else if protocol == "3" || protocol == "resp3" {
ProtocolVersion::RESP3
} else {
fail!((
ErrorKind::InvalidClientConfig,
"Invalid protocol version",
protocol.to_string()
))
}
}
None => ProtocolVersion::RESP2,
})
}
#[inline]
pub(crate) fn is_wildcard_address(address: &str) -> bool {
address == "0.0.0.0" || address == "::"
}
fn url_to_tcp_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
let host = match url.host() {
Some(host) => {
let host_str = match host {
url::Host::Domain(path) => path.to_string(),
url::Host::Ipv4(v4) => v4.to_string(),
url::Host::Ipv6(v6) => v6.to_string(),
};
if is_wildcard_address(&host_str) {
return Err(RedisError::from((
ErrorKind::InvalidClientConfig,
"Cannot connect to a wildcard address (0.0.0.0 or ::)",
)));
}
host_str
}
None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")),
};
let port = url.port().unwrap_or(DEFAULT_PORT);
let addr = if url.scheme() == "rediss" || url.scheme() == "valkeys" {
#[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))]
{
match url.fragment() {
Some("insecure") => ConnectionAddr::TcpTls {
host,
port,
insecure: true,
tls_params: None,
},
Some(_) => fail!((
ErrorKind::InvalidClientConfig,
"only #insecure is supported as URL fragment"
)),
_ => ConnectionAddr::TcpTls {
host,
port,
insecure: false,
tls_params: None,
},
}
}
#[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
fail!((
ErrorKind::InvalidClientConfig,
"can't connect with TLS, the feature is not enabled"
));
} else {
ConnectionAddr::Tcp(host, port)
};
let query: HashMap<_, _> = url.query_pairs().collect();
Ok(ConnectionInfo {
addr,
redis: RedisConnectionInfo {
db: match url.path().trim_matches('/') {
"" => 0,
path => path.parse::<i64>().map_err(|_| -> RedisError {
(ErrorKind::InvalidClientConfig, "Invalid database number").into()
})?,
},
username: if url.username().is_empty() {
None
} else {
match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() {
Ok(decoded) => Some(decoded.into()),
Err(_) => fail!((
ErrorKind::InvalidClientConfig,
"Username is not valid UTF-8 string"
)),
}
},
password: match url.password() {
Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() {
Ok(decoded) => Some(decoded.into()),
Err(_) => fail!((
ErrorKind::InvalidClientConfig,
"Password is not valid UTF-8 string"
)),
},
None => None,
},
protocol: parse_protocol(&query)?,
skip_set_lib_name: false,
lib_name: None,
lib_ver: None,
},
tcp_settings: TcpSettings::default(),
})
}
#[cfg(unix)]
fn url_to_unix_connection_info(url: url::Url) -> RedisResult<ConnectionInfo> {
let query: HashMap<_, _> = url.query_pairs().collect();
Ok(ConnectionInfo {
addr: ConnectionAddr::Unix(url.to_file_path().map_err(|_| -> RedisError {
(ErrorKind::InvalidClientConfig, "Missing path").into()
})?),
redis: RedisConnectionInfo {
db: match query.get("db") {
Some(db) => db.parse::<i64>().map_err(|_| -> RedisError {
(ErrorKind::InvalidClientConfig, "Invalid database number").into()
})?,
None => 0,
},
username: query.get("user").map(|username| username.as_ref().into()),
password: query.get("pass").map(|password| password.as_ref().into()),
protocol: parse_protocol(&query)?,
skip_set_lib_name: false,
lib_name: None,
lib_ver: None,
},
tcp_settings: TcpSettings::default(),
})
}
#[cfg(not(unix))]
fn url_to_unix_connection_info(_: url::Url) -> RedisResult<ConnectionInfo> {
fail!((
ErrorKind::InvalidClientConfig,
"Unix sockets are not available on this platform."
));
}
impl IntoConnectionInfo for url::Url {
fn into_connection_info(self) -> RedisResult<ConnectionInfo> {
match self.scheme() {
"redis" | "rediss" | "valkey" | "valkeys" => url_to_tcp_connection_info(self),
"unix" | "redis+unix" | "valkey+unix" => url_to_unix_connection_info(self),
_ => fail!((
ErrorKind::InvalidClientConfig,
"URL provided is not a redis URL"
)),
}
}
}
struct TcpConnection {
reader: TcpStream,
open: bool,
}
#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
struct TcpNativeTlsConnection {
reader: TlsStream<TcpStream>,
open: bool,
}
#[cfg(feature = "tls-rustls")]
struct TcpRustlsConnection {
reader: StreamOwned<rustls::ClientConnection, TcpStream>,
open: bool,
}
#[cfg(unix)]
struct UnixConnection {
sock: UnixStream,
open: bool,
}
enum ActualConnection {
Tcp(TcpConnection),
#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
TcpNativeTls(Box<TcpNativeTlsConnection>),
#[cfg(feature = "tls-rustls")]
TcpRustls(Box<TcpRustlsConnection>),
#[cfg(unix)]
Unix(UnixConnection),
}
#[cfg(feature = "tls-rustls-insecure")]
struct NoCertificateVerification {
supported: rustls::crypto::WebPkiSupportedAlgorithms,
}
#[cfg(feature = "tls-rustls-insecure")]
impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
Ok(rustls::client::danger::ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.supported.supported_schemes()
}
}
#[cfg(feature = "tls-rustls-insecure")]
impl fmt::Debug for NoCertificateVerification {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("NoCertificateVerification").finish()
}
}
#[cfg(feature = "tls-rustls-insecure")]
#[derive(Debug)]
struct AcceptInvalidHostnamesCertVerifier {
inner: Arc<rustls::client::WebPkiServerVerifier>,
}
#[cfg(feature = "tls-rustls-insecure")]
fn is_hostname_error(err: &rustls::Error) -> bool {
matches!(
err,
rustls::Error::InvalidCertificate(
rustls::CertificateError::NotValidForName
| rustls::CertificateError::NotValidForNameContext { .. }
)
)
}
#[cfg(feature = "tls-rustls-insecure")]
impl rustls::client::danger::ServerCertVerifier for AcceptInvalidHostnamesCertVerifier {
fn verify_server_cert(
&self,
end_entity: &rustls::pki_types::CertificateDer<'_>,
intermediates: &[rustls::pki_types::CertificateDer<'_>],
server_name: &rustls::pki_types::ServerName<'_>,
ocsp_response: &[u8],
now: rustls::pki_types::UnixTime,
) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
self.inner
.verify_server_cert(end_entity, intermediates, server_name, ocsp_response, now)
.or_else(|err| {
if is_hostname_error(&err) {
Ok(rustls::client::danger::ServerCertVerified::assertion())
} else {
Err(err)
}
})
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner
.verify_tls12_signature(message, cert, dss)
.or_else(|err| {
if is_hostname_error(&err) {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
} else {
Err(err)
}
})
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &rustls::pki_types::CertificateDer<'_>,
dss: &rustls::DigitallySignedStruct,
) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
self.inner
.verify_tls13_signature(message, cert, dss)
.or_else(|err| {
if is_hostname_error(&err) {
Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
} else {
Err(err)
}
})
}
fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
pub struct Connection {
con: ActualConnection,
parser: Parser,
db: i64,
pubsub: bool,
protocol: ProtocolVersion,
push_sender: Option<SyncPushSender>,
messages_to_skip: usize,
}
pub struct PubSub<'a> {
con: &'a mut Connection,
waiting_messages: VecDeque<Msg>,
}
#[derive(Debug, Clone)]
pub struct Msg {
payload: Value,
channel: Value,
pattern: Option<Value>,
}
impl ActualConnection {
pub fn new(
addr: &ConnectionAddr,
timeout: Option<Duration>,
tcp_settings: &TcpSettings,
) -> RedisResult<ActualConnection> {
Ok(match *addr {
ConnectionAddr::Tcp(ref host, ref port) => {
if is_wildcard_address(host) {
fail!((
ErrorKind::InvalidClientConfig,
"Cannot connect to a wildcard address (0.0.0.0 or ::)"
));
}
let addr = (host.as_str(), *port);
let tcp = match timeout {
None => connect_tcp(addr, tcp_settings)?,
Some(timeout) => {
let mut tcp = None;
let mut last_error = None;
for addr in addr.to_socket_addrs()? {
match connect_tcp_timeout(&addr, timeout, tcp_settings) {
Ok(l) => {
tcp = Some(l);
break;
}
Err(e) => {
last_error = Some(e);
}
};
}
match (tcp, last_error) {
(Some(tcp), _) => tcp,
(None, Some(e)) => {
fail!(e);
}
(None, None) => {
fail!((
ErrorKind::InvalidClientConfig,
"could not resolve to any addresses"
));
}
}
}
};
ActualConnection::Tcp(TcpConnection {
reader: tcp,
open: true,
})
}
#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
ConnectionAddr::TcpTls {
ref host,
port,
insecure,
ref tls_params,
} => {
let tls_connector = if insecure {
TlsConnector::builder()
.danger_accept_invalid_certs(true)
.danger_accept_invalid_hostnames(true)
.use_sni(false)
.build()?
} else if let Some(params) = tls_params {
TlsConnector::builder()
.danger_accept_invalid_hostnames(params.danger_accept_invalid_hostnames)
.build()?
} else {
TlsConnector::new()?
};
let addr = (host.as_str(), port);
let tls = match timeout {
None => {
let tcp = connect_tcp(addr, tcp_settings)?;
match tls_connector.connect(host, tcp) {
Ok(res) => res,
Err(e) => {
fail!((ErrorKind::Io, "SSL Handshake error", e.to_string()));
}
}
}
Some(timeout) => {
let mut tcp = None;
let mut last_error = None;
for addr in (host.as_str(), port).to_socket_addrs()? {
match connect_tcp_timeout(&addr, timeout, tcp_settings) {
Ok(l) => {
tcp = Some(l);
break;
}
Err(e) => {
last_error = Some(e);
}
};
}
match (tcp, last_error) {
(Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(),
(None, Some(e)) => {
fail!(e);
}
(None, None) => {
fail!((
ErrorKind::InvalidClientConfig,
"could not resolve to any addresses"
));
}
}
}
};
ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection {
reader: tls,
open: true,
}))
}
#[cfg(feature = "tls-rustls")]
ConnectionAddr::TcpTls {
ref host,
port,
insecure,
ref tls_params,
} => {
let host: &str = host;
let config = create_rustls_config(insecure, tls_params.clone())?;
let conn = rustls::ClientConnection::new(
Arc::new(config),
rustls::pki_types::ServerName::try_from(host)?.to_owned(),
)?;
let reader = match timeout {
None => {
let tcp = connect_tcp((host, port), tcp_settings)?;
StreamOwned::new(conn, tcp)
}
Some(timeout) => {
let mut tcp = None;
let mut last_error = None;
for addr in (host, port).to_socket_addrs()? {
match connect_tcp_timeout(&addr, timeout, tcp_settings) {
Ok(l) => {
tcp = Some(l);
break;
}
Err(e) => {
last_error = Some(e);
}
};
}
match (tcp, last_error) {
(Some(tcp), _) => StreamOwned::new(conn, tcp),
(None, Some(e)) => {
fail!(e);
}
(None, None) => {
fail!((
ErrorKind::InvalidClientConfig,
"could not resolve to any addresses"
));
}
}
}
};
ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true }))
}
#[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))]
ConnectionAddr::TcpTls { .. } => {
fail!((
ErrorKind::InvalidClientConfig,
"Cannot connect to TCP with TLS without the tls feature"
));
}
#[cfg(unix)]
ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection {
sock: UnixStream::connect(path)?,
open: true,
}),
#[cfg(not(unix))]
ConnectionAddr::Unix(ref _path) => {
fail!((
ErrorKind::InvalidClientConfig,
"Cannot connect to unix sockets \
on this platform"
));
}
})
}
pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
match *self {
ActualConnection::Tcp(ref mut connection) => {
let res = connection.reader.write_all(bytes).map_err(RedisError::from);
match res {
Err(e) => {
if e.is_unrecoverable_error() {
connection.open = false;
}
Err(e)
}
Ok(_) => Ok(Value::Okay),
}
}
#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
ActualConnection::TcpNativeTls(ref mut connection) => {
let res = connection.reader.write_all(bytes).map_err(RedisError::from);
match res {
Err(e) => {
if e.is_unrecoverable_error() {
connection.open = false;
}
Err(e)
}
Ok(_) => Ok(Value::Okay),
}
}
#[cfg(feature = "tls-rustls")]
ActualConnection::TcpRustls(ref mut connection) => {
let res = connection.reader.write_all(bytes).map_err(RedisError::from);
match res {
Err(e) => {
if e.is_unrecoverable_error() {
connection.open = false;
}
Err(e)
}
Ok(_) => Ok(Value::Okay),
}
}
#[cfg(unix)]
ActualConnection::Unix(ref mut connection) => {
let result = connection.sock.write_all(bytes).map_err(RedisError::from);
match result {
Err(e) => {
if e.is_unrecoverable_error() {
connection.open = false;
}
Err(e)
}
Ok(_) => Ok(Value::Okay),
}
}
}
}
pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
match *self {
ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
reader.set_write_timeout(dur)?;
}
#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
let reader = &(boxed_tls_connection.reader);
reader.get_ref().set_write_timeout(dur)?;
}
#[cfg(feature = "tls-rustls")]
ActualConnection::TcpRustls(ref boxed_tls_connection) => {
let reader = &(boxed_tls_connection.reader);
reader.get_ref().set_write_timeout(dur)?;
}
#[cfg(unix)]
ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
sock.set_write_timeout(dur)?;
}
}
Ok(())
}
pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
match *self {
ActualConnection::Tcp(TcpConnection { ref reader, .. }) => {
reader.set_read_timeout(dur)?;
}
#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
ActualConnection::TcpNativeTls(ref boxed_tls_connection) => {
let reader = &(boxed_tls_connection.reader);
reader.get_ref().set_read_timeout(dur)?;
}
#[cfg(feature = "tls-rustls")]
ActualConnection::TcpRustls(ref boxed_tls_connection) => {
let reader = &(boxed_tls_connection.reader);
reader.get_ref().set_read_timeout(dur)?;
}
#[cfg(unix)]
ActualConnection::Unix(UnixConnection { ref sock, .. }) => {
sock.set_read_timeout(dur)?;
}
}
Ok(())
}
pub fn is_open(&self) -> bool {
match *self {
ActualConnection::Tcp(TcpConnection { open, .. }) => open,
#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open,
#[cfg(feature = "tls-rustls")]
ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open,
#[cfg(unix)]
ActualConnection::Unix(UnixConnection { open, .. }) => open,
}
}
}
#[cfg(feature = "tls-rustls")]
pub(crate) fn create_rustls_config(
insecure: bool,
tls_params: Option<TlsConnParams>,
) -> RedisResult<rustls::ClientConfig> {
#[allow(unused_mut)]
let mut root_store = RootCertStore::empty();
#[cfg(feature = "tls-rustls-webpki-roots")]
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
#[cfg(all(
feature = "tls-rustls",
not(feature = "tls-native-tls"),
not(feature = "tls-rustls-webpki-roots")
))]
{
let mut certificate_result = load_native_certs();
if let Some(error) = certificate_result.errors.pop() {
return Err(error.into());
}
for cert in certificate_result.certs {
root_store.add(cert)?;
}
}
let config = rustls::ClientConfig::builder();
let config = if let Some(tls_params) = tls_params {
let root_cert_store = tls_params.root_cert_store.unwrap_or(root_store);
let config_builder = config.with_root_certificates(root_cert_store.clone());
let config_builder = if let Some(ClientTlsParams {
client_cert_chain: client_cert,
client_key,
}) = tls_params.client_tls_params
{
config_builder
.with_client_auth_cert(client_cert, client_key)
.map_err(|err| {
RedisError::from((
ErrorKind::InvalidClientConfig,
"Unable to build client with TLS parameters provided.",
err.to_string(),
))
})?
} else {
config_builder.with_no_client_auth()
};
#[cfg(any(feature = "tls-rustls-insecure", feature = "tls-native-tls"))]
let config_builder = if !insecure && tls_params.danger_accept_invalid_hostnames {
#[cfg(not(feature = "tls-rustls-insecure"))]
{
fail!((
ErrorKind::InvalidClientConfig,
"Cannot create insecure client via danger_accept_invalid_hostnames without tls-rustls-insecure feature"
));
}
#[cfg(feature = "tls-rustls-insecure")]
{
let mut config = config_builder;
config.dangerous().set_certificate_verifier(Arc::new(
AcceptInvalidHostnamesCertVerifier {
inner: rustls::client::WebPkiServerVerifier::builder(Arc::new(
root_cert_store,
))
.build()
.map_err(|err| rustls::Error::from(rustls::OtherError(Arc::new(err))))?,
},
));
config
}
} else {
config_builder
};
config_builder
} else {
config
.with_root_certificates(root_store)
.with_no_client_auth()
};
match (insecure, cfg!(feature = "tls-rustls-insecure")) {
#[cfg(feature = "tls-rustls-insecure")]
(true, true) => {
let mut config = config;
config.enable_sni = false;
let Some(crypto_provider) = rustls::crypto::CryptoProvider::get_default() else {
return Err(RedisError::from((
ErrorKind::InvalidClientConfig,
"No crypto provider available for rustls",
)));
};
config
.dangerous()
.set_certificate_verifier(Arc::new(NoCertificateVerification {
supported: crypto_provider.signature_verification_algorithms,
}));
Ok(config)
}
(true, false) => {
fail!((
ErrorKind::InvalidClientConfig,
"Cannot create insecure client without tls-rustls-insecure feature"
));
}
_ => Ok(config),
}
}
pub(crate) fn authenticate_cmd(username: Option<&str>, password: &str) -> Cmd {
let mut command = cmd("AUTH");
if let Some(username) = &username {
command.arg(username);
}
command.arg(password);
command
}
pub fn connect(
connection_info: &ConnectionInfo,
timeout: Option<Duration>,
) -> RedisResult<Connection> {
let start = Instant::now();
let con: ActualConnection = ActualConnection::new(
&connection_info.addr,
timeout,
&connection_info.tcp_settings,
)?;
let remaining_timeout = timeout.and_then(|timeout| timeout.checked_sub(start.elapsed()));
if timeout.is_some() && remaining_timeout.is_none() {
return Err(RedisError::from(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"Connection timed out",
)));
}
con.set_read_timeout(remaining_timeout)?;
con.set_write_timeout(remaining_timeout)?;
let con = setup_connection(
con,
&connection_info.redis,
#[cfg(feature = "cache-aio")]
None,
)?;
con.set_read_timeout(None)?;
con.set_write_timeout(None)?;
Ok(con)
}
pub(crate) struct ConnectionSetupComponents {
resp3_auth_cmd_idx: Option<usize>,
resp2_auth_cmd_idx: Option<usize>,
select_cmd_idx: Option<usize>,
#[cfg(feature = "cache-aio")]
cache_cmd_idx: Option<usize>,
}
pub(crate) fn connection_setup_pipeline(
connection_info: &RedisConnectionInfo,
check_username: bool,
#[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
) -> (crate::Pipeline, ConnectionSetupComponents) {
let mut pipeline = pipe();
let (authenticate_with_resp3_cmd_index, authenticate_with_resp2_cmd_index) =
if connection_info.protocol.supports_resp3() {
pipeline.add_command(resp3_hello(connection_info));
(Some(0), None)
} else if let Some(password) = connection_info.password.as_ref() {
pipeline.add_command(authenticate_cmd(
check_username.then(|| connection_info.username()).flatten(),
password,
));
(None, Some(0))
} else {
(None, None)
};
let select_db_cmd_index = (connection_info.db != 0)
.then(|| pipeline.len())
.inspect(|_| {
pipeline.cmd("SELECT").arg(connection_info.db);
});
#[cfg(feature = "cache-aio")]
let cache_cmd_index = cache_config.map(|cache_config| {
pipeline.cmd("CLIENT").arg("TRACKING").arg("ON");
match cache_config.mode {
crate::caching::CacheMode::All => {}
crate::caching::CacheMode::OptIn => {
pipeline.arg("OPTIN");
}
}
pipeline.len() - 1
});
if !connection_info.skip_set_lib_name {
pipeline
.cmd("CLIENT")
.arg("SETINFO")
.arg("LIB-NAME")
.arg(
connection_info
.lib_name
.as_ref()
.map_or(DEFAULT_CLIENT_SETINFO_LIB_NAME, ArcStr::as_str),
)
.ignore();
pipeline
.cmd("CLIENT")
.arg("SETINFO")
.arg("LIB-VER")
.arg(
connection_info
.lib_ver
.as_ref()
.map_or(DEFAULT_CLIENT_SETINFO_LIB_VER, ArcStr::as_str),
)
.ignore();
}
(
pipeline,
ConnectionSetupComponents {
resp3_auth_cmd_idx: authenticate_with_resp3_cmd_index,
resp2_auth_cmd_idx: authenticate_with_resp2_cmd_index,
select_cmd_idx: select_db_cmd_index,
#[cfg(feature = "cache-aio")]
cache_cmd_idx: cache_cmd_index,
},
)
}
fn check_resp3_auth(result: &Value) -> RedisResult<()> {
if let Value::ServerError(err) = result {
return Err(get_resp3_hello_command_error(err.clone().into()));
}
Ok(())
}
#[derive(PartialEq)]
pub(crate) enum AuthResult {
Succeeded,
ShouldRetryWithoutUsername,
}
fn check_resp2_auth(result: &Value) -> RedisResult<AuthResult> {
let err = match result {
Value::Okay => {
return Ok(AuthResult::Succeeded);
}
Value::ServerError(err) => err,
_ => {
return Err((
ServerErrorKind::ResponseError.into(),
"Redis server refused to authenticate, returns Ok() != Value::Okay",
)
.into());
}
};
let err_msg = err.details().ok_or((
ErrorKind::AuthenticationFailed,
"Password authentication failed",
))?;
if !err_msg.contains("wrong number of arguments for 'auth' command") {
return Err((
ErrorKind::AuthenticationFailed,
"Password authentication failed",
)
.into());
}
Ok(AuthResult::ShouldRetryWithoutUsername)
}
fn check_db_select(value: &Value) -> RedisResult<()> {
let Value::ServerError(err) = value else {
return Ok(());
};
match err.details() {
Some(err_msg) => Err((
ServerErrorKind::ResponseError.into(),
"Redis server refused to switch database",
err_msg.to_string(),
)
.into()),
None => Err((
ServerErrorKind::ResponseError.into(),
"Redis server refused to switch database",
)
.into()),
}
}
#[cfg(feature = "cache-aio")]
fn check_caching(result: &Value) -> RedisResult<()> {
match result {
Value::Okay => Ok(()),
_ => Err((
ServerErrorKind::ResponseError.into(),
"Client-side caching returned unknown response",
format!("{result:?}"),
)
.into()),
}
}
pub(crate) fn check_connection_setup(
results: Vec<Value>,
ConnectionSetupComponents {
resp3_auth_cmd_idx,
resp2_auth_cmd_idx,
select_cmd_idx,
#[cfg(feature = "cache-aio")]
cache_cmd_idx,
}: ConnectionSetupComponents,
) -> RedisResult<AuthResult> {
assert!(!(resp2_auth_cmd_idx.is_some() && resp3_auth_cmd_idx.is_some()));
if let Some(index) = resp3_auth_cmd_idx {
let Some(value) = results.get(index) else {
return Err((ErrorKind::Client, "Missing RESP3 auth response").into());
};
check_resp3_auth(value)?;
} else if let Some(index) = resp2_auth_cmd_idx {
let Some(value) = results.get(index) else {
return Err((ErrorKind::Client, "Missing RESP2 auth response").into());
};
if check_resp2_auth(value)? == AuthResult::ShouldRetryWithoutUsername {
return Ok(AuthResult::ShouldRetryWithoutUsername);
}
}
if let Some(index) = select_cmd_idx {
let Some(value) = results.get(index) else {
return Err((ErrorKind::Client, "Missing SELECT DB response").into());
};
check_db_select(value)?;
}
#[cfg(feature = "cache-aio")]
if let Some(index) = cache_cmd_idx {
let Some(value) = results.get(index) else {
return Err((ErrorKind::Client, "Missing Caching response").into());
};
check_caching(value)?;
}
Ok(AuthResult::Succeeded)
}
fn execute_connection_pipeline(
rv: &mut Connection,
(pipeline, instructions): (crate::Pipeline, ConnectionSetupComponents),
) -> RedisResult<AuthResult> {
if pipeline.is_empty() {
return Ok(AuthResult::Succeeded);
}
let results = rv.req_packed_commands(&pipeline.get_packed_pipeline(), 0, pipeline.len())?;
check_connection_setup(results, instructions)
}
fn setup_connection(
con: ActualConnection,
connection_info: &RedisConnectionInfo,
#[cfg(feature = "cache-aio")] cache_config: Option<crate::caching::CacheConfig>,
) -> RedisResult<Connection> {
let mut rv = Connection {
con,
parser: Parser::new(),
db: connection_info.db,
pubsub: false,
protocol: connection_info.protocol,
push_sender: None,
messages_to_skip: 0,
};
if execute_connection_pipeline(
&mut rv,
connection_setup_pipeline(
connection_info,
true,
#[cfg(feature = "cache-aio")]
cache_config,
),
)? == AuthResult::ShouldRetryWithoutUsername
{
execute_connection_pipeline(
&mut rv,
connection_setup_pipeline(
connection_info,
false,
#[cfg(feature = "cache-aio")]
cache_config,
),
)?;
}
Ok(rv)
}
pub trait ConnectionLike {
fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value>;
#[doc(hidden)]
fn req_packed_commands(
&mut self,
cmd: &[u8],
offset: usize,
count: usize,
) -> RedisResult<Vec<Value>>;
fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
let pcmd = cmd.get_packed_command();
self.req_packed_command(&pcmd)
}
fn get_db(&self) -> i64;
#[doc(hidden)]
fn supports_pipelining(&self) -> bool {
true
}
fn check_connection(&mut self) -> bool;
fn is_open(&self) -> bool;
}
impl Connection {
pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> {
self.send_bytes(cmd)?;
Ok(())
}
pub fn recv_response(&mut self) -> RedisResult<Value> {
self.read(true)
}
pub fn set_write_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
self.con.set_write_timeout(dur)
}
pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
self.con.set_read_timeout(dur)
}
pub fn as_pubsub(&mut self) -> PubSub<'_> {
PubSub::new(self)
}
fn exit_pubsub(&mut self) -> RedisResult<()> {
let res = self.clear_active_subscriptions();
if res.is_ok() {
self.pubsub = false;
} else {
self.pubsub = true;
}
res
}
fn clear_active_subscriptions(&mut self) -> RedisResult<()> {
{
let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command();
let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command();
self.send_bytes(&unsubscribe)?;
self.send_bytes(&punsubscribe)?;
}
let mut received_unsub = false;
let mut received_punsub = false;
loop {
let resp = self.recv_response()?;
match resp {
Value::Push { kind, data } => {
if data.len() >= 2 {
if let Value::Int(num) = data[1] {
if resp3_is_pub_sub_state_cleared(
&mut received_unsub,
&mut received_punsub,
&kind,
num as isize,
) {
break;
}
}
}
}
Value::ServerError(err) => {
if err.kind() == Some(ServerErrorKind::NoSub) {
if no_sub_err_is_pub_sub_state_cleared(
&mut received_unsub,
&mut received_punsub,
&err,
) {
break;
} else {
continue;
}
}
return Err(err.into());
}
Value::Array(vec) => {
let res: (Vec<u8>, (), isize) = from_redis_value(Value::Array(vec))?;
if resp2_is_pub_sub_state_cleared(
&mut received_unsub,
&mut received_punsub,
&res.0,
res.2,
) {
break;
}
}
_ => {
return Err((
ErrorKind::Client,
"Unexpected unsubscribe response",
format!("{resp:?}"),
)
.into());
}
}
}
Ok(())
}
fn send_push(&self, push: PushInfo) {
if let Some(sender) = &self.push_sender {
let _ = sender.send(push);
}
}
fn try_send(&self, value: &RedisResult<Value>) {
if let Ok(Value::Push { kind, data }) = value {
self.send_push(PushInfo {
kind: kind.clone(),
data: data.clone(),
});
}
}
fn send_disconnect(&self) {
self.send_push(PushInfo::disconnect())
}
fn close_connection(&mut self) {
self.send_disconnect();
match self.con {
ActualConnection::Tcp(ref mut connection) => {
let _ = connection.reader.shutdown(net::Shutdown::Both);
connection.open = false;
}
#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
ActualConnection::TcpNativeTls(ref mut connection) => {
let _ = connection.reader.shutdown();
connection.open = false;
}
#[cfg(feature = "tls-rustls")]
ActualConnection::TcpRustls(ref mut connection) => {
let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both);
connection.open = false;
}
#[cfg(unix)]
ActualConnection::Unix(ref mut connection) => {
let _ = connection.sock.shutdown(net::Shutdown::Both);
connection.open = false;
}
}
}
fn read(&mut self, is_response: bool) -> RedisResult<Value> {
loop {
let result = match self.con {
ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => {
self.parser.parse_value(reader)
}
#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))]
ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => {
let reader = &mut boxed_tls_connection.reader;
self.parser.parse_value(reader)
}
#[cfg(feature = "tls-rustls")]
ActualConnection::TcpRustls(ref mut boxed_tls_connection) => {
let reader = &mut boxed_tls_connection.reader;
self.parser.parse_value(reader)
}
#[cfg(unix)]
ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => {
self.parser.parse_value(sock)
}
};
self.try_send(&result);
let Err(err) = &result else {
if self.messages_to_skip > 0 {
self.messages_to_skip -= 1;
continue;
}
return result;
};
let Some(io_error) = err.as_io_error() else {
if self.messages_to_skip > 0 {
self.messages_to_skip -= 1;
continue;
}
return result;
};
if io_error.kind() == io::ErrorKind::UnexpectedEof {
self.close_connection();
} else if is_response {
self.messages_to_skip += 1;
}
return result;
}
}
pub fn set_push_sender(&mut self, sender: SyncPushSender) {
self.push_sender = Some(sender);
}
fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult<Value> {
if bytes.is_empty() {
return Err(RedisError::make_empty_command());
}
let result = self.con.send_bytes(bytes);
if self.protocol.supports_resp3() {
if let Err(e) = &result {
if e.is_connection_dropped() {
self.send_disconnect();
}
}
}
result
}
pub fn subscribe_resp3<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
check_resp3!(self.protocol);
cmd("SUBSCRIBE")
.arg(channel)
.set_no_response(true)
.exec(self)
}
pub fn psubscribe_resp3<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
check_resp3!(self.protocol);
cmd("PSUBSCRIBE")
.arg(pchannel)
.set_no_response(true)
.exec(self)
}
pub fn unsubscribe_resp3<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
check_resp3!(self.protocol);
cmd("UNSUBSCRIBE")
.arg(channel)
.set_no_response(true)
.exec(self)
}
pub fn punsubscribe_resp3<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
check_resp3!(self.protocol);
cmd("PUNSUBSCRIBE")
.arg(pchannel)
.set_no_response(true)
.exec(self)
}
}
impl ConnectionLike for Connection {
fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
let pcmd = cmd.get_packed_command();
if self.pubsub {
self.exit_pubsub()?;
}
self.send_bytes(&pcmd)?;
if cmd.is_no_response() {
return Ok(Value::Nil);
}
loop {
match self.read(true)? {
Value::Push {
kind: _kind,
data: _data,
} => continue,
val => return Ok(val),
}
}
}
fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
if self.pubsub {
self.exit_pubsub()?;
}
self.send_bytes(cmd)?;
loop {
match self.read(true)? {
Value::Push {
kind: _kind,
data: _data,
} => continue,
val => return Ok(val),
}
}
}
fn req_packed_commands(
&mut self,
cmd: &[u8],
offset: usize,
count: usize,
) -> RedisResult<Vec<Value>> {
if self.pubsub {
self.exit_pubsub()?;
}
self.send_bytes(cmd)?;
let mut rv = vec![];
let mut first_err = None;
let mut server_errors = vec![];
let mut count = count;
let mut idx = 0;
while idx < (offset + count) {
let response = self.read(true);
match response {
Ok(Value::ServerError(err)) => {
if idx < offset {
server_errors.push((idx - 1, err)); } else {
rv.push(Value::ServerError(err));
}
}
Ok(item) => {
if let Value::Push {
kind: _kind,
data: _data,
} = item
{
count += 1;
} else if idx >= offset {
rv.push(item);
}
}
Err(err) => {
if first_err.is_none() {
first_err = Some(err);
}
}
}
idx += 1;
}
if !server_errors.is_empty() {
return Err(RedisError::make_aborted_transaction(server_errors));
}
first_err.map_or(Ok(rv), Err)
}
fn get_db(&self) -> i64 {
self.db
}
fn check_connection(&mut self) -> bool {
cmd("PING").query::<String>(self).is_ok()
}
fn is_open(&self) -> bool {
self.con.is_open()
}
}
impl<C, T> ConnectionLike for T
where
C: ConnectionLike,
T: DerefMut<Target = C>,
{
fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult<Value> {
self.deref_mut().req_packed_command(cmd)
}
fn req_packed_commands(
&mut self,
cmd: &[u8],
offset: usize,
count: usize,
) -> RedisResult<Vec<Value>> {
self.deref_mut().req_packed_commands(cmd, offset, count)
}
fn req_command(&mut self, cmd: &Cmd) -> RedisResult<Value> {
self.deref_mut().req_command(cmd)
}
fn get_db(&self) -> i64 {
self.deref().get_db()
}
fn supports_pipelining(&self) -> bool {
self.deref().supports_pipelining()
}
fn check_connection(&mut self) -> bool {
self.deref_mut().check_connection()
}
fn is_open(&self) -> bool {
self.deref().is_open()
}
}
impl<'a> PubSub<'a> {
fn new(con: &'a mut Connection) -> Self {
Self {
con,
waiting_messages: VecDeque::new(),
}
}
fn cache_messages_until_received_response(
&mut self,
cmd: &mut Cmd,
is_sub_unsub: bool,
) -> RedisResult<Value> {
let ignore_response = self.con.protocol.supports_resp3() && is_sub_unsub;
cmd.set_no_response(ignore_response);
self.con.send_packed_command(&cmd.get_packed_command())?;
loop {
let response = self.con.recv_response()?;
if let Some(msg) = Msg::from_value(&response) {
self.waiting_messages.push_back(msg);
} else {
return Ok(response);
}
}
}
pub fn subscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
self.cache_messages_until_received_response(cmd("SUBSCRIBE").arg(channel), true)?;
Ok(())
}
pub fn psubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
self.cache_messages_until_received_response(cmd("PSUBSCRIBE").arg(pchannel), true)?;
Ok(())
}
pub fn unsubscribe<T: ToRedisArgs>(&mut self, channel: T) -> RedisResult<()> {
self.cache_messages_until_received_response(cmd("UNSUBSCRIBE").arg(channel), true)?;
Ok(())
}
pub fn punsubscribe<T: ToRedisArgs>(&mut self, pchannel: T) -> RedisResult<()> {
self.cache_messages_until_received_response(cmd("PUNSUBSCRIBE").arg(pchannel), true)?;
Ok(())
}
pub fn ping_message<T: FromRedisValue>(&mut self, message: impl ToRedisArgs) -> RedisResult<T> {
Ok(from_redis_value(
self.cache_messages_until_received_response(cmd("PING").arg(message), false)?,
)?)
}
pub fn ping<T: FromRedisValue>(&mut self) -> RedisResult<T> {
Ok(from_redis_value(
self.cache_messages_until_received_response(&mut cmd("PING"), false)?,
)?)
}
pub fn get_message(&mut self) -> RedisResult<Msg> {
if let Some(msg) = self.waiting_messages.pop_front() {
return Ok(msg);
}
loop {
if let Some(msg) = Msg::from_owned_value(self.con.read(false)?) {
return Ok(msg);
} else {
continue;
}
}
}
pub fn set_read_timeout(&self, dur: Option<Duration>) -> RedisResult<()> {
self.con.set_read_timeout(dur)
}
}
impl Drop for PubSub<'_> {
fn drop(&mut self) {
let _ = self.con.exit_pubsub();
}
}
impl Msg {
pub fn from_value(value: &Value) -> Option<Self> {
Self::from_owned_value(value.clone())
}
pub fn from_owned_value(value: Value) -> Option<Self> {
let mut pattern = None;
let payload;
let channel;
if let Value::Push { kind, data } = value {
return Self::from_push_info(PushInfo { kind, data });
} else {
let raw_msg: Vec<Value> = from_redis_value(value).ok()?;
let mut iter = raw_msg.into_iter();
let msg_type: String = from_redis_value(iter.next()?).ok()?;
if msg_type == "message" {
channel = iter.next()?;
payload = iter.next()?;
} else if msg_type == "pmessage" {
pattern = Some(iter.next()?);
channel = iter.next()?;
payload = iter.next()?;
} else {
return None;
}
};
Some(Msg {
payload,
channel,
pattern,
})
}
pub fn from_push_info(push_info: PushInfo) -> Option<Self> {
let mut pattern = None;
let payload;
let channel;
let mut iter = push_info.data.into_iter();
if push_info.kind == PushKind::Message || push_info.kind == PushKind::SMessage {
channel = iter.next()?;
payload = iter.next()?;
} else if push_info.kind == PushKind::PMessage {
pattern = Some(iter.next()?);
channel = iter.next()?;
payload = iter.next()?;
} else {
return None;
}
Some(Msg {
payload,
channel,
pattern,
})
}
pub fn get_channel<T: FromRedisValue>(&self) -> RedisResult<T> {
Ok(from_redis_value_ref(&self.channel)?)
}
pub fn get_channel_name(&self) -> &str {
match self.channel {
Value::BulkString(ref bytes) => from_utf8(bytes).unwrap_or("?"),
_ => "?",
}
}
pub fn get_payload<T: FromRedisValue>(&self) -> RedisResult<T> {
Ok(from_redis_value_ref(&self.payload)?)
}
pub fn get_payload_bytes(&self) -> &[u8] {
match self.payload {
Value::BulkString(ref bytes) => bytes,
_ => b"",
}
}
#[allow(clippy::wrong_self_convention)]
pub fn from_pattern(&self) -> bool {
self.pattern.is_some()
}
pub fn get_pattern<T: FromRedisValue>(&self) -> RedisResult<T> {
Ok(match self.pattern {
None => from_redis_value_ref(&Value::Nil),
Some(ref x) => from_redis_value_ref(x),
}?)
}
}
pub fn transaction<
C: ConnectionLike,
K: ToRedisArgs,
T,
F: FnMut(&mut C, &mut Pipeline) -> RedisResult<Option<T>>,
>(
con: &mut C,
keys: &[K],
func: F,
) -> RedisResult<T> {
let mut func = func;
loop {
cmd("WATCH").arg(keys).exec(con)?;
let mut p = pipe();
let response: Option<T> = func(con, p.atomic())?;
match response {
None => {
continue;
}
Some(response) => {
cmd("UNWATCH").exec(con)?;
return Ok(response);
}
}
}
}
pub fn resp2_is_pub_sub_state_cleared(
received_unsub: &mut bool,
received_punsub: &mut bool,
kind: &[u8],
num: isize,
) -> bool {
match kind.first() {
Some(&b'u') => *received_unsub = true,
Some(&b'p') => *received_punsub = true,
_ => (),
};
*received_unsub && *received_punsub && num == 0
}
pub fn resp3_is_pub_sub_state_cleared(
received_unsub: &mut bool,
received_punsub: &mut bool,
kind: &PushKind,
num: isize,
) -> bool {
match kind {
PushKind::Unsubscribe => *received_unsub = true,
PushKind::PUnsubscribe => *received_punsub = true,
_ => (),
};
*received_unsub && *received_punsub && num == 0
}
pub fn no_sub_err_is_pub_sub_state_cleared(
received_unsub: &mut bool,
received_punsub: &mut bool,
err: &ServerError,
) -> bool {
let details = err.details();
*received_unsub = *received_unsub
|| details
.map(|details| details.starts_with("'unsub"))
.unwrap_or_default();
*received_punsub = *received_punsub
|| details
.map(|details| details.starts_with("'punsub"))
.unwrap_or_default();
*received_unsub && *received_punsub
}
pub fn get_resp3_hello_command_error(err: RedisError) -> RedisError {
if let Some(detail) = err.detail() {
if detail.starts_with("unknown command `HELLO`") {
return (
ErrorKind::RESP3NotSupported,
"Redis Server doesn't support HELLO command therefore resp3 cannot be used",
)
.into();
}
}
err
}
#[cfg(test)]
mod tests {
mod util {
use crate::connection::connection_setup_pipeline;
use crate::{RedisConnectionInfo, cmd};
pub fn assert_lib_name_in_connection_setup_pipeline(
redis_connection_info: &RedisConnectionInfo,
expected_lib_name: &str,
expected_lib_ver: &str,
) {
let pipeline = connection_setup_pipeline(
redis_connection_info,
false,
#[cfg(feature = "cache-aio")]
None,
)
.0;
let actual_packed_cmds = pipeline
.commands
.iter()
.map(|c| c.get_packed_command())
.collect::<Vec<_>>();
let expected_lib_name_packed_cmd = cmd("CLIENT")
.arg("SETINFO")
.arg("LIB-NAME")
.arg(expected_lib_name)
.get_packed_command();
assert!(actual_packed_cmds.contains(&expected_lib_name_packed_cmd));
let expected_lib_ver_packed_cmd = cmd("CLIENT")
.arg("SETINFO")
.arg("LIB-VER")
.arg(expected_lib_ver)
.get_packed_command();
assert!(actual_packed_cmds.contains(&expected_lib_ver_packed_cmd));
}
}
use super::*;
use util::assert_lib_name_in_connection_setup_pipeline;
#[test]
fn test_parse_redis_url() {
let cases = vec![
("redis://127.0.0.1", true),
("redis://[::1]", true),
("rediss://127.0.0.1", true),
("rediss://[::1]", true),
("valkey://127.0.0.1", true),
("valkey://[::1]", true),
("valkeys://127.0.0.1", true),
("valkeys://[::1]", true),
("redis+unix:///run/redis.sock", true),
("valkey+unix:///run/valkey.sock", true),
("unix:///run/redis.sock", true),
("http://127.0.0.1", false),
("tcp://127.0.0.1", false),
];
for (url, expected) in cases.into_iter() {
let res = parse_redis_url(url);
assert_eq!(
res.is_some(),
expected,
"Parsed result of `{url}` is not expected",
);
}
}
#[test]
fn test_url_to_tcp_connection_info() {
let cases = vec![
(
url::Url::parse("redis://127.0.0.1").unwrap(),
ConnectionInfo {
addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
redis: Default::default(),
tcp_settings: TcpSettings::default(),
},
),
(
url::Url::parse("redis://[::1]").unwrap(),
ConnectionInfo {
addr: ConnectionAddr::Tcp("::1".to_string(), 6379),
redis: Default::default(),
tcp_settings: TcpSettings::default(),
},
),
(
url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(),
ConnectionInfo {
addr: ConnectionAddr::Tcp("example.com".to_string(), 6379),
redis: RedisConnectionInfo {
db: 2,
username: Some("%johndoe%".into()),
password: Some("#@<>$".into()),
protocol: ProtocolVersion::RESP2,
skip_set_lib_name: false,
lib_name: None,
lib_ver: None,
},
tcp_settings: TcpSettings::default(),
},
),
(
url::Url::parse("redis://127.0.0.1/?protocol=2").unwrap(),
ConnectionInfo {
addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
redis: Default::default(),
tcp_settings: TcpSettings::default(),
},
),
(
url::Url::parse("redis://127.0.0.1/?protocol=resp3").unwrap(),
ConnectionInfo {
addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379),
redis: RedisConnectionInfo {
db: 0,
username: None,
password: None,
protocol: ProtocolVersion::RESP3,
skip_set_lib_name: false,
lib_name: None,
lib_ver: None,
},
tcp_settings: TcpSettings::default(),
},
),
];
for (url, expected) in cases.into_iter() {
let res = url_to_tcp_connection_info(url.clone()).unwrap();
assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
assert_eq!(
res.redis.db, expected.redis.db,
"db of {url} is not expected",
);
assert_eq!(
res.redis.username, expected.redis.username,
"username of {url} is not expected",
);
assert_eq!(
res.redis.password, expected.redis.password,
"password of {url} is not expected",
);
}
}
#[test]
fn test_url_to_tcp_connection_info_failed() {
let cases = vec![
(
url::Url::parse("redis://").unwrap(),
"Missing hostname",
None,
),
(
url::Url::parse("redis://127.0.0.1/db").unwrap(),
"Invalid database number",
None,
),
(
url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(),
"Username is not valid UTF-8 string",
None,
),
(
url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(),
"Password is not valid UTF-8 string",
None,
),
(
url::Url::parse("redis://127.0.0.1/?protocol=4").unwrap(),
"Invalid protocol version",
Some("4"),
),
];
for (url, expected, detail) in cases.into_iter() {
let res = url_to_tcp_connection_info(url).unwrap_err();
assert_eq!(res.kind(), crate::ErrorKind::InvalidClientConfig,);
let desc = res.to_string();
assert!(desc.contains(expected), "{desc}");
assert_eq!(res.detail(), detail);
}
}
#[test]
#[cfg(unix)]
fn test_url_to_unix_connection_info() {
let cases = vec![
(
url::Url::parse("unix:///var/run/redis.sock").unwrap(),
ConnectionInfo {
addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
redis: RedisConnectionInfo {
db: 0,
username: None,
password: None,
protocol: ProtocolVersion::RESP2,
skip_set_lib_name: false,
lib_name: None,
lib_ver: None,
},
tcp_settings: Default::default(),
},
),
(
url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(),
ConnectionInfo {
addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
redis: RedisConnectionInfo {
db: 1,
username: None,
password: None,
protocol: ProtocolVersion::RESP2,
skip_set_lib_name: false,
lib_name: None,
lib_ver: None,
},
tcp_settings: TcpSettings::default(),
},
),
(
url::Url::parse(
"unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2",
)
.unwrap(),
ConnectionInfo {
addr: ConnectionAddr::Unix("/example.sock".into()),
redis: RedisConnectionInfo {
db: 2,
username: Some("%johndoe%".into()),
password: Some("#@<>$".into()),
protocol: ProtocolVersion::RESP2,
skip_set_lib_name: false,
lib_name: None,
lib_ver: None,
},
tcp_settings: TcpSettings::default(),
},
),
(
url::Url::parse(
"redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25",
)
.unwrap(),
ConnectionInfo {
addr: ConnectionAddr::Unix("/example.sock".into()),
redis: RedisConnectionInfo {
db: 2,
username: Some("%johndoe%".into()),
password: Some("&?= *+".into()),
protocol: ProtocolVersion::RESP2,
skip_set_lib_name: false,
lib_name: None,
lib_ver: None,
},
tcp_settings: TcpSettings::default(),
},
),
(
url::Url::parse("redis+unix:///var/run/redis.sock?protocol=3").unwrap(),
ConnectionInfo {
addr: ConnectionAddr::Unix("/var/run/redis.sock".into()),
redis: RedisConnectionInfo {
db: 0,
username: None,
password: None,
protocol: ProtocolVersion::RESP3,
skip_set_lib_name: false,
lib_name: None,
lib_ver: None,
},
tcp_settings: TcpSettings::default(),
},
),
];
for (url, expected) in cases.into_iter() {
assert_eq!(
ConnectionAddr::Unix(url.to_file_path().unwrap()),
expected.addr,
"addr of {url} is not expected",
);
let res = url_to_unix_connection_info(url.clone()).unwrap();
assert_eq!(res.addr, expected.addr, "addr of {url} is not expected");
assert_eq!(
res.redis.db, expected.redis.db,
"db of {url} is not expected",
);
assert_eq!(
res.redis.username, expected.redis.username,
"username of {url} is not expected",
);
assert_eq!(
res.redis.password, expected.redis.password,
"password of {url} is not expected",
);
}
}
#[test]
fn redis_connection_info_lib_name_default() {
let redis_connection_info = RedisConnectionInfo::default();
assert_eq!(redis_connection_info.lib_name(), None);
assert_eq!(redis_connection_info.lib_ver(), None);
assert_lib_name_in_connection_setup_pipeline(
&redis_connection_info,
DEFAULT_CLIENT_SETINFO_LIB_NAME,
DEFAULT_CLIENT_SETINFO_LIB_VER,
);
}
#[test]
fn redis_connection_info_lib_name_custom() {
let mut redis_connection_info = RedisConnectionInfo::default();
redis_connection_info = redis_connection_info.set_skip_set_lib_name();
assert!(redis_connection_info.skip_set_lib_name());
redis_connection_info = redis_connection_info.set_lib_name("foo", "42.4711");
assert!(!redis_connection_info.skip_set_lib_name());
assert_eq!(redis_connection_info.lib_name(), Some("foo"));
assert_eq!(redis_connection_info.lib_ver(), Some("42.4711"));
assert_lib_name_in_connection_setup_pipeline(&redis_connection_info, "foo", "42.4711");
}
}