redis_oxide/protocol/
resp2_optimized.rs

1//! Optimized RESP2 protocol implementation
2//!
3//! This module provides optimized versions of RESP2 encoding and decoding
4//! with focus on memory allocation reduction and performance improvements.
5
6#![allow(unused_variables)]
7#![allow(dead_code)]
8#![allow(missing_docs)]
9
10use crate::core::{
11    error::{RedisError, RedisResult},
12    value::RespValue,
13};
14use bytes::{Buf, BufMut, Bytes, BytesMut};
15use std::io::Cursor;
16
17const CRLF: &[u8] = b"\r\n";
18
19// Static strings for common responses to avoid allocations
20const OK_RESPONSE: &str = "OK";
21const PONG_RESPONSE: &str = "PONG";
22const QUEUED_RESPONSE: &str = "QUEUED";
23
24/// Optimized RESP2 encoder with buffer pre-sizing and zero-copy optimizations
25pub struct OptimizedRespEncoder {
26    // Reusable buffer to avoid allocations
27    buffer: BytesMut,
28}
29
30impl OptimizedRespEncoder {
31    /// Create a new optimized encoder
32    pub fn new() -> Self {
33        Self {
34            buffer: BytesMut::with_capacity(1024), // Start with reasonable capacity
35        }
36    }
37
38    /// Create a new encoder with specific initial capacity
39    pub fn with_capacity(capacity: usize) -> Self {
40        Self {
41            buffer: BytesMut::with_capacity(capacity),
42        }
43    }
44
45    /// Estimate buffer size needed for a RESP value
46    fn estimate_size(value: &RespValue) -> usize {
47        match value {
48            RespValue::SimpleString(s) => 1 + s.len() + 2, // +str\r\n
49            RespValue::Error(e) => 1 + e.len() + 2,        // -err\r\n
50            RespValue::Integer(i) => 1 + i.to_string().len() + 2, // :num\r\n
51            RespValue::BulkString(b) => {
52                let len_str = b.len().to_string();
53                1 + len_str.len() + 2 + b.len() + 2 // $len\r\ndata\r\n
54            }
55            RespValue::Null => 5, // $-1\r\n
56            RespValue::Array(arr) => {
57                let len_str = arr.len().to_string();
58                let mut size = 1 + len_str.len() + 2; // *len\r\n
59                for item in arr {
60                    size += Self::estimate_size(item);
61                }
62                size
63            }
64        }
65    }
66
67    /// Estimate buffer size needed for a command with arguments
68    fn estimate_command_size(command: &str, args: &[RespValue]) -> usize {
69        let total_items = 1 + args.len();
70        let array_header = 1 + total_items.to_string().len() + 2; // *count\r\n
71
72        // Command size
73        let cmd_size = 1 + command.len().to_string().len() + 2 + command.len() + 2; // $len\r\ncmd\r\n
74
75        // Arguments size
76        let args_size: usize = args.iter().map(Self::estimate_size).sum();
77
78        array_header + cmd_size + args_size
79    }
80
81    /// Encode a RESP value into the internal buffer with pre-sizing
82    pub fn encode(&mut self, value: &RespValue) -> RedisResult<Bytes> {
83        let estimated_size = Self::estimate_size(value);
84
85        // Reserve capacity if needed
86        if self.buffer.capacity() < estimated_size {
87            self.buffer.reserve(estimated_size);
88        }
89
90        self.buffer.clear();
91        self.encode_value(value)?;
92        Ok(self.buffer.split().freeze())
93    }
94
95    /// Encode a command with arguments using pre-sizing
96    pub fn encode_command(&mut self, command: &str, args: &[RespValue]) -> RedisResult<Bytes> {
97        let estimated_size = Self::estimate_command_size(command, args);
98
99        // Reserve capacity if needed
100        if self.buffer.capacity() < estimated_size {
101            self.buffer.reserve(estimated_size);
102        }
103
104        self.buffer.clear();
105
106        // Create array with command + args
107        let total_len = 1 + args.len();
108        self.buffer.put_u8(b'*');
109        self.put_integer_bytes(total_len);
110        self.buffer.put_slice(CRLF);
111
112        // Encode command as bulk string
113        self.buffer.put_u8(b'$');
114        self.put_integer_bytes(command.len());
115        self.buffer.put_slice(CRLF);
116        self.buffer.put_slice(command.as_bytes());
117        self.buffer.put_slice(CRLF);
118
119        // Encode arguments
120        for arg in args {
121            self.encode_value(arg)?;
122        }
123
124        Ok(self.buffer.split().freeze())
125    }
126
127    /// Internal method to encode a value into the buffer
128    fn encode_value(&mut self, value: &RespValue) -> RedisResult<()> {
129        match value {
130            RespValue::SimpleString(s) => {
131                self.buffer.put_u8(b'+');
132                // Use static strings for common responses
133                self.buffer.put_slice(s.as_bytes());
134                self.buffer.put_slice(CRLF);
135            }
136            RespValue::Error(e) => {
137                self.buffer.put_u8(b'-');
138                self.buffer.put_slice(e.as_bytes());
139                self.buffer.put_slice(CRLF);
140            }
141            RespValue::Integer(i) => {
142                self.buffer.put_u8(b':');
143                self.put_integer_bytes(*i);
144                self.buffer.put_slice(CRLF);
145            }
146            RespValue::BulkString(data) => {
147                self.buffer.put_u8(b'$');
148                self.put_integer_bytes(data.len());
149                self.buffer.put_slice(CRLF);
150                self.buffer.put_slice(data);
151                self.buffer.put_slice(CRLF);
152            }
153            RespValue::Null => {
154                self.buffer.put_slice(b"$-1\r\n");
155            }
156            RespValue::Array(arr) => {
157                self.buffer.put_u8(b'*');
158                self.put_integer_bytes(arr.len());
159                self.buffer.put_slice(CRLF);
160                for item in arr {
161                    self.encode_value(item)?;
162                }
163            }
164        }
165        Ok(())
166    }
167
168    /// Optimized integer to bytes conversion
169    fn put_integer_bytes<T: itoa::Integer>(&mut self, value: T) {
170        let mut buffer = itoa::Buffer::new();
171        let s = buffer.format(value);
172        self.buffer.put_slice(s.as_bytes());
173    }
174
175    /// Get the current buffer capacity
176    pub fn capacity(&self) -> usize {
177        self.buffer.capacity()
178    }
179
180    /// Clear the internal buffer
181    pub fn clear(&mut self) {
182        self.buffer.clear();
183    }
184
185    /// Reserve additional capacity
186    pub fn reserve(&mut self, additional: usize) {
187        self.buffer.reserve(additional);
188    }
189}
190
191impl Default for OptimizedRespEncoder {
192    fn default() -> Self {
193        Self::new()
194    }
195}
196
197/// Optimized RESP2 decoder with streaming support and reduced allocations
198pub struct OptimizedRespDecoder {
199    buffer: BytesMut,
200    // String cache for frequently used strings
201    string_cache: std::collections::HashMap<Vec<u8>, String>,
202    max_cache_size: usize,
203}
204
205impl OptimizedRespDecoder {
206    /// Create a new optimized decoder
207    pub fn new() -> Self {
208        Self {
209            buffer: BytesMut::with_capacity(4096),
210            string_cache: std::collections::HashMap::new(),
211            max_cache_size: 1000, // Limit cache size to prevent memory leaks
212        }
213    }
214
215    /// Create a new decoder with specific buffer capacity and cache size
216    pub fn with_config(buffer_capacity: usize, max_cache_size: usize) -> Self {
217        Self {
218            buffer: BytesMut::with_capacity(buffer_capacity),
219            string_cache: std::collections::HashMap::new(),
220            max_cache_size,
221        }
222    }
223
224    /// Decode data with streaming support
225    pub fn decode_streaming(&mut self, data: &[u8]) -> RedisResult<Vec<RespValue>> {
226        self.buffer.extend_from_slice(data);
227        let mut results = Vec::new();
228
229        loop {
230            let buffer_len = self.buffer.len();
231            if buffer_len == 0 {
232                break;
233            }
234
235            let buffer_slice = self.buffer.clone().freeze();
236            let mut cursor = Cursor::new(&buffer_slice[..]);
237
238            match self.try_decode_value(&mut cursor)? {
239                Some(value) => {
240                    let consumed = cursor.position() as usize;
241                    self.buffer.advance(consumed);
242                    results.push(value);
243                }
244                None => break, // Need more data
245            }
246        }
247
248        Ok(results)
249    }
250
251    /// Try to decode a single value, returning None if more data is needed
252    fn try_decode_value(&mut self, cursor: &mut Cursor<&[u8]>) -> RedisResult<Option<RespValue>> {
253        if !cursor.has_remaining() {
254            return Ok(None);
255        }
256
257        let type_byte = cursor.chunk()[0];
258        cursor.advance(1);
259
260        match type_byte {
261            b'+' => self.try_decode_simple_string(cursor),
262            b'-' => self.try_decode_error(cursor),
263            b':' => self.try_decode_integer(cursor),
264            b'$' => self.try_decode_bulk_string(cursor),
265            b'*' => self.try_decode_array(cursor),
266            _ => Err(RedisError::Protocol(format!(
267                "Invalid RESP type byte: {}",
268                type_byte as char
269            ))),
270        }
271    }
272
273    /// Try to decode a simple string with caching
274    fn try_decode_simple_string(
275        &mut self,
276        cursor: &mut Cursor<&[u8]>,
277    ) -> RedisResult<Option<RespValue>> {
278        if let Some(line) = self.try_read_line(cursor)? {
279            let string = self.bytes_to_string_cached(&line)?;
280            Ok(Some(RespValue::SimpleString(string)))
281        } else {
282            Ok(None)
283        }
284    }
285
286    /// Try to decode an error string
287    fn try_decode_error(&mut self, cursor: &mut Cursor<&[u8]>) -> RedisResult<Option<RespValue>> {
288        if let Some(line) = self.try_read_line(cursor)? {
289            let string = String::from_utf8(line)
290                .map_err(|e| RedisError::Protocol(format!("Invalid UTF-8 in error: {e}")))?;
291            Ok(Some(RespValue::Error(string)))
292        } else {
293            Ok(None)
294        }
295    }
296
297    /// Try to decode an integer
298    fn try_decode_integer(&mut self, cursor: &mut Cursor<&[u8]>) -> RedisResult<Option<RespValue>> {
299        if let Some(line) = self.try_read_line(cursor)? {
300            let s = String::from_utf8(line)
301                .map_err(|e| RedisError::Protocol(format!("Invalid UTF-8 in integer: {e}")))?;
302            let num = s
303                .parse::<i64>()
304                .map_err(|e| RedisError::Protocol(format!("Invalid integer format: {e}")))?;
305            Ok(Some(RespValue::Integer(num)))
306        } else {
307            Ok(None)
308        }
309    }
310
311    /// Try to decode a bulk string
312    fn try_decode_bulk_string(
313        &mut self,
314        cursor: &mut Cursor<&[u8]>,
315    ) -> RedisResult<Option<RespValue>> {
316        let len_line = match self.try_read_line(cursor)? {
317            Some(line) => line,
318            None => return Ok(None),
319        };
320
321        let len_str = String::from_utf8(len_line).map_err(|e| {
322            RedisError::Protocol(format!("Invalid UTF-8 in bulk string length: {e}"))
323        })?;
324        let len = len_str
325            .parse::<isize>()
326            .map_err(|e| RedisError::Protocol(format!("Invalid bulk string length: {e}")))?;
327
328        if len == -1 {
329            return Ok(Some(RespValue::Null));
330        }
331
332        if len < 0 {
333            return Err(RedisError::Protocol(
334                "Invalid bulk string length".to_string(),
335            ));
336        }
337
338        let len = len as usize;
339        if cursor.remaining() < len + 2 {
340            return Ok(None); // Need more data
341        }
342
343        let data = cursor.chunk()[..len].to_vec();
344        cursor.advance(len);
345
346        // Check for CRLF
347        if cursor.remaining() < 2 || &cursor.chunk()[..2] != CRLF {
348            return Err(RedisError::Protocol(
349                "Missing CRLF after bulk string".to_string(),
350            ));
351        }
352        cursor.advance(2);
353
354        Ok(Some(RespValue::BulkString(Bytes::from(data))))
355    }
356
357    /// Try to decode an array
358    fn try_decode_array(&mut self, cursor: &mut Cursor<&[u8]>) -> RedisResult<Option<RespValue>> {
359        let len_line = match self.try_read_line(cursor)? {
360            Some(line) => line,
361            None => return Ok(None),
362        };
363
364        let len_str = String::from_utf8(len_line)
365            .map_err(|e| RedisError::Protocol(format!("Invalid UTF-8 in array length: {e}")))?;
366        let len = len_str
367            .parse::<isize>()
368            .map_err(|e| RedisError::Protocol(format!("Invalid array length: {e}")))?;
369
370        if len == -1 {
371            return Ok(Some(RespValue::Null));
372        }
373
374        if len < 0 {
375            return Err(RedisError::Protocol("Invalid array length".to_string()));
376        }
377
378        let len = len as usize;
379        let mut elements = Vec::with_capacity(len);
380
381        for _ in 0..len {
382            match self.try_decode_value(cursor)? {
383                Some(element) => elements.push(element),
384                None => return Ok(None), // Need more data
385            }
386        }
387
388        Ok(Some(RespValue::Array(elements)))
389    }
390
391    /// Try to read a line, returning None if incomplete
392    fn try_read_line(&self, cursor: &mut Cursor<&[u8]>) -> RedisResult<Option<Vec<u8>>> {
393        let start_pos = cursor.position() as usize;
394        let remaining = cursor.get_ref();
395
396        // Look for CRLF
397        for (i, window) in remaining[start_pos..].windows(2).enumerate() {
398            if window == CRLF {
399                let line_end = start_pos + i;
400                let line = remaining[start_pos..line_end].to_vec();
401                cursor.advance(i + 2); // Skip line + CRLF
402                return Ok(Some(line));
403            }
404        }
405
406        Ok(None) // No complete line found
407    }
408
409    /// Convert bytes to string with caching for frequently used strings
410    fn bytes_to_string_cached(&mut self, bytes: &[u8]) -> RedisResult<String> {
411        // Check cache first
412        if let Some(cached) = self.string_cache.get(bytes) {
413            return Ok(cached.clone());
414        }
415
416        let string = String::from_utf8(bytes.to_vec())
417            .map_err(|e| RedisError::Protocol(format!("Invalid UTF-8: {e}")))?;
418
419        // Cache the string if cache isn't full
420        if self.string_cache.len() < self.max_cache_size {
421            self.string_cache.insert(bytes.to_vec(), string.clone());
422        }
423
424        Ok(string)
425    }
426
427    /// Clear the string cache
428    pub fn clear_cache(&mut self) {
429        self.string_cache.clear();
430    }
431
432    /// Get cache statistics
433    pub fn cache_stats(&self) -> (usize, usize) {
434        (self.string_cache.len(), self.max_cache_size)
435    }
436}
437
438impl Default for OptimizedRespDecoder {
439    fn default() -> Self {
440        Self::new()
441    }
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447
448    #[test]
449    fn test_optimized_encoder_simple_string() {
450        let mut encoder = OptimizedRespEncoder::new();
451        let value = RespValue::SimpleString("OK".to_string());
452        let encoded_result = encoder.encode(&value).unwrap();
453        assert_eq!(encoded_result, Bytes::from("+OK\r\n"));
454    }
455
456    #[test]
457    fn test_optimized_encoder_command() {
458        let mut encoder = OptimizedRespEncoder::new();
459        let args = vec![RespValue::from("mykey")];
460        let encoded_result = encoder.encode_command("GET", &args).unwrap();
461
462        let expected = "*2\r\n$3\r\nGET\r\n$5\r\nmykey\r\n";
463        assert_eq!(encoded_result, Bytes::from(expected));
464    }
465
466    #[test]
467    fn test_optimized_decoder_streaming() {
468        let mut decoder = OptimizedRespDecoder::new();
469
470        // Test partial data
471        let partial1 = b"+OK\r\n:42\r\n$5\r\nhel";
472        let results1 = decoder.decode_streaming(partial1).unwrap();
473        assert_eq!(results1.len(), 2); // Should decode +OK and :42
474
475        // Complete the bulk string
476        let partial2 = b"lo\r\n";
477        let results2 = decoder.decode_streaming(partial2).unwrap();
478        assert_eq!(results2.len(), 1); // Should decode the bulk string
479
480        match &results2[0] {
481            RespValue::BulkString(b) => assert_eq!(b, &Bytes::from("hello")),
482            _ => panic!("Expected bulk string"),
483        }
484    }
485
486    #[test]
487    fn test_size_estimation() {
488        let value = RespValue::SimpleString("OK".to_string());
489        let estimated = OptimizedRespEncoder::estimate_size(&value);
490        assert_eq!(estimated, 5); // +OK\r\n
491
492        let value = RespValue::BulkString(Bytes::from("hello"));
493        let estimated = OptimizedRespEncoder::estimate_size(&value);
494        assert_eq!(estimated, 11); // $5\r\nhello\r\n
495    }
496
497    #[test]
498    fn test_string_caching() {
499        let mut decoder = OptimizedRespDecoder::new();
500
501        // Decode the same string multiple times
502        let data = b"+OK\r\n+OK\r\n";
503        let results = decoder.decode_streaming(data).unwrap();
504
505        assert_eq!(results.len(), 2);
506        let (cache_size, _) = decoder.cache_stats();
507        assert_eq!(cache_size, 1); // "OK" should be cached
508    }
509}