Skip to main content

resp_proto/
streaming.rs

1//! Streaming command parser for zero-copy receive optimization.
2//!
3//! This module provides incremental parsing that can pause after parsing
4//! a command header, allowing the caller to receive large values directly
5//! into a target buffer (e.g., cache segment memory) without intermediate copies.
6//!
7//! # Example
8//!
9//! ```ignore
10//! use protocol_resp::streaming::{StreamingParser, ParseProgress};
11//!
12//! let mut parser = StreamingParser::new();
13//!
14//! // Feed data as it arrives
15//! match parser.parse(buffer)? {
16//!     ParseProgress::Incomplete => {
17//!         // Need more data
18//!     }
19//!     ParseProgress::NeedValue { header, value_len, .. } => {
20//!         // Allocate target buffer for value
21//!         let mut target = cache.reserve_set(header.key, value_len)?;
22//!
23//!         // Receive remaining bytes directly into target
24//!         recv_into(target.value_mut())?;
25//!
26//!         // Complete the command
27//!         parser.complete_value(target.value_mut());
28//!     }
29//!     ParseProgress::Complete(cmd, consumed) => {
30//!         // Handle complete command
31//!     }
32//! }
33//! ```
34
35use crate::Command;
36use crate::error::ParseError;
37use crate::value::ParseOptions;
38use std::time::Duration;
39
40/// Threshold for using streaming parse (64KB).
41/// Values smaller than this don't benefit enough from zero-copy receive.
42pub const STREAMING_THRESHOLD: usize = 64 * 1024;
43
44/// Result of incremental parsing.
45#[derive(Debug)]
46pub enum ParseProgress<'a> {
47    /// Need more data to continue parsing.
48    Incomplete,
49
50    /// Command header parsed, waiting for value data.
51    ///
52    /// The caller should:
53    /// 1. Allocate a buffer for the value (e.g., in cache segment)
54    /// 2. Copy `value_prefix` to the start of that buffer
55    /// 3. Receive remaining `value_len - value_prefix.len()` bytes into the buffer
56    /// 4. Call `complete_with_value()` to finish parsing
57    NeedValue {
58        /// Parsed command header with metadata.
59        header: SetHeader<'a>,
60        /// Total size of the value in bytes.
61        value_len: usize,
62        /// Bytes of value already in the parse buffer (may be empty).
63        /// These must be copied to the target buffer before receiving more.
64        value_prefix: &'a [u8],
65        /// Bytes consumed from buffer so far (header only).
66        header_consumed: usize,
67    },
68
69    /// Value exceeds maximum allowed size - needs to be drained.
70    ///
71    /// The caller should:
72    /// 1. Send an error response to the client
73    /// 2. Drain `value_len + 2` bytes (value + trailing CRLF) from the connection
74    /// 3. Resume normal parsing
75    ValueTooLarge {
76        /// Total size of the value in bytes.
77        value_len: usize,
78        /// Bytes of value already in the parse buffer (to be discarded).
79        value_prefix_len: usize,
80        /// Bytes consumed from buffer so far (header only).
81        header_consumed: usize,
82        /// Maximum allowed value size.
83        max_value_size: usize,
84    },
85
86    /// Fully parsed command (used for non-SET commands or small values).
87    Complete(Command<'a>, usize),
88}
89
90/// Parsed SET command header (without the value).
91#[derive(Debug, Clone)]
92pub struct SetHeader<'a> {
93    /// The key for this SET operation.
94    pub key: &'a [u8],
95    /// EX option: expire time in seconds.
96    pub ex: Option<u64>,
97    /// PX option: expire time in milliseconds.
98    pub px: Option<u64>,
99    /// NX option: only set if key doesn't exist.
100    pub nx: bool,
101    /// XX option: only set if key exists.
102    pub xx: bool,
103    /// Number of remaining option arguments after the value.
104    remaining_args: usize,
105}
106
107impl<'a> SetHeader<'a> {
108    /// Get the TTL as a Duration, if specified.
109    pub fn ttl(&self) -> Option<Duration> {
110        if let Some(secs) = self.ex {
111            Some(Duration::from_secs(secs))
112        } else {
113            self.px.map(Duration::from_millis)
114        }
115    }
116}
117
118/// Parse a command, potentially yielding early for large SET values.
119///
120/// This function provides the streaming parse capability. For SET commands
121/// with values >= `STREAMING_THRESHOLD`, it returns `NeedValue` after parsing
122/// the header, allowing the caller to receive the value directly into a
123/// target buffer.
124///
125/// For all other commands (including small SETs), it behaves identically to
126/// `Command::parse()`.
127///
128/// # Arguments
129///
130/// * `buffer` - The input buffer containing RESP data
131/// * `options` - Parse options (max lengths, etc.)
132/// * `streaming_threshold` - Minimum value size for streaming (use `STREAMING_THRESHOLD`)
133///
134/// # Returns
135///
136/// * `Ok(ParseProgress::Incomplete)` - Need more data
137/// * `Ok(ParseProgress::NeedValue { .. })` - SET header parsed, value pending
138/// * `Ok(ParseProgress::Complete(cmd, consumed))` - Fully parsed command
139/// * `Err(ParseError)` - Parse error
140pub fn parse_streaming<'a>(
141    buffer: &'a [u8],
142    options: &ParseOptions,
143    streaming_threshold: usize,
144) -> Result<ParseProgress<'a>, ParseError> {
145    let mut cursor = StreamingCursor::new(buffer, options.max_bulk_string_len);
146
147    // Read array header
148    if cursor.remaining() < 1 {
149        return Ok(ParseProgress::Incomplete);
150    }
151    if cursor.peek() != b'*' {
152        return Err(ParseError::Protocol("expected array".to_string()));
153    }
154    cursor.advance(1);
155
156    // Read array length
157    let count = match cursor.read_integer() {
158        Ok(n) => n,
159        Err(ParseError::Incomplete) => return Ok(ParseProgress::Incomplete),
160        Err(e) => return Err(e),
161    };
162
163    if count < 1 {
164        return Err(ParseError::Protocol(
165            "array must have at least 1 element".to_string(),
166        ));
167    }
168
169    const MAX_ARRAY_LEN: usize = 1024 * 1024;
170    if count > MAX_ARRAY_LEN {
171        return Err(ParseError::Protocol("array too large".to_string()));
172    }
173
174    // Read command name
175    let cmd_name = match cursor.read_bulk_string() {
176        Ok(s) => s,
177        Err(ParseError::Incomplete) => return Ok(ParseProgress::Incomplete),
178        Err(e) => return Err(e),
179    };
180
181    let cmd_str = std::str::from_utf8(cmd_name)
182        .map_err(|_| ParseError::Protocol("invalid UTF-8 in command".to_string()))?;
183
184    // Only handle SET specially; other commands use normal path
185    if !cmd_str.eq_ignore_ascii_case("set") {
186        // Fall back to normal parsing for non-SET commands
187        return match Command::parse_with_options(buffer, options) {
188            Ok((cmd, consumed)) => Ok(ParseProgress::Complete(cmd, consumed)),
189            Err(ParseError::Incomplete) => Ok(ParseProgress::Incomplete),
190            Err(e) => Err(e),
191        };
192    }
193
194    // Parse SET command
195    if count < 3 {
196        return Err(ParseError::WrongArity(
197            "SET requires at least 2 arguments".to_string(),
198        ));
199    }
200
201    // Read key
202    let key = match cursor.read_bulk_string() {
203        Ok(s) => s,
204        Err(ParseError::Incomplete) => return Ok(ParseProgress::Incomplete),
205        Err(e) => return Err(e),
206    };
207
208    // Read value length header (but not the value itself)
209    if cursor.remaining() < 1 {
210        return Ok(ParseProgress::Incomplete);
211    }
212    if cursor.peek() != b'$' {
213        return Err(ParseError::Protocol(
214            "expected bulk string for value".to_string(),
215        ));
216    }
217    cursor.advance(1);
218
219    let value_len = match cursor.read_integer() {
220        Ok(n) => n,
221        Err(ParseError::Incomplete) => return Ok(ParseProgress::Incomplete),
222        Err(e) => return Err(e),
223    };
224
225    // Check bulk string length limit - return ValueTooLarge to allow draining
226    if value_len > cursor.max_bulk_string_len {
227        let header_consumed = cursor.position();
228        let remaining_in_buffer = cursor.remaining();
229        let value_prefix_len = remaining_in_buffer.min(value_len);
230
231        return Ok(ParseProgress::ValueTooLarge {
232            value_len,
233            value_prefix_len,
234            header_consumed,
235            max_value_size: cursor.max_bulk_string_len,
236        });
237    }
238
239    // If value is small, use normal parsing path
240    if value_len < streaming_threshold {
241        return match Command::parse_with_options(buffer, options) {
242            Ok((cmd, consumed)) => Ok(ParseProgress::Complete(cmd, consumed)),
243            Err(ParseError::Incomplete) => Ok(ParseProgress::Incomplete),
244            Err(e) => Err(e),
245        };
246    }
247
248    // Large value: return NeedValue for streaming receive
249    let header_consumed = cursor.position();
250    let remaining_in_buffer = cursor.remaining();
251
252    // Calculate how much of the value (if any) is already in the buffer
253    let value_prefix_len = remaining_in_buffer.min(value_len);
254    let value_prefix = &buffer[header_consumed..header_consumed + value_prefix_len];
255
256    // Parse any options that come BEFORE the value in the command
257    // For standard SET, options come AFTER the value, so remaining_args = count - 3
258    let remaining_args = count.saturating_sub(3);
259
260    Ok(ParseProgress::NeedValue {
261        header: SetHeader {
262            key,
263            ex: None,
264            px: None,
265            nx: false,
266            xx: false,
267            remaining_args,
268        },
269        value_len,
270        value_prefix,
271        header_consumed,
272    })
273}
274
275/// Complete parsing a SET command after the value has been received.
276///
277/// This function parses any remaining options (EX, PX, NX, XX) that follow
278/// the value in the command.
279///
280/// # Arguments
281///
282/// * `buffer` - Buffer containing data after the value (options + CRLF)
283/// * `header` - The SET header from `ParseProgress::NeedValue`
284/// * `value` - The received value data
285///
286/// # Returns
287///
288/// * `Ok((Command, consumed))` - Fully parsed command
289/// * `Err(ParseError::Incomplete)` - Need more data for options
290/// * `Err(ParseError)` - Parse error
291pub fn complete_set<'a>(
292    buffer: &'a [u8],
293    header: &SetHeader<'a>,
294    value: &'a [u8],
295) -> Result<(Command<'a>, usize), ParseError> {
296    let mut cursor = StreamingCursor::new(buffer, usize::MAX);
297
298    // Expect CRLF after value
299    if cursor.remaining() < 2 {
300        return Err(ParseError::Incomplete);
301    }
302    if cursor.peek() != b'\r' {
303        return Err(ParseError::Protocol(
304            "expected CRLF after bulk string".to_string(),
305        ));
306    }
307    cursor.advance(1);
308    if cursor.peek() != b'\n' {
309        return Err(ParseError::Protocol(
310            "expected CRLF after bulk string".to_string(),
311        ));
312    }
313    cursor.advance(1);
314
315    // Parse remaining options
316    let mut ex = header.ex;
317    let mut px = header.px;
318    let mut nx = header.nx;
319    let mut xx = header.xx;
320
321    let mut remaining_args = header.remaining_args;
322    while remaining_args > 0 {
323        let option = match cursor.read_bulk_string() {
324            Ok(s) => s,
325            Err(ParseError::Incomplete) => return Err(ParseError::Incomplete),
326            Err(e) => return Err(e),
327        };
328
329        let option_str = std::str::from_utf8(option)
330            .map_err(|_| ParseError::Protocol("invalid UTF-8 in option".to_string()))?;
331
332        if option_str.eq_ignore_ascii_case("ex") {
333            if remaining_args < 2 {
334                return Err(ParseError::Protocol("EX requires a value".to_string()));
335            }
336            let ttl_bytes = cursor.read_bulk_string()?;
337            let ttl_str = std::str::from_utf8(ttl_bytes)
338                .map_err(|_| ParseError::Protocol("invalid UTF-8 in TTL".to_string()))?;
339            let ttl_secs = ttl_str
340                .parse::<u64>()
341                .map_err(|_| ParseError::Protocol("invalid TTL value".to_string()))?;
342            ex = Some(ttl_secs);
343            remaining_args -= 2;
344        } else if option_str.eq_ignore_ascii_case("px") {
345            if remaining_args < 2 {
346                return Err(ParseError::Protocol("PX requires a value".to_string()));
347            }
348            let ttl_bytes = cursor.read_bulk_string()?;
349            let ttl_str = std::str::from_utf8(ttl_bytes)
350                .map_err(|_| ParseError::Protocol("invalid UTF-8 in TTL".to_string()))?;
351            let ttl_ms = ttl_str
352                .parse::<u64>()
353                .map_err(|_| ParseError::Protocol("invalid TTL value".to_string()))?;
354            px = Some(ttl_ms);
355            remaining_args -= 2;
356        } else if option_str.eq_ignore_ascii_case("nx") {
357            nx = true;
358            remaining_args -= 1;
359        } else if option_str.eq_ignore_ascii_case("xx") {
360            xx = true;
361            remaining_args -= 1;
362        } else {
363            return Err(ParseError::Protocol(format!(
364                "unknown SET option: {}",
365                option_str
366            )));
367        }
368    }
369
370    Ok((
371        Command::Set {
372            key: header.key,
373            value,
374            ex,
375            px,
376            nx,
377            xx,
378        },
379        cursor.position(),
380    ))
381}
382
383/// Internal cursor for streaming parsing.
384struct StreamingCursor<'a> {
385    buffer: &'a [u8],
386    pos: usize,
387    max_bulk_string_len: usize,
388}
389
390impl<'a> StreamingCursor<'a> {
391    fn new(buffer: &'a [u8], max_bulk_string_len: usize) -> Self {
392        Self {
393            buffer,
394            pos: 0,
395            max_bulk_string_len,
396        }
397    }
398
399    #[inline]
400    fn remaining(&self) -> usize {
401        self.buffer.len() - self.pos
402    }
403
404    #[inline]
405    fn position(&self) -> usize {
406        self.pos
407    }
408
409    #[inline]
410    fn peek(&self) -> u8 {
411        self.buffer[self.pos]
412    }
413
414    #[inline]
415    fn advance(&mut self, n: usize) {
416        self.pos += n;
417    }
418
419    fn read_integer(&mut self) -> Result<usize, ParseError> {
420        let line = self.read_line()?;
421
422        if line.is_empty() {
423            return Err(ParseError::InvalidInteger("empty integer".to_string()));
424        }
425
426        if line.len() > 19 {
427            return Err(ParseError::InvalidInteger("integer too large".to_string()));
428        }
429
430        let mut result = 0usize;
431        for &byte in line {
432            if !byte.is_ascii_digit() {
433                return Err(ParseError::InvalidInteger(
434                    "non-digit character".to_string(),
435                ));
436            }
437            result = result
438                .checked_mul(10)
439                .and_then(|r| r.checked_add((byte - b'0') as usize))
440                .ok_or_else(|| ParseError::InvalidInteger("integer overflow".to_string()))?;
441        }
442        Ok(result)
443    }
444
445    fn read_bulk_string(&mut self) -> Result<&'a [u8], ParseError> {
446        if self.remaining() < 1 {
447            return Err(ParseError::Incomplete);
448        }
449
450        if self.peek() != b'$' {
451            return Err(ParseError::Protocol("expected bulk string".to_string()));
452        }
453        self.advance(1);
454
455        let len = self.read_integer()?;
456
457        if len > self.max_bulk_string_len {
458            return Err(ParseError::BulkStringTooLong {
459                len,
460                max: self.max_bulk_string_len,
461            });
462        }
463
464        if self.remaining() < len + 2 {
465            return Err(ParseError::Incomplete);
466        }
467
468        let data = &self.buffer[self.pos..self.pos + len];
469        self.pos += len;
470
471        if self.remaining() < 2 {
472            return Err(ParseError::Incomplete);
473        }
474        if self.peek() != b'\r' {
475            return Err(ParseError::Protocol(
476                "expected CRLF after bulk string".to_string(),
477            ));
478        }
479        self.advance(1);
480        if self.peek() != b'\n' {
481            return Err(ParseError::Protocol(
482                "expected CRLF after bulk string".to_string(),
483            ));
484        }
485        self.advance(1);
486
487        Ok(data)
488    }
489
490    fn read_line(&mut self) -> Result<&'a [u8], ParseError> {
491        let start = self.pos;
492        let slice = &self.buffer[start..];
493
494        for i in 0..slice.len().saturating_sub(1) {
495            if slice[i] == b'\r' && slice[i + 1] == b'\n' {
496                let line = &self.buffer[start..start + i];
497                self.pos = start + i + 2;
498                return Ok(line);
499            }
500        }
501
502        Err(ParseError::Incomplete)
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_small_set_uses_normal_path() {
512        let data = b"*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n$7\r\nmyvalue\r\n";
513        let result = parse_streaming(data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
514
515        match result {
516            ParseProgress::Complete(cmd, consumed) => {
517                assert_eq!(
518                    cmd,
519                    Command::Set {
520                        key: b"mykey",
521                        value: b"myvalue",
522                        ex: None,
523                        px: None,
524                        nx: false,
525                        xx: false,
526                    }
527                );
528                assert_eq!(consumed, data.len());
529            }
530            _ => panic!("expected Complete"),
531        }
532    }
533
534    #[test]
535    fn test_large_set_yields_need_value() {
536        // SET with 100KB value (above threshold)
537        let value_len = 100 * 1024;
538        let header = format!("*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n${}\r\n", value_len);
539        let mut data = header.as_bytes().to_vec();
540        // Add some bytes of the value (simulating partial receive)
541        data.extend_from_slice(&[b'x'; 1000]);
542
543        let result = parse_streaming(&data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
544
545        match result {
546            ParseProgress::NeedValue {
547                header,
548                value_len: vl,
549                value_prefix,
550                header_consumed,
551            } => {
552                assert_eq!(header.key, b"mykey");
553                assert_eq!(vl, 100 * 1024);
554                assert_eq!(value_prefix.len(), 1000);
555                assert!(value_prefix.iter().all(|&b| b == b'x'));
556                assert_eq!(header_consumed, header_consumed); // Just checking it's set
557            }
558            _ => panic!("expected NeedValue, got {:?}", result),
559        }
560    }
561
562    #[test]
563    fn test_get_uses_normal_path() {
564        let data = b"*2\r\n$3\r\nGET\r\n$5\r\nmykey\r\n";
565        let result = parse_streaming(data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
566
567        match result {
568            ParseProgress::Complete(cmd, _) => {
569                assert_eq!(cmd, Command::Get { key: b"mykey" });
570            }
571            _ => panic!("expected Complete"),
572        }
573    }
574
575    #[test]
576    fn test_incomplete_header() {
577        let data = b"*3\r\n$3\r\nSET\r\n$5\r\nmyk";
578        let result = parse_streaming(data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
579
580        match result {
581            ParseProgress::Incomplete => {}
582            _ => panic!("expected Incomplete"),
583        }
584    }
585
586    #[test]
587    fn test_complete_set_with_options() {
588        let header = SetHeader {
589            key: b"mykey",
590            ex: None,
591            px: None,
592            nx: false,
593            xx: false,
594            remaining_args: 2, // EX 3600
595        };
596
597        let value = b"myvalue";
598        let options_data = b"\r\n$2\r\nEX\r\n$4\r\n3600\r\n";
599
600        let (cmd, consumed) = complete_set(options_data, &header, value).unwrap();
601
602        match cmd {
603            Command::Set {
604                key, value: v, ex, ..
605            } => {
606                assert_eq!(key, b"mykey");
607                assert_eq!(v, b"myvalue");
608                assert_eq!(ex, Some(3600));
609            }
610            _ => panic!("expected Set command"),
611        }
612        assert_eq!(consumed, options_data.len());
613    }
614
615    #[test]
616    fn test_streaming_threshold_boundary() {
617        // Exactly at threshold - should use streaming
618        let value_len = STREAMING_THRESHOLD;
619        let header = format!("*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n${}\r\n", value_len);
620
621        let result = parse_streaming(
622            header.as_bytes(),
623            &ParseOptions::default(),
624            STREAMING_THRESHOLD,
625        )
626        .unwrap();
627
628        match result {
629            ParseProgress::NeedValue { value_len: vl, .. } => {
630                assert_eq!(vl, STREAMING_THRESHOLD);
631            }
632            _ => panic!("expected NeedValue at threshold"),
633        }
634
635        // Just below threshold - should use normal path (but incomplete)
636        let value_len = STREAMING_THRESHOLD - 1;
637        let header = format!("*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n${}\r\n", value_len);
638
639        let result = parse_streaming(
640            header.as_bytes(),
641            &ParseOptions::default(),
642            STREAMING_THRESHOLD,
643        )
644        .unwrap();
645
646        // Will be Incomplete because value data isn't present
647        match result {
648            ParseProgress::Incomplete => {}
649            _ => panic!("expected Incomplete for sub-threshold without data"),
650        }
651    }
652
653    #[test]
654    fn test_value_too_large_yields_value_too_large() {
655        // Create options with a small max bulk string length
656        let options = ParseOptions::new().max_bulk_string_len(1024); // 1KB limit
657
658        // SET with 2KB value (above limit)
659        let header = "*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n$2048\r\n".to_string();
660        let mut data = header.as_bytes().to_vec();
661        // Add some bytes of the value (simulating partial receive)
662        data.extend_from_slice(&[b'x'; 500]);
663
664        let result = parse_streaming(&data, &options, STREAMING_THRESHOLD).unwrap();
665
666        match result {
667            ParseProgress::ValueTooLarge {
668                value_len,
669                value_prefix_len,
670                header_consumed,
671                max_value_size,
672            } => {
673                assert_eq!(value_len, 2048);
674                assert_eq!(value_prefix_len, 500);
675                assert_eq!(max_value_size, 1024);
676                // header_consumed should be everything up to the value data
677                assert_eq!(header_consumed, header.len());
678            }
679            _ => panic!("expected ValueTooLarge, got {:?}", result),
680        }
681    }
682
683    #[test]
684    fn test_value_too_large_with_no_prefix() {
685        // Create options with a small max bulk string length
686        let options = ParseOptions::new().max_bulk_string_len(1024); // 1KB limit
687
688        // SET with 2KB value, but no value bytes in buffer yet
689        let header = "*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n$2048\r\n";
690
691        let result = parse_streaming(header.as_bytes(), &options, STREAMING_THRESHOLD).unwrap();
692
693        match result {
694            ParseProgress::ValueTooLarge {
695                value_len,
696                value_prefix_len,
697                max_value_size,
698                ..
699            } => {
700                assert_eq!(value_len, 2048);
701                assert_eq!(value_prefix_len, 0); // No value bytes in buffer
702                assert_eq!(max_value_size, 1024);
703            }
704            _ => panic!("expected ValueTooLarge, got {:?}", result),
705        }
706    }
707}