1use std::io::Cursor;
11
12use bytes::Bytes;
13
14use crate::error::ProtocolError;
15use crate::types::Frame;
16
17pub fn parse_frame(buf: &[u8]) -> Result<Option<(Frame, usize)>, ProtocolError> {
23 if buf.is_empty() {
24 return Ok(None);
25 }
26
27 let mut cursor = Cursor::new(buf);
28
29 match check(&mut cursor) {
30 Ok(()) => {
31 cursor.set_position(0);
33 let frame = parse(&mut cursor)?;
34 let consumed = cursor.position() as usize;
35 Ok(Some((frame, consumed)))
36 }
37 Err(ProtocolError::Incomplete) => Ok(None),
38 Err(e) => Err(e),
39 }
40}
41
42fn check(cursor: &mut Cursor<&[u8]>) -> Result<(), ProtocolError> {
49 let prefix = read_byte(cursor)?;
50
51 match prefix {
52 b'+' | b'-' => check_line(cursor),
53 b':' => check_line(cursor),
54 b'$' => check_bulk(cursor),
55 b'*' => check_array(cursor),
56 b'_' => check_line(cursor),
57 b'%' => check_map(cursor),
58 other => Err(ProtocolError::InvalidPrefix(other)),
59 }
60}
61
62fn check_line(cursor: &mut Cursor<&[u8]>) -> Result<(), ProtocolError> {
63 find_crlf(cursor)?;
64 Ok(())
65}
66
67fn check_bulk(cursor: &mut Cursor<&[u8]>) -> Result<(), ProtocolError> {
68 let len = read_integer_line(cursor)?;
69 if len < 0 {
70 return Err(ProtocolError::InvalidFrameLength(len));
71 }
72 let len = len as usize;
73
74 let remaining = remaining(cursor);
76 if remaining < len + 2 {
77 return Err(ProtocolError::Incomplete);
78 }
79
80 let pos = cursor.position() as usize;
81 let buf = cursor.get_ref();
83 if buf[pos + len] != b'\r' || buf[pos + len + 1] != b'\n' {
84 return Err(ProtocolError::InvalidFrameLength(len as i64));
85 }
86
87 cursor.set_position((pos + len + 2) as u64);
88 Ok(())
89}
90
91fn check_array(cursor: &mut Cursor<&[u8]>) -> Result<(), ProtocolError> {
92 let count = read_integer_line(cursor)?;
93 if count < 0 {
94 return Err(ProtocolError::InvalidFrameLength(count));
95 }
96
97 for _ in 0..count {
98 check(cursor)?;
99 }
100 Ok(())
101}
102
103fn check_map(cursor: &mut Cursor<&[u8]>) -> Result<(), ProtocolError> {
104 let count = read_integer_line(cursor)?;
105 if count < 0 {
106 return Err(ProtocolError::InvalidFrameLength(count));
107 }
108
109 for _ in 0..count {
110 check(cursor)?; check(cursor)?; }
113 Ok(())
114}
115
116fn parse(cursor: &mut Cursor<&[u8]>) -> Result<Frame, ProtocolError> {
121 let prefix = read_byte(cursor)?;
122
123 match prefix {
124 b'+' => {
125 let line = read_line(cursor)?;
126 let s = std::str::from_utf8(line).map_err(|_| {
127 ProtocolError::InvalidCommandFrame("invalid utf-8 in simple string".into())
128 })?;
129 Ok(Frame::Simple(s.to_owned()))
130 }
131 b'-' => {
132 let line = read_line(cursor)?;
133 let s = std::str::from_utf8(line).map_err(|_| {
134 ProtocolError::InvalidCommandFrame("invalid utf-8 in error string".into())
135 })?;
136 Ok(Frame::Error(s.to_owned()))
137 }
138 b':' => {
139 let val = read_integer_line(cursor)?;
140 Ok(Frame::Integer(val))
141 }
142 b'$' => {
143 let len = read_integer_line(cursor)? as usize;
144 let pos = cursor.position() as usize;
145 let data = &cursor.get_ref()[pos..pos + len];
146 cursor.set_position((pos + len + 2) as u64); Ok(Frame::Bulk(Bytes::copy_from_slice(data)))
148 }
149 b'*' => {
150 let count = read_integer_line(cursor)? as usize;
151 let mut frames = Vec::with_capacity(count);
152 for _ in 0..count {
153 frames.push(parse(cursor)?);
154 }
155 Ok(Frame::Array(frames))
156 }
157 b'_' => {
158 let _ = read_line(cursor)?;
160 Ok(Frame::Null)
161 }
162 b'%' => {
163 let count = read_integer_line(cursor)? as usize;
164 let mut pairs = Vec::with_capacity(count);
165 for _ in 0..count {
166 let key = parse(cursor)?;
167 let val = parse(cursor)?;
168 pairs.push((key, val));
169 }
170 Ok(Frame::Map(pairs))
171 }
172 other => Err(ProtocolError::InvalidPrefix(other)),
174 }
175}
176
177fn read_byte(cursor: &mut Cursor<&[u8]>) -> Result<u8, ProtocolError> {
182 let pos = cursor.position() as usize;
183 if pos >= cursor.get_ref().len() {
184 return Err(ProtocolError::Incomplete);
185 }
186 cursor.set_position((pos + 1) as u64);
187 Ok(cursor.get_ref()[pos])
188}
189
190fn read_line<'a>(cursor: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], ProtocolError> {
193 let start = cursor.position() as usize;
194 let end = find_crlf(cursor)?;
195 Ok(&cursor.get_ref()[start..end])
196}
197
198fn read_integer_line(cursor: &mut Cursor<&[u8]>) -> Result<i64, ProtocolError> {
200 let line = read_line(cursor)?;
201 parse_i64(line)
202}
203
204fn find_crlf(cursor: &mut Cursor<&[u8]>) -> Result<usize, ProtocolError> {
207 let buf = cursor.get_ref();
208 let start = cursor.position() as usize;
209
210 if start >= buf.len() {
211 return Err(ProtocolError::Incomplete);
212 }
213
214 for i in start..buf.len().saturating_sub(1) {
216 if buf[i] == b'\r' && buf[i + 1] == b'\n' {
217 cursor.set_position((i + 2) as u64);
218 return Ok(i);
219 }
220 }
221
222 Err(ProtocolError::Incomplete)
223}
224
225fn remaining(cursor: &Cursor<&[u8]>) -> usize {
226 let len = cursor.get_ref().len();
227 let pos = cursor.position() as usize;
228 len.saturating_sub(pos)
229}
230
231fn parse_i64(buf: &[u8]) -> Result<i64, ProtocolError> {
232 let s = std::str::from_utf8(buf).map_err(|_| ProtocolError::InvalidInteger)?;
233 s.parse::<i64>().map_err(|_| ProtocolError::InvalidInteger)
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 fn must_parse(input: &[u8]) -> Frame {
241 let (frame, consumed) = parse_frame(input)
242 .expect("parse should not error")
243 .expect("parse should return a frame");
244 assert_eq!(consumed, input.len(), "should consume entire input");
245 frame
246 }
247
248 #[test]
249 fn simple_string() {
250 assert_eq!(must_parse(b"+OK\r\n"), Frame::Simple("OK".into()));
251 assert_eq!(
252 must_parse(b"+hello world\r\n"),
253 Frame::Simple("hello world".into())
254 );
255 }
256
257 #[test]
258 fn simple_error() {
259 assert_eq!(
260 must_parse(b"-ERR unknown command\r\n"),
261 Frame::Error("ERR unknown command".into())
262 );
263 }
264
265 #[test]
266 fn integer() {
267 assert_eq!(must_parse(b":42\r\n"), Frame::Integer(42));
268 assert_eq!(must_parse(b":0\r\n"), Frame::Integer(0));
269 assert_eq!(must_parse(b":-1\r\n"), Frame::Integer(-1));
270 assert_eq!(
271 must_parse(b":9223372036854775807\r\n"),
272 Frame::Integer(i64::MAX)
273 );
274 assert_eq!(
275 must_parse(b":-9223372036854775808\r\n"),
276 Frame::Integer(i64::MIN)
277 );
278 }
279
280 #[test]
281 fn bulk_string() {
282 assert_eq!(
283 must_parse(b"$5\r\nhello\r\n"),
284 Frame::Bulk(Bytes::from_static(b"hello"))
285 );
286 }
287
288 #[test]
289 fn empty_bulk_string() {
290 assert_eq!(
291 must_parse(b"$0\r\n\r\n"),
292 Frame::Bulk(Bytes::from_static(b""))
293 );
294 }
295
296 #[test]
297 fn bulk_string_with_binary() {
298 let input = b"$4\r\n\x00\x01\x02\x03\r\n";
299 assert_eq!(
300 must_parse(input),
301 Frame::Bulk(Bytes::copy_from_slice(&[0, 1, 2, 3]))
302 );
303 }
304
305 #[test]
306 fn null() {
307 assert_eq!(must_parse(b"_\r\n"), Frame::Null);
308 }
309
310 #[test]
311 fn array() {
312 let input = b"*2\r\n+hello\r\n+world\r\n";
313 assert_eq!(
314 must_parse(input),
315 Frame::Array(vec![
316 Frame::Simple("hello".into()),
317 Frame::Simple("world".into()),
318 ])
319 );
320 }
321
322 #[test]
323 fn empty_array() {
324 assert_eq!(must_parse(b"*0\r\n"), Frame::Array(vec![]));
325 }
326
327 #[test]
328 fn nested_array() {
329 let input = b"*2\r\n*2\r\n:1\r\n:2\r\n*2\r\n:3\r\n:4\r\n";
330 assert_eq!(
331 must_parse(input),
332 Frame::Array(vec![
333 Frame::Array(vec![Frame::Integer(1), Frame::Integer(2)]),
334 Frame::Array(vec![Frame::Integer(3), Frame::Integer(4)]),
335 ])
336 );
337 }
338
339 #[test]
340 fn array_with_null() {
341 let input = b"*3\r\n+OK\r\n_\r\n:1\r\n";
342 assert_eq!(
343 must_parse(input),
344 Frame::Array(vec![
345 Frame::Simple("OK".into()),
346 Frame::Null,
347 Frame::Integer(1),
348 ])
349 );
350 }
351
352 #[test]
353 fn map() {
354 let input = b"%2\r\n+key1\r\n:1\r\n+key2\r\n:2\r\n";
355 assert_eq!(
356 must_parse(input),
357 Frame::Map(vec![
358 (Frame::Simple("key1".into()), Frame::Integer(1)),
359 (Frame::Simple("key2".into()), Frame::Integer(2)),
360 ])
361 );
362 }
363
364 #[test]
365 fn incomplete_returns_none() {
366 assert_eq!(parse_frame(b"").unwrap(), None);
367 assert_eq!(parse_frame(b"+OK").unwrap(), None);
368 assert_eq!(parse_frame(b"+OK\r").unwrap(), None);
369 assert_eq!(parse_frame(b"$5\r\nhel").unwrap(), None);
370 assert_eq!(parse_frame(b"*2\r\n+OK\r\n").unwrap(), None);
371 }
372
373 #[test]
374 fn invalid_prefix() {
375 let err = parse_frame(b"~invalid\r\n").unwrap_err();
376 assert_eq!(err, ProtocolError::InvalidPrefix(b'~'));
377 }
378
379 #[test]
380 fn invalid_integer() {
381 let err = parse_frame(b":abc\r\n").unwrap_err();
382 assert_eq!(err, ProtocolError::InvalidInteger);
383 }
384
385 #[test]
386 fn negative_bulk_length() {
387 let err = parse_frame(b"$-1\r\n").unwrap_err();
388 assert!(matches!(err, ProtocolError::InvalidFrameLength(-1)));
389 }
390
391 #[test]
392 fn parse_consumes_exact_bytes() {
393 let buf = b"+OK\r\ntrailing";
395 let (frame, consumed) = parse_frame(buf).unwrap().unwrap();
396 assert_eq!(frame, Frame::Simple("OK".into()));
397 assert_eq!(consumed, 5);
398 }
399}