1use std::fmt::{Display, Formatter};
2
3use anyhow::{anyhow, Context, Result};
4use bytes::{Buf, Bytes, BytesMut};
5
6#[derive(Copy, Clone, Debug)]
7struct Range {
8 start: usize,
9 end: usize,
10}
11
12impl Range {
13 fn next_offset(&self) -> usize {
14 self.end + 3
15 }
16}
17
18impl Display for Range {
19 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
20 write!(f, "{}..{}", self.start, self.end)
21 }
22}
23
24pub struct Request {
25 data: Bytes,
26 command: Range,
27 arguments: Vec<Range>,
28}
29
30const DOLLAR: u8 = b'$';
31const ASTERISK: u8 = b'*';
32const CR: u8 = b'\r';
33const ZERO_DIGIT: u8 = b'0';
34const NINE_DIGIT: u8 = b'9';
35
36fn read_int(buffer: &BytesMut, offset: usize) -> anyhow::Result<Option<(i32, Range)>> {
37 let mut len: i32 = 0;
38 let mut index = offset;
39 while index < buffer.len() {
40 let digit = buffer[index];
41 if digit == CR {
42 return if buffer.len() >= index + 1 {
43 Ok(Some((
44 len,
45 Range {
46 start: offset,
47 end: index - 1,
48 },
49 )))
50 } else {
51 Ok(None)
52 };
53 }
54 if digit < ZERO_DIGIT || digit > NINE_DIGIT {
55 return Err(anyhow!("Malformed integer at position {}", index));
56 }
57
58 len = len * 10 + (digit - ZERO_DIGIT) as i32;
59 index += 1;
60 }
61
62 Ok(None)
63}
64
65fn read_bulk_string(buffer: &BytesMut, offset: usize) -> anyhow::Result<Option<Range>> {
66 if offset >= buffer.len() {
67 return Ok(None);
68 }
69 if buffer[offset] != DOLLAR {
70 return Err(anyhow!("Expected a bulk string at {}", offset));
71 }
72
73 if let Some((length, range)) = read_int(buffer, offset + 1)? {
74 let next_offset = range.next_offset();
75 if buffer.len() >= next_offset + length as usize + 2 {
76 return Ok(Some(Range {
77 start: next_offset,
78 end: next_offset + length as usize - 1,
79 }));
80 }
81 }
82
83 Ok(None)
84}
85
86impl Request {
87 #[inline]
88 pub fn parse(data: &mut BytesMut) -> anyhow::Result<Option<Request>> {
89 let len = data.len();
90
91 if len < 4 || data[data.len() - 2] != b'\r' {
92 Ok(None)
93 } else {
94 Request::parse_inner(data)
95 }
96 }
97
98 fn parse_inner(data: &mut BytesMut) -> anyhow::Result<Option<Request>> {
99 let mut offset = 0;
100 if data[0] != ASTERISK {
101 return Err(anyhow!("A request must be an array of bulk strings!"));
102 } else {
103 offset += 1;
104 }
105
106 let (mut num_args, range) = match read_int(&data, offset)? {
107 Some((num_args, range)) => (num_args - 1, range),
108 _ => return Ok(None),
109 };
110 offset = range.next_offset();
111
112 let command = match read_bulk_string(&data, offset)? {
113 Some(range) => range,
114 _ => return Ok(None),
115 };
116 offset = command.next_offset();
117
118 let mut arguments = Vec::with_capacity(num_args as usize);
119 while num_args > 0 {
120 if let Some(range) = read_bulk_string(&data, offset)? {
121 arguments.push(range);
122 num_args -= 1;
123 offset = range.next_offset();
124 } else {
125 return Ok(None);
126 }
127 }
128
129 let result_data = data.to_bytes();
130 if offset >= data.len() {
131 data.clear();
132 } else {
133 data.advance(offset);
134 }
135
136 Ok(Some(Request {
137 data: result_data,
138 command,
139 arguments,
140 }))
141 }
142
143 pub fn command(&self) -> &str {
144 std::str::from_utf8(&self.data[self.command.start..=self.command.end]).unwrap()
145 }
146
147 pub fn parameter_count(&self) -> usize {
148 self.arguments.len()
149 }
150
151 pub fn parameter(&self, index: usize) -> Result<Bytes> {
152 if index < self.arguments.len() {
153 Ok(self
154 .data
155 .slice(self.arguments[index].start..=self.arguments[index].end))
156 } else {
157 Err(anyhow!(
158 "Invalid parameter index {} (only {} are present)",
159 index,
160 self.arguments.len()
161 ))
162 }
163 }
164
165 pub fn str_parameter(&self, index: usize) -> Result<&str> {
166 if index < self.arguments.len() {
167 let range = self.arguments[index];
168 std::str::from_utf8(&self.data[range.start..=range.end]).with_context(|| {
169 format!(
170 "Failed to parse parameter {} (range {}) as UTF-8 string!",
171 index, range
172 )
173 })
174 } else {
175 Err(anyhow!(
176 "Invalid parameter index {} (only {} are present)",
177 index,
178 self.arguments.len()
179 ))
180 }
181 }
182
183 pub fn int_parameter(&self, index: usize) -> Result<i32> {
184 let string = self.str_parameter(index)?;
185 string.parse().with_context(|| {
186 format!(
187 "Failed to parse parameter {} ('{}') as integer!",
188 index, string
189 )
190 })
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use crate::request::Request;
197 use bytes::BytesMut;
198
199 #[test]
200 fn a_command_is_successfully_parsed() {
201 let request = Request::parse(&mut BytesMut::from(
202 "*3\r\n$10\r\ntest.hello\r\n$5\r\nWorld\r\n$2\r\n42\r\n",
203 ))
204 .unwrap()
205 .unwrap();
206
207 assert_eq!(request.parameter_count(), 2);
208 assert_eq!(request.command(), "test.hello");
209
210 assert_eq!(request.str_parameter(0).unwrap(), "World");
211 assert_eq!(
212 std::str::from_utf8(request.parameter(0).unwrap().as_ref()).unwrap(),
213 "World"
214 );
215 assert_eq!(request.int_parameter(1).unwrap(), 42);
216
217 assert_eq!(request.str_parameter(2).is_err(), true);
218 assert_eq!(request.int_parameter(2).is_err(), true);
219 assert_eq!(request.parameter(2).is_err(), true);
220 }
221
222 #[test]
223 fn missing_array_is_detected() {
224 let result = Request::parse(&mut BytesMut::from("+GET\r\n"));
225 assert_eq!(result.is_err(), true);
226 }
227
228 #[test]
229 fn non_bulk_string_is_detected() {
230 let result = Request::parse(&mut BytesMut::from("*1\r\n+GET\r\n"));
231 assert_eq!(result.is_err(), true);
232 }
233
234 #[test]
235 fn invalid_number_is_detected() {
236 let result = Request::parse(&mut BytesMut::from("*GET\r\n"));
237 assert_eq!(result.is_err(), true);
238 }
239
240 #[test]
241 fn an_incomplete_command_is_skipped() {
242 {
243 let result = Request::parse(&mut BytesMut::from("")).unwrap();
244 assert_eq!(result.is_none(), true);
245 }
246 {
247 let result = Request::parse(&mut BytesMut::from("*")).unwrap();
248 assert_eq!(result.is_none(), true);
249 }
250 {
251 let result = Request::parse(&mut BytesMut::from("*1")).unwrap();
252 assert_eq!(result.is_none(), true);
253 }
254 {
255 let result = Request::parse(&mut BytesMut::from("*1\r")).unwrap();
256 assert_eq!(result.is_none(), true);
257 }
258 {
259 let result = Request::parse(&mut BytesMut::from("*1\r\n")).unwrap();
260 assert_eq!(result.is_none(), true);
261 }
262 {
263 let result = Request::parse(&mut BytesMut::from("*2\r\n$10\r\ntest.h")).unwrap();
264 assert_eq!(result.is_none(), true);
265 }
266 {
267 let result =
268 Request::parse(&mut BytesMut::from("*2\r\n$10\r\ntest.hello\r\n")).unwrap();
269 assert_eq!(result.is_none(), true);
270 }
271 }
272}