use crate::{RedisError, Result};
use bytes::{Buf, Bytes, BytesMut};
use std::io::{BufRead, Cursor};
#[derive(Debug, PartialEq)]
pub struct BigInt {
sign: bool,
data: Vec<u8>,
}
#[derive(Debug, PartialEq)]
pub enum Frame {
SimpleString(String),
SimpleError(String),
Integer(i64),
BulkString(Bytes),
Array(Vec<Frame>),
Null,
Boolean(bool),
Double(f64),
BigNumber(BigInt),
BulkError(Bytes),
VerbatimString(Bytes, Bytes),
Map(Vec<(Frame, Frame)>),
Attribute,
Set(Vec<Frame>),
Push,
}
impl Frame {
pub const fn array() -> Self {
Frame::Array(Vec::new())
}
pub fn push_frame_to_array(&mut self, frame: Frame) -> Result<()> {
match self {
Frame::Array(vec) => {
vec.push(frame);
Ok(())
}
Frame::Set(vec) => {
vec.push(frame);
Ok(())
}
_ => Err(RedisError::Unknown),
}
}
pub fn push_frame_to_map(&mut self, key: Frame, value: Frame) -> Result<()> {
match self {
Frame::Map(vec) => {
vec.push((key, value));
Ok(())
}
_ => Err(RedisError::Unknown),
}
}
pub async fn serialize(&self) -> Result<Bytes> {
match self {
Frame::SimpleString(val) => {
let mut buf = BytesMut::with_capacity(val.len() + 3);
buf.extend_from_slice(b"+");
buf.extend_from_slice(val.as_bytes());
buf.extend_from_slice(b"\r\n");
Ok(buf.freeze()) }
Frame::SimpleError(val) => {
let mut buf = BytesMut::with_capacity(val.len() + 3);
buf.extend_from_slice(b"-");
buf.extend_from_slice(val.as_bytes());
buf.extend_from_slice(b"\r\n");
Ok(buf.freeze())
}
Frame::Integer(val) => {
let mut buf = BytesMut::with_capacity(20);
buf.extend_from_slice(b":");
buf.extend_from_slice(val.to_string().as_bytes());
buf.extend_from_slice(b"\r\n");
Ok(buf.freeze())
}
Frame::BulkString(val) => {
let mut buf = BytesMut::with_capacity(val.len() + 5);
buf.extend_from_slice(b"$");
buf.extend_from_slice(val.len().to_string().as_bytes());
buf.extend_from_slice(b"\r\n");
buf.extend_from_slice(val.as_ref());
buf.extend_from_slice(b"\r\n");
Ok(buf.freeze())
}
Frame::Array(frame_vec) => {
let mut buf = BytesMut::new();
buf.extend_from_slice(b"*");
buf.extend_from_slice(frame_vec.len().to_string().as_bytes());
buf.extend_from_slice(b"\r\n");
for frame in frame_vec {
buf.extend_from_slice(&Box::pin(frame.serialize()).await?);
}
Ok(buf.freeze())
}
Frame::Null => {
let mut buf = BytesMut::with_capacity(3);
buf.extend_from_slice(b"_\r\n");
Ok(buf.freeze())
}
Frame::Boolean(val) => {
let mut buf: BytesMut = BytesMut::with_capacity(3);
buf.extend_from_slice(b"#");
buf.extend_from_slice(if *val { b"t" } else { b"f" });
buf.extend_from_slice(b"\r\n");
Ok(buf.freeze())
}
Frame::Double(val) => {
let mut buf: BytesMut = BytesMut::with_capacity(20);
buf.extend_from_slice(b",");
if val.is_nan() {
buf.extend_from_slice(b"nan");
} else {
match *val {
f64::INFINITY => buf.extend_from_slice(b"inf"),
f64::NEG_INFINITY => buf.extend_from_slice(b"-inf"),
_ => {
buf.extend_from_slice(val.to_string().as_bytes());
}
}
}
buf.extend_from_slice(b"\r\n");
Ok(buf.freeze())
}
Frame::BigNumber(val) => {
todo!("BigNumber serialization is not implemented yet {:?}", val)
}
Frame::BulkError(val) => {
let mut buf = BytesMut::with_capacity(val.len() + 5);
buf.extend_from_slice(b"!");
buf.extend_from_slice(val.len().to_string().as_bytes());
buf.extend_from_slice(b"\r\n");
buf.extend_from_slice(val.as_ref());
buf.extend_from_slice(b"\r\n");
Ok(buf.freeze())
}
Frame::VerbatimString(encoding, val) => {
let mut buf: BytesMut = BytesMut::with_capacity(val.len() + 10);
buf.extend_from_slice(b"=");
buf.extend_from_slice((val.len() + 4).to_string().as_bytes());
buf.extend_from_slice(b"\r\n");
buf.extend_from_slice(encoding.as_ref());
buf.extend_from_slice(b":");
buf.extend_from_slice(val.as_ref());
buf.extend_from_slice(b"\r\n");
Ok(buf.freeze())
}
Frame::Map(val) => {
let mut buf: BytesMut = BytesMut::new();
buf.extend_from_slice(b"%");
buf.extend_from_slice(val.len().to_string().as_bytes());
buf.extend_from_slice(b"\r\n");
for (key, value) in val {
buf.extend_from_slice(&Box::pin(key.serialize()).await?);
buf.extend_from_slice(&Box::pin(value.serialize()).await?);
}
Ok(buf.freeze())
}
Frame::Attribute => {
todo!("Attribute serialization is not implemented yet")
}
Frame::Set(val) => {
let mut buf: BytesMut = BytesMut::new();
buf.extend_from_slice(b"~");
buf.extend_from_slice(val.len().to_string().as_bytes());
buf.extend_from_slice(b"\r\n");
for frame in val {
buf.extend_from_slice(&Box::pin(frame.serialize()).await?);
}
Ok(buf.freeze())
}
Frame::Push => {
todo!("Push serialization is not implemented yet")
}
}
}
pub async fn deserialize(buf: Bytes) -> Result<Frame> {
Frame::try_parse(&mut Cursor::new(&buf[..]))
}
pub fn try_parse(cursor: &mut Cursor<&[u8]>) -> Result<Frame> {
if !cursor.has_remaining() {
return Err(RedisError::IncompleteFrame);
}
match cursor.get_u8() {
b'+' => {
let mut buf = String::new();
let _ = cursor.read_line(&mut buf).unwrap();
if buf.ends_with("\r\n") {
Ok(Frame::SimpleString(
buf.trim_end_matches("\r\n").to_string(),
))
} else {
Err(RedisError::IncompleteFrame)
}
}
b'-' => {
let mut buf = String::new();
let _ = cursor.read_line(&mut buf).unwrap();
if buf.ends_with("\r\n") {
Ok(Frame::SimpleError(buf.trim_end_matches("\r\n").to_string()))
} else {
Err(RedisError::IncompleteFrame)
}
}
b':' => {
let mut buf = String::new();
let _ = cursor.read_line(&mut buf).unwrap();
if buf.ends_with("\r\n") {
Ok(Frame::Integer(
buf.trim_end_matches("\r\n").parse::<i64>().unwrap(),
))
} else {
Err(RedisError::IncompleteFrame)
}
}
b'$' => {
let mut buf = String::new();
let _ = cursor.read_line(&mut buf).unwrap();
if !buf.ends_with("\r\n") {
return Err(RedisError::IncompleteFrame);
}
let len: isize = buf.trim_end_matches("\r\n").parse::<isize>().unwrap();
if len == -1 {
return Ok(Frame::Null);
}
if cursor.remaining() < len as usize + 2 {
return Err(RedisError::IncompleteFrame);
}
let data = Bytes::copy_from_slice(&cursor.chunk()[..len as usize]);
cursor.advance(len as usize + 2);
Ok(Frame::BulkString(data))
}
b'*' => {
let mut buf = String::new();
let _ = cursor.read_line(&mut buf).unwrap();
let len = buf.trim_end_matches("\r\n").parse::<usize>().unwrap();
let mut frame_vec: Vec<_> = Vec::with_capacity(len);
for _ in 0..len {
frame_vec.push(Frame::try_parse(cursor)?);
}
Ok(Frame::Array(frame_vec))
}
b'_' => Ok(Frame::Null),
b'#' => {
let mut buf = String::new();
let _ = cursor.read_line(&mut buf).unwrap();
if buf.ends_with("\r\n") {
let val = buf.trim_end_matches("\r\n");
if val == "t" {
Ok(Frame::Boolean(true))
} else if val == "f" {
Ok(Frame::Boolean(false))
} else {
Err(RedisError::InvalidFrame)
}
} else {
Err(RedisError::IncompleteFrame)
}
}
b',' => {
let mut buf = String::new();
let _ = cursor.read_line(&mut buf).unwrap();
if buf.ends_with("\r\n") {
let val = buf.trim_end_matches("\r\n");
if val == "nan" {
Ok(Frame::Double(f64::NAN))
} else if val == "inf" {
Ok(Frame::Double(f64::INFINITY))
} else if val == "-inf" {
Ok(Frame::Double(f64::NEG_INFINITY))
} else {
Ok(Frame::Double(val.parse::<f64>().unwrap()))
}
} else {
Err(RedisError::IncompleteFrame)
}
}
b'(' => {
todo!("Big number deserialization is not implemented yet")
}
b'!' => {
let mut buf = String::new();
let _ = cursor.read_line(&mut buf).unwrap();
if !buf.ends_with("\r\n") {
return Err(RedisError::IncompleteFrame);
}
let len: isize = buf.trim_end_matches("\r\n").parse::<isize>().unwrap();
if len == -1 {
return Ok(Frame::Null);
}
let len: usize = len.try_into().unwrap();
if cursor.remaining() < len + 2 {
return Err(RedisError::IncompleteFrame);
}
if cursor.chunk()[len] != b'\r' || cursor.chunk()[len + 1] != b'\n' {
return Err(RedisError::InvalidFrame);
}
let data = Bytes::copy_from_slice(&cursor.chunk()[..len]);
cursor.advance(len + 2);
Ok(Frame::BulkError(data))
}
b'=' => {
let mut buf = String::new();
let _ = cursor.read_line(&mut buf).unwrap();
if !buf.ends_with("\r\n") {
return Err(RedisError::IncompleteFrame);
}
let len: usize = buf.trim_end_matches("\r\n").parse::<usize>().unwrap();
if cursor.remaining() < len + 2 {
return Err(RedisError::IncompleteFrame);
}
if !cursor.chunk()[len..].starts_with(b"\r\n") {
return Err(RedisError::InvalidFrame);
}
let mut data = Bytes::copy_from_slice(&cursor.chunk()[..len]);
let encoding: Bytes = data.split_to(3);
data.advance(1);
cursor.advance(len + 2);
Ok(Frame::VerbatimString(encoding, data))
}
b'%' => {
let mut buf = String::new();
let _ = cursor.read_line(&mut buf).unwrap();
let len = buf.trim_end_matches("\r\n").parse::<usize>().unwrap();
let mut frame_vec: Vec<_> = Vec::with_capacity(len);
for _ in 0..len {
let key = Frame::try_parse(cursor)?;
let value = Frame::try_parse(cursor)?;
frame_vec.push((key, value));
}
Ok(Frame::Map(frame_vec))
}
b'&' => {
todo!("Attribute deserialization is not implemented yet")
}
b'~' => {
let mut buf = String::new();
let _ = cursor.read_line(&mut buf).unwrap();
let len = buf.trim_end_matches("\r\n").parse::<usize>().unwrap();
let mut frame_vec: Vec<_> = Vec::with_capacity(len);
for _ in 0..len {
frame_vec.push(Frame::try_parse(cursor)?);
}
Ok(Frame::Set(frame_vec))
}
b'>' => {
todo!("Push deserialization is not implemented yet")
}
_ => Err(RedisError::InvalidFrame),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_serialize_simple_string() {
let frame = Frame::SimpleString("OK".to_string());
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"+OK\r\n"));
}
#[tokio::test]
async fn test_serialize_simple_error() {
let frame = Frame::SimpleError("ERR".to_string());
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"-ERR\r\n"));
}
#[tokio::test]
async fn test_serialize_integer() {
let frame = Frame::Integer(123_i64);
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b":123\r\n"));
let frame = Frame::Integer(-123_i64);
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b":-123\r\n"));
}
#[tokio::test]
async fn test_serialize_bulk_string() {
let frame = Frame::BulkString(Bytes::from_static(b"Hello Redis"));
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"$11\r\nHello Redis\r\n"));
let frame = Frame::BulkString(Bytes::from_static(b""));
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"$0\r\n\r\n"));
}
#[tokio::test]
async fn test_serialize_array() {
let mut frame = Frame::array();
frame
.push_frame_to_array(Frame::BulkString(Bytes::from_static(b"Hello")))
.unwrap();
frame
.push_frame_to_array(Frame::BulkString(Bytes::from_static(b"Redis")))
.unwrap();
let bytes = frame.serialize().await.unwrap();
assert_eq!(
bytes,
Bytes::from_static(b"*2\r\n$5\r\nHello\r\n$5\r\nRedis\r\n")
);
let frame = Frame::array();
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"*0\r\n"));
let mut frame: Frame = Frame::array();
let mut nested_frame = Frame::array();
nested_frame
.push_frame_to_array(Frame::BulkString(Bytes::from_static(b"Hello")))
.unwrap();
nested_frame
.push_frame_to_array(Frame::BulkString(Bytes::from_static(b"Redis")))
.unwrap();
if let Frame::Array(vec) = &mut frame {
vec.push(nested_frame);
}
let bytes = frame.serialize().await.unwrap();
assert_eq!(
bytes,
Bytes::from_static(b"*1\r\n*2\r\n$5\r\nHello\r\n$5\r\nRedis\r\n")
);
}
#[tokio::test]
async fn test_serialize_null() {
let frame = Frame::Null;
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"_\r\n"));
}
#[tokio::test]
async fn test_serialize_boolean() {
let frame = Frame::Boolean(true);
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"#t\r\n"));
let frame = Frame::Boolean(false);
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"#f\r\n"));
}
#[tokio::test]
async fn test_serialize_double() {
let frame = Frame::Double(123.456);
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b",123.456\r\n"));
let frame = Frame::Double(f64::NAN);
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b",nan\r\n"));
let frame = Frame::Double(f64::INFINITY);
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b",inf\r\n"));
let frame = Frame::Double(f64::NEG_INFINITY);
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b",-inf\r\n"));
}
#[tokio::test]
async fn test_serialize_bulk_error() {
let frame = Frame::BulkError(Bytes::from_static(b"Hello Redis"));
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"!11\r\nHello Redis\r\n"));
let frame = Frame::BulkError(Bytes::from_static(b""));
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"!0\r\n\r\n"));
}
#[tokio::test]
async fn test_serialize_verbatim_string() {
let frame = Frame::VerbatimString(
Bytes::from_static(b"txt"),
Bytes::from_static(b"Some string"),
);
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"=15\r\ntxt:Some string\r\n"));
let frame = Frame::VerbatimString(Bytes::from_static(b"txt"), Bytes::from_static(b""));
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"=4\r\ntxt:\r\n"));
}
#[tokio::test]
async fn test_serialize_map() {
let mut frame: Frame = Frame::Map(Vec::new());
frame
.push_frame_to_map(
Frame::SimpleString("key".to_string()),
Frame::SimpleString("value".to_string()),
)
.unwrap();
let bytes = frame.serialize().await.unwrap();
assert_eq!(bytes, Bytes::from_static(b"%1\r\n+key\r\n+value\r\n"));
}
#[tokio::test]
async fn test_serialize_set() {
let mut frame: Frame = Frame::Set(Vec::new());
frame
.push_frame_to_array(Frame::BulkString(Bytes::from_static(b"Hello")))
.unwrap();
frame
.push_frame_to_array(Frame::BulkString(Bytes::from_static(b"Redis")))
.unwrap();
let bytes = frame.serialize().await.unwrap();
assert_eq!(
bytes,
Bytes::from_static(b"~2\r\n$5\r\nHello\r\n$5\r\nRedis\r\n")
);
}
#[tokio::test]
async fn test_deserialize_simple_string() {
let bytes = Bytes::from_static(b"+OK\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::SimpleString("OK".to_string()));
}
#[tokio::test]
async fn test_deserialize_simple_error() {
let bytes = Bytes::from_static(b"-ERR\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::SimpleError("ERR".to_string()));
}
#[tokio::test]
async fn test_deserialize_integer() {
let bytes = Bytes::from_static(b":123\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::Integer(123_i64));
let bytes = Bytes::from_static(b":-123\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::Integer(-123_i64));
}
#[tokio::test]
async fn test_deserialize_bulk_string() {
let bytes = Bytes::from_static(b"$11\r\nHello Redis\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::BulkString(Bytes::from_static(b"Hello Redis")));
let bytes = Bytes::from_static(b"$0\r\n\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::BulkString(Bytes::from_static(b"")));
}
#[tokio::test]
async fn test_deserialize_array() {
let bytes = Bytes::from_static(b"*2\r\n$5\r\nHello\r\n$5\r\nRedis\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
let mut expected_frame = Frame::array();
expected_frame
.push_frame_to_array(Frame::BulkString(Bytes::from_static(b"Hello")))
.unwrap();
expected_frame
.push_frame_to_array(Frame::BulkString(Bytes::from_static(b"Redis")))
.unwrap();
assert_eq!(frame, expected_frame);
let bytes = Bytes::from_static(b"*0\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::array());
let bytes = Bytes::from_static(b"*1\r\n*2\r\n$5\r\nHello\r\n$5\r\nRedis\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
let mut expected_frame = Frame::array();
let mut nested_frame = Frame::array();
nested_frame
.push_frame_to_array(Frame::BulkString(Bytes::from_static(b"Hello")))
.unwrap();
nested_frame
.push_frame_to_array(Frame::BulkString(Bytes::from_static(b"Redis")))
.unwrap();
expected_frame.push_frame_to_array(nested_frame).unwrap();
assert_eq!(frame, expected_frame);
}
#[tokio::test]
async fn test_deserialize_null() {
let bytes = Bytes::from_static(b"_\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::Null);
}
#[tokio::test]
async fn test_deserialize_boolean() {
let bytes = Bytes::from_static(b"#t\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::Boolean(true));
let bytes = Bytes::from_static(b"#f\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::Boolean(false));
}
#[tokio::test]
async fn test_deserialize_double() {
let bytes = Bytes::from_static(b",123.456\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::Double(123.456));
let bytes = Bytes::from_static(b",nan\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
if let Frame::Double(val) = frame {
assert!(val.is_nan());
} else {
panic!("Expected a Double frame");
}
let bytes = Bytes::from_static(b",inf\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::Double(f64::INFINITY));
let bytes = Bytes::from_static(b",-inf\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::Double(f64::NEG_INFINITY));
}
#[tokio::test]
async fn test_deserialize_bulk_error() {
let bytes = Bytes::from_static(b"!11\r\nHello Redis\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::BulkError(Bytes::from_static(b"Hello Redis")));
let bytes = Bytes::from_static(b"!0\r\n\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(frame, Frame::BulkError(Bytes::from_static(b"")));
}
#[tokio::test]
async fn test_deserialize_verbatim_string() {
let bytes = Bytes::from_static(b"=15\r\ntxt:Some string\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(
frame,
Frame::VerbatimString(
Bytes::from_static(b"txt"),
Bytes::from_static(b"Some string")
)
);
let bytes = Bytes::from_static(b"=4\r\ntxt:\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
assert_eq!(
frame,
Frame::VerbatimString(Bytes::from_static(b"txt"), Bytes::from_static(b""))
);
}
#[tokio::test]
async fn test_deserialize_map() {
let bytes = Bytes::from_static(b"%1\r\n+key\r\n+value\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
let mut expected_frame = Frame::Map(Vec::new());
expected_frame
.push_frame_to_map(
Frame::SimpleString("key".to_string()),
Frame::SimpleString("value".to_string()),
)
.unwrap();
assert_eq!(frame, expected_frame);
}
#[tokio::test]
async fn test_deserialize_set() {
let bytes = Bytes::from_static(b"~2\r\n$5\r\nHello\r\n$5\r\nRedis\r\n");
let frame = Frame::deserialize(bytes).await.unwrap();
let mut expected_frame = Frame::Set(Vec::new());
expected_frame
.push_frame_to_array(Frame::BulkString(Bytes::from_static(b"Hello")))
.unwrap();
expected_frame
.push_frame_to_array(Frame::BulkString(Bytes::from_static(b"Redis")))
.unwrap();
assert_eq!(frame, expected_frame);
}
}