1use sawp::error::Result;
45use sawp::parser::{Direction, Parse};
46use sawp::probe::{Probe, Status};
47use sawp::protocol::Protocol;
48use sawp_flags::{BitFlags, Flag, Flags};
49
50use nom::bytes::streaming::{take, take_until};
51use nom::character::streaming::crlf;
52use nom::number::streaming::be_u8;
53use nom::{AsBytes, FindToken, InputTakeAtPosition};
54
55use num_enum::TryFromPrimitive;
56
57use std::convert::TryFrom;
58
59#[cfg(feature = "ffi")]
61mod ffi;
62
63#[cfg(feature = "ffi")]
64use sawp_ffi::GenerateFFI;
65
66pub const CRLF: &[u8] = b"\r\n";
67pub const DATA_TYPE_TOKENS: &str = "$*+-:";
68pub const MAX_ARRAY_DEPTH: usize = 64;
69pub const MAX_BULK_STRING_LEN: usize = 1024 * 512;
71
72#[repr(u8)]
74#[derive(Clone, Copy, Debug, PartialEq, Eq, BitFlags)]
75pub enum ErrorFlags {
76 InvalidData = 0b0000_0001,
79 BulkStringExceedsMaxLen = 0b0000_0010,
82 MaxArrayDepthReached = 0b0000_0100,
85}
86
87#[derive(Clone, Copy, Debug, PartialEq, Eq, TryFromPrimitive)]
89#[repr(u8)]
90pub enum DataTypeToken {
91 BulkString = b'$',
96 Array = b'*',
101 SimpleString = b'+',
104 Error = b'-',
106 Integer = b':',
108 Unknown,
109}
110
111impl DataTypeToken {
112 pub fn from_raw(val: u8) -> Self {
113 DataTypeToken::try_from(val).unwrap_or(DataTypeToken::Unknown)
114 }
115}
116
117#[cfg_attr(feature = "ffi", derive(GenerateFFI))]
119#[cfg_attr(feature = "ffi", sawp_ffi(prefix = "sawp_resp"))]
120#[derive(Debug, PartialEq, Eq)]
121pub enum Entry {
122 Array(Vec<Entry>),
124 Error(Vec<u8>),
126 Integer(i64),
128 Invalid(Vec<u8>),
131 Nil,
134 String(Vec<u8>),
136}
137
138pub enum IntegerResult<'a> {
139 Integer(i64),
140 Data(&'a [u8]),
141}
142
143pub enum StringResult<'a> {
144 String(&'a [u8]),
145 Nil,
146 Invalid(&'a [u8], &'a [u8]), }
148
149#[cfg_attr(feature = "ffi", derive(GenerateFFI))]
151#[cfg_attr(feature = "ffi", sawp_ffi(prefix = "sawp_resp"))]
152#[derive(Debug, PartialEq, Eq)]
153pub struct Message {
154 pub entry: Entry,
155 #[cfg_attr(feature = "ffi", sawp_ffi(flag = "u8"))]
156 pub error_flags: Flags<ErrorFlags>,
157}
158
159impl Message {}
160
161#[derive(Debug)]
162pub struct Resp {}
163
164impl<'a> Protocol<'a> for Resp {
165 type Message = Message;
166
167 fn name() -> &'static str {
168 "resp"
169 }
170}
171
172impl<'a> Probe<'a> for Resp {
173 fn probe(&self, input: &'a [u8], direction: Direction) -> Status {
179 match self.parse(input, direction) {
180 Ok((
181 _,
182 Some(Message {
183 entry: Entry::Invalid(_),
184 error_flags: _,
185 }),
186 )) => Status::Unrecognized, Ok(_) => Status::Recognized,
188 Err(sawp::error::Error {
189 kind: sawp::error::ErrorKind::Incomplete(_),
190 }) => Status::Incomplete,
191 Err(_) => Status::Unrecognized,
192 }
193 }
194}
195
196impl Resp {
197 fn advance_if_crlf(input: &[u8]) -> &[u8] {
198 crlf::<_, (&[u8], nom::error::ErrorKind)>(input)
199 .map(|(rem, _)| rem)
200 .unwrap_or(input)
201 }
202
203 fn parse_integer(input: &[u8]) -> Result<(&[u8], IntegerResult, Flags<ErrorFlags>)> {
204 let (rem, raw_len) = take_until(CRLF)(input)?;
205 match std::str::from_utf8(raw_len) {
207 Ok(len_str) => match len_str.parse::<i64>() {
208 Ok(len) => Ok((
209 Resp::advance_if_crlf(rem),
210 IntegerResult::Integer(len),
211 ErrorFlags::none(),
212 )),
213 Err(_) => Ok((
214 Resp::advance_if_crlf(rem),
215 IntegerResult::Data(raw_len),
216 ErrorFlags::InvalidData.into(),
217 )),
218 },
219 Err(_) => Ok((
220 Resp::advance_if_crlf(rem),
221 IntegerResult::Data(raw_len),
222 ErrorFlags::InvalidData.into(),
223 )),
224 }
225 }
226
227 fn parse_bulk_string(input: &[u8]) -> Result<(&[u8], StringResult, Flags<ErrorFlags>)> {
231 let (rem, wrapped_length, mut error_flags) = Resp::parse_integer(input)?;
232 match wrapped_length {
233 IntegerResult::Integer(length) => {
234 if length >= 0 {
235 if length > MAX_BULK_STRING_LEN as i64 {
236 error_flags |= ErrorFlags::BulkStringExceedsMaxLen
237 }
238 let (rem, ret) = take(length as usize)(rem)?;
239 Ok((
241 Resp::advance_if_crlf(rem),
242 StringResult::String(ret),
243 error_flags,
244 ))
245 } else {
246 if length == -1 {
248 return Ok((Resp::advance_if_crlf(rem), StringResult::Nil, error_flags));
249 }
250 error_flags |= ErrorFlags::InvalidData;
251 Ok((
252 Resp::advance_if_crlf(rem),
253 StringResult::String(b""),
254 error_flags,
255 ))
256 }
257 }
258 IntegerResult::Data(bytes) => Ok((
259 Resp::advance_if_crlf(rem),
260 StringResult::Invalid(bytes, b""),
261 error_flags,
262 )),
263 }
264 }
265
266 fn parse_simple_string(input: &[u8]) -> Result<(&[u8], &[u8])> {
267 let (rem, ret) = take_until(CRLF)(input)?;
268 Ok((Resp::advance_if_crlf(rem), ret))
270 }
271
272 fn parse_entry(input: &[u8], array_depth: usize) -> Result<(&[u8], Entry, Flags<ErrorFlags>)> {
273 let (input, raw_token) = be_u8(input)?;
274 let token = DataTypeToken::from_raw(raw_token);
275 match token {
276 DataTypeToken::BulkString => {
277 let (rem, parsed_data, error_flags) = Resp::parse_bulk_string(input)?;
278 match parsed_data {
279 StringResult::String(string_data) => {
280 Ok((rem, Entry::String(string_data.to_vec()), error_flags))
281 }
282 StringResult::Nil => Ok((rem, Entry::Nil, error_flags)),
283 StringResult::Invalid(len, data) => {
284 Ok((rem, Entry::Invalid([len, data].concat()), error_flags))
285 }
286 }
287 }
288 DataTypeToken::Array => {
289 if array_depth < MAX_ARRAY_DEPTH {
290 let (mut local_input, length, mut error_flags) = Resp::parse_integer(input)?;
291 match length {
292 IntegerResult::Integer(length) if length >= 0 => {
293 let mut entries: Vec<Entry> = Vec::with_capacity(length as usize);
294
295 for _ in 0..length {
296 let (rem, entry, inner_error_flags) =
297 Resp::parse_entry(local_input, array_depth + 1)?;
298 error_flags |= inner_error_flags;
299 if error_flags.contains(ErrorFlags::MaxArrayDepthReached) {
300 return Ok((input, Entry::Array(entries), error_flags));
301 }
302 entries.push(entry);
303 local_input = rem;
304 }
305 Ok((local_input, Entry::Array(entries), error_flags))
306 }
307 IntegerResult::Integer(-1) => Ok((local_input, Entry::Nil, error_flags)),
308 IntegerResult::Integer(_length) => {
309 error_flags |= ErrorFlags::InvalidData;
310 Ok((
311 Resp::advance_if_crlf(local_input),
312 Entry::Array(vec![]),
313 error_flags,
314 ))
315 }
316 IntegerResult::Data(invalid_length) => Ok((
317 Resp::advance_if_crlf(local_input),
318 Entry::Invalid(
319 [b"*", invalid_length].concat(), ),
321 error_flags,
322 )),
323 }
324 } else {
325 Ok((
326 input,
327 Entry::Invalid(vec![]),
328 ErrorFlags::MaxArrayDepthReached.into(),
329 ))
330 }
331 }
332 DataTypeToken::SimpleString => {
333 let (rem, ret) = Resp::parse_simple_string(input)?;
334 Ok((rem, Entry::String(ret.to_vec()), ErrorFlags::none()))
335 }
336 DataTypeToken::Error => {
337 let (rem, ret) = Resp::parse_simple_string(input)?;
338 Ok((
339 rem,
340 Entry::Error(ret.as_bytes().to_vec()),
341 ErrorFlags::none(),
342 ))
343 }
344 DataTypeToken::Integer => {
345 let (rem, ret, error_flags) = Resp::parse_integer(input)?;
346 match ret {
347 IntegerResult::Integer(ret) => Ok((rem, Entry::Integer(ret), error_flags)),
348 IntegerResult::Data(ret) => Ok((
349 rem,
350 Entry::Invalid([b":", ret].concat()), error_flags,
352 )),
353 }
354 }
355 DataTypeToken::Unknown => {
356 let (rem, data) =
359 input.split_at_position_complete(|e: u8| DATA_TYPE_TOKENS.find_token(e))?;
360 Ok((
361 Resp::advance_if_crlf(rem),
362 Entry::Invalid([&[raw_token], data].concat()),
363 ErrorFlags::InvalidData.into(),
364 ))
365 }
366 }
367 }
368}
369
370impl<'a> Parse<'a> for Resp {
374 fn parse(
375 &self,
376 input: &'a [u8],
377 _direction: Direction,
378 ) -> Result<(&'a [u8], Option<Self::Message>)> {
379 let (rem, entry, error_flags) = Resp::parse_entry(input, 0)?;
380
381 Ok((rem, Some(Message { entry, error_flags })))
382 }
383}
384
385#[cfg(test)]
386mod test {
387 use crate::{Entry, ErrorFlags, Message, Resp};
388 use rstest::rstest;
389 use sawp::error::Result;
390 use sawp::parser::{Direction, Parse};
391 use sawp_flags::Flag;
392
393 #[rstest(
394 input,
395 expected,
396 case::parse_simple_string(
397 b"+OK\r\n",
398 Ok((
399 0,
400 Some(
401 Message {
402 entry: Entry::String(b"OK".to_vec()),
403 error_flags: ErrorFlags::none(),
404 }
405 )
406 ))
407 ),
408 case::parse_error(
409 b"-Error message\r\n",
410 Ok((
411 0,
412 Some(
413 Message {
414 entry: Entry::Error(b"Error message".to_vec()),
415 error_flags: ErrorFlags::none(),
416 }
417 )
418 ))
419 ),
420 case::parse_integer(
421 b":1000\r\n",
422 Ok((
423 0,
424 Some(
425 Message {
426 entry: Entry::Integer(1000),
427 error_flags: ErrorFlags::none(),
428 }
429 )
430 ))
431 ),
432 case::parse_bulk_string(
433 b"$6\r\nfoobar\r\n",
434 Ok((
435 0,
436 Some(
437 Message {
438 entry: Entry::String(b"foobar".to_vec()),
439 error_flags: ErrorFlags::none(),
440 }
441 )
442 ))
443 ),
444 case::parse_array(
445 b"*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",
446 Ok((
447 0,
448 Some(
449 Message {
450 entry: Entry::Array(vec!(
451 Entry::String(b"foo".to_vec()),
452 Entry::String(b"bar".to_vec()),
453 )),
454 error_flags: ErrorFlags::none(),
455 }
456 )
457 ))
458 ),
459 case::parse_null_value_array(
460 b"*-1\r\n",
461 Ok((
462 0,
463 Some(
464 Message {
465 entry: Entry::Nil,
466 error_flags: ErrorFlags::none(),
467 }
468 )
469 ))
470 ),
471 case::invalid_negative_array_length(
472 b"*-2\r\n",
473 Ok((
474 0,
475 Some(
476 Message {
477 entry: Entry::Array(vec![]),
478 error_flags: ErrorFlags::InvalidData.into(),
479 }
480 )
481 ))
482 ),
483 case::parse_nested_array(
484 b"*1\r\n*2\r\n$3\r\nfoo\r\n$3\r\nbar\r\n",
485 Ok((
486 0,
487 Some(
488 Message {
489 entry:
490 Entry::Array(vec!(
491 Entry::Array(vec!(
492 Entry::String(b"foo".to_vec()),
493 Entry::String(b"bar".to_vec()),
494 )),
495 ),
496 ),
497 error_flags: ErrorFlags::none(),
498 }
499 )
500 ))
501 ),
502 case::parse_empty_array(
503 b"*0\r\n",
504 Ok((
505 0,
506 Some(
507 Message {
508 entry: Entry::Array(vec![]),
509 error_flags: ErrorFlags::none(),
510 }
511 )
512 ))
513 ),
514 case::nested_array_exceeds_max_depth(
515 b"*2\r\n$3\r\nfoo\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n\
516 *1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n\
517 *1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n\
518 *1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n\
519 *1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n", Ok((
521 268,
522 Some(
523 Message {
524 entry:
525 Entry::Array(vec![
526 Entry::String(b"foo".to_vec()),
527 ]),
528 error_flags: ErrorFlags::MaxArrayDepthReached.into(),
529 }
530 )
531 ))
532 ),
533 case::parse_empty_bulk_string_with_trailing_negative_int(
534 b"*2\r\n$0\r\n\r\n:-100\r\n",
535 Ok((
536 0,
537 Some(
538 Message {
539 entry: Entry::Array(vec![
540 Entry::String(b"".to_vec()),
541 Entry::Integer(-100),
542 ]
543 ),
544 error_flags: ErrorFlags::none(),
545 }
546 )
547 ))
548 ),
549 case::parse_null_value_string(
550 b"$-1\r\n",
551 Ok((
552 0,
553 Some(
554 Message {
555 entry: Entry::Nil,
556 error_flags: ErrorFlags::none(),
557 }
558 )
559 ))
560 ),
561 case::invalid_negative_bulk_string_length(
562 b"$-2\r\n",
563 Ok((
564 0,
565 Some(
566 Message {
567 entry: Entry::String(b"".to_vec()),
568 error_flags: ErrorFlags::InvalidData.into(),
569 }
570 )
571 ))
572 ),
573 case::invalid_type_token(
574 b"!1\r\n",
575 Ok((
576 0,
577 Some(
578 Message {
579 entry: Entry::Invalid(b"!1\r\n".to_vec()),
580 error_flags: ErrorFlags::InvalidData.into(),
581 }
582 )
583 ))
584 ),
585 case::invalid_type_token_mixed_with_good_data(
586 b"!1\r\n$6\r\nfoobar\r\n",
587 Ok((
588 12,
589 Some(
590 Message {
591 entry: Entry::Invalid(b"!1\r\n".to_vec()),
592 error_flags: ErrorFlags::InvalidData.into(),
593 }
594 )
595 ))
596 ),
597 case::missing_type_token(
598 b"1\r\n$6\r\nfoobar\r\n",
599 Ok((
600 12,
601 Some(
602 Message {
603 entry: Entry::Invalid(b"1\r\n".to_vec()),
604 error_flags: ErrorFlags::InvalidData.into(),
605 }
606 )
607 ))
608 ),
609 case::parse_too_big_integer(
610 b":9223372036854775808\r\n", Ok((
612 0,
613 Some(
614 Message {
615 entry: Entry::Invalid(b":9223372036854775808".to_vec()),
616 error_flags: ErrorFlags::InvalidData.into(),
617 }
618 )
619 ))
620 ),
621 case::parse_too_small_integer(
622 b":-9223372036854775809\r\n", Ok((
624 0,
625 Some(
626 Message {
627 entry: Entry::Invalid(b":-9223372036854775809".to_vec()),
628 error_flags: ErrorFlags::InvalidData.into(),
629 }
630 )
631 ))
632 ),
633 case::parse_invalid_integer(
634 b":cats\r\n",
635 Ok((
636 0,
637 Some(
638 Message {
639 entry: Entry::Invalid(b":cats".to_vec()),
640 error_flags: ErrorFlags::InvalidData.into(),
641 }
642 )
643 ))
644 ),
645 )]
646 fn resp(input: &[u8], expected: Result<(usize, Option<Message>)>) {
647 let resp = Resp {};
648 assert_eq!(
649 resp.parse(input, Direction::Unknown)
650 .map(|(rem, msg)| (rem.len(), msg)),
651 expected
652 );
653 }
654}