kresp/
parser.rs

1use anyhow::Result;
2use thiserror::Error;
3
4use super::buffer::*;
5use super::RespConfig;
6use super::RespType;
7
8/// Error enumeration used when a parsing error occurs
9#[derive(Error, Debug)]
10pub enum ParserError {
11    /// Error occured when reading a simple string that should end in \r\n
12    #[error("Invalid RESP line read: {0}")]
13    ReadlineError(String),
14
15    /// Error when an invalid size is read
16    #[error("Invalid RESP size: {0}")]
17    ReadsizeError(i64),
18
19    /// An internal state machine error, this should not happen, please report
20    /// if it does!
21    #[error("State error: {0}")]
22    StateError(String),
23
24    /// Next byte read was not a type token
25    #[error("Invalid RESP type token: {0:#?}")]
26    TypeTokenError(char),
27
28    /// Size limit hit
29    #[error("RESP size exceeded")]
30    SizeExceededError,
31}
32
33/// The parser itself, use [`RespParser::read`] to provide it buffers to parse
34pub struct RespParser {
35    buffer: Vec<u8>,
36    state: Option<Box<State>>,
37    /// Configuration structure for memory limits
38    pub config: RespConfig,
39}
40
41#[derive(Debug)]
42enum State {
43    GetType {
44        cursor: usize,
45    },
46    Simple {
47        cursor: usize,
48        start: usize,
49        simple_type: SimpleType,
50    },
51    BulkString {
52        cursor: usize,
53        start: usize,
54        size: Option<usize>,
55    },
56    Array {
57        cursor: usize,
58        start: usize,
59        size: Option<usize>,
60        elements: Option<Vec<RespType>>,
61        substate: Option<Box<State>>,
62    },
63}
64
65impl State {
66    fn boxed(self) -> Box<State> {
67        Box::new(self)
68    }
69
70    fn get_type(cursor: usize) -> Box<State> {
71        Box::new(State::GetType { cursor })
72    }
73
74    fn get_simple(cursor: usize, simple_type: SimpleType) -> Box<State> {
75        Box::new(State::Simple {
76            cursor,
77            start: cursor,
78            simple_type,
79        })
80    }
81
82    fn get_bulk_string(cursor: usize) -> Box<State> {
83        Box::new(State::BulkString {
84            cursor,
85            start: cursor,
86            size: None,
87        })
88    }
89
90    fn get_array(cursor: usize) -> Box<State> {
91        Box::new(State::Array {
92            cursor,
93            start: cursor,
94            size: None,
95            elements: None,
96            substate: None,
97        })
98    }
99}
100
101#[derive(Debug)]
102enum StateResult {
103    Incomplete(Box<State>),
104    Done(RespType, usize),
105}
106
107#[derive(Debug)]
108enum SimpleType {
109    String,
110    Error,
111    Integer,
112}
113
114impl Default for RespParser {
115    fn default() -> Self {
116        Self::new(RespConfig::default())
117    }
118}
119
120impl RespParser {
121    /// Creates a new instance, can use [`RespParser`.`default`] for common setups
122    pub fn new(config: RespConfig) -> Self {
123        RespParser {
124            buffer: Vec::new(),
125            state: None,
126            config,
127        }
128    }
129
130    /// Copy and parses the provided buffer, returns a list of [`RespType`] variant results
131    pub fn read(&mut self, buffer: &[u8]) -> Result<Vec<RespType>> {
132        for byte in buffer {
133            self.buffer.push(*byte);
134        }
135
136        if self.buffer.len() > self.config.max_buffer_size {
137            self.buffer.clear();
138            return Err(ParserError::SizeExceededError.into());
139        }
140
141        let mut items = Vec::new();
142        if let Some(state) = self.state.take() {
143            match self.process_state(state) {
144                Ok(result) => match result {
145                    StateResult::Incomplete(state) => {
146                        self.state = Some(state.boxed());
147                        return Ok(items);
148                    }
149                    StateResult::Done(item, end) => {
150                        self.buffer.drain(..end);
151                        items.push(item)
152                    }
153                },
154                Err(error) => {
155                    self.buffer.clear();
156                    return Err(error);
157                }
158            }
159        }
160
161        loop {
162            match self.get_next() {
163                Ok(result) => match result {
164                    Some(item) => items.push(item),
165                    None => return Ok(items),
166                },
167                Err(error) => {
168                    self.buffer.clear();
169                    return Err(error);
170                }
171            }
172        }
173    }
174
175    fn get_next(&mut self) -> Result<Option<RespType>> {
176        match self.get_type(State::get_type(0))? {
177            StateResult::Incomplete(state) => {
178                self.state = Some(state.boxed());
179                Ok(None)
180            }
181            StateResult::Done(item, end) => {
182                self.buffer.drain(..end);
183                Ok(Some(item))
184            }
185        }
186    }
187
188    fn process_state(&self, state: Box<State>) -> Result<StateResult> {
189        match *state {
190            State::GetType { .. } => self.get_type(state),
191            State::Simple { .. } => self.get_simple(state),
192            State::BulkString { .. } => self.get_bulk_string(state),
193            State::Array { .. } => self.get_array(state),
194        }
195    }
196
197    fn get_type(&self, state: Box<State>) -> Result<StateResult> {
198        if let State::GetType { cursor } = *state {
199            if self.buffer.len() <= cursor {
200                return Ok(StateResult::Incomplete(State::get_type(cursor)));
201            }
202
203            let next_cursor = cursor + 1;
204            let state = match &self.buffer[cursor] {
205                b'+' => State::get_simple(next_cursor, SimpleType::String),
206                b'-' => State::get_simple(next_cursor, SimpleType::Error),
207                b':' => State::get_simple(next_cursor, SimpleType::Integer),
208                b'$' => State::get_bulk_string(next_cursor),
209                b'*' => State::get_array(next_cursor),
210                other => return Err(ParserError::TypeTokenError(*other as char).into()),
211            };
212
213            if self.buffer.len() > cursor + 1 {
214                self.process_state(state)
215            } else {
216                Ok(StateResult::Incomplete(state.boxed()))
217            }
218        } else {
219            Err(ParserError::StateError(format!(
220                "get_type received wrong state type: {:#?}",
221                state
222            ))
223            .into())
224        }
225    }
226
227    fn get_simple(&self, state: Box<State>) -> Result<StateResult> {
228        if let State::Simple {
229            cursor,
230            start,
231            simple_type,
232        } = *state
233        {
234            match readline(&self.buffer, cursor, start)? {
235                ReadlineResult::Line { line, cursor } => {
236                    if line.len() > self.config.max_resp_size {
237                        return Err(ParserError::SizeExceededError.into());
238                    }
239                    let result = match simple_type {
240                        SimpleType::String => RespType::SimpleString(line),
241                        SimpleType::Error => RespType::Error(line),
242                        SimpleType::Integer => RespType::Integer(line.parse()?),
243                    };
244                    Ok(StateResult::Done(result, cursor))
245                }
246                ReadlineResult::None { cursor } => Ok(StateResult::Incomplete(
247                    State::Simple {
248                        cursor,
249                        start,
250                        simple_type,
251                    }
252                    .boxed(),
253                )),
254            }
255        } else {
256            Err(ParserError::StateError(format!(
257                "get_simple received wrong state type: {:#?}",
258                state
259            ))
260            .into())
261        }
262    }
263
264    fn get_bulk_string(&self, state: Box<State>) -> Result<StateResult> {
265        if let State::BulkString {
266            cursor,
267            start,
268            size: string_length,
269        } = *state
270        {
271            let (cursor, size) = match string_length {
272                None => match readsize(&self.buffer, cursor, start)? {
273                    ReadsizeResult::None(cursor) => {
274                        let state = State::BulkString {
275                            cursor,
276                            start,
277                            size: None,
278                        };
279                        return Ok(StateResult::Incomplete(state.boxed()));
280                    }
281                    ReadsizeResult::Null(cursor) => {
282                        let result = RespType::Null;
283                        return Ok(StateResult::Done(result, cursor));
284                    }
285                    ReadsizeResult::Size { end, size } => (end, size),
286                },
287                Some(size) => (cursor, size),
288            };
289            if size > self.config.max_resp_size {
290                return Err(ParserError::SizeExceededError.into());
291            }
292
293            match readbuffer(&self.buffer, cursor, size) {
294                Some((vector, end)) => {
295                    let result = RespType::BulkString(vector);
296                    Ok(StateResult::Done(result, end))
297                }
298                None => {
299                    let state = State::BulkString {
300                        cursor,
301                        start,
302                        size: Some(size),
303                    }
304                    .boxed();
305                    Ok(StateResult::Incomplete(state))
306                }
307            }
308        } else {
309            Err(ParserError::StateError(format!(
310                "get_bulk_string received wrong state type: {:#?}",
311                state
312            ))
313            .into())
314        }
315    }
316
317    fn get_array(&self, state: Box<State>) -> Result<StateResult> {
318        if let State::Array {
319            cursor,
320            start,
321            size: array_size,
322            elements,
323            mut substate,
324        } = *state
325        {
326            let (cursor, size) = match array_size {
327                Some(size) => (cursor, size),
328                None => match readsize(&self.buffer, cursor, start)? {
329                    ReadsizeResult::None(cursor) => {
330                        let state = State::Array {
331                            cursor,
332                            start,
333                            size: None,
334                            elements: None,
335                            substate: None,
336                        };
337                        return Ok(StateResult::Incomplete(state.boxed()));
338                    }
339                    ReadsizeResult::Null(cursor) => {
340                        let result = RespType::NullArray;
341                        return Ok(StateResult::Done(result, cursor));
342                    }
343                    ReadsizeResult::Size { end, size } => {
344                        if size == 0 {
345                            let result = RespType::Array(Vec::new());
346                            return Ok(StateResult::Done(result, end));
347                        } else {
348                            (end, size)
349                        }
350                    }
351                },
352            };
353            if size > self.config.max_resp_size {
354                return Err(ParserError::SizeExceededError.into());
355            }
356
357            let mut elements = match elements {
358                Some(elements) => elements,
359                None => Vec::new(),
360            };
361            let mut cursor = cursor;
362            while elements.len() < size {
363                let state = match substate {
364                    Some(_) => substate.take().unwrap(),
365                    None => State::get_type(cursor),
366                };
367                match self.process_state(state)? {
368                    StateResult::Done(result, end) => {
369                        cursor = end;
370                        elements.push(result);
371                    }
372                    StateResult::Incomplete(substate) => {
373                        let state = State::Array {
374                            cursor,
375                            start,
376                            size: Some(size),
377                            elements: Some(elements),
378                            substate: Some(substate),
379                        };
380                        return Ok(StateResult::Incomplete(state.boxed()));
381                    }
382                }
383            }
384            let result = RespType::Array(elements);
385            Ok(StateResult::Done(result, cursor))
386        } else {
387            Err(ParserError::StateError(format!(
388                "get_array received wrong state type: {:#?}",
389                state
390            ))
391            .into())
392        }
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399    use RespType::*;
400
401    fn test_parser_ok<'a, T>(buffer: T) -> Vec<RespType>
402    where
403        &'a [u8]: From<T>,
404    {
405        let mut parser = RespParser::default();
406        match parser.read(buffer.into()) {
407            Ok(results) => results,
408            other => panic!("result was not Ok(), was {:#?}", other),
409        }
410    }
411
412    fn test_parser_err<'a, T>(buffer: T)
413    where
414        &'a [u8]: From<T>,
415    {
416        let mut parser = RespParser::default();
417        let result = parser.read(buffer.into());
418        assert!(result.is_err());
419    }
420
421    fn assert_empty_result(results: Vec<RespType>) {
422        let result_length = results.len();
423        assert_eq!(
424            result_length, 0,
425            "result was not empty, contained {} elements",
426            result_length
427        );
428    }
429
430    fn assert_num_results(results: &Vec<RespType>, expected: usize) {
431        let result_length = results.len();
432        assert_eq!(
433            result_length, expected,
434            "result was of unexpected length, contained {} elements, expected {}",
435            result_length, expected
436        )
437    }
438
439    #[test]
440    fn empty_start() {
441        let results = test_parser_ok(b"");
442
443        assert_empty_result(results);
444    }
445
446    #[test]
447    fn complex_nested() {
448        let results = test_parser_ok(b"*3\r\n*-1\r\n*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n*5\r\n+test\r\n-test3\r\n:-12345\r\n$2\r\nab\r\n$-1\r\n");
449
450        assert_num_results(&results, 1);
451        if let Array(array) = &results[0] {
452            assert_eq!(array.len(), 3);
453            assert_eq!(array[0], RespType::NullArray);
454            if let Array(nested) = &array[1] {
455                assert_eq!(nested.len(), 2);
456                assert_eq!(nested[0], BulkString("hello".into()));
457                assert_eq!(nested[1], BulkString("world".into()));
458            } else {
459                panic!("Nested array at pos 1 expected")
460            }
461            if let Array(mixed) = &array[2] {
462                assert_eq!(mixed.len(), 5);
463                assert_eq!(mixed[0], SimpleString("test".into()));
464                assert_eq!(mixed[1], Error("test3".into()));
465                assert_eq!(mixed[2], Integer(-12345));
466                assert_eq!(mixed[3], BulkString("ab".into()));
467                assert_eq!(mixed[4], Null);
468            } else {
469                panic!("Mixed array at pos 2 expected")
470            }
471        } else {
472            panic!("Array type expected")
473        }
474    }
475
476    #[test]
477    fn complex_nested_onebyte() -> Result<()> {
478        let mut parser = RespParser::default();
479        for byte in b"*3\r\n*-1\r\n*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n*5\r\n+test\r\n-test3\r\n:-12345\r\n$2\r\nab\r\n$-1\r" {
480            let results = parser.read(&[*byte])?;
481            assert_eq!(results.len(), 0);
482        }
483
484        let results = parser.read(b"\n")?;
485
486        assert_num_results(&results, 1);
487        if let Array(array) = &results[0] {
488            assert_eq!(array.len(), 3);
489            assert_eq!(array[0], RespType::NullArray);
490            if let Array(nested) = &array[1] {
491                assert_eq!(nested.len(), 2);
492                assert_eq!(nested[0], BulkString("hello".into()));
493                assert_eq!(nested[1], BulkString("world".into()));
494            } else {
495                panic!("Nested array at pos 1 expected")
496            }
497            if let Array(mixed) = &array[2] {
498                assert_eq!(mixed.len(), 5);
499                assert_eq!(mixed[0], SimpleString("test".into()));
500                assert_eq!(mixed[1], Error("test3".into()));
501                assert_eq!(mixed[2], Integer(-12345));
502                assert_eq!(mixed[3], BulkString("ab".into()));
503                assert_eq!(mixed[4], Null);
504                Ok(())
505            } else {
506                panic!("Mixed array at pos 2 expected")
507            }
508        } else {
509            panic!("Array type expected")
510        }
511    }
512
513    mod simple_string {
514        use super::*;
515
516        fn assert_simple_string(elements: &Vec<RespType>, index: usize, expected: &str) {
517            let element = &elements.get(index);
518            assert!(element.is_some());
519
520            match element.unwrap() {
521                SimpleString(string) => {
522                    assert_eq!(string, expected);
523                }
524                _ => {
525                    panic!("Expected SimpleString variant")
526                }
527            };
528        }
529
530        #[test]
531        fn valid() {
532            let results = test_parser_ok(b"+Valid!\r\n");
533
534            assert_num_results(&results, 1);
535            assert_simple_string(&results, 0, "Valid!");
536        }
537
538        #[test]
539        fn valid_remainder() {
540            let results = test_parser_ok(b"+valid and then some\r\n+");
541
542            assert_num_results(&results, 1);
543            assert_simple_string(&results, 0, "valid and then some");
544        }
545
546        #[test]
547        fn valid_incomplete() {
548            let results = test_parser_ok(b"+OK\r");
549
550            assert_empty_result(results);
551        }
552
553        #[test]
554        fn invalid_char_after_cr() {
555            test_parser_err(b"+OK\rx");
556        }
557
558        #[test]
559        fn invalid_newline() {
560            test_parser_err(b"+OK\n\r\n");
561        }
562    }
563
564    mod error {
565        use super::*;
566
567        fn assert_error(elements: &Vec<RespType>, index: usize, expected: &str) {
568            let element = &elements.get(index);
569            assert!(element.is_some());
570
571            match element.unwrap() {
572                Error(string) => {
573                    assert_eq!(string, expected);
574                }
575                _ => {
576                    panic!("Expected Error variant")
577                }
578            };
579        }
580
581        #[test]
582        fn valid() {
583            let results = test_parser_ok(b"-Valid!\r\n");
584
585            assert_num_results(&results, 1);
586            assert_error(&results, 0, "Valid!");
587        }
588
589        #[test]
590        fn remainder() {
591            let results = test_parser_ok(b"-Valid!\r\n:");
592
593            assert_num_results(&results, 1);
594            assert_error(&results, 0, "Valid!");
595        }
596
597        #[test]
598        fn two() {
599            let results = test_parser_ok(b"-Valid!\r\n-andmore\r\n");
600
601            assert_num_results(&results, 2);
602            assert_error(&results, 0, "Valid!");
603            assert_error(&results, 1, "andmore");
604        }
605    }
606
607    mod integer {
608        use super::*;
609
610        fn assert_integer(elements: &Vec<RespType>, index: usize, expected: i64) {
611            let element = &elements.get(index);
612            assert!(element.is_some());
613
614            match element.unwrap() {
615                Integer(int) => {
616                    assert_eq!(*int, expected);
617                }
618                _ => {
619                    panic!("Expected Integer variant")
620                }
621            };
622        }
623
624        #[test]
625        fn valid() {
626            let results = test_parser_ok(b":1234\r\n");
627
628            assert_num_results(&results, 1);
629            assert_integer(&results, 0, 1234);
630        }
631
632        #[test]
633        fn valid_negative() {
634            let results = test_parser_ok(b":-1234\r\n");
635
636            assert_num_results(&results, 1);
637            assert_integer(&results, 0, -1234);
638        }
639
640        #[test]
641        fn invalid() {
642            test_parser_err(b":hi\r\n");
643        }
644    }
645
646    mod bulk_string {
647        use super::*;
648
649        fn assert_bulk_string(results: &Vec<RespType>, index: usize, expected: &[u8]) {
650            let element = &results.get(index);
651            assert!(element.is_some());
652
653            match element.unwrap() {
654                BulkString(string) => {
655                    assert_eq!(string, expected);
656                }
657                _ => {
658                    panic!("Expected BulkString variant")
659                }
660            };
661        }
662
663        #[test]
664        fn valid() {
665            let results = test_parser_ok(b"$6\r\nValid!\r\n");
666
667            assert_num_results(&results, 1);
668            assert_bulk_string(&results, 0, "Valid!".as_bytes());
669        }
670
671        #[test]
672        fn two() {
673            let results = test_parser_ok(b"$6\r\nValid!\r\n$5\r\nwooo!\r\n");
674
675            assert_num_results(&results, 2);
676            assert_bulk_string(&results, 0, "Valid!".as_bytes());
677            assert_bulk_string(&results, 1, "wooo!".as_bytes());
678        }
679
680        #[test]
681        fn remainder() {
682            let results = test_parser_ok(b"$6\r\nValid!\r\n+OK");
683
684            assert_num_results(&results, 1);
685            assert_bulk_string(&results, 0, "Valid!".as_bytes());
686        }
687
688        #[test]
689        fn empty() {
690            let results = test_parser_ok(b"$0\r\n\r\n");
691
692            assert_num_results(&results, 1);
693        }
694
695        #[test]
696        fn null() {
697            let results = test_parser_ok(b"$-1\r\n");
698
699            assert_num_results(&results, 1);
700        }
701    }
702
703    mod array {
704        use super::*;
705
706        fn _assert_array_length(array: &RespType, length: usize) {
707            match array {
708                Array(array) => {
709                    assert_eq!(array.len(), length);
710                }
711                _ => {
712                    panic!("Expected Array variant")
713                }
714            };
715        }
716
717        #[test]
718        fn start() {
719            let results = test_parser_ok(b"*");
720            assert_empty_result(results);
721        }
722
723        #[test]
724        fn hello_world() {
725            let results = test_parser_ok(b"*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n");
726
727            assert_num_results(&results, 1);
728        }
729
730        #[test]
731        fn nested() {
732            let results = test_parser_ok(b"*1\r\n*3\r\n$5\r\nhello\r\n+ok\r\n*-1\r\n");
733            assert_num_results(&results, 1);
734        }
735
736        #[test]
737        fn null() -> Result<()> {
738            let results = test_parser_ok(b"*-1\r\n");
739
740            assert_num_results(&results, 1);
741            match results.first().unwrap() {
742                RespType::NullArray => {
743                    // good
744                }
745                _ => {
746                    panic!("null array expected");
747                }
748            };
749            Ok(())
750        }
751
752        #[test]
753        fn empty() {
754            let results = test_parser_ok(b"*0\r\n");
755
756            assert_num_results(&results, 1);
757        }
758    }
759}