Skip to main content

kevy_resp/
request.rs

1//! Request-side parser: turns a byte stream from a client into an [`Argv`].
2//!
3//! Handles the two RESP2 request forms — `*N\r\n$L\r\n…` multi-bulk (the
4//! normal client encoding) and the inline form (whitespace-separated, a
5//! convenience for raw-typed PING / DEBUG / etc). Parsing is incremental:
6//! returning `Ok(None)` asks the caller to read more bytes and retry.
7
8use crate::argv::{Argv, Command};
9use crate::error::ProtocolError;
10
11/// Attempt to parse one command from the front of `buf`.
12///
13/// - `Ok(Some((cmd, consumed)))` — a full command; `consumed` bytes may be dropped.
14/// - `Ok(None)` — need more bytes; call again after reading more.
15/// - `Err(_)` — the stream is corrupt; the caller should reply with an error
16///   and close the connection.
17///
18/// This is the convenience form that allocates a fresh `Argv` per call. The
19/// reactor's hot path uses [`parse_command_into`] with a reused scratch
20/// `Argv` to keep per-cmd malloc rate at 0.
21pub fn parse_command(buf: &[u8]) -> Result<Option<(Command, usize)>, ProtocolError> {
22    let mut argv = Argv::default();
23    match parse_command_into(buf, &mut argv)? {
24        Some(consumed) => Ok(Some((argv, consumed))),
25        None => Ok(None),
26    }
27}
28
29/// Same as [`parse_command`], but writes into a caller-provided scratch
30/// `Argv` instead of allocating a new one each call. The reactor stores one
31/// `Argv` per shard and reuses it for every cmd on the local hot path; the
32/// internal `Vec<u8>` + `Vec<u32>` capacities amortise to zero allocations
33/// per command after the first few cmds warm them.
34///
35/// `dst` is cleared at the start of every call; on `Ok(None)` and `Err`, `dst`
36/// is left empty (so the caller doesn't see partial state).
37pub fn parse_command_into(buf: &[u8], dst: &mut Argv) -> Result<Option<usize>, ProtocolError> {
38    dst.clear();
39    if buf.is_empty() {
40        return Ok(None);
41    }
42    if buf[0] == b'*' {
43        parse_multibulk_into(buf, dst)
44    } else {
45        parse_inline_into(buf, dst)
46    }
47}
48
49fn parse_inline_into(buf: &[u8], dst: &mut Argv) -> Result<Option<usize>, ProtocolError> {
50    let Some(eol) = find_crlf(buf, 0) else {
51        return Ok(None);
52    };
53    let line = &buf[..eol];
54    for tok in line
55        .split(|b| b.is_ascii_whitespace())
56        .filter(|s| !s.is_empty())
57    {
58        dst.push(tok);
59    }
60    Ok(Some(eol + 2))
61}
62
63/// Validate the multi-bulk frame is fully present and report `(end_pos,
64/// total_arg_bytes)` if so. `start_pos` is the offset of the first `$`
65/// after the `*N\r\n` header. `Ok(None)` = need more bytes; `Err` = malformed.
66pub(crate) fn validate_multibulk_frame(
67    buf: &[u8],
68    start_pos: usize,
69    count: usize,
70) -> Result<Option<(usize, usize)>, ProtocolError> {
71    let mut pos = start_pos;
72    let mut total = 0usize;
73    for _ in 0..count {
74        if pos >= buf.len() {
75            return Ok(None);
76        }
77        if buf[pos] != b'$' {
78            return Err(ProtocolError::Malformed("expected bulk string"));
79        }
80        let Some(len_end) = find_crlf(buf, pos + 1) else {
81            return Ok(None);
82        };
83        let len = parse_int(&buf[pos + 1..len_end])
84            .ok_or(ProtocolError::Malformed("bad bulk length"))?;
85        if len < 0 {
86            return Err(ProtocolError::Malformed("negative bulk length in request"));
87        }
88        let len = len as usize;
89        let data_end = len_end + 2 + len;
90        if buf.len() < data_end + 2 {
91            return Ok(None);
92        }
93        if &buf[data_end..data_end + 2] != b"\r\n" {
94            return Err(ProtocolError::Malformed("bulk string not CRLF-terminated"));
95        }
96        total += len;
97        pos = data_end + 2;
98    }
99    Ok(Some((pos, total)))
100}
101
102/// Copy `count` already-validated bulk args from `buf[start_pos..]` into `dst`.
103/// Caller must have called [`validate_multibulk_frame`] first.
104fn copy_multibulk_args(buf: &[u8], start_pos: usize, count: usize, dst: &mut Argv) {
105    let mut p = start_pos;
106    for _ in 0..count {
107        let len_end = find_crlf(buf, p + 1).expect("validated in pass 1");
108        let len = parse_int(&buf[p + 1..len_end]).expect("validated in pass 1") as usize;
109        let data_start = len_end + 2;
110        dst.push(&buf[data_start..data_start + len]);
111        p = data_start + len + 2;
112    }
113}
114
115fn parse_multibulk_into(buf: &[u8], dst: &mut Argv) -> Result<Option<usize>, ProtocolError> {
116    let Some(hdr_end) = find_crlf(buf, 1) else {
117        return Ok(None);
118    };
119    let count =
120        parse_int(&buf[1..hdr_end]).ok_or(ProtocolError::Malformed("bad multibulk count"))?;
121    if count < 0 {
122        // Null array → empty argv (already cleared).
123        return Ok(Some(hdr_end + 2));
124    }
125    let count = count as usize;
126    let start = hdr_end + 2;
127
128    let (end_pos, total) = match validate_multibulk_frame(buf, start, count)? {
129        Some(t) => t,
130        None => return Ok(None),
131    };
132
133    // `reserve` is a no-op when the scratch Argv has already amortised
134    // enough capacity from earlier cmds.
135    dst.reserve_for(count, total);
136    copy_multibulk_args(buf, start, count, dst);
137    Ok(Some(end_pos))
138}
139
140/// Parse a bulk-string length header `$<len>\r\n` whose `$` sits at
141/// `buf[pos]` (the caller has already checked that byte). One fused pass:
142/// the digits accumulate while the same loop walks to the terminating
143/// CRLF — bulk headers are 2-21 bytes, so this short byte loop beats the
144/// `find_crlf` + [`parse_int`] double scan the two-pass parser paid per
145/// arg. Accepts the same shapes as `parse_int` (optional `+`/`-` sign,
146/// checked i64 accumulation); a negative length is malformed in a
147/// request, matching [`validate_multibulk_frame`].
148///
149/// Returns `(len, data_start)`; `Ok(None)` = need more bytes.
150pub(crate) fn parse_bulk_len(
151    buf: &[u8],
152    pos: usize,
153) -> Result<Option<(usize, usize)>, ProtocolError> {
154    let mut q = pos + 1;
155    let neg = match buf.get(q) {
156        None => return Ok(None),
157        Some(b'-') => {
158            q += 1;
159            true
160        }
161        Some(b'+') => {
162            q += 1;
163            false
164        }
165        _ => false,
166    };
167    let digits_start = q;
168    let mut acc: i64 = 0;
169    loop {
170        match buf.get(q) {
171            None => return Ok(None),
172            Some(&b) if b.is_ascii_digit() => {
173                acc = acc
174                    .checked_mul(10)
175                    .and_then(|a| a.checked_add((b - b'0') as i64))
176                    .ok_or(ProtocolError::Malformed("bad bulk length"))?;
177                q += 1;
178            }
179            Some(b'\r') => break,
180            Some(_) => return Err(ProtocolError::Malformed("bad bulk length")),
181        }
182    }
183    if q == digits_start {
184        return Err(ProtocolError::Malformed("bad bulk length"));
185    }
186    match buf.get(q + 1) {
187        None => return Ok(None),
188        Some(b'\n') => {}
189        Some(_) => return Err(ProtocolError::Malformed("bad bulk length")),
190    }
191    if neg {
192        return Err(ProtocolError::Malformed("negative bulk length in request"));
193    }
194    Ok(Some((acc as usize, q + 2)))
195}
196
197/// Find the index of `\r\n` at or after `start`, returning the index of `\r`.
198///
199/// SWAR-accelerated: scans 8 bytes at a time using the classic "has-zero-byte"
200/// bit trick (XOR each byte with `\r`, then `(x - 0x01..) & !x & 0x80..`
201/// isolates bytes that were zero). On a CR hit we confirm the next byte is
202/// `\n` and return; otherwise we resume from `pos + 1` so a stray `\r` doesn't
203/// terminate the scan. Safe Rust only — keeps `kevy-resp`'s
204/// `forbid(unsafe_code)` guarantee.
205pub(crate) fn find_crlf(buf: &[u8], start: usize) -> Option<usize> {
206    const CR_BCAST: u64 = 0x0D0D0D0D_0D0D0D0Du64;
207    const ONES: u64 = 0x01010101_01010101u64;
208    const HIGH: u64 = 0x80808080_80808080u64;
209
210    let n = buf.len();
211    let mut i = start;
212    // Need at least 2 bytes (CR + LF) to find a CRLF.
213    if i + 1 >= n {
214        return None;
215    }
216    // SWAR loop: read 8 bytes, find any byte == 0x0D, then check the next
217    // byte. We require the WHOLE 8-byte window to be within `buf` AND the
218    // byte just past it to also exist (so a CR at position 7 of the window
219    // can be confirmed by reading position 8). That's `i + 9 <= n`, i.e.
220    // `i + 8 < n` (strict, since we may need [pos+1] which is at most i+8
221    // when pos == i+7).
222    while i + 8 < n {
223        let word = u64::from_le_bytes(buf[i..i + 8].try_into().expect("8 bytes"));
224        let x = word ^ CR_BCAST;
225        let zeroed = x.wrapping_sub(ONES) & !x & HIGH;
226        if zeroed != 0 {
227            // The low set bit's byte index = first CR in this 8-byte window.
228            let bit_idx = zeroed.trailing_zeros();
229            let pos = i + (bit_idx / 8) as usize;
230            // pos < i + 8 ≤ n - 1, so pos + 1 < n is valid to read.
231            if buf[pos + 1] == b'\n' {
232                return Some(pos);
233            }
234            // Lone CR — resume scanning from the byte after it.
235            i = pos + 1;
236            continue;
237        }
238        i += 8;
239    }
240    // Tail: scalar over the last < 8 bytes (or what's left after a partial
241    // resume above).
242    while i + 1 < n {
243        if buf[i] == b'\r' && buf[i + 1] == b'\n' {
244            return Some(i);
245        }
246        i += 1;
247    }
248    None
249}
250
251/// Parse a base-10 signed integer from ASCII bytes (no surrounding whitespace).
252pub(crate) fn parse_int(bytes: &[u8]) -> Option<i64> {
253    if bytes.is_empty() {
254        return None;
255    }
256    let (neg, digits) = match bytes[0] {
257        b'-' => (true, &bytes[1..]),
258        b'+' => (false, &bytes[1..]),
259        _ => (false, bytes),
260    };
261    if digits.is_empty() {
262        return None;
263    }
264    let mut acc: i64 = 0;
265    for &b in digits {
266        if !b.is_ascii_digit() {
267            return None;
268        }
269        acc = acc.checked_mul(10)?.checked_add((b - b'0') as i64)?;
270    }
271    Some(if neg { -acc } else { acc })
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::encode_command;
278
279    // SWAR find_crlf fuzz: planted CRLFs at every offset 0..40, lone-CR
280    // distractors, no-CRLF inputs, near-end boundaries. The SWAR window is
281    // 8 bytes, so transitions at offsets 0/7/8/15/16/… stress alignment.
282    #[test]
283    fn find_crlf_at_every_offset() {
284        for off in 0..40 {
285            let mut buf = vec![b'a'; 60];
286            buf[off] = b'\r';
287            buf[off + 1] = b'\n';
288            assert_eq!(find_crlf(&buf, 0), Some(off), "off={off}");
289        }
290    }
291
292    #[test]
293    fn find_crlf_skips_lone_cr() {
294        // Lone \r at the front, then a real CRLF later.
295        let mut buf = vec![b'a'; 32];
296        buf[3] = b'\r';
297        buf[4] = b'b'; // not \n → skip
298        buf[20] = b'\r';
299        buf[21] = b'\n';
300        assert_eq!(find_crlf(&buf, 0), Some(20));
301    }
302
303    #[test]
304    fn find_crlf_none_when_absent() {
305        let buf = vec![b'a'; 32];
306        assert_eq!(find_crlf(&buf, 0), None);
307        let buf = b"";
308        assert_eq!(find_crlf(buf, 0), None);
309        let buf = b"\r"; // only CR, no LF available
310        assert_eq!(find_crlf(buf, 0), None);
311    }
312
313    #[test]
314    fn find_crlf_at_buffer_end() {
315        let buf = b"abcdefghij\r\n"; // CRLF at offset 10
316        assert_eq!(find_crlf(buf, 0), Some(10));
317        // Start past the CR.
318        assert_eq!(find_crlf(buf, 11), None);
319    }
320
321    #[test]
322    fn find_crlf_with_many_lone_crs() {
323        // 7 lone CRs followed by a real CRLF. SWAR finds one CR per iter
324        // but must keep going until it finds the real pair.
325        let mut buf = Vec::new();
326        for _ in 0..7 {
327            buf.push(b'\r');
328            buf.push(b'x'); // not \n
329        }
330        buf.extend_from_slice(b"\r\n");
331        // Real CRLF starts at offset 14 (7 * 2).
332        assert_eq!(find_crlf(&buf, 0), Some(14));
333    }
334
335    #[test]
336    fn find_crlf_from_nonzero_start() {
337        let buf = b"\r\n\r\n\r\n";
338        // Starts at offset 0 → first CRLF.
339        assert_eq!(find_crlf(buf, 0), Some(0));
340        // Skip the first CRLF.
341        assert_eq!(find_crlf(buf, 2), Some(2));
342        assert_eq!(find_crlf(buf, 4), Some(4));
343    }
344
345    #[test]
346    fn parse_multibulk_ping() {
347        let (cmd, used) = parse_command(b"*1\r\n$4\r\nPING\r\n").unwrap().unwrap();
348        assert_eq!(cmd, vec![b"PING".to_vec()]);
349        assert_eq!(used, 14);
350    }
351
352    #[test]
353    fn parse_multibulk_echo() {
354        let frame = b"*2\r\n$4\r\nECHO\r\n$5\r\nhello\r\n";
355        let (cmd, used) = parse_command(frame).unwrap().unwrap();
356        assert_eq!(cmd, vec![b"ECHO".to_vec(), b"hello".to_vec()]);
357        assert_eq!(used, frame.len());
358    }
359
360    #[test]
361    fn parse_incomplete_returns_none() {
362        assert_eq!(parse_command(b"*1\r\n$4\r\nPI").unwrap(), None);
363        assert_eq!(parse_command(b"*2\r\n$4\r\nECHO\r\n").unwrap(), None);
364        assert_eq!(parse_command(b"").unwrap(), None);
365    }
366
367    #[test]
368    fn parse_inline_command() {
369        let (cmd, used) = parse_command(b"PING\r\n").unwrap().unwrap();
370        assert_eq!(cmd, vec![b"PING".to_vec()]);
371        assert_eq!(used, 6);
372        let (cmd, _) = parse_command(b"ECHO  hi there\r\n").unwrap().unwrap();
373        assert_eq!(
374            cmd,
375            vec![b"ECHO".to_vec(), b"hi".to_vec(), b"there".to_vec()]
376        );
377    }
378
379    #[test]
380    fn parse_malformed_errors() {
381        assert!(parse_command(b"*1\r\n+OK\r\n").is_err());
382        assert!(parse_command(b"*x\r\n").is_err());
383    }
384
385    #[test]
386    fn round_trip_command() {
387        let mut buf = Vec::new();
388        encode_command(&mut buf, &[b"SET".to_vec(), b"k".to_vec(), b"v".to_vec()]);
389        let (cmd, used) = parse_command(&buf).unwrap().unwrap();
390        assert_eq!(cmd, vec![b"SET".to_vec(), b"k".to_vec(), b"v".to_vec()]);
391        assert_eq!(used, buf.len());
392    }
393
394}