1use std::{cmp, collections::HashMap, convert::TryFrom, hash::BuildHasher, hash::Hash, str};
3use crate::codec::{Decoder, Encoder};
4use crate::bytes::{BufMut, Bytes, BytesMut, ByteString};
5use super::errors::Error;
6pub struct Codec;
8use crate::bytes::Buf;
9
10impl Encoder for Codec {
11 type EncodeItem = Request;
12 type EncodeError = Error;
13
14 fn encode(&self, msg: Request, buf: &mut BytesMut) -> Result<(), Self::EncodeError> {
15 match msg {
16 Request::Array(ary) => {
17 write_header(b'*', ary.len() as i64, buf, 0);
18 for v in ary {
19 self.encode(v, buf)?;
20 }
21 }
22 Request::BulkString(bstr) => {
23 let len = bstr.0.len();
24 write_header(b'$', len as i64, buf, len + 2);
25 buf.extend_from_slice(&bstr.0[..]);
26 write_rn(buf);
27 }
28 Request::BulkStatic(bstr) => {
29 let len = bstr.len();
30 write_header(b'$', len as i64, buf, len + 2);
31 buf.extend_from_slice(bstr);
32 write_rn(buf);
33 }
34 Request::BulkInteger(i) => {
35 let mut len_buf = [0; 32];
36 let size = itoa::write(&mut len_buf[..], i).unwrap();
37 write_header(b'$', size as i64, buf, size + 2);
38 buf.extend_from_slice(&len_buf[..size]);
39 write_rn(buf);
40 }
41 Request::String(ref string) => {
42 write_string(b'+', string, buf);
43 }
44 Request::Integer(val) => {
45 write_header(b':', val, buf, 0);
47 }
48 }
49 Ok(())
50 }
51}
52
53impl Decoder for Codec {
54 type DecodeItem = Response;
55 type DecodeError = Error;
56
57 fn decode(&self, buf: &mut BytesMut) -> Result<Option<Self::DecodeItem>, Self::DecodeError> {
58 match decode(buf, 0)? {
59 Some((pos, item)) => {
60 buf.advance(pos);
61 Ok(Some(item))
62 }
63 None => Ok(None),
64 }
65 }
66}
67
68#[derive(Debug, Clone, Eq, PartialEq, Hash)]
69pub struct BulkString(Bytes);
74
75impl BulkString {
76 pub fn from_static(data: &'static str) -> Self {
78 BulkString(Bytes::from_static(data.as_ref()))
79 }
80
81 pub fn from_bstatic(data: &'static [u8]) -> Self {
83 BulkString(Bytes::from_static(data))
84 }
85}
86
87impl From<ByteString> for BulkString {
88 fn from(val: ByteString) -> BulkString {
89 BulkString(val.into_bytes())
90 }
91}
92
93impl From<String> for BulkString {
94 fn from(val: String) -> BulkString {
95 BulkString(Bytes::from(val))
96 }
97}
98
99impl<'a> From<&'a String> for BulkString {
100 fn from(val: &'a String) -> BulkString {
101 BulkString(Bytes::copy_from_slice(val.as_ref()))
102 }
103}
104
105impl<'a> From<&'a str> for BulkString {
106 fn from(val: &'a str) -> BulkString {
107 BulkString(Bytes::copy_from_slice(val.as_bytes()))
108 }
109}
110
111impl<'a> From<&&'a str> for BulkString {
112 fn from(val: &&'a str) -> BulkString {
113 BulkString(Bytes::copy_from_slice(val.as_bytes()))
114 }
115}
116
117impl From<Bytes> for BulkString {
118 fn from(val: Bytes) -> BulkString {
119 BulkString(val)
120 }
121}
122
123impl From<BytesMut> for BulkString {
124 fn from(val: BytesMut) -> BulkString {
125 BulkString(val.freeze())
126 }
127}
128
129impl<'a> From<&'a Bytes> for BulkString {
130 fn from(val: &'a Bytes) -> BulkString {
131 BulkString(val.clone())
132 }
133}
134
135impl<'a> From<&'a ByteString> for BulkString {
136 fn from(val: &'a ByteString) -> BulkString {
137 BulkString(val.clone().into_bytes())
138 }
139}
140
141impl<'a> From<&'a [u8]> for BulkString {
142 fn from(val: &'a [u8]) -> BulkString {
143 BulkString(Bytes::copy_from_slice(val))
144 }
145}
146
147impl From<Vec<u8>> for BulkString {
148 fn from(val: Vec<u8>) -> BulkString {
149 BulkString(Bytes::from(val))
150 }
151}
152
153#[derive(Debug, Clone, Eq, PartialEq, Hash)]
158pub enum Request {
159 Array(Vec<Request>),
161
162 BulkString(BulkString),
165
166 BulkStatic(&'static [u8]),
169
170 BulkInteger(i64),
172
173 String(ByteString),
175
176 Integer(i64),
179}
180
181impl Request {
182 pub fn from_static(data: &'static str) -> Self {
184 Request::BulkStatic(data.as_ref())
185 }
186
187 pub fn from_bstatic(data: &'static [u8]) -> Self {
189 Request::BulkStatic(data)
190 }
191
192 #[allow(clippy::should_implement_trait)]
193 pub fn add<T>(mut self, other: T) -> Self
198 where
199 Request: From<T>,
200 {
201 match self {
202 Request::Array(ref mut vals) => {
203 vals.push(other.into());
204 self
205 }
206 _ => Request::Array(vec![self, other.into()]),
207 }
208 }
209
210 pub fn extend<T>(mut self, other: impl IntoIterator<Item = T>) -> Self
215 where
216 Request: From<T>,
217 {
218 match self {
219 Request::Array(ref mut vals) => {
220 vals.extend(other.into_iter().map(|t| t.into()));
221 self
222 }
223 _ => {
224 let mut vals = vec![self];
225 vals.extend(other.into_iter().map(|t| t.into()));
226 Request::Array(vals)
227 }
228 }
229 }
230}
231
232impl<T> From<T> for Request
233 where
234 BulkString: From<T>,
235{
236 fn from(val: T) -> Request {
237 Request::BulkString(val.into())
238 }
239}
240
241impl From<i8> for Request {
242 fn from(val: i8) -> Request {
243 Request::Integer(val as i64)
244 }
245}
246
247impl From<i16> for Request {
248 fn from(val: i16) -> Request {
249 Request::Integer(val as i64)
250 }
251}
252
253impl From<i32> for Request {
254 fn from(val: i32) -> Request {
255 Request::Integer(val as i64)
256 }
257}
258
259impl From<i64> for Request {
260 fn from(val: i64) -> Request {
261 Request::Integer(val)
262 }
263}
264
265impl From<u8> for Request {
266 fn from(val: u8) -> Request {
267 Request::Integer(val as i64)
268 }
269}
270
271impl From<u16> for Request {
272 fn from(val: u16) -> Request {
273 Request::Integer(val as i64)
274 }
275}
276
277impl From<u32> for Request {
278 fn from(val: u32) -> Request {
279 Request::Integer(val as i64)
280 }
281}
282
283impl From<usize> for Request {
284 fn from(val: usize) -> Request {
285 Request::Integer(val as i64)
286 }
287}
288
289#[derive(Debug, Clone, Eq, PartialEq, Hash)]
291pub enum Response {
292 Nil,
293
294 Array(Vec<Response>),
296
297 Bytes(Bytes),
300
301 String(ByteString),
303
304 Error(ByteString),
306
307 Integer(i64),
310}
311
312impl Response {
313 pub fn into_result(self) -> Result<Response, ByteString> {
315 match self {
316 Response::Error(val) => Err(val),
317 val => Ok(val),
318 }
319 }
320}
321
322impl TryFrom<Response> for Bytes {
323 type Error = (&'static str, Response);
324
325 fn try_from(val: Response) -> Result<Self, Self::Error> {
326 if let Response::Bytes(bytes) = val {
327 Ok(bytes)
328 } else {
329 Err(("Not a bytes object", val))
330 }
331 }
332}
333
334impl TryFrom<Response> for ByteString {
335 type Error = (&'static str, Response);
336
337 fn try_from(val: Response) -> Result<Self, Self::Error> {
338 match val {
339 Response::String(val) => Ok(val),
340 Response::Bytes(val) => {
341 match ByteString::try_from(val.as_ref()){
342 Ok(v) => {
343 Ok(v)
344 }
345 Err(e) => {
346 Err(("from response fail",Response::Bytes(val)))
347 }
348 }
349 }
350 _ => Err(("Cannot convert into a string", val)),
351 }
352 }
353}
354
355impl TryFrom<Response> for i64 {
356 type Error = (&'static str, Response);
357
358 fn try_from(val: Response) -> Result<Self, Self::Error> {
359 if let Response::Integer(i) = val {
360 Ok(i)
361 } else {
362 Err(("Cannot be converted into an i64", val))
363 }
364 }
365}
366
367impl TryFrom<Response> for bool {
368 type Error = (&'static str, Response);
369
370 fn try_from(val: Response) -> Result<bool, Self::Error> {
371 i64::try_from(val).and_then(|x| match x {
372 0 => Ok(false),
373 1 => Ok(true),
374 _ => Err((
375 "i64 value cannot be represented as bool",
376 Response::Integer(x),
377 )),
378 })
379 }
380}
381
382impl<T> TryFrom<Response> for Vec<T>
383 where
384 T: TryFrom<Response, Error = (&'static str, Response)>,
385{
386 type Error = (&'static str, Response);
387
388 fn try_from(val: Response) -> Result<Vec<T>, Self::Error> {
389 if let Response::Array(ary) = val {
390 let mut ar = Vec::with_capacity(ary.len());
391 for value in ary {
392 ar.push(T::try_from(value)?);
393 }
394 Ok(ar)
395 } else {
396 Err(("Cannot be converted into a vector", val))
397 }
398 }
399}
400
401impl TryFrom<Response> for () {
402 type Error = (&'static str, Response);
403
404 fn try_from(val: Response) -> Result<(), Self::Error> {
405 if let Response::String(string) = val {
406 match string.as_ref() {
407 "OK" => Ok(()),
408 _ => Err(("Unexpected value within String", Response::String(string))),
409 }
410 } else {
411 Err(("Unexpected value", val))
412 }
413 }
414}
415
416impl<A, B> TryFrom<Response> for (A, B)
417 where
418 A: TryFrom<Response, Error = (&'static str, Response)>,
419 B: TryFrom<Response, Error = (&'static str, Response)>,
420{
421 type Error = (&'static str, Response);
422
423 fn try_from(val: Response) -> Result<(A, B), Self::Error> {
424 match val {
425 Response::Array(ary) => {
426 if ary.len() == 2 {
427 let mut ary_iter = ary.into_iter();
428 Ok((
429 A::try_from(ary_iter.next().expect("No value"))?,
430 B::try_from(ary_iter.next().expect("No value"))?,
431 ))
432 } else {
433 Err(("Array needs to be 2 elements", Response::Array(ary)))
434 }
435 }
436 _ => Err(("Unexpected value", val)),
437 }
438 }
439}
440
441impl<A, B, C> TryFrom<Response> for (A, B, C)
442 where
443 A: TryFrom<Response, Error = (&'static str, Response)>,
444 B: TryFrom<Response, Error = (&'static str, Response)>,
445 C: TryFrom<Response, Error = (&'static str, Response)>,
446{
447 type Error = (&'static str, Response);
448
449 fn try_from(val: Response) -> Result<(A, B, C), Self::Error> {
450 match val {
451 Response::Array(ary) => {
452 if ary.len() == 3 {
453 let mut ary_iter = ary.into_iter();
454 Ok((
455 A::try_from(ary_iter.next().expect("No value"))?,
456 B::try_from(ary_iter.next().expect("No value"))?,
457 C::try_from(ary_iter.next().expect("No value"))?,
458 ))
459 } else {
460 Err(("Array needs to be 3 elements", Response::Array(ary)))
461 }
462 }
463 _ => Err(("Unexpected value", val)),
464 }
465 }
466}
467
468impl<K, T, S> TryFrom<Response> for HashMap<K, T, S>
469 where
470 K: TryFrom<Response, Error = (&'static str, Response)> + Hash + Eq,
471 T: TryFrom<Response, Error = (&'static str, Response)>,
472 S: BuildHasher + Default,
473{
474 type Error = (&'static str, Response);
475
476 fn try_from(val: Response) -> Result<HashMap<K, T, S>, Self::Error> {
477 match val {
478 Response::Array(ary) => {
479 let mut map = HashMap::with_capacity_and_hasher(ary.len() / 2, S::default());
480 let mut items = ary.into_iter();
481
482 while let Some(k) = items.next() {
483 let key = K::try_from(k)?;
484 let value = T::try_from(items.next().ok_or((
485 "Cannot convert an odd number of elements into a hashmap",
486 Response::Nil,
487 ))?)?;
488 map.insert(key, value);
489 }
490
491 Ok(map)
492 }
493 _ => Err(("Cannot be converted into a hashmap", val)),
494 }
495 }
496}
497
498macro_rules! impl_tryfrom_integers {
499 ($($int_ty:ident),* $(,)*) => {
500 $(
501 #[allow(clippy::cast_lossless)]
502 impl TryFrom<Response> for $int_ty {
503 type Error = (&'static str, Response);
504
505 fn try_from(val: Response) -> Result<Self, Self::Error> {
506 i64::try_from(val).and_then(|x| {
507 if x < ($int_ty::min_value() as i64)
510 || ($int_ty::max_value() as i64 > 0
511 && x > ($int_ty::max_value() as i64))
512 {
513 Err((
514 concat!(
515 "i64 value cannot be represented as {}",
516 stringify!($int_ty),
517 ),
518 Response::Integer(x),
519 ))
520 } else {
521 Ok(x as $int_ty)
522 }
523 })
524 }
525 }
526 )*
527 };
528}
529
530impl_tryfrom_integers!(isize, usize, i32, u32, u64);
531
532fn write_rn(buf: &mut BytesMut) {
533 buf.extend_from_slice(b"\r\n");
534}
535
536fn write_header(symb: u8, len: i64, buf: &mut BytesMut, body_size: usize) {
537 let mut len_buf = [0; 32];
538 let size = itoa::write(&mut len_buf[..], len).unwrap();
539 buf.reserve(3 + size + body_size);
540 buf.put_u8(symb);
541 buf.extend_from_slice(&len_buf[..size]);
542 write_rn(buf);
543}
544
545fn write_string(symb: u8, string: &str, buf: &mut BytesMut) {
546 let bytes = string.as_bytes();
547 buf.reserve(3 + bytes.len());
548 buf.put_u8(symb);
549 buf.extend_from_slice(bytes);
550 write_rn(buf);
551}
552
553type DecodeResult = Result<Option<(usize, Response)>, Error>;
554
555fn decode(buf: &mut BytesMut, idx: usize) -> DecodeResult {
556 if buf.len() > idx {
557 match buf[idx] {
558 b'$' => decode_bytes(buf, idx + 1),
559 b'*' => decode_array(buf, idx + 1),
560 b':' => decode_integer(buf, idx + 1),
561 b'+' => decode_string(buf, idx + 1),
562 b'-' => decode_error(buf, idx + 1),
563 _ => Err(Error::Parse(format!("Unexpected byte: {}", buf[idx]))),
564 }
565 } else {
566 Ok(None)
567 }
568}
569
570fn decode_length(buf: &mut BytesMut, idx: usize) -> Result<Option<(usize, i64)>, Error> {
571 let (pos, int_str) = if let Some(pos) = buf[idx..].windows(2).position(|w| w == b"\r\n") {
573 (idx + pos + 2, &buf[idx..idx + pos])
574 } else {
575 return Ok(None);
576 };
577
578 match btoi::btoi(int_str) {
580 Ok(int) => Ok(Some((pos, int))),
581 Err(_) => Err(Error::Parse(format!(
582 "Not an integer: {:?}",
583 &int_str[..cmp::min(int_str.len(), 10)]
584 ))),
585 }
586}
587
588fn decode_bytes(buf: &mut BytesMut, idx: usize) -> DecodeResult {
589 match decode_length(buf, idx)? {
590 Some((pos, -1)) => Ok(Some((pos, Response::Nil))),
591 Some((pos, size)) if size >= 0 => {
592 let size = size as usize;
593 let remaining = buf.len() - pos;
594 let required_bytes = size + 2;
595
596 if remaining < required_bytes {
597 return Ok(None);
598 }
599 buf.advance(pos);
600 Ok(Some((2, Response::Bytes(buf.split_to(size).freeze()))))
601 }
602 Some((_, size)) => Err(Error::Parse(format!("Invalid string size: {}", size))),
603 None => Ok(None),
604 }
605}
606
607fn decode_array(buf: &mut BytesMut, idx: usize) -> DecodeResult {
608 match decode_length(buf, idx)? {
609 Some((pos, -1)) => Ok(Some((pos, Response::Nil))),
610 Some((pos, size)) if size >= 0 => {
611 let size = size as usize;
612 let mut pos = pos;
613 let mut values = Vec::with_capacity(size);
614 for _ in 0..size {
615 match decode(buf, pos) {
616 Ok(None) => return Ok(None),
617 Ok(Some((new_pos, value))) => {
618 values.push(value);
619 pos = new_pos;
620 }
621 Err(e) => return Err(e),
622 }
623 }
624 Ok(Some((pos, Response::Array(values))))
625 }
626 Some((_, size)) => Err(Error::Parse(format!("Invalid array size: {}", size))),
627 None => Ok(None),
628 }
629}
630
631fn decode_integer(buf: &mut BytesMut, idx: usize) -> DecodeResult {
632 if let Some((pos, int)) = decode_length(buf, idx)? {
633 Ok(Some((pos, Response::Integer(int))))
634 } else {
635 Ok(None)
636 }
637}
638
639fn decode_string(buf: &mut BytesMut, idx: usize) -> DecodeResult {
641 if let Some((pos, string)) = scan_string(buf, idx)? {
642 Ok(Some((pos, Response::String(string))))
643 } else {
644 Ok(None)
645 }
646}
647
648fn decode_error(buf: &mut BytesMut, idx: usize) -> DecodeResult {
649 if let Some((pos, string)) = scan_string(buf, idx)? {
650 Ok(Some((pos, Response::Error(string))))
651 } else {
652 Ok(None)
653 }
654}
655
656fn scan_string(buf: &mut BytesMut, idx: usize) -> Result<Option<(usize, ByteString)>, Error> {
657 if let Some(pos) = buf[idx..].windows(2).position(|w| w == b"\r\n") {
658 buf.advance(idx);
659 match ByteString::try_from(buf.split_to(pos)) {
660 Ok(s) => Ok(Some((2, s))),
661 Err(_) => Err(Error::Parse(format!(
662 "Not a valid string: {:?}",
663 &buf[idx..idx + cmp::min(pos, 10)]
664 ))),
665 }
666 } else {
667 Ok(None)
668 }
669}
670
671#[cfg(test)]
672mod tests {
673 use std::convert::TryFrom;
674
675 use ntex::codec::{Decoder, Encoder};
676 use ntex::util::{ByteString, Bytes, BytesMut, HashMap};
677
678 use super::*;
679 use crate::array;
680
681 fn obj_to_bytes(obj: Request) -> Bytes {
682 let mut bytes = BytesMut::new();
683 Codec.encode(obj, &mut bytes).unwrap();
684 bytes.freeze()
685 }
686
687 #[test]
688 fn test_array_macro() {
689 let resp_object = array!["SET", "x"];
690 let bytes = obj_to_bytes(resp_object);
691 assert_eq!(bytes, b"*2\r\n$3\r\nSET\r\n$1\r\nx\r\n".as_ref());
692
693 let resp_object = array!["RPUSH", "wyz"].extend(vec!["a", "b"]);
694 let bytes = obj_to_bytes(resp_object);
695 assert_eq!(
696 bytes,
697 b"*4\r\n$5\r\nRPUSH\r\n$3\r\nwyz\r\n$1\r\na\r\n$1\r\nb\r\n".as_ref(),
698 );
699
700 let vals = vec!["a", "b"];
701 let resp_object = array!["RPUSH", "xyz"].extend(&vals);
702 let bytes = obj_to_bytes(resp_object);
703 assert_eq!(
704 bytes,
705 &b"*4\r\n$5\r\nRPUSH\r\n$3\r\nxyz\r\n$1\r\na\r\n$1\r\nb\r\n"[..],
706 );
707 }
708
709 #[test]
710 fn test_bulk_string() {
711 let req_object = Request::BulkString(Bytes::from_static(b"THISISATEST").into());
712 let mut bytes = BytesMut::new();
713 let codec = Codec;
714 codec.encode(req_object.clone(), &mut bytes).unwrap();
715 assert_eq!(b"$11\r\nTHISISATEST\r\n".to_vec(), bytes.to_vec());
716
717 let resp_object = Response::Bytes(Bytes::from_static(b"THISISATEST"));
718 let deserialized = codec.decode(&mut bytes).unwrap().unwrap();
719 assert_eq!(deserialized, resp_object);
720 }
721
722 #[test]
723 fn test_array() {
724 let req_object = Request::Array(vec![b"TEST1".as_ref().into(), b"TEST2".as_ref().into()]);
725 let mut bytes = BytesMut::new();
726 let codec = Codec;
727 codec.encode(req_object.clone(), &mut bytes).unwrap();
728 assert_eq!(
729 b"*2\r\n$5\r\nTEST1\r\n$5\r\nTEST2\r\n".to_vec(),
730 bytes.to_vec()
731 );
732
733 let resp = Response::Array(vec![
734 Response::Bytes(Bytes::from_static(b"TEST1")),
735 Response::Bytes(Bytes::from_static(b"TEST2")),
736 ]);
737 let deserialized = codec.decode(&mut bytes).unwrap().unwrap();
738 assert_eq!(deserialized, resp);
739 }
740
741 #[test]
742 fn test_nil_string() {
743 let mut bytes = BytesMut::new();
744 bytes.extend_from_slice(&b"$-1\r\n"[..]);
745
746 let codec = Codec;
747 let deserialized = codec.decode(&mut bytes).unwrap().unwrap();
748 assert_eq!(deserialized, Response::Nil);
749 }
750
751 #[test]
752 fn test_integer_overflow() {
753 let resp_object = Response::Integer(i64::max_value());
754 let res = i32::try_from(resp_object);
755 assert!(res.is_err());
756 }
757
758 #[test]
759 fn test_integer_underflow() {
760 let resp_object = Response::Integer(-2);
761 let res = u64::try_from(resp_object);
762 assert!(res.is_err());
763 }
764
765 #[test]
766 fn test_integer_convesion() {
767 let resp_object = Response::Integer(50);
768 assert_eq!(u32::try_from(resp_object).unwrap(), 50);
769 }
770
771 #[test]
772 fn test_hashmap_conversion() {
773 let mut expected = HashMap::default();
774 expected.insert(
775 ByteString::from("KEY1").into(),
776 ByteString::from("VALUE1").into(),
777 );
778 expected.insert(
779 ByteString::from("KEY2").into(),
780 ByteString::from("VALUE2").into(),
781 );
782
783 let resp_object = Response::Array(vec![
784 Response::String(ByteString::from_static("KEY1")),
785 Response::String(ByteString::from_static("VALUE1")),
786 Response::String(ByteString::from_static("KEY2")),
787 Response::String(ByteString::from_static("VALUE2")),
788 ]);
789 assert_eq!(
790 HashMap::<ByteString, ByteString>::try_from(resp_object).unwrap(),
791 expected
792 );
793 }
794
795 #[test]
796 fn test_hashmap_conversion_fails_with_odd_length_array() {
797 let resp_object = Response::Array(vec![
798 Response::String(ByteString::from_static("KEY1")),
799 Response::String(ByteString::from_static("VALUE1")),
800 Response::String(ByteString::from_static("KEY2")),
801 Response::String(ByteString::from_static("VALUE2")),
802 Response::String(ByteString::from_static("KEY3")),
803 ]);
804 let res = HashMap::<ByteString, ByteString>::try_from(resp_object);
805
806 match res {
807 Err((_, _)) => {}
808 _ => panic!("Should not be able to convert an odd number of elements to a hashmap"),
809 }
810 }
811}