use anyhow::*;
use bytes::{Buf, Bytes, BytesMut, BufMut};
#[derive(Debug, Clone, PartialEq)]
pub enum Cmd {
Ping,
Get(Bytes),
Set(Bytes, Bytes),
Del(Bytes),
Rename(Bytes, Bytes),
Exists(Bytes),
Incr(Bytes),
MGet(Vec<Bytes>),
MSet(Vec<(Bytes, Bytes)>),
}
#[derive(Debug, Clone, PartialEq)]
pub enum Value {
Str(Bytes),
Int(i64),
Blob(Bytes),
}
pub fn parse_one(data: &[u8]) -> Result<Option<(usize, Cmd)>> {
if data.is_empty() {
return Ok(None);
}
if data[0] != b'*' {
bail!("protocol error: expected array");
}
let (i, n) = read_decimal_line(&data[1..])?;
if i == 0 {
return Ok(None);
}
let mut cursor = 1 + i;
if n <= 0 {
bail!("empty array");
}
let mut items: Vec<Bytes> = Vec::with_capacity(n as usize);
for _ in 0..n {
if cursor >= data.len() {
return Ok(None); }
if data[cursor] != b'$' {
bail!("expected bulk");
}
let (i2, len) = read_decimal_line(&data[cursor + 1..])?;
if i2 == 0 {
return Ok(None);
}
cursor += 1 + i2;
let need = len as usize + 2;
if cursor + need > data.len() {
return Ok(None); }
let payload = &data[cursor..cursor + len as usize];
items.push(Bytes::copy_from_slice(payload));
cursor += need;
}
if items.is_empty() {
bail!("empty array body");
}
let cmd = if items[0].eq_ignore_ascii_case(b"PING") {
Cmd::Ping
} else if items[0].eq_ignore_ascii_case(b"GET") && items.len() >= 2 {
Cmd::Get(items[1].clone())
} else if items[0].eq_ignore_ascii_case(b"SET") && items.len() >= 3 {
Cmd::Set(items[1].clone(), items[2].clone())
} else if items[0].eq_ignore_ascii_case(b"DEL") && items.len() >= 2 {
Cmd::Del(items[1].clone())
} else if items[0].eq_ignore_ascii_case(b"RENAME") && items.len() >= 3 {
Cmd::Rename(items[1].clone(), items[2].clone())
} else if items[0].eq_ignore_ascii_case(b"EXISTS") && items.len() >= 2 {
Cmd::Exists(items[1].clone())
} else if items[0].eq_ignore_ascii_case(b"INCR") && items.len() >= 2 {
Cmd::Incr(items[1].clone())
} else if items[0].eq_ignore_ascii_case(b"MGET") && items.len() >= 2 {
Cmd::MGet(items[1..].to_vec())
} else if items[0].eq_ignore_ascii_case(b"MSET") && items.len() >= 3 && items.len() % 2 == 1 {
let mut v = Vec::with_capacity((items.len() - 1) / 2);
for pair in items[1..].chunks(2) {
if pair.len() == 2 {
v.push((pair[0].clone(), pair[1].clone()));
}
}
Cmd::MSet(v)
} else {
bail!("unknown/invalid command");
};
Ok(Some((cursor, cmd)))
}
pub fn parse_many(buf: &mut bytes::BytesMut, out: &mut Vec<Cmd>) -> Result<()> {
loop {
let (consumed, cmd) = match parse_one(&buf[..])? {
Some(x) => x,
None => break, };
buf.advance(consumed);
out.push(cmd);
}
Ok(())
}
fn read_decimal_line(s: &[u8]) -> Result<(usize, i64)> {
let mut i = 0;
let mut num: i64 = 0;
let mut sign: i64 = 1;
if i < s.len() && s[i] == b'-' {
sign = -1;
i += 1;
}
let start = i;
while i + 8 <= s.len() {
let chunk = u64::from_le_bytes(s[i..i+8].try_into().unwrap());
let val_minus_0 = chunk.wrapping_sub(0x3030303030303030);
let val_plus_46 = chunk.wrapping_add(0x4646464646464646);
if (val_minus_0 | val_plus_46) & 0x8080808080808080 != 0 {
break;
}
break;
}
while i < s.len() {
let c = s[i];
if c.is_ascii_digit() {
num = num.wrapping_mul(10).wrapping_add((c - b'0') as i64);
i += 1;
} else {
break;
}
}
if i == start {
}
if i + 1 < s.len() && s[i] == b'\r' && s[i + 1] == b'\n' {
Ok((i + 2, num * sign))
} else if i + 1 >= s.len() {
Ok((0, 0))
} else {
bail!("expected CRLF");
}
}
pub fn resp_simple(s: &str) -> Vec<u8> {
let mut v = Vec::with_capacity(s.len() + 3);
v.push(b'+');
v.extend_from_slice(s.as_bytes());
v.extend_from_slice(b"\r\n");
v
}
pub fn resp_bulk(b: &[u8]) -> Vec<u8> {
let len_str = b.len().to_string();
let mut v = Vec::with_capacity(1 + len_str.len() + 2 + b.len() + 2);
v.push(b'$');
v.extend_from_slice(len_str.as_bytes());
v.extend_from_slice(b"\r\n");
v.extend_from_slice(b);
v.extend_from_slice(b"\r\n");
v
}
pub fn resp_null() -> Vec<u8> {
b"$-1\r\n".to_vec()
}
pub fn resp_integer(i: i64) -> Vec<u8> {
let i_str = i.to_string();
let mut v = Vec::with_capacity(1 + i_str.len() + 2);
v.push(b':');
v.extend_from_slice(i_str.as_bytes());
v.extend_from_slice(b"\r\n");
v
}
pub fn resp_array(items: Vec<Vec<u8>>) -> Vec<u8> {
let len_str = items.len().to_string();
let mut out = Vec::with_capacity(1 + len_str.len() + 2 + items.iter().map(|i| i.len()).sum::<usize>());
out.push(b'*');
out.extend_from_slice(len_str.as_bytes());
out.extend_from_slice(b"\r\n");
for it in items {
out.extend_from_slice(&it);
}
out
}
pub fn write_simple(s: &str, out: &mut BytesMut) {
out.reserve(1 + s.len() + 2);
out.put_u8(b'+');
out.put_slice(s.as_bytes());
out.put_slice(b"\r\n");
}
pub fn write_bulk(b: &[u8], out: &mut BytesMut) {
let len_str = b.len().to_string();
out.reserve(1 + len_str.len() + 2 + b.len() + 2);
out.put_u8(b'$');
out.put_slice(len_str.as_bytes());
out.put_slice(b"\r\n");
out.put_slice(b);
out.put_slice(b"\r\n");
}
pub fn write_null(out: &mut BytesMut) {
out.extend_from_slice(b"$-1\r\n");
}
pub fn write_integer(i: i64, out: &mut BytesMut) {
let i_str = i.to_string();
out.reserve(1 + i_str.len() + 2);
out.put_u8(b':');
out.put_slice(i_str.as_bytes());
out.put_slice(b"\r\n");
}
pub fn write_array_len(n: usize, out: &mut BytesMut) {
let len_str = n.to_string();
out.reserve(1 + len_str.len() + 2);
out.put_u8(b'*');
out.put_slice(len_str.as_bytes());
out.put_slice(b"\r\n");
}