Skip to main content

jupiter/
request.rs

1//! Provides a parser and wrapper for handling incoming RESP requests.
2//!
3//! A RESP request (that is "REdis Serialization Protocol") is quite simple. It starts with a
4//! "*" followed by the number of arguments. For each argument, there is a "$" followed by the
5//! number of bytes in the argument string followed by "\r\n". After the string data, yet another
6//! CRLF (\r\n) is output.
7//!
8//! Therefore a simple request might look like:
9//! * "PING" => `*1\r\n$4\r\nPING\r\n`
10//! * "SET my-key 5 => `*3\r\n$3\r\SET\r\n$6\r\nmy-key\r\n$1\r\n5\r\n`
11//!
12//! As re receive these request via a network interface, which might also provide partial requests
13//! we require a very fast and efficient algorithm to detect if the given data is a valid request
14//! and what its parameters are.
15//!
16//! [Request::parse](Request::parse) implements this algorithm which can detect partial requests
17//! in less than 100ns and also parse full requests well below 500ns! As internally all results are
18//! only indices into the given byte buffer, only a single allocation for the list of offsets
19//! is performed and no data is copied whatsoever.
20//!
21//! # Examples
22//!
23//! Parsing a simple request:
24//! ```
25//! # use bytes::BytesMut;
26//! # use jupiter::request::Request;
27//! let mut bytes = BytesMut::from("*1\r\n$4\r\nPING\r\n");
28//! let result = Request::parse(&mut bytes).unwrap().unwrap();
29//!
30//! assert_eq!(result.command(), "PING");
31//! assert_eq!(result.parameter_count(), 0);
32//! ```
33//!
34//! Parsing a partial request:
35//! ```
36//! # use bytes::BytesMut;
37//! # use jupiter::request::Request;
38//! let mut bytes = BytesMut::from("*2\r\n$4\r\nPING\r\n$7\r\nTESTP");
39//! let result = Request::parse(&mut bytes).unwrap();
40//!
41//! assert_eq!(result.is_none(), true);
42//! ```
43//!
44//! Parsing an invalid request:
45//! ```
46//! # use bytes::BytesMut;
47//! # use jupiter::request::Request;
48//! let mut bytes = BytesMut::from("$4\r\nPING\r\n");
49//! let result = Request::parse(&mut bytes);
50//!
51//! assert_eq!(result.is_err(), true);
52//! ```
53//!
54//! Building a request for test environments:
55//! ```
56//! # use jupiter::request::Request;
57//! let request = Request::example(vec!("PING"));
58//! assert_eq!(request.command(), "PING");
59//! ```
60use std::fmt::{Display, Formatter, Write};
61
62use anyhow::{anyhow, Context, Result};
63use bytes::{Bytes, BytesMut};
64
65/// Provides an internal representation of either the command or a parameter.
66///
67/// We only need to keep the byte offsets around as the underlying byte buffer is tied to
68/// the request anyway.
69#[derive(Copy, Clone, Debug)]
70struct Range {
71    start: usize,
72    end: usize,
73}
74
75impl Range {
76    /// Computes the start of the subsequent range by skipping over the CRLF.
77    fn next_offset(&self) -> usize {
78        self.end + 3
79    }
80}
81
82impl Display for Range {
83    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
84        write!(f, "{}..{}", self.start, self.end)
85    }
86}
87
88/// Represents a parsed RESP request.
89///
90/// Note that we treat the 1st parameter as "command" and re-number all other parameters
91/// accordingly. Therefore "SET x y" will have "SET" as command, "x" as first parameter
92/// (index: 0) and "y" as second (index: 1).
93pub struct Request {
94    len: usize,
95    data: Bytes,
96    command: Range,
97    arguments: Vec<Range>,
98}
99
100impl Request {
101    const DOLLAR: u8 = b'$';
102    const ASTERISK: u8 = b'*';
103    const CR: u8 = b'\r';
104    const ZERO_DIGIT: u8 = b'0';
105    const NINE_DIGIT: u8 = b'9';
106
107    /// Tries to parse a RESP request from the given byte buffer.
108    ///
109    /// If malformed data is detected, we return an **Err**. Otherwise we either return an
110    /// empty optional, in case only a partial request is present or otherwise a full request
111    /// which has then the form `Ok(Some(Request))`.
112    ///
113    ///
114    /// # Examples
115    ///
116    /// Parsing a simple request:
117    /// ```
118    /// # use bytes::BytesMut;
119    /// # use jupiter::request::Request;
120    /// let mut bytes = BytesMut::from("*3\r\n$3\r\nSET\r\n$1\r\nx\r\n$1\r\ny\r\n");
121    /// let result = Request::parse(&mut bytes).unwrap().unwrap();
122    ///
123    /// assert_eq!(result.command(), "SET");
124    /// assert_eq!(result.parameter_count(), 2);
125    /// assert_eq!(result.str_parameter(0).unwrap(), "x");
126    /// assert_eq!(result.str_parameter(1).unwrap(), "y");
127    /// assert_eq!(result.str_parameter(2).is_err(), true);
128    /// ```
129    pub fn parse(data: &BytesMut) -> anyhow::Result<Option<Request>> {
130        let len = data.len();
131
132        // Abort as early as possible if a partial request is present...
133        if len < 4 || data[data.len() - 2] != Request::CR {
134            Ok(None)
135        } else {
136            Request::parse_inner(data)
137        }
138    }
139
140    /// Provides a helper function to create an example request in test environments.
141    ///
142    /// # Example
143    /// ```
144    /// # use jupiter::request::Request;
145    /// let request = Request::example(vec!("PING"));
146    /// assert_eq!(request.command(), "PING");
147    /// ```
148    pub fn example(data: Vec<&str>) -> Request {
149        let mut input = String::new();
150        let _ = write!(input, "*{}\r\n", data.len());
151        for param in data {
152            let _ = write!(input, "${}\r\n{}\r\n", param.len(), param);
153        }
154
155        Request::parse(&BytesMut::from(input.as_str()))
156            .unwrap()
157            .unwrap()
158    }
159
160    fn parse_inner(data: &BytesMut) -> anyhow::Result<Option<Request>> {
161        // Check general validity of the request...
162        let mut offset = 0;
163        if data[0] != Request::ASTERISK {
164            return Err(anyhow!("A request must be an array of bulk strings!"));
165        } else {
166            offset += 1;
167        }
168
169        // Parse the number of arguments...
170        let (mut num_args, range) = match Request::read_int(data, offset)? {
171            Some((num_args, range)) => (num_args - 1, range),
172            _ => return Ok(None),
173        };
174        offset = range.next_offset();
175
176        // Parse the first parameter as "command"...
177        let command = match Request::read_bulk_string(data, offset)? {
178            Some(range) => range,
179            _ => return Ok(None),
180        };
181        offset = command.next_offset();
182
183        // Parse the remaining arguments...
184        let mut arguments = Vec::with_capacity(num_args as usize);
185        while num_args > 0 {
186            if let Some(range) = Request::read_bulk_string(data, offset)? {
187                arguments.push(range);
188                num_args -= 1;
189                offset = range.next_offset();
190            } else {
191                return Ok(None);
192            }
193        }
194
195        Ok(Some(Request {
196            len: offset,
197            data: data.clone().freeze(),
198            command,
199            arguments,
200        }))
201    }
202
203    /// Tries to parse a number.
204    ///
205    /// This is either the number of arguments or the length of an argument string. Note that
206    /// the algorithm and also the return type is a bit more complex as we have to handle the
207    /// happy path (valid number being read) as well as an error (invalid number found) and also
208    /// the partial request case (we didn't discover the final CR which marks the end of the
209    /// number).
210    fn read_int(buffer: &BytesMut, offset: usize) -> anyhow::Result<Option<(i32, Range)>> {
211        let mut value: i32 = 0;
212        let mut index = offset;
213        while index < buffer.len() {
214            let digit = buffer[index];
215            if digit == Request::CR {
216                return if buffer.len() > index {
217                    Ok(Some((
218                        value,
219                        Range {
220                            start: offset,
221                            end: index - 1,
222                        },
223                    )))
224                } else {
225                    Ok(None)
226                };
227            }
228            if !(Request::ZERO_DIGIT..=Request::NINE_DIGIT).contains(&digit) {
229                return Err(anyhow!("Malformed integer at position {}", index));
230            }
231
232            value = value * 10 + (digit - Request::ZERO_DIGIT) as i32;
233            index += 1;
234        }
235
236        Ok(None)
237    }
238
239    fn read_bulk_string(buffer: &BytesMut, offset: usize) -> anyhow::Result<Option<Range>> {
240        if offset >= buffer.len() {
241            return Ok(None);
242        }
243        if buffer[offset] != Request::DOLLAR {
244            return Err(anyhow!("Expected a bulk string at {}", offset));
245        }
246
247        if let Some((length, range)) = Request::read_int(buffer, offset + 1)? {
248            let next_offset = range.next_offset();
249            if buffer.len() >= next_offset + length as usize + 2 {
250                return Ok(Some(Range {
251                    start: next_offset,
252                    end: next_offset + length as usize - 1,
253                }));
254            }
255        }
256
257        Ok(None)
258    }
259
260    /// Returns the command in the request (this is the first parameter).
261    pub fn command(&self) -> &str {
262        std::str::from_utf8(&self.data[self.command.start..=self.command.end]).unwrap()
263    }
264
265    /// Returns the number of parameters (not counting the command itself).
266    pub fn parameter_count(&self) -> usize {
267        self.arguments.len()
268    }
269
270    /// Returns the n-th parameter (not including the command).
271    ///
272    /// Returns an **Err** if the requested index is outside of the range of detected
273    /// parameters.
274    pub fn parameter(&self, index: usize) -> Result<Bytes> {
275        if index < self.arguments.len() {
276            Ok(self
277                .data
278                .slice(self.arguments[index].start..=self.arguments[index].end))
279        } else {
280            Err(anyhow!(
281                "Invalid parameter index {} (only {} are present)",
282                index,
283                self.arguments.len()
284            ))
285        }
286    }
287
288    /// Returns the n-th parameter as UTF-8 string.
289    ///
290    /// Returns an **Err** if either the requested index is out of range or if the parameter
291    /// data isn't a valid UTF-8 sequence.
292    pub fn str_parameter(&self, index: usize) -> Result<&str> {
293        if index < self.arguments.len() {
294            let range = self.arguments[index];
295            std::str::from_utf8(&self.data[range.start..=range.end]).with_context(|| {
296                format!(
297                    "Failed to parse parameter {} (range {}) as UTF-8 string!",
298                    index, range
299                )
300            })
301        } else {
302            Err(anyhow!(
303                "Invalid parameter index {} (only {} are present)",
304                index,
305                self.arguments.len()
306            ))
307        }
308    }
309
310    /// Returns the n-th parameter as integer.
311    ///
312    /// Returns an **Err** if either the requested index is out of range or if the parameter
313    /// data isn't a valid integer number.
314    pub fn int_parameter(&self, index: usize) -> Result<i32> {
315        let string = self.str_parameter(index)?;
316        string.parse().with_context(|| {
317            format!(
318                "Failed to parse parameter {} ('{}') as integer!",
319                index, string
320            )
321        })
322    }
323
324    /// Returns the total length on bytes for this request.
325    pub fn len(&self) -> usize {
326        self.len
327    }
328
329    /// Determines if the request is completely empty.
330    pub fn is_empty(&self) -> bool {
331        self.len == 0
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use crate::request::Request;
338    use bytes::BytesMut;
339
340    #[test]
341    fn a_command_is_successfully_parsed() {
342        let request = Request::parse(&mut BytesMut::from(
343            "*3\r\n$10\r\ntest.hello\r\n$5\r\nWorld\r\n$2\r\n42\r\n",
344        ))
345        .unwrap()
346        .unwrap();
347
348        assert_eq!(request.parameter_count(), 2);
349        assert_eq!(request.command(), "test.hello");
350
351        assert_eq!(request.str_parameter(0).unwrap(), "World");
352        assert_eq!(
353            std::str::from_utf8(request.parameter(0).unwrap().as_ref()).unwrap(),
354            "World"
355        );
356        assert_eq!(request.int_parameter(1).unwrap(), 42);
357
358        assert_eq!(request.str_parameter(2).is_err(), true);
359        assert_eq!(request.int_parameter(2).is_err(), true);
360        assert_eq!(request.parameter(2).is_err(), true);
361    }
362
363    #[test]
364    fn missing_array_is_detected() {
365        let result = Request::parse(&mut BytesMut::from("+GET\r\n"));
366        assert_eq!(result.is_err(), true);
367    }
368
369    #[test]
370    fn non_bulk_string_is_detected() {
371        let result = Request::parse(&mut BytesMut::from("*1\r\n+GET\r\n"));
372        assert_eq!(result.is_err(), true);
373    }
374
375    #[test]
376    fn invalid_number_is_detected() {
377        let result = Request::parse(&mut BytesMut::from("*GET\r\n"));
378        assert_eq!(result.is_err(), true);
379    }
380
381    #[test]
382    fn an_incomplete_command_is_skipped() {
383        {
384            let result = Request::parse(&mut BytesMut::from("")).unwrap();
385            assert_eq!(result.is_none(), true);
386        }
387        {
388            let result = Request::parse(&mut BytesMut::from("*")).unwrap();
389            assert_eq!(result.is_none(), true);
390        }
391        {
392            let result = Request::parse(&mut BytesMut::from("*1")).unwrap();
393            assert_eq!(result.is_none(), true);
394        }
395        {
396            let result = Request::parse(&mut BytesMut::from("*1\r")).unwrap();
397            assert_eq!(result.is_none(), true);
398        }
399        {
400            let result = Request::parse(&mut BytesMut::from("*1\r\n")).unwrap();
401            assert_eq!(result.is_none(), true);
402        }
403        {
404            let result = Request::parse(&mut BytesMut::from("*2\r\n$10\r\ntest.h")).unwrap();
405            assert_eq!(result.is_none(), true);
406        }
407        {
408            let result =
409                Request::parse(&mut BytesMut::from("*2\r\n$10\r\ntest.hello\r\n")).unwrap();
410            assert_eq!(result.is_none(), true);
411        }
412    }
413}