Skip to main content

fast_cache/protocol/
resp.rs

1use std::fmt;
2use std::ops::Range;
3
4use crate::{FastCacheError, Result};
5
6#[derive(Debug, Clone, PartialEq, Eq)]
7pub enum Frame {
8    SimpleString(String),
9    BlobString(Vec<u8>),
10    Integer(i64),
11    Array(Vec<Frame>),
12    Null,
13    Boolean(bool),
14    Error(String),
15}
16
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub struct CommandFrame {
19    pub parts: Vec<Vec<u8>>,
20}
21
22/// Inline storage for the borrowed command's parts. The benchmarked multi-key
23/// shape is MGET with 8 keys or MSET with 8 key/value pairs, so inline enough
24/// parts for those requests to avoid a per-command heap allocation.
25pub type BorrowedCommandParts<'a> = smallvec::SmallVec<[&'a [u8]; 18]>;
26pub type CommandPartSpans = smallvec::SmallVec<[Range<usize>; 18]>;
27
28#[derive(Debug, Clone, PartialEq, Eq)]
29pub struct BorrowedCommandFrame<'a> {
30    pub parts: BorrowedCommandParts<'a>,
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct CommandSpanFrame {
35    pub parts: CommandPartSpans,
36}
37
38pub type RespDecodeResult = Option<(Frame, usize)>;
39pub type RespCommandDecodeResult<'a> = Option<(BorrowedCommandFrame<'a>, usize)>;
40pub type RespCommandSpanDecodeResult = Option<(CommandSpanFrame, usize)>;
41
42#[derive(Debug, Default, Clone, Copy)]
43pub struct RespCodec;
44
45impl RespCodec {
46    pub fn decode(buffer: &[u8]) -> Result<RespDecodeResult> {
47        if buffer.is_empty() {
48            return Ok(None);
49        }
50        parse_frame(buffer, 0)
51    }
52
53    pub fn encode(frame: &Frame, out: &mut Vec<u8>) {
54        match frame {
55            Frame::SimpleString(value) => {
56                out.push(b'+');
57                out.extend_from_slice(value.as_bytes());
58                out.extend_from_slice(b"\r\n");
59            }
60            Frame::BlobString(value) => {
61                let mut buf = itoa::Buffer::new();
62                out.push(b'$');
63                out.extend_from_slice(buf.format(value.len()).as_bytes());
64                out.extend_from_slice(b"\r\n");
65                out.extend_from_slice(value);
66                out.extend_from_slice(b"\r\n");
67            }
68            Frame::Integer(value) => {
69                let mut buf = itoa::Buffer::new();
70                out.push(b':');
71                out.extend_from_slice(buf.format(*value).as_bytes());
72                out.extend_from_slice(b"\r\n");
73            }
74            Frame::Array(items) => {
75                let mut buf = itoa::Buffer::new();
76                out.push(b'*');
77                out.extend_from_slice(buf.format(items.len()).as_bytes());
78                out.extend_from_slice(b"\r\n");
79                for item in items {
80                    Self::encode(item, out);
81                }
82            }
83            Frame::Null => {
84                out.extend_from_slice(b"_\r\n");
85            }
86            Frame::Boolean(value) => {
87                out.extend_from_slice(if *value { b"#t\r\n" } else { b"#f\r\n" });
88            }
89            Frame::Error(message) => {
90                out.push(b'-');
91                out.extend_from_slice(message.as_bytes());
92                out.extend_from_slice(b"\r\n");
93            }
94        }
95    }
96
97    pub fn decode_command(buffer: &[u8]) -> Result<RespCommandDecodeResult<'_>> {
98        if buffer.is_empty() {
99            return Ok(None);
100        }
101        parse_command_frame(buffer, 0)
102    }
103
104    pub fn decode_command_spans(buffer: &[u8]) -> Result<RespCommandSpanDecodeResult> {
105        if buffer.is_empty() {
106            return Ok(None);
107        }
108        parse_command_span_frame(buffer, 0)
109    }
110
111    pub fn as_command(frame: Frame) -> Result<CommandFrame> {
112        match frame {
113            Frame::Array(parts) => {
114                let mut output = Vec::with_capacity(parts.len());
115                for part in parts {
116                    match part {
117                        Frame::BlobString(bytes) => output.push(bytes),
118                        Frame::SimpleString(text) => output.push(text.into_bytes()),
119                        Frame::Integer(value) => output.push(value.to_string().into_bytes()),
120                        other => {
121                            return Err(FastCacheError::Protocol(format!(
122                                "command arrays may only contain bulk strings, simple strings, or integers; got {other:?}"
123                            )));
124                        }
125                    }
126                }
127                Ok(CommandFrame { parts: output })
128            }
129            other => Err(FastCacheError::Protocol(format!(
130                "expected command array, got {other:?}"
131            ))),
132        }
133    }
134}
135
136fn parse_command_frame(buffer: &[u8], offset: usize) -> Result<RespCommandDecodeResult<'_>> {
137    if offset >= buffer.len() {
138        return Ok(None);
139    }
140    if buffer[offset] != b'*' {
141        return Err(FastCacheError::Protocol(
142            "expected RESP array for command frame".into(),
143        ));
144    }
145    let Some((count, header_consumed)) = parse_isize_line(&buffer[offset + 1..])? else {
146        return Ok(None);
147    };
148    if count < 0 {
149        return Err(FastCacheError::Protocol(
150            "null command arrays are not supported".into(),
151        ));
152    }
153
154    let mut cursor = offset + 1 + header_consumed;
155    let mut parts: BorrowedCommandParts<'_> = smallvec::SmallVec::with_capacity(count as usize);
156    for _ in 0..count as usize {
157        let Some((part, consumed)) = parse_command_part(buffer, cursor)? else {
158            return Ok(None);
159        };
160        parts.push(part);
161        cursor += consumed;
162    }
163
164    Ok(Some((BorrowedCommandFrame { parts }, cursor - offset)))
165}
166
167fn parse_command_span_frame(buffer: &[u8], offset: usize) -> Result<RespCommandSpanDecodeResult> {
168    if offset >= buffer.len() {
169        return Ok(None);
170    }
171    if buffer[offset] != b'*' {
172        return Err(FastCacheError::Protocol(
173            "expected RESP array for command frame".into(),
174        ));
175    }
176    let Some((count, header_consumed)) = parse_isize_line(&buffer[offset + 1..])? else {
177        return Ok(None);
178    };
179    if count < 0 {
180        return Err(FastCacheError::Protocol(
181            "null command arrays are not supported".into(),
182        ));
183    }
184
185    let mut cursor = offset + 1 + header_consumed;
186    let mut parts = CommandPartSpans::with_capacity(count as usize);
187    for _ in 0..count as usize {
188        let Some((part, consumed)) = parse_command_part_span(buffer, cursor)? else {
189            return Ok(None);
190        };
191        parts.push(part);
192        cursor += consumed;
193    }
194
195    Ok(Some((CommandSpanFrame { parts }, cursor - offset)))
196}
197
198fn parse_command_part(buffer: &[u8], offset: usize) -> Result<Option<(&[u8], usize)>> {
199    if offset >= buffer.len() {
200        return Ok(None);
201    }
202
203    match buffer[offset] {
204        b'$' => parse_command_blob_string(buffer, offset),
205        b'+' => parse_command_simple_string(buffer, offset),
206        b':' => parse_command_integer(buffer, offset),
207        other => Err(FastCacheError::Protocol(format!(
208            "unsupported RESP command part prefix byte: {other:#x}"
209        ))),
210    }
211}
212
213fn parse_command_part_span(buffer: &[u8], offset: usize) -> Result<Option<(Range<usize>, usize)>> {
214    if offset >= buffer.len() {
215        return Ok(None);
216    }
217
218    match buffer[offset] {
219        b'$' => parse_command_blob_string_span(buffer, offset),
220        b'+' => parse_command_line_span(buffer, offset),
221        b':' => parse_command_line_span(buffer, offset),
222        other => Err(FastCacheError::Protocol(format!(
223            "unsupported RESP command part prefix byte: {other:#x}"
224        ))),
225    }
226}
227
228fn parse_command_blob_string(buffer: &[u8], offset: usize) -> Result<Option<(&[u8], usize)>> {
229    let Some((length, header_consumed)) = parse_isize_line(&buffer[offset + 1..])? else {
230        return Ok(None);
231    };
232    if length < 0 {
233        return Err(FastCacheError::Protocol(
234            "null bulk strings are not supported in command frames".into(),
235        ));
236    }
237    let length = length as usize;
238    let start = offset + 1 + header_consumed;
239    let end = start + length;
240    if buffer.len() < end + 2 {
241        return Ok(None);
242    }
243    if &buffer[end..end + 2] != b"\r\n" {
244        return Err(FastCacheError::Protocol(
245            "blob string missing CRLF terminator".into(),
246        ));
247    }
248    Ok(Some((&buffer[start..end], (end + 2) - offset)))
249}
250
251fn parse_command_blob_string_span(
252    buffer: &[u8],
253    offset: usize,
254) -> Result<Option<(Range<usize>, usize)>> {
255    let Some((length, header_consumed)) = parse_isize_line(&buffer[offset + 1..])? else {
256        return Ok(None);
257    };
258    if length < 0 {
259        return Err(FastCacheError::Protocol(
260            "null bulk strings are not supported in command frames".into(),
261        ));
262    }
263    let length = length as usize;
264    let start = offset + 1 + header_consumed;
265    let end = start + length;
266    if buffer.len() < end + 2 {
267        return Ok(None);
268    }
269    if &buffer[end..end + 2] != b"\r\n" {
270        return Err(FastCacheError::Protocol(
271            "blob string missing CRLF terminator".into(),
272        ));
273    }
274    Ok(Some((start..end, (end + 2) - offset)))
275}
276
277fn parse_command_simple_string(buffer: &[u8], offset: usize) -> Result<Option<(&[u8], usize)>> {
278    let Some(line_end) = find_crlf(&buffer[offset + 1..]) else {
279        return Ok(None);
280    };
281    let start = offset + 1;
282    let end = start + line_end;
283    Ok(Some((&buffer[start..end], end + 2 - offset)))
284}
285
286fn parse_command_line_span(buffer: &[u8], offset: usize) -> Result<Option<(Range<usize>, usize)>> {
287    let Some(line_end) = find_crlf(&buffer[offset + 1..]) else {
288        return Ok(None);
289    };
290    let start = offset + 1;
291    let end = start + line_end;
292    Ok(Some((start..end, end + 2 - offset)))
293}
294
295fn parse_command_integer(buffer: &[u8], offset: usize) -> Result<Option<(&[u8], usize)>> {
296    let Some(line_end) = find_crlf(&buffer[offset + 1..]) else {
297        return Ok(None);
298    };
299    let start = offset + 1;
300    let end = start + line_end;
301    Ok(Some((&buffer[start..end], end + 2 - offset)))
302}
303
304fn parse_frame(buffer: &[u8], offset: usize) -> Result<RespDecodeResult> {
305    if offset >= buffer.len() {
306        return Ok(None);
307    }
308    match buffer[offset] {
309        b'+' => parse_simple_string(buffer, offset),
310        b'-' => parse_error(buffer, offset),
311        b':' => parse_integer(buffer, offset),
312        b'$' => parse_blob_string(buffer, offset),
313        b'*' => parse_array(buffer, offset),
314        b'_' => parse_null(buffer, offset),
315        b'#' => parse_boolean(buffer, offset),
316        other => Err(FastCacheError::Protocol(format!(
317            "unsupported RESP prefix byte: {other:#x}"
318        ))),
319    }
320}
321
322fn parse_simple_string(buffer: &[u8], offset: usize) -> Result<RespDecodeResult> {
323    let Some((line, consumed)) = parse_line(&buffer[offset + 1..])? else {
324        return Ok(None);
325    };
326    Ok(Some((Frame::SimpleString(line.to_string()), consumed + 1)))
327}
328
329fn parse_error(buffer: &[u8], offset: usize) -> Result<RespDecodeResult> {
330    let Some((line, consumed)) = parse_line(&buffer[offset + 1..])? else {
331        return Ok(None);
332    };
333    Ok(Some((Frame::Error(line.to_string()), consumed + 1)))
334}
335
336fn parse_integer(buffer: &[u8], offset: usize) -> Result<RespDecodeResult> {
337    let Some((value, consumed)) = parse_i64_line(&buffer[offset + 1..])? else {
338        return Ok(None);
339    };
340    Ok(Some((Frame::Integer(value), consumed + 1)))
341}
342
343fn parse_blob_string(buffer: &[u8], offset: usize) -> Result<RespDecodeResult> {
344    let Some((length, header_consumed)) = parse_isize_line(&buffer[offset + 1..])? else {
345        return Ok(None);
346    };
347    if length < 0 {
348        return Ok(Some((Frame::Null, header_consumed + 1)));
349    }
350    let length = length as usize;
351    let start = offset + 1 + header_consumed;
352    let end = start + length;
353    if buffer.len() < end + 2 {
354        return Ok(None);
355    }
356    if &buffer[end..end + 2] != b"\r\n" {
357        return Err(FastCacheError::Protocol(
358            "blob string missing CRLF terminator".into(),
359        ));
360    }
361    Ok(Some((
362        Frame::BlobString(buffer[start..end].to_vec()),
363        (end + 2) - offset,
364    )))
365}
366
367fn parse_array(buffer: &[u8], offset: usize) -> Result<RespDecodeResult> {
368    let Some((count, header_consumed)) = parse_isize_line(&buffer[offset + 1..])? else {
369        return Ok(None);
370    };
371    if count < 0 {
372        return Ok(Some((Frame::Null, header_consumed + 1)));
373    }
374    let count = count as usize;
375    let mut cursor = offset + 1 + header_consumed;
376    let mut items = Vec::with_capacity(count);
377    for _ in 0..count {
378        let Some((frame, consumed)) = parse_frame(buffer, cursor)? else {
379            return Ok(None);
380        };
381        items.push(frame);
382        cursor += consumed;
383    }
384    Ok(Some((Frame::Array(items), cursor - offset)))
385}
386
387fn parse_null(buffer: &[u8], offset: usize) -> Result<RespDecodeResult> {
388    if buffer.len() < offset + 3 {
389        return Ok(None);
390    }
391    if &buffer[offset + 1..offset + 3] != b"\r\n" {
392        return Err(FastCacheError::Protocol("invalid null frame".into()));
393    }
394    Ok(Some((Frame::Null, 3)))
395}
396
397fn parse_boolean(buffer: &[u8], offset: usize) -> Result<RespDecodeResult> {
398    if buffer.len() < offset + 4 {
399        return Ok(None);
400    }
401    let value = match buffer[offset + 1] {
402        b't' => true,
403        b'f' => false,
404        other => {
405            return Err(FastCacheError::Protocol(format!(
406                "invalid boolean marker: {other:#x}"
407            )));
408        }
409    };
410    if &buffer[offset + 2..offset + 4] != b"\r\n" {
411        return Err(FastCacheError::Protocol("invalid boolean frame".into()));
412    }
413    Ok(Some((Frame::Boolean(value), 4)))
414}
415
416fn parse_line(buffer: &[u8]) -> Result<Option<(&str, usize)>> {
417    let Some(end) = find_crlf(buffer) else {
418        return Ok(None);
419    };
420    let line = std::str::from_utf8(&buffer[..end])
421        .map_err(|error| FastCacheError::Protocol(format!("invalid utf8 in RESP line: {error}")))?;
422    Ok(Some((line, end + 2)))
423}
424
425#[inline]
426fn find_crlf(buffer: &[u8]) -> Option<usize> {
427    memchr::memmem::find(buffer, b"\r\n")
428}
429
430#[inline]
431fn parse_isize_line(buffer: &[u8]) -> Result<Option<(isize, usize)>> {
432    // Fast path: 1-4 digit non-negative integers terminated by `\r\n`. This
433    // covers the overwhelming majority of RESP integers in practice (array
434    // counts, blob string lengths up to 9999) without invoking memchr at all.
435    if let Some((value, consumed)) = try_parse_short_uint_line(buffer) {
436        return Ok(Some((value as isize, consumed)));
437    }
438    let Some(end) = find_crlf(buffer) else {
439        return Ok(None);
440    };
441    let value = parse_ascii_isize(&buffer[..end])?;
442    Ok(Some((value, end + 2)))
443}
444
445/// Parses RESP integer headers of 1-4 ASCII digits followed by `\r\n` without
446/// invoking `memchr`. Returns `Some((value, consumed))` when the header is in
447/// that range, `None` otherwise. Doesn't validate digits beyond range — the
448/// general path will report errors if this returns `None`.
449#[inline(always)]
450fn try_parse_short_uint_line(buffer: &[u8]) -> Option<(usize, usize)> {
451    if buffer.len() < 3 {
452        return None;
453    }
454    let b0 = buffer[0];
455    if !b0.is_ascii_digit() {
456        return None;
457    }
458    let d0 = (b0 - b'0') as usize;
459    // 1 digit
460    if buffer[1] == b'\r' && buffer[2] == b'\n' {
461        return Some((d0, 3));
462    }
463    if buffer.len() < 4 {
464        return None;
465    }
466    let b1 = buffer[1];
467    if !b1.is_ascii_digit() {
468        return None;
469    }
470    let d1 = (b1 - b'0') as usize;
471    // 2 digits
472    if buffer[2] == b'\r' && buffer[3] == b'\n' {
473        return Some((d0 * 10 + d1, 4));
474    }
475    if buffer.len() < 5 {
476        return None;
477    }
478    let b2 = buffer[2];
479    if !b2.is_ascii_digit() {
480        return None;
481    }
482    let d2 = (b2 - b'0') as usize;
483    // 3 digits
484    if buffer[3] == b'\r' && buffer[4] == b'\n' {
485        return Some((d0 * 100 + d1 * 10 + d2, 5));
486    }
487    if buffer.len() < 6 {
488        return None;
489    }
490    let b3 = buffer[3];
491    if !b3.is_ascii_digit() {
492        return None;
493    }
494    let d3 = (b3 - b'0') as usize;
495    // 4 digits
496    if buffer[4] == b'\r' && buffer[5] == b'\n' {
497        return Some((d0 * 1000 + d1 * 100 + d2 * 10 + d3, 6));
498    }
499    None
500}
501
502#[inline]
503fn parse_i64_line(buffer: &[u8]) -> Result<Option<(i64, usize)>> {
504    let Some(end) = find_crlf(buffer) else {
505        return Ok(None);
506    };
507    let value = parse_ascii_i64(&buffer[..end])?;
508    Ok(Some((value, end + 2)))
509}
510
511#[inline]
512fn parse_ascii_isize(bytes: &[u8]) -> Result<isize> {
513    let (negative, digits) = split_sign(bytes)?;
514    if digits.is_empty() {
515        return Err(FastCacheError::Protocol(
516            "empty integer in RESP header".into(),
517        ));
518    }
519    let mut value: isize = 0;
520    for &b in digits {
521        if !b.is_ascii_digit() {
522            return Err(FastCacheError::Protocol(format!(
523                "non-digit byte in RESP integer: {b:#x}"
524            )));
525        }
526        value = value
527            .checked_mul(10)
528            .and_then(|v| v.checked_add((b - b'0') as isize))
529            .ok_or_else(|| FastCacheError::Protocol("RESP integer overflow".into()))?;
530    }
531    Ok(if negative { -value } else { value })
532}
533
534#[inline]
535fn parse_ascii_i64(bytes: &[u8]) -> Result<i64> {
536    let (negative, digits) = split_sign(bytes)?;
537    if digits.is_empty() {
538        return Err(FastCacheError::Protocol(
539            "empty integer in RESP header".into(),
540        ));
541    }
542    let mut value: i64 = 0;
543    for &b in digits {
544        if !b.is_ascii_digit() {
545            return Err(FastCacheError::Protocol(format!(
546                "non-digit byte in RESP integer: {b:#x}"
547            )));
548        }
549        value = value
550            .checked_mul(10)
551            .and_then(|v| v.checked_add((b - b'0') as i64))
552            .ok_or_else(|| FastCacheError::Protocol("RESP integer overflow".into()))?;
553    }
554    Ok(if negative { -value } else { value })
555}
556
557#[inline]
558fn split_sign(bytes: &[u8]) -> Result<(bool, &[u8])> {
559    Ok(match bytes.first() {
560        Some(b'-') => (true, &bytes[1..]),
561        Some(b'+') => (false, &bytes[1..]),
562        _ => (false, bytes),
563    })
564}
565
566impl fmt::Display for Frame {
567    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
568        match self {
569            Frame::SimpleString(value) => write!(f, "{value}"),
570            Frame::BlobString(value) => write!(f, "{}", String::from_utf8_lossy(value)),
571            Frame::Integer(value) => write!(f, "{value}"),
572            Frame::Array(value) => write!(f, "{value:?}"),
573            Frame::Null => write!(f, "null"),
574            Frame::Boolean(value) => write!(f, "{value}"),
575            Frame::Error(value) => write!(f, "ERR {value}"),
576        }
577    }
578}
579
580#[cfg(test)]
581mod tests {
582    use super::{Frame, RespCodec};
583
584    #[test]
585    fn round_trips_array() {
586        let frame = Frame::Array(vec![
587            Frame::BlobString(b"SET".to_vec()),
588            Frame::BlobString(b"alpha".to_vec()),
589            Frame::BlobString(b"beta".to_vec()),
590        ]);
591        let mut encoded = Vec::new();
592        RespCodec::encode(&frame, &mut encoded);
593        let decoded = RespCodec::decode(&encoded).unwrap().unwrap().0;
594        assert_eq!(decoded, frame);
595    }
596
597    #[test]
598    fn decodes_command_part_spans() {
599        let frame = Frame::Array(vec![
600            Frame::BlobString(b"MSET".to_vec()),
601            Frame::BlobString(b"long-key-name".to_vec()),
602            Frame::BlobString(b"value-body".to_vec()),
603        ]);
604        let mut encoded = Vec::new();
605        RespCodec::encode(&frame, &mut encoded);
606
607        let (spans, consumed) = RespCodec::decode_command_spans(&encoded).unwrap().unwrap();
608
609        assert_eq!(consumed, encoded.len());
610        assert_eq!(&encoded[spans.parts[0].clone()], b"MSET");
611        assert_eq!(&encoded[spans.parts[1].clone()], b"long-key-name");
612        assert_eq!(&encoded[spans.parts[2].clone()], b"value-body");
613    }
614}