1use std::io::Cursor;
19
20use bytes::Bytes;
21
22use crate::error::ProtocolError;
23use crate::types::Frame;
24
25const MAX_NESTING_DEPTH: usize = 64;
28
29const MAX_ARRAY_ELEMENTS: usize = 1_048_576;
33
34const MAX_BULK_LEN: i64 = 512 * 1024 * 1024;
36
37const PREALLOC_CAP: usize = 1024;
42
43#[inline]
49pub fn parse_frame(buf: &[u8]) -> Result<Option<(Frame, usize)>, ProtocolError> {
50 if buf.is_empty() {
51 return Ok(None);
52 }
53
54 let mut cursor = Cursor::new(buf);
55
56 match try_parse(&mut cursor, 0) {
57 Ok(frame) => {
58 let consumed = cursor.position() as usize;
59 Ok(Some((frame, consumed)))
60 }
61 Err(ProtocolError::Incomplete) => Ok(None),
62 Err(e) => Err(e),
63 }
64}
65
66fn try_parse(cursor: &mut Cursor<&[u8]>, depth: usize) -> Result<Frame, ProtocolError> {
73 let prefix = read_byte(cursor)?;
74
75 match prefix {
76 b'+' => {
77 let line = read_line(cursor)?;
78 let s = std::str::from_utf8(line).map_err(|_| {
79 ProtocolError::InvalidCommandFrame("invalid utf-8 in simple string".into())
80 })?;
81 Ok(Frame::Simple(s.to_owned()))
82 }
83 b'-' => {
84 let line = read_line(cursor)?;
85 let s = std::str::from_utf8(line).map_err(|_| {
86 ProtocolError::InvalidCommandFrame("invalid utf-8 in error string".into())
87 })?;
88 Ok(Frame::Error(s.to_owned()))
89 }
90 b':' => {
91 let val = read_integer_line(cursor)?;
92 Ok(Frame::Integer(val))
93 }
94 b'$' => {
95 let len = read_integer_line(cursor)?;
96 if len < 0 {
97 return Err(ProtocolError::InvalidFrameLength(len));
98 }
99 if len > MAX_BULK_LEN {
100 return Err(ProtocolError::BulkStringTooLarge(len as usize));
101 }
102 let len = len as usize;
103
104 let remaining = remaining(cursor);
106 if remaining < len + 2 {
107 return Err(ProtocolError::Incomplete);
108 }
109
110 let pos = cursor.position() as usize;
111 let buf = cursor.get_ref();
112
113 if buf[pos + len] != b'\r' || buf[pos + len + 1] != b'\n' {
115 return Err(ProtocolError::InvalidFrameLength(len as i64));
116 }
117
118 let data = &buf[pos..pos + len];
119 cursor.set_position((pos + len + 2) as u64);
120 Ok(Frame::Bulk(Bytes::copy_from_slice(data)))
121 }
122 b'*' => {
123 let next_depth = depth + 1;
124 if next_depth > MAX_NESTING_DEPTH {
125 return Err(ProtocolError::NestingTooDeep(MAX_NESTING_DEPTH));
126 }
127
128 let count = read_integer_line(cursor)?;
129 if count < 0 {
130 return Err(ProtocolError::InvalidFrameLength(count));
131 }
132 if count as usize > MAX_ARRAY_ELEMENTS {
133 return Err(ProtocolError::TooManyElements(count as usize));
134 }
135
136 let count = count as usize;
137 let mut frames = Vec::with_capacity(count.min(PREALLOC_CAP));
138 for _ in 0..count {
139 frames.push(try_parse(cursor, next_depth)?);
140 }
141 Ok(Frame::Array(frames))
142 }
143 b'_' => {
144 let _ = read_line(cursor)?;
146 Ok(Frame::Null)
147 }
148 b'%' => {
149 let next_depth = depth + 1;
150 if next_depth > MAX_NESTING_DEPTH {
151 return Err(ProtocolError::NestingTooDeep(MAX_NESTING_DEPTH));
152 }
153
154 let count = read_integer_line(cursor)?;
155 if count < 0 {
156 return Err(ProtocolError::InvalidFrameLength(count));
157 }
158 if count as usize > MAX_ARRAY_ELEMENTS {
159 return Err(ProtocolError::TooManyElements(count as usize));
160 }
161
162 let count = count as usize;
163 let mut pairs = Vec::with_capacity(count.min(PREALLOC_CAP));
164 for _ in 0..count {
165 let key = try_parse(cursor, next_depth)?;
166 let val = try_parse(cursor, next_depth)?;
167 pairs.push((key, val));
168 }
169 Ok(Frame::Map(pairs))
170 }
171 other => Err(ProtocolError::InvalidPrefix(other)),
172 }
173}
174
175fn read_byte(cursor: &mut Cursor<&[u8]>) -> Result<u8, ProtocolError> {
180 let pos = cursor.position() as usize;
181 if pos >= cursor.get_ref().len() {
182 return Err(ProtocolError::Incomplete);
183 }
184 cursor.set_position((pos + 1) as u64);
185 Ok(cursor.get_ref()[pos])
186}
187
188fn read_line<'a>(cursor: &mut Cursor<&'a [u8]>) -> Result<&'a [u8], ProtocolError> {
191 let start = cursor.position() as usize;
192 let end = find_crlf(cursor)?;
193 Ok(&cursor.get_ref()[start..end])
194}
195
196fn read_integer_line(cursor: &mut Cursor<&[u8]>) -> Result<i64, ProtocolError> {
198 let line = read_line(cursor)?;
199 parse_i64(line)
200}
201
202fn find_crlf(cursor: &mut Cursor<&[u8]>) -> Result<usize, ProtocolError> {
205 let buf = cursor.get_ref();
206 let start = cursor.position() as usize;
207
208 if start >= buf.len() {
209 return Err(ProtocolError::Incomplete);
210 }
211
212 let mut pos = start;
215 while let Some(offset) = memchr::memchr(b'\r', &buf[pos..]) {
216 let cr = pos + offset;
217 if cr + 1 < buf.len() && buf[cr + 1] == b'\n' {
218 cursor.set_position((cr + 2) as u64);
219 return Ok(cr);
220 }
221 pos = cr + 1;
223 }
224
225 Err(ProtocolError::Incomplete)
226}
227
228fn remaining(cursor: &Cursor<&[u8]>) -> usize {
229 let len = cursor.get_ref().len();
230 let pos = cursor.position() as usize;
231 len.saturating_sub(pos)
232}
233
234fn parse_i64(buf: &[u8]) -> Result<i64, ProtocolError> {
235 let s = std::str::from_utf8(buf).map_err(|_| ProtocolError::InvalidInteger)?;
236 s.parse::<i64>().map_err(|_| ProtocolError::InvalidInteger)
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 fn must_parse(input: &[u8]) -> Frame {
244 let (frame, consumed) = parse_frame(input)
245 .expect("parse should not error")
246 .expect("parse should return a frame");
247 assert_eq!(consumed, input.len(), "should consume entire input");
248 frame
249 }
250
251 #[test]
252 fn simple_string() {
253 assert_eq!(must_parse(b"+OK\r\n"), Frame::Simple("OK".into()));
254 assert_eq!(
255 must_parse(b"+hello world\r\n"),
256 Frame::Simple("hello world".into())
257 );
258 }
259
260 #[test]
261 fn simple_error() {
262 assert_eq!(
263 must_parse(b"-ERR unknown command\r\n"),
264 Frame::Error("ERR unknown command".into())
265 );
266 }
267
268 #[test]
269 fn integer() {
270 assert_eq!(must_parse(b":42\r\n"), Frame::Integer(42));
271 assert_eq!(must_parse(b":0\r\n"), Frame::Integer(0));
272 assert_eq!(must_parse(b":-1\r\n"), Frame::Integer(-1));
273 assert_eq!(
274 must_parse(b":9223372036854775807\r\n"),
275 Frame::Integer(i64::MAX)
276 );
277 assert_eq!(
278 must_parse(b":-9223372036854775808\r\n"),
279 Frame::Integer(i64::MIN)
280 );
281 }
282
283 #[test]
284 fn bulk_string() {
285 assert_eq!(
286 must_parse(b"$5\r\nhello\r\n"),
287 Frame::Bulk(Bytes::from_static(b"hello"))
288 );
289 }
290
291 #[test]
292 fn empty_bulk_string() {
293 assert_eq!(
294 must_parse(b"$0\r\n\r\n"),
295 Frame::Bulk(Bytes::from_static(b""))
296 );
297 }
298
299 #[test]
300 fn bulk_string_with_binary() {
301 let input = b"$4\r\n\x00\x01\x02\x03\r\n";
302 assert_eq!(
303 must_parse(input),
304 Frame::Bulk(Bytes::copy_from_slice(&[0, 1, 2, 3]))
305 );
306 }
307
308 #[test]
309 fn null() {
310 assert_eq!(must_parse(b"_\r\n"), Frame::Null);
311 }
312
313 #[test]
314 fn array() {
315 let input = b"*2\r\n+hello\r\n+world\r\n";
316 assert_eq!(
317 must_parse(input),
318 Frame::Array(vec![
319 Frame::Simple("hello".into()),
320 Frame::Simple("world".into()),
321 ])
322 );
323 }
324
325 #[test]
326 fn empty_array() {
327 assert_eq!(must_parse(b"*0\r\n"), Frame::Array(vec![]));
328 }
329
330 #[test]
331 fn nested_array() {
332 let input = b"*2\r\n*2\r\n:1\r\n:2\r\n*2\r\n:3\r\n:4\r\n";
333 assert_eq!(
334 must_parse(input),
335 Frame::Array(vec![
336 Frame::Array(vec![Frame::Integer(1), Frame::Integer(2)]),
337 Frame::Array(vec![Frame::Integer(3), Frame::Integer(4)]),
338 ])
339 );
340 }
341
342 #[test]
343 fn array_with_null() {
344 let input = b"*3\r\n+OK\r\n_\r\n:1\r\n";
345 assert_eq!(
346 must_parse(input),
347 Frame::Array(vec![
348 Frame::Simple("OK".into()),
349 Frame::Null,
350 Frame::Integer(1),
351 ])
352 );
353 }
354
355 #[test]
356 fn map() {
357 let input = b"%2\r\n+key1\r\n:1\r\n+key2\r\n:2\r\n";
358 assert_eq!(
359 must_parse(input),
360 Frame::Map(vec![
361 (Frame::Simple("key1".into()), Frame::Integer(1)),
362 (Frame::Simple("key2".into()), Frame::Integer(2)),
363 ])
364 );
365 }
366
367 #[test]
368 fn incomplete_returns_none() {
369 assert_eq!(parse_frame(b"").unwrap(), None);
370 assert_eq!(parse_frame(b"+OK").unwrap(), None);
371 assert_eq!(parse_frame(b"+OK\r").unwrap(), None);
372 assert_eq!(parse_frame(b"$5\r\nhel").unwrap(), None);
373 assert_eq!(parse_frame(b"*2\r\n+OK\r\n").unwrap(), None);
374 }
375
376 #[test]
377 fn invalid_prefix() {
378 let err = parse_frame(b"~invalid\r\n").unwrap_err();
379 assert_eq!(err, ProtocolError::InvalidPrefix(b'~'));
380 }
381
382 #[test]
383 fn invalid_integer() {
384 let err = parse_frame(b":abc\r\n").unwrap_err();
385 assert_eq!(err, ProtocolError::InvalidInteger);
386 }
387
388 #[test]
389 fn negative_bulk_length() {
390 let err = parse_frame(b"$-1\r\n").unwrap_err();
391 assert!(matches!(err, ProtocolError::InvalidFrameLength(-1)));
392 }
393
394 #[test]
395 fn parse_consumes_exact_bytes() {
396 let buf = b"+OK\r\ntrailing";
398 let (frame, consumed) = parse_frame(buf).unwrap().unwrap();
399 assert_eq!(frame, Frame::Simple("OK".into()));
400 assert_eq!(consumed, 5);
401 }
402
403 #[test]
404 fn deeply_nested_array_rejected() {
405 let mut buf = Vec::new();
407 for _ in 0..65 {
408 buf.extend_from_slice(b"*1\r\n");
409 }
410 buf.extend_from_slice(b":1\r\n"); let err = parse_frame(&buf).unwrap_err();
413 assert!(
414 matches!(err, ProtocolError::NestingTooDeep(64)),
415 "expected NestingTooDeep, got {err:?}"
416 );
417 }
418
419 #[test]
420 fn nesting_at_limit_accepted() {
421 let mut buf = Vec::new();
423 for _ in 0..64 {
424 buf.extend_from_slice(b"*1\r\n");
425 }
426 buf.extend_from_slice(b":1\r\n");
427
428 let result = parse_frame(&buf);
429 assert!(result.is_ok(), "64 levels of nesting should be accepted");
430 assert!(result.unwrap().is_some());
431 }
432}