use super::command::BinaryCommand;
use super::header::{HEADER_SIZE, Opcode, RequestHeader};
use crate::error::ParseError;
pub const BINARY_STREAMING_THRESHOLD: usize = 64 * 1024;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BinarySetHeader<'a> {
pub key: &'a [u8],
pub flags: u32,
pub expiration: u32,
pub cas: u64,
pub opaque: u32,
pub opcode: Opcode,
}
#[derive(Debug)]
pub enum BinaryParseProgress<'a> {
Incomplete,
NeedValue {
header: BinarySetHeader<'a>,
value_len: usize,
value_prefix: &'a [u8],
total_consumed: usize,
},
Complete(BinaryCommand<'a>, usize),
}
pub fn parse_streaming(
buffer: &[u8],
streaming_threshold: usize,
) -> Result<BinaryParseProgress<'_>, ParseError> {
if buffer.len() < HEADER_SIZE {
return Ok(BinaryParseProgress::Incomplete);
}
let header = RequestHeader::parse(buffer)?;
let total_len = HEADER_SIZE + header.total_body_length as usize;
let is_storage_command = matches!(
header.opcode,
Opcode::Set
| Opcode::SetQ
| Opcode::Add
| Opcode::AddQ
| Opcode::Replace
| Opcode::ReplaceQ
| Opcode::Append
| Opcode::AppendQ
| Opcode::Prepend
| Opcode::PrependQ
);
if !is_storage_command {
if buffer.len() < total_len {
return Ok(BinaryParseProgress::Incomplete);
}
return match BinaryCommand::parse(buffer) {
Ok((cmd, consumed)) => Ok(BinaryParseProgress::Complete(cmd, consumed)),
Err(ParseError::Incomplete) => Ok(BinaryParseProgress::Incomplete),
Err(e) => Err(e),
};
}
let extras_len = header.extras_length as usize;
let key_len = header.key_length as usize;
if extras_len + key_len > header.total_body_length as usize {
return Err(ParseError::Protocol("header lengths exceed body length"));
}
let value_len = header.total_body_length as usize - extras_len - key_len;
if value_len < streaming_threshold {
if buffer.len() < total_len {
return Ok(BinaryParseProgress::Incomplete);
}
return match BinaryCommand::parse(buffer) {
Ok((cmd, consumed)) => Ok(BinaryParseProgress::Complete(cmd, consumed)),
Err(ParseError::Incomplete) => Ok(BinaryParseProgress::Incomplete),
Err(e) => Err(e),
};
}
let header_and_key_len = HEADER_SIZE + extras_len + key_len;
if buffer.len() < header_and_key_len {
return Ok(BinaryParseProgress::Incomplete);
}
let body = &buffer[HEADER_SIZE..];
let extras = &body[..extras_len];
let (flags, expiration) = if matches!(
header.opcode,
Opcode::Set
| Opcode::SetQ
| Opcode::Add
| Opcode::AddQ
| Opcode::Replace
| Opcode::ReplaceQ
) {
if extras.len() < 8 {
return Err(ParseError::Protocol(
"storage command requires 8 bytes of extras",
));
}
let flags = u32::from_be_bytes([extras[0], extras[1], extras[2], extras[3]]);
let expiration = u32::from_be_bytes([extras[4], extras[5], extras[6], extras[7]]);
(flags, expiration)
} else {
(0, 0)
};
let key_start = extras_len;
let key_end = key_start + key_len;
let key = &body[key_start..key_end];
let value_start = header_and_key_len;
let available = buffer.len().saturating_sub(value_start);
let prefix_len = std::cmp::min(available, value_len);
let value_prefix = &buffer[value_start..value_start + prefix_len];
Ok(BinaryParseProgress::NeedValue {
header: BinarySetHeader {
key,
flags,
expiration,
cas: header.cas,
opaque: header.opaque,
opcode: header.opcode,
},
value_len,
value_prefix,
total_consumed: total_len,
})
}
pub fn complete_set<'a>(header: &BinarySetHeader<'_>, value: &'a [u8]) -> BinaryCommand<'a> {
let key: &'a [u8] = unsafe { std::mem::transmute::<&[u8], &'a [u8]>(header.key) };
match header.opcode {
Opcode::Set => BinaryCommand::Set {
key,
value,
flags: header.flags,
expiration: header.expiration,
cas: header.cas,
opaque: header.opaque,
},
Opcode::SetQ => BinaryCommand::SetQ {
key,
value,
flags: header.flags,
expiration: header.expiration,
cas: header.cas,
opaque: header.opaque,
},
Opcode::Add | Opcode::AddQ => BinaryCommand::Add {
key,
value,
flags: header.flags,
expiration: header.expiration,
opaque: header.opaque,
},
Opcode::Replace | Opcode::ReplaceQ => BinaryCommand::Replace {
key,
value,
flags: header.flags,
expiration: header.expiration,
cas: header.cas,
opaque: header.opaque,
},
Opcode::Append | Opcode::AppendQ => BinaryCommand::Append {
key,
value,
cas: header.cas,
opaque: header.opaque,
},
Opcode::Prepend | Opcode::PrependQ => BinaryCommand::Prepend {
key,
value,
cas: header.cas,
opaque: header.opaque,
},
_ => BinaryCommand::Set {
key,
value,
flags: header.flags,
expiration: header.expiration,
cas: header.cas,
opaque: header.opaque,
},
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_set_request(key: &[u8], value: &[u8], flags: u32, expiration: u32) -> Vec<u8> {
let extras = [
(flags >> 24) as u8,
(flags >> 16) as u8,
(flags >> 8) as u8,
flags as u8,
(expiration >> 24) as u8,
(expiration >> 16) as u8,
(expiration >> 8) as u8,
expiration as u8,
];
let body_len = extras.len() + key.len() + value.len();
let mut buf = vec![0u8; HEADER_SIZE + body_len];
let mut header = RequestHeader::new(Opcode::Set);
header.key_length = key.len() as u16;
header.extras_length = extras.len() as u8;
header.total_body_length = body_len as u32;
header.encode(&mut buf);
let body_start = HEADER_SIZE;
buf[body_start..body_start + extras.len()].copy_from_slice(&extras);
buf[body_start + extras.len()..body_start + extras.len() + key.len()].copy_from_slice(key);
buf[body_start + extras.len() + key.len()..].copy_from_slice(value);
buf
}
fn make_get_request(key: &[u8]) -> Vec<u8> {
let mut buf = vec![0u8; HEADER_SIZE + key.len()];
let mut header = RequestHeader::new(Opcode::Get);
header.key_length = key.len() as u16;
header.total_body_length = key.len() as u32;
header.encode(&mut buf);
buf[HEADER_SIZE..].copy_from_slice(key);
buf
}
#[test]
fn test_small_set_complete() {
let data = make_set_request(b"mykey", b"myvalue", 0, 3600);
let result = parse_streaming(&data, BINARY_STREAMING_THRESHOLD).unwrap();
match result {
BinaryParseProgress::Complete(cmd, consumed) => {
if let BinaryCommand::Set {
key,
value,
flags,
expiration,
..
} = cmd
{
assert_eq!(key, b"mykey");
assert_eq!(value, b"myvalue");
assert_eq!(flags, 0);
assert_eq!(expiration, 3600);
} else {
panic!("Expected Set command");
}
assert_eq!(consumed, data.len());
}
_ => panic!("expected Complete"),
}
}
#[test]
fn test_large_set_needs_value() {
let value_len = 100 * 1024; let key = b"mykey";
let flags = 42u32;
let expiration = 3600u32;
let extras = [
(flags >> 24) as u8,
(flags >> 16) as u8,
(flags >> 8) as u8,
flags as u8,
(expiration >> 24) as u8,
(expiration >> 16) as u8,
(expiration >> 8) as u8,
expiration as u8,
];
let body_len = extras.len() + key.len() + value_len;
let mut data = vec![0u8; HEADER_SIZE + extras.len() + key.len() + 1000]; let mut header = RequestHeader::new(Opcode::Set);
header.key_length = key.len() as u16;
header.extras_length = extras.len() as u8;
header.total_body_length = body_len as u32;
header.opaque = 123;
header.cas = 456;
header.encode(&mut data);
let body_start = HEADER_SIZE;
data[body_start..body_start + extras.len()].copy_from_slice(&extras);
data[body_start + extras.len()..body_start + extras.len() + key.len()].copy_from_slice(key);
for i in 0..1000 {
data[body_start + extras.len() + key.len() + i] = b'x';
}
let result = parse_streaming(&data, BINARY_STREAMING_THRESHOLD).unwrap();
match result {
BinaryParseProgress::NeedValue {
header: set_header,
value_len: vl,
value_prefix,
total_consumed,
} => {
assert_eq!(set_header.key, b"mykey");
assert_eq!(set_header.flags, 42);
assert_eq!(set_header.expiration, 3600);
assert_eq!(set_header.opaque, 123);
assert_eq!(set_header.cas, 456);
assert_eq!(set_header.opcode, Opcode::Set);
assert_eq!(vl, 100 * 1024);
assert_eq!(value_prefix.len(), 1000);
assert!(value_prefix.iter().all(|&b| b == b'x'));
assert_eq!(total_consumed, HEADER_SIZE + body_len);
}
_ => panic!("expected NeedValue, got {:?}", result),
}
}
#[test]
fn test_get_uses_normal_path() {
let data = make_get_request(b"mykey");
let result = parse_streaming(&data, BINARY_STREAMING_THRESHOLD).unwrap();
match result {
BinaryParseProgress::Complete(cmd, consumed) => {
assert!(matches!(cmd, BinaryCommand::Get { key: b"mykey", .. }));
assert_eq!(consumed, data.len());
}
_ => panic!("expected Complete"),
}
}
#[test]
fn test_incomplete_header() {
let data = [0x80, 0x01]; let result = parse_streaming(&data, BINARY_STREAMING_THRESHOLD).unwrap();
match result {
BinaryParseProgress::Incomplete => {}
_ => panic!("expected Incomplete"),
}
}
#[test]
fn test_incomplete_small_value() {
let key = b"mykey";
let value_len = 100; let extras = [0u8; 8];
let body_len = extras.len() + key.len() + value_len;
let mut data = vec![0u8; HEADER_SIZE + extras.len() + key.len() + 10]; let mut header = RequestHeader::new(Opcode::Set);
header.key_length = key.len() as u16;
header.extras_length = extras.len() as u8;
header.total_body_length = body_len as u32;
header.encode(&mut data);
data[HEADER_SIZE..HEADER_SIZE + extras.len()].copy_from_slice(&extras);
data[HEADER_SIZE + extras.len()..HEADER_SIZE + extras.len() + key.len()]
.copy_from_slice(key);
let result = parse_streaming(&data, BINARY_STREAMING_THRESHOLD).unwrap();
match result {
BinaryParseProgress::Incomplete => {}
_ => panic!("expected Incomplete"),
}
}
#[test]
fn test_threshold_boundary() {
let value_len = BINARY_STREAMING_THRESHOLD;
let key = b"mykey";
let extras = [0u8; 8];
let body_len = extras.len() + key.len() + value_len;
let mut data = vec![0u8; HEADER_SIZE + extras.len() + key.len()]; let mut header = RequestHeader::new(Opcode::Set);
header.key_length = key.len() as u16;
header.extras_length = extras.len() as u8;
header.total_body_length = body_len as u32;
header.encode(&mut data);
data[HEADER_SIZE..HEADER_SIZE + extras.len()].copy_from_slice(&extras);
data[HEADER_SIZE + extras.len()..].copy_from_slice(key);
let result = parse_streaming(&data, BINARY_STREAMING_THRESHOLD).unwrap();
match result {
BinaryParseProgress::NeedValue { value_len: vl, .. } => {
assert_eq!(vl, BINARY_STREAMING_THRESHOLD);
}
_ => panic!("expected NeedValue at threshold"),
}
let value_len = BINARY_STREAMING_THRESHOLD - 1;
let body_len = extras.len() + key.len() + value_len;
let mut data = vec![0u8; HEADER_SIZE + extras.len() + key.len()]; let mut header = RequestHeader::new(Opcode::Set);
header.key_length = key.len() as u16;
header.extras_length = extras.len() as u8;
header.total_body_length = body_len as u32;
header.encode(&mut data);
data[HEADER_SIZE..HEADER_SIZE + extras.len()].copy_from_slice(&extras);
data[HEADER_SIZE + extras.len()..].copy_from_slice(key);
let result = parse_streaming(&data, BINARY_STREAMING_THRESHOLD).unwrap();
match result {
BinaryParseProgress::Incomplete => {}
_ => panic!("expected Incomplete for sub-threshold without data"),
}
}
#[test]
fn test_setq_streaming() {
let value_len = 100 * 1024;
let key = b"key";
let extras = [0u8; 8];
let body_len = extras.len() + key.len() + value_len;
let mut data = vec![0u8; HEADER_SIZE + extras.len() + key.len()];
let mut header = RequestHeader::new(Opcode::SetQ);
header.key_length = key.len() as u16;
header.extras_length = extras.len() as u8;
header.total_body_length = body_len as u32;
header.encode(&mut data);
data[HEADER_SIZE..HEADER_SIZE + extras.len()].copy_from_slice(&extras);
data[HEADER_SIZE + extras.len()..].copy_from_slice(key);
let result = parse_streaming(&data, BINARY_STREAMING_THRESHOLD).unwrap();
match result {
BinaryParseProgress::NeedValue { header, .. } => {
assert_eq!(header.opcode, Opcode::SetQ);
}
_ => panic!("expected NeedValue"),
}
}
#[test]
fn test_add_streaming() {
let value_len = 100 * 1024;
let key = b"key";
let extras = [0u8; 8];
let body_len = extras.len() + key.len() + value_len;
let mut data = vec![0u8; HEADER_SIZE + extras.len() + key.len()];
let mut header = RequestHeader::new(Opcode::Add);
header.key_length = key.len() as u16;
header.extras_length = extras.len() as u8;
header.total_body_length = body_len as u32;
header.encode(&mut data);
data[HEADER_SIZE..HEADER_SIZE + extras.len()].copy_from_slice(&extras);
data[HEADER_SIZE + extras.len()..].copy_from_slice(key);
let result = parse_streaming(&data, BINARY_STREAMING_THRESHOLD).unwrap();
match result {
BinaryParseProgress::NeedValue { header, .. } => {
assert_eq!(header.opcode, Opcode::Add);
}
_ => panic!("expected NeedValue"),
}
}
#[test]
fn test_append_streaming() {
let value_len = 100 * 1024;
let key = b"key";
let body_len = key.len() + value_len;
let mut data = vec![0u8; HEADER_SIZE + key.len()];
let mut header = RequestHeader::new(Opcode::Append);
header.key_length = key.len() as u16;
header.extras_length = 0;
header.total_body_length = body_len as u32;
header.encode(&mut data);
data[HEADER_SIZE..].copy_from_slice(key);
let result = parse_streaming(&data, BINARY_STREAMING_THRESHOLD).unwrap();
match result {
BinaryParseProgress::NeedValue { header, .. } => {
assert_eq!(header.opcode, Opcode::Append);
assert_eq!(header.flags, 0);
assert_eq!(header.expiration, 0);
}
_ => panic!("expected NeedValue"),
}
}
#[test]
fn test_complete_set_helper() {
let header = BinarySetHeader {
key: b"mykey",
flags: 42,
expiration: 3600,
cas: 123,
opaque: 456,
opcode: Opcode::Set,
};
let value = b"myvalue";
let cmd = complete_set(&header, value);
match cmd {
BinaryCommand::Set {
key,
value: v,
flags,
expiration,
cas,
opaque,
} => {
assert_eq!(key, b"mykey");
assert_eq!(v, b"myvalue");
assert_eq!(flags, 42);
assert_eq!(expiration, 3600);
assert_eq!(cas, 123);
assert_eq!(opaque, 456);
}
_ => panic!("expected Set command"),
}
}
#[test]
fn test_complete_set_helper_setq() {
let header = BinarySetHeader {
key: b"key",
flags: 0,
expiration: 0,
cas: 0,
opaque: 0,
opcode: Opcode::SetQ,
};
let cmd = complete_set(&header, b"val");
assert!(matches!(cmd, BinaryCommand::SetQ { .. }));
}
#[test]
fn test_complete_set_helper_add() {
let header = BinarySetHeader {
key: b"key",
flags: 0,
expiration: 0,
cas: 0,
opaque: 0,
opcode: Opcode::Add,
};
let cmd = complete_set(&header, b"val");
assert!(matches!(cmd, BinaryCommand::Add { .. }));
}
#[test]
fn test_complete_set_helper_replace() {
let header = BinarySetHeader {
key: b"key",
flags: 0,
expiration: 0,
cas: 0,
opaque: 0,
opcode: Opcode::Replace,
};
let cmd = complete_set(&header, b"val");
assert!(matches!(cmd, BinaryCommand::Replace { .. }));
}
#[test]
fn test_complete_set_helper_append() {
let header = BinarySetHeader {
key: b"key",
flags: 0,
expiration: 0,
cas: 0,
opaque: 0,
opcode: Opcode::Append,
};
let cmd = complete_set(&header, b"val");
assert!(matches!(cmd, BinaryCommand::Append { .. }));
}
#[test]
fn test_complete_set_helper_prepend() {
let header = BinarySetHeader {
key: b"key",
flags: 0,
expiration: 0,
cas: 0,
opaque: 0,
opcode: Opcode::Prepend,
};
let cmd = complete_set(&header, b"val");
assert!(matches!(cmd, BinaryCommand::Prepend { .. }));
}
#[test]
fn test_noop_command() {
let mut data = vec![0u8; HEADER_SIZE];
let header = RequestHeader::new(Opcode::Noop);
header.encode(&mut data);
let result = parse_streaming(&data, BINARY_STREAMING_THRESHOLD).unwrap();
match result {
BinaryParseProgress::Complete(cmd, _) => {
assert!(matches!(cmd, BinaryCommand::Noop { .. }));
}
_ => panic!("expected Complete"),
}
}
#[test]
fn test_delete_command() {
let key = b"mykey";
let mut data = vec![0u8; HEADER_SIZE + key.len()];
let mut header = RequestHeader::new(Opcode::Delete);
header.key_length = key.len() as u16;
header.total_body_length = key.len() as u32;
header.encode(&mut data);
data[HEADER_SIZE..].copy_from_slice(key);
let result = parse_streaming(&data, BINARY_STREAMING_THRESHOLD).unwrap();
match result {
BinaryParseProgress::Complete(cmd, consumed) => {
assert!(matches!(cmd, BinaryCommand::Delete { key: b"mykey", .. }));
assert_eq!(consumed, data.len());
}
_ => panic!("expected Complete"),
}
}
#[test]
fn test_header_traits() {
let header = BinarySetHeader {
key: b"test",
flags: 0,
expiration: 0,
cas: 0,
opaque: 0,
opcode: Opcode::Set,
};
let header2 = header.clone();
assert_eq!(header, header2);
let debug_str = format!("{:?}", header);
assert!(debug_str.contains("BinarySetHeader"));
}
}