pub type Result<'a> = std::result::Result<(RESP<'a>, &'a [u8]), RError>;
const NIL_VALUE_SIZE: usize = 4;
const CR: u8 = b'\r';
const LF: u8 = b'\n';
pub struct RedisProtocolParser;
#[derive(Debug, Eq, PartialEq)]
pub enum RESP<'a> {
String(&'a [u8]),
Error(&'a [u8]),
Integer(&'a [u8]),
BulkString(&'a [u8]),
Nil,
Array(Vec<RESP<'a>>),
}
#[derive(Debug)]
pub enum RErrorType {
UnknownSymbol,
EmptyInput,
NoCrlf,
IncorrectFormat,
Other(Box<dyn std::error::Error>),
}
#[derive(Debug)]
pub struct RError {
err_type: RErrorType,
}
impl RError {
fn unknown_symbol() -> Self {
Self {
err_type: RErrorType::UnknownSymbol,
}
}
fn empty_input() -> Self {
Self {
err_type: RErrorType::EmptyInput,
}
}
fn no_crlf() -> Self {
Self {
err_type: RErrorType::NoCrlf,
}
}
fn incorrect_format() -> Self {
Self {
err_type: RErrorType::IncorrectFormat,
}
}
}
impl<'a> std::fmt::Display for RError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self)
}
}
impl<'a> std::error::Error for RError {}
impl<'a> From<std::str::Utf8Error> for RError {
fn from(from: std::str::Utf8Error) -> Self {
Self {
err_type: RErrorType::Other(Box::new(from)),
}
}
}
impl<'a> From<std::num::ParseIntError> for RError {
fn from(from: std::num::ParseIntError) -> Self {
Self {
err_type: RErrorType::Other(Box::new(from)),
}
}
}
impl RedisProtocolParser {
pub fn parse_resp(input: &[u8]) -> Result {
if let Some(first) = input.get(0) {
let first = *first as char;
let input = &input[1..];
let (resp, left) = match first {
'+' => RedisProtocolParser::parse_simple_string(input)?,
':' => RedisProtocolParser::parse_integers(input)?,
'$' => RedisProtocolParser::parse_bulk_strings(input)?,
'*' => RedisProtocolParser::parse_arrays(input)?,
'-' => RedisProtocolParser::parse_errors(input)?,
_ => return Err(RError::unknown_symbol()),
};
Ok((resp, left))
} else {
Err(RError::empty_input())
}
}
fn parse_everything_until_crlf(input: &[u8]) -> std::result::Result<(&[u8], &[u8]), RError> {
for (index, (first, second)) in input.iter().zip(input.iter().skip(1)).enumerate() {
if first == &CR && second == &LF {
return Ok((&input[0..index], &input[index + 2..]));
}
}
Err(RError::no_crlf())
}
pub fn parse_simple_string(input: &[u8]) -> Result {
RedisProtocolParser::parse_everything_until_crlf(input).map(|(x, y)| (RESP::String(x), y))
}
pub fn parse_errors(input: &[u8]) -> Result {
RedisProtocolParser::parse_everything_until_crlf(input).map(|(x, y)| (RESP::Error(x), y))
}
pub fn parse_integers(input: &[u8]) -> Result {
RedisProtocolParser::parse_everything_until_crlf(input).map(|(x, y)| (RESP::Integer(x), y))
}
pub fn parse_bulk_strings(input: &[u8]) -> Result {
if RedisProtocolParser::check_null_value(input) {
Ok((RESP::Nil, &input[NIL_VALUE_SIZE..]))
} else {
let (size_str, input_after_size) =
RedisProtocolParser::parse_everything_until_crlf(input)?;
let size = std::str::from_utf8(size_str)?.parse::<u64>()? as usize;
if RedisProtocolParser::check_crlf_at_index(input_after_size, size) {
Ok((
RESP::BulkString(&input_after_size[..size]),
&input_after_size[size + 2..],
))
} else {
Err(RError::incorrect_format())
}
}
}
fn check_crlf_at_index(input: &[u8], index: usize) -> bool {
input.len() >= index && input[index] == b'\r' && input[index + 1] == b'\n'
}
fn check_null_value(input: &[u8]) -> bool {
input.len() >= 4 && input[0] == b'-' && input[1] == b'1' && input[2] == CR && input[3] == LF
}
pub fn parse_arrays(input: &[u8]) -> Result {
let (size_str, input) = RedisProtocolParser::parse_everything_until_crlf(input)?;
let size = std::str::from_utf8(size_str)?.parse::<u64>()?;
let sizes = size as usize;
let mut left = input;
let mut result = Vec::with_capacity(sizes);
for _ in 0..sizes {
let (element, tmp) = RedisProtocolParser::parse_resp(left)?;
result.push(element);
left = tmp;
}
Ok((RESP::Array(result), left))
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
pub fn test_simple_string() -> std::result::Result<(), RError> {
let input = "+hello\r\n".as_bytes();
let (resp, left) = RedisProtocolParser::parse_resp(input)?;
assert_eq!(resp, RESP::String("hello".as_bytes()));
assert!(left.is_empty());
Ok(())
}
#[test]
pub fn test_errors() -> std::result::Result<(), RError> {
let input = "+hello".as_bytes();
let err = RedisProtocolParser::parse_resp(input).unwrap_err();
assert!(matches!(err.err_type, RErrorType::NoCrlf));
let input = "*2\r\n$3\r\nfoo\r\n)hello".as_bytes();
let err = RedisProtocolParser::parse_resp(input).unwrap_err();
assert!(matches!(err.err_type, RErrorType::UnknownSymbol));
let input = "".as_bytes();
let err = RedisProtocolParser::parse_resp(input).unwrap_err();
assert!(matches!(err.err_type, RErrorType::EmptyInput));
let input = "$4\r\nfoo\r\n".as_bytes();
let err = RedisProtocolParser::parse_resp(input).unwrap_err();
assert!(matches!(err.err_type, RErrorType::IncorrectFormat));
let input = "*2\r\n$3\r\nfoo+hello\r\n".as_bytes();
let err = RedisProtocolParser::parse_resp(input).unwrap_err();
assert!(matches!(err.err_type, RErrorType::IncorrectFormat));
Ok(())
}
#[test]
pub fn test_nil() -> std::result::Result<(), RError> {
let input = "$-1\r\n".as_bytes();
let (resp, left) = RedisProtocolParser::parse_resp(input)?;
assert_eq!(resp, RESP::Nil);
assert!(left.is_empty());
Ok(())
}
#[test]
pub fn test_bulk_string() -> std::result::Result<(), RError> {
let input = "$6\r\nfoobar\r\n".as_bytes();
let (resp, left) = RedisProtocolParser::parse_resp(input)?;
assert_eq!(resp, RESP::BulkString("foobar".as_bytes()));
assert!(left.is_empty());
let input = "$0\r\n\r\n".as_bytes();
let (resp, left) = RedisProtocolParser::parse_resp(input)?;
assert_eq!(resp, RESP::BulkString("".as_bytes()));
assert!(left.is_empty());
Ok(())
}
#[test]
pub fn test_arrays() -> std::result::Result<(), RError> {
let input = "*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n".as_bytes();
let (resp, left) = RedisProtocolParser::parse_resp(input)?;
assert_eq!(
resp,
RESP::Array(vec![
RESP::BulkString("foo".as_bytes()),
RESP::BulkString("bar".as_bytes())
])
);
assert!(left.is_empty());
let input = "*5\r\n:1\r\n:2\r\n:3\r\n:4\r\n$6\r\nfoobar\r\n".as_bytes();
let (resp, left) = RedisProtocolParser::parse_resp(input)?;
assert_eq!(
resp,
RESP::Array(vec![
RESP::Integer("1".as_bytes()),
RESP::Integer("2".as_bytes()),
RESP::Integer("3".as_bytes()),
RESP::Integer("4".as_bytes()),
RESP::BulkString("foobar".as_bytes()),
])
);
assert!(left.is_empty());
Ok(())
}
#[test]
pub fn test_array_of_arrays() -> std::result::Result<(), RError> {
let input = "*2\r\n*3\r\n:1\r\n:2\r\n:3\r\n*2\r\n+Foo\r\n-Bar\r\n".as_bytes();
let (resp, left) = RedisProtocolParser::parse_resp(input)?;
assert_eq!(
resp,
RESP::Array(vec![
RESP::Array(vec![
RESP::Integer("1".as_bytes()),
RESP::Integer("2".as_bytes()),
RESP::Integer("3".as_bytes()),
]),
RESP::Array(vec![
RESP::String("Foo".as_bytes()),
RESP::Error("Bar".as_bytes()),
]),
])
);
assert!(left.is_empty());
Ok(())
}
}