use anyhow::{anyhow, Result};
use bstr::{ByteSlice, ByteVec};
#[derive(PartialEq, Debug)]
pub enum RespType {
SimpleString(String),
Error(String),
Integer(i64),
BulkString(Vec<u8>),
Array(Vec<RespType>),
Null,
NullArray,
}
impl<'a> RespType {
pub fn as_bytes(self) -> Vec<u8> {
use RespType::*;
let mut bytes = Vec::new();
match self {
SimpleString(string) => {
bytes.push_str("+");
bytes.push_str(string);
bytes.push_str("\r\n")
}
Error(string) => {
bytes.push_char('-');
bytes.push_str(string);
bytes.push_str("\r\n")
}
Integer(string) => {
bytes.push_char(':');
bytes.push_str(string.to_string());
bytes.push_str("\r\n")
}
BulkString(string) => {
bytes.push_char('$');
bytes.push_str(string.len().to_string());
bytes.push_str("\r\n");
bytes.push_str(string);
bytes.push_str("\r\n")
}
Array(array) => {
bytes.push_char('*');
bytes.push_str(array.len().to_string());
bytes.push_str("\r\n");
for i in array {
bytes.push_str(i.as_bytes())
}
}
Null => bytes.push_str("$-1\r\n"),
NullArray => bytes.push_str("*-1\r\n"),
};
bytes
}
pub fn simple_string(string: String) -> Result<Self> {
if string.contains('\r') || string.contains('\n') {
Err(anyhow!("Simple string contains \\r or \\n"))
} else {
Ok(RespType::SimpleString(string))
}
}
pub fn error(string: String) -> Result<Self> {
if string.contains('\r') || string.contains('\n') {
Err(anyhow!("Error type contains \\r or \\n"))
} else {
Ok(RespType::Error(string))
}
}
pub fn integer(int: i64) -> Self {
RespType::Integer(int)
}
pub fn bulk_string(string: Vec<u8>) -> Self {
RespType::BulkString(string)
}
pub fn array(array: Vec<RespType>) -> Self {
RespType::Array(array)
}
pub fn command(command: Vec<Vec<u8>>) -> Self {
let mut cmd = Vec::new();
for i in command {
cmd.push(RespType::bulk_string(i.as_bytes().to_vec()))
}
RespType::array(cmd)
}
}
impl TryInto<String> for RespType {
type Error = anyhow::Error;
fn try_into(self) -> Result<String> {
let result = match self {
RespType::SimpleString(string) => string,
RespType::Error(error) => error,
RespType::Integer(integer) => integer.to_string(),
RespType::BulkString(string) => {
std::str::from_utf8(&string)?.to_string()
},
RespType::Array(array) => {
let mut converted = Vec::<String>::new();
for i in array {
converted.push(i.try_into()?)
}
format!("[{}]", converted.join(", "))
},
RespType::Null => "(null)".to_string(),
RespType::NullArray => "[null]".to_string(),
};
Ok(result)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_expected_encode(resp: RespType, expected: &str) {
let encoded = resp.as_bytes();
assert_eq!(expected.as_bytes(), encoded);
}
#[test]
fn simple_string() -> Result<()> {
let input = "test";
let resp = RespType::simple_string(input.into())?;
assert_expected_encode(resp, "+test\r\n");
Ok(())
}
#[test]
fn error() -> Result<()> {
let input = "error";
let resp = RespType::error(input.into())?;
assert_expected_encode(resp, "-error\r\n");
Ok(())
}
#[test]
fn integer() {
let resp = RespType::integer(42);
assert_expected_encode(resp, ":42\r\n");
}
#[test]
fn bulk_string() {
let input = "test";
let resp = RespType::bulk_string(input.into());
assert_expected_encode(resp, "$4\r\ntest\r\n");
}
#[test]
fn empty_bulk_string() {
let resp = RespType::bulk_string("".into());
assert_expected_encode(resp, "$0\r\n\r\n");
}
#[test]
fn null_bulk_string() {
let resp = RespType::Null;
assert_expected_encode(resp, "$-1\r\n");
}
#[test]
fn array() {
let resp = RespType::Array(vec![RespType::SimpleString("test!".into())]);
assert_expected_encode(resp, "*1\r\n+test!\r\n");
}
#[test]
fn empty_array() {
let resp = RespType::Array(Vec::new());
assert_expected_encode(resp, "*0\r\n");
}
#[test]
fn null_array() {
let resp = RespType::NullArray;
assert_expected_encode(resp, "*-1\r\n");
}
}