redis-on-mysql 0.0.1

A Redis-compatible proxy that stores all data and Pub/Sub state in MySQL
Documentation
use bytes::Bytes;
use resp_async::Value;
use resp_async::response::RespError;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};

static RAND_STATE: AtomicU64 = AtomicU64::new(0x9e3779b97f4a7c15);

pub fn ok() -> Value {
    Value::Simple(Bytes::from_static(b"OK"))
}

pub fn pong() -> Value {
    Value::Simple(Bytes::from_static(b"PONG"))
}

pub fn wrong_arity(cmd: &str) -> RespError {
    RespError::invalid_data(format!(
        "ERR wrong number of arguments for '{}' command",
        cmd
    ))
}

pub fn wrong_type() -> Value {
    Value::Error(Bytes::from_static(
        b"WRONGTYPE Operation against a key holding the wrong kind of value",
    ))
}

pub fn invalid_integer() -> RespError {
    RespError::invalid_data("ERR value is not an integer or out of range")
}

pub fn invalid_arguments(msg: &'static str) -> RespError {
    RespError::invalid_data(msg)
}

pub fn random_index(max_exclusive: i64) -> i64 {
    if max_exclusive <= 0 {
        return 0;
    }
    (random_u64() % max_exclusive as u64) as i64
}

pub fn shuffle_slice<T>(items: &mut [T]) {
    if items.len() <= 1 {
        return;
    }
    for i in (1..items.len()).rev() {
        let j = random_index((i + 1) as i64) as usize;
        items.swap(i, j);
    }
}

pub fn arg_as_bytes(arg: &Value) -> Result<&Bytes, RespError> {
    match arg {
        Value::Bulk(bytes) | Value::Simple(bytes) => Ok(bytes),
        Value::Integer(_) => Err(invalid_arguments("ERR invalid bulk argument")),
        Value::Null => Err(invalid_arguments("ERR invalid bulk argument")),
        Value::Array(_) | Value::Error(_) => Err(invalid_arguments("ERR invalid bulk argument")),
    }
}

pub fn arg_as_str(arg: &Value) -> Result<&str, RespError> {
    let bytes = arg_as_bytes(arg)?;
    std::str::from_utf8(bytes).map_err(|_| invalid_arguments("ERR invalid string argument"))
}

pub fn arg_to_string(arg: &Value) -> Result<String, RespError> {
    Ok(arg_as_str(arg)?.to_string())
}

pub fn arg_as_i64(arg: &Value) -> Result<i64, RespError> {
    match arg {
        Value::Integer(value) => Ok(*value),
        Value::Bulk(bytes) | Value::Simple(bytes) => {
            let text = std::str::from_utf8(bytes).map_err(|_| invalid_integer())?;
            text.parse::<i64>().map_err(|_| invalid_integer())
        }
        _ => Err(invalid_integer()),
    }
}

pub fn arg_as_f64(arg: &Value) -> Result<f64, RespError> {
    match arg {
        Value::Integer(value) => Ok(*value as f64),
        Value::Bulk(bytes) | Value::Simple(bytes) => {
            let text = std::str::from_utf8(bytes)
                .map_err(|_| RespError::invalid_data("ERR value is not a valid float"))?;
            text.parse::<f64>()
                .map_err(|_| RespError::invalid_data("ERR value is not a valid float"))
        }
        _ => Err(RespError::invalid_data("ERR value is not a valid float")),
    }
}

pub fn glob_has_wildcards(pattern: &[u8]) -> bool {
    pattern.iter().any(|b| matches!(b, b'*' | b'?' | b'['))
}

pub fn glob_match(pattern: &[u8], text: &[u8]) -> bool {
    glob_match_at(pattern, 0, text, 0)
}

fn glob_match_at(pattern: &[u8], mut pi: usize, text: &[u8], mut ti: usize) -> bool {
    while pi < pattern.len() {
        match pattern[pi] {
            b'*' => {
                while pi < pattern.len() && pattern[pi] == b'*' {
                    pi += 1;
                }
                if pi == pattern.len() {
                    return true;
                }
                for i in ti..=text.len() {
                    if glob_match_at(pattern, pi, text, i) {
                        return true;
                    }
                }
                return false;
            }
            b'?' => {
                if ti >= text.len() {
                    return false;
                }
                pi += 1;
                ti += 1;
            }
            b'[' => {
                let mut negate = false;
                let mut i = pi + 1;
                if i < pattern.len() && (pattern[i] == b'^' || pattern[i] == b'!') {
                    negate = true;
                    i += 1;
                }
                let mut matched = false;
                let mut prev = None;
                while i < pattern.len() && pattern[i] != b']' {
                    let ch = pattern[i];
                    if ch == b'-'
                        && prev.is_some()
                        && i + 1 < pattern.len()
                        && pattern[i + 1] != b']'
                    {
                        let start = prev.unwrap();
                        let end = pattern[i + 1];
                        if ti < text.len() && text[ti] >= start && text[ti] <= end {
                            matched = true;
                        }
                        i += 2;
                        prev = None;
                        continue;
                    }
                    if ti < text.len() && text[ti] == ch {
                        matched = true;
                    }
                    prev = Some(ch);
                    i += 1;
                }
                if i == pattern.len() {
                    return false;
                }
                if negate {
                    matched = !matched;
                }
                if !matched || ti >= text.len() {
                    return false;
                }
                pi = i + 1;
                ti += 1;
            }
            b'\\' => {
                if pi + 1 < pattern.len() {
                    pi += 1;
                    if ti >= text.len() || pattern[pi] != text[ti] {
                        return false;
                    }
                    pi += 1;
                    ti += 1;
                } else {
                    if ti >= text.len() || text[ti] != b'\\' {
                        return false;
                    }
                    pi += 1;
                    ti += 1;
                }
            }
            ch => {
                if ti >= text.len() || text[ti] != ch {
                    return false;
                }
                pi += 1;
                ti += 1;
            }
        }
    }
    ti == text.len()
}

pub fn crc16(data: &[u8]) -> u16 {
    let mut crc: u16 = 0;
    for &b in data {
        crc ^= (b as u16) << 8;
        for _ in 0..8 {
            if (crc & 0x8000) != 0 {
                crc = (crc << 1) ^ 0x1021;
            } else {
                crc <<= 1;
            }
        }
    }
    crc
}

fn random_u64() -> u64 {
    let nanos = SystemTime::now()
        .duration_since(UNIX_EPOCH)
        .unwrap_or_default()
        .as_nanos() as u64;
    let mut x = RAND_STATE.fetch_add(0x9e3779b97f4a7c15, Ordering::Relaxed) ^ nanos;
    x ^= x >> 12;
    x ^= x << 25;
    x ^= x >> 27;
    x.wrapping_mul(0x2545F4914F6CDD1D)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn glob_star() {
        assert!(glob_match(b"foo*", b"foobar"));
        assert!(glob_match(b"*bar", b"foobar"));
        assert!(!glob_match(b"foo", b"foobar"));
    }

    #[test]
    fn glob_question() {
        assert!(glob_match(b"f?o", b"foo"));
        assert!(!glob_match(b"f?o", b"fooo"));
    }

    #[test]
    fn glob_class() {
        assert!(glob_match(b"f[oa]o", b"foo"));
        assert!(glob_match(b"f[oa]o", b"fao"));
        assert!(!glob_match(b"f[oa]o", b"fbo"));
    }

    #[test]
    fn crc16_known() {
        assert_eq!(crc16(b"123456789"), 0x31C3);
    }
}