redis_protocol_parser/
lib.rs

1pub type Result<'a> = std::result::Result<(RESP<'a>, &'a [u8]), RError>;
2
3const NIL_VALUE_SIZE: usize = 4;
4const CR: u8 = b'\r';
5const LF: u8 = b'\n';
6
7pub struct RedisProtocolParser;
8
9#[derive(Debug, Eq, PartialEq)]
10pub enum RESP<'a> {
11    String(&'a [u8]),
12    Error(&'a [u8]),
13    Integer(&'a [u8]),
14    BulkString(&'a [u8]),
15    Nil,
16    Array(Vec<RESP<'a>>),
17}
18
19#[derive(Debug)]
20pub enum RErrorType {
21    // Unknown symbol at index
22    UnknownSymbol,
23    // Attempting to parse an empty input
24    EmptyInput,
25    // Cannot find CRLF at index
26    NoCrlf,
27    // Incorrect format detected
28    IncorrectFormat,
29    Other(Box<dyn std::error::Error>),
30}
31
32#[derive(Debug)]
33pub struct RError {
34    err_type: RErrorType,
35}
36
37impl RError {
38    fn unknown_symbol() -> Self {
39        Self {
40            err_type: RErrorType::UnknownSymbol,
41        }
42    }
43
44    fn empty_input() -> Self {
45        Self {
46            err_type: RErrorType::EmptyInput,
47        }
48    }
49
50    fn no_crlf() -> Self {
51        Self {
52            err_type: RErrorType::NoCrlf,
53        }
54    }
55    fn incorrect_format() -> Self {
56        Self {
57            err_type: RErrorType::IncorrectFormat,
58        }
59    }
60}
61
62impl<'a> std::fmt::Display for RError {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        write!(f, "{}", self)
65    }
66}
67
68impl<'a> std::error::Error for RError {}
69
70impl<'a> From<std::str::Utf8Error> for RError {
71    fn from(from: std::str::Utf8Error) -> Self {
72        Self {
73            err_type: RErrorType::Other(Box::new(from)),
74        }
75    }
76}
77
78impl<'a> From<std::num::ParseIntError> for RError {
79    fn from(from: std::num::ParseIntError) -> Self {
80        Self {
81            err_type: RErrorType::Other(Box::new(from)),
82        }
83    }
84}
85
86impl RedisProtocolParser {
87    pub fn parse_resp(input: &[u8]) -> Result {
88        if let Some(first) = input.get(0) {
89            let first = *first as char;
90            let input = &input[1..];
91            let (resp, left) = match first {
92                '+' => RedisProtocolParser::parse_simple_string(input)?,
93                ':' => RedisProtocolParser::parse_integers(input)?,
94                '$' => RedisProtocolParser::parse_bulk_strings(input)?,
95                '*' => RedisProtocolParser::parse_arrays(input)?,
96                '-' => RedisProtocolParser::parse_errors(input)?,
97                _ => return Err(RError::unknown_symbol()),
98            };
99            Ok((resp, left))
100        } else {
101            Err(RError::empty_input())
102        }
103    }
104
105    fn parse_everything_until_crlf(input: &[u8]) -> std::result::Result<(&[u8], &[u8]), RError> {
106        for (index, (first, second)) in input.iter().zip(input.iter().skip(1)).enumerate() {
107            if first == &CR && second == &LF {
108                return Ok((&input[0..index], &input[index + 2..]));
109            }
110        }
111        Err(RError::no_crlf())
112    }
113
114    pub fn parse_simple_string(input: &[u8]) -> Result {
115        RedisProtocolParser::parse_everything_until_crlf(input).map(|(x, y)| (RESP::String(x), y))
116    }
117
118    pub fn parse_errors(input: &[u8]) -> Result {
119        RedisProtocolParser::parse_everything_until_crlf(input).map(|(x, y)| (RESP::Error(x), y))
120    }
121
122    pub fn parse_integers(input: &[u8]) -> Result {
123        RedisProtocolParser::parse_everything_until_crlf(input).map(|(x, y)| (RESP::Integer(x), y))
124    }
125
126    pub fn parse_bulk_strings(input: &[u8]) -> Result {
127        // Check Null Strings.
128        if RedisProtocolParser::check_null_value(input) {
129            Ok((RESP::Nil, &input[NIL_VALUE_SIZE..]))
130        } else {
131            let (size_str, input_after_size) =
132                RedisProtocolParser::parse_everything_until_crlf(input)?;
133            let size = std::str::from_utf8(size_str)?.parse::<u64>()? as usize;
134            if RedisProtocolParser::check_crlf_at_index(input_after_size, size) {
135                Ok((
136                    RESP::BulkString(&input_after_size[..size]),
137                    &input_after_size[size + 2..],
138                ))
139            } else {
140                Err(RError::incorrect_format())
141            }
142        }
143    }
144
145    fn check_crlf_at_index(input: &[u8], index: usize) -> bool {
146        input.len() >= index && input[index] == b'\r' && input[index + 1] == b'\n'
147    }
148
149    fn check_null_value(input: &[u8]) -> bool {
150        input.len() >= 4 && input[0] == b'-' && input[1] == b'1' && input[2] == CR && input[3] == LF
151    }
152
153    pub fn parse_arrays(input: &[u8]) -> Result {
154        let (size_str, input) = RedisProtocolParser::parse_everything_until_crlf(input)?;
155        let size = std::str::from_utf8(size_str)?.parse::<u64>()?;
156        let sizes = size as usize;
157        let mut left = input;
158        let mut result = Vec::with_capacity(sizes);
159        for _ in 0..sizes {
160            let (element, tmp) = RedisProtocolParser::parse_resp(left)?;
161            result.push(element);
162            left = tmp;
163        }
164        Ok((RESP::Array(result), left))
165    }
166}
167
168#[cfg(test)]
169mod test {
170    use super::*;
171
172    #[test]
173    pub fn test_simple_string() -> std::result::Result<(), RError> {
174        let input = "+hello\r\n".as_bytes();
175        let (resp, left) = RedisProtocolParser::parse_resp(input)?;
176        assert_eq!(resp, RESP::String("hello".as_bytes()));
177        assert!(left.is_empty());
178        Ok(())
179    }
180
181    #[test]
182    pub fn test_errors() -> std::result::Result<(), RError> {
183        let input = "+hello".as_bytes();
184        let err = RedisProtocolParser::parse_resp(input).unwrap_err();
185        assert!(matches!(err.err_type, RErrorType::NoCrlf));
186        let input = "*2\r\n$3\r\nfoo\r\n)hello".as_bytes();
187        let err = RedisProtocolParser::parse_resp(input).unwrap_err();
188        assert!(matches!(err.err_type, RErrorType::UnknownSymbol));
189        let input = "".as_bytes();
190        let err = RedisProtocolParser::parse_resp(input).unwrap_err();
191        assert!(matches!(err.err_type, RErrorType::EmptyInput));
192        let input = "$4\r\nfoo\r\n".as_bytes();
193        let err = RedisProtocolParser::parse_resp(input).unwrap_err();
194        assert!(matches!(err.err_type, RErrorType::IncorrectFormat));
195        let input = "*2\r\n$3\r\nfoo+hello\r\n".as_bytes();
196        let err = RedisProtocolParser::parse_resp(input).unwrap_err();
197        assert!(matches!(err.err_type, RErrorType::IncorrectFormat));
198        Ok(())
199    }
200
201    #[test]
202    pub fn test_nil() -> std::result::Result<(), RError> {
203        let input = "$-1\r\n".as_bytes();
204        let (resp, left) = RedisProtocolParser::parse_resp(input)?;
205        assert_eq!(resp, RESP::Nil);
206        assert!(left.is_empty());
207        Ok(())
208    }
209
210    #[test]
211    pub fn test_bulk_string() -> std::result::Result<(), RError> {
212        let input = "$6\r\nfoobar\r\n".as_bytes();
213        let (resp, left) = RedisProtocolParser::parse_resp(input)?;
214        assert_eq!(resp, RESP::BulkString("foobar".as_bytes()));
215        assert!(left.is_empty());
216        let input = "$0\r\n\r\n".as_bytes();
217        let (resp, left) = RedisProtocolParser::parse_resp(input)?;
218        assert_eq!(resp, RESP::BulkString("".as_bytes()));
219        assert!(left.is_empty());
220        Ok(())
221    }
222
223    #[test]
224    pub fn test_arrays() -> std::result::Result<(), RError> {
225        let input = "*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n".as_bytes();
226        let (resp, left) = RedisProtocolParser::parse_resp(input)?;
227        assert_eq!(
228            resp,
229            RESP::Array(vec![
230                RESP::BulkString("foo".as_bytes()),
231                RESP::BulkString("bar".as_bytes())
232            ])
233        );
234        assert!(left.is_empty());
235        let input = "*5\r\n:1\r\n:2\r\n:3\r\n:4\r\n$6\r\nfoobar\r\n".as_bytes();
236        let (resp, left) = RedisProtocolParser::parse_resp(input)?;
237        assert_eq!(
238            resp,
239            RESP::Array(vec![
240                RESP::Integer("1".as_bytes()),
241                RESP::Integer("2".as_bytes()),
242                RESP::Integer("3".as_bytes()),
243                RESP::Integer("4".as_bytes()),
244                RESP::BulkString("foobar".as_bytes()),
245            ])
246        );
247        assert!(left.is_empty());
248        Ok(())
249    }
250
251    #[test]
252    pub fn test_array_of_arrays() -> std::result::Result<(), RError> {
253        let input = "*2\r\n*3\r\n:1\r\n:2\r\n:3\r\n*2\r\n+Foo\r\n-Bar\r\n".as_bytes();
254        let (resp, left) = RedisProtocolParser::parse_resp(input)?;
255        assert_eq!(
256            resp,
257            RESP::Array(vec![
258                RESP::Array(vec![
259                    RESP::Integer("1".as_bytes()),
260                    RESP::Integer("2".as_bytes()),
261                    RESP::Integer("3".as_bytes()),
262                ]),
263                RESP::Array(vec![
264                    RESP::String("Foo".as_bytes()),
265                    RESP::Error("Bar".as_bytes()),
266                ]),
267            ])
268        );
269        assert!(left.is_empty());
270        Ok(())
271    }
272}