use std::{
io::{self, Read},
str,
};
use num::{CheckedAdd, CheckedMul};
use serde::{
de::{DeserializeOwned, Unexpected},
Deserialize,
};
use crate::{
types::{AttributeSkip, PushSkip},
Error, Result,
};
pub enum Reference<'b, 'c, T: ?Sized + 'static> {
Borrowed(&'b T),
Copied(&'c T),
}
pub trait Reader<'de> {
fn read_slice<'a>(
&'a mut self,
len: usize,
consume_crlf: bool,
) -> Result<Reference<'de, 'a, [u8]>>;
fn read_slice_until<'a, F>(
&'a mut self,
until_fn: F,
consume_crlf: bool,
) -> Result<Reference<'de, 'a, [u8]>>
where
F: Fn(u8) -> bool;
fn peek_u8(&mut self) -> Result<Option<u8>>;
fn read_u8(&mut self) -> Result<Option<u8>>;
fn read_length(&mut self) -> Result<usize> {
self.read_unsigned()
}
fn read_unsigned<T>(&mut self) -> Result<T>
where
T: CheckedMul + CheckedAdd + From<u8>,
{
let peek = self.peek_u8()?.ok_or_else(Error::eof)?;
match peek {
b'0' => {
self.read_u8()?;
match self.peek_u8()? {
Some(b'0'..=b'9') => Err(Error::unexpected_value("number after 0")),
_ => Ok(T::from(0)),
}
}
ch @ b'1'..=b'9' => {
self.read_u8()?;
let mut num = T::from(ch - b'0');
loop {
match self.peek_u8()? {
Some(c @ b'0'..=b'9') => {
let digit = T::from(c - b'0');
let ten = T::from(10);
if let Some(r) =
num.checked_mul(&ten).and_then(|n| n.checked_add(&digit))
{
num = r;
} else {
return Err(Error::overflow());
}
self.read_u8()?;
}
_ => {
return Ok(num);
}
}
}
}
_ => Err(Error::expected_value("number")),
}
}
fn read_double(&mut self) -> Result<f64> {
let mut buf = Vec::new();
let mut negative = false;
let mut inf = false;
if let Some(b'-') = self.peek_u8()? {
negative = true;
self.read_u8()?;
}
if let Some(b'i') = self.peek_u8()? {
self.read_ident(b"inf")?;
inf = true;
}
if inf {
if negative {
return Ok(f64::NEG_INFINITY);
} else {
return Ok(f64::INFINITY);
}
}
loop {
match self.peek_u8()? {
Some(ch) if ch != b'\r' && ch != b'\n' => {
self.read_u8()?;
buf.push(ch);
}
None => return Err(Error::eof()),
_ => break,
}
}
let str = str::from_utf8(&buf[..]).map_err(|e| Error::utf8(e.valid_up_to()))?;
let result = str.parse::<f64>().map_err(|_e| Error::parse())?;
Ok(result)
}
fn read_bool(&mut self) -> Result<bool> {
match self.peek_u8()? {
Some(b't') => {
self.read_u8()?;
self.read_crlf()?;
Ok(true)
}
Some(b'f') => {
self.read_u8()?;
self.read_crlf()?;
Ok(false)
}
_ => Err(Error::expected_value("bool")),
}
}
fn read_ident(&mut self, ident: &[u8]) -> Result<()>;
fn read_crlf(&mut self) -> Result<()> {
self.read_ident(b"\r\n")
}
}
pub struct ReadReader<R: Read> {
r: io::Bytes<R>,
ch: Option<u8>,
buf: Vec<u8>,
}
fn peek_u8<R: Read>(r: &mut io::Bytes<R>, ch: &mut Option<u8>) -> Result<Option<u8>> {
match ch {
Some(next) => Ok(Some(*next)),
None => read_u8(r, ch),
}
}
fn read_u8<R: Read>(r: &mut io::Bytes<R>, ch: &mut Option<u8>) -> Result<Option<u8>> {
r.next().transpose().map_err(Error::io).map(|next| {
*ch = next;
next
})
}
fn read_reader_ident<R: Read>(
r: &mut io::Bytes<R>,
ch: &mut Option<u8>,
ident: &[u8],
) -> Result<()> {
for expected in ident {
match peek_u8(r, ch)? {
None => return Err(Error::eof()),
Some(next) => {
if next != *expected {
return Err(Error::expected_value("ident"));
}
read_u8(r, ch)?;
}
}
}
Ok(())
}
impl<'de, R: Read> Reader<'de> for ReadReader<R> {
fn read_slice<'a>(
&'a mut self,
len: usize,
consume_crlf: bool,
) -> Result<Reference<'de, 'a, [u8]>> {
self.buf.clear();
for _count in 0..len {
let ch = peek_u8(&mut self.r, &mut self.ch)?.ok_or_else(Error::eof)?;
self.buf.push(ch);
read_u8(&mut self.r, &mut self.ch)?;
}
if consume_crlf {
read_reader_ident(&mut self.r, &mut self.ch, b"\r\n")?;
}
Ok(Reference::Copied(&self.buf[..]))
}
fn read_slice_until<'a, F>(
&'a mut self,
until_fn: F,
consume_crlf: bool,
) -> Result<Reference<'de, 'a, [u8]>>
where
F: Fn(u8) -> bool,
{
self.buf.clear();
loop {
let ch = peek_u8(&mut self.r, &mut self.ch)?.ok_or_else(Error::eof)?;
if until_fn(ch) {
break;
}
self.buf.push(ch);
read_u8(&mut self.r, &mut self.ch)?;
}
if consume_crlf {
read_reader_ident(&mut self.r, &mut self.ch, b"\r\n")?;
}
Ok(Reference::Copied(&self.buf[..]))
}
fn peek_u8(&mut self) -> Result<Option<u8>> {
peek_u8(&mut self.r, &mut self.ch)
}
fn read_u8(&mut self) -> Result<Option<u8>> {
read_u8(&mut self.r, &mut self.ch)
}
fn read_ident(&mut self, ident: &[u8]) -> Result<()> {
read_reader_ident(&mut self.r, &mut self.ch, ident)
}
}
pub struct RefReader<'de, R: AsRef<[u8]> + ?Sized> {
slice: &'de R,
src: &'de [u8],
buf: &'de [u8],
}
impl<'de, R: AsRef<[u8]> + ?Sized> RefReader<'de, R> {
pub fn from_slice(slice: &'de R) -> Self {
let buf = slice.as_ref();
RefReader {
slice,
src: buf,
buf,
}
}
pub fn consumed_bytes(&self) -> usize {
unsafe { self.buf.as_ptr().offset_from(self.src.as_ptr()) as usize }
}
}
fn read_slice_ident<'de>(buf: &mut &'de [u8], ident: &[u8]) -> Result<()> {
if buf.starts_with(ident) {
*buf = &buf[ident.len()..];
Ok(())
} else {
Err(Error::expected_value("ident"))
}
}
impl<'de, R: AsRef<[u8]> + ?Sized> Reader<'de> for RefReader<'de, R> {
fn read_slice<'a>(
&'a mut self,
len: usize,
consume_crlf: bool,
) -> Result<Reference<'de, 'a, [u8]>> {
if len > self.buf.len() {
return Err(Error::eof());
}
let (a, b) = self.buf.split_at(len);
self.buf = b;
if consume_crlf {
read_slice_ident(&mut self.buf, b"\r\n")?;
}
Ok(Reference::Borrowed(a))
}
fn read_slice_until<'a, F>(
&'a mut self,
until_fn: F,
consume_crlf: bool,
) -> Result<Reference<'de, 'a, [u8]>>
where
F: Fn(u8) -> bool,
{
let len = self
.buf
.iter()
.position(|ch| until_fn(*ch))
.ok_or_else(Error::eof)?;
let (a, b) = self.buf.split_at(len);
self.buf = b;
if consume_crlf {
read_slice_ident(&mut self.buf, b"\r\n")?;
}
Ok(Reference::Borrowed(a))
}
fn peek_u8(&mut self) -> Result<Option<u8>> {
Ok(self.buf.iter().copied().next())
}
fn read_u8(&mut self) -> Result<Option<u8>> {
if self.buf.is_empty() {
return Ok(None);
}
let ch = self.buf[0];
self.buf = &self.buf[1..];
Ok(Some(ch))
}
fn read_ident(&mut self, ident: &[u8]) -> Result<()> {
read_slice_ident(&mut self.buf, ident)
}
}
pub struct Deserializer<R> {
reader: R,
skip_attribute: bool,
skip_push: bool,
}
impl<R> ReadReader<R>
where
R: Read,
{
fn from_read(r: R) -> Self {
ReadReader {
r: r.bytes(),
ch: None,
buf: Vec::new(),
}
}
}
impl<R: Read> Deserializer<ReadReader<R>> {
pub fn from_read(r: R) -> Self {
Deserializer {
reader: ReadReader::from_read(r),
skip_attribute: true,
skip_push: true,
}
}
}
impl<'a, R: AsRef<[u8]> + ?Sized> Deserializer<RefReader<'a, R>> {
pub fn from_slice(slice: &'a R) -> Self {
Deserializer {
reader: RefReader::from_slice(slice),
skip_attribute: true,
skip_push: true,
}
}
}
impl<'de, R> Deserializer<RefReader<'de, R>>
where
R: AsRef<[u8]> + ?Sized,
{
pub fn get_ref(&self) -> &R {
self.reader.slice
}
pub fn get_consumed_bytes(&self) -> usize {
self.reader.consumed_bytes()
}
}
pub fn from_read<R, T>(rd: R) -> Result<T>
where
R: Read,
T: DeserializeOwned,
{
let mut d = Deserializer::from_read(rd);
T::deserialize(&mut d)
}
pub fn from_slice<'a, R, T>(input: &'a R) -> Result<T>
where
R: AsRef<[u8]> + ?Sized,
T: Deserialize<'a>,
{
let mut d = Deserializer::from_slice(input);
T::deserialize(&mut d)
}
impl<'de, R: Reader<'de>> Deserializer<R> {
fn parse_blob_string<'a>(&'a mut self) -> Result<Reference<'de, 'a, [u8]>> {
let len = self.reader.read_length()?;
self.reader.read_crlf()?;
let slice = self.reader.read_slice(len, true)?;
Ok(slice)
}
fn parse_simple_string<'a>(&'a mut self) -> Result<Reference<'de, 'a, [u8]>> {
let slice = self
.reader
.read_slice_until(|ch| ch == b'\r' || ch == b'\n', true)?;
Ok(slice)
}
fn parse_double(&mut self) -> Result<f64> {
let val = self.reader.read_double()?;
self.reader.read_crlf()?;
Ok(val)
}
fn skip_attribute(&mut self) -> Result<()> {
let _s: AttributeSkip = Deserialize::deserialize(self)?;
Ok(())
}
fn skip_push(&mut self) -> Result<()> {
let _s: PushSkip = Deserialize::deserialize(self)?;
Ok(())
}
fn peek_skip_attribute(&mut self) -> Result<u8> {
let peek = self.peek()?;
if peek == b'|' && self.skip_attribute {
self.skip_attribute()?;
return self.reader.peek_u8()?.ok_or_else(Error::eof);
}
if peek == b'>' && self.skip_push {
self.skip_push()?;
return self.reader.peek_u8()?.ok_or_else(Error::eof);
}
Ok(peek)
}
fn peek(&mut self) -> Result<u8> {
self.reader.peek_u8()?.ok_or_else(Error::eof)
}
}
fn visit_ref_bytes<'de, 'a, V>(r: Reference<'de, 'a, [u8]>, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
match r {
Reference::Copied(s) => visitor.visit_bytes(s),
Reference::Borrowed(s) => visitor.visit_borrowed_bytes(s),
}
}
fn visit_ref_str<'de, 'a, V>(r: Reference<'de, 'a, [u8]>, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
match r {
Reference::Copied(s) => {
let string = str::from_utf8(s).map_err(|e| Error::utf8(e.valid_up_to()))?;
visitor.visit_str(string)
}
Reference::Borrowed(s) => {
let string = str::from_utf8(s).map_err(|e| Error::utf8(e.valid_up_to()))?;
visitor.visit_borrowed_str(string)
}
}
}
impl<'de, 'a, R> serde::Deserializer<'de> for &'a mut Deserializer<R>
where
R: Reader<'de>,
{
type Error = super::Error;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.peek_skip_attribute()?;
match peek {
b'$' => self.deserialize_str(visitor),
b'=' => self.deserialize_str(visitor),
b'!' => self.deserialize_str(visitor),
b'+' => self.deserialize_str(visitor),
b'-' => self.deserialize_str(visitor),
b'#' => self.deserialize_bool(visitor),
b':' => self.deserialize_i64(visitor),
b',' => self.deserialize_f64(visitor),
b'*' => self.deserialize_seq(visitor),
b'~' => self.deserialize_seq(visitor),
b'>' => self.deserialize_seq(visitor),
b'%' => self.deserialize_map(visitor),
b'|' => self.deserialize_map(visitor),
_ => Err(Error::expected_value("type header")),
}
}
fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.peek_skip_attribute()?;
match peek {
b'#' => {
self.reader.read_u8()?;
let val = self.reader.read_bool()?;
visitor.visit_bool(val)
}
_ => Err(Error::expected_marker("bool")),
}
}
fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_i64(visitor)
}
fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_i64(visitor)
}
fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_i64(visitor)
}
fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.peek_skip_attribute()?;
match peek {
b':' => {
self.reader.read_u8()?;
match self.reader.peek_u8()? {
Some(b'-') => {
self.reader.read_u8()?;
let num: i64 = self.reader.read_unsigned()?;
self.reader.read_crlf()?;
visitor.visit_i64(-num)
}
Some(b'0'..=b'9') => {
let num: i64 = self.reader.read_unsigned()?;
self.reader.read_crlf()?;
visitor.visit_i64(num)
}
_ => Err(Error::expected_value("number")),
}
}
_ => Err(Error::expected_marker("number")),
}
}
fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_u64(visitor)
}
fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_u64(visitor)
}
fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_u64(visitor)
}
fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.peek_skip_attribute()?;
match peek {
b':' => {
self.reader.read_u8()?;
match self.reader.peek_u8()? {
Some(b'-') => Err(Error::unexpected_value("signed")),
Some(b'0'..=b'9') => {
let num: u64 = self.reader.read_unsigned()?;
self.reader.read_crlf()?;
visitor.visit_u64(num)
}
_ => Err(Error::expected_value("number")),
}
}
_ => Err(Error::expected_marker("number")),
}
}
fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_f64(visitor)
}
fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.peek_skip_attribute()?;
match peek {
b':' => {
self.reader.read_u8()?;
match self.reader.peek_u8()? {
Some(b'-') => {
self.reader.read_u8()?;
let num: i64 = self.reader.read_unsigned()?;
self.reader.read_crlf()?;
visitor.visit_f64(-num as f64)
}
Some(b'0'..=b'9') => {
let num: i64 = self.reader.read_unsigned()?;
self.reader.read_crlf()?;
visitor.visit_f64(num as f64)
}
_ => Err(Error::expected_value("number")),
}
}
b',' => {
self.reader.read_u8()?;
let num = self.parse_double()?;
visitor.visit_f64(num)
}
_ => Err(Error::expected_marker("number|double")),
}
}
fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_str(visitor)
}
fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.peek_skip_attribute()?;
match peek {
b'+' => {
self.reader.read_u8()?;
let bytes = self.parse_simple_string()?;
visit_ref_str(bytes, visitor)
}
b'-' => {
self.reader.read_u8()?;
let bytes = self.parse_simple_string()?;
visit_ref_str(bytes, visitor)
}
b'$' => {
self.reader.read_u8()?;
let bytes = self.parse_blob_string()?;
visit_ref_str(bytes, visitor)
}
b'!' => {
self.reader.read_u8()?;
let bytes = self.parse_blob_string()?;
visit_ref_str(bytes, visitor)
}
b'=' => {
self.reader.read_u8()?;
let bytes = self.parse_blob_string()?;
visit_ref_str(bytes, visitor)
}
_ => Err(Error::expected_marker("string|error")),
}
}
fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_str(visitor)
}
fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.peek_skip_attribute()?;
match peek {
b'+' => {
self.reader.read_u8()?;
let bytes = self.parse_simple_string()?;
visit_ref_bytes(bytes, visitor)
}
b'-' => {
self.reader.read_u8()?;
let bytes = self.parse_simple_string()?;
visit_ref_bytes(bytes, visitor)
}
b'$' => {
self.reader.read_u8()?;
let bytes = self.parse_blob_string()?;
visit_ref_bytes(bytes, visitor)
}
b'!' => {
self.reader.read_u8()?;
let bytes = self.parse_blob_string()?;
visit_ref_bytes(bytes, visitor)
}
b'=' => {
self.reader.read_u8()?;
let bytes = self.parse_blob_string()?;
visit_ref_bytes(bytes, visitor)
}
_ => Err(Error::expected_marker("string|error")),
}
}
fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_bytes(visitor)
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.peek_skip_attribute()?;
match peek {
b'_' => {
self.reader.read_u8()?;
self.reader.read_crlf()?;
visitor.visit_none()
}
_ => visitor.visit_some(self),
}
}
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.peek_skip_attribute()?;
match peek {
b'_' => {
self.reader.read_u8()?;
self.reader.read_crlf()?;
visitor.visit_unit()
}
_ => Err(Error::expected_marker("null")),
}
}
fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_unit(visitor)
}
fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.reader.peek_u8()?.ok_or_else(Error::eof)?;
match name {
crate::types::SIMPLE_ERROR_TOKEN => {
if peek == b'-' {
self.reader.read_u8()?;
let bytes = self.parse_simple_string()?;
visit_ref_str(bytes, visitor)
} else {
Err(Error::expected_marker("simple error"))
}
}
crate::types::BLOB_ERROR_TOKEN => {
if peek == b'!' {
self.reader.read_u8()?;
let bytes = self.parse_blob_string()?;
visit_ref_str(bytes, visitor)
} else {
Err(Error::expected_marker("blob error"))
}
}
crate::types::SIMPLE_STRING_TOKEN => {
if peek == b'+' {
self.reader.read_u8()?;
let bytes = self.parse_simple_string()?;
visit_ref_str(bytes, visitor)
} else {
Err(Error::expected_marker("simple string"))
}
}
crate::types::BLOB_STRING_TOKEN => {
if peek == b'$' {
self.reader.read_u8()?;
let bytes = self.parse_blob_string()?;
visit_ref_str(bytes, visitor)
} else {
Err(Error::expected_marker("blob string"))
}
}
crate::types::ATTRIBUTE_SKIP_TOKEN => {
if peek == b'|' {
self.reader.read_u8()?;
let len = self.reader.read_length()?;
self.reader.read_crlf()?;
visitor.visit_map(CountMapAccess::new(self, len))
} else {
Err(Error::expected_marker("attribute"))
}
}
crate::types::PUSH_TOKEN => {
if peek != b'>' {
return Err(Error::expected_marker("push"));
}
self.skip_push = false;
visitor.visit_newtype_struct(self)
}
crate::types::PUSH_OR_VALUE_TOKEN => {
if peek == b'>' {
visitor.visit_map(PushOrValueAccess::new_push(self))
} else {
visitor.visit_map(PushOrValueAccess::new_value(self))
}
}
_ => visitor.visit_newtype_struct(self),
}
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.peek_skip_attribute()?;
match peek {
b'*' => {
self.reader.read_u8()?;
let len = self.reader.read_length()?;
self.reader.read_crlf()?;
visitor.visit_seq(CountSeqAccess::new(self, len))
}
b'~' => {
self.reader.read_u8()?;
let len = self.reader.read_length()?;
self.reader.read_crlf()?;
visitor.visit_seq(CountSeqAccess::new(self, len))
}
b'>' => {
self.skip_push = true;
self.reader.read_u8()?;
let len = self.reader.read_length()?;
self.reader.read_crlf()?;
visitor.visit_seq(CountSeqAccess::new(self, len))
}
_ => Err(Error::expected_marker("array|set|push")),
}
}
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V>(
self,
name: &'static str,
_len: usize,
visitor: V,
) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.reader.peek_u8()?.ok_or_else(Error::eof)?;
match name {
crate::types::WITH_ATTRIBUTE_TOKEN => {
if peek == b'|' {
let last_skip = self.skip_attribute;
self.skip_attribute = false;
let r = visitor.visit_seq(CountSeqAccess::new(self, 2));
self.skip_attribute = last_skip;
r
} else {
Err(Error::expected_marker("attribute"))
}
}
_ => self.deserialize_seq(visitor),
}
}
fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.peek_skip_attribute()?;
match peek {
b'%' => {
self.reader.read_u8()?;
let len = self.reader.read_length()?;
self.reader.read_crlf()?;
visitor.visit_map(CountMapAccess::new(self, len))
}
b'|' => {
self.reader.read_u8()?;
let len = self.reader.read_length()?;
self.reader.read_crlf()?;
let last_skip = self.skip_attribute;
self.skip_attribute = true;
let r = visitor.visit_map(CountMapAccess::new(self, len));
self.skip_attribute = last_skip;
r
}
_ => Err(Error::expected_marker("map")),
}
}
fn deserialize_struct<V>(
self,
_name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_any(visitor)
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
let peek = self.peek_skip_attribute()?;
match peek {
b'%' => {
self.reader.read_u8()?;
let len = self.reader.read_length()?;
self.reader.read_crlf()?;
if len > 1 {
return Err(Error::expected_value("1-length map"));
}
visitor.visit_enum(VariantAccess::new(self))
}
b'+' => visitor.visit_enum(UnitVariantAccess::new(self)),
b'$' => visitor.visit_enum(UnitVariantAccess::new(self)),
_ => Err(Error::expected_marker("map")),
}
}
fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_str(visitor)
}
fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
self.deserialize_any(visitor)
}
}
struct CountSeqAccess<'a, R> {
de: &'a mut Deserializer<R>,
len: usize,
}
impl<'a, R> CountSeqAccess<'a, R> {
fn new(de: &'a mut Deserializer<R>, len: usize) -> Self {
CountSeqAccess { de, len }
}
}
impl<'de, 'a, R: Reader<'de> + 'a> serde::de::SeqAccess<'de> for CountSeqAccess<'a, R> {
type Error = Error;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
where
T: serde::de::DeserializeSeed<'de>,
{
if self.len > 0 {
let result = seed.deserialize(&mut *self.de).map(Some);
self.len -= 1;
result
} else {
Ok(None)
}
}
fn size_hint(&self) -> Option<usize> {
Some(self.len)
}
}
struct CountMapAccess<'a, R> {
de: &'a mut Deserializer<R>,
len: usize,
}
impl<'a, R> CountMapAccess<'a, R> {
fn new(de: &'a mut Deserializer<R>, len: usize) -> Self {
CountMapAccess { de, len }
}
}
impl<'de, 'a, R: Reader<'de> + 'a> serde::de::MapAccess<'de> for CountMapAccess<'a, R> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: serde::de::DeserializeSeed<'de>,
{
if self.len > 0 {
let key = seed.deserialize(&mut *self.de).map(Some);
self.len -= 1;
key
} else {
Ok(None)
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: serde::de::DeserializeSeed<'de>,
{
seed.deserialize(&mut *self.de)
}
}
struct VariantAccess<'a, R> {
de: &'a mut Deserializer<R>,
}
impl<'a, R> VariantAccess<'a, R> {
fn new(de: &'a mut Deserializer<R>) -> Self {
VariantAccess { de }
}
}
impl<'de, 'a, R: Reader<'de> + 'a> serde::de::EnumAccess<'de> for VariantAccess<'a, R> {
type Error = Error;
type Variant = Self;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
where
V: serde::de::DeserializeSeed<'de>,
{
let variant_key = seed.deserialize(&mut *self.de)?;
Ok((variant_key, self))
}
}
impl<'de, 'a, R: Reader<'de> + 'a> serde::de::VariantAccess<'de> for VariantAccess<'a, R> {
type Error = Error;
fn unit_variant(self) -> Result<()> {
serde::de::Deserialize::deserialize(self.de)
}
fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
where
T: serde::de::DeserializeSeed<'de>,
{
seed.deserialize(self.de)
}
fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
serde::de::Deserializer::deserialize_tuple(self.de, len, visitor)
}
fn struct_variant<V>(self, fields: &'static [&'static str], visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
serde::de::Deserializer::deserialize_struct(self.de, "", fields, visitor)
}
}
struct UnitVariantAccess<'a, R> {
de: &'a mut Deserializer<R>,
}
impl<'a, R> UnitVariantAccess<'a, R> {
fn new(de: &'a mut Deserializer<R>) -> Self {
UnitVariantAccess { de }
}
}
impl<'de, 'a, R: Reader<'de> + 'a> serde::de::EnumAccess<'de> for UnitVariantAccess<'a, R> {
type Error = Error;
type Variant = Self;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
where
V: serde::de::DeserializeSeed<'de>,
{
let variant_key = seed.deserialize(&mut *self.de)?;
Ok((variant_key, self))
}
}
impl<'de, 'a, R: Reader<'de> + 'a> serde::de::VariantAccess<'de> for UnitVariantAccess<'a, R> {
type Error = Error;
fn unit_variant(self) -> Result<()> {
Ok(())
}
fn newtype_variant_seed<T>(self, _seed: T) -> Result<T::Value>
where
T: serde::de::DeserializeSeed<'de>,
{
Err(serde::de::Error::invalid_type(
Unexpected::UnitVariant,
&"newtype variant",
))
}
fn tuple_variant<V>(self, _len: usize, _visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
Err(serde::de::Error::invalid_type(
Unexpected::UnitVariant,
&"tuple variant",
))
}
fn struct_variant<V>(self, _fields: &'static [&'static str], _visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>,
{
Err(serde::de::Error::invalid_type(
Unexpected::UnitVariant,
&"struct variant",
))
}
}
struct PushOrValueAccess<'a, R> {
de: &'a mut Deserializer<R>,
is_push: bool,
done: bool,
}
impl<'a, R> PushOrValueAccess<'a, R> {
fn new_push(de: &'a mut Deserializer<R>) -> Self {
PushOrValueAccess { de, is_push: true, done: false }
}
fn new_value(de: &'a mut Deserializer<R>) -> Self {
PushOrValueAccess { de, is_push: false, done: false }
}
}
struct ConstantStrDeserializer {
s: &'static str
}
impl<'de> serde::de::Deserializer<'de> for ConstantStrDeserializer {
type Error = Error;
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
where
V: serde::de::Visitor<'de>
{
visitor.visit_borrowed_str(self.s)
}
serde::forward_to_deserialize_any! {
bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
bytes byte_buf option unit unit_struct newtype_struct seq tuple
tuple_struct map struct enum identifier ignored_any
}
}
impl<'de, 'a, R: Reader<'de> + 'a> serde::de::MapAccess<'de> for PushOrValueAccess<'a, R> {
type Error = Error;
fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
where
K: serde::de::DeserializeSeed<'de>,
{
if self.done {
return Ok(None);
}
if self.is_push {
Ok(Some(seed.deserialize(ConstantStrDeserializer { s: crate::types::PUSH_TOKEN })?))
} else {
Ok(Some(seed.deserialize(ConstantStrDeserializer { s: crate::types::VALUE_TOKEN })?))
}
}
fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
where
V: serde::de::DeserializeSeed<'de>,
{
seed.deserialize(&mut *self.de)
}
}
#[cfg(test)]
#[allow(clippy::bool_assert_comparison)]
mod tests {
use std::collections::HashMap;
use serde::Deserialize;
use super::*;
use crate::{
test_utils::{test_deserialize, test_deserialize_result},
types::owned::{BlobError, BlobString, SimpleError, SimpleString},
};
#[test]
fn test_blob_string() {
test_deserialize(b"$11\r\nhello world\r\n", |value: String| {
assert_eq!(value, "hello world");
});
test_deserialize(b"$11\r\nhello world\r\n", |value: BlobString| {
assert_eq!(value.0, "hello world");
});
test_deserialize_result(b"+hello world\r\n", |value: Result<BlobString>| {
assert!(matches!(value, Err(_)));
});
}
#[test]
fn test_simple_string() {
test_deserialize(b"+hello world\r\n", |value: String| {
assert_eq!(value, "hello world");
});
test_deserialize(b"+hello world\r\n", |value: SimpleString| {
assert_eq!(value.0, "hello world");
});
}
#[test]
fn test_blob_error() {
test_deserialize(b"!15\r\nERR hello world\r\n", |value: BlobError| {
assert_eq!(value.0, "ERR hello world");
});
}
#[test]
fn test_simple_error() {
test_deserialize(b"-ERR hello world\r\n", |value: SimpleError| {
assert_eq!(value.0, "ERR hello world");
});
}
#[test]
fn test_bool() {
test_deserialize(b"#t\r\n", |value: bool| {
assert_eq!(value, true);
});
test_deserialize(b"#f\r\n", |value: bool| {
assert_eq!(value, false);
});
}
#[test]
fn test_number() {
test_deserialize(b":12345\r\n", |value: i64| {
assert_eq!(value, 12345);
});
test_deserialize(b":-12345\r\n", |value: i64| {
assert_eq!(value, -12345);
});
}
#[test]
fn test_double() {
test_deserialize(b",1.23\r\n", |value: f64| {
assert_eq!(value, 1.23);
});
test_deserialize(b",10\r\n", |value: f64| {
assert_eq!(value, 10.0);
});
test_deserialize(b",inf\r\n", |value: f64| {
assert_eq!(value, f64::INFINITY);
});
test_deserialize(b",-inf\r\n", |value: f64| {
assert_eq!(value, f64::NEG_INFINITY);
});
}
#[test]
fn test_char() {
test_deserialize(b"+a\r\n", |value: char| {
assert_eq!(value, 'a');
});
}
#[test]
fn test_seq() {
test_deserialize(b"*3\r\n:1\r\n:2\r\n:3\r\n", |value: Vec<u64>| {
assert_eq!(value, [1, 2, 3]);
});
test_deserialize(
b"*2\r\n*3\r\n:1\r\n$5\r\nhello\r\n:2\r\n#f\r\n",
|value: ((u64, String, u64), bool)| {
assert_eq!(value, ((1, String::from("hello"), 2), false));
},
);
}
#[test]
fn test_map() {
test_deserialize(
b"%2\r\n+first\r\n:1\r\n+second\r\n:2\r\n",
|value: HashMap<String, usize>| {
let kv = value.into_iter().collect::<Vec<_>>();
assert!(kv.contains(&("first".to_string(), 1)));
assert!(kv.contains(&("second".to_string(), 2)));
},
);
#[derive(PartialEq, Deserialize, Debug)]
struct CustomMap {
first: usize,
second: f64,
}
test_deserialize(
b"%2\r\n+first\r\n:1\r\n+second\r\n:2\r\n",
|value: CustomMap| {
assert_eq!(
value,
CustomMap {
first: 1,
second: 2.0
}
);
},
);
}
#[test]
fn test_enum() {
#[derive(Debug, Deserialize, PartialEq)]
enum TestEnum {
One(usize),
Two(usize, String),
Three { value: usize },
Four,
}
test_deserialize(b"%1\r\n+One\r\n:300\r\n", |value: TestEnum| {
assert_eq!(value, TestEnum::One(300))
});
test_deserialize(
b"%1\r\n+Two\r\n*2\r\n:300\r\n+teststring\r\n",
|value: TestEnum| assert_eq!(value, TestEnum::Two(300, String::from("teststring"))),
);
test_deserialize(
b"%1\r\n+Three\r\n%1\r\n+value\r\n:300\r\n",
|value: TestEnum| assert_eq!(value, TestEnum::Three { value: 300 }),
);
test_deserialize(b"%1\r\n+Four\r\n_\r\n", |value: TestEnum| {
assert_eq!(value, TestEnum::Four)
});
test_deserialize(b"+Four\r\n", |value: TestEnum| {
assert_eq!(value, TestEnum::Four)
});
test_deserialize(b"$4\r\nFour\r\n", |value: TestEnum| {
assert_eq!(value, TestEnum::Four)
});
}
}