1use crate::Command;
36use crate::error::ParseError;
37use crate::value::ParseOptions;
38use std::time::Duration;
39
40pub const STREAMING_THRESHOLD: usize = 64 * 1024;
43
44#[derive(Debug)]
46pub enum ParseProgress<'a> {
47 Incomplete,
49
50 NeedValue {
58 header: SetHeader<'a>,
60 value_len: usize,
62 value_prefix: &'a [u8],
65 header_consumed: usize,
67 },
68
69 ValueTooLarge {
76 value_len: usize,
78 value_prefix_len: usize,
80 header_consumed: usize,
82 max_value_size: usize,
84 },
85
86 Complete(Command<'a>, usize),
88}
89
90#[derive(Debug, Clone)]
92pub struct SetHeader<'a> {
93 pub key: &'a [u8],
95 pub ex: Option<u64>,
97 pub px: Option<u64>,
99 pub nx: bool,
101 pub xx: bool,
103 remaining_args: usize,
105}
106
107impl<'a> SetHeader<'a> {
108 pub fn ttl(&self) -> Option<Duration> {
110 if let Some(secs) = self.ex {
111 Some(Duration::from_secs(secs))
112 } else {
113 self.px.map(Duration::from_millis)
114 }
115 }
116}
117
118pub fn parse_streaming<'a>(
141 buffer: &'a [u8],
142 options: &ParseOptions,
143 streaming_threshold: usize,
144) -> Result<ParseProgress<'a>, ParseError> {
145 let mut cursor = StreamingCursor::new(buffer, options.max_bulk_string_len);
146
147 if cursor.remaining() < 1 {
149 return Ok(ParseProgress::Incomplete);
150 }
151 if cursor.peek() != b'*' {
152 return Err(ParseError::Protocol("expected array".to_string()));
153 }
154 cursor.advance(1);
155
156 let count = match cursor.read_integer() {
158 Ok(n) => n,
159 Err(ParseError::Incomplete) => return Ok(ParseProgress::Incomplete),
160 Err(e) => return Err(e),
161 };
162
163 if count < 1 {
164 return Err(ParseError::Protocol(
165 "array must have at least 1 element".to_string(),
166 ));
167 }
168
169 const MAX_ARRAY_LEN: usize = 1024 * 1024;
170 if count > MAX_ARRAY_LEN {
171 return Err(ParseError::Protocol("array too large".to_string()));
172 }
173
174 let cmd_name = match cursor.read_bulk_string() {
176 Ok(s) => s,
177 Err(ParseError::Incomplete) => return Ok(ParseProgress::Incomplete),
178 Err(e) => return Err(e),
179 };
180
181 let cmd_str = std::str::from_utf8(cmd_name)
182 .map_err(|_| ParseError::Protocol("invalid UTF-8 in command".to_string()))?;
183
184 if !cmd_str.eq_ignore_ascii_case("set") {
186 return match Command::parse_with_options(buffer, options) {
188 Ok((cmd, consumed)) => Ok(ParseProgress::Complete(cmd, consumed)),
189 Err(ParseError::Incomplete) => Ok(ParseProgress::Incomplete),
190 Err(e) => Err(e),
191 };
192 }
193
194 if count < 3 {
196 return Err(ParseError::WrongArity(
197 "SET requires at least 2 arguments".to_string(),
198 ));
199 }
200
201 let key = match cursor.read_bulk_string() {
203 Ok(s) => s,
204 Err(ParseError::Incomplete) => return Ok(ParseProgress::Incomplete),
205 Err(e) => return Err(e),
206 };
207
208 if cursor.remaining() < 1 {
210 return Ok(ParseProgress::Incomplete);
211 }
212 if cursor.peek() != b'$' {
213 return Err(ParseError::Protocol(
214 "expected bulk string for value".to_string(),
215 ));
216 }
217 cursor.advance(1);
218
219 let value_len = match cursor.read_integer() {
220 Ok(n) => n,
221 Err(ParseError::Incomplete) => return Ok(ParseProgress::Incomplete),
222 Err(e) => return Err(e),
223 };
224
225 if value_len > cursor.max_bulk_string_len {
227 let header_consumed = cursor.position();
228 let remaining_in_buffer = cursor.remaining();
229 let value_prefix_len = remaining_in_buffer.min(value_len);
230
231 return Ok(ParseProgress::ValueTooLarge {
232 value_len,
233 value_prefix_len,
234 header_consumed,
235 max_value_size: cursor.max_bulk_string_len,
236 });
237 }
238
239 if value_len < streaming_threshold {
241 return match Command::parse_with_options(buffer, options) {
242 Ok((cmd, consumed)) => Ok(ParseProgress::Complete(cmd, consumed)),
243 Err(ParseError::Incomplete) => Ok(ParseProgress::Incomplete),
244 Err(e) => Err(e),
245 };
246 }
247
248 let header_consumed = cursor.position();
250 let remaining_in_buffer = cursor.remaining();
251
252 let value_prefix_len = remaining_in_buffer.min(value_len);
254 let value_prefix = &buffer[header_consumed..header_consumed + value_prefix_len];
255
256 let remaining_args = count.saturating_sub(3);
259
260 Ok(ParseProgress::NeedValue {
261 header: SetHeader {
262 key,
263 ex: None,
264 px: None,
265 nx: false,
266 xx: false,
267 remaining_args,
268 },
269 value_len,
270 value_prefix,
271 header_consumed,
272 })
273}
274
275pub fn complete_set<'a>(
292 buffer: &'a [u8],
293 header: &SetHeader<'a>,
294 value: &'a [u8],
295) -> Result<(Command<'a>, usize), ParseError> {
296 let mut cursor = StreamingCursor::new(buffer, usize::MAX);
297
298 if cursor.remaining() < 2 {
300 return Err(ParseError::Incomplete);
301 }
302 if cursor.peek() != b'\r' {
303 return Err(ParseError::Protocol(
304 "expected CRLF after bulk string".to_string(),
305 ));
306 }
307 cursor.advance(1);
308 if cursor.peek() != b'\n' {
309 return Err(ParseError::Protocol(
310 "expected CRLF after bulk string".to_string(),
311 ));
312 }
313 cursor.advance(1);
314
315 let mut ex = header.ex;
317 let mut px = header.px;
318 let mut nx = header.nx;
319 let mut xx = header.xx;
320
321 let mut remaining_args = header.remaining_args;
322 while remaining_args > 0 {
323 let option = match cursor.read_bulk_string() {
324 Ok(s) => s,
325 Err(ParseError::Incomplete) => return Err(ParseError::Incomplete),
326 Err(e) => return Err(e),
327 };
328
329 let option_str = std::str::from_utf8(option)
330 .map_err(|_| ParseError::Protocol("invalid UTF-8 in option".to_string()))?;
331
332 if option_str.eq_ignore_ascii_case("ex") {
333 if remaining_args < 2 {
334 return Err(ParseError::Protocol("EX requires a value".to_string()));
335 }
336 let ttl_bytes = cursor.read_bulk_string()?;
337 let ttl_str = std::str::from_utf8(ttl_bytes)
338 .map_err(|_| ParseError::Protocol("invalid UTF-8 in TTL".to_string()))?;
339 let ttl_secs = ttl_str
340 .parse::<u64>()
341 .map_err(|_| ParseError::Protocol("invalid TTL value".to_string()))?;
342 ex = Some(ttl_secs);
343 remaining_args -= 2;
344 } else if option_str.eq_ignore_ascii_case("px") {
345 if remaining_args < 2 {
346 return Err(ParseError::Protocol("PX requires a value".to_string()));
347 }
348 let ttl_bytes = cursor.read_bulk_string()?;
349 let ttl_str = std::str::from_utf8(ttl_bytes)
350 .map_err(|_| ParseError::Protocol("invalid UTF-8 in TTL".to_string()))?;
351 let ttl_ms = ttl_str
352 .parse::<u64>()
353 .map_err(|_| ParseError::Protocol("invalid TTL value".to_string()))?;
354 px = Some(ttl_ms);
355 remaining_args -= 2;
356 } else if option_str.eq_ignore_ascii_case("nx") {
357 nx = true;
358 remaining_args -= 1;
359 } else if option_str.eq_ignore_ascii_case("xx") {
360 xx = true;
361 remaining_args -= 1;
362 } else {
363 return Err(ParseError::Protocol(format!(
364 "unknown SET option: {}",
365 option_str
366 )));
367 }
368 }
369
370 Ok((
371 Command::Set {
372 key: header.key,
373 value,
374 ex,
375 px,
376 nx,
377 xx,
378 },
379 cursor.position(),
380 ))
381}
382
383struct StreamingCursor<'a> {
385 buffer: &'a [u8],
386 pos: usize,
387 max_bulk_string_len: usize,
388}
389
390impl<'a> StreamingCursor<'a> {
391 fn new(buffer: &'a [u8], max_bulk_string_len: usize) -> Self {
392 Self {
393 buffer,
394 pos: 0,
395 max_bulk_string_len,
396 }
397 }
398
399 #[inline]
400 fn remaining(&self) -> usize {
401 self.buffer.len() - self.pos
402 }
403
404 #[inline]
405 fn position(&self) -> usize {
406 self.pos
407 }
408
409 #[inline]
410 fn peek(&self) -> u8 {
411 self.buffer[self.pos]
412 }
413
414 #[inline]
415 fn advance(&mut self, n: usize) {
416 self.pos += n;
417 }
418
419 fn read_integer(&mut self) -> Result<usize, ParseError> {
420 let line = self.read_line()?;
421
422 if line.is_empty() {
423 return Err(ParseError::InvalidInteger("empty integer".to_string()));
424 }
425
426 if line.len() > 19 {
427 return Err(ParseError::InvalidInteger("integer too large".to_string()));
428 }
429
430 let mut result = 0usize;
431 for &byte in line {
432 if !byte.is_ascii_digit() {
433 return Err(ParseError::InvalidInteger(
434 "non-digit character".to_string(),
435 ));
436 }
437 result = result
438 .checked_mul(10)
439 .and_then(|r| r.checked_add((byte - b'0') as usize))
440 .ok_or_else(|| ParseError::InvalidInteger("integer overflow".to_string()))?;
441 }
442 Ok(result)
443 }
444
445 fn read_bulk_string(&mut self) -> Result<&'a [u8], ParseError> {
446 if self.remaining() < 1 {
447 return Err(ParseError::Incomplete);
448 }
449
450 if self.peek() != b'$' {
451 return Err(ParseError::Protocol("expected bulk string".to_string()));
452 }
453 self.advance(1);
454
455 let len = self.read_integer()?;
456
457 if len > self.max_bulk_string_len {
458 return Err(ParseError::BulkStringTooLong {
459 len,
460 max: self.max_bulk_string_len,
461 });
462 }
463
464 if self.remaining() < len + 2 {
465 return Err(ParseError::Incomplete);
466 }
467
468 let data = &self.buffer[self.pos..self.pos + len];
469 self.pos += len;
470
471 if self.remaining() < 2 {
472 return Err(ParseError::Incomplete);
473 }
474 if self.peek() != b'\r' {
475 return Err(ParseError::Protocol(
476 "expected CRLF after bulk string".to_string(),
477 ));
478 }
479 self.advance(1);
480 if self.peek() != b'\n' {
481 return Err(ParseError::Protocol(
482 "expected CRLF after bulk string".to_string(),
483 ));
484 }
485 self.advance(1);
486
487 Ok(data)
488 }
489
490 fn read_line(&mut self) -> Result<&'a [u8], ParseError> {
491 let start = self.pos;
492 let slice = &self.buffer[start..];
493
494 for i in 0..slice.len().saturating_sub(1) {
495 if slice[i] == b'\r' && slice[i + 1] == b'\n' {
496 let line = &self.buffer[start..start + i];
497 self.pos = start + i + 2;
498 return Ok(line);
499 }
500 }
501
502 Err(ParseError::Incomplete)
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_small_set_uses_normal_path() {
512 let data = b"*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n$7\r\nmyvalue\r\n";
513 let result = parse_streaming(data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
514
515 match result {
516 ParseProgress::Complete(cmd, consumed) => {
517 assert_eq!(
518 cmd,
519 Command::Set {
520 key: b"mykey",
521 value: b"myvalue",
522 ex: None,
523 px: None,
524 nx: false,
525 xx: false,
526 }
527 );
528 assert_eq!(consumed, data.len());
529 }
530 _ => panic!("expected Complete"),
531 }
532 }
533
534 #[test]
535 fn test_large_set_yields_need_value() {
536 let value_len = 100 * 1024;
538 let header = format!("*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n${}\r\n", value_len);
539 let mut data = header.as_bytes().to_vec();
540 data.extend_from_slice(&[b'x'; 1000]);
542
543 let result = parse_streaming(&data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
544
545 match result {
546 ParseProgress::NeedValue {
547 header,
548 value_len: vl,
549 value_prefix,
550 header_consumed,
551 } => {
552 assert_eq!(header.key, b"mykey");
553 assert_eq!(vl, 100 * 1024);
554 assert_eq!(value_prefix.len(), 1000);
555 assert!(value_prefix.iter().all(|&b| b == b'x'));
556 assert_eq!(header_consumed, header_consumed); }
558 _ => panic!("expected NeedValue, got {:?}", result),
559 }
560 }
561
562 #[test]
563 fn test_get_uses_normal_path() {
564 let data = b"*2\r\n$3\r\nGET\r\n$5\r\nmykey\r\n";
565 let result = parse_streaming(data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
566
567 match result {
568 ParseProgress::Complete(cmd, _) => {
569 assert_eq!(cmd, Command::Get { key: b"mykey" });
570 }
571 _ => panic!("expected Complete"),
572 }
573 }
574
575 #[test]
576 fn test_incomplete_header() {
577 let data = b"*3\r\n$3\r\nSET\r\n$5\r\nmyk";
578 let result = parse_streaming(data, &ParseOptions::default(), STREAMING_THRESHOLD).unwrap();
579
580 match result {
581 ParseProgress::Incomplete => {}
582 _ => panic!("expected Incomplete"),
583 }
584 }
585
586 #[test]
587 fn test_complete_set_with_options() {
588 let header = SetHeader {
589 key: b"mykey",
590 ex: None,
591 px: None,
592 nx: false,
593 xx: false,
594 remaining_args: 2, };
596
597 let value = b"myvalue";
598 let options_data = b"\r\n$2\r\nEX\r\n$4\r\n3600\r\n";
599
600 let (cmd, consumed) = complete_set(options_data, &header, value).unwrap();
601
602 match cmd {
603 Command::Set {
604 key, value: v, ex, ..
605 } => {
606 assert_eq!(key, b"mykey");
607 assert_eq!(v, b"myvalue");
608 assert_eq!(ex, Some(3600));
609 }
610 _ => panic!("expected Set command"),
611 }
612 assert_eq!(consumed, options_data.len());
613 }
614
615 #[test]
616 fn test_streaming_threshold_boundary() {
617 let value_len = STREAMING_THRESHOLD;
619 let header = format!("*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n${}\r\n", value_len);
620
621 let result = parse_streaming(
622 header.as_bytes(),
623 &ParseOptions::default(),
624 STREAMING_THRESHOLD,
625 )
626 .unwrap();
627
628 match result {
629 ParseProgress::NeedValue { value_len: vl, .. } => {
630 assert_eq!(vl, STREAMING_THRESHOLD);
631 }
632 _ => panic!("expected NeedValue at threshold"),
633 }
634
635 let value_len = STREAMING_THRESHOLD - 1;
637 let header = format!("*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n${}\r\n", value_len);
638
639 let result = parse_streaming(
640 header.as_bytes(),
641 &ParseOptions::default(),
642 STREAMING_THRESHOLD,
643 )
644 .unwrap();
645
646 match result {
648 ParseProgress::Incomplete => {}
649 _ => panic!("expected Incomplete for sub-threshold without data"),
650 }
651 }
652
653 #[test]
654 fn test_value_too_large_yields_value_too_large() {
655 let options = ParseOptions::new().max_bulk_string_len(1024); let header = "*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n$2048\r\n".to_string();
660 let mut data = header.as_bytes().to_vec();
661 data.extend_from_slice(&[b'x'; 500]);
663
664 let result = parse_streaming(&data, &options, STREAMING_THRESHOLD).unwrap();
665
666 match result {
667 ParseProgress::ValueTooLarge {
668 value_len,
669 value_prefix_len,
670 header_consumed,
671 max_value_size,
672 } => {
673 assert_eq!(value_len, 2048);
674 assert_eq!(value_prefix_len, 500);
675 assert_eq!(max_value_size, 1024);
676 assert_eq!(header_consumed, header.len());
678 }
679 _ => panic!("expected ValueTooLarge, got {:?}", result),
680 }
681 }
682
683 #[test]
684 fn test_value_too_large_with_no_prefix() {
685 let options = ParseOptions::new().max_bulk_string_len(1024); let header = "*3\r\n$3\r\nSET\r\n$5\r\nmykey\r\n$2048\r\n";
690
691 let result = parse_streaming(header.as_bytes(), &options, STREAMING_THRESHOLD).unwrap();
692
693 match result {
694 ParseProgress::ValueTooLarge {
695 value_len,
696 value_prefix_len,
697 max_value_size,
698 ..
699 } => {
700 assert_eq!(value_len, 2048);
701 assert_eq!(value_prefix_len, 0); assert_eq!(max_value_size, 1024);
703 }
704 _ => panic!("expected ValueTooLarge, got {:?}", result),
705 }
706 }
707}