use bytes::Bytes;
use resp_async::Value;
use resp_async::response::RespError;
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
static RAND_STATE: AtomicU64 = AtomicU64::new(0x9e3779b97f4a7c15);
pub fn ok() -> Value {
Value::Simple(Bytes::from_static(b"OK"))
}
pub fn pong() -> Value {
Value::Simple(Bytes::from_static(b"PONG"))
}
pub fn wrong_arity(cmd: &str) -> RespError {
RespError::invalid_data(format!(
"ERR wrong number of arguments for '{}' command",
cmd
))
}
pub fn wrong_type() -> Value {
Value::Error(Bytes::from_static(
b"WRONGTYPE Operation against a key holding the wrong kind of value",
))
}
pub fn invalid_integer() -> RespError {
RespError::invalid_data("ERR value is not an integer or out of range")
}
pub fn invalid_arguments(msg: &'static str) -> RespError {
RespError::invalid_data(msg)
}
pub fn random_index(max_exclusive: i64) -> i64 {
if max_exclusive <= 0 {
return 0;
}
(random_u64() % max_exclusive as u64) as i64
}
pub fn shuffle_slice<T>(items: &mut [T]) {
if items.len() <= 1 {
return;
}
for i in (1..items.len()).rev() {
let j = random_index((i + 1) as i64) as usize;
items.swap(i, j);
}
}
pub fn arg_as_bytes(arg: &Value) -> Result<&Bytes, RespError> {
match arg {
Value::Bulk(bytes) | Value::Simple(bytes) => Ok(bytes),
Value::Integer(_) => Err(invalid_arguments("ERR invalid bulk argument")),
Value::Null => Err(invalid_arguments("ERR invalid bulk argument")),
Value::Array(_) | Value::Error(_) => Err(invalid_arguments("ERR invalid bulk argument")),
}
}
pub fn arg_as_str(arg: &Value) -> Result<&str, RespError> {
let bytes = arg_as_bytes(arg)?;
std::str::from_utf8(bytes).map_err(|_| invalid_arguments("ERR invalid string argument"))
}
pub fn arg_to_string(arg: &Value) -> Result<String, RespError> {
Ok(arg_as_str(arg)?.to_string())
}
pub fn arg_as_i64(arg: &Value) -> Result<i64, RespError> {
match arg {
Value::Integer(value) => Ok(*value),
Value::Bulk(bytes) | Value::Simple(bytes) => {
let text = std::str::from_utf8(bytes).map_err(|_| invalid_integer())?;
text.parse::<i64>().map_err(|_| invalid_integer())
}
_ => Err(invalid_integer()),
}
}
pub fn arg_as_f64(arg: &Value) -> Result<f64, RespError> {
match arg {
Value::Integer(value) => Ok(*value as f64),
Value::Bulk(bytes) | Value::Simple(bytes) => {
let text = std::str::from_utf8(bytes)
.map_err(|_| RespError::invalid_data("ERR value is not a valid float"))?;
text.parse::<f64>()
.map_err(|_| RespError::invalid_data("ERR value is not a valid float"))
}
_ => Err(RespError::invalid_data("ERR value is not a valid float")),
}
}
pub fn glob_has_wildcards(pattern: &[u8]) -> bool {
pattern.iter().any(|b| matches!(b, b'*' | b'?' | b'['))
}
pub fn glob_match(pattern: &[u8], text: &[u8]) -> bool {
glob_match_at(pattern, 0, text, 0)
}
fn glob_match_at(pattern: &[u8], mut pi: usize, text: &[u8], mut ti: usize) -> bool {
while pi < pattern.len() {
match pattern[pi] {
b'*' => {
while pi < pattern.len() && pattern[pi] == b'*' {
pi += 1;
}
if pi == pattern.len() {
return true;
}
for i in ti..=text.len() {
if glob_match_at(pattern, pi, text, i) {
return true;
}
}
return false;
}
b'?' => {
if ti >= text.len() {
return false;
}
pi += 1;
ti += 1;
}
b'[' => {
let mut negate = false;
let mut i = pi + 1;
if i < pattern.len() && (pattern[i] == b'^' || pattern[i] == b'!') {
negate = true;
i += 1;
}
let mut matched = false;
let mut prev = None;
while i < pattern.len() && pattern[i] != b']' {
let ch = pattern[i];
if ch == b'-'
&& prev.is_some()
&& i + 1 < pattern.len()
&& pattern[i + 1] != b']'
{
let start = prev.unwrap();
let end = pattern[i + 1];
if ti < text.len() && text[ti] >= start && text[ti] <= end {
matched = true;
}
i += 2;
prev = None;
continue;
}
if ti < text.len() && text[ti] == ch {
matched = true;
}
prev = Some(ch);
i += 1;
}
if i == pattern.len() {
return false;
}
if negate {
matched = !matched;
}
if !matched || ti >= text.len() {
return false;
}
pi = i + 1;
ti += 1;
}
b'\\' => {
if pi + 1 < pattern.len() {
pi += 1;
if ti >= text.len() || pattern[pi] != text[ti] {
return false;
}
pi += 1;
ti += 1;
} else {
if ti >= text.len() || text[ti] != b'\\' {
return false;
}
pi += 1;
ti += 1;
}
}
ch => {
if ti >= text.len() || text[ti] != ch {
return false;
}
pi += 1;
ti += 1;
}
}
}
ti == text.len()
}
pub fn crc16(data: &[u8]) -> u16 {
let mut crc: u16 = 0;
for &b in data {
crc ^= (b as u16) << 8;
for _ in 0..8 {
if (crc & 0x8000) != 0 {
crc = (crc << 1) ^ 0x1021;
} else {
crc <<= 1;
}
}
}
crc
}
fn random_u64() -> u64 {
let nanos = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64;
let mut x = RAND_STATE.fetch_add(0x9e3779b97f4a7c15, Ordering::Relaxed) ^ nanos;
x ^= x >> 12;
x ^= x << 25;
x ^= x >> 27;
x.wrapping_mul(0x2545F4914F6CDD1D)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn glob_star() {
assert!(glob_match(b"foo*", b"foobar"));
assert!(glob_match(b"*bar", b"foobar"));
assert!(!glob_match(b"foo", b"foobar"));
}
#[test]
fn glob_question() {
assert!(glob_match(b"f?o", b"foo"));
assert!(!glob_match(b"f?o", b"fooo"));
}
#[test]
fn glob_class() {
assert!(glob_match(b"f[oa]o", b"foo"));
assert!(glob_match(b"f[oa]o", b"fao"));
assert!(!glob_match(b"f[oa]o", b"fbo"));
}
#[test]
fn crc16_known() {
assert_eq!(crc16(b"123456789"), 0x31C3);
}
}