use crate::command::{Command, ParseOptions};
use crate::error::ParseError;
pub const STREAMING_THRESHOLD: usize = 64 * 1024;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SetHeader<'a> {
pub key: &'a [u8],
pub flags: u32,
pub exptime: u32,
pub noreply: bool,
}
#[derive(Debug)]
pub enum ParseProgress<'a> {
Incomplete,
NeedValue {
header: SetHeader<'a>,
value_len: usize,
value_prefix: &'a [u8],
header_consumed: usize,
},
Complete(Command<'a>, usize),
}
pub fn parse_streaming<'a>(
buffer: &'a [u8],
options: &ParseOptions,
streaming_threshold: usize,
) -> Result<ParseProgress<'a>, ParseError> {
let max_line_len = options.max_line_len();
let line_end = match find_crlf(buffer, max_line_len)? {
Some(pos) => pos,
None => return Ok(ParseProgress::Incomplete),
};
let line = &buffer[..line_end];
let mut parts = line.split(|&b| b == b' ');
let cmd = parts.next().ok_or(ParseError::Protocol("empty command"))?;
if cmd != b"set" && cmd != b"SET" {
return match Command::parse_with_options(buffer, options) {
Ok((cmd, consumed)) => Ok(ParseProgress::Complete(cmd, consumed)),
Err(ParseError::Incomplete) => Ok(ParseProgress::Incomplete),
Err(e) => Err(e),
};
}
let key = parts
.next()
.ok_or(ParseError::Protocol("set requires key"))?;
if key.is_empty() {
return Err(ParseError::Protocol("empty key"));
}
if key.len() > options.max_key_len {
return Err(ParseError::Protocol("key too large"));
}
let flags_str = parts
.next()
.ok_or(ParseError::Protocol("set requires flags"))?;
let exptime_str = parts
.next()
.ok_or(ParseError::Protocol("set requires exptime"))?;
let bytes_str = parts
.next()
.ok_or(ParseError::Protocol("set requires bytes"))?;
let flags = parse_u32(flags_str)?;
let exptime = parse_u32(exptime_str)?;
let value_len = parse_usize(bytes_str)?;
if value_len > options.max_value_len {
return Err(ParseError::Protocol("value too large"));
}
let noreply = parts.next().map(|s| s == b"noreply").unwrap_or(false);
let header_consumed = line_end + 2;
if value_len >= streaming_threshold {
let value_start = header_consumed;
let available = buffer.len().saturating_sub(value_start);
let prefix_len = std::cmp::min(available, value_len);
let value_prefix = &buffer[value_start..value_start + prefix_len];
return Ok(ParseProgress::NeedValue {
header: SetHeader {
key,
flags,
exptime,
noreply,
},
value_len,
value_prefix,
header_consumed,
});
}
let data_start = header_consumed;
let data_end = data_start
.checked_add(value_len)
.ok_or(ParseError::InvalidNumber)?;
let total_len = data_end.checked_add(2).ok_or(ParseError::InvalidNumber)?;
if buffer.len() < total_len {
return Ok(ParseProgress::Incomplete);
}
if buffer[data_end] != b'\r' || buffer[data_end + 1] != b'\n' {
return Err(ParseError::Protocol("missing data terminator"));
}
let data = &buffer[data_start..data_end];
Ok(ParseProgress::Complete(
Command::Set {
key,
flags,
exptime,
data,
},
total_len,
))
}
pub fn complete_set<'a>(header: &SetHeader<'_>, value: &'a [u8]) -> Command<'a> {
Command::Set {
key: unsafe {
std::mem::transmute::<&[u8], &'a [u8]>(header.key)
},
flags: header.flags,
exptime: header.exptime,
data: value,
}
}
fn find_crlf(buffer: &[u8], max_line_len: usize) -> Result<Option<usize>, ParseError> {
if let Some(pos) = memchr::memchr(b'\r', buffer)
.filter(|&pos| pos + 1 < buffer.len() && buffer[pos + 1] == b'\n')
{
return Ok(Some(pos));
}
if buffer.len() > max_line_len {
return Err(ParseError::Protocol("line too long"));
}
Ok(None)
}
fn parse_u32(data: &[u8]) -> Result<u32, ParseError> {
std::str::from_utf8(data)
.map_err(|_| ParseError::InvalidNumber)?
.parse()
.map_err(|_| ParseError::InvalidNumber)
}
fn parse_usize(data: &[u8]) -> Result<usize, ParseError> {
std::str::from_utf8(data)
.map_err(|_| ParseError::InvalidNumber)?
.parse()
.map_err(|_| ParseError::InvalidNumber)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_small_set_complete() {
let data = b"set mykey 0 3600 7\r\nmyvalue\r\n";
let result = parse_streaming(data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
match result {
ParseProgress::Complete(cmd, consumed) => {
assert_eq!(
cmd,
Command::Set {
key: b"mykey",
flags: 0,
exptime: 3600,
data: b"myvalue",
}
);
assert_eq!(consumed, data.len());
}
_ => panic!("expected Complete"),
}
}
#[test]
fn test_large_set_needs_value() {
let value_len = 100 * 1024; let header = format!("set mykey 0 3600 {}\r\n", value_len);
let mut data = header.as_bytes().to_vec();
data.extend_from_slice(&vec![b'x'; 1000]);
let result = parse_streaming(&data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
match result {
ParseProgress::NeedValue {
header,
value_len: vl,
value_prefix,
header_consumed,
} => {
assert_eq!(header.key, b"mykey");
assert_eq!(header.flags, 0);
assert_eq!(header.exptime, 3600);
assert!(!header.noreply);
assert_eq!(vl, 100 * 1024);
assert_eq!(value_prefix.len(), 1000);
assert!(value_prefix.iter().all(|&b| b == b'x'));
assert_eq!(header_consumed, 25); }
_ => panic!("expected NeedValue, got {:?}", result),
}
}
#[test]
fn test_set_with_noreply() {
let value_len = 100 * 1024;
let header = format!("set mykey 0 3600 {} noreply\r\n", value_len);
let result = parse_streaming(
header.as_bytes(),
&ParseOptions::default(),
STREAMING_THRESHOLD,
)
.unwrap();
match result {
ParseProgress::NeedValue { header, .. } => {
assert!(header.noreply);
}
_ => panic!("expected NeedValue"),
}
}
#[test]
fn test_get_uses_normal_path() {
let data = b"get mykey\r\n";
let result = parse_streaming(data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
match result {
ParseProgress::Complete(cmd, consumed) => {
assert_eq!(cmd, Command::Get { key: b"mykey" });
assert_eq!(consumed, data.len());
}
_ => panic!("expected Complete"),
}
}
#[test]
fn test_incomplete_header() {
let data = b"set mykey 0 360";
let result = parse_streaming(data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
match result {
ParseProgress::Incomplete => {}
_ => panic!("expected Incomplete"),
}
}
#[test]
fn test_incomplete_small_value() {
let data = b"set mykey 0 3600 100\r\npartial";
let result = parse_streaming(data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
match result {
ParseProgress::Incomplete => {}
_ => panic!("expected Incomplete"),
}
}
#[test]
fn test_threshold_boundary() {
let value_len = STREAMING_THRESHOLD;
let header = format!("set mykey 0 3600 {}\r\n", value_len);
let result = parse_streaming(
header.as_bytes(),
&ParseOptions::default(),
STREAMING_THRESHOLD,
)
.unwrap();
match result {
ParseProgress::NeedValue { value_len: vl, .. } => {
assert_eq!(vl, STREAMING_THRESHOLD);
}
_ => panic!("expected NeedValue at threshold"),
}
let value_len = STREAMING_THRESHOLD - 1;
let header = format!("set mykey 0 3600 {}\r\n", value_len);
let result = parse_streaming(
header.as_bytes(),
&ParseOptions::default(),
STREAMING_THRESHOLD,
)
.unwrap();
match result {
ParseProgress::Incomplete => {}
_ => panic!("expected Incomplete for sub-threshold without data"),
}
}
#[test]
fn test_delete_command() {
let data = b"delete mykey\r\n";
let result = parse_streaming(data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
match result {
ParseProgress::Complete(Command::Delete { key }, consumed) => {
assert_eq!(key, b"mykey");
assert_eq!(consumed, data.len());
}
_ => panic!("expected Complete Delete"),
}
}
#[test]
fn test_flush_all_command() {
let data = b"flush_all\r\n";
let result = parse_streaming(data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
match result {
ParseProgress::Complete(Command::FlushAll, consumed) => {
assert_eq!(consumed, data.len());
}
_ => panic!("expected Complete FlushAll"),
}
}
#[test]
fn test_complete_set_helper() {
let header = SetHeader {
key: b"mykey",
flags: 42,
exptime: 3600,
noreply: false,
};
let value = b"myvalue";
let cmd = complete_set(&header, value);
match cmd {
Command::Set {
key,
flags,
exptime,
data,
} => {
assert_eq!(key, b"mykey");
assert_eq!(flags, 42);
assert_eq!(exptime, 3600);
assert_eq!(data, b"myvalue");
}
_ => panic!("expected Set command"),
}
}
}