1use bytes::{BufMut, Bytes, BytesMut};
11
12use crate::ParseError;
13
14#[derive(Debug, Clone, PartialEq)]
16pub enum Frame {
17 SimpleString(Bytes),
19 Error(Bytes),
21 Integer(i64),
23 BulkString(Option<Bytes>),
25 Array(Option<Vec<Frame>>),
27}
28
29pub fn parse_frame(input: Bytes) -> Result<(Frame, Bytes), ParseError> {
49 let (frame, consumed) = parse_frame_inner(&input, 0)?;
50 Ok((frame, input.slice(consumed..)))
51}
52
53fn parse_frame_inner(input: &Bytes, pos: usize) -> Result<(Frame, usize), ParseError> {
56 let buf = input.as_ref();
57 if pos >= buf.len() {
58 return Err(ParseError::Incomplete);
59 }
60
61 let tag = buf[pos];
62
63 match tag {
64 b'+' => {
65 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
66 Ok((
67 Frame::SimpleString(input.slice(pos + 1..line_end)),
68 after_crlf,
69 ))
70 }
71 b'-' => {
72 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
73 Ok((Frame::Error(input.slice(pos + 1..line_end)), after_crlf))
74 }
75 b':' => {
76 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
77 let v = parse_i64(&buf[pos + 1..line_end])?;
78 Ok((Frame::Integer(v), after_crlf))
79 }
80 b'$' => {
81 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
82 let len_bytes = &buf[pos + 1..line_end];
83 if len_bytes == b"-1" {
85 return Ok((Frame::BulkString(None), after_crlf));
86 }
87 let len = parse_usize(len_bytes)?;
88 if len == 0 {
89 if after_crlf + 1 >= buf.len() {
90 return Err(ParseError::Incomplete);
91 }
92 if buf[after_crlf] == b'\r' && buf[after_crlf + 1] == b'\n' {
93 return Ok((Frame::BulkString(Some(Bytes::new())), after_crlf + 2));
94 } else {
95 return Err(ParseError::InvalidFormat);
96 }
97 }
98 let data_start = after_crlf;
99 let data_end = data_start + len;
100 if data_end + 1 >= buf.len() || buf[data_end] != b'\r' || buf[data_end + 1] != b'\n' {
101 return Err(ParseError::Incomplete);
102 }
103 Ok((
104 Frame::BulkString(Some(input.slice(data_start..data_end))),
105 data_end + 2,
106 ))
107 }
108 b'*' => {
109 let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
110 let len_bytes = &buf[pos + 1..line_end];
111 if len_bytes == b"-1" {
113 return Ok((Frame::Array(None), after_crlf));
114 }
115 let count = parse_usize(len_bytes)?;
116 if count == 0 {
117 return Ok((Frame::Array(Some(Vec::new())), after_crlf));
118 }
119 let mut cursor = after_crlf;
120 let mut items = Vec::with_capacity(count);
121 for _ in 0..count {
122 let (item, next) = parse_frame_inner(input, cursor)?;
123 items.push(item);
124 cursor = next;
125 }
126 Ok((Frame::Array(Some(items)), cursor))
127 }
128 _ => Err(ParseError::InvalidTag(tag)),
129 }
130}
131
132pub fn frame_to_bytes(frame: &Frame) -> Bytes {
144 let mut buf = BytesMut::new();
145 serialize_frame(frame, &mut buf);
146 buf.freeze()
147}
148
149fn serialize_frame(frame: &Frame, buf: &mut BytesMut) {
150 match frame {
151 Frame::SimpleString(s) => {
152 buf.put_u8(b'+');
153 buf.extend_from_slice(s);
154 buf.extend_from_slice(b"\r\n");
155 }
156 Frame::Error(s) => {
157 buf.put_u8(b'-');
158 buf.extend_from_slice(s);
159 buf.extend_from_slice(b"\r\n");
160 }
161 Frame::Integer(i) => {
162 buf.put_u8(b':');
163 buf.extend_from_slice(i.to_string().as_bytes());
164 buf.extend_from_slice(b"\r\n");
165 }
166 Frame::BulkString(opt) => {
167 buf.put_u8(b'$');
168 match opt {
169 Some(data) => {
170 buf.extend_from_slice(data.len().to_string().as_bytes());
171 buf.extend_from_slice(b"\r\n");
172 buf.extend_from_slice(data);
173 buf.extend_from_slice(b"\r\n");
174 }
175 None => buf.extend_from_slice(b"-1\r\n"),
176 }
177 }
178 Frame::Array(opt) => {
179 buf.put_u8(b'*');
180 match opt {
181 Some(items) => {
182 buf.extend_from_slice(items.len().to_string().as_bytes());
183 buf.extend_from_slice(b"\r\n");
184 for item in items {
185 serialize_frame(item, buf);
186 }
187 }
188 None => buf.extend_from_slice(b"-1\r\n"),
189 }
190 }
191 }
192}
193
194#[derive(Default, Debug)]
213pub struct Parser {
214 buffer: BytesMut,
215}
216
217impl Parser {
218 pub fn new() -> Self {
220 Self {
221 buffer: BytesMut::new(),
222 }
223 }
224
225 pub fn feed(&mut self, data: Bytes) {
227 self.buffer.extend_from_slice(&data);
228 }
229
230 pub fn next_frame(&mut self) -> Result<Option<Frame>, ParseError> {
235 if self.buffer.is_empty() {
236 return Ok(None);
237 }
238
239 let bytes = self.buffer.split().freeze();
240
241 match parse_frame_inner(&bytes, 0) {
242 Ok((frame, consumed)) => {
243 if consumed < bytes.len() {
244 self.buffer.unsplit(BytesMut::from(&bytes[consumed..]));
245 }
246 Ok(Some(frame))
247 }
248 Err(ParseError::Incomplete) => {
249 self.buffer.unsplit(bytes.into());
250 Ok(None)
251 }
252 Err(e) => Err(e),
253 }
254 }
255
256 pub fn buffered_bytes(&self) -> usize {
258 self.buffer.len()
259 }
260
261 pub fn clear(&mut self) {
263 self.buffer.clear();
264 }
265}
266
267#[inline]
270fn find_crlf(buf: &[u8], from: usize) -> Result<(usize, usize), ParseError> {
271 let mut i = from;
272 let len = buf.len();
273 while i + 1 < len {
274 if buf[i] == b'\r' && buf[i + 1] == b'\n' {
275 return Ok((i, i + 2));
276 }
277 i += 1;
278 }
279 Err(ParseError::Incomplete)
280}
281
282#[inline]
284fn parse_usize(buf: &[u8]) -> Result<usize, ParseError> {
285 if buf.is_empty() {
286 return Err(ParseError::BadLength);
287 }
288 let mut v: usize = 0;
289 for &b in buf {
290 if !b.is_ascii_digit() {
291 return Err(ParseError::BadLength);
292 }
293 v = v.checked_mul(10).ok_or(ParseError::BadLength)?;
294 v = v
295 .checked_add((b - b'0') as usize)
296 .ok_or(ParseError::BadLength)?;
297 }
298 Ok(v)
299}
300
301#[inline]
303fn parse_i64(buf: &[u8]) -> Result<i64, ParseError> {
304 if buf.is_empty() {
305 return Err(ParseError::InvalidFormat);
306 }
307 let (neg, digits) = if buf[0] == b'-' {
308 (true, &buf[1..])
309 } else {
310 (false, buf)
311 };
312 if digits.is_empty() {
313 return Err(ParseError::InvalidFormat);
314 }
315 let mut v: i64 = 0;
316 for (i, &d) in digits.iter().enumerate() {
317 if !d.is_ascii_digit() {
318 return Err(ParseError::InvalidFormat);
319 }
320 let digit = (d - b'0') as i64;
321 if neg && v == i64::MAX / 10 && digit == 8 && i == digits.len() - 1 {
322 return Ok(i64::MIN);
323 }
324 if v > i64::MAX / 10 || (v == i64::MAX / 10 && digit > i64::MAX % 10) {
325 return Err(ParseError::Overflow);
326 }
327 v = v * 10 + digit;
328 }
329 if neg { Ok(-v) } else { Ok(v) }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn simple_string() {
338 let (frame, rest) = parse_frame(Bytes::from("+OK\r\nrest")).unwrap();
339 assert_eq!(frame, Frame::SimpleString(Bytes::from("OK")));
340 assert_eq!(rest, Bytes::from("rest"));
341 }
342
343 #[test]
344 fn error() {
345 let (frame, _) = parse_frame(Bytes::from("-ERR fail\r\n")).unwrap();
346 assert_eq!(frame, Frame::Error(Bytes::from("ERR fail")));
347 }
348
349 #[test]
350 fn integer() {
351 let (frame, _) = parse_frame(Bytes::from(":42\r\n")).unwrap();
352 assert_eq!(frame, Frame::Integer(42));
353
354 let (frame, _) = parse_frame(Bytes::from(":-123\r\n")).unwrap();
355 assert_eq!(frame, Frame::Integer(-123));
356 }
357
358 #[test]
359 fn bulk_string() {
360 let (frame, rest) = parse_frame(Bytes::from("$5\r\nhello\r\nX")).unwrap();
361 assert_eq!(frame, Frame::BulkString(Some(Bytes::from("hello"))));
362 assert_eq!(rest, Bytes::from("X"));
363 }
364
365 #[test]
366 fn null_bulk_string() {
367 let (frame, _) = parse_frame(Bytes::from("$-1\r\n")).unwrap();
368 assert_eq!(frame, Frame::BulkString(None));
369 }
370
371 #[test]
372 fn empty_bulk_string() {
373 let (frame, rest) = parse_frame(Bytes::from("$0\r\n\r\nX")).unwrap();
374 assert_eq!(frame, Frame::BulkString(Some(Bytes::new())));
375 assert_eq!(rest, Bytes::from("X"));
376 }
377
378 #[test]
379 fn array() {
380 let input = Bytes::from("*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n");
381 let (frame, _) = parse_frame(input).unwrap();
382 assert_eq!(
383 frame,
384 Frame::Array(Some(vec![
385 Frame::BulkString(Some(Bytes::from("foo"))),
386 Frame::BulkString(Some(Bytes::from("bar"))),
387 ]))
388 );
389 }
390
391 #[test]
392 fn null_array() {
393 let (frame, _) = parse_frame(Bytes::from("*-1\r\n")).unwrap();
394 assert_eq!(frame, Frame::Array(None));
395 }
396
397 #[test]
398 fn empty_array() {
399 let (frame, _) = parse_frame(Bytes::from("*0\r\n")).unwrap();
400 assert_eq!(frame, Frame::Array(Some(vec![])));
401 }
402
403 #[test]
404 fn nested_array() {
405 let input = Bytes::from("*2\r\n*1\r\n:1\r\n+OK\r\n");
406 let (frame, _) = parse_frame(input).unwrap();
407 assert_eq!(
408 frame,
409 Frame::Array(Some(vec![
410 Frame::Array(Some(vec![Frame::Integer(1)])),
411 Frame::SimpleString(Bytes::from("OK")),
412 ]))
413 );
414 }
415
416 #[test]
417 fn incomplete() {
418 assert_eq!(parse_frame(Bytes::new()), Err(ParseError::Incomplete));
419 assert_eq!(
420 parse_frame(Bytes::from("+OK\r")),
421 Err(ParseError::Incomplete)
422 );
423 assert_eq!(
424 parse_frame(Bytes::from("$5\r\nhel")),
425 Err(ParseError::Incomplete)
426 );
427 }
428
429 #[test]
430 fn invalid_tag() {
431 assert_eq!(
432 parse_frame(Bytes::from("X\r\n")),
433 Err(ParseError::InvalidTag(b'X'))
434 );
435 }
436
437 #[test]
438 fn roundtrip() {
439 let frames = vec![
440 Frame::SimpleString(Bytes::from("OK")),
441 Frame::Error(Bytes::from("ERR bad")),
442 Frame::Integer(42),
443 Frame::BulkString(Some(Bytes::from("hello"))),
444 Frame::BulkString(None),
445 Frame::Array(Some(vec![
446 Frame::Integer(1),
447 Frame::BulkString(Some(Bytes::from("two"))),
448 ])),
449 Frame::Array(None),
450 ];
451 for frame in &frames {
452 let bytes = frame_to_bytes(frame);
453 let (parsed, rest) = parse_frame(bytes).unwrap();
454 assert_eq!(&parsed, frame);
455 assert!(rest.is_empty());
456 }
457 }
458
459 #[test]
460 fn streaming_parser() {
461 let mut parser = Parser::new();
462 parser.feed(Bytes::from("+HEL"));
463 assert!(parser.next_frame().unwrap().is_none());
464
465 parser.feed(Bytes::from("LO\r\n:42\r\n"));
466 let f1 = parser.next_frame().unwrap().unwrap();
467 assert_eq!(f1, Frame::SimpleString(Bytes::from("HELLO")));
468
469 let f2 = parser.next_frame().unwrap().unwrap();
470 assert_eq!(f2, Frame::Integer(42));
471
472 assert!(parser.next_frame().unwrap().is_none());
473 }
474
475 #[test]
476 fn chained_frames() {
477 let input = Bytes::from("+OK\r\n:1\r\n$3\r\nfoo\r\n");
478 let (f1, rest) = parse_frame(input).unwrap();
479 assert_eq!(f1, Frame::SimpleString(Bytes::from("OK")));
480 let (f2, rest) = parse_frame(rest).unwrap();
481 assert_eq!(f2, Frame::Integer(1));
482 let (f3, rest) = parse_frame(rest).unwrap();
483 assert_eq!(f3, Frame::BulkString(Some(Bytes::from("foo"))));
484 assert!(rest.is_empty());
485 }
486
487 #[test]
488 fn binary_bulk_string() {
489 let mut data = Vec::new();
490 data.extend_from_slice(b"$5\r\n");
491 data.extend_from_slice(&[0x00, 0x01, 0xFF, 0xFE, 0x42]);
492 data.extend_from_slice(b"\r\n");
493 let (frame, _) = parse_frame(Bytes::from(data)).unwrap();
494 match frame {
495 Frame::BulkString(Some(b)) => {
496 assert_eq!(b.as_ref(), &[0x00, 0x01, 0xFF, 0xFE, 0x42]);
497 }
498 _ => panic!("expected bulk string"),
499 }
500 }
501
502 #[test]
503 fn rejects_resp3_types() {
504 assert!(parse_frame(Bytes::from("_\r\n")).is_err()); assert!(parse_frame(Bytes::from(",3.14\r\n")).is_err()); assert!(parse_frame(Bytes::from("#t\r\n")).is_err()); assert!(parse_frame(Bytes::from("(123\r\n")).is_err()); }
510}