1use anyhow::Result;
2use thiserror::Error;
3
4use super::buffer::*;
5use super::RespConfig;
6use super::RespType;
7
8#[derive(Error, Debug)]
10pub enum ParserError {
11 #[error("Invalid RESP line read: {0}")]
13 ReadlineError(String),
14
15 #[error("Invalid RESP size: {0}")]
17 ReadsizeError(i64),
18
19 #[error("State error: {0}")]
22 StateError(String),
23
24 #[error("Invalid RESP type token: {0:#?}")]
26 TypeTokenError(char),
27
28 #[error("RESP size exceeded")]
30 SizeExceededError,
31}
32
33pub struct RespParser {
35 buffer: Vec<u8>,
36 state: Option<Box<State>>,
37 pub config: RespConfig,
39}
40
41#[derive(Debug)]
42enum State {
43 GetType {
44 cursor: usize,
45 },
46 Simple {
47 cursor: usize,
48 start: usize,
49 simple_type: SimpleType,
50 },
51 BulkString {
52 cursor: usize,
53 start: usize,
54 size: Option<usize>,
55 },
56 Array {
57 cursor: usize,
58 start: usize,
59 size: Option<usize>,
60 elements: Option<Vec<RespType>>,
61 substate: Option<Box<State>>,
62 },
63}
64
65impl State {
66 fn boxed(self) -> Box<State> {
67 Box::new(self)
68 }
69
70 fn get_type(cursor: usize) -> Box<State> {
71 Box::new(State::GetType { cursor })
72 }
73
74 fn get_simple(cursor: usize, simple_type: SimpleType) -> Box<State> {
75 Box::new(State::Simple {
76 cursor,
77 start: cursor,
78 simple_type,
79 })
80 }
81
82 fn get_bulk_string(cursor: usize) -> Box<State> {
83 Box::new(State::BulkString {
84 cursor,
85 start: cursor,
86 size: None,
87 })
88 }
89
90 fn get_array(cursor: usize) -> Box<State> {
91 Box::new(State::Array {
92 cursor,
93 start: cursor,
94 size: None,
95 elements: None,
96 substate: None,
97 })
98 }
99}
100
101#[derive(Debug)]
102enum StateResult {
103 Incomplete(Box<State>),
104 Done(RespType, usize),
105}
106
107#[derive(Debug)]
108enum SimpleType {
109 String,
110 Error,
111 Integer,
112}
113
114impl Default for RespParser {
115 fn default() -> Self {
116 Self::new(RespConfig::default())
117 }
118}
119
120impl RespParser {
121 pub fn new(config: RespConfig) -> Self {
123 RespParser {
124 buffer: Vec::new(),
125 state: None,
126 config,
127 }
128 }
129
130 pub fn read(&mut self, buffer: &[u8]) -> Result<Vec<RespType>> {
132 for byte in buffer {
133 self.buffer.push(*byte);
134 }
135
136 if self.buffer.len() > self.config.max_buffer_size {
137 self.buffer.clear();
138 return Err(ParserError::SizeExceededError.into());
139 }
140
141 let mut items = Vec::new();
142 if let Some(state) = self.state.take() {
143 match self.process_state(state) {
144 Ok(result) => match result {
145 StateResult::Incomplete(state) => {
146 self.state = Some(state.boxed());
147 return Ok(items);
148 }
149 StateResult::Done(item, end) => {
150 self.buffer.drain(..end);
151 items.push(item)
152 }
153 },
154 Err(error) => {
155 self.buffer.clear();
156 return Err(error);
157 }
158 }
159 }
160
161 loop {
162 match self.get_next() {
163 Ok(result) => match result {
164 Some(item) => items.push(item),
165 None => return Ok(items),
166 },
167 Err(error) => {
168 self.buffer.clear();
169 return Err(error);
170 }
171 }
172 }
173 }
174
175 fn get_next(&mut self) -> Result<Option<RespType>> {
176 match self.get_type(State::get_type(0))? {
177 StateResult::Incomplete(state) => {
178 self.state = Some(state.boxed());
179 Ok(None)
180 }
181 StateResult::Done(item, end) => {
182 self.buffer.drain(..end);
183 Ok(Some(item))
184 }
185 }
186 }
187
188 fn process_state(&self, state: Box<State>) -> Result<StateResult> {
189 match *state {
190 State::GetType { .. } => self.get_type(state),
191 State::Simple { .. } => self.get_simple(state),
192 State::BulkString { .. } => self.get_bulk_string(state),
193 State::Array { .. } => self.get_array(state),
194 }
195 }
196
197 fn get_type(&self, state: Box<State>) -> Result<StateResult> {
198 if let State::GetType { cursor } = *state {
199 if self.buffer.len() <= cursor {
200 return Ok(StateResult::Incomplete(State::get_type(cursor)));
201 }
202
203 let next_cursor = cursor + 1;
204 let state = match &self.buffer[cursor] {
205 b'+' => State::get_simple(next_cursor, SimpleType::String),
206 b'-' => State::get_simple(next_cursor, SimpleType::Error),
207 b':' => State::get_simple(next_cursor, SimpleType::Integer),
208 b'$' => State::get_bulk_string(next_cursor),
209 b'*' => State::get_array(next_cursor),
210 other => return Err(ParserError::TypeTokenError(*other as char).into()),
211 };
212
213 if self.buffer.len() > cursor + 1 {
214 self.process_state(state)
215 } else {
216 Ok(StateResult::Incomplete(state.boxed()))
217 }
218 } else {
219 Err(ParserError::StateError(format!(
220 "get_type received wrong state type: {:#?}",
221 state
222 ))
223 .into())
224 }
225 }
226
227 fn get_simple(&self, state: Box<State>) -> Result<StateResult> {
228 if let State::Simple {
229 cursor,
230 start,
231 simple_type,
232 } = *state
233 {
234 match readline(&self.buffer, cursor, start)? {
235 ReadlineResult::Line { line, cursor } => {
236 if line.len() > self.config.max_resp_size {
237 return Err(ParserError::SizeExceededError.into());
238 }
239 let result = match simple_type {
240 SimpleType::String => RespType::SimpleString(line),
241 SimpleType::Error => RespType::Error(line),
242 SimpleType::Integer => RespType::Integer(line.parse()?),
243 };
244 Ok(StateResult::Done(result, cursor))
245 }
246 ReadlineResult::None { cursor } => Ok(StateResult::Incomplete(
247 State::Simple {
248 cursor,
249 start,
250 simple_type,
251 }
252 .boxed(),
253 )),
254 }
255 } else {
256 Err(ParserError::StateError(format!(
257 "get_simple received wrong state type: {:#?}",
258 state
259 ))
260 .into())
261 }
262 }
263
264 fn get_bulk_string(&self, state: Box<State>) -> Result<StateResult> {
265 if let State::BulkString {
266 cursor,
267 start,
268 size: string_length,
269 } = *state
270 {
271 let (cursor, size) = match string_length {
272 None => match readsize(&self.buffer, cursor, start)? {
273 ReadsizeResult::None(cursor) => {
274 let state = State::BulkString {
275 cursor,
276 start,
277 size: None,
278 };
279 return Ok(StateResult::Incomplete(state.boxed()));
280 }
281 ReadsizeResult::Null(cursor) => {
282 let result = RespType::Null;
283 return Ok(StateResult::Done(result, cursor));
284 }
285 ReadsizeResult::Size { end, size } => (end, size),
286 },
287 Some(size) => (cursor, size),
288 };
289 if size > self.config.max_resp_size {
290 return Err(ParserError::SizeExceededError.into());
291 }
292
293 match readbuffer(&self.buffer, cursor, size) {
294 Some((vector, end)) => {
295 let result = RespType::BulkString(vector);
296 Ok(StateResult::Done(result, end))
297 }
298 None => {
299 let state = State::BulkString {
300 cursor,
301 start,
302 size: Some(size),
303 }
304 .boxed();
305 Ok(StateResult::Incomplete(state))
306 }
307 }
308 } else {
309 Err(ParserError::StateError(format!(
310 "get_bulk_string received wrong state type: {:#?}",
311 state
312 ))
313 .into())
314 }
315 }
316
317 fn get_array(&self, state: Box<State>) -> Result<StateResult> {
318 if let State::Array {
319 cursor,
320 start,
321 size: array_size,
322 elements,
323 mut substate,
324 } = *state
325 {
326 let (cursor, size) = match array_size {
327 Some(size) => (cursor, size),
328 None => match readsize(&self.buffer, cursor, start)? {
329 ReadsizeResult::None(cursor) => {
330 let state = State::Array {
331 cursor,
332 start,
333 size: None,
334 elements: None,
335 substate: None,
336 };
337 return Ok(StateResult::Incomplete(state.boxed()));
338 }
339 ReadsizeResult::Null(cursor) => {
340 let result = RespType::NullArray;
341 return Ok(StateResult::Done(result, cursor));
342 }
343 ReadsizeResult::Size { end, size } => {
344 if size == 0 {
345 let result = RespType::Array(Vec::new());
346 return Ok(StateResult::Done(result, end));
347 } else {
348 (end, size)
349 }
350 }
351 },
352 };
353 if size > self.config.max_resp_size {
354 return Err(ParserError::SizeExceededError.into());
355 }
356
357 let mut elements = match elements {
358 Some(elements) => elements,
359 None => Vec::new(),
360 };
361 let mut cursor = cursor;
362 while elements.len() < size {
363 let state = match substate {
364 Some(_) => substate.take().unwrap(),
365 None => State::get_type(cursor),
366 };
367 match self.process_state(state)? {
368 StateResult::Done(result, end) => {
369 cursor = end;
370 elements.push(result);
371 }
372 StateResult::Incomplete(substate) => {
373 let state = State::Array {
374 cursor,
375 start,
376 size: Some(size),
377 elements: Some(elements),
378 substate: Some(substate),
379 };
380 return Ok(StateResult::Incomplete(state.boxed()));
381 }
382 }
383 }
384 let result = RespType::Array(elements);
385 Ok(StateResult::Done(result, cursor))
386 } else {
387 Err(ParserError::StateError(format!(
388 "get_array received wrong state type: {:#?}",
389 state
390 ))
391 .into())
392 }
393 }
394}
395
396#[cfg(test)]
397mod tests {
398 use super::*;
399 use RespType::*;
400
401 fn test_parser_ok<'a, T>(buffer: T) -> Vec<RespType>
402 where
403 &'a [u8]: From<T>,
404 {
405 let mut parser = RespParser::default();
406 match parser.read(buffer.into()) {
407 Ok(results) => results,
408 other => panic!("result was not Ok(), was {:#?}", other),
409 }
410 }
411
412 fn test_parser_err<'a, T>(buffer: T)
413 where
414 &'a [u8]: From<T>,
415 {
416 let mut parser = RespParser::default();
417 let result = parser.read(buffer.into());
418 assert!(result.is_err());
419 }
420
421 fn assert_empty_result(results: Vec<RespType>) {
422 let result_length = results.len();
423 assert_eq!(
424 result_length, 0,
425 "result was not empty, contained {} elements",
426 result_length
427 );
428 }
429
430 fn assert_num_results(results: &Vec<RespType>, expected: usize) {
431 let result_length = results.len();
432 assert_eq!(
433 result_length, expected,
434 "result was of unexpected length, contained {} elements, expected {}",
435 result_length, expected
436 )
437 }
438
439 #[test]
440 fn empty_start() {
441 let results = test_parser_ok(b"");
442
443 assert_empty_result(results);
444 }
445
446 #[test]
447 fn complex_nested() {
448 let results = test_parser_ok(b"*3\r\n*-1\r\n*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n*5\r\n+test\r\n-test3\r\n:-12345\r\n$2\r\nab\r\n$-1\r\n");
449
450 assert_num_results(&results, 1);
451 if let Array(array) = &results[0] {
452 assert_eq!(array.len(), 3);
453 assert_eq!(array[0], RespType::NullArray);
454 if let Array(nested) = &array[1] {
455 assert_eq!(nested.len(), 2);
456 assert_eq!(nested[0], BulkString("hello".into()));
457 assert_eq!(nested[1], BulkString("world".into()));
458 } else {
459 panic!("Nested array at pos 1 expected")
460 }
461 if let Array(mixed) = &array[2] {
462 assert_eq!(mixed.len(), 5);
463 assert_eq!(mixed[0], SimpleString("test".into()));
464 assert_eq!(mixed[1], Error("test3".into()));
465 assert_eq!(mixed[2], Integer(-12345));
466 assert_eq!(mixed[3], BulkString("ab".into()));
467 assert_eq!(mixed[4], Null);
468 } else {
469 panic!("Mixed array at pos 2 expected")
470 }
471 } else {
472 panic!("Array type expected")
473 }
474 }
475
476 #[test]
477 fn complex_nested_onebyte() -> Result<()> {
478 let mut parser = RespParser::default();
479 for byte in b"*3\r\n*-1\r\n*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n*5\r\n+test\r\n-test3\r\n:-12345\r\n$2\r\nab\r\n$-1\r" {
480 let results = parser.read(&[*byte])?;
481 assert_eq!(results.len(), 0);
482 }
483
484 let results = parser.read(b"\n")?;
485
486 assert_num_results(&results, 1);
487 if let Array(array) = &results[0] {
488 assert_eq!(array.len(), 3);
489 assert_eq!(array[0], RespType::NullArray);
490 if let Array(nested) = &array[1] {
491 assert_eq!(nested.len(), 2);
492 assert_eq!(nested[0], BulkString("hello".into()));
493 assert_eq!(nested[1], BulkString("world".into()));
494 } else {
495 panic!("Nested array at pos 1 expected")
496 }
497 if let Array(mixed) = &array[2] {
498 assert_eq!(mixed.len(), 5);
499 assert_eq!(mixed[0], SimpleString("test".into()));
500 assert_eq!(mixed[1], Error("test3".into()));
501 assert_eq!(mixed[2], Integer(-12345));
502 assert_eq!(mixed[3], BulkString("ab".into()));
503 assert_eq!(mixed[4], Null);
504 Ok(())
505 } else {
506 panic!("Mixed array at pos 2 expected")
507 }
508 } else {
509 panic!("Array type expected")
510 }
511 }
512
513 mod simple_string {
514 use super::*;
515
516 fn assert_simple_string(elements: &Vec<RespType>, index: usize, expected: &str) {
517 let element = &elements.get(index);
518 assert!(element.is_some());
519
520 match element.unwrap() {
521 SimpleString(string) => {
522 assert_eq!(string, expected);
523 }
524 _ => {
525 panic!("Expected SimpleString variant")
526 }
527 };
528 }
529
530 #[test]
531 fn valid() {
532 let results = test_parser_ok(b"+Valid!\r\n");
533
534 assert_num_results(&results, 1);
535 assert_simple_string(&results, 0, "Valid!");
536 }
537
538 #[test]
539 fn valid_remainder() {
540 let results = test_parser_ok(b"+valid and then some\r\n+");
541
542 assert_num_results(&results, 1);
543 assert_simple_string(&results, 0, "valid and then some");
544 }
545
546 #[test]
547 fn valid_incomplete() {
548 let results = test_parser_ok(b"+OK\r");
549
550 assert_empty_result(results);
551 }
552
553 #[test]
554 fn invalid_char_after_cr() {
555 test_parser_err(b"+OK\rx");
556 }
557
558 #[test]
559 fn invalid_newline() {
560 test_parser_err(b"+OK\n\r\n");
561 }
562 }
563
564 mod error {
565 use super::*;
566
567 fn assert_error(elements: &Vec<RespType>, index: usize, expected: &str) {
568 let element = &elements.get(index);
569 assert!(element.is_some());
570
571 match element.unwrap() {
572 Error(string) => {
573 assert_eq!(string, expected);
574 }
575 _ => {
576 panic!("Expected Error variant")
577 }
578 };
579 }
580
581 #[test]
582 fn valid() {
583 let results = test_parser_ok(b"-Valid!\r\n");
584
585 assert_num_results(&results, 1);
586 assert_error(&results, 0, "Valid!");
587 }
588
589 #[test]
590 fn remainder() {
591 let results = test_parser_ok(b"-Valid!\r\n:");
592
593 assert_num_results(&results, 1);
594 assert_error(&results, 0, "Valid!");
595 }
596
597 #[test]
598 fn two() {
599 let results = test_parser_ok(b"-Valid!\r\n-andmore\r\n");
600
601 assert_num_results(&results, 2);
602 assert_error(&results, 0, "Valid!");
603 assert_error(&results, 1, "andmore");
604 }
605 }
606
607 mod integer {
608 use super::*;
609
610 fn assert_integer(elements: &Vec<RespType>, index: usize, expected: i64) {
611 let element = &elements.get(index);
612 assert!(element.is_some());
613
614 match element.unwrap() {
615 Integer(int) => {
616 assert_eq!(*int, expected);
617 }
618 _ => {
619 panic!("Expected Integer variant")
620 }
621 };
622 }
623
624 #[test]
625 fn valid() {
626 let results = test_parser_ok(b":1234\r\n");
627
628 assert_num_results(&results, 1);
629 assert_integer(&results, 0, 1234);
630 }
631
632 #[test]
633 fn valid_negative() {
634 let results = test_parser_ok(b":-1234\r\n");
635
636 assert_num_results(&results, 1);
637 assert_integer(&results, 0, -1234);
638 }
639
640 #[test]
641 fn invalid() {
642 test_parser_err(b":hi\r\n");
643 }
644 }
645
646 mod bulk_string {
647 use super::*;
648
649 fn assert_bulk_string(results: &Vec<RespType>, index: usize, expected: &[u8]) {
650 let element = &results.get(index);
651 assert!(element.is_some());
652
653 match element.unwrap() {
654 BulkString(string) => {
655 assert_eq!(string, expected);
656 }
657 _ => {
658 panic!("Expected BulkString variant")
659 }
660 };
661 }
662
663 #[test]
664 fn valid() {
665 let results = test_parser_ok(b"$6\r\nValid!\r\n");
666
667 assert_num_results(&results, 1);
668 assert_bulk_string(&results, 0, "Valid!".as_bytes());
669 }
670
671 #[test]
672 fn two() {
673 let results = test_parser_ok(b"$6\r\nValid!\r\n$5\r\nwooo!\r\n");
674
675 assert_num_results(&results, 2);
676 assert_bulk_string(&results, 0, "Valid!".as_bytes());
677 assert_bulk_string(&results, 1, "wooo!".as_bytes());
678 }
679
680 #[test]
681 fn remainder() {
682 let results = test_parser_ok(b"$6\r\nValid!\r\n+OK");
683
684 assert_num_results(&results, 1);
685 assert_bulk_string(&results, 0, "Valid!".as_bytes());
686 }
687
688 #[test]
689 fn empty() {
690 let results = test_parser_ok(b"$0\r\n\r\n");
691
692 assert_num_results(&results, 1);
693 }
694
695 #[test]
696 fn null() {
697 let results = test_parser_ok(b"$-1\r\n");
698
699 assert_num_results(&results, 1);
700 }
701 }
702
703 mod array {
704 use super::*;
705
706 fn _assert_array_length(array: &RespType, length: usize) {
707 match array {
708 Array(array) => {
709 assert_eq!(array.len(), length);
710 }
711 _ => {
712 panic!("Expected Array variant")
713 }
714 };
715 }
716
717 #[test]
718 fn start() {
719 let results = test_parser_ok(b"*");
720 assert_empty_result(results);
721 }
722
723 #[test]
724 fn hello_world() {
725 let results = test_parser_ok(b"*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n");
726
727 assert_num_results(&results, 1);
728 }
729
730 #[test]
731 fn nested() {
732 let results = test_parser_ok(b"*1\r\n*3\r\n$5\r\nhello\r\n+ok\r\n*-1\r\n");
733 assert_num_results(&results, 1);
734 }
735
736 #[test]
737 fn null() -> Result<()> {
738 let results = test_parser_ok(b"*-1\r\n");
739
740 assert_num_results(&results, 1);
741 match results.first().unwrap() {
742 RespType::NullArray => {
743 }
745 _ => {
746 panic!("null array expected");
747 }
748 };
749 Ok(())
750 }
751
752 #[test]
753 fn empty() {
754 let results = test_parser_ok(b"*0\r\n");
755
756 assert_num_results(&results, 1);
757 }
758 }
759}