1use std::io::Cursor;
27
28use bytes::Bytes;
29
30use crate::error::ProtocolError;
31use crate::types::Frame;
32
33const MAX_NESTING_DEPTH: usize = 64;
36
37const MAX_ARRAY_ELEMENTS: usize = 1_048_576;
41
42const MAX_BULK_LEN: i64 = 512 * 1024 * 1024;
44
45const PREALLOC_CAP: usize = 1024;
50
51#[inline]
61pub fn parse_frame_bytes(buf: &Bytes) -> Result<Option<(Frame, usize)>, ProtocolError> {
62 if buf.is_empty() {
63 return Ok(None);
64 }
65
66 let mut cursor = Cursor::new(buf.as_ref());
67
68 match try_parse(&mut cursor, Some(buf), 0) {
69 Ok(frame) => {
70 let consumed = cursor.position() as usize;
71 Ok(Some((frame, consumed)))
72 }
73 Err(ProtocolError::Incomplete) => Ok(None),
74 Err(e) => Err(e),
75 }
76}
77
78#[inline]
87pub fn parse_frame(buf: &[u8]) -> Result<Option<(Frame, usize)>, ProtocolError> {
88 if buf.is_empty() {
89 return Ok(None);
90 }
91
92 let mut cursor = Cursor::new(buf);
93
94 match try_parse(&mut cursor, None, 0) {
95 Ok(frame) => {
96 let consumed = cursor.position() as usize;
97 Ok(Some((frame, consumed)))
98 }
99 Err(ProtocolError::Incomplete) => Ok(None),
100 Err(e) => Err(e),
101 }
102}
103
104fn try_parse(
114 cursor: &mut Cursor<&[u8]>,
115 src: Option<&Bytes>,
116 depth: usize,
117) -> Result<Frame, ProtocolError> {
118 let prefix = read_byte(cursor)?;
119
120 match prefix {
121 b'+' => {
122 let line = read_line(cursor)?;
123 let s = std::str::from_utf8(line).map_err(|_| {
124 ProtocolError::InvalidCommandFrame("invalid utf-8 in simple string".into())
125 })?;
126 Ok(Frame::Simple(s.to_owned()))
127 }
128 b'-' => {
129 let line = read_line(cursor)?;
130 let s = std::str::from_utf8(line).map_err(|_| {
131 ProtocolError::InvalidCommandFrame("invalid utf-8 in error string".into())
132 })?;
133 Ok(Frame::Error(s.to_owned()))
134 }
135 b':' => {
136 let val = read_integer_line(cursor)?;
137 Ok(Frame::Integer(val))
138 }
139 b'$' => {
140 let len = read_integer_line(cursor)?;
141 if len < 0 {
142 return Err(ProtocolError::InvalidFrameLength(len));
143 }
144 if len > MAX_BULK_LEN {
145 return Err(ProtocolError::BulkStringTooLarge(len as usize));
146 }
147 let len = len as usize;
148
149 let remaining = remaining(cursor);
151 if remaining < len + 2 {
152 return Err(ProtocolError::Incomplete);
153 }
154
155 let pos = cursor.position() as usize;
156
157 {
159 let buf = cursor.get_ref();
160 if buf[pos + len] != b'\r' || buf[pos + len + 1] != b'\n' {
161 return Err(ProtocolError::InvalidFrameLength(len as i64));
162 }
163 }
164
165 cursor.set_position((pos + len + 2) as u64);
166
167 let data = match src {
169 Some(b) => b.slice(pos..pos + len),
170 None => Bytes::copy_from_slice(&cursor.get_ref()[pos..pos + len]),
171 };
172 Ok(Frame::Bulk(data))
173 }
174 b'*' => {
175 let next_depth = depth + 1;
176 if next_depth > MAX_NESTING_DEPTH {
177 return Err(ProtocolError::NestingTooDeep(MAX_NESTING_DEPTH));
178 }
179
180 let count = read_integer_line(cursor)?;
181 if count < 0 {
182 return Err(ProtocolError::InvalidFrameLength(count));
183 }
184 if count as usize > MAX_ARRAY_ELEMENTS {
185 return Err(ProtocolError::TooManyElements(count as usize));
186 }
187
188 let count = count as usize;
189 let mut frames = Vec::with_capacity(count.min(PREALLOC_CAP));
190 for _ in 0..count {
191 frames.push(try_parse(cursor, src, next_depth)?);
192 }
193 Ok(Frame::Array(frames))
194 }
195 b'_' => {
196 let _ = read_line(cursor)?;
198 Ok(Frame::Null)
199 }
200 b'%' => {
201 let next_depth = depth + 1;
202 if next_depth > MAX_NESTING_DEPTH {
203 return Err(ProtocolError::NestingTooDeep(MAX_NESTING_DEPTH));
204 }
205
206 let count = read_integer_line(cursor)?;
207 if count < 0 {
208 return Err(ProtocolError::InvalidFrameLength(count));
209 }
210 if count as usize > MAX_ARRAY_ELEMENTS {
211 return Err(ProtocolError::TooManyElements(count as usize));
212 }
213
214 let count = count as usize;
215 let mut pairs = Vec::with_capacity(count.min(PREALLOC_CAP));
216 for _ in 0..count {
217 let key = try_parse(cursor, src, next_depth)?;
218 let val = try_parse(cursor, src, next_depth)?;
219 pairs.push((key, val));
220 }
221 Ok(Frame::Map(pairs))
222 }
223 other => Err(ProtocolError::InvalidPrefix(other)),
224 }
225}
226
227fn read_byte(cursor: &mut Cursor<&[u8]>) -> Result<u8, ProtocolError> {
232 let pos = cursor.position() as usize;
233 if pos >= cursor.get_ref().len() {
234 return Err(ProtocolError::Incomplete);
235 }
236 cursor.set_position((pos + 1) as u64);
237 Ok(cursor.get_ref()[pos])
238}
239
240fn read_line<'a>(cursor: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], ProtocolError> {
243 let start = cursor.position() as usize;
244 let end = find_crlf(cursor)?;
245 Ok(&cursor.get_ref()[start..end])
246}
247
248fn read_integer_line(cursor: &mut Cursor<&[u8]>) -> Result<i64, ProtocolError> {
250 let line = read_line(cursor)?;
251 parse_i64_bytes(line)
252}
253
254fn find_crlf(cursor: &mut Cursor<&[u8]>) -> Result<usize, ProtocolError> {
257 let buf = cursor.get_ref();
258 let start = cursor.position() as usize;
259
260 if start >= buf.len() {
261 return Err(ProtocolError::Incomplete);
262 }
263
264 let mut pos = start;
267 while let Some(offset) = memchr::memchr(b'\r', &buf[pos..]) {
268 let cr = pos + offset;
269 if cr + 1 < buf.len() && buf[cr + 1] == b'\n' {
270 cursor.set_position((cr + 2) as u64);
271 return Ok(cr);
272 }
273 pos = cr + 1;
275 }
276
277 Err(ProtocolError::Incomplete)
278}
279
280fn remaining(cursor: &Cursor<&[u8]>) -> usize {
281 let len = cursor.get_ref().len();
282 let pos = cursor.position() as usize;
283 len.saturating_sub(pos)
284}
285
286fn parse_i64_bytes(buf: &[u8]) -> Result<i64, ProtocolError> {
291 if buf.is_empty() {
292 return Err(ProtocolError::InvalidInteger);
293 }
294
295 let (negative, digits) = if buf[0] == b'-' {
296 (true, &buf[1..])
297 } else {
298 (false, buf)
299 };
300
301 if digits.is_empty() {
302 return Err(ProtocolError::InvalidInteger);
303 }
304
305 if negative {
306 let mut n: i64 = 0;
308 for &b in digits {
309 if !b.is_ascii_digit() {
310 return Err(ProtocolError::InvalidInteger);
311 }
312 n = n
313 .checked_mul(10)
314 .and_then(|n| n.checked_sub((b - b'0') as i64))
315 .ok_or(ProtocolError::InvalidInteger)?;
316 }
317 Ok(n)
318 } else {
319 let mut n: i64 = 0;
320 for &b in digits {
321 if !b.is_ascii_digit() {
322 return Err(ProtocolError::InvalidInteger);
323 }
324 n = n
325 .checked_mul(10)
326 .and_then(|n| n.checked_add((b - b'0') as i64))
327 .ok_or(ProtocolError::InvalidInteger)?;
328 }
329 Ok(n)
330 }
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 fn must_parse(input: &[u8]) -> Frame {
338 let (frame, consumed) = parse_frame(input)
339 .expect("parse should not error")
340 .expect("parse should return a frame");
341 assert_eq!(consumed, input.len(), "should consume entire input");
342 frame
343 }
344
345 fn must_parse_zerocopy(input: &Bytes) -> Frame {
346 let (frame, consumed) = parse_frame_bytes(input)
347 .expect("parse should not error")
348 .expect("parse should return a frame");
349 assert_eq!(consumed, input.len(), "should consume entire input");
350 frame
351 }
352
353 #[test]
354 fn simple_string() {
355 assert_eq!(must_parse(b"+OK\r\n"), Frame::Simple("OK".into()));
356 assert_eq!(
357 must_parse(b"+hello world\r\n"),
358 Frame::Simple("hello world".into())
359 );
360 }
361
362 #[test]
363 fn simple_error() {
364 assert_eq!(
365 must_parse(b"-ERR unknown command\r\n"),
366 Frame::Error("ERR unknown command".into())
367 );
368 }
369
370 #[test]
371 fn integer() {
372 assert_eq!(must_parse(b":42\r\n"), Frame::Integer(42));
373 assert_eq!(must_parse(b":0\r\n"), Frame::Integer(0));
374 assert_eq!(must_parse(b":-1\r\n"), Frame::Integer(-1));
375 assert_eq!(
376 must_parse(b":9223372036854775807\r\n"),
377 Frame::Integer(i64::MAX)
378 );
379 assert_eq!(
380 must_parse(b":-9223372036854775808\r\n"),
381 Frame::Integer(i64::MIN)
382 );
383 }
384
385 #[test]
386 fn bulk_string() {
387 assert_eq!(
388 must_parse(b"$5\r\nhello\r\n"),
389 Frame::Bulk(Bytes::from_static(b"hello"))
390 );
391 }
392
393 #[test]
394 fn empty_bulk_string() {
395 assert_eq!(
396 must_parse(b"$0\r\n\r\n"),
397 Frame::Bulk(Bytes::from_static(b""))
398 );
399 }
400
401 #[test]
402 fn bulk_string_with_binary() {
403 let input = b"$4\r\n\x00\x01\x02\x03\r\n";
404 assert_eq!(
405 must_parse(input),
406 Frame::Bulk(Bytes::copy_from_slice(&[0, 1, 2, 3]))
407 );
408 }
409
410 #[test]
411 fn null() {
412 assert_eq!(must_parse(b"_\r\n"), Frame::Null);
413 }
414
415 #[test]
416 fn array() {
417 let input = b"*2\r\n+hello\r\n+world\r\n";
418 assert_eq!(
419 must_parse(input),
420 Frame::Array(vec![
421 Frame::Simple("hello".into()),
422 Frame::Simple("world".into()),
423 ])
424 );
425 }
426
427 #[test]
428 fn empty_array() {
429 assert_eq!(must_parse(b"*0\r\n"), Frame::Array(vec![]));
430 }
431
432 #[test]
433 fn nested_array() {
434 let input = b"*2\r\n*2\r\n:1\r\n:2\r\n*2\r\n:3\r\n:4\r\n";
435 assert_eq!(
436 must_parse(input),
437 Frame::Array(vec![
438 Frame::Array(vec![Frame::Integer(1), Frame::Integer(2)]),
439 Frame::Array(vec![Frame::Integer(3), Frame::Integer(4)]),
440 ])
441 );
442 }
443
444 #[test]
445 fn array_with_null() {
446 let input = b"*3\r\n+OK\r\n_\r\n:1\r\n";
447 assert_eq!(
448 must_parse(input),
449 Frame::Array(vec![
450 Frame::Simple("OK".into()),
451 Frame::Null,
452 Frame::Integer(1),
453 ])
454 );
455 }
456
457 #[test]
458 fn map() {
459 let input = b"%2\r\n+key1\r\n:1\r\n+key2\r\n:2\r\n";
460 assert_eq!(
461 must_parse(input),
462 Frame::Map(vec![
463 (Frame::Simple("key1".into()), Frame::Integer(1)),
464 (Frame::Simple("key2".into()), Frame::Integer(2)),
465 ])
466 );
467 }
468
469 #[test]
470 fn incomplete_returns_none() {
471 assert_eq!(parse_frame(b"").unwrap(), None);
472 assert_eq!(parse_frame(b"+OK").unwrap(), None);
473 assert_eq!(parse_frame(b"+OK\r").unwrap(), None);
474 assert_eq!(parse_frame(b"$5\r\nhel").unwrap(), None);
475 assert_eq!(parse_frame(b"*2\r\n+OK\r\n").unwrap(), None);
476 }
477
478 #[test]
479 fn invalid_prefix() {
480 let err = parse_frame(b"~invalid\r\n").unwrap_err();
481 assert_eq!(err, ProtocolError::InvalidPrefix(b'~'));
482 }
483
484 #[test]
485 fn invalid_integer() {
486 let err = parse_frame(b":abc\r\n").unwrap_err();
487 assert_eq!(err, ProtocolError::InvalidInteger);
488 }
489
490 #[test]
491 fn negative_bulk_length() {
492 let err = parse_frame(b"$-1\r\n").unwrap_err();
493 assert!(matches!(err, ProtocolError::InvalidFrameLength(-1)));
494 }
495
496 #[test]
497 fn parse_consumes_exact_bytes() {
498 let buf = b"+OK\r\ntrailing";
500 let (frame, consumed) = parse_frame(buf).unwrap().unwrap();
501 assert_eq!(frame, Frame::Simple("OK".into()));
502 assert_eq!(consumed, 5);
503 }
504
505 #[test]
506 fn deeply_nested_array_rejected() {
507 let mut buf = Vec::new();
509 for _ in 0..65 {
510 buf.extend_from_slice(b"*1\r\n");
511 }
512 buf.extend_from_slice(b":1\r\n"); let err = parse_frame(&buf).unwrap_err();
515 assert!(
516 matches!(err, ProtocolError::NestingTooDeep(64)),
517 "expected NestingTooDeep, got {err:?}"
518 );
519 }
520
521 #[test]
522 fn nesting_at_limit_accepted() {
523 let mut buf = Vec::new();
525 for _ in 0..64 {
526 buf.extend_from_slice(b"*1\r\n");
527 }
528 buf.extend_from_slice(b":1\r\n");
529
530 let result = parse_frame(&buf);
531 assert!(result.is_ok(), "64 levels of nesting should be accepted");
532 assert!(result.unwrap().is_some());
533 }
534
535 #[test]
536 fn zerocopy_bulk_string() {
537 let input = Bytes::from_static(b"$5\r\nhello\r\n");
538 assert_eq!(
539 must_parse_zerocopy(&input),
540 Frame::Bulk(Bytes::from_static(b"hello"))
541 );
542 }
543
544 #[test]
545 fn zerocopy_array() {
546 let input = Bytes::from_static(b"*2\r\n$3\r\nGET\r\n$5\r\nmykey\r\n");
547 let frame = must_parse_zerocopy(&input);
548 assert_eq!(
549 frame,
550 Frame::Array(vec![
551 Frame::Bulk(Bytes::from_static(b"GET")),
552 Frame::Bulk(Bytes::from_static(b"mykey")),
553 ])
554 );
555 }
556
557 #[test]
558 fn parse_i64_bytes_valid() {
559 assert_eq!(parse_i64_bytes(b"0").unwrap(), 0);
560 assert_eq!(parse_i64_bytes(b"42").unwrap(), 42);
561 assert_eq!(parse_i64_bytes(b"-1").unwrap(), -1);
562 assert_eq!(parse_i64_bytes(b"9223372036854775807").unwrap(), i64::MAX);
563 assert_eq!(parse_i64_bytes(b"-9223372036854775808").unwrap(), i64::MIN);
564 }
565
566 #[test]
567 fn parse_i64_bytes_invalid() {
568 assert!(parse_i64_bytes(b"").is_err());
569 assert!(parse_i64_bytes(b"-").is_err());
570 assert!(parse_i64_bytes(b"abc").is_err());
571 assert!(parse_i64_bytes(b"12a").is_err());
572 }
573}