1pub type Result<'a> = std::result::Result<(RESP<'a>, &'a [u8]), RError>;
2
3const NIL_VALUE_SIZE: usize = 4;
4const CR: u8 = b'\r';
5const LF: u8 = b'\n';
6
7pub struct RedisProtocolParser;
8
9#[derive(Debug, Eq, PartialEq)]
10pub enum RESP<'a> {
11 String(&'a [u8]),
12 Error(&'a [u8]),
13 Integer(&'a [u8]),
14 BulkString(&'a [u8]),
15 Nil,
16 Array(Vec<RESP<'a>>),
17}
18
19#[derive(Debug)]
20pub enum RErrorType {
21 UnknownSymbol,
23 EmptyInput,
25 NoCrlf,
27 IncorrectFormat,
29 Other(Box<dyn std::error::Error>),
30}
31
32#[derive(Debug)]
33pub struct RError {
34 err_type: RErrorType,
35}
36
37impl RError {
38 fn unknown_symbol() -> Self {
39 Self {
40 err_type: RErrorType::UnknownSymbol,
41 }
42 }
43
44 fn empty_input() -> Self {
45 Self {
46 err_type: RErrorType::EmptyInput,
47 }
48 }
49
50 fn no_crlf() -> Self {
51 Self {
52 err_type: RErrorType::NoCrlf,
53 }
54 }
55 fn incorrect_format() -> Self {
56 Self {
57 err_type: RErrorType::IncorrectFormat,
58 }
59 }
60}
61
62impl<'a> std::fmt::Display for RError {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 write!(f, "{}", self)
65 }
66}
67
68impl<'a> std::error::Error for RError {}
69
70impl<'a> From<std::str::Utf8Error> for RError {
71 fn from(from: std::str::Utf8Error) -> Self {
72 Self {
73 err_type: RErrorType::Other(Box::new(from)),
74 }
75 }
76}
77
78impl<'a> From<std::num::ParseIntError> for RError {
79 fn from(from: std::num::ParseIntError) -> Self {
80 Self {
81 err_type: RErrorType::Other(Box::new(from)),
82 }
83 }
84}
85
86impl RedisProtocolParser {
87 pub fn parse_resp(input: &[u8]) -> Result {
88 if let Some(first) = input.get(0) {
89 let first = *first as char;
90 let input = &input[1..];
91 let (resp, left) = match first {
92 '+' => RedisProtocolParser::parse_simple_string(input)?,
93 ':' => RedisProtocolParser::parse_integers(input)?,
94 '$' => RedisProtocolParser::parse_bulk_strings(input)?,
95 '*' => RedisProtocolParser::parse_arrays(input)?,
96 '-' => RedisProtocolParser::parse_errors(input)?,
97 _ => return Err(RError::unknown_symbol()),
98 };
99 Ok((resp, left))
100 } else {
101 Err(RError::empty_input())
102 }
103 }
104
105 fn parse_everything_until_crlf(input: &[u8]) -> std::result::Result<(&[u8], &[u8]), RError> {
106 for (index, (first, second)) in input.iter().zip(input.iter().skip(1)).enumerate() {
107 if first == &CR && second == &LF {
108 return Ok((&input[0..index], &input[index + 2..]));
109 }
110 }
111 Err(RError::no_crlf())
112 }
113
114 pub fn parse_simple_string(input: &[u8]) -> Result {
115 RedisProtocolParser::parse_everything_until_crlf(input).map(|(x, y)| (RESP::String(x), y))
116 }
117
118 pub fn parse_errors(input: &[u8]) -> Result {
119 RedisProtocolParser::parse_everything_until_crlf(input).map(|(x, y)| (RESP::Error(x), y))
120 }
121
122 pub fn parse_integers(input: &[u8]) -> Result {
123 RedisProtocolParser::parse_everything_until_crlf(input).map(|(x, y)| (RESP::Integer(x), y))
124 }
125
126 pub fn parse_bulk_strings(input: &[u8]) -> Result {
127 if RedisProtocolParser::check_null_value(input) {
129 Ok((RESP::Nil, &input[NIL_VALUE_SIZE..]))
130 } else {
131 let (size_str, input_after_size) =
132 RedisProtocolParser::parse_everything_until_crlf(input)?;
133 let size = std::str::from_utf8(size_str)?.parse::<u64>()? as usize;
134 if RedisProtocolParser::check_crlf_at_index(input_after_size, size) {
135 Ok((
136 RESP::BulkString(&input_after_size[..size]),
137 &input_after_size[size + 2..],
138 ))
139 } else {
140 Err(RError::incorrect_format())
141 }
142 }
143 }
144
145 fn check_crlf_at_index(input: &[u8], index: usize) -> bool {
146 input.len() >= index && input[index] == b'\r' && input[index + 1] == b'\n'
147 }
148
149 fn check_null_value(input: &[u8]) -> bool {
150 input.len() >= 4 && input[0] == b'-' && input[1] == b'1' && input[2] == CR && input[3] == LF
151 }
152
153 pub fn parse_arrays(input: &[u8]) -> Result {
154 let (size_str, input) = RedisProtocolParser::parse_everything_until_crlf(input)?;
155 let size = std::str::from_utf8(size_str)?.parse::<u64>()?;
156 let sizes = size as usize;
157 let mut left = input;
158 let mut result = Vec::with_capacity(sizes);
159 for _ in 0..sizes {
160 let (element, tmp) = RedisProtocolParser::parse_resp(left)?;
161 result.push(element);
162 left = tmp;
163 }
164 Ok((RESP::Array(result), left))
165 }
166}
167
168#[cfg(test)]
169mod test {
170 use super::*;
171
172 #[test]
173 pub fn test_simple_string() -> std::result::Result<(), RError> {
174 let input = "+hello\r\n".as_bytes();
175 let (resp, left) = RedisProtocolParser::parse_resp(input)?;
176 assert_eq!(resp, RESP::String("hello".as_bytes()));
177 assert!(left.is_empty());
178 Ok(())
179 }
180
181 #[test]
182 pub fn test_errors() -> std::result::Result<(), RError> {
183 let input = "+hello".as_bytes();
184 let err = RedisProtocolParser::parse_resp(input).unwrap_err();
185 assert!(matches!(err.err_type, RErrorType::NoCrlf));
186 let input = "*2\r\n$3\r\nfoo\r\n)hello".as_bytes();
187 let err = RedisProtocolParser::parse_resp(input).unwrap_err();
188 assert!(matches!(err.err_type, RErrorType::UnknownSymbol));
189 let input = "".as_bytes();
190 let err = RedisProtocolParser::parse_resp(input).unwrap_err();
191 assert!(matches!(err.err_type, RErrorType::EmptyInput));
192 let input = "$4\r\nfoo\r\n".as_bytes();
193 let err = RedisProtocolParser::parse_resp(input).unwrap_err();
194 assert!(matches!(err.err_type, RErrorType::IncorrectFormat));
195 let input = "*2\r\n$3\r\nfoo+hello\r\n".as_bytes();
196 let err = RedisProtocolParser::parse_resp(input).unwrap_err();
197 assert!(matches!(err.err_type, RErrorType::IncorrectFormat));
198 Ok(())
199 }
200
201 #[test]
202 pub fn test_nil() -> std::result::Result<(), RError> {
203 let input = "$-1\r\n".as_bytes();
204 let (resp, left) = RedisProtocolParser::parse_resp(input)?;
205 assert_eq!(resp, RESP::Nil);
206 assert!(left.is_empty());
207 Ok(())
208 }
209
210 #[test]
211 pub fn test_bulk_string() -> std::result::Result<(), RError> {
212 let input = "$6\r\nfoobar\r\n".as_bytes();
213 let (resp, left) = RedisProtocolParser::parse_resp(input)?;
214 assert_eq!(resp, RESP::BulkString("foobar".as_bytes()));
215 assert!(left.is_empty());
216 let input = "$0\r\n\r\n".as_bytes();
217 let (resp, left) = RedisProtocolParser::parse_resp(input)?;
218 assert_eq!(resp, RESP::BulkString("".as_bytes()));
219 assert!(left.is_empty());
220 Ok(())
221 }
222
223 #[test]
224 pub fn test_arrays() -> std::result::Result<(), RError> {
225 let input = "*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n".as_bytes();
226 let (resp, left) = RedisProtocolParser::parse_resp(input)?;
227 assert_eq!(
228 resp,
229 RESP::Array(vec![
230 RESP::BulkString("foo".as_bytes()),
231 RESP::BulkString("bar".as_bytes())
232 ])
233 );
234 assert!(left.is_empty());
235 let input = "*5\r\n:1\r\n:2\r\n:3\r\n:4\r\n$6\r\nfoobar\r\n".as_bytes();
236 let (resp, left) = RedisProtocolParser::parse_resp(input)?;
237 assert_eq!(
238 resp,
239 RESP::Array(vec![
240 RESP::Integer("1".as_bytes()),
241 RESP::Integer("2".as_bytes()),
242 RESP::Integer("3".as_bytes()),
243 RESP::Integer("4".as_bytes()),
244 RESP::BulkString("foobar".as_bytes()),
245 ])
246 );
247 assert!(left.is_empty());
248 Ok(())
249 }
250
251 #[test]
252 pub fn test_array_of_arrays() -> std::result::Result<(), RError> {
253 let input = "*2\r\n*3\r\n:1\r\n:2\r\n:3\r\n*2\r\n+Foo\r\n-Bar\r\n".as_bytes();
254 let (resp, left) = RedisProtocolParser::parse_resp(input)?;
255 assert_eq!(
256 resp,
257 RESP::Array(vec![
258 RESP::Array(vec![
259 RESP::Integer("1".as_bytes()),
260 RESP::Integer("2".as_bytes()),
261 RESP::Integer("3".as_bytes()),
262 ]),
263 RESP::Array(vec![
264 RESP::String("Foo".as_bytes()),
265 RESP::Error("Bar".as_bytes()),
266 ]),
267 ])
268 );
269 assert!(left.is_empty());
270 Ok(())
271 }
272}