use crate::error::ProtocolError;
use crate::request::{find_crlf, parse_int};
#[derive(Debug, Clone, PartialEq)]
pub enum Reply {
Simple(Vec<u8>),
Error(Vec<u8>),
Int(i64),
Bulk(Vec<u8>),
Nil,
Array(Vec<Reply>),
Map(Vec<(Reply, Reply)>),
Set(Vec<Reply>),
Double(f64),
Boolean(bool),
Verbatim {
fmt: [u8; 3],
data: Vec<u8>,
},
BigNumber(Vec<u8>),
Null,
Push(Vec<Reply>),
BlobError(Vec<u8>),
}
pub fn parse_reply(buf: &[u8]) -> Result<Option<(Reply, usize)>, ProtocolError> {
let Some(&tag) = buf.first() else {
return Ok(None);
};
match tag {
b'+' => Ok(reply_line(buf).map(|(b, used)| (Reply::Simple(b.to_vec()), used))),
b'-' => Ok(reply_line(buf).map(|(b, used)| (Reply::Error(b.to_vec()), used))),
b':' => match reply_line(buf) {
None => Ok(None),
Some((b, used)) => {
let n = parse_int(b).ok_or(ProtocolError::Malformed("bad integer reply"))?;
Ok(Some((Reply::Int(n), used)))
}
},
b'$' => parse_bulk_reply(buf),
b'*' => parse_array_reply(buf, false),
b'%' => parse_map_reply(buf),
b'~' => parse_set_reply(buf),
b',' => parse_double_reply(buf),
b'#' => parse_boolean_reply(buf),
b'=' => parse_verbatim_reply(buf),
b'(' => match reply_line(buf) {
None => Ok(None),
Some((b, used)) => Ok(Some((Reply::BigNumber(b.to_vec()), used))),
},
b'_' => parse_null_reply(buf),
b'>' => parse_array_reply(buf, true),
b'!' => parse_blob_error_reply(buf),
b'|' => parse_attributed_reply(buf),
_ => Err(ProtocolError::Malformed("unknown reply type")),
}
}
fn reply_line(buf: &[u8]) -> Option<(&[u8], usize)> {
find_crlf(buf, 1).map(|eol| (&buf[1..eol], eol + 2))
}
fn parse_bulk_reply(buf: &[u8]) -> Result<Option<(Reply, usize)>, ProtocolError> {
let Some(hdr_end) = find_crlf(buf, 1) else {
return Ok(None);
};
let len = parse_int(&buf[1..hdr_end]).ok_or(ProtocolError::Malformed("bad bulk length"))?;
if len < 0 {
return Ok(Some((Reply::Nil, hdr_end + 2)));
}
let data_start = hdr_end + 2;
let data_end = data_start + len as usize;
if buf.len() < data_end + 2 {
return Ok(None);
}
Ok(Some((
Reply::Bulk(buf[data_start..data_end].to_vec()),
data_end + 2,
)))
}
fn parse_array_reply(buf: &[u8], push: bool) -> Result<Option<(Reply, usize)>, ProtocolError> {
let Some(hdr_end) = find_crlf(buf, 1) else {
return Ok(None);
};
let count = parse_int(&buf[1..hdr_end]).ok_or(ProtocolError::Malformed("bad array length"))?;
if count < 0 {
if push {
return Err(ProtocolError::Malformed("push frame cannot be null"));
}
return Ok(Some((Reply::Nil, hdr_end + 2)));
}
let mut pos = hdr_end + 2;
let cap = (count as usize).min(buf.len().saturating_sub(pos));
let mut items = Vec::with_capacity(cap);
for _ in 0..count {
match parse_reply(&buf[pos..])? {
None => return Ok(None),
Some((r, used)) => {
items.push(r);
pos += used;
}
}
}
let reply = if push { Reply::Push(items) } else { Reply::Array(items) };
Ok(Some((reply, pos)))
}
fn parse_map_reply(buf: &[u8]) -> Result<Option<(Reply, usize)>, ProtocolError> {
let Some(hdr_end) = find_crlf(buf, 1) else {
return Ok(None);
};
let count = parse_int(&buf[1..hdr_end]).ok_or(ProtocolError::Malformed("bad map length"))?;
if count < 0 {
return Err(ProtocolError::Malformed("map length cannot be negative"));
}
let mut pos = hdr_end + 2;
let cap = (count as usize).min(buf.len().saturating_sub(pos) / 2);
let mut pairs: Vec<(Reply, Reply)> = Vec::with_capacity(cap);
for _ in 0..count {
let Some((k, used_k)) = parse_reply(&buf[pos..])? else {
return Ok(None);
};
pos += used_k;
let Some((v, used_v)) = parse_reply(&buf[pos..])? else {
return Ok(None);
};
pos += used_v;
pairs.push((k, v));
}
Ok(Some((Reply::Map(pairs), pos)))
}
fn parse_set_reply(buf: &[u8]) -> Result<Option<(Reply, usize)>, ProtocolError> {
let Some(hdr_end) = find_crlf(buf, 1) else {
return Ok(None);
};
let count = parse_int(&buf[1..hdr_end]).ok_or(ProtocolError::Malformed("bad set length"))?;
if count < 0 {
return Err(ProtocolError::Malformed("set length cannot be negative"));
}
let mut pos = hdr_end + 2;
let cap = (count as usize).min(buf.len().saturating_sub(pos));
let mut items = Vec::with_capacity(cap);
for _ in 0..count {
match parse_reply(&buf[pos..])? {
None => return Ok(None),
Some((r, used)) => {
items.push(r);
pos += used;
}
}
}
Ok(Some((Reply::Set(items), pos)))
}
fn parse_double_reply(buf: &[u8]) -> Result<Option<(Reply, usize)>, ProtocolError> {
let Some((bytes, used)) = reply_line(buf) else {
return Ok(None);
};
let s = std::str::from_utf8(bytes).map_err(|_| ProtocolError::Malformed("bad double utf8"))?;
let v: f64 = s.parse().map_err(|_| ProtocolError::Malformed("bad double"))?;
Ok(Some((Reply::Double(v), used)))
}
fn parse_boolean_reply(buf: &[u8]) -> Result<Option<(Reply, usize)>, ProtocolError> {
let Some((bytes, used)) = reply_line(buf) else {
return Ok(None);
};
let v = match bytes {
b"t" => true,
b"f" => false,
_ => return Err(ProtocolError::Malformed("bad boolean payload")),
};
Ok(Some((Reply::Boolean(v), used)))
}
fn parse_verbatim_reply(buf: &[u8]) -> Result<Option<(Reply, usize)>, ProtocolError> {
let Some(hdr_end) = find_crlf(buf, 1) else {
return Ok(None);
};
let len = parse_int(&buf[1..hdr_end])
.ok_or(ProtocolError::Malformed("bad verbatim length"))?;
if len < 4 {
return Err(ProtocolError::Malformed("verbatim length < 4 (fmt + ':')"));
}
let data_start = hdr_end + 2;
let data_end = data_start + len as usize;
if buf.len() < data_end + 2 {
return Ok(None);
}
let body = &buf[data_start..data_end];
if body[3] != b':' {
return Err(ProtocolError::Malformed("verbatim missing fmt:data separator"));
}
let mut fmt = [0u8; 3];
fmt.copy_from_slice(&body[..3]);
let data = body[4..].to_vec();
Ok(Some((Reply::Verbatim { fmt, data }, data_end + 2)))
}
fn parse_null_reply(buf: &[u8]) -> Result<Option<(Reply, usize)>, ProtocolError> {
if buf.len() < 3 {
return Ok(None);
}
if &buf[..3] != b"_\r\n" {
return Err(ProtocolError::Malformed("bad null payload"));
}
Ok(Some((Reply::Null, 3)))
}
fn parse_blob_error_reply(buf: &[u8]) -> Result<Option<(Reply, usize)>, ProtocolError> {
let Some(hdr_end) = find_crlf(buf, 1) else {
return Ok(None);
};
let len = parse_int(&buf[1..hdr_end])
.ok_or(ProtocolError::Malformed("bad blob error length"))?;
if len < 0 {
return Err(ProtocolError::Malformed("blob error length cannot be negative"));
}
let data_start = hdr_end + 2;
let data_end = data_start + len as usize;
if buf.len() < data_end + 2 {
return Ok(None);
}
Ok(Some((Reply::BlobError(buf[data_start..data_end].to_vec()), data_end + 2)))
}
fn parse_attributed_reply(buf: &[u8]) -> Result<Option<(Reply, usize)>, ProtocolError> {
let Some((_attrs, used_attrs)) = parse_map_reply(buf)? else {
return Ok(None);
};
match parse_reply(&buf[used_attrs..])? {
None => Ok(None),
Some((r, used)) => Ok(Some((r, used_attrs + used))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_replies() {
let r = |b: &[u8]| parse_reply(b).unwrap().unwrap().0;
assert_eq!(r(b"+OK\r\n"), Reply::Simple(b"OK".to_vec()));
assert_eq!(r(b"-ERR bad\r\n"), Reply::Error(b"ERR bad".to_vec()));
assert_eq!(r(b":42\r\n"), Reply::Int(42));
assert_eq!(r(b"$5\r\nhello\r\n"), Reply::Bulk(b"hello".to_vec()));
assert_eq!(r(b"$-1\r\n"), Reply::Nil);
assert_eq!(r(b"*-1\r\n"), Reply::Nil);
let (arr, used) = parse_reply(b"*2\r\n:1\r\n$2\r\nhi\r\n").unwrap().unwrap();
assert_eq!(
arr,
Reply::Array(vec![Reply::Int(1), Reply::Bulk(b"hi".to_vec())])
);
assert_eq!(used, 16);
assert_eq!(parse_reply(b"$5\r\nhel").unwrap(), None);
assert_eq!(parse_reply(b"*2\r\n:1\r\n").unwrap(), None);
assert!(parse_reply(b"@huh\r\n").is_err());
}
#[test]
fn parse_resp3_scalars() {
let r = |b: &[u8]| parse_reply(b).unwrap().unwrap().0;
assert_eq!(r(b"_\r\n"), Reply::Null);
assert_eq!(r(b"#t\r\n"), Reply::Boolean(true));
assert_eq!(r(b"#f\r\n"), Reply::Boolean(false));
assert_eq!(r(b",1.5\r\n"), Reply::Double(1.5));
assert_eq!(r(b",inf\r\n"), Reply::Double(f64::INFINITY));
assert_eq!(r(b",-inf\r\n"), Reply::Double(f64::NEG_INFINITY));
match r(b",nan\r\n") {
Reply::Double(v) => assert!(v.is_nan()),
other => panic!("expected Double(nan), got {other:?}"),
}
assert_eq!(
r(b"(170141183460469231731687303715884105727\r\n"),
Reply::BigNumber(b"170141183460469231731687303715884105727".to_vec())
);
assert_eq!(
r(b"!11\r\nERR bad cmd\r\n"),
Reply::BlobError(b"ERR bad cmd".to_vec())
);
}
#[test]
fn parse_resp3_verbatim() {
let r = |b: &[u8]| parse_reply(b).unwrap().unwrap().0;
assert_eq!(
r(b"=15\r\ntxt:Some string\r\n"),
Reply::Verbatim { fmt: *b"txt", data: b"Some string".to_vec() }
);
assert!(parse_reply(b"=3\r\ntxt\r\n").is_err());
assert!(parse_reply(b"=7\r\ntxt+abc\r\n").is_err());
}
#[test]
fn parse_resp3_map_and_set() {
let r = |b: &[u8]| parse_reply(b).unwrap().unwrap().0;
let m = r(b"%2\r\n:1\r\n$1\r\na\r\n:2\r\n$1\r\nb\r\n");
assert_eq!(
m,
Reply::Map(vec![
(Reply::Int(1), Reply::Bulk(b"a".to_vec())),
(Reply::Int(2), Reply::Bulk(b"b".to_vec())),
])
);
let s = r(b"~3\r\n:1\r\n:2\r\n:3\r\n");
assert_eq!(s, Reply::Set(vec![Reply::Int(1), Reply::Int(2), Reply::Int(3)]));
assert_eq!(r(b"%0\r\n"), Reply::Map(vec![]));
assert_eq!(r(b"~0\r\n"), Reply::Set(vec![]));
assert!(parse_reply(b"%-1\r\n").is_err());
assert!(parse_reply(b"~-1\r\n").is_err());
}
#[test]
fn parse_resp3_push_frame() {
let r = |b: &[u8]| parse_reply(b).unwrap().unwrap().0;
let push = r(b">3\r\n+message\r\n$4\r\nnews\r\n$5\r\nhello\r\n");
assert_eq!(
push,
Reply::Push(vec![
Reply::Simple(b"message".to_vec()),
Reply::Bulk(b"news".to_vec()),
Reply::Bulk(b"hello".to_vec()),
])
);
assert!(parse_reply(b">-1\r\n").is_err());
}
#[test]
fn parse_resp3_attributes_are_skipped() {
let frame =
b"|1\r\n+key-popularity\r\n%2\r\n$1\r\na\r\n,0.5\r\n$1\r\nb\r\n,0.3\r\n*2\r\n:1\r\n:2\r\n";
let (r, used) = parse_reply(frame).unwrap().unwrap();
assert_eq!(r, Reply::Array(vec![Reply::Int(1), Reply::Int(2)]));
assert_eq!(used, frame.len());
}
#[test]
fn parse_resp3_partial_returns_none() {
for cut in [b"_".as_slice(), b"_\r", b"#t", b"#t\r"].iter() {
assert_eq!(parse_reply(cut).unwrap(), None);
}
assert_eq!(parse_reply(b"=15\r\ntxt:Some str").unwrap(), None);
assert_eq!(parse_reply(b"%2\r\n:1\r\n$1\r\na\r\n:2\r\n").unwrap(), None);
}
}