Skip to main content

resp_rs/
resp2.rs

1//! RESP2 protocol parser and serializer.
2//!
3//! RESP2 supports five data types:
4//! - Simple String: `+OK\r\n`
5//! - Error: `-ERR message\r\n`
6//! - Integer: `:42\r\n`
7//! - Bulk String: `$6\r\nfoobar\r\n` (or `$-1\r\n` for null)
8//! - Array: `*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n` (or `*-1\r\n` for null)
9
10use bytes::{BufMut, Bytes, BytesMut};
11
12use crate::ParseError;
13
14/// Maximum reasonable size for collections to prevent DoS attacks.
15const MAX_COLLECTION_SIZE: usize = 10_000_000;
16
17/// A parsed RESP2 frame.
18#[derive(Debug, Clone, PartialEq)]
19pub enum Frame {
20    /// Simple string: `+OK\r\n`
21    SimpleString(Bytes),
22    /// Error: `-ERR message\r\n`
23    Error(Bytes),
24    /// Integer: `:42\r\n`
25    Integer(i64),
26    /// Bulk string: `$6\r\nfoobar\r\n`
27    BulkString(Option<Bytes>),
28    /// Array: `*N\r\n...`
29    Array(Option<Vec<Frame>>),
30}
31
32/// Parse a single RESP2 frame from the provided bytes.
33///
34/// Returns the parsed frame and any remaining unconsumed bytes.
35///
36/// # Errors
37///
38/// Returns `ParseError::Incomplete` if there isn't enough data for a complete frame.
39/// Returns other `ParseError` variants for malformed input.
40///
41/// # Examples
42///
43/// ```
44/// use bytes::Bytes;
45/// use resp_rs::resp2::parse_frame;
46///
47/// let data = Bytes::from("+OK\r\nrest");
48/// let (frame, rest) = parse_frame(data).unwrap();
49/// assert_eq!(rest, Bytes::from("rest"));
50/// ```
51pub fn parse_frame(input: Bytes) -> Result<(Frame, Bytes), ParseError> {
52    let (frame, consumed) = parse_frame_inner(&input, 0)?;
53    Ok((frame, input.slice(consumed..)))
54}
55
56/// Offset-based internal parser. Works with byte positions to avoid creating
57/// intermediate `Bytes::slice()` objects. Only slices for actual frame data.
58fn parse_frame_inner(input: &Bytes, pos: usize) -> Result<(Frame, usize), ParseError> {
59    let buf = input.as_ref();
60    if pos >= buf.len() {
61        return Err(ParseError::Incomplete);
62    }
63
64    let tag = buf[pos];
65
66    match tag {
67        b'+' => {
68            let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
69            Ok((
70                Frame::SimpleString(input.slice(pos + 1..line_end)),
71                after_crlf,
72            ))
73        }
74        b'-' => {
75            let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
76            Ok((Frame::Error(input.slice(pos + 1..line_end)), after_crlf))
77        }
78        b':' => {
79            let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
80            let v = parse_i64(&buf[pos + 1..line_end])?;
81            Ok((Frame::Integer(v), after_crlf))
82        }
83        b'$' => {
84            let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
85            let len_bytes = &buf[pos + 1..line_end];
86            // null bulk string: $-1\r\n
87            if len_bytes == b"-1" {
88                return Ok((Frame::BulkString(None), after_crlf));
89            }
90            let len = parse_usize(len_bytes)?;
91            if len == 0 {
92                if after_crlf + 1 >= buf.len() {
93                    return Err(ParseError::Incomplete);
94                }
95                if buf[after_crlf] == b'\r' && buf[after_crlf + 1] == b'\n' {
96                    return Ok((Frame::BulkString(Some(Bytes::new())), after_crlf + 2));
97                } else {
98                    return Err(ParseError::InvalidFormat);
99                }
100            }
101            let data_start = after_crlf;
102            let data_end = data_start.checked_add(len).ok_or(ParseError::BadLength)?;
103            if data_end + 1 >= buf.len() || buf[data_end] != b'\r' || buf[data_end + 1] != b'\n' {
104                return Err(ParseError::Incomplete);
105            }
106            Ok((
107                Frame::BulkString(Some(input.slice(data_start..data_end))),
108                data_end + 2,
109            ))
110        }
111        b'*' => {
112            let (line_end, after_crlf) = find_crlf(buf, pos + 1)?;
113            let len_bytes = &buf[pos + 1..line_end];
114            // null array: *-1\r\n
115            if len_bytes == b"-1" {
116                return Ok((Frame::Array(None), after_crlf));
117            }
118            let count = parse_usize(len_bytes)?;
119            if count > MAX_COLLECTION_SIZE {
120                return Err(ParseError::BadLength);
121            }
122            if count == 0 {
123                return Ok((Frame::Array(Some(Vec::new())), after_crlf));
124            }
125            let mut cursor = after_crlf;
126            let mut items = Vec::with_capacity(count);
127            for _ in 0..count {
128                let (item, next) = parse_frame_inner(input, cursor)?;
129                items.push(item);
130                cursor = next;
131            }
132            Ok((Frame::Array(Some(items)), cursor))
133        }
134        _ => Err(ParseError::InvalidTag(tag)),
135    }
136}
137
138/// Serialize a RESP2 frame to bytes.
139///
140/// # Examples
141///
142/// ```
143/// use bytes::Bytes;
144/// use resp_rs::resp2::{Frame, frame_to_bytes};
145///
146/// let frame = Frame::SimpleString(Bytes::from("OK"));
147/// assert_eq!(frame_to_bytes(&frame), Bytes::from("+OK\r\n"));
148/// ```
149pub fn frame_to_bytes(frame: &Frame) -> Bytes {
150    let mut buf = BytesMut::new();
151    serialize_frame(frame, &mut buf);
152    buf.freeze()
153}
154
155fn serialize_frame(frame: &Frame, buf: &mut BytesMut) {
156    match frame {
157        Frame::SimpleString(s) => {
158            buf.put_u8(b'+');
159            buf.extend_from_slice(s);
160            buf.extend_from_slice(b"\r\n");
161        }
162        Frame::Error(s) => {
163            buf.put_u8(b'-');
164            buf.extend_from_slice(s);
165            buf.extend_from_slice(b"\r\n");
166        }
167        Frame::Integer(i) => {
168            buf.put_u8(b':');
169            buf.extend_from_slice(i.to_string().as_bytes());
170            buf.extend_from_slice(b"\r\n");
171        }
172        Frame::BulkString(opt) => {
173            buf.put_u8(b'$');
174            match opt {
175                Some(data) => {
176                    buf.extend_from_slice(data.len().to_string().as_bytes());
177                    buf.extend_from_slice(b"\r\n");
178                    buf.extend_from_slice(data);
179                    buf.extend_from_slice(b"\r\n");
180                }
181                None => buf.extend_from_slice(b"-1\r\n"),
182            }
183        }
184        Frame::Array(opt) => {
185            buf.put_u8(b'*');
186            match opt {
187                Some(items) => {
188                    buf.extend_from_slice(items.len().to_string().as_bytes());
189                    buf.extend_from_slice(b"\r\n");
190                    for item in items {
191                        serialize_frame(item, buf);
192                    }
193                }
194                None => buf.extend_from_slice(b"-1\r\n"),
195            }
196        }
197    }
198}
199
200/// Streaming RESP2 parser.
201///
202/// Feed data incrementally and extract frames as they become available.
203///
204/// # Examples
205///
206/// ```
207/// use bytes::Bytes;
208/// use resp_rs::resp2::{Parser, Frame};
209///
210/// let mut parser = Parser::new();
211/// parser.feed(Bytes::from("+HEL"));
212/// assert!(parser.next_frame().unwrap().is_none());
213///
214/// parser.feed(Bytes::from("LO\r\n"));
215/// let frame = parser.next_frame().unwrap().unwrap();
216/// assert_eq!(frame, Frame::SimpleString(Bytes::from("HELLO")));
217/// ```
218#[derive(Default, Debug)]
219pub struct Parser {
220    buffer: BytesMut,
221}
222
223impl Parser {
224    /// Create a new empty parser.
225    pub fn new() -> Self {
226        Self {
227            buffer: BytesMut::new(),
228        }
229    }
230
231    /// Feed data into the parser buffer.
232    pub fn feed(&mut self, data: Bytes) {
233        self.buffer.extend_from_slice(&data);
234    }
235
236    /// Try to extract the next complete frame.
237    ///
238    /// Returns `Ok(None)` if there isn't enough data yet.
239    /// Returns `Err` on protocol errors (buffer is cleared).
240    pub fn next_frame(&mut self) -> Result<Option<Frame>, ParseError> {
241        if self.buffer.is_empty() {
242            return Ok(None);
243        }
244
245        let bytes = self.buffer.split().freeze();
246
247        match parse_frame_inner(&bytes, 0) {
248            Ok((frame, consumed)) => {
249                if consumed < bytes.len() {
250                    self.buffer.unsplit(BytesMut::from(&bytes[consumed..]));
251                }
252                Ok(Some(frame))
253            }
254            Err(ParseError::Incomplete) => {
255                self.buffer.unsplit(bytes.into());
256                Ok(None)
257            }
258            Err(e) => Err(e),
259        }
260    }
261
262    /// Number of bytes currently buffered.
263    pub fn buffered_bytes(&self) -> usize {
264        self.buffer.len()
265    }
266
267    /// Clear the internal buffer.
268    pub fn clear(&mut self) {
269        self.buffer.clear();
270    }
271}
272
273/// Find `\r\n` in `buf` starting at `from`. Returns `(line_end, after_crlf)` where
274/// `line_end` is the position of `\r` and `after_crlf` is the position after `\n`.
275#[inline]
276fn find_crlf(buf: &[u8], from: usize) -> Result<(usize, usize), ParseError> {
277    let mut i = from;
278    let len = buf.len();
279    while i + 1 < len {
280        if buf[i] == b'\r' && buf[i + 1] == b'\n' {
281            return Ok((i, i + 2));
282        }
283        i += 1;
284    }
285    Err(ParseError::Incomplete)
286}
287
288/// Parse a `usize` directly from ASCII digit bytes, no UTF-8 validation needed.
289#[inline]
290fn parse_usize(buf: &[u8]) -> Result<usize, ParseError> {
291    if buf.is_empty() {
292        return Err(ParseError::BadLength);
293    }
294    let mut v: usize = 0;
295    for &b in buf {
296        if !b.is_ascii_digit() {
297            return Err(ParseError::BadLength);
298        }
299        v = v.checked_mul(10).ok_or(ParseError::BadLength)?;
300        v = v
301            .checked_add((b - b'0') as usize)
302            .ok_or(ParseError::BadLength)?;
303    }
304    Ok(v)
305}
306
307/// Parse an `i64` directly from ASCII bytes (optional leading `-`), no UTF-8 validation.
308#[inline]
309fn parse_i64(buf: &[u8]) -> Result<i64, ParseError> {
310    if buf.is_empty() {
311        return Err(ParseError::InvalidFormat);
312    }
313    let (neg, digits) = if buf[0] == b'-' {
314        (true, &buf[1..])
315    } else {
316        (false, buf)
317    };
318    if digits.is_empty() {
319        return Err(ParseError::InvalidFormat);
320    }
321    let mut v: i64 = 0;
322    for (i, &d) in digits.iter().enumerate() {
323        if !d.is_ascii_digit() {
324            return Err(ParseError::InvalidFormat);
325        }
326        let digit = (d - b'0') as i64;
327        if neg && v == i64::MAX / 10 && digit == 8 && i == digits.len() - 1 {
328            return Ok(i64::MIN);
329        }
330        if v > i64::MAX / 10 || (v == i64::MAX / 10 && digit > i64::MAX % 10) {
331            return Err(ParseError::Overflow);
332        }
333        v = v * 10 + digit;
334    }
335    if neg { Ok(-v) } else { Ok(v) }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn simple_string() {
344        let (frame, rest) = parse_frame(Bytes::from("+OK\r\nrest")).unwrap();
345        assert_eq!(frame, Frame::SimpleString(Bytes::from("OK")));
346        assert_eq!(rest, Bytes::from("rest"));
347    }
348
349    #[test]
350    fn error() {
351        let (frame, _) = parse_frame(Bytes::from("-ERR fail\r\n")).unwrap();
352        assert_eq!(frame, Frame::Error(Bytes::from("ERR fail")));
353    }
354
355    #[test]
356    fn integer() {
357        let (frame, _) = parse_frame(Bytes::from(":42\r\n")).unwrap();
358        assert_eq!(frame, Frame::Integer(42));
359
360        let (frame, _) = parse_frame(Bytes::from(":-123\r\n")).unwrap();
361        assert_eq!(frame, Frame::Integer(-123));
362    }
363
364    #[test]
365    fn bulk_string() {
366        let (frame, rest) = parse_frame(Bytes::from("$5\r\nhello\r\nX")).unwrap();
367        assert_eq!(frame, Frame::BulkString(Some(Bytes::from("hello"))));
368        assert_eq!(rest, Bytes::from("X"));
369    }
370
371    #[test]
372    fn null_bulk_string() {
373        let (frame, _) = parse_frame(Bytes::from("$-1\r\n")).unwrap();
374        assert_eq!(frame, Frame::BulkString(None));
375    }
376
377    #[test]
378    fn empty_bulk_string() {
379        let (frame, rest) = parse_frame(Bytes::from("$0\r\n\r\nX")).unwrap();
380        assert_eq!(frame, Frame::BulkString(Some(Bytes::new())));
381        assert_eq!(rest, Bytes::from("X"));
382    }
383
384    #[test]
385    fn array() {
386        let input = Bytes::from("*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n");
387        let (frame, _) = parse_frame(input).unwrap();
388        assert_eq!(
389            frame,
390            Frame::Array(Some(vec![
391                Frame::BulkString(Some(Bytes::from("foo"))),
392                Frame::BulkString(Some(Bytes::from("bar"))),
393            ]))
394        );
395    }
396
397    #[test]
398    fn null_array() {
399        let (frame, _) = parse_frame(Bytes::from("*-1\r\n")).unwrap();
400        assert_eq!(frame, Frame::Array(None));
401    }
402
403    #[test]
404    fn empty_array() {
405        let (frame, _) = parse_frame(Bytes::from("*0\r\n")).unwrap();
406        assert_eq!(frame, Frame::Array(Some(vec![])));
407    }
408
409    #[test]
410    fn nested_array() {
411        let input = Bytes::from("*2\r\n*1\r\n:1\r\n+OK\r\n");
412        let (frame, _) = parse_frame(input).unwrap();
413        assert_eq!(
414            frame,
415            Frame::Array(Some(vec![
416                Frame::Array(Some(vec![Frame::Integer(1)])),
417                Frame::SimpleString(Bytes::from("OK")),
418            ]))
419        );
420    }
421
422    #[test]
423    fn incomplete() {
424        assert_eq!(parse_frame(Bytes::new()), Err(ParseError::Incomplete));
425        assert_eq!(
426            parse_frame(Bytes::from("+OK\r")),
427            Err(ParseError::Incomplete)
428        );
429        assert_eq!(
430            parse_frame(Bytes::from("$5\r\nhel")),
431            Err(ParseError::Incomplete)
432        );
433    }
434
435    #[test]
436    fn invalid_tag() {
437        assert_eq!(
438            parse_frame(Bytes::from("X\r\n")),
439            Err(ParseError::InvalidTag(b'X'))
440        );
441    }
442
443    #[test]
444    fn roundtrip() {
445        let frames = vec![
446            Frame::SimpleString(Bytes::from("OK")),
447            Frame::Error(Bytes::from("ERR bad")),
448            Frame::Integer(42),
449            Frame::BulkString(Some(Bytes::from("hello"))),
450            Frame::BulkString(None),
451            Frame::Array(Some(vec![
452                Frame::Integer(1),
453                Frame::BulkString(Some(Bytes::from("two"))),
454            ])),
455            Frame::Array(None),
456        ];
457        for frame in &frames {
458            let bytes = frame_to_bytes(frame);
459            let (parsed, rest) = parse_frame(bytes).unwrap();
460            assert_eq!(&parsed, frame);
461            assert!(rest.is_empty());
462        }
463    }
464
465    #[test]
466    fn streaming_parser() {
467        let mut parser = Parser::new();
468        parser.feed(Bytes::from("+HEL"));
469        assert!(parser.next_frame().unwrap().is_none());
470
471        parser.feed(Bytes::from("LO\r\n:42\r\n"));
472        let f1 = parser.next_frame().unwrap().unwrap();
473        assert_eq!(f1, Frame::SimpleString(Bytes::from("HELLO")));
474
475        let f2 = parser.next_frame().unwrap().unwrap();
476        assert_eq!(f2, Frame::Integer(42));
477
478        assert!(parser.next_frame().unwrap().is_none());
479    }
480
481    #[test]
482    fn chained_frames() {
483        let input = Bytes::from("+OK\r\n:1\r\n$3\r\nfoo\r\n");
484        let (f1, rest) = parse_frame(input).unwrap();
485        assert_eq!(f1, Frame::SimpleString(Bytes::from("OK")));
486        let (f2, rest) = parse_frame(rest).unwrap();
487        assert_eq!(f2, Frame::Integer(1));
488        let (f3, rest) = parse_frame(rest).unwrap();
489        assert_eq!(f3, Frame::BulkString(Some(Bytes::from("foo"))));
490        assert!(rest.is_empty());
491    }
492
493    #[test]
494    fn binary_bulk_string() {
495        let mut data = Vec::new();
496        data.extend_from_slice(b"$5\r\n");
497        data.extend_from_slice(&[0x00, 0x01, 0xFF, 0xFE, 0x42]);
498        data.extend_from_slice(b"\r\n");
499        let (frame, _) = parse_frame(Bytes::from(data)).unwrap();
500        match frame {
501            Frame::BulkString(Some(b)) => {
502                assert_eq!(b.as_ref(), &[0x00, 0x01, 0xFF, 0xFE, 0x42]);
503            }
504            _ => panic!("expected bulk string"),
505        }
506    }
507
508    #[test]
509    fn rejects_resp3_types() {
510        // RESP3-only types should fail with InvalidTag in RESP2 mode
511        assert!(parse_frame(Bytes::from("_\r\n")).is_err()); // Null
512        assert!(parse_frame(Bytes::from(",3.14\r\n")).is_err()); // Double
513        assert!(parse_frame(Bytes::from("#t\r\n")).is_err()); // Boolean
514        assert!(parse_frame(Bytes::from("(123\r\n")).is_err()); // Big number
515    }
516}