1use bytes::{BufMut, Bytes, BytesMut};
11
12use crate::ParseError;
13
14const MAX_COLLECTION_SIZE: usize = 10_000_000;
16
17#[derive(Debug, Clone, PartialEq)]
19pub enum Frame {
20 SimpleString(Bytes),
22 Error(Bytes),
24 Integer(i64),
26 BulkString(Option<Bytes>),
28 Array(Option<Vec<Frame>>),
30}
31
32pub fn parse_frame(input: Bytes) -> Result<(Frame, Bytes), ParseError> {
52 let (frame, consumed) = parse_frame_inner(&input, 0)?;
53 Ok((frame, input.slice(consumed..)))
54}
55
56fn parse_frame_inner(input: &Bytes, pos: usize) -> Result<(Frame, usize), ParseError> {
59 let buf = input.as_ref();
60 if pos >= buf.len() {
61 return Err(ParseError::Incomplete);
62 }
63
64 let tag = buf[pos];
65
66 match tag {
67 b'+' => {
68 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
69 Ok((
70 Frame::SimpleString(input.slice(pos + 1..line_end)),
71 after_crlf,
72 ))
73 }
74 b'-' => {
75 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
76 Ok((Frame::Error(input.slice(pos + 1..line_end)), after_crlf))
77 }
78 b':' => {
79 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
80 let v = parse_i64(&buf[pos + 1..line_end])?;
81 Ok((Frame::Integer(v), after_crlf))
82 }
83 b'$' => {
84 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
85 let len_bytes = &buf[pos + 1..line_end];
86 if len_bytes == b"-1" {
88 return Ok((Frame::BulkString(None), after_crlf));
89 }
90 let len = parse_usize(len_bytes)?;
91 if len == 0 {
92 if after_crlf + 1 >= buf.len() {
93 return Err(ParseError::Incomplete);
94 }
95 if buf[after_crlf] == b'\r' && buf[after_crlf + 1] == b'\n' {
96 return Ok((Frame::BulkString(Some(Bytes::new())), after_crlf + 2));
97 } else {
98 return Err(ParseError::InvalidFormat);
99 }
100 }
101 let data_start = after_crlf;
102 let data_end = data_start.checked_add(len).ok_or(ParseError::BadLength)?;
103 if data_end + 1 >= buf.len() || buf[data_end] != b'\r' || buf[data_end + 1] != b'\n' {
104 return Err(ParseError::Incomplete);
105 }
106 Ok((
107 Frame::BulkString(Some(input.slice(data_start..data_end))),
108 data_end + 2,
109 ))
110 }
111 b'*' => {
112 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
113 let len_bytes = &buf[pos + 1..line_end];
114 if len_bytes == b"-1" {
116 return Ok((Frame::Array(None), after_crlf));
117 }
118 let count = parse_usize(len_bytes)?;
119 if count > MAX_COLLECTION_SIZE {
120 return Err(ParseError::BadLength);
121 }
122 if count == 0 {
123 return Ok((Frame::Array(Some(Vec::new())), after_crlf));
124 }
125 let mut cursor = after_crlf;
126 let mut items = Vec::with_capacity(count);
127 for _ in 0..count {
128 let (item, next) = parse_frame_inner(input, cursor)?;
129 items.push(item);
130 cursor = next;
131 }
132 Ok((Frame::Array(Some(items)), cursor))
133 }
134 _ => Err(ParseError::InvalidTag(tag)),
135 }
136}
137
138pub fn frame_to_bytes(frame: &Frame) -> Bytes {
150 let mut buf = BytesMut::new();
151 serialize_frame(frame, &mut buf);
152 buf.freeze()
153}
154
155fn serialize_frame(frame: &Frame, buf: &mut BytesMut) {
156 match frame {
157 Frame::SimpleString(s) => {
158 buf.put_u8(b'+');
159 buf.extend_from_slice(s);
160 buf.extend_from_slice(b"\r\n");
161 }
162 Frame::Error(s) => {
163 buf.put_u8(b'-');
164 buf.extend_from_slice(s);
165 buf.extend_from_slice(b"\r\n");
166 }
167 Frame::Integer(i) => {
168 buf.put_u8(b':');
169 buf.extend_from_slice(i.to_string().as_bytes());
170 buf.extend_from_slice(b"\r\n");
171 }
172 Frame::BulkString(opt) => {
173 buf.put_u8(b'$');
174 match opt {
175 Some(data) => {
176 buf.extend_from_slice(data.len().to_string().as_bytes());
177 buf.extend_from_slice(b"\r\n");
178 buf.extend_from_slice(data);
179 buf.extend_from_slice(b"\r\n");
180 }
181 None => buf.extend_from_slice(b"-1\r\n"),
182 }
183 }
184 Frame::Array(opt) => {
185 buf.put_u8(b'*');
186 match opt {
187 Some(items) => {
188 buf.extend_from_slice(items.len().to_string().as_bytes());
189 buf.extend_from_slice(b"\r\n");
190 for item in items {
191 serialize_frame(item, buf);
192 }
193 }
194 None => buf.extend_from_slice(b"-1\r\n"),
195 }
196 }
197 }
198}
199
200#[derive(Default, Debug)]
219pub struct Parser {
220 buffer: BytesMut,
221}
222
223impl Parser {
224 pub fn new() -> Self {
226 Self {
227 buffer: BytesMut::new(),
228 }
229 }
230
231 pub fn feed(&mut self, data: Bytes) {
233 self.buffer.extend_from_slice(&data);
234 }
235
236 pub fn next_frame(&mut self) -> Result<Option<Frame>, ParseError> {
241 if self.buffer.is_empty() {
242 return Ok(None);
243 }
244
245 let bytes = self.buffer.split().freeze();
246
247 match parse_frame_inner(&bytes, 0) {
248 Ok((frame, consumed)) => {
249 if consumed < bytes.len() {
250 self.buffer.unsplit(BytesMut::from(&bytes[consumed..]));
251 }
252 Ok(Some(frame))
253 }
254 Err(ParseError::Incomplete) => {
255 self.buffer.unsplit(bytes.into());
256 Ok(None)
257 }
258 Err(e) => Err(e),
259 }
260 }
261
262 pub fn buffered_bytes(&self) -> usize {
264 self.buffer.len()
265 }
266
267 pub fn clear(&mut self) {
269 self.buffer.clear();
270 }
271}
272
273#[inline]
276fn find_crlf(buf: &[u8], from: usize) -> Result<(usize, usize), ParseError> {
277 let mut i = from;
278 let len = buf.len();
279 while i + 1 < len {
280 if buf[i] == b'\r' && buf[i + 1] == b'\n' {
281 return Ok((i, i + 2));
282 }
283 i += 1;
284 }
285 Err(ParseError::Incomplete)
286}
287
288#[inline]
290fn parse_usize(buf: &[u8]) -> Result<usize, ParseError> {
291 if buf.is_empty() {
292 return Err(ParseError::BadLength);
293 }
294 let mut v: usize = 0;
295 for &b in buf {
296 if !b.is_ascii_digit() {
297 return Err(ParseError::BadLength);
298 }
299 v = v.checked_mul(10).ok_or(ParseError::BadLength)?;
300 v = v
301 .checked_add((b - b'0') as usize)
302 .ok_or(ParseError::BadLength)?;
303 }
304 Ok(v)
305}
306
307#[inline]
309fn parse_i64(buf: &[u8]) -> Result<i64, ParseError> {
310 if buf.is_empty() {
311 return Err(ParseError::InvalidFormat);
312 }
313 let (neg, digits) = if buf[0] == b'-' {
314 (true, &buf[1..])
315 } else {
316 (false, buf)
317 };
318 if digits.is_empty() {
319 return Err(ParseError::InvalidFormat);
320 }
321 let mut v: i64 = 0;
322 for (i, &d) in digits.iter().enumerate() {
323 if !d.is_ascii_digit() {
324 return Err(ParseError::InvalidFormat);
325 }
326 let digit = (d - b'0') as i64;
327 if neg && v == i64::MAX / 10 && digit == 8 && i == digits.len() - 1 {
328 return Ok(i64::MIN);
329 }
330 if v > i64::MAX / 10 || (v == i64::MAX / 10 && digit > i64::MAX % 10) {
331 return Err(ParseError::Overflow);
332 }
333 v = v * 10 + digit;
334 }
335 if neg { Ok(-v) } else { Ok(v) }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn simple_string() {
344 let (frame, rest) = parse_frame(Bytes::from("+OK\r\nrest")).unwrap();
345 assert_eq!(frame, Frame::SimpleString(Bytes::from("OK")));
346 assert_eq!(rest, Bytes::from("rest"));
347 }
348
349 #[test]
350 fn error() {
351 let (frame, _) = parse_frame(Bytes::from("-ERR fail\r\n")).unwrap();
352 assert_eq!(frame, Frame::Error(Bytes::from("ERR fail")));
353 }
354
355 #[test]
356 fn integer() {
357 let (frame, _) = parse_frame(Bytes::from(":42\r\n")).unwrap();
358 assert_eq!(frame, Frame::Integer(42));
359
360 let (frame, _) = parse_frame(Bytes::from(":-123\r\n")).unwrap();
361 assert_eq!(frame, Frame::Integer(-123));
362 }
363
364 #[test]
365 fn bulk_string() {
366 let (frame, rest) = parse_frame(Bytes::from("$5\r\nhello\r\nX")).unwrap();
367 assert_eq!(frame, Frame::BulkString(Some(Bytes::from("hello"))));
368 assert_eq!(rest, Bytes::from("X"));
369 }
370
371 #[test]
372 fn null_bulk_string() {
373 let (frame, _) = parse_frame(Bytes::from("$-1\r\n")).unwrap();
374 assert_eq!(frame, Frame::BulkString(None));
375 }
376
377 #[test]
378 fn empty_bulk_string() {
379 let (frame, rest) = parse_frame(Bytes::from("$0\r\n\r\nX")).unwrap();
380 assert_eq!(frame, Frame::BulkString(Some(Bytes::new())));
381 assert_eq!(rest, Bytes::from("X"));
382 }
383
384 #[test]
385 fn array() {
386 let input = Bytes::from("*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n");
387 let (frame, _) = parse_frame(input).unwrap();
388 assert_eq!(
389 frame,
390 Frame::Array(Some(vec![
391 Frame::BulkString(Some(Bytes::from("foo"))),
392 Frame::BulkString(Some(Bytes::from("bar"))),
393 ]))
394 );
395 }
396
397 #[test]
398 fn null_array() {
399 let (frame, _) = parse_frame(Bytes::from("*-1\r\n")).unwrap();
400 assert_eq!(frame, Frame::Array(None));
401 }
402
403 #[test]
404 fn empty_array() {
405 let (frame, _) = parse_frame(Bytes::from("*0\r\n")).unwrap();
406 assert_eq!(frame, Frame::Array(Some(vec![])));
407 }
408
409 #[test]
410 fn nested_array() {
411 let input = Bytes::from("*2\r\n*1\r\n:1\r\n+OK\r\n");
412 let (frame, _) = parse_frame(input).unwrap();
413 assert_eq!(
414 frame,
415 Frame::Array(Some(vec![
416 Frame::Array(Some(vec![Frame::Integer(1)])),
417 Frame::SimpleString(Bytes::from("OK")),
418 ]))
419 );
420 }
421
422 #[test]
423 fn incomplete() {
424 assert_eq!(parse_frame(Bytes::new()), Err(ParseError::Incomplete));
425 assert_eq!(
426 parse_frame(Bytes::from("+OK\r")),
427 Err(ParseError::Incomplete)
428 );
429 assert_eq!(
430 parse_frame(Bytes::from("$5\r\nhel")),
431 Err(ParseError::Incomplete)
432 );
433 }
434
435 #[test]
436 fn invalid_tag() {
437 assert_eq!(
438 parse_frame(Bytes::from("X\r\n")),
439 Err(ParseError::InvalidTag(b'X'))
440 );
441 }
442
443 #[test]
444 fn roundtrip() {
445 let frames = vec![
446 Frame::SimpleString(Bytes::from("OK")),
447 Frame::Error(Bytes::from("ERR bad")),
448 Frame::Integer(42),
449 Frame::BulkString(Some(Bytes::from("hello"))),
450 Frame::BulkString(None),
451 Frame::Array(Some(vec![
452 Frame::Integer(1),
453 Frame::BulkString(Some(Bytes::from("two"))),
454 ])),
455 Frame::Array(None),
456 ];
457 for frame in &frames {
458 let bytes = frame_to_bytes(frame);
459 let (parsed, rest) = parse_frame(bytes).unwrap();
460 assert_eq!(&parsed, frame);
461 assert!(rest.is_empty());
462 }
463 }
464
465 #[test]
466 fn streaming_parser() {
467 let mut parser = Parser::new();
468 parser.feed(Bytes::from("+HEL"));
469 assert!(parser.next_frame().unwrap().is_none());
470
471 parser.feed(Bytes::from("LO\r\n:42\r\n"));
472 let f1 = parser.next_frame().unwrap().unwrap();
473 assert_eq!(f1, Frame::SimpleString(Bytes::from("HELLO")));
474
475 let f2 = parser.next_frame().unwrap().unwrap();
476 assert_eq!(f2, Frame::Integer(42));
477
478 assert!(parser.next_frame().unwrap().is_none());
479 }
480
481 #[test]
482 fn chained_frames() {
483 let input = Bytes::from("+OK\r\n:1\r\n$3\r\nfoo\r\n");
484 let (f1, rest) = parse_frame(input).unwrap();
485 assert_eq!(f1, Frame::SimpleString(Bytes::from("OK")));
486 let (f2, rest) = parse_frame(rest).unwrap();
487 assert_eq!(f2, Frame::Integer(1));
488 let (f3, rest) = parse_frame(rest).unwrap();
489 assert_eq!(f3, Frame::BulkString(Some(Bytes::from("foo"))));
490 assert!(rest.is_empty());
491 }
492
493 #[test]
494 fn binary_bulk_string() {
495 let mut data = Vec::new();
496 data.extend_from_slice(b"$5\r\n");
497 data.extend_from_slice(&[0x00, 0x01, 0xFF, 0xFE, 0x42]);
498 data.extend_from_slice(b"\r\n");
499 let (frame, _) = parse_frame(Bytes::from(data)).unwrap();
500 match frame {
501 Frame::BulkString(Some(b)) => {
502 assert_eq!(b.as_ref(), &[0x00, 0x01, 0xFF, 0xFE, 0x42]);
503 }
504 _ => panic!("expected bulk string"),
505 }
506 }
507
508 #[test]
509 fn rejects_resp3_types() {
510 assert!(parse_frame(Bytes::from("_\r\n")).is_err()); assert!(parse_frame(Bytes::from(",3.14\r\n")).is_err()); assert!(parse_frame(Bytes::from("#t\r\n")).is_err()); assert!(parse_frame(Bytes::from("(123\r\n")).is_err()); }
516}