Skip to main content

ferriskey/cluster/
routing.rs

1use rand::Rng;
2use strum_macros::Display;
3
4use crate::cluster::topology::get_slot;
5use crate::cmd::{Arg, Cmd};
6use crate::value::Value;
7use crate::value::{ErrorKind, Error, Result};
8use core::cmp::Ordering;
9use std::borrow::Cow;
10use std::cmp::min;
11use std::collections::HashMap;
12use std::iter::Once;
13use std::sync::Arc;
14use std::sync::{RwLock, RwLockWriteGuard};
15
16#[derive(Clone)]
17pub(crate) enum Redirect {
18    Moved(String),
19    /// (addr, should_exec_asking) - if `should_exec_asking` is true,  the `ASKING` command would be executed as part of `get_connection`.
20    Ask(String, bool),
21}
22
23/// Logical bitwise aggregating operators.
24#[derive(Debug, Clone, Copy, PartialEq)]
25pub enum LogicalAggregateOp {
26    /// Aggregate by bitwise &&
27    And,
28    // Or, omitted due to dead code warnings. ATM this value isn't constructed anywhere
29}
30
31/// Numerical aggregating operators.
32#[derive(Debug, Clone, Copy, PartialEq)]
33pub enum AggregateOp {
34    /// Choose minimal value
35    Min,
36    /// Sum all values
37    Sum,
38    // Max, omitted due to dead code warnings. ATM this value isn't constructed anywhere
39}
40
41/// Array aggregating operators for element-wise operations.
42#[derive(Debug, Clone, Copy, PartialEq)]
43pub enum ArrayAggregateOp {
44    /// Choose minimal value for each array element
45    Min,
46}
47
48/// Policy defining how to combine multiple responses into one.
49#[derive(Debug, Clone, Copy, PartialEq)]
50pub enum ResponsePolicy {
51    /// Wait for one request to succeed and return its results. Return error if all requests fail.
52    OneSucceeded,
53    /// Returns the first succeeded non-empty result; if all results are empty, returns `Nil`; otherwise, returns the last received error.
54    FirstSucceededNonEmptyOrAllEmpty,
55    /// Waits for all requests to succeed, and the returns one of the successes. Returns the error on the first received error.
56    AllSucceeded,
57    /// Aggregate success results according to a logical bitwise operator. Return error on any failed request or on a response that doesn't conform to 0 or 1.
58    AggregateLogical(LogicalAggregateOp),
59    /// Aggregate success results according to a numeric operator. Return error on any failed request or on a response that isn't an integer.
60    Aggregate(AggregateOp),
61    /// Aggregate array responses element-wise according to a numeric operator. Return error on any failed request or on a response that isn't an array of integers.
62    AggregateArray(ArrayAggregateOp),
63    /// Aggregate array responses into a single array. Return error on any failed request or on a response that isn't an array.
64    CombineArrays,
65    /// Handling is not defined by the Redis standard. Will receive a special case
66    Special,
67    /// Combines multiple map responses into a single map.
68    CombineMaps,
69}
70
71/// Defines whether a request should be routed to a single node, or multiple ones.
72#[derive(Debug, Clone, PartialEq)]
73pub enum RoutingInfo {
74    /// Route to single node
75    SingleNode(SingleNodeRoutingInfo),
76    /// Route to multiple nodes
77    MultiNode((MultipleNodeRoutingInfo, Option<ResponsePolicy>)),
78}
79
80/// Defines which single node should receive a request.
81#[derive(Debug, Clone, PartialEq)]
82pub enum SingleNodeRoutingInfo {
83    /// Route to any node at random
84    Random,
85    /// Route to any *primary* node
86    RandomPrimary,
87    /// Route to the node that matches the [Route]
88    SpecificNode(Route),
89    /// Route to the node with the given address.
90    ByAddress {
91        /// DNS hostname of the node
92        host: String,
93        /// port of the node
94        port: u16,
95    },
96}
97
98impl From<Option<Route>> for SingleNodeRoutingInfo {
99    fn from(value: Option<Route>) -> Self {
100        value
101            .map(SingleNodeRoutingInfo::SpecificNode)
102            .unwrap_or(SingleNodeRoutingInfo::Random)
103    }
104}
105
106/// Defines which collection of nodes should receive a request
107#[derive(Debug, Clone, PartialEq)]
108pub enum MultipleNodeRoutingInfo {
109    /// Route to all nodes in the clusters
110    AllNodes,
111    /// Route to all primaries in the cluster
112    AllMasters,
113    /// Routes the request to multiple slots.
114    /// This variant contains instructions for splitting a multi-slot command (e.g., MGET, MSET) into sub-commands.
115    /// Each tuple consists of a `Route` representing the target node for the subcommand,
116    /// and a vector of argument indices from the original command that should be copied to each subcommand.
117    /// The `MultiSlotArgPattern` specifies the pattern of the command’s arguments, indicating how they are organized
118    /// (e.g., only keys, key-value pairs, etc).
119    MultiSlot((Vec<(Route, Vec<usize>)>, MultiSlotArgPattern)),
120}
121
122/// Takes a routable and an iterator of indices, which is assued to be created from`MultipleNodeRoutingInfo::MultiSlot`,
123/// and returns a command with the arguments matching the indices.
124pub fn command_for_multi_slot_indices<'a, 'b>(
125    original_cmd: &'a impl Routable,
126    indices: impl Iterator<Item = &'b usize> + 'a,
127) -> Cmd
128where
129    'b: 'a,
130{
131    let mut new_cmd = Cmd::new();
132    let command_length = 1; // TODO - the +1 should change if we have multi-slot commands with 2 command words.
133    new_cmd.arg(original_cmd.arg_idx(0));
134    for index in indices {
135        new_cmd.arg(original_cmd.arg_idx(index + command_length));
136    }
137    new_cmd
138}
139
140/// Aggreagte numeric responses.
141pub fn aggregate(values: Vec<Value>, op: AggregateOp) -> Result<Value> {
142    let initial_value = match op {
143        AggregateOp::Min => i64::MAX,
144        AggregateOp::Sum => 0,
145    };
146    let result = values.into_iter().try_fold(initial_value, |acc, curr| {
147        let int = match curr {
148            Value::Int(int) => int,
149            _ => {
150                return Result::Err(
151                    (
152                        ErrorKind::TypeError,
153                        "expected array of integers as response",
154                    )
155                        .into(),
156                );
157            }
158        };
159        let acc = match op {
160            AggregateOp::Min => min(acc, int),
161            AggregateOp::Sum => acc + int,
162        };
163        Ok(acc)
164    })?;
165    Ok(Value::Int(result))
166}
167
168/// Aggreagte numeric responses by a boolean operator.
169pub fn logical_aggregate(values: Vec<Value>, op: LogicalAggregateOp) -> Result<Value> {
170    let initial_value = match op {
171        LogicalAggregateOp::And => true,
172    };
173    let results = values.into_iter().try_fold(Vec::new(), |acc, curr| {
174        let values = match curr {
175            Value::Array(values) => values,
176            _ => {
177                return Result::Err(
178                    (
179                        ErrorKind::TypeError,
180                        "expected array of integers as response",
181                    )
182                        .into(),
183                );
184            }
185        };
186        let mut acc = if acc.is_empty() {
187            vec![initial_value; values.len()]
188        } else {
189            acc
190        };
191        for (index, value) in values.into_iter().enumerate() {
192            let int = match value {
193                Ok(Value::Int(int)) => int,
194                _ => {
195                    return Err((
196                        ErrorKind::TypeError,
197                        "expected array of integers as response",
198                    )
199                        .into());
200                }
201            };
202            acc[index] = match op {
203                LogicalAggregateOp::And => acc[index] && (int > 0),
204            };
205        }
206        Ok(acc)
207    })?;
208    Ok(Value::Array(
209        results
210            .into_iter()
211            .map(|result| Ok(Value::Int(result as i64)))
212            .collect(),
213    ))
214}
215
216/// Aggregate array responses element-wise according to a numeric operator.
217pub fn aggregate_array(values: Vec<Value>, op: ArrayAggregateOp) -> Result<Value> {
218    let initial_value = match op {
219        ArrayAggregateOp::Min => i64::MAX,
220    };
221    let results = values.into_iter().try_fold(Vec::new(), |acc, curr| {
222        let values = match curr {
223            Value::Array(values) => values,
224            _ => {
225                return Result::Err(
226                    (
227                        ErrorKind::TypeError,
228                        "expected array of integers as response",
229                    )
230                        .into(),
231                );
232            }
233        };
234        let mut acc = if acc.is_empty() {
235            vec![initial_value; values.len()]
236        } else {
237            acc
238        };
239        for (index, value) in values.into_iter().enumerate() {
240            let int = match value {
241                Ok(Value::Int(int)) => int,
242                _ => {
243                    return Err((
244                        ErrorKind::TypeError,
245                        "expected array of integers as response",
246                    )
247                        .into());
248                }
249            };
250            acc[index] = match op {
251                ArrayAggregateOp::Min => min(acc[index], int),
252            };
253        }
254        Ok(acc)
255    })?;
256    Ok(Value::Array(results.into_iter().map(|i| Ok(Value::Int(i))).collect()))
257}
258/// Aggregate array responses into a single map.
259pub fn combine_map_results(values: Vec<Value>) -> Result<Value> {
260    let mut map: HashMap<Vec<u8>, i64> = HashMap::new();
261
262    for value in values {
263        match value {
264            Value::Array(elements) => {
265                let mut iter = elements.into_iter();
266
267                while let Some(Ok(key)) = iter.next() {
268                    if let Value::BulkString(key_bytes) = key {
269                        if let Some(Ok(Value::Int(value))) = iter.next() {
270                            *map.entry(key_bytes.to_vec()).or_insert(0) += value;
271                        } else {
272                            return Err((ErrorKind::TypeError, "expected integer value").into());
273                        }
274                    } else {
275                        return Err((ErrorKind::TypeError, "expected string key").into());
276                    }
277                }
278            }
279            _ => {
280                return Err((ErrorKind::TypeError, "expected array of values as response").into());
281            }
282        }
283    }
284
285    let result_vec: Vec<(Value, Value)> = map
286        .into_iter()
287        .map(|(k, v)| (Value::BulkString(bytes::Bytes::from(k)), Value::Int(v)))
288        .collect();
289
290    Ok(Value::Map(result_vec))
291}
292
293/// Aggregate array responses into a single array.
294pub fn combine_array_results(values: Vec<Value>) -> Result<Value> {
295    let mut results = Vec::new();
296
297    for value in values {
298        match value {
299            Value::Array(values) => results.extend(values),
300            _ => {
301                return Err((ErrorKind::TypeError, "expected array of values as response").into());
302            }
303        }
304    }
305
306    Ok(Value::Array(results))
307}
308
309// An iterator that yields `Cow<[usize]>` representing grouped result indices according to a specified argument pattern.
310// This type is used to combine multi-slot array responses.
311type MultiSlotResIdxIter<'a> = std::iter::Map<
312    std::slice::Iter<'a, (Route, Vec<usize>)>,
313    fn(&'a (Route, Vec<usize>)) -> Cow<'a, [usize]>,
314>;
315
316/// Generates an iterator that yields a vector of result indices for each slot within the final merged results array for a multi-slot command response.
317/// The indices are calculated based on the `args_pattern` and the positions of the arguments for each slot-specific request in the original multi-slot request,
318/// ensuring that the results are ordered according to the structure of the initial multi-slot command.
319///
320/// # Arguments
321/// * `route_arg_indices` - A reference to a vector where each element is a tuple containing a route and
322///   the corresponding argument indices for that route.
323/// * `args_pattern` - Specifies the argument pattern (e.g., `KeysOnly`, `KeyValuePairs`, ..), which defines how the indices are grouped for each slot.
324///
325/// # Returns
326/// An iterator yielding `Cow<[usize]>` with the grouped result indices based on the specified argument pattern.
327///
328/// /// For example, given the command `MSET foo bar foo2 bar2 {foo}foo3 bar3` with the `KeyValuePairs` pattern:
329/// - `route_arg_indices` would include:
330///   - Slot of "foo" with argument indices `[0, 1, 4, 5]` (where `{foo}foo3` hashes to the same slot as "foo" due to curly braces).
331///   - Slot of "foo2" with argument indices `[2, 3]`.
332/// - Using the `KeyValuePairs` pattern, each key-value pair contributes a single response, yielding three responses total.
333/// - Therefore, the iterator generated by this function would yield grouped result indices as follows:
334///   - Slot "foo" is mapped to `[0, 2]` in the final result order.
335///   - Slot "foo2" is mapped to `[1]`.
336fn calculate_multi_slot_result_indices<'a>(
337    route_arg_indices: &'a [(Route, Vec<usize>)],
338    args_pattern: &MultiSlotArgPattern,
339) -> Result<MultiSlotResIdxIter<'a>> {
340    let check_indices_input = |step_count: usize| {
341        for (_, indices) in route_arg_indices {
342            if indices.len() % step_count != 0 {
343                return Err(Error::from((
344                    ErrorKind::ClientError,
345                    "Invalid indices input detected",
346                    format!(
347                        "Expected argument pattern with tuples of size {step_count}, but found indices: {indices:?}"
348                    ),
349                )));
350            }
351        }
352        Ok(())
353    };
354
355    match args_pattern {
356        MultiSlotArgPattern::KeysOnly => Ok(route_arg_indices
357            .iter()
358            .map(|(_, indices)| Cow::Borrowed(indices))),
359        MultiSlotArgPattern::KeysAndLastArg => {
360            // The last index corresponds to the path, skip it
361            Ok(route_arg_indices
362                .iter()
363                .map(|(_, indices)| Cow::Borrowed(&indices[..indices.len() - 1])))
364        }
365        MultiSlotArgPattern::KeyWithTwoArgTriples => {
366            // For each triplet (key, path, value) we receive a single response.
367            // For example, for argument indices: [(_, [0,1,2]), (_, [3,4,5,9,10,11]), (_, [6,7,8])]
368            // The resulting grouped indices would be: [0], [1, 3], [2]
369            check_indices_input(3)?;
370            Ok(route_arg_indices.iter().map(|(_, indices)| {
371                Cow::Owned(
372                    indices
373                        .iter()
374                        .step_by(3)
375                        .map(|idx| idx / 3)
376                        .collect::<Vec<usize>>(),
377                )
378            }))
379        }
380        MultiSlotArgPattern::KeyValuePairs =>
381        // For each pair (key, value) we receive a single response.
382        // For example, for argument indices: [(_, [0,1]), (_, [2,3,6,7]), (_, [4,5])]
383        // The resulting grouped indices would be: [0], [1, 3], [2]
384        {
385            check_indices_input(2)?;
386            Ok(route_arg_indices.iter().map(|(_, indices)| {
387                Cow::Owned(
388                    indices
389                        .iter()
390                        .step_by(2)
391                        .map(|idx| idx / 2)
392                        .collect::<Vec<usize>>(),
393                )
394            }))
395        }
396    }
397}
398
399/// Merges the results of a multi-slot command from the `values` field, where each entry is expected to be an array of results.
400/// The combined results are ordered according to the sequence in which they appeared in the original command.
401///
402/// # Arguments
403///
404/// * `values` - A vector of `Value`s, where each `Value` is expected to be an array representing results
405///   from separate slots in a multi-slot command. Each `Value::Array` within `values` corresponds to
406///   the results associated with a specific slot, as indicated by `route_arg_indices`.
407///
408/// * `route_arg_indices` - A reference to a vector of tuples, where each tuple represents a route and a vector of
409///   argument indices associated with that route. The route indicates the slot, while the indices vector
410///   specifies the positions of arguments relevant to this slot. This is used to construct `sorting_order`,
411///   which guides the placement of results in the final array.
412///
413/// * `args_pattern` - Specifies the argument pattern (e.g., `KeysOnly`, `KeyValuePairs`, ...).
414///   The pattern defines how the argument indices are grouped for each slot and determines
415///   the ordering of results from `values` as they are placed in the final combined array.
416///
417/// # Returns
418///
419/// Returns a `Result<Value>` containing the final ordered array (`Value::Array`) of combined results.
420pub(crate) fn combine_and_sort_array_results(
421    values: Vec<Value>,
422    route_arg_indices: &[(Route, Vec<usize>)],
423    args_pattern: &MultiSlotArgPattern,
424) -> Result<Value> {
425    let result_indices = calculate_multi_slot_result_indices(route_arg_indices, args_pattern)?;
426    let mut results: Vec<Result<Value>> = Vec::new();
427    results.resize(
428        values.iter().fold(0, |acc, value| match value {
429            Value::Array(values) => values.len() + acc,
430            _ => 0,
431        }),
432        Ok(Value::Nil),
433    );
434    if values.len() != result_indices.len() {
435        return Err(Error::from((
436            ErrorKind::ClientError,
437            "Mismatch in the number of multi-slot results compared to the expected result count.",
438            format!(
439                "Expected: {:?}, Found: {:?}",
440                values.len(),
441                result_indices.len()
442            ),
443        )));
444    }
445
446    for (key_indices, value) in result_indices.into_iter().zip(values) {
447        match value {
448            Value::Array(values) => {
449                debug_assert_eq!(values.len(), key_indices.len());
450                for (index, value) in key_indices.iter().zip(values) {
451                    results[*index] = value;
452                }
453            }
454            _ => {
455                return Err((ErrorKind::TypeError, "expected array of values as response").into());
456            }
457        }
458    }
459
460    Ok(Value::Array(results))
461}
462
463fn get_route(is_readonly: bool, key: &[u8]) -> Route {
464    let slot = get_slot(key);
465    if is_readonly {
466        Route::new(slot, SlotAddr::ReplicaOptional)
467    } else {
468        Route::new(slot, SlotAddr::Master)
469    }
470}
471
472/// Represents the pattern of argument structures in multi-slot commands,
473/// defining how the arguments are organized in the command.
474#[derive(Debug, Clone, PartialEq)]
475pub enum MultiSlotArgPattern {
476    /// Pattern where only keys are provided in the command.
477    /// For example: `MGET key1 key2`
478    KeysOnly,
479
480    /// Pattern where each key is followed by a corresponding value.
481    /// For example: `MSET key1 value1 key2 value2`
482    KeyValuePairs,
483
484    /// Pattern where a list of keys is followed by a shared parameter.
485    /// For example: `JSON.MGET key1 key2 key3 path`
486    KeysAndLastArg,
487
488    /// Pattern where each key is followed by two associated arguments, forming key-argument-argument triples.
489    /// For example: `JSON.MSET key1 path1 value1 key2 path2 value2`
490    KeyWithTwoArgTriples,
491}
492
493/// Takes the given `routable` and creates a multi-slot routing info.
494/// This is used for commands like MSET & MGET, where if the command's keys
495/// are hashed to multiple slots, the command should be split into sub-commands,
496/// each targetting a single slot. The results of these sub-commands are then
497/// usually reassembled using `combine_and_sort_array_results`. In order to do this,
498/// `MultipleNodeRoutingInfo::MultiSlot` contains the routes for each sub-command, and
499/// the indices in the final combined result for each result from the sub-command.
500///
501/// If all keys are routed to the same slot, there's no need to split the command,
502/// so a single node routing info will be returned.
503///
504/// # Arguments
505/// * `routable` - The command or structure containing key-related data that can be routed.
506/// * `cmd` - A byte slice representing the command name or opcode (e.g., `b"MGET"`).
507/// * `first_key_index` - The starting index in the command where the first key is located.
508/// * `args_pattern` - Specifies how keys and values are patterned in the command (e.g., `OnlyKeys`, `KeyValuePairs`).
509///
510/// # Returns
511/// `Some(RoutingInfo)` if routing info is created, indicating the command targets multiple slots or a single slot;
512/// `None` if no routing info could be derived.
513fn multi_shard<R>(
514    routable: &R,
515    cmd: &[u8],
516    first_key_index: usize,
517    args_pattern: MultiSlotArgPattern,
518) -> Option<RoutingInfo>
519where
520    R: Routable + ?Sized,
521{
522    let is_readonly = is_readonly_cmd(cmd);
523    let mut routes = HashMap::new();
524    let mut curr_arg_idx = 0;
525    let incr_add_next_arg = |arg_indices: &mut Vec<usize>, mut curr_arg_idx: usize| {
526        curr_arg_idx += 1;
527        // Ensure there's a value following the key
528        routable.arg_idx(curr_arg_idx)?;
529        arg_indices.push(curr_arg_idx);
530        Some(curr_arg_idx)
531    };
532    while let Some(arg) = routable.arg_idx(first_key_index + curr_arg_idx) {
533        let route = get_route(is_readonly, arg);
534        let arg_indices = routes.entry(route).or_insert(Vec::new());
535
536        arg_indices.push(curr_arg_idx);
537
538        match args_pattern {
539            MultiSlotArgPattern::KeysOnly => {} // no additional handling needed for keys-only commands
540            MultiSlotArgPattern::KeyValuePairs => {
541                // Increment to the value paired with the current key and add its index
542                curr_arg_idx = incr_add_next_arg(arg_indices, curr_arg_idx)?;
543            }
544            MultiSlotArgPattern::KeysAndLastArg => {
545                // Check if the command has more keys or if the next argument is a path
546                if routable
547                    .arg_idx(first_key_index + curr_arg_idx + 2)
548                    .is_none()
549                {
550                    // Last key reached; add the path argument index for each route and break
551                    let path_idx = curr_arg_idx + 1;
552                    for (_, arg_indices) in routes.iter_mut() {
553                        arg_indices.push(path_idx);
554                    }
555                    break;
556                }
557            }
558            MultiSlotArgPattern::KeyWithTwoArgTriples => {
559                // Increment to the first argument associated with the current key and add its index
560                curr_arg_idx = incr_add_next_arg(arg_indices, curr_arg_idx)?;
561                // Increment to the second argument associated with the current key and add its index
562                curr_arg_idx = incr_add_next_arg(arg_indices, curr_arg_idx)?;
563            }
564        }
565        curr_arg_idx += 1;
566    }
567
568    let mut routes: Vec<(Route, Vec<usize>)> = routes.into_iter().collect();
569    if routes.is_empty() {
570        return None;
571    }
572
573    Some(if routes.len() == 1 {
574        RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(routes.pop().unwrap().0))
575    } else {
576        RoutingInfo::MultiNode((
577            MultipleNodeRoutingInfo::MultiSlot((routes, args_pattern)),
578            ResponsePolicy::for_command(cmd),
579        ))
580    })
581}
582
583impl ResponsePolicy {
584    /// Parse the command for the matching response policy.
585    pub fn for_command(cmd: &[u8]) -> Option<ResponsePolicy> {
586        match cmd {
587            b"SCRIPT EXISTS" => Some(ResponsePolicy::AggregateLogical(LogicalAggregateOp::And)),
588
589            b"DBSIZE" | b"DEL" | b"EXISTS" | b"SLOWLOG LEN" | b"TOUCH" | b"UNLINK"
590            | b"LATENCY RESET" | b"PUBSUB NUMPAT" => {
591                Some(ResponsePolicy::Aggregate(AggregateOp::Sum))
592            }
593
594            b"WAIT" => Some(ResponsePolicy::Aggregate(AggregateOp::Min)),
595
596            b"WAITAOF" => Some(ResponsePolicy::AggregateArray(ArrayAggregateOp::Min)),
597
598            b"ACL SETUSER" | b"ACL DELUSER" | b"ACL SAVE" | b"AUTH" | b"CLIENT SETNAME"
599            | b"CLIENT SETINFO" | b"CONFIG SET" | b"CONFIG RESETSTAT" | b"CONFIG REWRITE"
600            | b"FLUSHALL" | b"FLUSHDB" | b"FUNCTION DELETE" | b"FUNCTION FLUSH"
601            | b"FUNCTION LOAD" | b"FUNCTION RESTORE" | b"MEMORY PURGE" | b"MSET" | b"JSON.MSET"
602            | b"PING" | b"SCRIPT FLUSH" | b"SCRIPT LOAD" | b"SELECT" | b"SLOWLOG RESET"
603            | b"UNWATCH" | b"WATCH" => Some(ResponsePolicy::AllSucceeded),
604
605            b"KEYS"
606            | b"FT._ALIASLIST"
607            | b"FT._LIST"
608            | b"MGET"
609            | b"JSON.MGET"
610            | b"SLOWLOG GET"
611            | b"PUBSUB CHANNELS"
612            | b"PUBSUB SHARDCHANNELS" => Some(ResponsePolicy::CombineArrays),
613
614            b"PUBSUB NUMSUB" | b"PUBSUB SHARDNUMSUB" => Some(ResponsePolicy::CombineMaps),
615
616            b"FUNCTION KILL" | b"SCRIPT KILL" => Some(ResponsePolicy::OneSucceeded),
617
618            // This isn't based on response_tips, but on the discussion here - https://github.com/redis/redis/issues/12410
619            b"RANDOMKEY" => Some(ResponsePolicy::FirstSucceededNonEmptyOrAllEmpty),
620
621            b"LATENCY GRAPH" | b"LATENCY HISTOGRAM" | b"LATENCY HISTORY" | b"LATENCY DOCTOR"
622            | b"LATENCY LATEST" => Some(ResponsePolicy::Special),
623
624            b"FUNCTION STATS" => Some(ResponsePolicy::Special),
625
626            b"MEMORY MALLOC-STATS" | b"MEMORY DOCTOR" | b"MEMORY STATS" => {
627                Some(ResponsePolicy::Special)
628            }
629
630            b"INFO" => Some(ResponsePolicy::Special),
631
632            _ => None,
633        }
634    }
635}
636
637enum RouteBy {
638    AllNodes,
639    AllPrimaries,
640    FirstKey,
641    MultiShard(MultiSlotArgPattern),
642    Random,
643    SecondArg,
644    SecondArgAfterKeyCount,
645    SecondArgSlot,
646    StreamsIndex,
647    ThirdArg,
648    ThirdArgAfterKeyCount,
649    Undefined,
650}
651
652fn base_routing(cmd: &[u8]) -> RouteBy {
653    match cmd {
654        b"ACL SETUSER"
655        | b"ACL DELUSER"
656        | b"ACL SAVE"
657        | b"AUTH"
658        | b"CLIENT SETNAME"
659        | b"CLIENT SETINFO"
660        | b"SELECT"
661        | b"SLOWLOG GET"
662        | b"SLOWLOG LEN"
663        | b"SLOWLOG RESET"
664        | b"CONFIG SET"
665        | b"CONFIG RESETSTAT"
666        | b"CONFIG REWRITE"
667        | b"SCRIPT FLUSH"
668        | b"SCRIPT LOAD"
669        | b"LATENCY RESET"
670        | b"LATENCY GRAPH"
671        | b"LATENCY HISTOGRAM"
672        | b"LATENCY HISTORY"
673        | b"LATENCY DOCTOR"
674        | b"LATENCY LATEST"
675        | b"PUBSUB NUMPAT"
676        | b"PUBSUB CHANNELS"
677        | b"PUBSUB NUMSUB"
678        | b"PUBSUB SHARDCHANNELS"
679        | b"PUBSUB SHARDNUMSUB"
680        | b"SCRIPT KILL"
681        | b"FUNCTION KILL"
682        | b"FUNCTION STATS" => RouteBy::AllNodes,
683
684        b"DBSIZE"
685        | b"DEBUG"
686        | b"FLUSHALL"
687        | b"FLUSHDB"
688        | b"FT._ALIASLIST"
689        | b"FT._LIST"
690        | b"FUNCTION DELETE"
691        | b"FUNCTION FLUSH"
692        | b"FUNCTION LOAD"
693        | b"FUNCTION RESTORE"
694        | b"INFO"
695        | b"KEYS"
696        | b"MEMORY DOCTOR"
697        | b"MEMORY MALLOC-STATS"
698        | b"MEMORY PURGE"
699        | b"MEMORY STATS"
700        | b"PING"
701        | b"SCRIPT EXISTS"
702        | b"UNWATCH"
703        | b"WAIT"
704        | b"RANDOMKEY"
705        | b"WAITAOF" => RouteBy::AllPrimaries,
706
707        b"MGET" | b"DEL" | b"EXISTS" | b"UNLINK" | b"TOUCH" | b"WATCH" | b"SUBSCRIBE"
708        | b"PSUBSCRIBE" | b"SSUBSCRIBE" => RouteBy::MultiShard(MultiSlotArgPattern::KeysOnly),
709
710        b"MSET" => RouteBy::MultiShard(MultiSlotArgPattern::KeyValuePairs),
711        b"JSON.MGET" => RouteBy::MultiShard(MultiSlotArgPattern::KeysAndLastArg),
712        b"JSON.MSET" => RouteBy::MultiShard(MultiSlotArgPattern::KeyWithTwoArgTriples),
713        // TODO - special handling - b"SCAN"
714        b"SCAN" | b"SHUTDOWN" | b"SLAVEOF" | b"REPLICAOF" => RouteBy::Undefined,
715
716        b"BLMPOP" | b"BZMPOP" | b"EVAL" | b"EVALSHA" | b"EVALSHA_RO" | b"EVAL_RO" | b"FCALL"
717        | b"FCALL_RO" => RouteBy::ThirdArgAfterKeyCount,
718
719        b"BITOP"
720        | b"MEMORY USAGE"
721        | b"PFDEBUG"
722        | b"XGROUP CREATE"
723        | b"XGROUP CREATECONSUMER"
724        | b"XGROUP DELCONSUMER"
725        | b"XGROUP DESTROY"
726        | b"XGROUP SETID"
727        | b"XINFO CONSUMERS"
728        | b"XINFO GROUPS"
729        | b"XINFO STREAM"
730        | b"OBJECT ENCODING"
731        | b"OBJECT FREQ"
732        | b"OBJECT IDLETIME"
733        | b"OBJECT REFCOUNT"
734        | b"JSON.DEBUG" => RouteBy::SecondArg,
735
736        b"MIGRATE" => RouteBy::ThirdArg,
737
738        b"LMPOP" | b"SINTERCARD" | b"ZDIFF" | b"ZINTER" | b"ZINTERCARD" | b"ZMPOP" | b"ZUNION" => {
739            RouteBy::SecondArgAfterKeyCount
740        }
741
742        b"XREAD" | b"XREADGROUP" => RouteBy::StreamsIndex,
743
744        // keyless commands with more arguments, whose arguments might be wrongly taken to be keys.
745        // TODO - double check these, in order to find better ways to route some of them.
746        b"ACL DRYRUN"
747        | b"ACL GENPASS"
748        | b"ACL GETUSER"
749        | b"ACL HELP"
750        | b"ACL LIST"
751        | b"ACL LOG"
752        | b"ACL USERS"
753        | b"ACL WHOAMI"
754        | b"BGSAVE"
755        | b"CLIENT GETNAME"
756        | b"CLIENT GETREDIR"
757        | b"CLIENT ID"
758        | b"CLIENT INFO"
759        | b"CLIENT KILL"
760        | b"CLIENT LIST"
761        | b"CLIENT PAUSE"
762        | b"CLIENT REPLY"
763        | b"CLIENT TRACKINGINFO"
764        | b"CLIENT UNBLOCK"
765        | b"CLIENT UNPAUSE"
766        | b"CLUSTER COUNT-FAILURE-REPORTS"
767        | b"CLUSTER INFO"
768        | b"CLUSTER KEYSLOT"
769        | b"CLUSTER MEET"
770        | b"CLUSTER MYSHARDID"
771        | b"CLUSTER NODES"
772        | b"CLUSTER REPLICAS"
773        | b"CLUSTER RESET"
774        | b"CLUSTER SET-CONFIG-EPOCH"
775        | b"CLUSTER SHARDS"
776        | b"CLUSTER SLOTS"
777        | b"COMMAND COUNT"
778        | b"COMMAND GETKEYS"
779        | b"COMMAND LIST"
780        | b"COMMAND"
781        | b"CONFIG GET"
782        | b"ECHO"
783        | b"FUNCTION LIST"
784        | b"LASTSAVE"
785        | b"LOLWUT"
786        | b"MODULE LIST"
787        | b"MODULE LOAD"
788        | b"MODULE LOADEX"
789        | b"MODULE UNLOAD"
790        | b"READONLY"
791        | b"READWRITE"
792        | b"SAVE"
793        | b"SCRIPT SHOW"
794        | b"TFCALL"
795        | b"TFCALLASYNC"
796        | b"TFUNCTION DELETE"
797        | b"TFUNCTION LIST"
798        | b"TFUNCTION LOAD"
799        | b"TIME" => RouteBy::Random,
800
801        b"CLUSTER ADDSLOTS"
802        | b"CLUSTER COUNTKEYSINSLOT"
803        | b"CLUSTER DELSLOTS"
804        | b"CLUSTER DELSLOTSRANGE"
805        | b"CLUSTER GETKEYSINSLOT"
806        | b"CLUSTER SETSLOT" => RouteBy::SecondArgSlot,
807
808        _ => RouteBy::FirstKey,
809    }
810}
811
812impl RoutingInfo {
813    /// Returns true if the `cmd` should be routed to all nodes.
814    pub fn is_all_nodes(cmd: &[u8]) -> bool {
815        matches!(base_routing(cmd), RouteBy::AllNodes)
816    }
817
818    /// Returns true if the `cmd` is a key-based command that triggers MOVED errors.
819    /// A key-based command is one that will be accepted only by the slot owner,
820    /// while other nodes will respond with a MOVED error redirecting to the relevant primary owner.
821    pub fn is_key_routing_command(cmd: &[u8]) -> bool {
822        match base_routing(cmd) {
823            RouteBy::FirstKey
824            | RouteBy::SecondArg
825            | RouteBy::ThirdArg
826            | RouteBy::SecondArgAfterKeyCount
827            | RouteBy::ThirdArgAfterKeyCount
828            | RouteBy::SecondArgSlot
829            | RouteBy::StreamsIndex
830            | RouteBy::MultiShard(_) => {
831                if matches!(cmd, b"SPUBLISH") {
832                    // SPUBLISH does not return MOVED errors within the slot's shard. This means that even if READONLY wasn't sent to a replica,
833                    // executing SPUBLISH FOO BAR on that replica will succeed. This behavior differs from true key-based commands,
834                    // such as SET FOO BAR, where a non-readonly replica would return a MOVED error if READONLY is off.
835                    // Consequently, SPUBLISH does not meet the requirement of being a command that triggers MOVED errors.
836                    // TODO: remove this when PRIMARY_PREFERRED route for SPUBLISH is added
837                    false
838                } else {
839                    true
840                }
841            }
842            RouteBy::AllNodes | RouteBy::AllPrimaries | RouteBy::Random | RouteBy::Undefined => {
843                false
844            }
845        }
846    }
847
848    /// Returns the routing info for `r`.
849    pub fn for_routable<R>(r: &R) -> Option<RoutingInfo>
850    where
851        R: Routable + ?Sized,
852    {
853        let cmd = &r.command()?[..];
854        match base_routing(cmd) {
855            RouteBy::AllNodes => Some(RoutingInfo::MultiNode((
856                MultipleNodeRoutingInfo::AllNodes,
857                ResponsePolicy::for_command(cmd),
858            ))),
859
860            RouteBy::AllPrimaries => Some(RoutingInfo::MultiNode((
861                MultipleNodeRoutingInfo::AllMasters,
862                ResponsePolicy::for_command(cmd),
863            ))),
864
865            RouteBy::MultiShard(arg_pattern) => multi_shard(r, cmd, 1, arg_pattern),
866
867            RouteBy::Random => Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)),
868
869            RouteBy::ThirdArgAfterKeyCount => {
870                let key_count = r
871                    .arg_idx(2)
872                    .and_then(|x| std::str::from_utf8(x).ok())
873                    .and_then(|x| x.parse::<u64>().ok())?;
874                if key_count == 0 {
875                    Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random))
876                } else {
877                    r.arg_idx(3).map(|key| RoutingInfo::for_key(cmd, key))
878                }
879            }
880
881            RouteBy::SecondArg => r.arg_idx(2).map(|key| RoutingInfo::for_key(cmd, key)),
882
883            RouteBy::ThirdArg => r.arg_idx(3).map(|key| RoutingInfo::for_key(cmd, key)),
884
885            RouteBy::SecondArgAfterKeyCount => {
886                let key_count = r
887                    .arg_idx(1)
888                    .and_then(|x| std::str::from_utf8(x).ok())
889                    .and_then(|x| x.parse::<u64>().ok())?;
890                if key_count == 0 {
891                    Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random))
892                } else {
893                    r.arg_idx(2).map(|key| RoutingInfo::for_key(cmd, key))
894                }
895            }
896
897            RouteBy::StreamsIndex => {
898                let streams_position = r.position(b"STREAMS")?;
899                r.arg_idx(streams_position + 1)
900                    .map(|key| RoutingInfo::for_key(cmd, key))
901            }
902
903            RouteBy::SecondArgSlot => r
904                .arg_idx(2)
905                .and_then(|arg| std::str::from_utf8(arg).ok())
906                .and_then(|slot| slot.parse::<u16>().ok())
907                .map(|slot| {
908                    RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new(
909                        slot,
910                        SlotAddr::Master,
911                    )))
912                }),
913
914            RouteBy::FirstKey => match r.arg_idx(1) {
915                Some(key) => Some(RoutingInfo::for_key(cmd, key)),
916                None => Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)),
917            },
918
919            RouteBy::Undefined => None,
920        }
921    }
922
923    fn for_key(cmd: &[u8], key: &[u8]) -> RoutingInfo {
924        RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(get_route(
925            is_readonly_cmd(cmd),
926            key,
927        )))
928    }
929}
930
931/// Returns true if the given `routable` represents a readonly command.
932pub fn is_readonly(routable: &impl Routable) -> bool {
933    match routable.command() {
934        Some(cmd) => is_readonly_cmd(cmd.as_slice()),
935        None => false,
936    }
937}
938
939/// Returns `true` if the given `cmd` is a readonly command.
940pub fn is_readonly_cmd(cmd: &[u8]) -> bool {
941    matches!(
942        cmd,
943        b"ACL CAT"
944            | b"ACL DELUSER"
945            | b"ACL DRYRUN"
946            | b"ACL GENPASS"
947            | b"ACL GETUSER"
948            | b"ACL HELP"
949            | b"ACL LIST"
950            | b"ACL LOAD"
951            | b"ACL LOG"
952            | b"ACL SAVE"
953            | b"ACL SETUSER"
954            | b"ACL USERS"
955            | b"ACL WHOAMI"
956            | b"AUTH"
957            | b"BGREWRITEAOF"
958            | b"BGSAVE"
959            | b"BITCOUNT"
960            | b"BITFIELD_RO"
961            | b"BITPOS"
962            | b"CLIENT ID"
963            | b"CLIENT CACHING"
964            | b"CLIENT CAPA"
965            | b"CLIENT GETNAME"
966            | b"CLIENT GETREDIR"
967            | b"CLIENT HELP"
968            | b"CLIENT INFO"
969            | b"CLIENT KILL"
970            | b"CLIENT LIST"
971            | b"CLIENT NO-EVICT"
972            | b"CLIENT NO-TOUCH"
973            | b"CLIENT PAUSE"
974            | b"CLIENT REPLY"
975            | b"CLIENT SETINFO"
976            | b"CLIENT SETNAME"
977            | b"CLIENT TRACKING"
978            | b"CLIENT TRACKINGINFO"
979            | b"CLIENT UNBLOCK"
980            | b"CLIENT UNPAUSE"
981            | b"CLUSTER COUNT-FAILURE-REPORTS"
982            | b"CLUSTER COUNTKEYSINSLOT"
983            | b"CLUSTER FAILOVER"
984            | b"CLUSTER GETKEYSINSLOT"
985            | b"CLUSTER HELP"
986            | b"CLUSTER INFO"
987            | b"CLUSTER KEYSLOT"
988            | b"CLUSTER LINKS"
989            | b"CLUSTER MYID"
990            | b"CLUSTER MYSHARDID"
991            | b"CLUSTER NODES"
992            | b"CLUSTER REPLICATE"
993            | b"CLUSTER SAVECONFIG"
994            | b"CLUSTER SHARDS"
995            | b"CLUSTER SLOTS"
996            | b"COMMAND COUNT"
997            | b"COMMAND DOCS"
998            | b"COMMAND GETKEYS"
999            | b"COMMAND GETKEYSANDFLAGS"
1000            | b"COMMAND HELP"
1001            | b"COMMAND INFO"
1002            | b"COMMAND LIST"
1003            | b"CONFIG GET"
1004            | b"CONFIG HELP"
1005            | b"CONFIG RESETSTAT"
1006            | b"CONFIG REWRITE"
1007            | b"CONFIG SET"
1008            | b"DBSIZE"
1009            | b"DUMP"
1010            | b"ECHO"
1011            | b"EVAL_RO"
1012            | b"EVALSHA_RO"
1013            | b"EXISTS"
1014            | b"EXPIRETIME"
1015            | b"FCALL_RO"
1016            | b"FT.AGGREGATE"
1017            | b"FT.EXPLAIN"
1018            | b"FT.EXPLAINCLI"
1019            | b"FT.INFO"
1020            | b"FT.PROFILE"
1021            | b"FT.SEARCH"
1022            | b"FT._ALIASLIST"
1023            | b"FT._LIST"
1024            | b"FUNCTION DUMP"
1025            | b"FUNCTION HELP"
1026            | b"FUNCTION KILL"
1027            | b"FUNCTION LIST"
1028            | b"FUNCTION STATS"
1029            | b"GEODIST"
1030            | b"GEOHASH"
1031            | b"GEOPOS"
1032            | b"GEORADIUSBYMEMBER_RO"
1033            | b"GEORADIUS_RO"
1034            | b"GEOSEARCH"
1035            | b"GET"
1036            | b"GETBIT"
1037            | b"GETRANGE"
1038            | b"HELLO"
1039            | b"HEXISTS"
1040            | b"HGET"
1041            | b"HGETALL"
1042            | b"HKEYS"
1043            | b"HLEN"
1044            | b"HMGET"
1045            | b"HRANDFIELD"
1046            | b"HSCAN"
1047            | b"HSTRLEN"
1048            | b"HVALS"
1049            | b"JSON.ARRINDEX"
1050            | b"JSON.ARRLEN"
1051            | b"JSON.DEBUG"
1052            | b"JSON.GET"
1053            | b"JSON.OBJLEN"
1054            | b"JSON.OBJKEYS"
1055            | b"JSON.MGET"
1056            | b"JSON.RESP"
1057            | b"JSON.STRLEN"
1058            | b"JSON.TYPE"
1059            | b"INFO"
1060            | b"KEYS"
1061            | b"LASTSAVE"
1062            | b"LATENCY DOCTOR"
1063            | b"LATENCY GRAPH"
1064            | b"LATENCY HELP"
1065            | b"LATENCY HISTOGRAM"
1066            | b"LATENCY HISTORY"
1067            | b"LATENCY LATEST"
1068            | b"LATENCY RESET"
1069            | b"LCS"
1070            | b"LINDEX"
1071            | b"LLEN"
1072            | b"LOLWUT"
1073            | b"LPOS"
1074            | b"LRANGE"
1075            | b"MEMORY DOCTOR"
1076            | b"MEMORY HELP"
1077            | b"MEMORY MALLOC-STATS"
1078            | b"MEMORY PURGE"
1079            | b"MEMORY STATS"
1080            | b"MEMORY USAGE"
1081            | b"MGET"
1082            | b"MODULE HELP"
1083            | b"MODULE LIST"
1084            | b"MODULE LOAD"
1085            | b"MODULE LOADEX"
1086            | b"MODULE UNLOAD"
1087            | b"OBJECT ENCODING"
1088            | b"OBJECT FREQ"
1089            | b"OBJECT HELP"
1090            | b"OBJECT IDLETIME"
1091            | b"OBJECT REFCOUNT"
1092            | b"PEXPIRETIME"
1093            | b"PFCOUNT"
1094            | b"PING"
1095            | b"PTTL"
1096            | b"PUBLISH"
1097            | b"PUBSUB CHANNELS"
1098            | b"PUBSUB HELP"
1099            | b"PUBSUB NUMPAT"
1100            | b"PUBSUB NUMSUB"
1101            | b"PUBSUB SHARDCHANNELS"
1102            | b"PUBSUB SHARDNUMSUB"
1103            | b"RANDOMKEY"
1104            | b"REPLICAOF"
1105            | b"RESET"
1106            | b"ROLE"
1107            | b"SAVE"
1108            | b"SCAN"
1109            | b"SCARD"
1110            | b"SCRIPT DEBUG"
1111            | b"SCRIPT EXISTS"
1112            | b"SCRIPT FLUSH"
1113            | b"SCRIPT KILL"
1114            | b"SCRIPT LOAD"
1115            | b"SCRIPT SHOW"
1116            | b"SDIFF"
1117            | b"SELECT"
1118            | b"SENTINEL GET-MASTER-ADDR-BY-NAME"
1119            | b"SENTINEL MASTER"
1120            | b"SENTINEL MASTERS"
1121            | b"SENTINEL REPLICAS"
1122            | b"SENTINEL CKQUORUM"
1123            | b"SHUTDOWN"
1124            | b"SINTER"
1125            | b"SINTERCARD"
1126            | b"SISMEMBER"
1127            | b"SMEMBERS"
1128            | b"SMISMEMBER"
1129            | b"SLOWLOG GET"
1130            | b"SLOWLOG HELP"
1131            | b"SLOWLOG LEN"
1132            | b"SLOWLOG RESET"
1133            | b"SORT_RO"
1134            | b"SPUBLISH"
1135            | b"SRANDMEMBER"
1136            | b"SSCAN"
1137            | b"SSUBSCRIBE"
1138            | b"STRLEN"
1139            | b"SUBSCRIBE"
1140            | b"SUBSTR"
1141            | b"SUNION"
1142            | b"SUNSUBSCRIBE"
1143            | b"TIME"
1144            | b"TOUCH"
1145            | b"TTL"
1146            | b"TYPE"
1147            | b"UNSUBSCRIBE"
1148            | b"XINFO CONSUMERS"
1149            | b"XINFO GROUPS"
1150            | b"XINFO HELP"
1151            | b"XINFO STREAM"
1152            | b"XLEN"
1153            | b"XPENDING"
1154            | b"XRANGE"
1155            | b"XREAD"
1156            | b"XREVRANGE"
1157            | b"ZCARD"
1158            | b"ZCOUNT"
1159            | b"ZDIFF"
1160            | b"ZINTER"
1161            | b"ZINTERCARD"
1162            | b"ZLEXCOUNT"
1163            | b"ZMSCORE"
1164            | b"ZRANDMEMBER"
1165            | b"ZRANGE"
1166            | b"ZRANGEBYLEX"
1167            | b"ZRANGEBYSCORE"
1168            | b"ZRANK"
1169            | b"ZREVRANGE"
1170            | b"ZREVRANGEBYLEX"
1171            | b"ZREVRANGEBYSCORE"
1172            | b"ZREVRANK"
1173            | b"ZSCAN"
1174            | b"ZSCORE"
1175            | b"ZUNION"
1176    )
1177}
1178
1179/// Objects that implement this trait define a request that can be routed by a cluster client to different nodes in the cluster.
1180pub trait Routable {
1181    /// Convenience function to return ascii uppercase version of the
1182    /// the first argument (i.e., the command).
1183    fn command(&self) -> Option<Vec<u8>> {
1184        let primary_command = self.arg_idx(0).map(|x| x.to_ascii_uppercase())?;
1185        let mut primary_command = match primary_command.as_slice() {
1186            b"XGROUP" | b"OBJECT" | b"SLOWLOG" | b"FUNCTION" | b"MODULE" | b"COMMAND"
1187            | b"PUBSUB" | b"CONFIG" | b"MEMORY" | b"XINFO" | b"CLIENT" | b"ACL" | b"SCRIPT"
1188            | b"CLUSTER" | b"LATENCY" | b"SENTINEL" => primary_command,
1189            _ => {
1190                return Some(primary_command);
1191            }
1192        };
1193
1194        Some(match self.arg_idx(1) {
1195            Some(secondary_command) => {
1196                let previous_len = primary_command.len();
1197                primary_command.reserve(secondary_command.len() + 1);
1198                primary_command.extend(b" ");
1199                primary_command.extend(secondary_command);
1200                let current_len = primary_command.len();
1201                primary_command[previous_len + 1..current_len].make_ascii_uppercase();
1202                primary_command
1203            }
1204            None => primary_command,
1205        })
1206    }
1207
1208    /// Returns a reference to the data for the argument at `idx`.
1209    fn arg_idx(&self, idx: usize) -> Option<&[u8]>;
1210
1211    /// Returns index of argument that matches `candidate`, if it exists
1212    fn position(&self, candidate: &[u8]) -> Option<usize>;
1213}
1214
1215impl Routable for Cmd {
1216    fn arg_idx(&self, idx: usize) -> Option<&[u8]> {
1217        self.arg_idx(idx)
1218    }
1219
1220    fn position(&self, candidate: &[u8]) -> Option<usize> {
1221        self.args_iter().position(|a| match a {
1222            Arg::Simple(d) => d.eq_ignore_ascii_case(candidate),
1223            _ => false,
1224        })
1225    }
1226}
1227
1228impl Routable for Value {
1229    fn arg_idx(&self, idx: usize) -> Option<&[u8]> {
1230        match self {
1231            Value::Array(args) => match args.get(idx) {
1232                Some(Ok(Value::BulkString(data))) => Some(&data[..]),
1233                _ => None,
1234            },
1235            _ => None,
1236        }
1237    }
1238
1239    fn position(&self, candidate: &[u8]) -> Option<usize> {
1240        match self {
1241            Value::Array(args) => args.iter().position(|a| match a {
1242                Ok(Value::BulkString(d)) => d.eq_ignore_ascii_case(candidate),
1243                _ => false,
1244            }),
1245            _ => None,
1246        }
1247    }
1248}
1249
1250#[derive(Debug, Hash, Clone)]
1251pub(crate) struct Slot {
1252    pub(crate) start: u16,
1253    pub(crate) end: u16,
1254    pub(crate) master: String,
1255    pub(crate) replicas: Vec<String>,
1256}
1257
1258impl Slot {
1259    pub fn new(s: u16, e: u16, m: String, r: Vec<String>) -> Self {
1260        Self {
1261            start: s,
1262            end: e,
1263            master: m,
1264            replicas: r,
1265        }
1266    }
1267
1268    #[allow(dead_code)] // used in tests
1269    pub(crate) fn master(&self) -> &str {
1270        self.master.as_str()
1271    }
1272
1273    #[cfg(test)]
1274    pub fn replicas(&self) -> &[String] {
1275        &self.replicas
1276    }
1277}
1278
1279/// What type of node should a request be routed to, assuming read from replica is enabled.
1280#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash, Display)]
1281pub enum SlotAddr {
1282    /// The request must be routed to primary node
1283    Master,
1284    /// The request may be routed to a replica node.
1285    /// For example, a GET command can be routed either to replica or primary.
1286    ReplicaOptional,
1287    /// The request must be routed to replica node, if one exists.
1288    /// For example, by user requested routing.
1289    ReplicaRequired,
1290}
1291
1292/// Represents the result of checking a shard for the status of a node.
1293///
1294/// This enum indicates whether a given node is already the primary, has been promoted to a primary from a replica,
1295/// or is not found in the shard at all.
1296///
1297/// Variants:
1298/// - `AlreadyPrimary`: The specified node is already the primary for the shard, so no changes are needed.
1299/// - `Promoted`: The specified node was found as a replica and successfully promoted to primary.
1300/// - `NodeNotFound`: The specified node is neither the current primary nor a replica within the shard.
1301#[derive(PartialEq, Debug)]
1302pub(crate) enum ShardUpdateResult {
1303    AlreadyPrimary,
1304    Promoted,
1305    NodeNotFound,
1306}
1307
1308const READ_LK_ERR_SHARDADDRS: &str = "Failed to acquire read lock for ShardAddrs";
1309const WRITE_LK_ERR_SHARDADDRS: &str = "Failed to acquire write lock for ShardAddrs";
1310/// This is just a simplified version of [`Slot`],
1311/// which stores only the master and [optional] replica
1312/// to avoid the need to choose a replica each time
1313/// a command is executed
1314#[derive(Debug)]
1315pub struct ShardAddrs {
1316    primary: RwLock<Arc<String>>,
1317    replicas: RwLock<Vec<Arc<String>>>,
1318}
1319
1320impl PartialEq for ShardAddrs {
1321    fn eq(&self, other: &Self) -> bool {
1322        let self_primary = self.primary.read().expect(READ_LK_ERR_SHARDADDRS);
1323        let other_primary = other.primary.read().expect(READ_LK_ERR_SHARDADDRS);
1324
1325        let self_replicas = self.replicas.read().expect(READ_LK_ERR_SHARDADDRS);
1326        let other_replicas = other.replicas.read().expect(READ_LK_ERR_SHARDADDRS);
1327
1328        *self_primary == *other_primary && *self_replicas == *other_replicas
1329    }
1330}
1331
1332impl Eq for ShardAddrs {}
1333
1334impl PartialOrd for ShardAddrs {
1335    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1336        Some(self.cmp(other))
1337    }
1338}
1339
1340impl Ord for ShardAddrs {
1341    fn cmp(&self, other: &Self) -> Ordering {
1342        let self_primary = self.primary.read().expect(READ_LK_ERR_SHARDADDRS);
1343        let other_primary = other.primary.read().expect(READ_LK_ERR_SHARDADDRS);
1344
1345        let primary_cmp = self_primary.cmp(&other_primary);
1346        if primary_cmp == Ordering::Equal {
1347            let self_replicas = self.replicas.read().expect(READ_LK_ERR_SHARDADDRS);
1348            let other_replicas = other.replicas.read().expect(READ_LK_ERR_SHARDADDRS);
1349            return self_replicas.cmp(&other_replicas);
1350        }
1351
1352        primary_cmp
1353    }
1354}
1355
1356impl ShardAddrs {
1357    pub(crate) fn new(primary: Arc<String>, replicas: Vec<Arc<String>>) -> Self {
1358        let primary = RwLock::new(primary);
1359        let replicas = RwLock::new(replicas);
1360        Self { primary, replicas }
1361    }
1362
1363    pub(crate) fn new_with_primary(primary: Arc<String>) -> Self {
1364        Self::new(primary, Vec::default())
1365    }
1366
1367    /// Returns the address of the primary node for this shard.
1368    pub fn primary(&self) -> Arc<String> {
1369        self.primary.read().expect(READ_LK_ERR_SHARDADDRS).clone()
1370    }
1371
1372    pub(crate) fn replicas(&self) -> std::sync::RwLockReadGuard<'_, Vec<Arc<String>>> {
1373        self.replicas.read().expect(READ_LK_ERR_SHARDADDRS)
1374    }
1375
1376    /// Attempts to update the shard roles based on the provided `new_primary`.
1377    ///
1378    /// This function evaluates whether the specified `new_primary` node is already
1379    /// the primary, a replica that can be promoted to primary, or a node not present
1380    /// in the shard. It handles three scenarios:
1381    ///
1382    /// 1. **Already Primary**: If the `new_primary` is already the current primary,
1383    ///    the function returns `ShardUpdateResult::AlreadyPrimary` and no changes are made.
1384    ///
1385    /// 2. **Promoted**: If the `new_primary` is found in the list of replicas, it is promoted
1386    ///    to primary by swapping it with the current primary, and the function returns
1387    ///    `ShardUpdateResult::Promoted`.
1388    ///
1389    /// 3. **Node Not Found**: If the `new_primary` is neither the current primary nor a replica,
1390    ///    the function returns `ShardUpdateResult::NodeNotFound` to indicate that the node is
1391    ///    not part of the current shard.
1392    ///
1393    /// # Arguments:
1394    /// * `new_primary` - Representing the node to be promoted or checked.
1395    ///
1396    /// # Returns:
1397    /// * `ShardUpdateResult` - The result of the role update operation.
1398    pub(crate) fn attempt_shard_role_update(&self, new_primary: Arc<String>) -> ShardUpdateResult {
1399        let mut primary_lock = self.primary.write().expect(WRITE_LK_ERR_SHARDADDRS);
1400        let mut replicas_lock = self.replicas.write().expect(WRITE_LK_ERR_SHARDADDRS);
1401
1402        // If the new primary is already the current primary, return early.
1403        if *primary_lock == new_primary {
1404            return ShardUpdateResult::AlreadyPrimary;
1405        }
1406
1407        // If the new primary is found among replicas, swap it with the current primary.
1408        if let Some(replica_idx) = Self::replica_index(&replicas_lock, new_primary.clone()) {
1409            std::mem::swap(&mut *primary_lock, &mut replicas_lock[replica_idx]);
1410            return ShardUpdateResult::Promoted;
1411        }
1412
1413        // If the new primary isn't part of the shard.
1414        ShardUpdateResult::NodeNotFound
1415    }
1416
1417    fn replica_index(
1418        replicas: &RwLockWriteGuard<'_, Vec<Arc<String>>>,
1419        target_replica: Arc<String>,
1420    ) -> Option<usize> {
1421        replicas
1422            .iter()
1423            .position(|curr_replica| **curr_replica == *target_replica)
1424    }
1425
1426    /// Returns true if the given address is any member of this shard (primary or replica).
1427    pub(crate) fn is_member(&self, addr: &str) -> bool {
1428        if self.primary.read().expect(READ_LK_ERR_SHARDADDRS).as_str() == addr {
1429            return true;
1430        }
1431        self.replicas
1432            .read()
1433            .expect(READ_LK_ERR_SHARDADDRS)
1434            .iter()
1435            .any(|r| r.as_str() == addr)
1436    }
1437
1438    /// Removes the specified `replica_to_remove` from the shard's replica list if it exists.
1439    pub(crate) fn remove_replica(&self, replica_to_remove: Arc<String>) -> Result<()> {
1440        let mut replicas_lock = self.replicas.write().expect(WRITE_LK_ERR_SHARDADDRS);
1441        if let Some(index) = Self::replica_index(&replicas_lock, replica_to_remove.clone()) {
1442            replicas_lock.remove(index);
1443            Ok(())
1444        } else {
1445            Err(Error::from((
1446                ErrorKind::ClientError,
1447                "Couldn't remove replica",
1448                format!("Replica {replica_to_remove:?} not found"),
1449            )))
1450        }
1451    }
1452}
1453
1454impl IntoIterator for &ShardAddrs {
1455    type Item = Arc<String>;
1456    type IntoIter = std::iter::Chain<Once<Arc<String>>, std::vec::IntoIter<Arc<String>>>;
1457
1458    fn into_iter(self) -> Self::IntoIter {
1459        let primary = self.primary.read().expect(READ_LK_ERR_SHARDADDRS).clone();
1460        let replicas = self.replicas.read().expect(READ_LK_ERR_SHARDADDRS).clone();
1461
1462        std::iter::once(primary).chain(replicas)
1463    }
1464}
1465
1466/// Defines the slot and the [`SlotAddr`] to which
1467/// a command should be sent
1468#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)]
1469pub struct Route(u16, SlotAddr);
1470
1471impl Route {
1472    /// Returns a new Route.
1473    pub fn new(slot: u16, slot_addr: SlotAddr) -> Self {
1474        Self(slot, slot_addr)
1475    }
1476
1477    /// Returns the slot number of the route.
1478    pub fn slot(&self) -> u16 {
1479        self.0
1480    }
1481
1482    /// Returns the slot address of the route.
1483    pub fn slot_addr(&self) -> SlotAddr {
1484        self.1
1485    }
1486
1487    /// Returns a new Route for a random primary node
1488    pub fn new_random_primary() -> Self {
1489        Self::new(random_slot(), SlotAddr::Master)
1490    }
1491}
1492
1493/// Choose a random slot from `0..SLOT_SIZE` (excluding)
1494fn random_slot() -> u16 {
1495    let mut rng = rand::rng();
1496    rng.random_range(0..crate::cluster::topology::SLOT_SIZE)
1497}
1498
1499#[cfg(test)]
1500mod tests_routing {
1501    use super::{
1502        AggregateOp, MultiSlotArgPattern, MultipleNodeRoutingInfo, ResponsePolicy, Route,
1503        RoutingInfo, ShardAddrs, SingleNodeRoutingInfo, SlotAddr, command_for_multi_slot_indices,
1504    };
1505    use crate::cluster::routing::{Routable, ShardUpdateResult, is_readonly, is_readonly_cmd};
1506    use crate::cluster::topology::slot;
1507    use crate::cmd::cmd;
1508    use crate::protocol::parser::parse_valkey_value;
1509    use crate::value::Value;
1510    use core::panic;
1511    use std::sync::{Arc, RwLock};
1512
1513    #[test]
1514    fn test_routing_info_mixed_capatalization() {
1515        let mut upper = cmd("XREAD");
1516        upper.arg("STREAMS").arg("foo").arg(0);
1517
1518        let mut lower = cmd("xread");
1519        lower.arg("streams").arg("foo").arg(0);
1520
1521        assert_eq!(
1522            RoutingInfo::for_routable(&upper).unwrap(),
1523            RoutingInfo::for_routable(&lower).unwrap()
1524        );
1525
1526        let mut mixed = cmd("xReAd");
1527        mixed.arg("StReAmS").arg("foo").arg(0);
1528
1529        assert_eq!(
1530            RoutingInfo::for_routable(&lower).unwrap(),
1531            RoutingInfo::for_routable(&mixed).unwrap()
1532        );
1533    }
1534
1535    #[test]
1536    fn test_routing_info() {
1537        let mut test_cmds = vec![];
1538
1539        // RoutingInfo::AllMasters
1540        let mut test_cmd = cmd("FLUSHALL");
1541        test_cmd.arg("");
1542        test_cmds.push(test_cmd);
1543
1544        // RoutingInfo::AllNodes
1545        test_cmd = cmd("ECHO");
1546        test_cmd.arg("");
1547        test_cmds.push(test_cmd);
1548
1549        // Routing key is 2nd arg ("42")
1550        test_cmd = cmd("SET");
1551        test_cmd.arg("42");
1552        test_cmds.push(test_cmd);
1553
1554        // Routing key is 3rd arg ("FOOBAR")
1555        test_cmd = cmd("XINFO");
1556        test_cmd.arg("GROUPS").arg("FOOBAR");
1557        test_cmds.push(test_cmd);
1558
1559        // Routing key is 3rd or 4th arg (3rd = "0" == RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random))
1560        test_cmd = cmd("EVAL");
1561        test_cmd.arg("FOO").arg("0").arg("BAR");
1562        test_cmds.push(test_cmd);
1563
1564        // Routing key is 3rd or 4th arg (3rd != "0" == RoutingInfo::Slot)
1565        test_cmd = cmd("EVAL");
1566        test_cmd.arg("FOO").arg("4").arg("BAR");
1567        test_cmds.push(test_cmd);
1568
1569        // Routing key position is variable, 3rd arg
1570        test_cmd = cmd("XREAD");
1571        test_cmd.arg("STREAMS").arg("4");
1572        test_cmds.push(test_cmd);
1573
1574        // Routing key position is variable, 4th arg
1575        test_cmd = cmd("XREAD");
1576        test_cmd.arg("FOO").arg("STREAMS").arg("4");
1577        test_cmds.push(test_cmd);
1578
1579        for cmd in test_cmds {
1580            let value = parse_valkey_value(&cmd.get_packed_command()).unwrap();
1581            assert_eq!(
1582                RoutingInfo::for_routable(&value).unwrap(),
1583                RoutingInfo::for_routable(&cmd).unwrap(),
1584            );
1585        }
1586
1587        // Assert expected RoutingInfo explicitly:
1588
1589        for cmd in [cmd("FLUSHALL"), cmd("FLUSHDB"), cmd("PING")] {
1590            assert_eq!(
1591                RoutingInfo::for_routable(&cmd),
1592                Some(RoutingInfo::MultiNode((
1593                    MultipleNodeRoutingInfo::AllMasters,
1594                    Some(ResponsePolicy::AllSucceeded)
1595                )))
1596            );
1597        }
1598
1599        assert_eq!(
1600            RoutingInfo::for_routable(&cmd("DBSIZE")),
1601            Some(RoutingInfo::MultiNode((
1602                MultipleNodeRoutingInfo::AllMasters,
1603                Some(ResponsePolicy::Aggregate(AggregateOp::Sum))
1604            )))
1605        );
1606
1607        assert_eq!(
1608            RoutingInfo::for_routable(&cmd("SCRIPT KILL")),
1609            Some(RoutingInfo::MultiNode((
1610                MultipleNodeRoutingInfo::AllNodes,
1611                Some(ResponsePolicy::OneSucceeded)
1612            )))
1613        );
1614
1615        assert_eq!(
1616            RoutingInfo::for_routable(&cmd("INFO")),
1617            Some(RoutingInfo::MultiNode((
1618                MultipleNodeRoutingInfo::AllMasters,
1619                Some(ResponsePolicy::Special)
1620            )))
1621        );
1622
1623        assert_eq!(
1624            RoutingInfo::for_routable(&cmd("KEYS")),
1625            Some(RoutingInfo::MultiNode((
1626                MultipleNodeRoutingInfo::AllMasters,
1627                Some(ResponsePolicy::CombineArrays)
1628            )))
1629        );
1630
1631        for cmd in [cmd("SCAN"),
1632            cmd("SHUTDOWN"),
1633            cmd("SLAVEOF"),
1634            cmd("REPLICAOF")] {
1635            assert_eq!(
1636                RoutingInfo::for_routable(&cmd),
1637                None,
1638                "{}",
1639                std::str::from_utf8(cmd.arg_idx(0).unwrap()).unwrap()
1640            );
1641        }
1642
1643        for cmd in [
1644            cmd("EVAL").arg(r#"redis.call("PING");"#).arg(0),
1645            cmd("EVALSHA").arg(r#"redis.call("PING");"#).arg(0),
1646        ] {
1647            assert_eq!(
1648                RoutingInfo::for_routable(cmd),
1649                Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random))
1650            );
1651        }
1652
1653        // While FCALL with N keys is expected to be routed to a specific node
1654        assert_eq!(
1655            RoutingInfo::for_routable(cmd("FCALL").arg("foo").arg(1).arg("mykey")),
1656            Some(RoutingInfo::SingleNode(
1657                SingleNodeRoutingInfo::SpecificNode(Route::new(slot(b"mykey"), SlotAddr::Master))
1658            ))
1659        );
1660
1661        for (cmd, expected) in [
1662            (
1663                cmd("EVAL")
1664                    .arg(r#"redis.call("GET, KEYS[1]");"#)
1665                    .arg(1)
1666                    .arg("foo"),
1667                Some(RoutingInfo::SingleNode(
1668                    SingleNodeRoutingInfo::SpecificNode(Route::new(slot(b"foo"), SlotAddr::Master)),
1669                )),
1670            ),
1671            (
1672                cmd("XGROUP")
1673                    .arg("CREATE")
1674                    .arg("mystream")
1675                    .arg("workers")
1676                    .arg("$")
1677                    .arg("MKSTREAM"),
1678                Some(RoutingInfo::SingleNode(
1679                    SingleNodeRoutingInfo::SpecificNode(Route::new(
1680                        slot(b"mystream"),
1681                        SlotAddr::Master,
1682                    )),
1683                )),
1684            ),
1685            (
1686                cmd("XINFO").arg("GROUPS").arg("foo"),
1687                Some(RoutingInfo::SingleNode(
1688                    SingleNodeRoutingInfo::SpecificNode(Route::new(
1689                        slot(b"foo"),
1690                        SlotAddr::ReplicaOptional,
1691                    )),
1692                )),
1693            ),
1694            (
1695                cmd("XREADGROUP")
1696                    .arg("GROUP")
1697                    .arg("wkrs")
1698                    .arg("consmrs")
1699                    .arg("STREAMS")
1700                    .arg("mystream"),
1701                Some(RoutingInfo::SingleNode(
1702                    SingleNodeRoutingInfo::SpecificNode(Route::new(
1703                        slot(b"mystream"),
1704                        SlotAddr::Master,
1705                    )),
1706                )),
1707            ),
1708            (
1709                cmd("XREAD")
1710                    .arg("COUNT")
1711                    .arg("2")
1712                    .arg("STREAMS")
1713                    .arg("mystream")
1714                    .arg("writers")
1715                    .arg("0-0")
1716                    .arg("0-0"),
1717                Some(RoutingInfo::SingleNode(
1718                    SingleNodeRoutingInfo::SpecificNode(Route::new(
1719                        slot(b"mystream"),
1720                        SlotAddr::ReplicaOptional,
1721                    )),
1722                )),
1723            ),
1724        ] {
1725            assert_eq!(
1726                RoutingInfo::for_routable(cmd),
1727                expected,
1728                "{}",
1729                std::str::from_utf8(cmd.arg_idx(0).unwrap()).unwrap()
1730            );
1731        }
1732    }
1733
1734    #[test]
1735    fn test_slot_for_packed_cmd() {
1736        assert!(matches!(RoutingInfo::for_routable(&parse_valkey_value(&[
1737                42, 50, 13, 10, 36, 54, 13, 10, 69, 88, 73, 83, 84, 83, 13, 10, 36, 49, 54, 13, 10,
1738                244, 93, 23, 40, 126, 127, 253, 33, 89, 47, 185, 204, 171, 249, 96, 139, 13, 10
1739            ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::ReplicaOptional)))) if slot == 964));
1740
1741        assert!(matches!(RoutingInfo::for_routable(&parse_valkey_value(&[
1742                42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 36, 241,
1743                197, 111, 180, 254, 5, 175, 143, 146, 171, 39, 172, 23, 164, 145, 13, 10, 36, 52,
1744                13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10,
1745                80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10
1746            ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 8352));
1747
1748        assert!(matches!(RoutingInfo::for_routable(&parse_valkey_value(&[
1749                42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 169, 233,
1750                247, 59, 50, 247, 100, 232, 123, 140, 2, 101, 125, 221, 66, 170, 13, 10, 36, 52,
1751                13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10,
1752                80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10
1753            ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 5210));
1754    }
1755
1756    #[test]
1757    fn test_multi_shard_keys_only() {
1758        let mut cmd = cmd("DEL");
1759        cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz");
1760        let routing = RoutingInfo::for_routable(&cmd);
1761        let mut expected = std::collections::HashMap::new();
1762        expected.insert(Route(4813, SlotAddr::Master), vec![2]);
1763        expected.insert(Route(5061, SlotAddr::Master), vec![1, 3]);
1764        expected.insert(Route(12182, SlotAddr::Master), vec![0]);
1765
1766        assert!(
1767            matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::Aggregate(AggregateOp::Sum))))) if {
1768                let routes = vec.clone().into_iter().collect();
1769                expected == routes && args_pattern == MultiSlotArgPattern::KeysOnly
1770            }),
1771            "expected={expected:?}\nrouting={routing:?}"
1772        );
1773
1774        let mut cmd = crate::cmd::cmd("MGET");
1775        cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz");
1776        let routing = RoutingInfo::for_routable(&cmd);
1777        let mut expected = std::collections::HashMap::new();
1778        expected.insert(Route(4813, SlotAddr::ReplicaOptional), vec![2]);
1779        expected.insert(Route(5061, SlotAddr::ReplicaOptional), vec![1, 3]);
1780        expected.insert(Route(12182, SlotAddr::ReplicaOptional), vec![0]);
1781
1782        assert!(
1783            matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::CombineArrays)))) if {
1784                let routes = vec.clone().into_iter().collect();
1785                expected == routes && args_pattern == MultiSlotArgPattern::KeysOnly
1786            }),
1787            "expected={expected:?}\nrouting={routing:?}"
1788        );
1789    }
1790
1791    #[test]
1792    fn test_multi_shard_key_value_pairs() {
1793        let mut cmd = cmd("MSET");
1794        cmd.arg("foo") // key slot 12182
1795            .arg("bar") // value
1796            .arg("foo2") // key slot 1044
1797            .arg("bar2") // value
1798            .arg("{foo}foo3") // key slot 12182
1799            .arg("bar3"); // value
1800        let routing = RoutingInfo::for_routable(&cmd);
1801        let mut expected = std::collections::HashMap::new();
1802        expected.insert(Route(1044, SlotAddr::Master), vec![2, 3]);
1803        expected.insert(Route(12182, SlotAddr::Master), vec![0, 1, 4, 5]);
1804
1805        assert!(
1806            matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::AllSucceeded)))) if {
1807                let routes = vec.clone().into_iter().collect();
1808                expected == routes && args_pattern == MultiSlotArgPattern::KeyValuePairs
1809            }),
1810            "expected={expected:?}\nrouting={routing:?}"
1811        );
1812    }
1813
1814    #[test]
1815    fn test_multi_shard_keys_and_path() {
1816        let mut cmd = cmd("JSON.MGET");
1817        cmd.arg("foo") // key slot 12182
1818            .arg("bar") // key slot 5061
1819            .arg("baz") // key slot 4813
1820            .arg("{bar}vaz") // key slot 5061
1821            .arg("$.f.a"); // path
1822        let routing = RoutingInfo::for_routable(&cmd);
1823        let mut expected = std::collections::HashMap::new();
1824        expected.insert(Route(4813, SlotAddr::ReplicaOptional), vec![2, 4]);
1825        expected.insert(Route(5061, SlotAddr::ReplicaOptional), vec![1, 3, 4]);
1826        expected.insert(Route(12182, SlotAddr::ReplicaOptional), vec![0, 4]);
1827
1828        assert!(
1829            matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::CombineArrays)))) if {
1830                let routes = vec.clone().into_iter().collect();
1831                expected == routes && args_pattern == MultiSlotArgPattern::KeysAndLastArg
1832            }),
1833            "expected={expected:?}\nrouting={routing:?}"
1834        );
1835    }
1836
1837    #[test]
1838    fn test_multi_shard_key_with_two_arg_triples() {
1839        let mut cmd = cmd("JSON.MSET");
1840        cmd.arg("foo") // key slot 12182
1841            .arg("$.a") // path
1842            .arg("bar") // value
1843            .arg("foo2") // key slot 1044
1844            .arg("$.f.a") // path
1845            .arg("bar2") // value
1846            .arg("{foo}foo3") // key slot 12182
1847            .arg("$.f.a") // path
1848            .arg("bar3"); // value
1849        let routing = RoutingInfo::for_routable(&cmd);
1850        let mut expected = std::collections::HashMap::new();
1851        expected.insert(Route(1044, SlotAddr::Master), vec![3, 4, 5]);
1852        expected.insert(Route(12182, SlotAddr::Master), vec![0, 1, 2, 6, 7, 8]);
1853
1854        assert!(
1855            matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot((vec, args_pattern)), Some(ResponsePolicy::AllSucceeded)))) if {
1856                let routes = vec.clone().into_iter().collect();
1857                expected == routes && args_pattern == MultiSlotArgPattern::KeyWithTwoArgTriples
1858            }),
1859            "expected={expected:?}\nrouting={routing:?}"
1860        );
1861    }
1862
1863    #[test]
1864    fn test_command_creation_for_multi_shard() {
1865        let mut original_cmd = cmd("DEL");
1866        original_cmd
1867            .arg("foo")
1868            .arg("bar")
1869            .arg("baz")
1870            .arg("{bar}vaz");
1871        let routing = RoutingInfo::for_routable(&original_cmd);
1872        let expected = [vec![0], vec![1, 3], vec![2]];
1873
1874        let mut indices: Vec<_> = match routing {
1875            Some(RoutingInfo::MultiNode((
1876                MultipleNodeRoutingInfo::MultiSlot((vec, MultiSlotArgPattern::KeysOnly)),
1877                _,
1878            ))) => vec.into_iter().map(|(_, indices)| indices).collect(),
1879            _ => panic!("unexpected routing: {routing:?}"),
1880        };
1881        indices.sort_by(|prev, next| prev.iter().next().unwrap().cmp(next.iter().next().unwrap())); // sorting because the `for_routable` doesn't return values in a consistent order between runs.
1882
1883        for (index, indices) in indices.into_iter().enumerate() {
1884            let cmd = command_for_multi_slot_indices(&original_cmd, indices.iter());
1885            let expected_indices = &expected[index];
1886            assert_eq!(original_cmd.arg_idx(0), cmd.arg_idx(0));
1887            for (index, target_index) in expected_indices.iter().enumerate() {
1888                let target_index = target_index + 1;
1889                assert_eq!(original_cmd.arg_idx(target_index), cmd.arg_idx(index + 1));
1890            }
1891        }
1892    }
1893
1894    #[test]
1895    fn test_combine_multi_shard_to_single_node_when_all_keys_are_in_same_slot() {
1896        let mut cmd = cmd("DEL");
1897        cmd.arg("foo").arg("{foo}bar").arg("{foo}baz");
1898        let routing = RoutingInfo::for_routable(&cmd);
1899
1900        assert!(
1901            matches!(
1902                routing,
1903                Some(RoutingInfo::SingleNode(
1904                    SingleNodeRoutingInfo::SpecificNode(Route(12182, SlotAddr::Master))
1905                ))
1906            ),
1907            "{routing:?}"
1908        );
1909    }
1910
1911    #[test]
1912    fn test_combining_results_into_single_array_only_keys() {
1913        // For example `MGET foo bar baz {baz}baz2 {bar}bar2 {foo}foo2`
1914        let res1 = Value::Array(vec![Ok(Value::Nil), Ok(Value::Okay)]);
1915        let res2 = Value::Array(vec![
1916            Ok(Value::BulkString("1".as_bytes().to_vec().into())),
1917            Ok(Value::BulkString("4".as_bytes().to_vec().into())),
1918        ]);
1919        let res3 = Value::Array(vec![Ok(Value::SimpleString("2".to_string())), Ok(Value::Int(3))]);
1920        let results = super::combine_and_sort_array_results(
1921            vec![res1, res2, res3],
1922            &[
1923                (Route(4813, SlotAddr::Master), vec![2, 3]),
1924                (Route(5061, SlotAddr::Master), vec![1, 4]),
1925                (Route(12182, SlotAddr::Master), vec![0, 5]),
1926            ],
1927            &MultiSlotArgPattern::KeysOnly,
1928        );
1929
1930        assert_eq!(
1931            results.unwrap(),
1932            Value::Array(vec![
1933                Ok(Value::SimpleString("2".to_string())),
1934                Ok(Value::BulkString("1".as_bytes().to_vec().into())),
1935                Ok(Value::Nil),
1936                Ok(Value::Okay),
1937                Ok(Value::BulkString("4".as_bytes().to_vec().into())),
1938                Ok(Value::Int(3)),
1939            ])
1940        );
1941    }
1942
1943    #[test]
1944    fn test_combining_results_into_single_array_key_value_paires() {
1945        // For example `MSET foo bar foo2 bar2 {foo}foo3 bar3`
1946        let res1 = Value::Array(vec![Ok(Value::Okay)]);
1947        let res2 = Value::Array(vec![
1948            Ok(Value::BulkString("1".as_bytes().to_vec().into())),
1949            Ok(Value::Nil),
1950        ]);
1951        let results = super::combine_and_sort_array_results(
1952            vec![res1, res2],
1953            &[
1954                (Route(1044, SlotAddr::Master), vec![2, 3]),
1955                (Route(12182, SlotAddr::Master), vec![0, 1, 4, 5]),
1956            ],
1957            &MultiSlotArgPattern::KeyValuePairs,
1958        );
1959
1960        assert_eq!(
1961            results.unwrap(),
1962            Value::Array(vec![
1963                Ok(Value::BulkString("1".as_bytes().to_vec().into())),
1964                Ok(Value::Okay),
1965                Ok(Value::Nil)
1966            ])
1967        );
1968    }
1969
1970    #[test]
1971    fn test_combining_results_into_single_array_keys_and_path() {
1972        // For example `JSON.MGET foo bar {foo}foo2 $.a`
1973        let res1 = Value::Array(vec![Ok(Value::Okay)]);
1974        let res2 = Value::Array(vec![
1975            Ok(Value::BulkString("1".as_bytes().to_vec().into())),
1976            Ok(Value::Nil),
1977        ]);
1978        let results = super::combine_and_sort_array_results(
1979            vec![res1, res2],
1980            &[
1981                (Route(5061, SlotAddr::Master), vec![2, 3]),
1982                (Route(12182, SlotAddr::Master), vec![0, 1, 3]),
1983            ],
1984            &MultiSlotArgPattern::KeysAndLastArg,
1985        );
1986
1987        assert_eq!(
1988            results.unwrap(),
1989            Value::Array(vec![
1990                Ok(Value::BulkString("1".as_bytes().to_vec().into())),
1991                Ok(Value::Nil),
1992                Ok(Value::Okay),
1993            ])
1994        );
1995    }
1996
1997    #[test]
1998    fn test_combining_results_into_single_array_key_with_two_arg_triples() {
1999        // For example `JSON.MSET foo $.a bar foo2 $.f.a bar2 {foo}foo3 $.f bar3`
2000        let res1 = Value::Array(vec![Ok(Value::Okay)]);
2001        let res2 = Value::Array(vec![
2002            Ok(Value::BulkString("1".as_bytes().to_vec().into())),
2003            Ok(Value::Nil),
2004        ]);
2005        let results = super::combine_and_sort_array_results(
2006            vec![res1, res2],
2007            &[
2008                (Route(5061, SlotAddr::Master), vec![3, 4, 5]),
2009                (Route(12182, SlotAddr::Master), vec![0, 1, 2, 6, 7, 8]),
2010            ],
2011            &MultiSlotArgPattern::KeyWithTwoArgTriples,
2012        );
2013
2014        assert_eq!(
2015            results.unwrap(),
2016            Value::Array(vec![
2017                Ok(Value::BulkString("1".as_bytes().to_vec().into())),
2018                Ok(Value::Okay),
2019                Ok(Value::Nil)
2020            ])
2021        );
2022    }
2023
2024    #[test]
2025    fn test_combine_map_results() {
2026        let input = vec![];
2027        let result = super::combine_map_results(input).unwrap();
2028        assert_eq!(result, Value::Map(vec![]));
2029
2030        let input = vec![
2031            Value::Array(vec![
2032                Ok(Value::BulkString(b"key1".to_vec().into())),
2033                Ok(Value::Int(5)),
2034                Ok(Value::BulkString(b"key2".to_vec().into())),
2035                Ok(Value::Int(10)),
2036            ]),
2037            Value::Array(vec![
2038                Ok(Value::BulkString(b"key1".to_vec().into())),
2039                Ok(Value::Int(3)),
2040                Ok(Value::BulkString(b"key3".to_vec().into())),
2041                Ok(Value::Int(15)),
2042            ]),
2043        ];
2044        let result = super::combine_map_results(input).unwrap();
2045        let mut expected = vec![
2046            (Value::BulkString(b"key1".to_vec().into()), Value::Int(8)),
2047            (Value::BulkString(b"key2".to_vec().into()), Value::Int(10)),
2048            (Value::BulkString(b"key3".to_vec().into()), Value::Int(15)),
2049        ];
2050        expected.sort_unstable_by(|a, b| match (&a.0, &b.0) {
2051            (Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes),
2052            _ => std::cmp::Ordering::Equal,
2053        });
2054        let mut result_vec = match result {
2055            Value::Map(v) => v,
2056            _ => panic!("Expected Map"),
2057        };
2058        result_vec.sort_unstable_by(|a, b| match (&a.0, &b.0) {
2059            (Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes),
2060            _ => std::cmp::Ordering::Equal,
2061        });
2062        assert_eq!(result_vec, expected);
2063
2064        let input = vec![Value::Int(5)];
2065        let result = super::combine_map_results(input);
2066        assert!(result.is_err());
2067    }
2068
2069    fn create_shard_addrs(primary: &str, replicas: Vec<&str>) -> ShardAddrs {
2070        ShardAddrs {
2071            primary: RwLock::new(Arc::new(primary.to_string())),
2072            replicas: RwLock::new(
2073                replicas
2074                    .into_iter()
2075                    .map(|r| Arc::new(r.to_string()))
2076                    .collect(),
2077            ),
2078        }
2079    }
2080
2081    #[test]
2082    fn test_attempt_shard_role_update_already_primary() {
2083        let shard_addrs = create_shard_addrs("node1:6379", vec!["node2:6379", "node3:6379"]);
2084        let result = shard_addrs.attempt_shard_role_update(Arc::new("node1:6379".to_string()));
2085        assert_eq!(result, ShardUpdateResult::AlreadyPrimary);
2086    }
2087
2088    #[test]
2089    fn test_attempt_shard_role_update_promoted() {
2090        let shard_addrs = create_shard_addrs("node1:6379", vec!["node2:6379", "node3:6379"]);
2091        let result = shard_addrs.attempt_shard_role_update(Arc::new("node2:6379".to_string()));
2092        assert_eq!(result, ShardUpdateResult::Promoted);
2093
2094        let primary = shard_addrs.primary.read().unwrap().clone();
2095        assert_eq!(primary.as_str(), "node2:6379");
2096
2097        let replicas = shard_addrs.replicas.read().unwrap();
2098        assert_eq!(replicas.len(), 2);
2099        assert!(replicas.iter().any(|r| r.as_str() == "node1:6379"));
2100    }
2101
2102    #[test]
2103    fn test_attempt_shard_role_update_node_not_found() {
2104        let shard_addrs = create_shard_addrs("node1:6379", vec!["node2:6379", "node3:6379"]);
2105        let result = shard_addrs.attempt_shard_role_update(Arc::new("node4:6379".to_string()));
2106        assert_eq!(result, ShardUpdateResult::NodeNotFound);
2107    }
2108
2109    #[test]
2110    fn test_client_list_routing() {
2111        let mut cmd = cmd("CLIENT");
2112        cmd.arg("LIST");
2113        let routing = RoutingInfo::for_routable(&cmd);
2114        assert_eq!(
2115            routing,
2116            Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)),
2117            "CLIENT LIST should be routed to a random node"
2118        );
2119    }
2120
2121    #[test]
2122    fn test_is_read_only() {
2123        assert!(is_readonly_cmd(b"SENTINEL MASTERS"));
2124        assert!(is_readonly_cmd(b"SENTINEL MASTER"));
2125        assert!(is_readonly_cmd(b"SENTINEL REPLICAS"));
2126        assert!(is_readonly_cmd(b"SENTINEL GET-MASTER-ADDR-BY-NAME"));
2127        assert!(is_readonly_cmd(b"SENTINEL CKQUORUM"));
2128
2129        assert!(!is_readonly_cmd(b"SENTINEL FAILOVER"));
2130
2131        let mut test_cmd = cmd("SENTINEL");
2132        test_cmd.arg("MASTERS").arg("my_service");
2133        assert!(is_readonly(&test_cmd));
2134        assert!(is_readonly_cmd(
2135            Routable::command(&test_cmd).unwrap().as_slice()
2136        ));
2137
2138        let mut test_cmd = cmd("SENTINEL");
2139        test_cmd.arg("GET-MASTER-ADDR-BY-NAME").arg("my_service");
2140        assert!(is_readonly(&test_cmd));
2141        assert!(is_readonly_cmd(
2142            Routable::command(&test_cmd).unwrap().as_slice()
2143        ));
2144
2145        test_cmd = cmd("SENTINEL");
2146        test_cmd.arg("FAILOVER").arg("my_service");
2147        assert!(!is_readonly(&test_cmd));
2148        assert!(!is_readonly_cmd(
2149            Routable::command(&test_cmd).unwrap().as_slice()
2150        ));
2151    }
2152}