jupiter_rs/
request.rs

1use std::fmt::{Display, Formatter};
2
3use anyhow::{anyhow, Context, Result};
4use bytes::{Buf, Bytes, BytesMut};
5
6#[derive(Copy, Clone, Debug)]
7struct Range {
8    start: usize,
9    end: usize,
10}
11
12impl Range {
13    fn next_offset(&self) -> usize {
14        self.end + 3
15    }
16}
17
18impl Display for Range {
19    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
20        write!(f, "{}..{}", self.start, self.end)
21    }
22}
23
24pub struct Request {
25    data: Bytes,
26    command: Range,
27    arguments: Vec<Range>,
28}
29
30const DOLLAR: u8 = b'$';
31const ASTERISK: u8 = b'*';
32const CR: u8 = b'\r';
33const ZERO_DIGIT: u8 = b'0';
34const NINE_DIGIT: u8 = b'9';
35
36fn read_int(buffer: &BytesMut, offset: usize) -> anyhow::Result<Option<(i32, Range)>> {
37    let mut len: i32 = 0;
38    let mut index = offset;
39    while index < buffer.len() {
40        let digit = buffer[index];
41        if digit == CR {
42            return if buffer.len() >= index + 1 {
43                Ok(Some((
44                    len,
45                    Range {
46                        start: offset,
47                        end: index - 1,
48                    },
49                )))
50            } else {
51                Ok(None)
52            };
53        }
54        if digit < ZERO_DIGIT || digit > NINE_DIGIT {
55            return Err(anyhow!("Malformed integer at position {}", index));
56        }
57
58        len = len * 10 + (digit - ZERO_DIGIT) as i32;
59        index += 1;
60    }
61
62    Ok(None)
63}
64
65fn read_bulk_string(buffer: &BytesMut, offset: usize) -> anyhow::Result<Option<Range>> {
66    if offset >= buffer.len() {
67        return Ok(None);
68    }
69    if buffer[offset] != DOLLAR {
70        return Err(anyhow!("Expected a bulk string at {}", offset));
71    }
72
73    if let Some((length, range)) = read_int(buffer, offset + 1)? {
74        let next_offset = range.next_offset();
75        if buffer.len() >= next_offset + length as usize + 2 {
76            return Ok(Some(Range {
77                start: next_offset,
78                end: next_offset + length as usize - 1,
79            }));
80        }
81    }
82
83    Ok(None)
84}
85
86impl Request {
87    #[inline]
88    pub fn parse(data: &mut BytesMut) -> anyhow::Result<Option<Request>> {
89        let len = data.len();
90
91        if len < 4 || data[data.len() - 2] != b'\r' {
92            Ok(None)
93        } else {
94            Request::parse_inner(data)
95        }
96    }
97
98    fn parse_inner(data: &mut BytesMut) -> anyhow::Result<Option<Request>> {
99        let mut offset = 0;
100        if data[0] != ASTERISK {
101            return Err(anyhow!("A request must be an array of bulk strings!"));
102        } else {
103            offset += 1;
104        }
105
106        let (mut num_args, range) = match read_int(&data, offset)? {
107            Some((num_args, range)) => (num_args - 1, range),
108            _ => return Ok(None),
109        };
110        offset = range.next_offset();
111
112        let command = match read_bulk_string(&data, offset)? {
113            Some(range) => range,
114            _ => return Ok(None),
115        };
116        offset = command.next_offset();
117
118        let mut arguments = Vec::with_capacity(num_args as usize);
119        while num_args > 0 {
120            if let Some(range) = read_bulk_string(&data, offset)? {
121                arguments.push(range);
122                num_args -= 1;
123                offset = range.next_offset();
124            } else {
125                return Ok(None);
126            }
127        }
128
129        let result_data = data.to_bytes();
130        if offset >= data.len() {
131            data.clear();
132        } else {
133            data.advance(offset);
134        }
135
136        Ok(Some(Request {
137            data: result_data,
138            command,
139            arguments,
140        }))
141    }
142
143    pub fn command(&self) -> &str {
144        std::str::from_utf8(&self.data[self.command.start..=self.command.end]).unwrap()
145    }
146
147    pub fn parameter_count(&self) -> usize {
148        self.arguments.len()
149    }
150
151    pub fn parameter(&self, index: usize) -> Result<Bytes> {
152        if index < self.arguments.len() {
153            Ok(self
154                .data
155                .slice(self.arguments[index].start..=self.arguments[index].end))
156        } else {
157            Err(anyhow!(
158                "Invalid parameter index {} (only {} are present)",
159                index,
160                self.arguments.len()
161            ))
162        }
163    }
164
165    pub fn str_parameter(&self, index: usize) -> Result<&str> {
166        if index < self.arguments.len() {
167            let range = self.arguments[index];
168            std::str::from_utf8(&self.data[range.start..=range.end]).with_context(|| {
169                format!(
170                    "Failed to parse parameter {} (range {}) as UTF-8 string!",
171                    index, range
172                )
173            })
174        } else {
175            Err(anyhow!(
176                "Invalid parameter index {} (only {} are present)",
177                index,
178                self.arguments.len()
179            ))
180        }
181    }
182
183    pub fn int_parameter(&self, index: usize) -> Result<i32> {
184        let string = self.str_parameter(index)?;
185        string.parse().with_context(|| {
186            format!(
187                "Failed to parse parameter {} ('{}') as integer!",
188                index, string
189            )
190        })
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use crate::request::Request;
197    use bytes::BytesMut;
198
199    #[test]
200    fn a_command_is_successfully_parsed() {
201        let request = Request::parse(&mut BytesMut::from(
202            "*3\r\n$10\r\ntest.hello\r\n$5\r\nWorld\r\n$2\r\n42\r\n",
203        ))
204        .unwrap()
205        .unwrap();
206
207        assert_eq!(request.parameter_count(), 2);
208        assert_eq!(request.command(), "test.hello");
209
210        assert_eq!(request.str_parameter(0).unwrap(), "World");
211        assert_eq!(
212            std::str::from_utf8(request.parameter(0).unwrap().as_ref()).unwrap(),
213            "World"
214        );
215        assert_eq!(request.int_parameter(1).unwrap(), 42);
216
217        assert_eq!(request.str_parameter(2).is_err(), true);
218        assert_eq!(request.int_parameter(2).is_err(), true);
219        assert_eq!(request.parameter(2).is_err(), true);
220    }
221
222    #[test]
223    fn missing_array_is_detected() {
224        let result = Request::parse(&mut BytesMut::from("+GET\r\n"));
225        assert_eq!(result.is_err(), true);
226    }
227
228    #[test]
229    fn non_bulk_string_is_detected() {
230        let result = Request::parse(&mut BytesMut::from("*1\r\n+GET\r\n"));
231        assert_eq!(result.is_err(), true);
232    }
233
234    #[test]
235    fn invalid_number_is_detected() {
236        let result = Request::parse(&mut BytesMut::from("*GET\r\n"));
237        assert_eq!(result.is_err(), true);
238    }
239
240    #[test]
241    fn an_incomplete_command_is_skipped() {
242        {
243            let result = Request::parse(&mut BytesMut::from("")).unwrap();
244            assert_eq!(result.is_none(), true);
245        }
246        {
247            let result = Request::parse(&mut BytesMut::from("*")).unwrap();
248            assert_eq!(result.is_none(), true);
249        }
250        {
251            let result = Request::parse(&mut BytesMut::from("*1")).unwrap();
252            assert_eq!(result.is_none(), true);
253        }
254        {
255            let result = Request::parse(&mut BytesMut::from("*1\r")).unwrap();
256            assert_eq!(result.is_none(), true);
257        }
258        {
259            let result = Request::parse(&mut BytesMut::from("*1\r\n")).unwrap();
260            assert_eq!(result.is_none(), true);
261        }
262        {
263            let result = Request::parse(&mut BytesMut::from("*2\r\n$10\r\ntest.h")).unwrap();
264            assert_eq!(result.is_none(), true);
265        }
266        {
267            let result =
268                Request::parse(&mut BytesMut::from("*2\r\n$10\r\ntest.hello\r\n")).unwrap();
269            assert_eq!(result.is_none(), true);
270        }
271    }
272}