Skip to main content

ember_protocol/
parse.rs

1//! Zero-copy RESP3 parser.
2//!
3//! Operates on buffered byte slices. The caller is responsible for reading
4//! data from the network into a buffer — this parser is purely synchronous.
5//!
6//! The parser uses a `Cursor<&[u8]>` to track its position through the
7//! input buffer without consuming it, allowing the caller to retry once
8//! more data arrives.
9//!
10//! # Single-pass design
11//!
12//! Earlier versions used a two-pass approach: `check()` to validate a
13//! complete frame exists, then `parse()` to build Frame values. This
14//! scanned every byte twice. The current implementation does a single
15//! pass that builds Frame values directly, returning `Incomplete` if
16//! the buffer doesn't contain enough data yet.
17//!
18//! # Zero-copy bulk strings
19//!
20//! When parsing from a `Bytes` buffer via [`parse_frame_bytes`], bulk
21//! string data is returned as a zero-copy `Bytes::slice()` into the
22//! original buffer. This avoids a heap allocation per bulk string.
23//! The fallback [`parse_frame`] copies bulk data for callers that
24//! only have a `&[u8]`.
25
26use std::io::Cursor;
27
28use bytes::Bytes;
29
30use crate::error::ProtocolError;
31use crate::types::Frame;
32
33/// Maximum nesting depth for arrays and maps. Prevents stack overflow
34/// from malicious or malformed deeply-nested frames.
35const MAX_NESTING_DEPTH: usize = 64;
36
37/// Maximum number of elements in an array or map. Prevents memory
38/// amplification attacks where tiny elements (3 bytes each) create
39/// disproportionately large Vec allocations.
40const MAX_ARRAY_ELEMENTS: usize = 1_048_576;
41
42/// Maximum length of a bulk string in bytes (512 MB, matching Redis).
43const MAX_BULK_LEN: i64 = 512 * 1024 * 1024;
44
45/// Cap for Vec::with_capacity in array/map parsing. A declared count of
46/// 1M elements with capacity pre-allocation costs ~72 MB upfront even
47/// before any child data is parsed. This cap limits the initial allocation
48/// while still letting the Vec grow organically as elements are parsed.
49const PREALLOC_CAP: usize = 1024;
50
51/// Zero-copy frame parser. Bulk string data is returned as `Bytes::slice()`
52/// into the input buffer, avoiding a heap copy per bulk string.
53///
54/// Use this on the hot path when the caller has a `Bytes` (e.g. from
55/// `BytesMut::freeze()`).
56///
57/// Returns `Ok(Some((frame, consumed)))` if a complete frame was parsed,
58/// `Ok(None)` if the buffer doesn't contain enough data yet,
59/// or `Err(...)` if the data is malformed.
60#[inline]
61pub fn parse_frame_bytes(buf: &Bytes) -> Result<Option<(Frame, usize)>, ProtocolError> {
62    if buf.is_empty() {
63        return Ok(None);
64    }
65
66    let mut cursor = Cursor::new(buf.as_ref());
67
68    match try_parse(&mut cursor, Some(buf), 0) {
69        Ok(frame) => {
70            let consumed = cursor.position() as usize;
71            Ok(Some((frame, consumed)))
72        }
73        Err(ProtocolError::Incomplete) => Ok(None),
74        Err(e) => Err(e),
75    }
76}
77
78/// Checks whether `buf` contains a complete RESP3 frame and parses it.
79///
80/// Bulk string data is copied out of the buffer. Prefer [`parse_frame_bytes`]
81/// on hot paths when a `Bytes` reference is available.
82///
83/// Returns `Ok(Some(frame))` if a complete frame was parsed,
84/// `Ok(None)` if the buffer doesn't contain enough data yet,
85/// or `Err(...)` if the data is malformed.
86#[inline]
87pub fn parse_frame(buf: &[u8]) -> Result<Option<(Frame, usize)>, ProtocolError> {
88    if buf.is_empty() {
89        return Ok(None);
90    }
91
92    let mut cursor = Cursor::new(buf);
93
94    match try_parse(&mut cursor, None, 0) {
95        Ok(frame) => {
96            let consumed = cursor.position() as usize;
97            Ok(Some((frame, consumed)))
98        }
99        Err(ProtocolError::Incomplete) => Ok(None),
100        Err(e) => Err(e),
101    }
102}
103
104// ---------------------------------------------------------------------------
105// single-pass parser: validates and builds Frame values in one traversal
106// ---------------------------------------------------------------------------
107
108/// Parses a complete RESP3 frame from the cursor position, returning
109/// `Incomplete` if the buffer doesn't contain enough data.
110///
111/// When `src` is `Some`, bulk string bytes are sliced zero-copy from the
112/// source buffer. When `None`, they are copied.
113fn try_parse(
114    cursor: &mut Cursor<&[u8]>,
115    src: Option<&Bytes>,
116    depth: usize,
117) -> Result<Frame, ProtocolError> {
118    let prefix = read_byte(cursor)?;
119
120    match prefix {
121        b'+' => {
122            let line = read_line(cursor)?;
123            let s = std::str::from_utf8(line).map_err(|_| {
124                ProtocolError::InvalidCommandFrame("invalid utf-8 in simple string".into())
125            })?;
126            Ok(Frame::Simple(s.to_owned()))
127        }
128        b'-' => {
129            let line = read_line(cursor)?;
130            let s = std::str::from_utf8(line).map_err(|_| {
131                ProtocolError::InvalidCommandFrame("invalid utf-8 in error string".into())
132            })?;
133            Ok(Frame::Error(s.to_owned()))
134        }
135        b':' => {
136            let val = read_integer_line(cursor)?;
137            Ok(Frame::Integer(val))
138        }
139        b'$' => {
140            let len = read_integer_line(cursor)?;
141            if len < 0 {
142                return Err(ProtocolError::InvalidFrameLength(len));
143            }
144            if len > MAX_BULK_LEN {
145                return Err(ProtocolError::BulkStringTooLarge(len as usize));
146            }
147            let len = len as usize;
148
149            // need `len` bytes of data + \r\n
150            let remaining = remaining(cursor);
151            if remaining < len + 2 {
152                return Err(ProtocolError::Incomplete);
153            }
154
155            let pos = cursor.position() as usize;
156
157            // verify trailing \r\n (scope the borrow so we can mutate cursor after)
158            {
159                let buf = cursor.get_ref();
160                if buf[pos + len] != b'\r' || buf[pos + len + 1] != b'\n' {
161                    return Err(ProtocolError::InvalidFrameLength(len as i64));
162                }
163            }
164
165            cursor.set_position((pos + len + 2) as u64);
166
167            // zero-copy when source Bytes is available, copy otherwise
168            let data = match src {
169                Some(b) => b.slice(pos..pos + len),
170                None => Bytes::copy_from_slice(&cursor.get_ref()[pos..pos + len]),
171            };
172            Ok(Frame::Bulk(data))
173        }
174        b'*' => {
175            let next_depth = depth + 1;
176            if next_depth > MAX_NESTING_DEPTH {
177                return Err(ProtocolError::NestingTooDeep(MAX_NESTING_DEPTH));
178            }
179
180            let count = read_integer_line(cursor)?;
181            if count < 0 {
182                return Err(ProtocolError::InvalidFrameLength(count));
183            }
184            if count as usize > MAX_ARRAY_ELEMENTS {
185                return Err(ProtocolError::TooManyElements(count as usize));
186            }
187
188            let count = count as usize;
189            let mut frames = Vec::with_capacity(count.min(PREALLOC_CAP));
190            for _ in 0..count {
191                frames.push(try_parse(cursor, src, next_depth)?);
192            }
193            Ok(Frame::Array(frames))
194        }
195        b'_' => {
196            // consume the trailing \r\n
197            let _ = read_line(cursor)?;
198            Ok(Frame::Null)
199        }
200        b'%' => {
201            let next_depth = depth + 1;
202            if next_depth > MAX_NESTING_DEPTH {
203                return Err(ProtocolError::NestingTooDeep(MAX_NESTING_DEPTH));
204            }
205
206            let count = read_integer_line(cursor)?;
207            if count < 0 {
208                return Err(ProtocolError::InvalidFrameLength(count));
209            }
210            if count as usize > MAX_ARRAY_ELEMENTS {
211                return Err(ProtocolError::TooManyElements(count as usize));
212            }
213
214            let count = count as usize;
215            let mut pairs = Vec::with_capacity(count.min(PREALLOC_CAP));
216            for _ in 0..count {
217                let key = try_parse(cursor, src, next_depth)?;
218                let val = try_parse(cursor, src, next_depth)?;
219                pairs.push((key, val));
220            }
221            Ok(Frame::Map(pairs))
222        }
223        other => Err(ProtocolError::InvalidPrefix(other)),
224    }
225}
226
227// ---------------------------------------------------------------------------
228// low-level cursor helpers
229// ---------------------------------------------------------------------------
230
231fn read_byte(cursor: &mut Cursor<&[u8]>) -> Result<u8, ProtocolError> {
232    let pos = cursor.position() as usize;
233    if pos >= cursor.get_ref().len() {
234        return Err(ProtocolError::Incomplete);
235    }
236    cursor.set_position((pos + 1) as u64);
237    Ok(cursor.get_ref()[pos])
238}
239
240/// Returns the slice of bytes up to (but not including) the next `\r\n`,
241/// and advances the cursor past the `\r\n`.
242fn read_line<'a>(cursor: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], ProtocolError> {
243    let start = cursor.position() as usize;
244    let end = find_crlf(cursor)?;
245    Ok(&cursor.get_ref()[start..end])
246}
247
248/// Reads a line and parses it as an i64.
249fn read_integer_line(cursor: &mut Cursor<&[u8]>) -> Result<i64, ProtocolError> {
250    let line = read_line(cursor)?;
251    parse_i64_bytes(line)
252}
253
254/// Finds the next `\r\n` in the buffer starting from the cursor position.
255/// Returns the index of `\r` and advances the cursor past the `\n`.
256fn find_crlf(cursor: &mut Cursor<&[u8]>) -> Result<usize, ProtocolError> {
257    let buf = cursor.get_ref();
258    let start = cursor.position() as usize;
259
260    if start >= buf.len() {
261        return Err(ProtocolError::Incomplete);
262    }
263
264    // SIMD-accelerated scan for \r, then verify \n follows.
265    // memchr processes 16-32 bytes per cycle vs 1 byte in a naive loop.
266    let mut pos = start;
267    while let Some(offset) = memchr::memchr(b'\r', &buf[pos..]) {
268        let cr = pos + offset;
269        if cr + 1 < buf.len() && buf[cr + 1] == b'\n' {
270            cursor.set_position((cr + 2) as u64);
271            return Ok(cr);
272        }
273        // bare \r without \n — keep scanning past it
274        pos = cr + 1;
275    }
276
277    Err(ProtocolError::Incomplete)
278}
279
280fn remaining(cursor: &Cursor<&[u8]>) -> usize {
281    let len = cursor.get_ref().len();
282    let pos = cursor.position() as usize;
283    len.saturating_sub(pos)
284}
285
286/// Parses an i64 directly from a byte slice without allocating a String.
287///
288/// Negative numbers are accumulated in the negative direction so that
289/// `i64::MIN` (-9223372036854775808) is representable without overflow.
290fn parse_i64_bytes(buf: &[u8]) -> Result<i64, ProtocolError> {
291    if buf.is_empty() {
292        return Err(ProtocolError::InvalidInteger);
293    }
294
295    let (negative, digits) = if buf[0] == b'-' {
296        (true, &buf[1..])
297    } else {
298        (false, buf)
299    };
300
301    if digits.is_empty() {
302        return Err(ProtocolError::InvalidInteger);
303    }
304
305    if negative {
306        // accumulate in the negative direction to handle i64::MIN
307        let mut n: i64 = 0;
308        for &b in digits {
309            if !b.is_ascii_digit() {
310                return Err(ProtocolError::InvalidInteger);
311            }
312            n = n
313                .checked_mul(10)
314                .and_then(|n| n.checked_sub((b - b'0') as i64))
315                .ok_or(ProtocolError::InvalidInteger)?;
316        }
317        Ok(n)
318    } else {
319        let mut n: i64 = 0;
320        for &b in digits {
321            if !b.is_ascii_digit() {
322                return Err(ProtocolError::InvalidInteger);
323            }
324            n = n
325                .checked_mul(10)
326                .and_then(|n| n.checked_add((b - b'0') as i64))
327                .ok_or(ProtocolError::InvalidInteger)?;
328        }
329        Ok(n)
330    }
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    fn must_parse(input: &[u8]) -> Frame {
338        let (frame, consumed) = parse_frame(input)
339            .expect("parse should not error")
340            .expect("parse should return a frame");
341        assert_eq!(consumed, input.len(), "should consume entire input");
342        frame
343    }
344
345    fn must_parse_zerocopy(input: &Bytes) -> Frame {
346        let (frame, consumed) = parse_frame_bytes(input)
347            .expect("parse should not error")
348            .expect("parse should return a frame");
349        assert_eq!(consumed, input.len(), "should consume entire input");
350        frame
351    }
352
353    #[test]
354    fn simple_string() {
355        assert_eq!(must_parse(b"+OK\r\n"), Frame::Simple("OK".into()));
356        assert_eq!(
357            must_parse(b"+hello world\r\n"),
358            Frame::Simple("hello world".into())
359        );
360    }
361
362    #[test]
363    fn simple_error() {
364        assert_eq!(
365            must_parse(b"-ERR unknown command\r\n"),
366            Frame::Error("ERR unknown command".into())
367        );
368    }
369
370    #[test]
371    fn integer() {
372        assert_eq!(must_parse(b":42\r\n"), Frame::Integer(42));
373        assert_eq!(must_parse(b":0\r\n"), Frame::Integer(0));
374        assert_eq!(must_parse(b":-1\r\n"), Frame::Integer(-1));
375        assert_eq!(
376            must_parse(b":9223372036854775807\r\n"),
377            Frame::Integer(i64::MAX)
378        );
379        assert_eq!(
380            must_parse(b":-9223372036854775808\r\n"),
381            Frame::Integer(i64::MIN)
382        );
383    }
384
385    #[test]
386    fn bulk_string() {
387        assert_eq!(
388            must_parse(b"$5\r\nhello\r\n"),
389            Frame::Bulk(Bytes::from_static(b"hello"))
390        );
391    }
392
393    #[test]
394    fn empty_bulk_string() {
395        assert_eq!(
396            must_parse(b"$0\r\n\r\n"),
397            Frame::Bulk(Bytes::from_static(b""))
398        );
399    }
400
401    #[test]
402    fn bulk_string_with_binary() {
403        let input = b"$4\r\n\x00\x01\x02\x03\r\n";
404        assert_eq!(
405            must_parse(input),
406            Frame::Bulk(Bytes::copy_from_slice(&[0, 1, 2, 3]))
407        );
408    }
409
410    #[test]
411    fn null() {
412        assert_eq!(must_parse(b"_\r\n"), Frame::Null);
413    }
414
415    #[test]
416    fn array() {
417        let input = b"*2\r\n+hello\r\n+world\r\n";
418        assert_eq!(
419            must_parse(input),
420            Frame::Array(vec![
421                Frame::Simple("hello".into()),
422                Frame::Simple("world".into()),
423            ])
424        );
425    }
426
427    #[test]
428    fn empty_array() {
429        assert_eq!(must_parse(b"*0\r\n"), Frame::Array(vec![]));
430    }
431
432    #[test]
433    fn nested_array() {
434        let input = b"*2\r\n*2\r\n:1\r\n:2\r\n*2\r\n:3\r\n:4\r\n";
435        assert_eq!(
436            must_parse(input),
437            Frame::Array(vec![
438                Frame::Array(vec![Frame::Integer(1), Frame::Integer(2)]),
439                Frame::Array(vec![Frame::Integer(3), Frame::Integer(4)]),
440            ])
441        );
442    }
443
444    #[test]
445    fn array_with_null() {
446        let input = b"*3\r\n+OK\r\n_\r\n:1\r\n";
447        assert_eq!(
448            must_parse(input),
449            Frame::Array(vec![
450                Frame::Simple("OK".into()),
451                Frame::Null,
452                Frame::Integer(1),
453            ])
454        );
455    }
456
457    #[test]
458    fn map() {
459        let input = b"%2\r\n+key1\r\n:1\r\n+key2\r\n:2\r\n";
460        assert_eq!(
461            must_parse(input),
462            Frame::Map(vec![
463                (Frame::Simple("key1".into()), Frame::Integer(1)),
464                (Frame::Simple("key2".into()), Frame::Integer(2)),
465            ])
466        );
467    }
468
469    #[test]
470    fn incomplete_returns_none() {
471        assert_eq!(parse_frame(b"").unwrap(), None);
472        assert_eq!(parse_frame(b"+OK").unwrap(), None);
473        assert_eq!(parse_frame(b"+OK\r").unwrap(), None);
474        assert_eq!(parse_frame(b"$5\r\nhel").unwrap(), None);
475        assert_eq!(parse_frame(b"*2\r\n+OK\r\n").unwrap(), None);
476    }
477
478    #[test]
479    fn invalid_prefix() {
480        let err = parse_frame(b"~invalid\r\n").unwrap_err();
481        assert_eq!(err, ProtocolError::InvalidPrefix(b'~'));
482    }
483
484    #[test]
485    fn invalid_integer() {
486        let err = parse_frame(b":abc\r\n").unwrap_err();
487        assert_eq!(err, ProtocolError::InvalidInteger);
488    }
489
490    #[test]
491    fn negative_bulk_length() {
492        let err = parse_frame(b"$-1\r\n").unwrap_err();
493        assert!(matches!(err, ProtocolError::InvalidFrameLength(-1)));
494    }
495
496    #[test]
497    fn parse_consumes_exact_bytes() {
498        // buffer contains a full frame plus trailing garbage
499        let buf = b"+OK\r\ntrailing";
500        let (frame, consumed) = parse_frame(buf).unwrap().unwrap();
501        assert_eq!(frame, Frame::Simple("OK".into()));
502        assert_eq!(consumed, 5);
503    }
504
505    #[test]
506    fn deeply_nested_array_rejected() {
507        // build a frame nested 65 levels deep (exceeds MAX_NESTING_DEPTH of 64)
508        let mut buf = Vec::new();
509        for _ in 0..65 {
510            buf.extend_from_slice(b"*1\r\n");
511        }
512        buf.extend_from_slice(b":1\r\n"); // leaf value
513
514        let err = parse_frame(&buf).unwrap_err();
515        assert!(
516            matches!(err, ProtocolError::NestingTooDeep(64)),
517            "expected NestingTooDeep, got {err:?}"
518        );
519    }
520
521    #[test]
522    fn nesting_at_limit_accepted() {
523        // exactly 64 levels deep — should succeed
524        let mut buf = Vec::new();
525        for _ in 0..64 {
526            buf.extend_from_slice(b"*1\r\n");
527        }
528        buf.extend_from_slice(b":1\r\n");
529
530        let result = parse_frame(&buf);
531        assert!(result.is_ok(), "64 levels of nesting should be accepted");
532        assert!(result.unwrap().is_some());
533    }
534
535    #[test]
536    fn zerocopy_bulk_string() {
537        let input = Bytes::from_static(b"$5\r\nhello\r\n");
538        assert_eq!(
539            must_parse_zerocopy(&input),
540            Frame::Bulk(Bytes::from_static(b"hello"))
541        );
542    }
543
544    #[test]
545    fn zerocopy_array() {
546        let input = Bytes::from_static(b"*2\r\n$3\r\nGET\r\n$5\r\nmykey\r\n");
547        let frame = must_parse_zerocopy(&input);
548        assert_eq!(
549            frame,
550            Frame::Array(vec![
551                Frame::Bulk(Bytes::from_static(b"GET")),
552                Frame::Bulk(Bytes::from_static(b"mykey")),
553            ])
554        );
555    }
556
557    #[test]
558    fn parse_i64_bytes_valid() {
559        assert_eq!(parse_i64_bytes(b"0").unwrap(), 0);
560        assert_eq!(parse_i64_bytes(b"42").unwrap(), 42);
561        assert_eq!(parse_i64_bytes(b"-1").unwrap(), -1);
562        assert_eq!(parse_i64_bytes(b"9223372036854775807").unwrap(), i64::MAX);
563        assert_eq!(parse_i64_bytes(b"-9223372036854775808").unwrap(), i64::MIN);
564    }
565
566    #[test]
567    fn parse_i64_bytes_invalid() {
568        assert!(parse_i64_bytes(b"").is_err());
569        assert!(parse_i64_bytes(b"-").is_err());
570        assert!(parse_i64_bytes(b"abc").is_err());
571        assert!(parse_i64_bytes(b"12a").is_err());
572    }
573}