use rand::Rng;
use strum_macros::Display;
use crate::cluster::topology::get_slot;
use crate::cmd::{Arg, Cmd};
use crate::value::Value;
use crate::value::{ErrorKind, Error, Result};
use core::cmp::Ordering;
use std::borrow::Cow;
use std::cmp::min;
use std::collections::HashMap;
use std::iter::Once;
use std::sync::Arc;
use std::sync::{RwLock, RwLockWriteGuard};
#[derive(Clone)]
pub(crate) enum Redirect {
Moved(String),
Ask(String, bool),
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum LogicalAggregateOp {
And,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum AggregateOp {
Min,
Sum,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ArrayAggregateOp {
Min,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ResponsePolicy {
OneSucceeded,
FirstSucceededNonEmptyOrAllEmpty,
AllSucceeded,
AggregateLogical(LogicalAggregateOp),
Aggregate(AggregateOp),
AggregateArray(ArrayAggregateOp),
CombineArrays,
Special,
CombineMaps,
}
#[derive(Debug, Clone, PartialEq)]
pub enum RoutingInfo {
SingleNode(SingleNodeRoutingInfo),
MultiNode((MultipleNodeRoutingInfo, Option<ResponsePolicy>)),
}
#[derive(Debug, Clone, PartialEq)]
pub enum SingleNodeRoutingInfo {
Random,
RandomPrimary,
SpecificNode(Route),
ByAddress {
host: String,
port: u16,
},
}
impl From<Option<Route>> for SingleNodeRoutingInfo {
fn from(value: Option<Route>) -> Self {
value
.map(SingleNodeRoutingInfo::SpecificNode)
.unwrap_or(SingleNodeRoutingInfo::Random)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum MultipleNodeRoutingInfo {
AllNodes,
AllMasters,
MultiSlot((Vec<(Route, Vec<usize>)>, MultiSlotArgPattern)),
}
pub fn command_for_multi_slot_indices<'a, 'b>(
original_cmd: &'a impl Routable,
indices: impl Iterator<Item = &'b usize> + 'a,
) -> Cmd
where
'b: 'a,
{
let mut new_cmd = Cmd::new();
let command_length = 1; new_cmd.arg(original_cmd.arg_idx(0));
for index in indices {
new_cmd.arg(original_cmd.arg_idx(index + command_length));
}
new_cmd
}
pub fn aggregate(values: Vec<Value>, op: AggregateOp) -> Result<Value> {
let initial_value = match op {
AggregateOp::Min => i64::MAX,
AggregateOp::Sum => 0,
};
let result = values.into_iter().try_fold(initial_value, |acc, curr| {
let int = match curr {
Value::Int(int) => int,
_ => {
return Result::Err(
(
ErrorKind::TypeError,
"expected array of integers as response",
)
.into(),
);
}
};
let acc = match op {
AggregateOp::Min => min(acc, int),
AggregateOp::Sum => acc + int,
};
Ok(acc)
})?;
Ok(Value::Int(result))
}
pub fn logical_aggregate(values: Vec<Value>, op: LogicalAggregateOp) -> Result<Value> {
let initial_value = match op {
LogicalAggregateOp::And => true,
};
let results = values.into_iter().try_fold(Vec::new(), |acc, curr| {
let values = match curr {
Value::Array(values) => values,
_ => {
return Result::Err(
(
ErrorKind::TypeError,
"expected array of integers as response",
)
.into(),
);
}
};
let mut acc = if acc.is_empty() {
vec![initial_value; values.len()]
} else {
acc
};
for (index, value) in values.into_iter().enumerate() {
let int = match value {
Ok(Value::Int(int)) => int,
_ => {
return Err((
ErrorKind::TypeError,
"expected array of integers as response",
)
.into());
}
};
acc[index] = match op {
LogicalAggregateOp::And => acc[index] && (int > 0),
};
}
Ok(acc)
})?;
Ok(Value::Array(
results
.into_iter()
.map(|result| Ok(Value::Int(result as i64)))
.collect(),
))
}
pub fn aggregate_array(values: Vec<Value>, op: ArrayAggregateOp) -> Result<Value> {
let initial_value = match op {
ArrayAggregateOp::Min => i64::MAX,
};
let results = values.into_iter().try_fold(Vec::new(), |acc, curr| {
let values = match curr {
Value::Array(values) => values,
_ => {
return Result::Err(
(
ErrorKind::TypeError,
"expected array of integers as response",
)
.into(),
);
}
};
let mut acc = if acc.is_empty() {
vec![initial_value; values.len()]
} else {
acc
};
for (index, value) in values.into_iter().enumerate() {
let int = match value {
Ok(Value::Int(int)) => int,
_ => {
return Err((
ErrorKind::TypeError,
"expected array of integers as response",
)
.into());
}
};
acc[index] = match op {
ArrayAggregateOp::Min => min(acc[index], int),
};
}
Ok(acc)
})?;
Ok(Value::Array(results.into_iter().map(|i| Ok(Value::Int(i))).collect()))
}
pub fn combine_map_results(values: Vec<Value>) -> Result<Value> {
let mut map: HashMap<Vec<u8>, i64> = HashMap::new();
for value in values {
match value {
Value::Array(elements) => {
let mut iter = elements.into_iter();
while let Some(Ok(key)) = iter.next() {
if let Value::BulkString(key_bytes) = key {
if let Some(Ok(Value::Int(value))) = iter.next() {
*map.entry(key_bytes.to_vec()).or_insert(0) += value;
} else {
return Err((ErrorKind::TypeError, "expected integer value").into());
}
} else {
return Err((ErrorKind::TypeError, "expected string key").into());
}
}
}
_ => {
return Err((ErrorKind::TypeError, "expected array of values as response").into());
}
}
}
let result_vec: Vec<(Value, Value)> = map
.into_iter()
.map(|(k, v)| (Value::BulkString(bytes::Bytes::from(k)), Value::Int(v)))
.collect();
Ok(Value::Map(result_vec))
}
pub fn combine_array_results(values: Vec<Value>) -> Result<Value> {
let mut results = Vec::new();
for value in values {
match value {
Value::Array(values) => results.extend(values),
_ => {
return Err((ErrorKind::TypeError, "expected array of values as response").into());
}
}
}
Ok(Value::Array(results))
}
type MultiSlotResIdxIter<'a> = std::iter::Map<
std::slice::Iter<'a, (Route, Vec<usize>)>,
fn(&'a (Route, Vec<usize>)) -> Cow<'a, [usize]>,
>;
fn calculate_multi_slot_result_indices<'a>(
route_arg_indices: &'a [(Route, Vec<usize>)],
args_pattern: &MultiSlotArgPattern,
) -> Result<MultiSlotResIdxIter<'a>> {
let check_indices_input = |step_count: usize| {
for (_, indices) in route_arg_indices {
if indices.len() % step_count != 0 {
return Err(Error::from((
ErrorKind::ClientError,
"Invalid indices input detected",
format!(
"Expected argument pattern with tuples of size {step_count}, but found indices: {indices:?}"
),
)));
}
}
Ok(())
};
match args_pattern {
MultiSlotArgPattern::KeysOnly => Ok(route_arg_indices
.iter()
.map(|(_, indices)| Cow::Borrowed(indices))),
MultiSlotArgPattern::KeysAndLastArg => {
Ok(route_arg_indices
.iter()
.map(|(_, indices)| Cow::Borrowed(&indices[..indices.len() - 1])))
}
MultiSlotArgPattern::KeyWithTwoArgTriples => {
check_indices_input(3)?;
Ok(route_arg_indices.iter().map(|(_, indices)| {
Cow::Owned(
indices
.iter()
.step_by(3)
.map(|idx| idx / 3)
.collect::<Vec<usize>>(),
)
}))
}
MultiSlotArgPattern::KeyValuePairs =>
{
check_indices_input(2)?;
Ok(route_arg_indices.iter().map(|(_, indices)| {
Cow::Owned(
indices
.iter()
.step_by(2)
.map(|idx| idx / 2)
.collect::<Vec<usize>>(),
)
}))
}
}
}
pub(crate) fn combine_and_sort_array_results(
values: Vec<Value>,
route_arg_indices: &[(Route, Vec<usize>)],
args_pattern: &MultiSlotArgPattern,
) -> Result<Value> {
let result_indices = calculate_multi_slot_result_indices(route_arg_indices, args_pattern)?;
let mut results: Vec<Result<Value>> = Vec::new();
results.resize(
values.iter().fold(0, |acc, value| match value {
Value::Array(values) => values.len() + acc,
_ => 0,
}),
Ok(Value::Nil),
);
if values.len() != result_indices.len() {
return Err(Error::from((
ErrorKind::ClientError,
"Mismatch in the number of multi-slot results compared to the expected result count.",
format!(
"Expected: {:?}, Found: {:?}",
values.len(),
result_indices.len()
),
)));
}
for (key_indices, value) in result_indices.into_iter().zip(values) {
match value {
Value::Array(values) => {
debug_assert_eq!(values.len(), key_indices.len());
for (index, value) in key_indices.iter().zip(values) {
results[*index] = value;
}
}
_ => {
return Err((ErrorKind::TypeError, "expected array of values as response").into());
}
}
}
Ok(Value::Array(results))
}
fn get_route(is_readonly: bool, key: &[u8]) -> Route {
let slot = get_slot(key);
if is_readonly {
Route::new(slot, SlotAddr::ReplicaOptional)
} else {
Route::new(slot, SlotAddr::Master)
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum MultiSlotArgPattern {
KeysOnly,
KeyValuePairs,
KeysAndLastArg,
KeyWithTwoArgTriples,
}
fn multi_shard<R>(
routable: &R,
cmd: &[u8],
first_key_index: usize,
args_pattern: MultiSlotArgPattern,
) -> Option<RoutingInfo>
where
R: Routable + ?Sized,
{
let is_readonly = is_readonly_cmd(cmd);
let mut routes = HashMap::new();
let mut curr_arg_idx = 0;
let incr_add_next_arg = |arg_indices: &mut Vec<usize>, mut curr_arg_idx: usize| {
curr_arg_idx += 1;
routable.arg_idx(curr_arg_idx)?;
arg_indices.push(curr_arg_idx);
Some(curr_arg_idx)
};
while let Some(arg) = routable.arg_idx(first_key_index + curr_arg_idx) {
let route = get_route(is_readonly, arg);
let arg_indices = routes.entry(route).or_insert(Vec::new());
arg_indices.push(curr_arg_idx);
match args_pattern {
MultiSlotArgPattern::KeysOnly => {} MultiSlotArgPattern::KeyValuePairs => {
curr_arg_idx = incr_add_next_arg(arg_indices, curr_arg_idx)?;
}
MultiSlotArgPattern::KeysAndLastArg => {
if routable
.arg_idx(first_key_index + curr_arg_idx + 2)
.is_none()
{
let path_idx = curr_arg_idx + 1;
for (_, arg_indices) in routes.iter_mut() {
arg_indices.push(path_idx);
}
break;
}
}
MultiSlotArgPattern::KeyWithTwoArgTriples => {
curr_arg_idx = incr_add_next_arg(arg_indices, curr_arg_idx)?;
curr_arg_idx = incr_add_next_arg(arg_indices, curr_arg_idx)?;
}
}
curr_arg_idx += 1;
}
let mut routes: Vec<(Route, Vec<usize>)> = routes.into_iter().collect();
if routes.is_empty() {
return None;
}
Some(if routes.len() == 1 {
RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(routes.pop().unwrap().0))
} else {
RoutingInfo::MultiNode((
MultipleNodeRoutingInfo::MultiSlot((routes, args_pattern)),
ResponsePolicy::for_command(cmd),
))
})
}
impl ResponsePolicy {
pub fn for_command(cmd: &[u8]) -> Option<ResponsePolicy> {
match cmd {
b"SCRIPT EXISTS" => Some(ResponsePolicy::AggregateLogical(LogicalAggregateOp::And)),
b"DBSIZE" | b"DEL" | b"EXISTS" | b"SLOWLOG LEN" | b"TOUCH" | b"UNLINK"
| b"LATENCY RESET" | b"PUBSUB NUMPAT" => {
Some(ResponsePolicy::Aggregate(AggregateOp::Sum))
}
b"WAIT" => Some(ResponsePolicy::Aggregate(AggregateOp::Min)),
b"WAITAOF" => Some(ResponsePolicy::AggregateArray(ArrayAggregateOp::Min)),
b"ACL SETUSER" | b"ACL DELUSER" | b"ACL SAVE" | b"AUTH" | b"CLIENT SETNAME"
| b"CLIENT SETINFO" | b"CONFIG SET" | b"CONFIG RESETSTAT" | b"CONFIG REWRITE"
| b"FLUSHALL" | b"FLUSHDB" | b"FUNCTION DELETE" | b"FUNCTION FLUSH"
| b"FUNCTION LOAD" | b"FUNCTION RESTORE" | b"MEMORY PURGE" | b"MSET" | b"JSON.MSET"
| b"PING" | b"SCRIPT FLUSH" | b"SCRIPT LOAD" | b"SELECT" | b"SLOWLOG RESET"
| b"UNWATCH" | b"WATCH" => Some(ResponsePolicy::AllSucceeded),
b"KEYS"
| b"FT._ALIASLIST"
| b"FT._LIST"
| b"MGET"
| b"JSON.MGET"
| b"SLOWLOG GET"
| b"PUBSUB CHANNELS"
| b"PUBSUB SHARDCHANNELS" => Some(ResponsePolicy::CombineArrays),
b"PUBSUB NUMSUB" | b"PUBSUB SHARDNUMSUB" => Some(ResponsePolicy::CombineMaps),
b"FUNCTION KILL" | b"SCRIPT KILL" => Some(ResponsePolicy::OneSucceeded),
b"RANDOMKEY" => Some(ResponsePolicy::FirstSucceededNonEmptyOrAllEmpty),
b"LATENCY GRAPH" | b"LATENCY HISTOGRAM" | b"LATENCY HISTORY" | b"LATENCY DOCTOR"
| b"LATENCY LATEST" => Some(ResponsePolicy::Special),
b"FUNCTION STATS" => Some(ResponsePolicy::Special),
b"MEMORY MALLOC-STATS" | b"MEMORY DOCTOR" | b"MEMORY STATS" => {
Some(ResponsePolicy::Special)
}
b"INFO" => Some(ResponsePolicy::Special),
_ => None,
}
}
}
enum RouteBy {
AllNodes,
AllPrimaries,
FirstKey,
MultiShard(MultiSlotArgPattern),
Random,
SecondArg,
SecondArgAfterKeyCount,
SecondArgSlot,
StreamsIndex,
ThirdArg,
ThirdArgAfterKeyCount,
Undefined,
}
fn base_routing(cmd: &[u8]) -> RouteBy {
match cmd {
b"ACL SETUSER"
| b"ACL DELUSER"
| b"ACL SAVE"
| b"AUTH"
| b"CLIENT SETNAME"
| b"CLIENT SETINFO"
| b"SELECT"
| b"SLOWLOG GET"
| b"SLOWLOG LEN"
| b"SLOWLOG RESET"
| b"CONFIG SET"
| b"CONFIG RESETSTAT"
| b"CONFIG REWRITE"
| b"SCRIPT FLUSH"
| b"SCRIPT LOAD"
| b"LATENCY RESET"
| b"LATENCY GRAPH"
| b"LATENCY HISTOGRAM"
| b"LATENCY HISTORY"
| b"LATENCY DOCTOR"
| b"LATENCY LATEST"
| b"PUBSUB NUMPAT"
| b"PUBSUB CHANNELS"
| b"PUBSUB NUMSUB"
| b"PUBSUB SHARDCHANNELS"
| b"PUBSUB SHARDNUMSUB"
| b"SCRIPT KILL"
| b"FUNCTION KILL"
| b"FUNCTION STATS" => RouteBy::AllNodes,
b"DBSIZE"
| b"DEBUG"
| b"FLUSHALL"
| b"FLUSHDB"
| b"FT._ALIASLIST"
| b"FT._LIST"
| b"FUNCTION DELETE"
| b"FUNCTION FLUSH"
| b"FUNCTION LOAD"
| b"FUNCTION RESTORE"
| b"INFO"
| b"KEYS"
| b"MEMORY DOCTOR"
| b"MEMORY MALLOC-STATS"
| b"MEMORY PURGE"
| b"MEMORY STATS"
| b"PING"
| b"SCRIPT EXISTS"
| b"UNWATCH"
| b"WAIT"
| b"RANDOMKEY"
| b"WAITAOF" => RouteBy::AllPrimaries,
b"MGET" | b"DEL" | b"EXISTS" | b"UNLINK" | b"TOUCH" | b"WATCH" | b"SUBSCRIBE"
| b"PSUBSCRIBE" | b"SSUBSCRIBE" => RouteBy::MultiShard(MultiSlotArgPattern::KeysOnly),
b"MSET" => RouteBy::MultiShard(MultiSlotArgPattern::KeyValuePairs),
b"JSON.MGET" => RouteBy::MultiShard(MultiSlotArgPattern::KeysAndLastArg),
b"JSON.MSET" => RouteBy::MultiShard(MultiSlotArgPattern::KeyWithTwoArgTriples),
b"SCAN" | b"SHUTDOWN" | b"SLAVEOF" | b"REPLICAOF" => RouteBy::Undefined,
b"BLMPOP" | b"BZMPOP" | b"EVAL" | b"EVALSHA" | b"EVALSHA_RO" | b"EVAL_RO" | b"FCALL"
| b"FCALL_RO" => RouteBy::ThirdArgAfterKeyCount,
b"BITOP"
| b"MEMORY USAGE"
| b"PFDEBUG"
| b"XGROUP CREATE"
| b"XGROUP CREATECONSUMER"
| b"XGROUP DELCONSUMER"
| b"XGROUP DESTROY"
| b"XGROUP SETID"
| b"XINFO CONSUMERS"
| b"XINFO GROUPS"
| b"XINFO STREAM"
| b"OBJECT ENCODING"
| b"OBJECT FREQ"
| b"OBJECT IDLETIME"
| b"OBJECT REFCOUNT"
| b"JSON.DEBUG" => RouteBy::SecondArg,
b"MIGRATE" => RouteBy::ThirdArg,
b"LMPOP" | b"SINTERCARD" | b"ZDIFF" | b"ZINTER" | b"ZINTERCARD" | b"ZMPOP" | b"ZUNION" => {
RouteBy::SecondArgAfterKeyCount
}
b"XREAD" | b"XREADGROUP" => RouteBy::StreamsIndex,
b"ACL DRYRUN"
| b"ACL GENPASS"
| b"ACL GETUSER"
| b"ACL HELP"
| b"ACL LIST"
| b"ACL LOG"
| b"ACL USERS"
| b"ACL WHOAMI"
| b"BGSAVE"
| b"CLIENT GETNAME"
| b"CLIENT GETREDIR"
| b"CLIENT ID"
| b"CLIENT INFO"
| b"CLIENT KILL"
| b"CLIENT LIST"
| b"CLIENT PAUSE"
| b"CLIENT REPLY"
| b"CLIENT TRACKINGINFO"
| b"CLIENT UNBLOCK"
| b"CLIENT UNPAUSE"
| b"CLUSTER COUNT-FAILURE-REPORTS"
| b"CLUSTER INFO"
| b"CLUSTER KEYSLOT"
| b"CLUSTER MEET"
| b"CLUSTER MYSHARDID"
| b"CLUSTER NODES"
| b"CLUSTER REPLICAS"
| b"CLUSTER RESET"
| b"CLUSTER SET-CONFIG-EPOCH"
| b"CLUSTER SHARDS"
| b"CLUSTER SLOTS"
| b"COMMAND COUNT"
| b"COMMAND GETKEYS"
| b"COMMAND LIST"
| b"COMMAND"
| b"CONFIG GET"
| b"ECHO"
| b"FUNCTION LIST"
| b"LASTSAVE"
| b"LOLWUT"
| b"MODULE LIST"
| b"MODULE LOAD"
| b"MODULE LOADEX"
| b"MODULE UNLOAD"
| b"READONLY"
| b"READWRITE"
| b"SAVE"
| b"SCRIPT SHOW"
| b"TFCALL"
| b"TFCALLASYNC"
| b"TFUNCTION DELETE"
| b"TFUNCTION LIST"
| b"TFUNCTION LOAD"
| b"TIME" => RouteBy::Random,
b"CLUSTER ADDSLOTS"
| b"CLUSTER COUNTKEYSINSLOT"
| b"CLUSTER DELSLOTS"
| b"CLUSTER DELSLOTSRANGE"
| b"CLUSTER GETKEYSINSLOT"
| b"CLUSTER SETSLOT" => RouteBy::SecondArgSlot,
_ => RouteBy::FirstKey,
}
}
impl RoutingInfo {
pub fn is_all_nodes(cmd: &[u8]) -> bool {
matches!(base_routing(cmd), RouteBy::AllNodes)
}
pub fn is_key_routing_command(cmd: &[u8]) -> bool {
match base_routing(cmd) {
RouteBy::FirstKey
| RouteBy::SecondArg
| RouteBy::ThirdArg
| RouteBy::SecondArgAfterKeyCount
| RouteBy::ThirdArgAfterKeyCount
| RouteBy::SecondArgSlot
| RouteBy::StreamsIndex
| RouteBy::MultiShard(_) => {
if matches!(cmd, b"SPUBLISH") {
false
} else {
true
}
}
RouteBy::AllNodes | RouteBy::AllPrimaries | RouteBy::Random | RouteBy::Undefined => {
false
}
}
}
pub fn for_routable<R>(r: &R) -> Option<RoutingInfo>
where
R: Routable + ?Sized,
{
let cmd = &r.command()?[..];
match base_routing(cmd) {
RouteBy::AllNodes => Some(RoutingInfo::MultiNode((
MultipleNodeRoutingInfo::AllNodes,
ResponsePolicy::for_command(cmd),
))),
RouteBy::AllPrimaries => Some(RoutingInfo::MultiNode((
MultipleNodeRoutingInfo::AllMasters,
ResponsePolicy::for_command(cmd),
))),
RouteBy::MultiShard(arg_pattern) => multi_shard(r, cmd, 1, arg_pattern),
RouteBy::Random => Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)),
RouteBy::ThirdArgAfterKeyCount => {
let key_count = r
.arg_idx(2)
.and_then(|x| std::str::from_utf8(x).ok())
.and_then(|x| x.parse::<u64>().ok())?;
if key_count == 0 {
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random))
} else {
r.arg_idx(3).map(|key| RoutingInfo::for_key(cmd, key))
}
}
RouteBy::SecondArg => r.arg_idx(2).map(|key| RoutingInfo::for_key(cmd, key)),
RouteBy::ThirdArg => r.arg_idx(3).map(|key| RoutingInfo::for_key(cmd, key)),
RouteBy::SecondArgAfterKeyCount => {
let key_count = r
.arg_idx(1)
.and_then(|x| std::str::from_utf8(x).ok())
.and_then(|x| x.parse::<u64>().ok())?;
if key_count == 0 {
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random))
} else {
r.arg_idx(2).map(|key| RoutingInfo::for_key(cmd, key))
}
}
RouteBy::StreamsIndex => {
let streams_position = r.position(b"STREAMS")?;
r.arg_idx(streams_position + 1)
.map(|key| RoutingInfo::for_key(cmd, key))
}
RouteBy::SecondArgSlot => r
.arg_idx(2)
.and_then(|arg| std::str::from_utf8(arg).ok())
.and_then(|slot| slot.parse::<u16>().ok())
.map(|slot| {
RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new(
slot,
SlotAddr::Master,
)))
}),
RouteBy::FirstKey => match r.arg_idx(1) {
Some(key) => Some(RoutingInfo::for_key(cmd, key)),
None => Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)),
},
RouteBy::Undefined => None,
}
}
fn for_key(cmd: &[u8], key: &[u8]) -> RoutingInfo {
RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(get_route(
is_readonly_cmd(cmd),
key,
)))
}
}
pub fn is_readonly(routable: &impl Routable) -> bool {
match routable.command() {
Some(cmd) => is_readonly_cmd(cmd.as_slice()),
None => false,
}
}
pub fn is_readonly_cmd(cmd: &[u8]) -> bool {
matches!(
cmd,
b"ACL CAT"
| b"ACL DELUSER"
| b"ACL DRYRUN"
| b"ACL GENPASS"
| b"ACL GETUSER"
| b"ACL HELP"
| b"ACL LIST"
| b"ACL LOAD"
| b"ACL LOG"
| b"ACL SAVE"
| b"ACL SETUSER"
| b"ACL USERS"
| b"ACL WHOAMI"
| b"AUTH"
| b"BGREWRITEAOF"
| b"BGSAVE"
| b"BITCOUNT"
| b"BITFIELD_RO"
| b"BITPOS"
| b"CLIENT ID"
| b"CLIENT CACHING"
| b"CLIENT CAPA"
| b"CLIENT GETNAME"
| b"CLIENT GETREDIR"
| b"CLIENT HELP"
| b"CLIENT INFO"
| b"CLIENT KILL"
| b"CLIENT LIST"
| b"CLIENT NO-EVICT"
| b"CLIENT NO-TOUCH"
| b"CLIENT PAUSE"
| b"CLIENT REPLY"
| b"CLIENT SETINFO"
| b"CLIENT SETNAME"
| b"CLIENT TRACKING"
| b"CLIENT TRACKINGINFO"
| b"CLIENT UNBLOCK"
| b"CLIENT UNPAUSE"
| b"CLUSTER COUNT-FAILURE-REPORTS"
| b"CLUSTER COUNTKEYSINSLOT"
| b"CLUSTER FAILOVER"
| b"CLUSTER GETKEYSINSLOT"
| b"CLUSTER HELP"
| b"CLUSTER INFO"
| b"CLUSTER KEYSLOT"
| b"CLUSTER LINKS"
| b"CLUSTER MYID"
| b"CLUSTER MYSHARDID"
| b"CLUSTER NODES"
| b"CLUSTER REPLICATE"
| b"CLUSTER SAVECONFIG"
| b"CLUSTER SHARDS"
| b"CLUSTER SLOTS"
| b"COMMAND COUNT"
| b"COMMAND DOCS"
| b"COMMAND GETKEYS"
| b"COMMAND GETKEYSANDFLAGS"
| b"COMMAND HELP"
| b"COMMAND INFO"
| b"COMMAND LIST"
| b"CONFIG GET"
| b"CONFIG HELP"
| b"CONFIG RESETSTAT"
| b"CONFIG REWRITE"
| b"CONFIG SET"
| b"DBSIZE"
| b"DUMP"
| b"ECHO"
| b"EVAL_RO"
| b"EVALSHA_RO"
| b"EXISTS"
| b"EXPIRETIME"
| b"FCALL_RO"
| b"FT.AGGREGATE"
| b"FT.EXPLAIN"
| b"FT.EXPLAINCLI"
| b"FT.INFO"
| b"FT.PROFILE"
| b"FT.SEARCH"
| b"FT._ALIASLIST"
| b"FT._LIST"
| b"FUNCTION DUMP"
| b"FUNCTION HELP"
| b"FUNCTION KILL"
| b"FUNCTION LIST"
| b"FUNCTION STATS"
| b"GEODIST"
| b"GEOHASH"
| b"GEOPOS"
| b"GEORADIUSBYMEMBER_RO"
| b"GEORADIUS_RO"
| b"GEOSEARCH"
| b"GET"
| b"GETBIT"
| b"GETRANGE"
| b"HELLO"
| b"HEXISTS"
| b"HGET"
| b"HGETALL"
| b"HKEYS"
| b"HLEN"
| b"HMGET"
| b"HRANDFIELD"
| b"HSCAN"
| b"HSTRLEN"
| b"HVALS"
| b"JSON.ARRINDEX"
| b"JSON.ARRLEN"
| b"JSON.DEBUG"
| b"JSON.GET"
| b"JSON.OBJLEN"
| b"JSON.OBJKEYS"
| b"JSON.MGET"
| b"JSON.RESP"
| b"JSON.STRLEN"
| b"JSON.TYPE"
| b"INFO"
| b"KEYS"
| b"LASTSAVE"
| b"LATENCY DOCTOR"
| b"LATENCY GRAPH"
| b"LATENCY HELP"
| b"LATENCY HISTOGRAM"
| b"LATENCY HISTORY"
| b"LATENCY LATEST"
| b"LATENCY RESET"
| b"LCS"
| b"LINDEX"
| b"LLEN"
| b"LOLWUT"
| b"LPOS"
| b"LRANGE"
| b"MEMORY DOCTOR"
| b"MEMORY HELP"
| b"MEMORY MALLOC-STATS"
| b"MEMORY PURGE"
| b"MEMORY STATS"
| b"MEMORY USAGE"
| b"MGET"
| b"MODULE HELP"
| b"MODULE LIST"
| b"MODULE LOAD"
| b"MODULE LOADEX"
| b"MODULE UNLOAD"
| b"OBJECT ENCODING"
| b"OBJECT FREQ"
| b"OBJECT HELP"
| b"OBJECT IDLETIME"
| b"OBJECT REFCOUNT"
| b"PEXPIRETIME"
| b"PFCOUNT"
| b"PING"
| b"PTTL"
| b"PUBLISH"
| b"PUBSUB CHANNELS"
| b"PUBSUB HELP"
| b"PUBSUB NUMPAT"
| b"PUBSUB NUMSUB"
| b"PUBSUB SHARDCHANNELS"
| b"PUBSUB SHARDNUMSUB"
| b"RANDOMKEY"
| b"REPLICAOF"
| b"RESET"
| b"ROLE"
| b"SAVE"
| b"SCAN"
| b"SCARD"
| b"SCRIPT DEBUG"
| b"SCRIPT EXISTS"
| b"SCRIPT FLUSH"
| b"SCRIPT KILL"
| b"SCRIPT LOAD"
| b"SCRIPT SHOW"
| b"SDIFF"
| b"SELECT"
| b"SENTINEL GET-MASTER-ADDR-BY-NAME"
| b"SENTINEL MASTER"
| b"SENTINEL MASTERS"
| b"SENTINEL REPLICAS"
| b"SENTINEL CKQUORUM"
| b"SHUTDOWN"
| b"SINTER"
| b"SINTERCARD"
| b"SISMEMBER"
| b"SMEMBERS"
| b"SMISMEMBER"
| b"SLOWLOG GET"
| b"SLOWLOG HELP"
| b"SLOWLOG LEN"
| b"SLOWLOG RESET"
| b"SORT_RO"
| b"SPUBLISH"
| b"SRANDMEMBER"
| b"SSCAN"
| b"SSUBSCRIBE"
| b"STRLEN"
| b"SUBSCRIBE"
| b"SUBSTR"
| b"SUNION"
| b"SUNSUBSCRIBE"
| b"TIME"
| b"TOUCH"
| b"TTL"
| b"TYPE"
| b"UNSUBSCRIBE"
| b"XINFO CONSUMERS"
| b"XINFO GROUPS"
| b"XINFO HELP"
| b"XINFO STREAM"
| b"XLEN"
| b"XPENDING"
| b"XRANGE"
| b"XREAD"
| b"XREVRANGE"
| b"ZCARD"
| b"ZCOUNT"
| b"ZDIFF"
| b"ZINTER"
| b"ZINTERCARD"
| b"ZLEXCOUNT"
| b"ZMSCORE"
| b"ZRANDMEMBER"
| b"ZRANGE"
| b"ZRANGEBYLEX"
| b"ZRANGEBYSCORE"
| b"ZRANK"
| b"ZREVRANGE"
| b"ZREVRANGEBYLEX"
| b"ZREVRANGEBYSCORE"
| b"ZREVRANK"
| b"ZSCAN"
| b"ZSCORE"
| b"ZUNION"
)
}
pub trait Routable {
fn command(&self) -> Option<Vec<u8>> {
let primary_command = self.arg_idx(0).map(|x| x.to_ascii_uppercase())?;
let mut primary_command = match primary_command.as_slice() {
b"XGROUP" | b"OBJECT" | b"SLOWLOG" | b"FUNCTION" | b"MODULE" | b"COMMAND"
| b"PUBSUB" | b"CONFIG" | b"MEMORY" | b"XINFO" | b"CLIENT" | b"ACL" | b"SCRIPT"
| b"CLUSTER" | b"LATENCY" | b"SENTINEL" => primary_command,
_ => {
return Some(primary_command);
}
};
Some(match self.arg_idx(1) {
Some(secondary_command) => {
let previous_len = primary_command.len();
primary_command.reserve(secondary_command.len() + 1);
primary_command.extend(b" ");
primary_command.extend(secondary_command);
let current_len = primary_command.len();
primary_command[previous_len + 1..current_len].make_ascii_uppercase();
primary_command
}
None => primary_command,
})
}
fn arg_idx(&self, idx: usize) -> Option<&[u8]>;
fn position(&self, candidate: &[u8]) -> Option<usize>;
}
impl Routable for Cmd {
fn arg_idx(&self, idx: usize) -> Option<&[u8]> {
self.arg_idx(idx)
}
fn position(&self, candidate: &[u8]) -> Option<usize> {
self.args_iter().position(|a| match a {
Arg::Simple(d) => d.eq_ignore_ascii_case(candidate),
_ => false,
})
}
}
impl Routable for Value {
fn arg_idx(&self, idx: usize) -> Option<&[u8]> {
match self {
Value::Array(args) => match args.get(idx) {
Some(Ok(Value::BulkString(data))) => Some(&data[..]),
_ => None,
},
_ => None,
}
}
fn position(&self, candidate: &[u8]) -> Option<usize> {
match self {
Value::Array(args) => args.iter().position(|a| match a {
Ok(Value::BulkString(d)) => d.eq_ignore_ascii_case(candidate),
_ => false,
}),
_ => None,
}
}
}
#[derive(Debug, Hash, Clone)]
pub(crate) struct Slot {
pub(crate) start: u16,
pub(crate) end: u16,
pub(crate) master: String,
pub(crate) replicas: Vec<String>,
}
impl Slot {
pub fn new(s: u16, e: u16, m: String, r: Vec<String>) -> Self {
Self {
start: s,
end: e,
master: m,
replicas: r,
}
}
#[allow(dead_code)] pub(crate) fn master(&self) -> &str {
self.master.as_str()
}
#[cfg(test)]
pub fn replicas(&self) -> &[String] {
&self.replicas
}
}
#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash, Display)]
pub enum SlotAddr {
Master,
ReplicaOptional,
ReplicaRequired,
}
#[derive(PartialEq, Debug)]
pub(crate) enum ShardUpdateResult {
AlreadyPrimary,
Promoted,
NodeNotFound,
}
const READ_LK_ERR_SHARDADDRS: &str = "Failed to acquire read lock for ShardAddrs";
const WRITE_LK_ERR_SHARDADDRS: &str = "Failed to acquire write lock for ShardAddrs";
#[derive(Debug)]
pub struct ShardAddrs {
primary: RwLock<Arc<String>>,
replicas: RwLock<Vec<Arc<String>>>,
}
impl PartialEq for ShardAddrs {
fn eq(&self, other: &Self) -> bool {
let self_primary = self.primary.read().expect(READ_LK_ERR_SHARDADDRS);
let other_primary = other.primary.read().expect(READ_LK_ERR_SHARDADDRS);
let self_replicas = self.replicas.read().expect(READ_LK_ERR_SHARDADDRS);
let other_replicas = other.replicas.read().expect(READ_LK_ERR_SHARDADDRS);
*self_primary == *other_primary && *self_replicas == *other_replicas
}
}
impl Eq for ShardAddrs {}
impl PartialOrd for ShardAddrs {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ShardAddrs {
fn cmp(&self, other: &Self) -> Ordering {
let self_primary = self.primary.read().expect(READ_LK_ERR_SHARDADDRS);
let other_primary = other.primary.read().expect(READ_LK_ERR_SHARDADDRS);
let primary_cmp = self_primary.cmp(&other_primary);
if primary_cmp == Ordering::Equal {
let self_replicas = self.replicas.read().expect(READ_LK_ERR_SHARDADDRS);
let other_replicas = other.replicas.read().expect(READ_LK_ERR_SHARDADDRS);
return self_replicas.cmp(&other_replicas);
}
primary_cmp
}
}
impl ShardAddrs {
pub(crate) fn new(primary: Arc<String>, replicas: Vec<Arc<String>>) -> Self {
let primary = RwLock::new(primary);
let replicas = RwLock::new(replicas);
Self { primary, replicas }
}
pub(crate) fn new_with_primary(primary: Arc<String>) -> Self {
Self::new(primary, Vec::default())
}
pub fn primary(&self) -> Arc<String> {
self.primary.read().expect(READ_LK_ERR_SHARDADDRS).clone()
}
pub(crate) fn replicas(&self) -> std::sync::RwLockReadGuard<'_, Vec<Arc<String>>> {
self.replicas.read().expect(READ_LK_ERR_SHARDADDRS)
}
pub(crate) fn attempt_shard_role_update(&self, new_primary: Arc<String>) -> ShardUpdateResult {
let mut primary_lock = self.primary.write().expect(WRITE_LK_ERR_SHARDADDRS);
let mut replicas_lock = self.replicas.write().expect(WRITE_LK_ERR_SHARDADDRS);
if *primary_lock == new_primary {
return ShardUpdateResult::AlreadyPrimary;
}
if let Some(replica_idx) = Self::replica_index(&replicas_lock, new_primary.clone()) {
std::mem::swap(&mut *primary_lock, &mut replicas_lock[replica_idx]);
return ShardUpdateResult::Promoted;
}
ShardUpdateResult::NodeNotFound
}
fn replica_index(
replicas: &RwLockWriteGuard<'_, Vec<Arc<String>>>,
target_replica: Arc<String>,
) -> Option<usize> {
replicas
.iter()
.position(|curr_replica| **curr_replica == *target_replica)
}
pub(crate) fn is_member(&self, addr: &str) -> bool {
if self.primary.read().expect(READ_LK_ERR_SHARDADDRS).as_str() == addr {
return true;
}
self.replicas
.read()
.expect(READ_LK_ERR_SHARDADDRS)
.iter()
.any(|r| r.as_str() == addr)
}
pub(crate) fn remove_replica(&self, replica_to_remove: Arc<String>) -> Result<()> {
let mut replicas_lock = self.replicas.write().expect(WRITE_LK_ERR_SHARDADDRS);
if let Some(index) = Self::replica_index(&replicas_lock, replica_to_remove.clone()) {
replicas_lock.remove(index);
Ok(())
} else {
Err(Error::from((
ErrorKind::ClientError,
"Couldn't remove replica",
format!("Replica {replica_to_remove:?} not found"),
)))
}
}
}
impl IntoIterator for &ShardAddrs {
type Item = Arc<String>;
type IntoIter = std::iter::Chain<Once<Arc<String>>, std::vec::IntoIter<Arc<String>>>;
fn into_iter(self) -> Self::IntoIter {
let primary = self.primary.read().expect(READ_LK_ERR_SHARDADDRS).clone();
let replicas = self.replicas.read().expect(READ_LK_ERR_SHARDADDRS).clone();
std::iter::once(primary).chain(replicas)
}
}
#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)]
pub struct Route(u16, SlotAddr);
impl Route {
pub fn new(slot: u16, slot_addr: SlotAddr) -> Self {
Self(slot, slot_addr)
}
pub fn slot(&self) -> u16 {
self.0
}
pub fn slot_addr(&self) -> SlotAddr {
self.1
}
pub fn new_random_primary() -> Self {
Self::new(random_slot(), SlotAddr::Master)
}
}
fn random_slot() -> u16 {
let mut rng = rand::rng();
rng.random_range(0..crate::cluster::topology::SLOT_SIZE)
}
#[cfg(test)]
mod tests_routing {
use super::{
AggregateOp, MultiSlotArgPattern, MultipleNodeRoutingInfo, ResponsePolicy, Route,
RoutingInfo, ShardAddrs, SingleNodeRoutingInfo, SlotAddr, command_for_multi_slot_indices,
};
use crate::cluster::routing::{Routable, ShardUpdateResult, is_readonly, is_readonly_cmd};
use crate::cluster::topology::slot;
use crate::cmd::cmd;
use crate::protocol::parser::parse_valkey_value;
use crate::value::Value;
use core::panic;
use std::sync::{Arc, RwLock};
#[test]
fn test_routing_info_mixed_capatalization() {
let mut upper = cmd("XREAD");
upper.arg("STREAMS").arg("foo").arg(0);
let mut lower = cmd("xread");
lower.arg("streams").arg("foo").arg(0);
assert_eq!(
RoutingInfo::for_routable(&upper).unwrap(),
RoutingInfo::for_routable(&lower).unwrap()
);
let mut mixed = cmd("xReAd");
mixed.arg("StReAmS").arg("foo").arg(0);
assert_eq!(
RoutingInfo::for_routable(&lower).unwrap(),
RoutingInfo::for_routable(&mixed).unwrap()
);
}
#[test]
fn test_routing_info() {
let mut test_cmds = vec![];
let mut test_cmd = cmd("FLUSHALL");
test_cmd.arg("");
test_cmds.push(test_cmd);
test_cmd = cmd("ECHO");
test_cmd.arg("");
test_cmds.push(test_cmd);
test_cmd = cmd("SET");
test_cmd.arg("42");
test_cmds.push(test_cmd);
test_cmd = cmd("XINFO");
test_cmd.arg("GROUPS").arg("FOOBAR");
test_cmds.push(test_cmd);
test_cmd = cmd("EVAL");
test_cmd.arg("FOO").arg("0").arg("BAR");
test_cmds.push(test_cmd);
test_cmd = cmd("EVAL");
test_cmd.arg("FOO").arg("4").arg("BAR");
test_cmds.push(test_cmd);
test_cmd = cmd("XREAD");
test_cmd.arg("STREAMS").arg("4");
test_cmds.push(test_cmd);
test_cmd = cmd("XREAD");
test_cmd.arg("FOO").arg("STREAMS").arg("4");
test_cmds.push(test_cmd);
for cmd in test_cmds {
let value = parse_valkey_value(&cmd.get_packed_command()).unwrap();
assert_eq!(
RoutingInfo::for_routable(&value).unwrap(),
RoutingInfo::for_routable(&cmd).unwrap(),
);
}
for cmd in [cmd("FLUSHALL"), cmd("FLUSHDB"), cmd("PING")] {
assert_eq!(
RoutingInfo::for_routable(&cmd),
Some(RoutingInfo::MultiNode((
MultipleNodeRoutingInfo::AllMasters,
Some(ResponsePolicy::AllSucceeded)
)))
);
}
assert_eq!(
RoutingInfo::for_routable(&cmd("DBSIZE")),
Some(RoutingInfo::MultiNode((
MultipleNodeRoutingInfo::AllMasters,
Some(ResponsePolicy::Aggregate(AggregateOp::Sum))
)))
);
assert_eq!(
RoutingInfo::for_routable(&cmd("SCRIPT KILL")),
Some(RoutingInfo::MultiNode((
MultipleNodeRoutingInfo::AllNodes,
Some(ResponsePolicy::OneSucceeded)
)))
);
assert_eq!(
RoutingInfo::for_routable(&cmd("INFO")),
Some(RoutingInfo::MultiNode((
MultipleNodeRoutingInfo::AllMasters,
Some(ResponsePolicy::Special)
)))
);
assert_eq!(
RoutingInfo::for_routable(&cmd("KEYS")),
Some(RoutingInfo::MultiNode((
MultipleNodeRoutingInfo::AllMasters,
Some(ResponsePolicy::CombineArrays)
)))
);
for cmd in [cmd("SCAN"),
cmd("SHUTDOWN"),
cmd("SLAVEOF"),
cmd("REPLICAOF")] {
assert_eq!(
RoutingInfo::for_routable(&cmd),
None,
"{}",
std::str::from_utf8(cmd.arg_idx(0).unwrap()).unwrap()
);
}
for cmd in [
cmd("EVAL").arg(r#"redis.call("PING");"#).arg(0),
cmd("EVALSHA").arg(r#"redis.call("PING");"#).arg(0),
] {
assert_eq!(
RoutingInfo::for_routable(cmd),
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random))
);
}
assert_eq!(
RoutingInfo::for_routable(cmd("FCALL").arg("foo").arg(1).arg("mykey")),
Some(RoutingInfo::SingleNode(
SingleNodeRoutingInfo::SpecificNode(Route::new(slot(b"mykey"), SlotAddr::Master))
))
);
for (cmd, expected) in [
(
cmd("EVAL")
.arg(r#"redis.call("GET, KEYS[1]");"#)
.arg(1)
.arg("foo"),
Some(RoutingInfo::SingleNode(
SingleNodeRoutingInfo::SpecificNode(Route::new(slot(b"foo"), SlotAddr::Master)),
)),
),
(
cmd("XGROUP")
.arg("CREATE")
.arg("mystream")
.arg("workers")
.arg("$")
.arg("MKSTREAM"),
Some(RoutingInfo::SingleNode(
SingleNodeRoutingInfo::SpecificNode(Route::new(
slot(b"mystream"),
SlotAddr::Master,
)),
)),
),
(
cmd("XINFO").arg("GROUPS").arg("foo"),
Some(RoutingInfo::SingleNode(
SingleNodeRoutingInfo::SpecificNode(Route::new(
slot(b"foo"),
SlotAddr::ReplicaOptional,
)),
)),
),
(
cmd("XREADGROUP")
.arg("GROUP")
.arg("wkrs")
.arg("consmrs")
.arg("STREAMS")
.arg("mystream"),
Some(RoutingInfo::SingleNode(
SingleNodeRoutingInfo::SpecificNode(Route::new(
slot(b"mystream"),
SlotAddr::Master,
)),
)),
),
(
cmd("XREAD")
.arg("COUNT")
.arg("2")
.arg("STREAMS")
.arg("mystream")
.arg("writers")
.arg("0-0")
.arg("0-0"),
Some(RoutingInfo::SingleNode(
SingleNodeRoutingInfo::SpecificNode(Route::new(
slot(b"mystream"),
SlotAddr::ReplicaOptional,
)),
)),
),
] {
assert_eq!(
RoutingInfo::for_routable(cmd),
expected,
"{}",
std::str::from_utf8(cmd.arg_idx(0).unwrap()).unwrap()
);
}
}
#[test]
fn test_slot_for_packed_cmd() {
assert!(matches!(RoutingInfo::for_routable(&parse_valkey_value(&[
42, 50, 13, 10, 36, 54, 13, 10, 69, 88, 73, 83, 84, 83, 13, 10, 36, 49, 54, 13, 10,
244, 93, 23, 40, 126, 127, 253, 33, 89, 47, 185, 204, 171, 249, 96, 139, 13, 10
]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::ReplicaOptional)))) if slot == 964));
assert!(matches!(RoutingInfo::for_routable(&parse_valkey_value(&[
42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 36, 241,
197, 111, 180, 254, 5, 175, 143, 146, 171, 39, 172, 23, 164, 145, 13, 10, 36, 52,
13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10,
80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10
]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 8352));
assert!(matches!(RoutingInfo::for_routable(&parse_valkey_value(&[
42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 169, 233,
247, 59, 50, 247, 100, 232, 123, 140, 2, 101, 125, 221, 66, 170, 13, 10, 36, 52,
13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10,
80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10
]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 5210));
}
#[test]
fn test_multi_shard_keys_only() {
let mut cmd = cmd("DEL");
cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz");
let routing = RoutingInfo::for_routable(&cmd);
let mut expected = std::collections::HashMap::new();
expected.insert(Route(4813, SlotAddr::Master), vec![2]);
expected.insert(Route(5061, SlotAddr::Master), vec![1, 3]);
expected.insert(Route(12182, SlotAddr::Master), vec![0]);
assert!(
matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::Aggregate(AggregateOp::Sum))))) if {
let routes = vec.clone().into_iter().collect();
expected == routes && args_pattern == MultiSlotArgPattern::KeysOnly
}),
"expected={expected:?}\nrouting={routing:?}"
);
let mut cmd = crate::cmd::cmd("MGET");
cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz");
let routing = RoutingInfo::for_routable(&cmd);
let mut expected = std::collections::HashMap::new();
expected.insert(Route(4813, SlotAddr::ReplicaOptional), vec![2]);
expected.insert(Route(5061, SlotAddr::ReplicaOptional), vec![1, 3]);
expected.insert(Route(12182, SlotAddr::ReplicaOptional), vec![0]);
assert!(
matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::CombineArrays)))) if {
let routes = vec.clone().into_iter().collect();
expected == routes && args_pattern == MultiSlotArgPattern::KeysOnly
}),
"expected={expected:?}\nrouting={routing:?}"
);
}
#[test]
fn test_multi_shard_key_value_pairs() {
let mut cmd = cmd("MSET");
cmd.arg("foo") .arg("bar") .arg("foo2") .arg("bar2") .arg("{foo}foo3") .arg("bar3"); let routing = RoutingInfo::for_routable(&cmd);
let mut expected = std::collections::HashMap::new();
expected.insert(Route(1044, SlotAddr::Master), vec![2, 3]);
expected.insert(Route(12182, SlotAddr::Master), vec![0, 1, 4, 5]);
assert!(
matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::AllSucceeded)))) if {
let routes = vec.clone().into_iter().collect();
expected == routes && args_pattern == MultiSlotArgPattern::KeyValuePairs
}),
"expected={expected:?}\nrouting={routing:?}"
);
}
#[test]
fn test_multi_shard_keys_and_path() {
let mut cmd = cmd("JSON.MGET");
cmd.arg("foo") .arg("bar") .arg("baz") .arg("{bar}vaz") .arg("$.f.a"); let routing = RoutingInfo::for_routable(&cmd);
let mut expected = std::collections::HashMap::new();
expected.insert(Route(4813, SlotAddr::ReplicaOptional), vec![2, 4]);
expected.insert(Route(5061, SlotAddr::ReplicaOptional), vec![1, 3, 4]);
expected.insert(Route(12182, SlotAddr::ReplicaOptional), vec![0, 4]);
assert!(
matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::CombineArrays)))) if {
let routes = vec.clone().into_iter().collect();
expected == routes && args_pattern == MultiSlotArgPattern::KeysAndLastArg
}),
"expected={expected:?}\nrouting={routing:?}"
);
}
#[test]
fn test_multi_shard_key_with_two_arg_triples() {
let mut cmd = cmd("JSON.MSET");
cmd.arg("foo") .arg("$.a") .arg("bar") .arg("foo2") .arg("$.f.a") .arg("bar2") .arg("{foo}foo3") .arg("$.f.a") .arg("bar3"); let routing = RoutingInfo::for_routable(&cmd);
let mut expected = std::collections::HashMap::new();
expected.insert(Route(1044, SlotAddr::Master), vec![3, 4, 5]);
expected.insert(Route(12182, SlotAddr::Master), vec![0, 1, 2, 6, 7, 8]);
assert!(
matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::AllSucceeded)))) if {
let routes = vec.clone().into_iter().collect();
expected == routes && args_pattern == MultiSlotArgPattern::KeyWithTwoArgTriples
}),
"expected={expected:?}\nrouting={routing:?}"
);
}
#[test]
fn test_command_creation_for_multi_shard() {
let mut original_cmd = cmd("DEL");
original_cmd
.arg("foo")
.arg("bar")
.arg("baz")
.arg("{bar}vaz");
let routing = RoutingInfo::for_routable(&original_cmd);
let expected = [vec![0], vec![1, 3], vec![2]];
let mut indices: Vec<_> = match routing {
Some(RoutingInfo::MultiNode((
MultipleNodeRoutingInfo::MultiSlot((vec, MultiSlotArgPattern::KeysOnly)),
_,
))) => vec.into_iter().map(|(_, indices)| indices).collect(),
_ => panic!("unexpected routing: {routing:?}"),
};
indices.sort_by(|prev, next| prev.iter().next().unwrap().cmp(next.iter().next().unwrap()));
for (index, indices) in indices.into_iter().enumerate() {
let cmd = command_for_multi_slot_indices(&original_cmd, indices.iter());
let expected_indices = &expected[index];
assert_eq!(original_cmd.arg_idx(0), cmd.arg_idx(0));
for (index, target_index) in expected_indices.iter().enumerate() {
let target_index = target_index + 1;
assert_eq!(original_cmd.arg_idx(target_index), cmd.arg_idx(index + 1));
}
}
}
#[test]
fn test_combine_multi_shard_to_single_node_when_all_keys_are_in_same_slot() {
let mut cmd = cmd("DEL");
cmd.arg("foo").arg("{foo}bar").arg("{foo}baz");
let routing = RoutingInfo::for_routable(&cmd);
assert!(
matches!(
routing,
Some(RoutingInfo::SingleNode(
SingleNodeRoutingInfo::SpecificNode(Route(12182, SlotAddr::Master))
))
),
"{routing:?}"
);
}
#[test]
fn test_combining_results_into_single_array_only_keys() {
let res1 = Value::Array(vec![Ok(Value::Nil), Ok(Value::Okay)]);
let res2 = Value::Array(vec![
Ok(Value::BulkString("1".as_bytes().to_vec().into())),
Ok(Value::BulkString("4".as_bytes().to_vec().into())),
]);
let res3 = Value::Array(vec![Ok(Value::SimpleString("2".to_string())), Ok(Value::Int(3))]);
let results = super::combine_and_sort_array_results(
vec![res1, res2, res3],
&[
(Route(4813, SlotAddr::Master), vec![2, 3]),
(Route(5061, SlotAddr::Master), vec![1, 4]),
(Route(12182, SlotAddr::Master), vec![0, 5]),
],
&MultiSlotArgPattern::KeysOnly,
);
assert_eq!(
results.unwrap(),
Value::Array(vec![
Ok(Value::SimpleString("2".to_string())),
Ok(Value::BulkString("1".as_bytes().to_vec().into())),
Ok(Value::Nil),
Ok(Value::Okay),
Ok(Value::BulkString("4".as_bytes().to_vec().into())),
Ok(Value::Int(3)),
])
);
}
#[test]
fn test_combining_results_into_single_array_key_value_paires() {
let res1 = Value::Array(vec![Ok(Value::Okay)]);
let res2 = Value::Array(vec![
Ok(Value::BulkString("1".as_bytes().to_vec().into())),
Ok(Value::Nil),
]);
let results = super::combine_and_sort_array_results(
vec![res1, res2],
&[
(Route(1044, SlotAddr::Master), vec![2, 3]),
(Route(12182, SlotAddr::Master), vec![0, 1, 4, 5]),
],
&MultiSlotArgPattern::KeyValuePairs,
);
assert_eq!(
results.unwrap(),
Value::Array(vec![
Ok(Value::BulkString("1".as_bytes().to_vec().into())),
Ok(Value::Okay),
Ok(Value::Nil)
])
);
}
#[test]
fn test_combining_results_into_single_array_keys_and_path() {
let res1 = Value::Array(vec![Ok(Value::Okay)]);
let res2 = Value::Array(vec![
Ok(Value::BulkString("1".as_bytes().to_vec().into())),
Ok(Value::Nil),
]);
let results = super::combine_and_sort_array_results(
vec![res1, res2],
&[
(Route(5061, SlotAddr::Master), vec![2, 3]),
(Route(12182, SlotAddr::Master), vec![0, 1, 3]),
],
&MultiSlotArgPattern::KeysAndLastArg,
);
assert_eq!(
results.unwrap(),
Value::Array(vec![
Ok(Value::BulkString("1".as_bytes().to_vec().into())),
Ok(Value::Nil),
Ok(Value::Okay),
])
);
}
#[test]
fn test_combining_results_into_single_array_key_with_two_arg_triples() {
let res1 = Value::Array(vec![Ok(Value::Okay)]);
let res2 = Value::Array(vec![
Ok(Value::BulkString("1".as_bytes().to_vec().into())),
Ok(Value::Nil),
]);
let results = super::combine_and_sort_array_results(
vec![res1, res2],
&[
(Route(5061, SlotAddr::Master), vec![3, 4, 5]),
(Route(12182, SlotAddr::Master), vec![0, 1, 2, 6, 7, 8]),
],
&MultiSlotArgPattern::KeyWithTwoArgTriples,
);
assert_eq!(
results.unwrap(),
Value::Array(vec![
Ok(Value::BulkString("1".as_bytes().to_vec().into())),
Ok(Value::Okay),
Ok(Value::Nil)
])
);
}
#[test]
fn test_combine_map_results() {
let input = vec![];
let result = super::combine_map_results(input).unwrap();
assert_eq!(result, Value::Map(vec![]));
let input = vec![
Value::Array(vec![
Ok(Value::BulkString(b"key1".to_vec().into())),
Ok(Value::Int(5)),
Ok(Value::BulkString(b"key2".to_vec().into())),
Ok(Value::Int(10)),
]),
Value::Array(vec![
Ok(Value::BulkString(b"key1".to_vec().into())),
Ok(Value::Int(3)),
Ok(Value::BulkString(b"key3".to_vec().into())),
Ok(Value::Int(15)),
]),
];
let result = super::combine_map_results(input).unwrap();
let mut expected = vec![
(Value::BulkString(b"key1".to_vec().into()), Value::Int(8)),
(Value::BulkString(b"key2".to_vec().into()), Value::Int(10)),
(Value::BulkString(b"key3".to_vec().into()), Value::Int(15)),
];
expected.sort_unstable_by(|a, b| match (&a.0, &b.0) {
(Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes),
_ => std::cmp::Ordering::Equal,
});
let mut result_vec = match result {
Value::Map(v) => v,
_ => panic!("Expected Map"),
};
result_vec.sort_unstable_by(|a, b| match (&a.0, &b.0) {
(Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes),
_ => std::cmp::Ordering::Equal,
});
assert_eq!(result_vec, expected);
let input = vec![Value::Int(5)];
let result = super::combine_map_results(input);
assert!(result.is_err());
}
fn create_shard_addrs(primary: &str, replicas: Vec<&str>) -> ShardAddrs {
ShardAddrs {
primary: RwLock::new(Arc::new(primary.to_string())),
replicas: RwLock::new(
replicas
.into_iter()
.map(|r| Arc::new(r.to_string()))
.collect(),
),
}
}
#[test]
fn test_attempt_shard_role_update_already_primary() {
let shard_addrs = create_shard_addrs("node1:6379", vec!["node2:6379", "node3:6379"]);
let result = shard_addrs.attempt_shard_role_update(Arc::new("node1:6379".to_string()));
assert_eq!(result, ShardUpdateResult::AlreadyPrimary);
}
#[test]
fn test_attempt_shard_role_update_promoted() {
let shard_addrs = create_shard_addrs("node1:6379", vec!["node2:6379", "node3:6379"]);
let result = shard_addrs.attempt_shard_role_update(Arc::new("node2:6379".to_string()));
assert_eq!(result, ShardUpdateResult::Promoted);
let primary = shard_addrs.primary.read().unwrap().clone();
assert_eq!(primary.as_str(), "node2:6379");
let replicas = shard_addrs.replicas.read().unwrap();
assert_eq!(replicas.len(), 2);
assert!(replicas.iter().any(|r| r.as_str() == "node1:6379"));
}
#[test]
fn test_attempt_shard_role_update_node_not_found() {
let shard_addrs = create_shard_addrs("node1:6379", vec!["node2:6379", "node3:6379"]);
let result = shard_addrs.attempt_shard_role_update(Arc::new("node4:6379".to_string()));
assert_eq!(result, ShardUpdateResult::NodeNotFound);
}
#[test]
fn test_client_list_routing() {
let mut cmd = cmd("CLIENT");
cmd.arg("LIST");
let routing = RoutingInfo::for_routable(&cmd);
assert_eq!(
routing,
Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)),
"CLIENT LIST should be routed to a random node"
);
}
#[test]
fn test_is_read_only() {
assert!(is_readonly_cmd(b"SENTINEL MASTERS"));
assert!(is_readonly_cmd(b"SENTINEL MASTER"));
assert!(is_readonly_cmd(b"SENTINEL REPLICAS"));
assert!(is_readonly_cmd(b"SENTINEL GET-MASTER-ADDR-BY-NAME"));
assert!(is_readonly_cmd(b"SENTINEL CKQUORUM"));
assert!(!is_readonly_cmd(b"SENTINEL FAILOVER"));
let mut test_cmd = cmd("SENTINEL");
test_cmd.arg("MASTERS").arg("my_service");
assert!(is_readonly(&test_cmd));
assert!(is_readonly_cmd(
Routable::command(&test_cmd).unwrap().as_slice()
));
let mut test_cmd = cmd("SENTINEL");
test_cmd.arg("GET-MASTER-ADDR-BY-NAME").arg("my_service");
assert!(is_readonly(&test_cmd));
assert!(is_readonly_cmd(
Routable::command(&test_cmd).unwrap().as_slice()
));
test_cmd = cmd("SENTINEL");
test_cmd.arg("FAILOVER").arg("my_service");
assert!(!is_readonly(&test_cmd));
assert!(!is_readonly_cmd(
Routable::command(&test_cmd).unwrap().as_slice()
));
}
}