use crate::msg::{Msg, MsgType};
use super::fragment::FragmentDispatcher;
pub const ADD_SET_STR: &[u8] = b"._add-set";
pub const REM_SET_STR: &[u8] = b"._rem-set";
pub fn redis_verify_request<D: FragmentDispatcher + ?Sized>(
r: &Msg,
dispatcher: &D,
) -> Result<(), VerifyError> {
if r.ty() != MsgType::ReqRedisEval {
return Ok(());
}
if r.keys().len() <= 1 {
return Ok(());
}
let mut prev_idx: Option<u32> = None;
for k in r.keys() {
let key = k.key();
if key.starts_with(ADD_SET_STR) || key.starts_with(REM_SET_STR) {
continue;
}
let idx = dispatcher.shard_for(k.tag_bytes());
match prev_idx {
None => prev_idx = Some(idx),
Some(prev) if prev == idx => {}
Some(_) => return Err(VerifyError::ScriptSpansNodes),
}
}
Ok(())
}
#[derive(Copy, Clone, Debug, Eq, PartialEq, thiserror::Error)]
#[non_exhaustive]
pub enum VerifyError {
#[error("redis verify: script spans nodes")]
ScriptSpansNodes,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::msg::KeyPos;
struct OddEven;
impl FragmentDispatcher for OddEven {
fn shard_for(&self, key: &[u8]) -> u32 {
u32::from(*key.first().unwrap_or(&0)) % 2
}
fn shard_count(&self) -> u32 {
2
}
}
#[test]
fn non_eval_is_ok() {
let mut r = Msg::new(0, MsgType::ReqRedisGet, true);
r.push_key(KeyPos::without_tag(b"a".to_vec()));
r.push_key(KeyPos::without_tag(b"b".to_vec()));
assert!(redis_verify_request(&r, &OddEven).is_ok());
}
#[test]
fn eval_one_key_is_ok() {
let mut r = Msg::new(0, MsgType::ReqRedisEval, true);
r.push_key(KeyPos::without_tag(b"a".to_vec()));
assert!(redis_verify_request(&r, &OddEven).is_ok());
}
#[test]
fn eval_disjoint_shards_errors() {
let mut r = Msg::new(0, MsgType::ReqRedisEval, true);
r.push_key(KeyPos::without_tag(b"a".to_vec())); r.push_key(KeyPos::without_tag(b"b".to_vec())); assert_eq!(
redis_verify_request(&r, &OddEven),
Err(VerifyError::ScriptSpansNodes),
);
}
#[test]
fn eval_skips_metadata_keys() {
let mut r = Msg::new(0, MsgType::ReqRedisEval, true);
r.push_key(KeyPos::without_tag(b"a".to_vec()));
r.push_key(KeyPos::without_tag(b"._add-set".to_vec()));
assert!(redis_verify_request(&r, &OddEven).is_ok());
}
}