1use bytes::{BufMut, Bytes, BytesMut};
25
26use crate::ParseError;
27
28const MAX_COLLECTION_SIZE: usize = 10_000_000;
30
31const MAX_BULK_STRING_SIZE: usize = 512 * 1024 * 1024;
33
34#[derive(Debug, Clone, PartialEq)]
36pub enum Frame {
37 SimpleString(Bytes),
39 Error(Bytes),
41 Integer(i64),
43 BulkString(Option<Bytes>),
45 Array(Option<Vec<Frame>>),
47}
48
49pub fn parse_frame(input: Bytes) -> Result<(Frame, Bytes), ParseError> {
69 let (frame, consumed) = parse_frame_inner(&input, 0)?;
70 Ok((frame, input.slice(consumed..)))
71}
72
73fn parse_frame_inner(input: &Bytes, pos: usize) -> Result<(Frame, usize), ParseError> {
76 let buf = input.as_ref();
77 if pos >= buf.len() {
78 return Err(ParseError::Incomplete);
79 }
80
81 let tag = buf[pos];
82
83 match tag {
84 b'+' => {
85 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
86 Ok((
87 Frame::SimpleString(input.slice(pos + 1..line_end)),
88 after_crlf,
89 ))
90 }
91 b'-' => {
92 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
93 Ok((Frame::Error(input.slice(pos + 1..line_end)), after_crlf))
94 }
95 b':' => {
96 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
97 let v = parse_i64(&buf[pos + 1..line_end])?;
98 Ok((Frame::Integer(v), after_crlf))
99 }
100 b'$' => {
101 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
102 let len_bytes = &buf[pos + 1..line_end];
103 if len_bytes == b"-1" {
105 return Ok((Frame::BulkString(None), after_crlf));
106 }
107 let len = parse_usize(len_bytes)?;
108 if len > MAX_BULK_STRING_SIZE {
109 return Err(ParseError::BadLength);
110 }
111 if len == 0 {
112 if after_crlf + 1 >= buf.len() {
113 return Err(ParseError::Incomplete);
114 }
115 if buf[after_crlf] == b'\r' && buf[after_crlf + 1] == b'\n' {
116 return Ok((Frame::BulkString(Some(Bytes::new())), after_crlf + 2));
117 } else {
118 return Err(ParseError::InvalidFormat);
119 }
120 }
121 let data_start = after_crlf;
122 let data_end = data_start.checked_add(len).ok_or(ParseError::BadLength)?;
123 if data_end + 1 >= buf.len() {
124 return Err(ParseError::Incomplete);
125 }
126 if buf[data_end] != b'\r' || buf[data_end + 1] != b'\n' {
127 return Err(ParseError::InvalidFormat);
128 }
129 Ok((
130 Frame::BulkString(Some(input.slice(data_start..data_end))),
131 data_end + 2,
132 ))
133 }
134 b'*' => {
135 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
136 let len_bytes = &buf[pos + 1..line_end];
137 if len_bytes == b"-1" {
139 return Ok((Frame::Array(None), after_crlf));
140 }
141 let count = parse_count(len_bytes)?;
142 if count == 0 {
143 return Ok((Frame::Array(Some(Vec::new())), after_crlf));
144 }
145 let mut cursor = after_crlf;
146 let mut items = Vec::with_capacity(count);
147 for _ in 0..count {
148 let (item, next) = parse_frame_inner(input, cursor)?;
149 items.push(item);
150 cursor = next;
151 }
152 Ok((Frame::Array(Some(items)), cursor))
153 }
154 _ => Err(ParseError::InvalidTag(tag)),
155 }
156}
157
158pub fn frame_to_bytes(frame: &Frame) -> Bytes {
170 let mut buf = BytesMut::new();
171 serialize_frame(frame, &mut buf);
172 buf.freeze()
173}
174
175fn serialize_frame(frame: &Frame, buf: &mut BytesMut) {
176 match frame {
177 Frame::SimpleString(s) => {
178 buf.put_u8(b'+');
179 buf.extend_from_slice(s);
180 buf.extend_from_slice(b"\r\n");
181 }
182 Frame::Error(s) => {
183 buf.put_u8(b'-');
184 buf.extend_from_slice(s);
185 buf.extend_from_slice(b"\r\n");
186 }
187 Frame::Integer(i) => {
188 buf.put_u8(b':');
189 buf.extend_from_slice(i.to_string().as_bytes());
190 buf.extend_from_slice(b"\r\n");
191 }
192 Frame::BulkString(opt) => {
193 buf.put_u8(b'$');
194 match opt {
195 Some(data) => {
196 buf.extend_from_slice(data.len().to_string().as_bytes());
197 buf.extend_from_slice(b"\r\n");
198 buf.extend_from_slice(data);
199 buf.extend_from_slice(b"\r\n");
200 }
201 None => buf.extend_from_slice(b"-1\r\n"),
202 }
203 }
204 Frame::Array(opt) => {
205 buf.put_u8(b'*');
206 match opt {
207 Some(items) => {
208 buf.extend_from_slice(items.len().to_string().as_bytes());
209 buf.extend_from_slice(b"\r\n");
210 for item in items {
211 serialize_frame(item, buf);
212 }
213 }
214 None => buf.extend_from_slice(b"-1\r\n"),
215 }
216 }
217 }
218}
219
220#[derive(Default, Debug)]
239pub struct Parser {
240 buffer: BytesMut,
241}
242
243impl Parser {
244 pub fn new() -> Self {
246 Self {
247 buffer: BytesMut::new(),
248 }
249 }
250
251 pub fn feed(&mut self, data: Bytes) {
253 self.buffer.extend_from_slice(&data);
254 }
255
256 pub fn next_frame(&mut self) -> Result<Option<Frame>, ParseError> {
261 if self.buffer.is_empty() {
262 return Ok(None);
263 }
264
265 let bytes = self.buffer.split().freeze();
266
267 match parse_frame_inner(&bytes, 0) {
268 Ok((frame, consumed)) => {
269 if consumed < bytes.len() {
270 self.buffer.unsplit(BytesMut::from(&bytes[consumed..]));
271 }
272 Ok(Some(frame))
273 }
274 Err(ParseError::Incomplete) => {
275 self.buffer.unsplit(bytes.into());
276 Ok(None)
277 }
278 Err(e) => {
279 Err(e)
282 }
283 }
284 }
285
286 pub fn buffered_bytes(&self) -> usize {
288 self.buffer.len()
289 }
290
291 pub fn clear(&mut self) {
293 self.buffer.clear();
294 }
295}
296
297#[inline]
300fn find_crlf(buf: &[u8], from: usize) -> Result<(usize, usize), ParseError> {
301 let mut i = from;
302 let len = buf.len();
303 while i + 1 < len {
304 if buf[i] == b'\r' && buf[i + 1] == b'\n' {
305 return Ok((i, i + 2));
306 }
307 i += 1;
308 }
309 Err(ParseError::Incomplete)
310}
311
312#[inline]
314fn parse_usize(buf: &[u8]) -> Result<usize, ParseError> {
315 if buf.is_empty() {
316 return Err(ParseError::BadLength);
317 }
318 let mut v: usize = 0;
319 for &b in buf {
320 if !b.is_ascii_digit() {
321 return Err(ParseError::BadLength);
322 }
323 v = v.checked_mul(10).ok_or(ParseError::BadLength)?;
324 v = v
325 .checked_add((b - b'0') as usize)
326 .ok_or(ParseError::BadLength)?;
327 }
328 Ok(v)
329}
330
331#[inline]
333fn parse_count(buf: &[u8]) -> Result<usize, ParseError> {
334 let count = parse_usize(buf)?;
335 if count > MAX_COLLECTION_SIZE {
336 return Err(ParseError::BadLength);
337 }
338 Ok(count)
339}
340
341#[inline]
343fn parse_i64(buf: &[u8]) -> Result<i64, ParseError> {
344 if buf.is_empty() {
345 return Err(ParseError::InvalidFormat);
346 }
347 let (neg, digits) = if buf[0] == b'-' {
348 (true, &buf[1..])
349 } else {
350 (false, buf)
351 };
352 if digits.is_empty() {
353 return Err(ParseError::InvalidFormat);
354 }
355 let mut v: i64 = 0;
356 for (i, &d) in digits.iter().enumerate() {
357 if !d.is_ascii_digit() {
358 return Err(ParseError::InvalidFormat);
359 }
360 let digit = (d - b'0') as i64;
361 if neg && v == i64::MAX / 10 && digit == 8 && i == digits.len() - 1 {
362 return Ok(i64::MIN);
363 }
364 if v > i64::MAX / 10 || (v == i64::MAX / 10 && digit > i64::MAX % 10) {
365 return Err(ParseError::Overflow);
366 }
367 v = v * 10 + digit;
368 }
369 if neg { Ok(-v) } else { Ok(v) }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 #[test]
377 fn simple_string() {
378 let (frame, rest) = parse_frame(Bytes::from("+OK\r\nrest")).unwrap();
379 assert_eq!(frame, Frame::SimpleString(Bytes::from("OK")));
380 assert_eq!(rest, Bytes::from("rest"));
381 }
382
383 #[test]
384 fn error() {
385 let (frame, _) = parse_frame(Bytes::from("-ERR fail\r\n")).unwrap();
386 assert_eq!(frame, Frame::Error(Bytes::from("ERR fail")));
387 }
388
389 #[test]
390 fn integer() {
391 let (frame, _) = parse_frame(Bytes::from(":42\r\n")).unwrap();
392 assert_eq!(frame, Frame::Integer(42));
393
394 let (frame, _) = parse_frame(Bytes::from(":-123\r\n")).unwrap();
395 assert_eq!(frame, Frame::Integer(-123));
396 }
397
398 #[test]
399 fn bulk_string() {
400 let (frame, rest) = parse_frame(Bytes::from("$5\r\nhello\r\nX")).unwrap();
401 assert_eq!(frame, Frame::BulkString(Some(Bytes::from("hello"))));
402 assert_eq!(rest, Bytes::from("X"));
403 }
404
405 #[test]
406 fn null_bulk_string() {
407 let (frame, _) = parse_frame(Bytes::from("$-1\r\n")).unwrap();
408 assert_eq!(frame, Frame::BulkString(None));
409 }
410
411 #[test]
412 fn empty_bulk_string() {
413 let (frame, rest) = parse_frame(Bytes::from("$0\r\n\r\nX")).unwrap();
414 assert_eq!(frame, Frame::BulkString(Some(Bytes::new())));
415 assert_eq!(rest, Bytes::from("X"));
416 }
417
418 #[test]
419 fn array() {
420 let input = Bytes::from("*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n");
421 let (frame, _) = parse_frame(input).unwrap();
422 assert_eq!(
423 frame,
424 Frame::Array(Some(vec![
425 Frame::BulkString(Some(Bytes::from("foo"))),
426 Frame::BulkString(Some(Bytes::from("bar"))),
427 ]))
428 );
429 }
430
431 #[test]
432 fn null_array() {
433 let (frame, _) = parse_frame(Bytes::from("*-1\r\n")).unwrap();
434 assert_eq!(frame, Frame::Array(None));
435 }
436
437 #[test]
438 fn empty_array() {
439 let (frame, _) = parse_frame(Bytes::from("*0\r\n")).unwrap();
440 assert_eq!(frame, Frame::Array(Some(vec![])));
441 }
442
443 #[test]
444 fn nested_array() {
445 let input = Bytes::from("*2\r\n*1\r\n:1\r\n+OK\r\n");
446 let (frame, _) = parse_frame(input).unwrap();
447 assert_eq!(
448 frame,
449 Frame::Array(Some(vec![
450 Frame::Array(Some(vec![Frame::Integer(1)])),
451 Frame::SimpleString(Bytes::from("OK")),
452 ]))
453 );
454 }
455
456 #[test]
457 fn incomplete() {
458 assert_eq!(parse_frame(Bytes::new()), Err(ParseError::Incomplete));
459 assert_eq!(
460 parse_frame(Bytes::from("+OK\r")),
461 Err(ParseError::Incomplete)
462 );
463 assert_eq!(
464 parse_frame(Bytes::from("$5\r\nhel")),
465 Err(ParseError::Incomplete)
466 );
467 }
468
469 #[test]
470 fn invalid_tag() {
471 assert_eq!(
472 parse_frame(Bytes::from("X\r\n")),
473 Err(ParseError::InvalidTag(b'X'))
474 );
475 }
476
477 #[test]
478 fn roundtrip() {
479 let frames = vec![
480 Frame::SimpleString(Bytes::from("OK")),
481 Frame::Error(Bytes::from("ERR bad")),
482 Frame::Integer(42),
483 Frame::BulkString(Some(Bytes::from("hello"))),
484 Frame::BulkString(None),
485 Frame::Array(Some(vec![
486 Frame::Integer(1),
487 Frame::BulkString(Some(Bytes::from("two"))),
488 ])),
489 Frame::Array(None),
490 ];
491 for frame in &frames {
492 let bytes = frame_to_bytes(frame);
493 let (parsed, rest) = parse_frame(bytes).unwrap();
494 assert_eq!(&parsed, frame);
495 assert!(rest.is_empty());
496 }
497 }
498
499 #[test]
500 fn streaming_parser() {
501 let mut parser = Parser::new();
502 parser.feed(Bytes::from("+HEL"));
503 assert!(parser.next_frame().unwrap().is_none());
504
505 parser.feed(Bytes::from("LO\r\n:42\r\n"));
506 let f1 = parser.next_frame().unwrap().unwrap();
507 assert_eq!(f1, Frame::SimpleString(Bytes::from("HELLO")));
508
509 let f2 = parser.next_frame().unwrap().unwrap();
510 assert_eq!(f2, Frame::Integer(42));
511
512 assert!(parser.next_frame().unwrap().is_none());
513 }
514
515 #[test]
516 fn chained_frames() {
517 let input = Bytes::from("+OK\r\n:1\r\n$3\r\nfoo\r\n");
518 let (f1, rest) = parse_frame(input).unwrap();
519 assert_eq!(f1, Frame::SimpleString(Bytes::from("OK")));
520 let (f2, rest) = parse_frame(rest).unwrap();
521 assert_eq!(f2, Frame::Integer(1));
522 let (f3, rest) = parse_frame(rest).unwrap();
523 assert_eq!(f3, Frame::BulkString(Some(Bytes::from("foo"))));
524 assert!(rest.is_empty());
525 }
526
527 #[test]
528 fn binary_bulk_string() {
529 let mut data = Vec::new();
530 data.extend_from_slice(b"$5\r\n");
531 data.extend_from_slice(&[0x00, 0x01, 0xFF, 0xFE, 0x42]);
532 data.extend_from_slice(b"\r\n");
533 let (frame, _) = parse_frame(Bytes::from(data)).unwrap();
534 match frame {
535 Frame::BulkString(Some(b)) => {
536 assert_eq!(b.as_ref(), &[0x00, 0x01, 0xFF, 0xFE, 0x42]);
537 }
538 _ => panic!("expected bulk string"),
539 }
540 }
541
542 #[test]
543 fn rejects_resp3_types() {
544 assert!(parse_frame(Bytes::from("_\r\n")).is_err()); assert!(parse_frame(Bytes::from(",3.14\r\n")).is_err()); assert!(parse_frame(Bytes::from("#t\r\n")).is_err()); assert!(parse_frame(Bytes::from("(123\r\n")).is_err()); }
550
551 #[test]
552 fn integer_overflow() {
553 assert_eq!(
555 parse_frame(Bytes::from(":9223372036854775808\r\n")),
556 Err(ParseError::Overflow)
557 );
558
559 let (frame, _) = parse_frame(Bytes::from(":9223372036854775807\r\n")).unwrap();
561 assert_eq!(frame, Frame::Integer(i64::MAX));
562
563 let (frame, _) = parse_frame(Bytes::from(":-9223372036854775808\r\n")).unwrap();
565 assert_eq!(frame, Frame::Integer(i64::MIN));
566
567 assert!(parse_frame(Bytes::from(":-9223372036854775809\r\n")).is_err());
569 }
570
571 #[test]
572 fn zero_length_bulk_edge_cases() {
573 assert_eq!(
575 parse_frame(Bytes::from("$0\r\n")),
576 Err(ParseError::Incomplete)
577 );
578
579 assert_eq!(
581 parse_frame(Bytes::from("$0\r\n\r")),
582 Err(ParseError::Incomplete)
583 );
584
585 assert_eq!(
587 parse_frame(Bytes::from("$0\r\nXY")),
588 Err(ParseError::InvalidFormat)
589 );
590 }
591
592 #[test]
593 fn nonempty_bulk_malformed_terminator() {
594 assert_eq!(
596 parse_frame(Bytes::from("$3\r\nfoo")),
597 Err(ParseError::Incomplete)
598 );
599
600 assert_eq!(
602 parse_frame(Bytes::from("$3\r\nfooX")),
603 Err(ParseError::Incomplete)
604 );
605
606 assert_eq!(
608 parse_frame(Bytes::from("$3\r\nfooXY")),
609 Err(ParseError::InvalidFormat)
610 );
611 }
612
613 #[test]
614 fn array_size_limit() {
615 assert_eq!(
617 parse_frame(Bytes::from("*10000001\r\n")),
618 Err(ParseError::BadLength)
619 );
620
621 assert_eq!(
623 parse_frame(Bytes::from("*10000000\r\n")),
624 Err(ParseError::Incomplete)
625 );
626 }
627
628 #[test]
629 fn bulk_string_size_limit() {
630 assert_eq!(
632 parse_frame(Bytes::from("$536870913\r\n")),
633 Err(ParseError::BadLength)
634 );
635 }
636
637 #[test]
638 fn streaming_parser_clears_buffer_on_error() {
639 let mut parser = Parser::new();
640 parser.feed(Bytes::from("X\r\n"));
641 assert_eq!(parser.next_frame(), Err(ParseError::InvalidTag(b'X')));
642 assert_eq!(parser.buffered_bytes(), 0);
643 }
644
645 #[test]
646 fn streaming_parser_recovers_after_error() {
647 let mut parser = Parser::new();
648 parser.feed(Bytes::from("X\r\n"));
650 assert!(parser.next_frame().is_err());
651 assert_eq!(parser.buffered_bytes(), 0);
652
653 parser.feed(Bytes::from("+OK\r\n"));
655 let frame = parser.next_frame().unwrap().unwrap();
656 assert_eq!(frame, Frame::SimpleString(Bytes::from("OK")));
657 }
658}