Skip to main content

ember_protocol/
parse.rs

1//! Zero-copy RESP3 parser.
2//!
3//! Operates on buffered byte slices. The caller is responsible for reading
4//! data from the network into a buffer — this parser is purely synchronous.
5//!
6//! The parser uses a `Cursor<&[u8]>` to track its position through the
7//! input buffer without consuming it, allowing the caller to retry once
8//! more data arrives.
9//!
10//! # Single-pass design
11//!
12//! Earlier versions used a two-pass approach: `check()` to validate a
13//! complete frame exists, then `parse()` to build Frame values. This
14//! scanned every byte twice. The current implementation does a single
15//! pass that builds Frame values directly, returning `Incomplete` if
16//! the buffer doesn't contain enough data yet.
17
18use std::io::Cursor;
19
20use bytes::Bytes;
21
22use crate::error::ProtocolError;
23use crate::types::Frame;
24
25/// Maximum nesting depth for arrays and maps. Prevents stack overflow
26/// from malicious or malformed deeply-nested frames.
27const MAX_NESTING_DEPTH: usize = 64;
28
29/// Maximum number of elements in an array or map. Prevents memory
30/// amplification attacks where tiny elements (3 bytes each) create
31/// disproportionately large Vec allocations.
32const MAX_ARRAY_ELEMENTS: usize = 1_048_576;
33
34/// Maximum length of a bulk string in bytes (512 MB, matching Redis).
35const MAX_BULK_LEN: i64 = 512 * 1024 * 1024;
36
37/// Cap for Vec::with_capacity in array/map parsing. A declared count of
38/// 1M elements with capacity pre-allocation costs ~72 MB upfront even
39/// before any child data is parsed. This cap limits the initial allocation
40/// while still letting the Vec grow organically as elements are parsed.
41const PREALLOC_CAP: usize = 1024;
42
43/// Checks whether `buf` contains a complete RESP3 frame and parses it.
44///
45/// Returns `Ok(Some(frame))` if a complete frame was parsed,
46/// `Ok(None)` if the buffer doesn't contain enough data yet,
47/// or `Err(...)` if the data is malformed.
48#[inline]
49pub fn parse_frame(buf: &[u8]) -> Result<Option<(Frame, usize)>, ProtocolError> {
50    if buf.is_empty() {
51        return Ok(None);
52    }
53
54    let mut cursor = Cursor::new(buf);
55
56    match try_parse(&mut cursor, 0) {
57        Ok(frame) => {
58            let consumed = cursor.position() as usize;
59            Ok(Some((frame, consumed)))
60        }
61        Err(ProtocolError::Incomplete) => Ok(None),
62        Err(e) => Err(e),
63    }
64}
65
66// ---------------------------------------------------------------------------
67// single-pass parser: validates and builds Frame values in one traversal
68// ---------------------------------------------------------------------------
69
70/// Parses a complete RESP3 frame from the cursor position, returning
71/// `Incomplete` if the buffer doesn't contain enough data.
72fn try_parse(cursor: &mut Cursor<&[u8]>, depth: usize) -> Result<Frame, ProtocolError> {
73    let prefix = read_byte(cursor)?;
74
75    match prefix {
76        b'+' => {
77            let line = read_line(cursor)?;
78            let s = std::str::from_utf8(line).map_err(|_| {
79                ProtocolError::InvalidCommandFrame("invalid utf-8 in simple string".into())
80            })?;
81            Ok(Frame::Simple(s.to_owned()))
82        }
83        b'-' => {
84            let line = read_line(cursor)?;
85            let s = std::str::from_utf8(line).map_err(|_| {
86                ProtocolError::InvalidCommandFrame("invalid utf-8 in error string".into())
87            })?;
88            Ok(Frame::Error(s.to_owned()))
89        }
90        b':' => {
91            let val = read_integer_line(cursor)?;
92            Ok(Frame::Integer(val))
93        }
94        b'$' => {
95            let len = read_integer_line(cursor)?;
96            if len < 0 {
97                return Err(ProtocolError::InvalidFrameLength(len));
98            }
99            if len > MAX_BULK_LEN {
100                return Err(ProtocolError::BulkStringTooLarge(len as usize));
101            }
102            let len = len as usize;
103
104            // need `len` bytes of data + \r\n
105            let remaining = remaining(cursor);
106            if remaining < len + 2 {
107                return Err(ProtocolError::Incomplete);
108            }
109
110            let pos = cursor.position() as usize;
111            let buf = cursor.get_ref();
112
113            // verify trailing \r\n
114            if buf[pos + len] != b'\r' || buf[pos + len + 1] != b'\n' {
115                return Err(ProtocolError::InvalidFrameLength(len as i64));
116            }
117
118            let data = &buf[pos..pos + len];
119            cursor.set_position((pos + len + 2) as u64);
120            Ok(Frame::Bulk(Bytes::copy_from_slice(data)))
121        }
122        b'*' => {
123            let next_depth = depth + 1;
124            if next_depth > MAX_NESTING_DEPTH {
125                return Err(ProtocolError::NestingTooDeep(MAX_NESTING_DEPTH));
126            }
127
128            let count = read_integer_line(cursor)?;
129            if count < 0 {
130                return Err(ProtocolError::InvalidFrameLength(count));
131            }
132            if count as usize > MAX_ARRAY_ELEMENTS {
133                return Err(ProtocolError::TooManyElements(count as usize));
134            }
135
136            let count = count as usize;
137            let mut frames = Vec::with_capacity(count.min(PREALLOC_CAP));
138            for _ in 0..count {
139                frames.push(try_parse(cursor, next_depth)?);
140            }
141            Ok(Frame::Array(frames))
142        }
143        b'_' => {
144            // consume the trailing \r\n
145            let _ = read_line(cursor)?;
146            Ok(Frame::Null)
147        }
148        b'%' => {
149            let next_depth = depth + 1;
150            if next_depth > MAX_NESTING_DEPTH {
151                return Err(ProtocolError::NestingTooDeep(MAX_NESTING_DEPTH));
152            }
153
154            let count = read_integer_line(cursor)?;
155            if count < 0 {
156                return Err(ProtocolError::InvalidFrameLength(count));
157            }
158            if count as usize > MAX_ARRAY_ELEMENTS {
159                return Err(ProtocolError::TooManyElements(count as usize));
160            }
161
162            let count = count as usize;
163            let mut pairs = Vec::with_capacity(count.min(PREALLOC_CAP));
164            for _ in 0..count {
165                let key = try_parse(cursor, next_depth)?;
166                let val = try_parse(cursor, next_depth)?;
167                pairs.push((key, val));
168            }
169            Ok(Frame::Map(pairs))
170        }
171        other => Err(ProtocolError::InvalidPrefix(other)),
172    }
173}
174
175// ---------------------------------------------------------------------------
176// low-level cursor helpers
177// ---------------------------------------------------------------------------
178
179fn read_byte(cursor: &mut Cursor<&[u8]>) -> Result<u8, ProtocolError> {
180    let pos = cursor.position() as usize;
181    if pos >= cursor.get_ref().len() {
182        return Err(ProtocolError::Incomplete);
183    }
184    cursor.set_position((pos + 1) as u64);
185    Ok(cursor.get_ref()[pos])
186}
187
188/// Returns the slice of bytes up to (but not including) the next `\r\n`,
189/// and advances the cursor past the `\r\n`.
190fn read_line<'a>(cursor: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], ProtocolError> {
191    let start = cursor.position() as usize;
192    let end = find_crlf(cursor)?;
193    Ok(&cursor.get_ref()[start..end])
194}
195
196/// Reads a line and parses it as an i64.
197fn read_integer_line(cursor: &mut Cursor<&[u8]>) -> Result<i64, ProtocolError> {
198    let line = read_line(cursor)?;
199    parse_i64(line)
200}
201
202/// Finds the next `\r\n` in the buffer starting from the cursor position.
203/// Returns the index of `\r` and advances the cursor past the `\n`.
204fn find_crlf(cursor: &mut Cursor<&[u8]>) -> Result<usize, ProtocolError> {
205    let buf = cursor.get_ref();
206    let start = cursor.position() as usize;
207
208    if start >= buf.len() {
209        return Err(ProtocolError::Incomplete);
210    }
211
212    // SIMD-accelerated scan for \r, then verify \n follows.
213    // memchr processes 16-32 bytes per cycle vs 1 byte in a naive loop.
214    let mut pos = start;
215    while let Some(offset) = memchr::memchr(b'\r', &buf[pos..]) {
216        let cr = pos + offset;
217        if cr + 1 < buf.len() && buf[cr + 1] == b'\n' {
218            cursor.set_position((cr + 2) as u64);
219            return Ok(cr);
220        }
221        // bare \r without \n — keep scanning past it
222        pos = cr + 1;
223    }
224
225    Err(ProtocolError::Incomplete)
226}
227
228fn remaining(cursor: &Cursor<&[u8]>) -> usize {
229    let len = cursor.get_ref().len();
230    let pos = cursor.position() as usize;
231    len.saturating_sub(pos)
232}
233
234fn parse_i64(buf: &[u8]) -> Result<i64, ProtocolError> {
235    let s = std::str::from_utf8(buf).map_err(|_| ProtocolError::InvalidInteger)?;
236    s.parse::<i64>().map_err(|_| ProtocolError::InvalidInteger)
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    fn must_parse(input: &[u8]) -> Frame {
244        let (frame, consumed) = parse_frame(input)
245            .expect("parse should not error")
246            .expect("parse should return a frame");
247        assert_eq!(consumed, input.len(), "should consume entire input");
248        frame
249    }
250
251    #[test]
252    fn simple_string() {
253        assert_eq!(must_parse(b"+OK\r\n"), Frame::Simple("OK".into()));
254        assert_eq!(
255            must_parse(b"+hello world\r\n"),
256            Frame::Simple("hello world".into())
257        );
258    }
259
260    #[test]
261    fn simple_error() {
262        assert_eq!(
263            must_parse(b"-ERR unknown command\r\n"),
264            Frame::Error("ERR unknown command".into())
265        );
266    }
267
268    #[test]
269    fn integer() {
270        assert_eq!(must_parse(b":42\r\n"), Frame::Integer(42));
271        assert_eq!(must_parse(b":0\r\n"), Frame::Integer(0));
272        assert_eq!(must_parse(b":-1\r\n"), Frame::Integer(-1));
273        assert_eq!(
274            must_parse(b":9223372036854775807\r\n"),
275            Frame::Integer(i64::MAX)
276        );
277        assert_eq!(
278            must_parse(b":-9223372036854775808\r\n"),
279            Frame::Integer(i64::MIN)
280        );
281    }
282
283    #[test]
284    fn bulk_string() {
285        assert_eq!(
286            must_parse(b"$5\r\nhello\r\n"),
287            Frame::Bulk(Bytes::from_static(b"hello"))
288        );
289    }
290
291    #[test]
292    fn empty_bulk_string() {
293        assert_eq!(
294            must_parse(b"$0\r\n\r\n"),
295            Frame::Bulk(Bytes::from_static(b""))
296        );
297    }
298
299    #[test]
300    fn bulk_string_with_binary() {
301        let input = b"$4\r\n\x00\x01\x02\x03\r\n";
302        assert_eq!(
303            must_parse(input),
304            Frame::Bulk(Bytes::copy_from_slice(&[0, 1, 2, 3]))
305        );
306    }
307
308    #[test]
309    fn null() {
310        assert_eq!(must_parse(b"_\r\n"), Frame::Null);
311    }
312
313    #[test]
314    fn array() {
315        let input = b"*2\r\n+hello\r\n+world\r\n";
316        assert_eq!(
317            must_parse(input),
318            Frame::Array(vec![
319                Frame::Simple("hello".into()),
320                Frame::Simple("world".into()),
321            ])
322        );
323    }
324
325    #[test]
326    fn empty_array() {
327        assert_eq!(must_parse(b"*0\r\n"), Frame::Array(vec![]));
328    }
329
330    #[test]
331    fn nested_array() {
332        let input = b"*2\r\n*2\r\n:1\r\n:2\r\n*2\r\n:3\r\n:4\r\n";
333        assert_eq!(
334            must_parse(input),
335            Frame::Array(vec![
336                Frame::Array(vec![Frame::Integer(1), Frame::Integer(2)]),
337                Frame::Array(vec![Frame::Integer(3), Frame::Integer(4)]),
338            ])
339        );
340    }
341
342    #[test]
343    fn array_with_null() {
344        let input = b"*3\r\n+OK\r\n_\r\n:1\r\n";
345        assert_eq!(
346            must_parse(input),
347            Frame::Array(vec![
348                Frame::Simple("OK".into()),
349                Frame::Null,
350                Frame::Integer(1),
351            ])
352        );
353    }
354
355    #[test]
356    fn map() {
357        let input = b"%2\r\n+key1\r\n:1\r\n+key2\r\n:2\r\n";
358        assert_eq!(
359            must_parse(input),
360            Frame::Map(vec![
361                (Frame::Simple("key1".into()), Frame::Integer(1)),
362                (Frame::Simple("key2".into()), Frame::Integer(2)),
363            ])
364        );
365    }
366
367    #[test]
368    fn incomplete_returns_none() {
369        assert_eq!(parse_frame(b"").unwrap(), None);
370        assert_eq!(parse_frame(b"+OK").unwrap(), None);
371        assert_eq!(parse_frame(b"+OK\r").unwrap(), None);
372        assert_eq!(parse_frame(b"$5\r\nhel").unwrap(), None);
373        assert_eq!(parse_frame(b"*2\r\n+OK\r\n").unwrap(), None);
374    }
375
376    #[test]
377    fn invalid_prefix() {
378        let err = parse_frame(b"~invalid\r\n").unwrap_err();
379        assert_eq!(err, ProtocolError::InvalidPrefix(b'~'));
380    }
381
382    #[test]
383    fn invalid_integer() {
384        let err = parse_frame(b":abc\r\n").unwrap_err();
385        assert_eq!(err, ProtocolError::InvalidInteger);
386    }
387
388    #[test]
389    fn negative_bulk_length() {
390        let err = parse_frame(b"$-1\r\n").unwrap_err();
391        assert!(matches!(err, ProtocolError::InvalidFrameLength(-1)));
392    }
393
394    #[test]
395    fn parse_consumes_exact_bytes() {
396        // buffer contains a full frame plus trailing garbage
397        let buf = b"+OK\r\ntrailing";
398        let (frame, consumed) = parse_frame(buf).unwrap().unwrap();
399        assert_eq!(frame, Frame::Simple("OK".into()));
400        assert_eq!(consumed, 5);
401    }
402
403    #[test]
404    fn deeply_nested_array_rejected() {
405        // build a frame nested 65 levels deep (exceeds MAX_NESTING_DEPTH of 64)
406        let mut buf = Vec::new();
407        for _ in 0..65 {
408            buf.extend_from_slice(b"*1\r\n");
409        }
410        buf.extend_from_slice(b":1\r\n"); // leaf value
411
412        let err = parse_frame(&buf).unwrap_err();
413        assert!(
414            matches!(err, ProtocolError::NestingTooDeep(64)),
415            "expected NestingTooDeep, got {err:?}"
416        );
417    }
418
419    #[test]
420    fn nesting_at_limit_accepted() {
421        // exactly 64 levels deep — should succeed
422        let mut buf = Vec::new();
423        for _ in 0..64 {
424            buf.extend_from_slice(b"*1\r\n");
425        }
426        buf.extend_from_slice(b":1\r\n");
427
428        let result = parse_frame(&buf);
429        assert!(result.is_ok(), "64 levels of nesting should be accepted");
430        assert!(result.unwrap().is_some());
431    }
432}