use super::utils as protocol_utils;
use crate::{
error::{Error, ErrorKind},
modules::inner::ClientInner,
prelude::FredResult,
protocol::{cluster, utils::server_to_parts},
runtime::{RefCount, Sender},
types::{
scan::{HScanResult, SScanResult, ScanResult, ZScanResult},
Key,
Map,
Value,
},
utils,
};
use async_trait::async_trait;
use bytes_utils::Str;
use rand::Rng;
use redis_protocol::{resp2::types::BytesFrame as Resp2Frame, resp3::types::BytesFrame as Resp3Frame};
use std::{
cmp::Ordering,
collections::{BTreeMap, BTreeSet, HashMap},
convert::TryInto,
fmt::{Display, Formatter},
hash::{Hash, Hasher},
net::{SocketAddr, ToSocketAddrs},
};
#[cfg(any(
feature = "enable-rustls",
feature = "enable-native-tls",
feature = "enable-rustls-ring"
))]
use crate::types::config::TlsHostMapping;
#[cfg(any(
feature = "enable-rustls",
feature = "enable-native-tls",
feature = "enable-rustls-ring"
))]
use std::{net::IpAddr, str::FromStr};
#[derive(Debug, Clone)]
pub enum ProtocolFrame {
Resp2(Resp2Frame),
Resp3(Resp3Frame),
}
impl ProtocolFrame {
pub fn into_resp3(self) -> Resp3Frame {
match self {
ProtocolFrame::Resp2(frame) => frame.into_resp3(),
ProtocolFrame::Resp3(frame) => frame,
}
}
pub fn is_resp3(&self) -> bool {
matches!(*self, ProtocolFrame::Resp3(_))
}
}
impl From<Resp2Frame> for ProtocolFrame {
fn from(frame: Resp2Frame) -> Self {
ProtocolFrame::Resp2(frame)
}
}
impl From<Resp3Frame> for ProtocolFrame {
fn from(frame: Resp3Frame) -> Self {
ProtocolFrame::Resp3(frame)
}
}
#[derive(Debug, Clone)]
pub struct Server {
pub host: Str,
pub port: u16,
#[cfg(any(
feature = "enable-rustls",
feature = "enable-native-tls",
feature = "enable-rustls-ring"
))]
#[cfg_attr(
docsrs,
doc(cfg(any(
feature = "enable-rustls",
feature = "enable-native-tls",
feature = "enable-rustls-ring"
)))
)]
pub tls_server_name: Option<Str>,
}
impl Server {
#[cfg(any(
feature = "enable-rustls",
feature = "enable-native-tls",
feature = "enable-rustls-ring"
))]
#[cfg_attr(
docsrs,
doc(cfg(any(
feature = "enable-rustls",
feature = "enable-native-tls",
feature = "enable-rustls-ring"
)))
)]
pub fn new_with_tls<S: Into<Str>>(host: S, port: u16, tls_server_name: Option<String>) -> Self {
Server {
host: host.into(),
port,
tls_server_name: tls_server_name.map(|s| s.into()),
}
}
pub fn new<S: Into<Str>>(host: S, port: u16) -> Self {
Server {
host: host.into(),
port,
#[cfg(any(
feature = "enable-rustls",
feature = "enable-native-tls",
feature = "enable-rustls-ring"
))]
tls_server_name: None,
}
}
#[cfg(any(
feature = "enable-rustls",
feature = "enable-native-tls",
feature = "enable-rustls-ring"
))]
pub(crate) fn set_tls_server_name(&mut self, policy: &TlsHostMapping, default_host: &str) {
if *policy == TlsHostMapping::None {
return;
}
let ip = match IpAddr::from_str(&self.host) {
Ok(ip) => ip,
Err(_) => return,
};
if let Some(tls_server_name) = policy.map(&ip, default_host) {
self.tls_server_name = Some(Str::from(tls_server_name));
}
}
pub(crate) fn from_str(s: &str) -> Option<Server> {
let parts: Vec<&str> = s.trim().split(':').collect();
if parts.len() == 2 {
if let Ok(port) = parts[1].parse::<u16>() {
Some(Server {
host: parts[0].into(),
port,
#[cfg(any(
feature = "enable-rustls",
feature = "enable-native-tls",
feature = "enable-rustls-ring"
))]
tls_server_name: None,
})
} else {
None
}
} else {
None
}
}
pub(crate) fn from_parts(server: &str, default_host: &str) -> Option<Server> {
server_to_parts(server).ok().map(|(host, port)| {
let host = if host.is_empty() {
Str::from(default_host)
} else {
Str::from(host)
};
Server {
host,
port,
#[cfg(any(
feature = "enable-rustls",
feature = "enable-native-tls",
feature = "enable-rustls-ring"
))]
tls_server_name: None,
}
})
}
}
#[cfg(feature = "unix-sockets")]
#[doc(hidden)]
impl From<&std::path::Path> for Server {
fn from(value: &std::path::Path) -> Self {
Server {
host: utils::path_to_string(value).into(),
port: 0,
#[cfg(any(
feature = "enable-rustls",
feature = "enable-native-tls",
feature = "enable-rustls-ring"
))]
tls_server_name: None,
}
}
}
impl TryFrom<String> for Server {
type Error = Error;
fn try_from(value: String) -> Result<Self, Self::Error> {
Server::from_str(&value).ok_or(Error::new(ErrorKind::Config, "Invalid `host:port` server."))
}
}
impl TryFrom<&str> for Server {
type Error = Error;
fn try_from(value: &str) -> Result<Self, Self::Error> {
Server::from_str(value).ok_or(Error::new(ErrorKind::Config, "Invalid `host:port` server."))
}
}
impl From<(String, u16)> for Server {
fn from((host, port): (String, u16)) -> Self {
Server {
host: host.into(),
port,
#[cfg(any(
feature = "enable-native-tls",
feature = "enable-rustls",
feature = "enable-rustls-ring"
))]
tls_server_name: None,
}
}
}
impl From<(&str, u16)> for Server {
fn from((host, port): (&str, u16)) -> Self {
Server {
host: host.into(),
port,
#[cfg(any(
feature = "enable-native-tls",
feature = "enable-rustls",
feature = "enable-rustls-ring"
))]
tls_server_name: None,
}
}
}
impl From<&Server> for Server {
fn from(value: &Server) -> Self {
value.clone()
}
}
impl Display for Server {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}:{}", self.host, self.port)
}
}
impl PartialEq for Server {
fn eq(&self, other: &Self) -> bool {
self.host == other.host && self.port == other.port
}
}
impl Eq for Server {}
impl Hash for Server {
fn hash<H: Hasher>(&self, state: &mut H) {
self.host.hash(state);
self.port.hash(state);
}
}
impl PartialOrd for Server {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Server {
fn cmp(&self, other: &Self) -> Ordering {
let host_ord = self.host.cmp(&other.host);
if host_ord == Ordering::Equal {
self.port.cmp(&other.port)
} else {
host_ord
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum MessageKind {
Message,
PMessage,
SMessage,
}
impl MessageKind {
pub(crate) fn from_str(s: &str) -> Option<MessageKind> {
Some(match s {
"message" => MessageKind::Message,
"pmessage" => MessageKind::PMessage,
"smessage" => MessageKind::SMessage,
_ => return None,
})
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Message {
pub channel: Str,
pub value: Value,
pub kind: MessageKind,
pub server: Server,
}
pub struct KeyScanInner {
pub hash_slot: Option<u16>,
pub server: Option<Server>,
pub cursor_idx: usize,
pub args: Vec<Value>,
pub tx: Sender<Result<ScanResult, Error>>,
}
pub struct KeyScanBufferedInner {
pub hash_slot: Option<u16>,
pub server: Option<Server>,
pub cursor_idx: usize,
pub args: Vec<Value>,
pub tx: Sender<Result<Key, Error>>,
}
impl KeyScanInner {
pub fn update_cursor(&mut self, cursor: Str) {
self.args[self.cursor_idx] = cursor.into();
}
pub fn send_error(&self, error: Error) {
let _ = self.tx.try_send(Err(error));
}
}
impl KeyScanBufferedInner {
pub fn update_cursor(&mut self, cursor: Str) {
self.args[self.cursor_idx] = cursor.into();
}
pub fn send_error(&self, error: Error) {
let _ = self.tx.try_send(Err(error));
}
}
pub enum ValueScanResult {
SScan(SScanResult),
HScan(HScanResult),
ZScan(ZScanResult),
}
pub struct ValueScanInner {
pub cursor_idx: usize,
pub args: Vec<Value>,
pub tx: Sender<Result<ValueScanResult, Error>>,
}
impl ValueScanInner {
pub fn update_cursor(&mut self, cursor: Str) {
self.args[self.cursor_idx] = cursor.into();
}
pub fn send_error(&self, error: Error) {
let _ = self.tx.try_send(Err(error));
}
pub fn transform_hscan_result(mut data: Vec<Value>) -> Result<Map, Error> {
if data.is_empty() {
return Ok(Map::new());
}
if data.len() % 2 != 0 {
return Err(Error::new(
ErrorKind::Protocol,
"Invalid HSCAN result. Expected array with an even number of elements.",
));
}
let mut out = HashMap::with_capacity(data.len() / 2);
while data.len() >= 2 {
let value = data.pop().unwrap();
let key: Key = match data.pop().unwrap() {
Value::String(s) => s.into(),
Value::Bytes(b) => b.into(),
_ => {
return Err(Error::new(
ErrorKind::Protocol,
"Invalid HSCAN result. Expected string.",
))
},
};
out.insert(key, value);
}
out.try_into()
}
pub fn transform_zscan_result(mut data: Vec<Value>) -> Result<Vec<(Value, f64)>, Error> {
if data.is_empty() {
return Ok(Vec::new());
}
if data.len() % 2 != 0 {
return Err(Error::new(
ErrorKind::Protocol,
"Invalid ZSCAN result. Expected array with an even number of elements.",
));
}
let mut out = Vec::with_capacity(data.len() / 2);
for chunk in data.chunks_exact_mut(2) {
let value = chunk[0].take();
let score = match chunk[1].take() {
Value::String(s) => utils::string_to_f64(&s)?,
Value::Integer(i) => i as f64,
Value::Double(f) => f,
_ => {
return Err(Error::new(
ErrorKind::Protocol,
"Invalid HSCAN result. Expected a string or number score.",
))
},
};
out.push((value, score));
}
Ok(out)
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct SlotRange {
pub start: u16,
pub end: u16,
pub primary: Server,
pub id: Str,
#[cfg(feature = "replicas")]
#[cfg_attr(docsrs, doc(cfg(feature = "replicas")))]
pub replicas: Vec<Server>,
}
#[derive(Debug, Clone)]
pub struct ClusterRouting {
data: Vec<SlotRange>,
}
impl ClusterRouting {
pub fn new() -> Self {
ClusterRouting { data: Vec::new() }
}
pub fn from_cluster_slots<S: Into<Str>>(value: Value, default_host: S) -> Result<Self, Error> {
let default_host = default_host.into();
let mut data = cluster::parse_cluster_slots(value, &default_host)?;
data.sort_by(|a, b| a.start.cmp(&b.start));
Ok(ClusterRouting { data })
}
pub fn unique_hash_slots(&self) -> Vec<u16> {
let mut out = BTreeMap::new();
for slot in self.data.iter() {
out.insert(&slot.primary, slot.start);
}
out.into_iter().map(|(_, v)| v).collect()
}
pub fn unique_primary_nodes(&self) -> Vec<Server> {
let mut out = BTreeSet::new();
for slot in self.data.iter() {
out.insert(slot.primary.clone());
}
out.into_iter().collect()
}
pub(crate) fn rebuild(
&mut self,
inner: &RefCount<ClientInner>,
cluster_slots: Value,
default_host: &Str,
) -> Result<(), Error> {
self.data = cluster::parse_cluster_slots(cluster_slots, default_host)?;
self.data.sort_by(|a, b| a.start.cmp(&b.start));
cluster::modify_cluster_slot_hostnames(inner, &mut self.data, default_host);
Ok(())
}
pub fn hash_key(key: &[u8]) -> u16 {
redis_protocol::redis_keyslot(key)
}
pub fn get_server(&self, slot: u16) -> Option<&Server> {
if self.data.is_empty() {
return None;
}
protocol_utils::binary_search(&self.data, slot).map(|idx| &self.data[idx].primary)
}
#[cfg(feature = "replicas")]
#[cfg_attr(docsrs, doc(cfg(feature = "replicas")))]
pub fn replicas(&self, primary: &Server) -> Vec<Server> {
self
.data
.iter()
.fold(BTreeSet::new(), |mut replicas, slot| {
if slot.primary == *primary {
replicas.extend(slot.replicas.clone());
}
replicas
})
.into_iter()
.collect()
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn slots(&self) -> &[SlotRange] {
&self.data
}
pub fn random_slot(&self) -> Option<&SlotRange> {
if !self.data.is_empty() {
let idx = rand::thread_rng().gen_range(0 .. self.data.len());
Some(&self.data[idx])
} else {
None
}
}
pub fn random_node(&self) -> Option<&Server> {
self.random_slot().map(|slot| &slot.primary)
}
pub fn pretty(&self) -> BTreeMap<Server, (Vec<(u16, u16)>, BTreeSet<Server>)> {
let mut out = BTreeMap::new();
for slot_range in self.data.iter() {
let entry = out
.entry(slot_range.primary.clone())
.or_insert((Vec::new(), BTreeSet::new()));
entry.0.push((slot_range.start, slot_range.end));
#[cfg(feature = "replicas")]
entry.1.extend(slot_range.replicas.iter().cloned());
}
out
}
}
#[derive(Clone, Debug)]
pub struct DefaultResolver {
id: Str,
}
impl DefaultResolver {
pub fn new(id: &Str) -> Self {
DefaultResolver { id: id.clone() }
}
}
#[cfg(feature = "glommio")]
#[async_trait(?Send)]
#[cfg_attr(docsrs, doc(cfg(feature = "dns")))]
pub trait Resolve: 'static {
async fn resolve(&self, host: Str, port: u16) -> FredResult<Vec<SocketAddr>>;
}
#[cfg(feature = "glommio")]
#[async_trait(?Send)]
impl Resolve for DefaultResolver {
async fn resolve(&self, host: Str, port: u16) -> FredResult<Vec<SocketAddr>> {
let client_id = self.id.clone();
crate::runtime::spawn(async move {
let addr = format!("{}:{}", host, port);
let ips: Vec<SocketAddr> = addr.to_socket_addrs()?.collect();
if ips.is_empty() {
Err(Error::new(
ErrorKind::IO,
format!("Failed to resolve {}:{}", host, port),
))
} else {
trace!("{}: Found {} addresses for {}", client_id, ips.len(), addr);
Ok(ips)
}
})
.await?
}
}
#[cfg(not(feature = "glommio"))]
#[async_trait]
#[cfg_attr(docsrs, doc(cfg(feature = "dns")))]
pub trait Resolve: Send + Sync + 'static {
async fn resolve(&self, host: Str, port: u16) -> FredResult<Vec<SocketAddr>>;
}
#[cfg(not(feature = "glommio"))]
#[async_trait]
impl Resolve for DefaultResolver {
async fn resolve(&self, host: Str, port: u16) -> FredResult<Vec<SocketAddr>> {
let client_id = self.id.clone();
tokio::task::spawn_blocking(move || {
let addr = format!("{}:{}", host, port);
let ips: Vec<SocketAddr> = addr.to_socket_addrs()?.collect();
if ips.is_empty() {
Err(Error::new(
ErrorKind::IO,
format!("Failed to resolve {}:{}", host, port),
))
} else {
trace!("{}: Found {} addresses for {}", client_id, ips.len(), addr);
Ok(ips)
}
})
.await?
}
}