Skip to main content

ember_protocol/
parse.rs

1//! Zero-copy RESP3 parser.
2//!
3//! Operates on buffered byte slices. The caller is responsible for reading
4//! data from the network into a buffer — this parser is purely synchronous.
5//!
6//! The parser uses a `Cursor<&[u8]>` to track its position through the
7//! input buffer without consuming it, allowing the caller to retry once
8//! more data arrives.
9
10use std::io::Cursor;
11
12use bytes::Bytes;
13
14use crate::error::ProtocolError;
15use crate::types::Frame;
16
17/// Checks whether `buf` contains a complete RESP3 frame and parses it.
18///
19/// Returns `Ok(Some(frame))` if a complete frame was parsed,
20/// `Ok(None)` if the buffer doesn't contain enough data yet,
21/// or `Err(...)` if the data is malformed.
22pub fn parse_frame(buf: &[u8]) -> Result<Option<(Frame, usize)>, ProtocolError> {
23    if buf.is_empty() {
24        return Ok(None);
25    }
26
27    let mut cursor = Cursor::new(buf);
28
29    match check(&mut cursor) {
30        Ok(()) => {
31            // we know a complete frame exists — reset and parse it
32            cursor.set_position(0);
33            let frame = parse(&mut cursor)?;
34            let consumed = cursor.position() as usize;
35            Ok(Some((frame, consumed)))
36        }
37        Err(ProtocolError::Incomplete) => Ok(None),
38        Err(e) => Err(e),
39    }
40}
41
42// ---------------------------------------------------------------------------
43// check: validates a complete frame exists without allocating
44// ---------------------------------------------------------------------------
45
46/// Peeks through the buffer to verify a complete frame is present.
47/// Advances the cursor past the frame on success.
48fn check(cursor: &mut Cursor<&[u8]>) -> Result<(), ProtocolError> {
49    let prefix = read_byte(cursor)?;
50
51    match prefix {
52        b'+' | b'-' => check_line(cursor),
53        b':' => check_line(cursor),
54        b'$' => check_bulk(cursor),
55        b'*' => check_array(cursor),
56        b'_' => check_line(cursor),
57        b'%' => check_map(cursor),
58        other => Err(ProtocolError::InvalidPrefix(other)),
59    }
60}
61
62fn check_line(cursor: &mut Cursor<&[u8]>) -> Result<(), ProtocolError> {
63    find_crlf(cursor)?;
64    Ok(())
65}
66
67fn check_bulk(cursor: &mut Cursor<&[u8]>) -> Result<(), ProtocolError> {
68    let len = read_integer_line(cursor)?;
69    if len < 0 {
70        return Err(ProtocolError::InvalidFrameLength(len));
71    }
72    let len = len as usize;
73
74    // need `len` bytes of data + \r\n
75    let remaining = remaining(cursor);
76    if remaining < len + 2 {
77        return Err(ProtocolError::Incomplete);
78    }
79
80    let pos = cursor.position() as usize;
81    // verify trailing \r\n
82    let buf = cursor.get_ref();
83    if buf[pos + len] != b'\r' || buf[pos + len + 1] != b'\n' {
84        return Err(ProtocolError::InvalidFrameLength(len as i64));
85    }
86
87    cursor.set_position((pos + len + 2) as u64);
88    Ok(())
89}
90
91fn check_array(cursor: &mut Cursor<&[u8]>) -> Result<(), ProtocolError> {
92    let count = read_integer_line(cursor)?;
93    if count < 0 {
94        return Err(ProtocolError::InvalidFrameLength(count));
95    }
96
97    for _ in 0..count {
98        check(cursor)?;
99    }
100    Ok(())
101}
102
103fn check_map(cursor: &mut Cursor<&[u8]>) -> Result<(), ProtocolError> {
104    let count = read_integer_line(cursor)?;
105    if count < 0 {
106        return Err(ProtocolError::InvalidFrameLength(count));
107    }
108
109    for _ in 0..count {
110        check(cursor)?; // key
111        check(cursor)?; // value
112    }
113    Ok(())
114}
115
116// ---------------------------------------------------------------------------
117// parse: actually builds Frame values (only called after check succeeds)
118// ---------------------------------------------------------------------------
119
120fn parse(cursor: &mut Cursor<&[u8]>) -> Result<Frame, ProtocolError> {
121    let prefix = read_byte(cursor)?;
122
123    match prefix {
124        b'+' => {
125            let line = read_line(cursor)?;
126            let s = std::str::from_utf8(line).map_err(|_| {
127                ProtocolError::InvalidCommandFrame("invalid utf-8 in simple string".into())
128            })?;
129            Ok(Frame::Simple(s.to_owned()))
130        }
131        b'-' => {
132            let line = read_line(cursor)?;
133            let s = std::str::from_utf8(line).map_err(|_| {
134                ProtocolError::InvalidCommandFrame("invalid utf-8 in error string".into())
135            })?;
136            Ok(Frame::Error(s.to_owned()))
137        }
138        b':' => {
139            let val = read_integer_line(cursor)?;
140            Ok(Frame::Integer(val))
141        }
142        b'$' => {
143            let len = read_integer_line(cursor)? as usize;
144            let pos = cursor.position() as usize;
145            let data = &cursor.get_ref()[pos..pos + len];
146            cursor.set_position((pos + len + 2) as u64); // skip data + \r\n
147            Ok(Frame::Bulk(Bytes::copy_from_slice(data)))
148        }
149        b'*' => {
150            let count = read_integer_line(cursor)? as usize;
151            let mut frames = Vec::with_capacity(count);
152            for _ in 0..count {
153                frames.push(parse(cursor)?);
154            }
155            Ok(Frame::Array(frames))
156        }
157        b'_' => {
158            // consume the trailing \r\n
159            let _ = read_line(cursor)?;
160            Ok(Frame::Null)
161        }
162        b'%' => {
163            let count = read_integer_line(cursor)? as usize;
164            let mut pairs = Vec::with_capacity(count);
165            for _ in 0..count {
166                let key = parse(cursor)?;
167                let val = parse(cursor)?;
168                pairs.push((key, val));
169            }
170            Ok(Frame::Map(pairs))
171        }
172        // check() already validated the prefix, so this shouldn't happen
173        other => Err(ProtocolError::InvalidPrefix(other)),
174    }
175}
176
177// ---------------------------------------------------------------------------
178// low-level cursor helpers
179// ---------------------------------------------------------------------------
180
181fn read_byte(cursor: &mut Cursor<&[u8]>) -> Result<u8, ProtocolError> {
182    let pos = cursor.position() as usize;
183    if pos >= cursor.get_ref().len() {
184        return Err(ProtocolError::Incomplete);
185    }
186    cursor.set_position((pos + 1) as u64);
187    Ok(cursor.get_ref()[pos])
188}
189
190/// Returns the slice of bytes up to (but not including) the next `\r\n`,
191/// and advances the cursor past the `\r\n`.
192fn read_line<'a>(cursor: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], ProtocolError> {
193    let start = cursor.position() as usize;
194    let end = find_crlf(cursor)?;
195    Ok(&cursor.get_ref()[start..end])
196}
197
198/// Reads a line and parses it as an i64.
199fn read_integer_line(cursor: &mut Cursor<&[u8]>) -> Result<i64, ProtocolError> {
200    let line = read_line(cursor)?;
201    parse_i64(line)
202}
203
204/// Finds the next `\r\n` in the buffer starting from the cursor position.
205/// Returns the index of `\r` and advances the cursor past the `\n`.
206fn find_crlf(cursor: &mut Cursor<&[u8]>) -> Result<usize, ProtocolError> {
207    let buf = cursor.get_ref();
208    let start = cursor.position() as usize;
209
210    if start >= buf.len() {
211        return Err(ProtocolError::Incomplete);
212    }
213
214    // scan for \r\n
215    for i in start..buf.len().saturating_sub(1) {
216        if buf[i] == b'\r' && buf[i + 1] == b'\n' {
217            cursor.set_position((i + 2) as u64);
218            return Ok(i);
219        }
220    }
221
222    Err(ProtocolError::Incomplete)
223}
224
225fn remaining(cursor: &Cursor<&[u8]>) -> usize {
226    let len = cursor.get_ref().len();
227    let pos = cursor.position() as usize;
228    len.saturating_sub(pos)
229}
230
231fn parse_i64(buf: &[u8]) -> Result<i64, ProtocolError> {
232    let s = std::str::from_utf8(buf).map_err(|_| ProtocolError::InvalidInteger)?;
233    s.parse::<i64>().map_err(|_| ProtocolError::InvalidInteger)
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    fn must_parse(input: &[u8]) -> Frame {
241        let (frame, consumed) = parse_frame(input)
242            .expect("parse should not error")
243            .expect("parse should return a frame");
244        assert_eq!(consumed, input.len(), "should consume entire input");
245        frame
246    }
247
248    #[test]
249    fn simple_string() {
250        assert_eq!(must_parse(b"+OK\r\n"), Frame::Simple("OK".into()));
251        assert_eq!(
252            must_parse(b"+hello world\r\n"),
253            Frame::Simple("hello world".into())
254        );
255    }
256
257    #[test]
258    fn simple_error() {
259        assert_eq!(
260            must_parse(b"-ERR unknown command\r\n"),
261            Frame::Error("ERR unknown command".into())
262        );
263    }
264
265    #[test]
266    fn integer() {
267        assert_eq!(must_parse(b":42\r\n"), Frame::Integer(42));
268        assert_eq!(must_parse(b":0\r\n"), Frame::Integer(0));
269        assert_eq!(must_parse(b":-1\r\n"), Frame::Integer(-1));
270        assert_eq!(
271            must_parse(b":9223372036854775807\r\n"),
272            Frame::Integer(i64::MAX)
273        );
274        assert_eq!(
275            must_parse(b":-9223372036854775808\r\n"),
276            Frame::Integer(i64::MIN)
277        );
278    }
279
280    #[test]
281    fn bulk_string() {
282        assert_eq!(
283            must_parse(b"$5\r\nhello\r\n"),
284            Frame::Bulk(Bytes::from_static(b"hello"))
285        );
286    }
287
288    #[test]
289    fn empty_bulk_string() {
290        assert_eq!(
291            must_parse(b"$0\r\n\r\n"),
292            Frame::Bulk(Bytes::from_static(b""))
293        );
294    }
295
296    #[test]
297    fn bulk_string_with_binary() {
298        let input = b"$4\r\n\x00\x01\x02\x03\r\n";
299        assert_eq!(
300            must_parse(input),
301            Frame::Bulk(Bytes::copy_from_slice(&[0, 1, 2, 3]))
302        );
303    }
304
305    #[test]
306    fn null() {
307        assert_eq!(must_parse(b"_\r\n"), Frame::Null);
308    }
309
310    #[test]
311    fn array() {
312        let input = b"*2\r\n+hello\r\n+world\r\n";
313        assert_eq!(
314            must_parse(input),
315            Frame::Array(vec![
316                Frame::Simple("hello".into()),
317                Frame::Simple("world".into()),
318            ])
319        );
320    }
321
322    #[test]
323    fn empty_array() {
324        assert_eq!(must_parse(b"*0\r\n"), Frame::Array(vec![]));
325    }
326
327    #[test]
328    fn nested_array() {
329        let input = b"*2\r\n*2\r\n:1\r\n:2\r\n*2\r\n:3\r\n:4\r\n";
330        assert_eq!(
331            must_parse(input),
332            Frame::Array(vec![
333                Frame::Array(vec![Frame::Integer(1), Frame::Integer(2)]),
334                Frame::Array(vec![Frame::Integer(3), Frame::Integer(4)]),
335            ])
336        );
337    }
338
339    #[test]
340    fn array_with_null() {
341        let input = b"*3\r\n+OK\r\n_\r\n:1\r\n";
342        assert_eq!(
343            must_parse(input),
344            Frame::Array(vec![
345                Frame::Simple("OK".into()),
346                Frame::Null,
347                Frame::Integer(1),
348            ])
349        );
350    }
351
352    #[test]
353    fn map() {
354        let input = b"%2\r\n+key1\r\n:1\r\n+key2\r\n:2\r\n";
355        assert_eq!(
356            must_parse(input),
357            Frame::Map(vec![
358                (Frame::Simple("key1".into()), Frame::Integer(1)),
359                (Frame::Simple("key2".into()), Frame::Integer(2)),
360            ])
361        );
362    }
363
364    #[test]
365    fn incomplete_returns_none() {
366        assert_eq!(parse_frame(b"").unwrap(), None);
367        assert_eq!(parse_frame(b"+OK").unwrap(), None);
368        assert_eq!(parse_frame(b"+OK\r").unwrap(), None);
369        assert_eq!(parse_frame(b"$5\r\nhel").unwrap(), None);
370        assert_eq!(parse_frame(b"*2\r\n+OK\r\n").unwrap(), None);
371    }
372
373    #[test]
374    fn invalid_prefix() {
375        let err = parse_frame(b"~invalid\r\n").unwrap_err();
376        assert_eq!(err, ProtocolError::InvalidPrefix(b'~'));
377    }
378
379    #[test]
380    fn invalid_integer() {
381        let err = parse_frame(b":abc\r\n").unwrap_err();
382        assert_eq!(err, ProtocolError::InvalidInteger);
383    }
384
385    #[test]
386    fn negative_bulk_length() {
387        let err = parse_frame(b"$-1\r\n").unwrap_err();
388        assert!(matches!(err, ProtocolError::InvalidFrameLength(-1)));
389    }
390
391    #[test]
392    fn parse_consumes_exact_bytes() {
393        // buffer contains a full frame plus trailing garbage
394        let buf = b"+OK\r\ntrailing";
395        let (frame, consumed) = parse_frame(buf).unwrap().unwrap();
396        assert_eq!(frame, Frame::Simple("OK".into()));
397        assert_eq!(consumed, 5);
398    }
399}