use bytes::{Bytes, BytesMut};
use crate::codec::DecodeOutcome;
use crate::error::Error;
use crate::response::Response;
use crate::types::{Op, Protocol, ReplyMode, Request, RequestMeta, StatLine};
#[derive(Debug, Clone, Copy)]
pub struct BinaryLimits {
pub max_frame_len: usize,
}
impl Default for BinaryLimits {
fn default() -> Self {
Self {
max_frame_len: 1 << 21,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct BinaryDecoder;
impl BinaryDecoder {
pub fn new() -> Self {
Self
}
pub fn decode(&mut self, buf: &mut BytesMut, limits: BinaryLimits) -> Option<DecodeOutcome> {
if buf.len() < 24 {
return None;
}
let header = &buf[..24];
let magic = header[0];
let opcode = header[1];
let key_len = u16::from_be_bytes([header[2], header[3]]) as usize;
let extras_len = header[4] as usize;
let body_len = u32::from_be_bytes([header[8], header[9], header[10], header[11]]) as usize;
let opaque = u32::from_be_bytes([header[12], header[13], header[14], header[15]]);
let cas = u64::from_be_bytes([
header[16], header[17], header[18], header[19], header[20], header[21], header[22],
header[23],
]);
let mut meta = RequestMeta {
protocol: Protocol::Binary,
reply: ReplyMode::Always,
opaque: Some(opaque),
return_key: false,
opcode,
};
if magic != 0x80 {
let err = Error::server("invalid magic");
return Some(DecodeOutcome::Response(meta, Response::Error(err)));
}
if body_len > limits.max_frame_len {
let err = Error::server("frame too large");
return Some(DecodeOutcome::Response(meta, Response::Error(err)));
}
if buf.len() < 24 + body_len {
return None;
}
let frame = buf.split_to(24 + body_len).freeze();
let body = frame.slice(24..);
if extras_len + key_len > body_len {
let err = Error::client("invalid lengths");
return Some(DecodeOutcome::Response(meta, Response::Error(err)));
}
let extras = body.slice(0..extras_len);
let key = body.slice(extras_len..extras_len + key_len);
let value = body.slice(extras_len + key_len..body_len);
let (op, quiet, return_key) = opcode_to_op(opcode);
meta.reply = if quiet {
ReplyMode::QuietBuffered
} else {
ReplyMode::Always
};
meta.return_key = return_key;
let mut req = Request::new(op);
if key_len > 0 {
if key_len > 250 {
let err = Error::client("key too long");
return Some(DecodeOutcome::Response(meta, Response::Error(err)));
}
if !is_valid_key(&key) {
let err = Error::client("invalid key");
return Some(DecodeOutcome::Response(meta, Response::Error(err)));
}
req.key = Some(key.clone());
}
if cas != 0 {
req.cas = Some(cas);
}
let result = match op {
Op::Get => {
if key_len == 0 || extras_len != 0 || !value.is_empty() {
Err("invalid get")
} else {
Ok(())
}
}
Op::Set | Op::Add | Op::Replace => {
if key_len == 0 || extras_len != 8 {
Err("invalid storage")
} else {
parse_flags_exptime(&extras, &mut req);
req.value = Some(value);
Ok(())
}
}
Op::Append | Op::Prepend => {
if key_len == 0 || (extras_len != 0 && extras_len != 8) {
Err("invalid append")
} else {
if extras_len == 8 {
parse_flags_exptime(&extras, &mut req);
}
req.value = Some(value);
Ok(())
}
}
Op::Delete => {
if key_len == 0 || extras_len != 0 || !value.is_empty() {
Err("invalid delete")
} else {
Ok(())
}
}
Op::Flush => {
if key_len != 0 || (extras_len != 0 && extras_len != 4) || !value.is_empty() {
Err("invalid flush")
} else {
if extras_len == 4 {
req.exptime =
Some(
u32::from_be_bytes([extras[0], extras[1], extras[2], extras[3]])
as i64,
);
}
Ok(())
}
}
Op::Incr | Op::Decr => {
if key_len == 0 || extras_len != 20 || !value.is_empty() {
Err("invalid incr")
} else {
parse_delta(&extras, &mut req);
Ok(())
}
}
Op::Touch => {
if key_len == 0 || extras_len != 4 || !value.is_empty() {
Err("invalid touch")
} else {
req.exptime =
Some(
u32::from_be_bytes([extras[0], extras[1], extras[2], extras[3]]) as i64,
);
Ok(())
}
}
Op::Stats => {
if extras_len != 0 {
Err("invalid stats")
} else {
if key_len > 0 {
req.key = Some(key);
}
Ok(())
}
}
Op::SaslListMechs => {
if extras_len != 0 || key_len != 0 || !value.is_empty() {
Err("invalid sasl list")
} else {
Ok(())
}
}
Op::SaslAuth | Op::SaslStep => {
if extras_len != 0 || key_len == 0 || value.is_empty() {
Err("invalid sasl auth")
} else {
req.value = Some(value);
Ok(())
}
}
Op::Version | Op::Noop | Op::Quit => {
if extras_len != 0 || key_len != 0 || !value.is_empty() {
Err("invalid command")
} else {
Ok(())
}
}
Op::Gets | Op::Gat | Op::Gats | Op::Cas => Ok(()),
Op::Unknown
| Op::MetaGet
| Op::MetaSet
| Op::MetaDelete
| Op::MetaArithmetic
| Op::MetaDebug
| Op::MetaNoop => Ok(()),
};
if result.is_err() {
let err = Error::client("invalid arguments");
return Some(DecodeOutcome::Response(meta, Response::Error(err)));
}
Some(DecodeOutcome::Request(req, meta))
}
}
fn parse_flags_exptime(extras: &Bytes, req: &mut Request) {
if extras.len() < 8 {
return;
}
req.flags = Some(u32::from_be_bytes([
extras[0], extras[1], extras[2], extras[3],
]));
req.exptime = Some(u32::from_be_bytes([extras[4], extras[5], extras[6], extras[7]]) as i64);
}
fn parse_delta(extras: &Bytes, req: &mut Request) {
if extras.len() < 20 {
return;
}
req.delta = Some(u64::from_be_bytes([
extras[0], extras[1], extras[2], extras[3], extras[4], extras[5], extras[6], extras[7],
]));
req.initial = Some(u64::from_be_bytes([
extras[8], extras[9], extras[10], extras[11], extras[12], extras[13], extras[14],
extras[15],
]));
req.exptime = Some(u32::from_be_bytes([extras[16], extras[17], extras[18], extras[19]]) as i64);
}
fn opcode_to_op(opcode: u8) -> (Op, bool, bool) {
match opcode {
0x00 => (Op::Get, false, false),
0x09 => (Op::Get, true, false),
0x0c => (Op::Get, false, true),
0x0d => (Op::Get, true, true),
0x01 => (Op::Set, false, false),
0x11 => (Op::Set, true, false),
0x02 => (Op::Add, false, false),
0x12 => (Op::Add, true, false),
0x03 => (Op::Replace, false, false),
0x13 => (Op::Replace, true, false),
0x04 => (Op::Delete, false, false),
0x14 => (Op::Delete, true, false),
0x05 => (Op::Incr, false, false),
0x15 => (Op::Incr, true, false),
0x06 => (Op::Decr, false, false),
0x16 => (Op::Decr, true, false),
0x07 => (Op::Quit, false, false),
0x17 => (Op::Quit, true, false),
0x0a => (Op::Noop, false, false),
0x0b => (Op::Version, false, false),
0x08 => (Op::Flush, false, false),
0x18 => (Op::Flush, true, false),
0x0e => (Op::Append, false, false),
0x19 => (Op::Append, true, false),
0x0f => (Op::Prepend, false, false),
0x1a => (Op::Prepend, true, false),
0x10 => (Op::Stats, false, false),
0x20 => (Op::SaslListMechs, false, false),
0x21 => (Op::SaslAuth, false, false),
0x22 => (Op::SaslStep, false, false),
_ => (Op::Unknown, false, false),
}
}
fn is_valid_key(key: &Bytes) -> bool {
if key.is_empty() || key.len() > 250 {
return false;
}
for &b in key.as_ref() {
if b <= b' ' || b == 0x7f {
return false;
}
}
true
}
pub const STATUS_SUCCESS: u16 = 0x0000;
pub const STATUS_KEY_NOT_FOUND: u16 = 0x0001;
pub const STATUS_KEY_EXISTS: u16 = 0x0002;
pub const STATUS_ITEM_NOT_STORED: u16 = 0x0005;
pub const STATUS_INVALID_ARGUMENTS: u16 = 0x0004;
pub const STATUS_AUTH_ERROR: u16 = 0x0020;
pub const STATUS_UNKNOWN_COMMAND: u16 = 0x0081;
pub const STATUS_INTERNAL_ERROR: u16 = 0x0084;
pub fn encode_binary_response(
meta: RequestMeta,
response: &Response,
out: &mut BytesMut,
return_key: bool,
) -> (u16, usize) {
let opcode = meta.opcode;
let opaque = meta.opaque.unwrap_or(0);
match response {
Response::Stored
| Response::Deleted
| Response::Touched
| Response::Noop
| Response::Ok => {
encode_header(
out,
HeaderFields::new(opcode, 0, 0, STATUS_SUCCESS, 0, opaque, 0),
);
(STATUS_SUCCESS, 24)
}
Response::NotStored => {
encode_header(
out,
HeaderFields::new(opcode, 0, 0, STATUS_ITEM_NOT_STORED, 0, opaque, 0),
);
(STATUS_ITEM_NOT_STORED, 24)
}
Response::Exists => {
encode_header(
out,
HeaderFields::new(opcode, 0, 0, STATUS_KEY_EXISTS, 0, opaque, 0),
);
(STATUS_KEY_EXISTS, 24)
}
Response::NotFound => {
encode_header(
out,
HeaderFields::new(opcode, 0, 0, STATUS_KEY_NOT_FOUND, 0, opaque, 0),
);
(STATUS_KEY_NOT_FOUND, 24)
}
Response::Numeric(value) => {
let extras_len = 0u8;
let key_len = 0u16;
let body_len = 8u32;
encode_header(
out,
HeaderFields::new(
opcode,
extras_len,
key_len,
STATUS_SUCCESS,
body_len,
opaque,
0,
),
);
out.extend_from_slice(&value.to_be_bytes());
(STATUS_SUCCESS, 24 + 8)
}
Response::Value(entry) => {
let extras_len = 4u8;
let key = if return_key { entry.key.as_ref() } else { b"" };
let key_len = key.len() as u16;
let body_len = extras_len as u32 + key_len as u32 + entry.value.len() as u32;
encode_header(
out,
HeaderFields::new(
opcode,
extras_len,
key_len,
STATUS_SUCCESS,
body_len,
opaque,
entry.cas.unwrap_or(0),
),
);
out.extend_from_slice(&entry.flags.to_be_bytes());
if return_key {
out.extend_from_slice(key);
}
out.extend_from_slice(entry.value.as_ref());
(STATUS_SUCCESS, 24 + body_len as usize)
}
Response::Values(entries) => {
if let Some(entry) = entries.first() {
encode_binary_response(meta, &Response::Value(entry.clone()), out, return_key)
} else {
encode_header(
out,
HeaderFields::new(opcode, 0, 0, STATUS_KEY_NOT_FOUND, 0, opaque, 0),
);
(STATUS_KEY_NOT_FOUND, 24)
}
}
Response::Stats(lines) => {
let mut total = 0usize;
for line in lines {
total += encode_stat_line(meta, line, out);
}
total += encode_header(
out,
HeaderFields::new(opcode, 0, 0, STATUS_SUCCESS, 0, opaque, 0),
);
(STATUS_SUCCESS, total)
}
Response::Version(version) => {
let body_len = version.len() as u32;
encode_header(
out,
HeaderFields::new(opcode, 0, 0, STATUS_SUCCESS, body_len, opaque, 0),
);
out.extend_from_slice(version.as_ref());
(STATUS_SUCCESS, 24 + version.len())
}
Response::Error(err) => {
let status = match err.kind {
crate::error::ErrorKind::UnknownCommand => STATUS_UNKNOWN_COMMAND,
crate::error::ErrorKind::Client => STATUS_INVALID_ARGUMENTS,
crate::error::ErrorKind::Server => STATUS_INTERNAL_ERROR,
crate::error::ErrorKind::Auth => STATUS_AUTH_ERROR,
};
let body_len = err.message.len() as u32;
encode_header(
out,
HeaderFields::new(opcode, 0, 0, status, body_len, opaque, 0),
);
out.extend_from_slice(err.message.as_ref());
(status, 24 + err.message.len())
}
Response::Meta(_) | Response::ValuesStream(_) | Response::StatsStream(_) => {
encode_header(
out,
HeaderFields::new(opcode, 0, 0, STATUS_INTERNAL_ERROR, 0, opaque, 0),
);
(STATUS_INTERNAL_ERROR, 24)
}
}
}
struct HeaderFields {
opcode: u8,
extras_len: u8,
key_len: u16,
status: u16,
body_len: u32,
opaque: u32,
cas: u64,
}
impl HeaderFields {
fn new(
opcode: u8,
extras_len: u8,
key_len: u16,
status: u16,
body_len: u32,
opaque: u32,
cas: u64,
) -> Self {
Self {
opcode,
extras_len,
key_len,
status,
body_len,
opaque,
cas,
}
}
}
fn encode_header(out: &mut BytesMut, header: HeaderFields) -> usize {
let HeaderFields {
opcode,
extras_len,
key_len,
status,
body_len,
opaque,
cas,
} = header;
out.extend_from_slice(&[
0x81,
opcode,
(key_len >> 8) as u8,
(key_len & 0xff) as u8,
extras_len,
0x00,
(status >> 8) as u8,
(status & 0xff) as u8,
((body_len >> 24) & 0xff) as u8,
((body_len >> 16) & 0xff) as u8,
((body_len >> 8) & 0xff) as u8,
(body_len & 0xff) as u8,
((opaque >> 24) & 0xff) as u8,
((opaque >> 16) & 0xff) as u8,
((opaque >> 8) & 0xff) as u8,
(opaque & 0xff) as u8,
((cas >> 56) & 0xff) as u8,
((cas >> 48) & 0xff) as u8,
((cas >> 40) & 0xff) as u8,
((cas >> 32) & 0xff) as u8,
((cas >> 24) & 0xff) as u8,
((cas >> 16) & 0xff) as u8,
((cas >> 8) & 0xff) as u8,
(cas & 0xff) as u8,
]);
24
}
fn encode_stat_line(meta: RequestMeta, line: &StatLine, out: &mut BytesMut) -> usize {
let opcode = meta.opcode;
let opaque = meta.opaque.unwrap_or(0);
let key_len = line.key.len() as u16;
let body_len = key_len as u32 + line.value.len() as u32;
encode_header(
out,
HeaderFields::new(opcode, 0, key_len, STATUS_SUCCESS, body_len, opaque, 0),
);
out.extend_from_slice(line.key.as_ref());
out.extend_from_slice(line.value.as_ref());
24 + body_len as usize
}
#[cfg(test)]
mod tests {
use super::*;
use bytes::BytesMut;
#[test]
fn decode_get_request() {
let key = b"foo";
let mut buf = BytesMut::with_capacity(24 + key.len());
buf.extend_from_slice(&[
0x80,
0x00, 0x00,
key.len() as u8, 0x00, 0x00, 0x00,
0x00, 0x00,
0x00,
0x00,
key.len() as u8, 0xde,
0xad,
0xbe,
0xef, 0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00, ]);
buf.extend_from_slice(key);
let mut decoder = BinaryDecoder::new();
let outcome = decoder.decode(&mut buf, BinaryLimits::default());
let (req, meta) = match outcome {
Some(DecodeOutcome::Request(req, meta)) => (req, meta),
_ => panic!("unexpected outcome"),
};
assert_eq!(req.op, Op::Get);
assert_eq!(req.key.unwrap(), Bytes::from_static(b"foo"));
assert_eq!(meta.protocol, Protocol::Binary);
}
#[test]
fn decode_invalid_magic() {
let mut buf = BytesMut::with_capacity(24);
buf.extend_from_slice(&[
0x81, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
let mut decoder = BinaryDecoder::new();
let outcome = decoder.decode(&mut buf, BinaryLimits::default());
match outcome {
Some(DecodeOutcome::Response(_, Response::Error(_))) => {}
_ => panic!("unexpected outcome"),
}
}
#[test]
fn decode_flush_with_exptime() {
let mut buf = BytesMut::with_capacity(28);
buf.extend_from_slice(&[
0x80, 0x08, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]);
buf.extend_from_slice(&[0x00, 0x00, 0x00, 0x0a]);
let mut decoder = BinaryDecoder::new();
let outcome = decoder.decode(&mut buf, BinaryLimits::default());
let (req, _meta) = match outcome {
Some(DecodeOutcome::Request(req, meta)) => (req, meta),
_ => panic!("unexpected outcome"),
};
assert_eq!(req.op, Op::Flush);
assert_eq!(req.exptime, Some(10));
}
#[test]
fn decode_sasl_auth() {
let mechanism = b"PLAIN";
let payload = b"\0user\0pass";
let body_len = mechanism.len() + payload.len();
let mut buf = BytesMut::with_capacity(24 + body_len);
buf.extend_from_slice(&[
0x80,
0x21, 0x00,
mechanism.len() as u8, 0x00, 0x00, 0x00,
0x00, 0x00,
0x00,
0x00,
body_len as u8, 0x00,
0x00,
0x00,
0x00, 0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00,
0x00, ]);
buf.extend_from_slice(mechanism);
buf.extend_from_slice(payload);
let mut decoder = BinaryDecoder::new();
let outcome = decoder.decode(&mut buf, BinaryLimits::default());
let (req, _meta) = match outcome {
Some(DecodeOutcome::Request(req, meta)) => (req, meta),
_ => panic!("unexpected outcome"),
};
assert_eq!(req.op, Op::SaslAuth);
assert_eq!(req.key.unwrap(), Bytes::from_static(b"PLAIN"));
assert_eq!(req.value.unwrap(), Bytes::from_static(b"\0user\0pass"));
}
}