use crate::argv::{Argv, Command};
use crate::error::ProtocolError;
pub fn parse_command(buf: &[u8]) -> Result<Option<(Command, usize)>, ProtocolError> {
let mut argv = Argv::default();
match parse_command_into(buf, &mut argv)? {
Some(consumed) => Ok(Some((argv, consumed))),
None => Ok(None),
}
}
pub fn parse_command_into(buf: &[u8], dst: &mut Argv) -> Result<Option<usize>, ProtocolError> {
dst.clear();
if buf.is_empty() {
return Ok(None);
}
if buf[0] == b'*' {
parse_multibulk_into(buf, dst)
} else {
parse_inline_into(buf, dst)
}
}
fn parse_inline_into(buf: &[u8], dst: &mut Argv) -> Result<Option<usize>, ProtocolError> {
let Some(eol) = find_crlf(buf, 0) else {
return Ok(None);
};
let line = &buf[..eol];
for tok in line
.split(|b| b.is_ascii_whitespace())
.filter(|s| !s.is_empty())
{
dst.push(tok);
}
Ok(Some(eol + 2))
}
pub(crate) fn validate_multibulk_frame(
buf: &[u8],
start_pos: usize,
count: usize,
) -> Result<Option<(usize, usize)>, ProtocolError> {
let mut pos = start_pos;
let mut total = 0usize;
for _ in 0..count {
if pos >= buf.len() {
return Ok(None);
}
if buf[pos] != b'$' {
return Err(ProtocolError::Malformed("expected bulk string"));
}
let Some(len_end) = find_crlf(buf, pos + 1) else {
return Ok(None);
};
let len = parse_int(&buf[pos + 1..len_end])
.ok_or(ProtocolError::Malformed("bad bulk length"))?;
if len < 0 {
return Err(ProtocolError::Malformed("negative bulk length in request"));
}
let len = len as usize;
let data_end = len_end + 2 + len;
if buf.len() < data_end + 2 {
return Ok(None);
}
if &buf[data_end..data_end + 2] != b"\r\n" {
return Err(ProtocolError::Malformed("bulk string not CRLF-terminated"));
}
total += len;
pos = data_end + 2;
}
Ok(Some((pos, total)))
}
fn copy_multibulk_args(buf: &[u8], start_pos: usize, count: usize, dst: &mut Argv) {
let mut p = start_pos;
for _ in 0..count {
let len_end = find_crlf(buf, p + 1).expect("validated in pass 1");
let len = parse_int(&buf[p + 1..len_end]).expect("validated in pass 1") as usize;
let data_start = len_end + 2;
dst.push(&buf[data_start..data_start + len]);
p = data_start + len + 2;
}
}
fn parse_multibulk_into(buf: &[u8], dst: &mut Argv) -> Result<Option<usize>, ProtocolError> {
let Some(hdr_end) = find_crlf(buf, 1) else {
return Ok(None);
};
let count =
parse_int(&buf[1..hdr_end]).ok_or(ProtocolError::Malformed("bad multibulk count"))?;
if count < 0 {
return Ok(Some(hdr_end + 2));
}
let count = count as usize;
let start = hdr_end + 2;
let (end_pos, total) = match validate_multibulk_frame(buf, start, count)? {
Some(t) => t,
None => return Ok(None),
};
dst.reserve_for(count, total);
copy_multibulk_args(buf, start, count, dst);
Ok(Some(end_pos))
}
pub(crate) fn find_crlf(buf: &[u8], start: usize) -> Option<usize> {
const CR_BCAST: u64 = 0x0D0D0D0D_0D0D0D0Du64;
const ONES: u64 = 0x01010101_01010101u64;
const HIGH: u64 = 0x80808080_80808080u64;
let n = buf.len();
let mut i = start;
if i + 1 >= n {
return None;
}
while i + 8 < n {
let word = u64::from_le_bytes(buf[i..i + 8].try_into().expect("8 bytes"));
let x = word ^ CR_BCAST;
let zeroed = x.wrapping_sub(ONES) & !x & HIGH;
if zeroed != 0 {
let bit_idx = zeroed.trailing_zeros();
let pos = i + (bit_idx / 8) as usize;
if buf[pos + 1] == b'\n' {
return Some(pos);
}
i = pos + 1;
continue;
}
i += 8;
}
while i + 1 < n {
if buf[i] == b'\r' && buf[i + 1] == b'\n' {
return Some(i);
}
i += 1;
}
None
}
pub(crate) fn parse_int(bytes: &[u8]) -> Option<i64> {
if bytes.is_empty() {
return None;
}
let (neg, digits) = match bytes[0] {
b'-' => (true, &bytes[1..]),
b'+' => (false, &bytes[1..]),
_ => (false, bytes),
};
if digits.is_empty() {
return None;
}
let mut acc: i64 = 0;
for &b in digits {
if !b.is_ascii_digit() {
return None;
}
acc = acc.checked_mul(10)?.checked_add((b - b'0') as i64)?;
}
Some(if neg { -acc } else { acc })
}
#[cfg(test)]
mod tests {
use super::*;
use crate::encode_command;
#[test]
fn find_crlf_at_every_offset() {
for off in 0..40 {
let mut buf = vec![b'a'; 60];
buf[off] = b'\r';
buf[off + 1] = b'\n';
assert_eq!(find_crlf(&buf, 0), Some(off), "off={off}");
}
}
#[test]
fn find_crlf_skips_lone_cr() {
let mut buf = vec![b'a'; 32];
buf[3] = b'\r';
buf[4] = b'b'; buf[20] = b'\r';
buf[21] = b'\n';
assert_eq!(find_crlf(&buf, 0), Some(20));
}
#[test]
fn find_crlf_none_when_absent() {
let buf = vec![b'a'; 32];
assert_eq!(find_crlf(&buf, 0), None);
let buf = b"";
assert_eq!(find_crlf(buf, 0), None);
let buf = b"\r"; assert_eq!(find_crlf(buf, 0), None);
}
#[test]
fn find_crlf_at_buffer_end() {
let buf = b"abcdefghij\r\n"; assert_eq!(find_crlf(buf, 0), Some(10));
assert_eq!(find_crlf(buf, 11), None);
}
#[test]
fn find_crlf_with_many_lone_crs() {
let mut buf = Vec::new();
for _ in 0..7 {
buf.push(b'\r');
buf.push(b'x'); }
buf.extend_from_slice(b"\r\n");
assert_eq!(find_crlf(&buf, 0), Some(14));
}
#[test]
fn find_crlf_from_nonzero_start() {
let buf = b"\r\n\r\n\r\n";
assert_eq!(find_crlf(buf, 0), Some(0));
assert_eq!(find_crlf(buf, 2), Some(2));
assert_eq!(find_crlf(buf, 4), Some(4));
}
#[test]
fn parse_multibulk_ping() {
let (cmd, used) = parse_command(b"*1\r\n$4\r\nPING\r\n").unwrap().unwrap();
assert_eq!(cmd, vec![b"PING".to_vec()]);
assert_eq!(used, 14);
}
#[test]
fn parse_multibulk_echo() {
let frame = b"*2\r\n$4\r\nECHO\r\n$5\r\nhello\r\n";
let (cmd, used) = parse_command(frame).unwrap().unwrap();
assert_eq!(cmd, vec![b"ECHO".to_vec(), b"hello".to_vec()]);
assert_eq!(used, frame.len());
}
#[test]
fn parse_incomplete_returns_none() {
assert_eq!(parse_command(b"*1\r\n$4\r\nPI").unwrap(), None);
assert_eq!(parse_command(b"*2\r\n$4\r\nECHO\r\n").unwrap(), None);
assert_eq!(parse_command(b"").unwrap(), None);
}
#[test]
fn parse_inline_command() {
let (cmd, used) = parse_command(b"PING\r\n").unwrap().unwrap();
assert_eq!(cmd, vec![b"PING".to_vec()]);
assert_eq!(used, 6);
let (cmd, _) = parse_command(b"ECHO hi there\r\n").unwrap().unwrap();
assert_eq!(
cmd,
vec![b"ECHO".to_vec(), b"hi".to_vec(), b"there".to_vec()]
);
}
#[test]
fn parse_malformed_errors() {
assert!(parse_command(b"*1\r\n+OK\r\n").is_err());
assert!(parse_command(b"*x\r\n").is_err());
}
#[test]
fn round_trip_command() {
let mut buf = Vec::new();
encode_command(&mut buf, &[b"SET".to_vec(), b"k".to_vec(), b"v".to_vec()]);
let (cmd, used) = parse_command(&buf).unwrap().unwrap();
assert_eq!(cmd, vec![b"SET".to_vec(), b"k".to_vec(), b"v".to_vec()]);
assert_eq!(used, buf.len());
}
}