use crate::dom::{DomArrayIter, DomEntryKind, DomObjectIter, DomRef};
use serde::Deserialize;
use serde::de::{self, DeserializeSeed, EnumAccess, MapAccess, SeqAccess, VariantAccess, Visitor};
use std::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Error(String);
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
impl std::error::Error for Error {}
impl de::Error for Error {
fn custom<T: fmt::Display>(msg: T) -> Self {
Error(msg.to_string())
}
}
impl<'de> de::Deserializer<'de> for DomRef<'de, 'de> {
type Error = Error;
fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
match self.tape[self.pos].kind() {
DomEntryKind::Null => visitor.visit_unit(),
DomEntryKind::Bool => visitor.visit_bool(self.tape[self.pos].payload() != 0),
DomEntryKind::Number => {
let s = self.tape[self.pos].as_number().unwrap();
if let Ok(v) = s.parse::<u64>() {
return visitor.visit_u64(v);
}
if let Ok(v) = s.parse::<i64>() {
return visitor.visit_i64(v);
}
if let Ok(v) = s.parse::<f64>() {
return visitor.visit_f64(v);
}
visitor.visit_str(s)
}
DomEntryKind::String => {
let s: &'de str = self.tape[self.pos].source_string().unwrap();
visitor.visit_borrowed_str(s)
}
DomEntryKind::EscapedString => {
visitor.visit_str(self.tape[self.pos].as_string().unwrap())
}
DomEntryKind::StartObject => visitor.visit_map(TapeMapAccess::new(self)),
DomEntryKind::StartArray => visitor.visit_seq(TapeSeqAccess::new(self)),
_ => Err(de::Error::custom("unexpected tape entry at value position")),
}
}
fn deserialize_bool<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
match self.tape[self.pos].kind() {
DomEntryKind::Bool => visitor.visit_bool(self.tape[self.pos].payload() != 0),
_ => self.deserialize_any(visitor),
}
}
fn deserialize_i8<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Error> {
self.deserialize_i64(v)
}
fn deserialize_i16<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Error> {
self.deserialize_i64(v)
}
fn deserialize_i32<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Error> {
self.deserialize_i64(v)
}
fn deserialize_i64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
if let Some(s) = self.tape[self.pos].as_number() {
visitor.visit_i64(s.parse().map_err(de::Error::custom)?)
} else {
self.deserialize_any(visitor)
}
}
fn deserialize_u8<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Error> {
self.deserialize_u64(v)
}
fn deserialize_u16<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Error> {
self.deserialize_u64(v)
}
fn deserialize_u32<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Error> {
self.deserialize_u64(v)
}
fn deserialize_u64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
if let Some(s) = self.tape[self.pos].as_number() {
visitor.visit_u64(s.parse().map_err(de::Error::custom)?)
} else {
self.deserialize_any(visitor)
}
}
fn deserialize_f32<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Error> {
self.deserialize_f64(v)
}
fn deserialize_f64<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
if let Some(s) = self.tape[self.pos].as_number() {
visitor.visit_f64(s.parse().map_err(de::Error::custom)?)
} else {
self.deserialize_any(visitor)
}
}
fn deserialize_char<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Error> {
self.deserialize_str(v)
}
fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
match self.tape[self.pos].kind() {
DomEntryKind::String => {
let s: &'de str = self.tape[self.pos].source_string().unwrap();
visitor.visit_borrowed_str(s)
}
DomEntryKind::EscapedString => {
visitor.visit_str(self.tape[self.pos].as_string().unwrap())
}
_ => self.deserialize_any(visitor),
}
}
fn deserialize_string<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Error> {
self.deserialize_str(v)
}
fn deserialize_bytes<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Error> {
self.deserialize_any(v)
}
fn deserialize_byte_buf<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Error> {
self.deserialize_any(v)
}
fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
if self.tape[self.pos].kind() == DomEntryKind::Null {
visitor.visit_none()
} else {
visitor.visit_some(self)
}
}
fn deserialize_unit<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
if self.tape[self.pos].kind() == DomEntryKind::Null {
visitor.visit_unit()
} else {
Err(de::Error::custom("expected null"))
}
}
fn deserialize_unit_struct<V: Visitor<'de>>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Error> {
self.deserialize_unit(visitor)
}
fn deserialize_newtype_struct<V: Visitor<'de>>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Error> {
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
if self.tape[self.pos].kind() == DomEntryKind::StartArray {
visitor.visit_seq(TapeSeqAccess::new(self))
} else {
Err(de::Error::custom("expected JSON array"))
}
}
fn deserialize_tuple<V: Visitor<'de>>(
self,
_len: usize,
visitor: V,
) -> Result<V::Value, Error> {
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V: Visitor<'de>>(
self,
_name: &'static str,
_len: usize,
visitor: V,
) -> Result<V::Value, Error> {
self.deserialize_seq(visitor)
}
fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
if self.tape[self.pos].kind() == DomEntryKind::StartObject {
visitor.visit_map(TapeMapAccess::new(self))
} else {
Err(de::Error::custom("expected JSON object"))
}
}
fn deserialize_struct<V: Visitor<'de>>(
self,
_name: &'static str,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Error> {
self.deserialize_map(visitor)
}
fn deserialize_enum<V: Visitor<'de>>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Error> {
match self.tape[self.pos].kind() {
DomEntryKind::String | DomEntryKind::EscapedString => {
visitor.visit_enum(UnitVariantAccess(self))
}
DomEntryKind::StartObject => visitor.visit_enum(TapeEnumAccess::new(self)),
_ => Err(de::Error::custom(
"expected string or single-key object for enum",
)),
}
}
fn deserialize_identifier<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Error> {
self.deserialize_str(v)
}
fn deserialize_ignored_any<V: Visitor<'de>>(self, v: V) -> Result<V::Value, Error> {
self.deserialize_any(v)
}
}
struct TapeSeqAccess<'de> {
iter: DomArrayIter<'de, 'de>,
}
impl<'de> TapeSeqAccess<'de> {
fn new(r: DomRef<'de, 'de>) -> Self {
Self {
iter: r.array_iter().expect("expected StartArray entry"),
}
}
}
impl<'de> SeqAccess<'de> for TapeSeqAccess<'de> {
type Error = Error;
fn next_element_seed<T: DeserializeSeed<'de>>(
&mut self,
seed: T,
) -> Result<Option<T::Value>, Error> {
match self.iter.next() {
None => Ok(None),
Some(elem) => seed.deserialize(elem).map(Some),
}
}
}
struct TapeMapAccess<'de> {
iter: DomObjectIter<'de, 'de>,
pending_value: Option<DomRef<'de, 'de>>,
}
impl<'de> TapeMapAccess<'de> {
fn new(r: DomRef<'de, 'de>) -> Self {
Self {
iter: r.object_iter().expect("expected StartObject entry"),
pending_value: None,
}
}
}
impl<'de> MapAccess<'de> for TapeMapAccess<'de> {
type Error = Error;
fn next_key_seed<K: DeserializeSeed<'de>>(
&mut self,
seed: K,
) -> Result<Option<K::Value>, Error> {
match self.iter.next() {
None => Ok(None),
Some((key, val)) => {
self.pending_value = Some(val);
seed.deserialize(KeyDeserializer(key)).map(Some)
}
}
}
fn next_value_seed<V: DeserializeSeed<'de>>(&mut self, seed: V) -> Result<V::Value, Error> {
let val = self
.pending_value
.take()
.ok_or_else(|| de::Error::custom("next_value called before next_key"))?;
seed.deserialize(val)
}
}
struct KeyDeserializer<'de>(&'de str);
impl<'de> de::Deserializer<'de> for KeyDeserializer<'de> {
type Error = Error;
fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Error> {
visitor.visit_borrowed_str(self.0)
}
serde::forward_to_deserialize_any! {
bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct
map struct enum identifier ignored_any
}
}
struct UnitVariantAccess<'de>(DomRef<'de, 'de>);
impl<'de> EnumAccess<'de> for UnitVariantAccess<'de> {
type Error = Error;
type Variant = UnitOnly;
fn variant_seed<V: DeserializeSeed<'de>>(self, seed: V) -> Result<(V::Value, UnitOnly), Error> {
let val = seed.deserialize(self.0)?;
Ok((val, UnitOnly))
}
}
struct UnitOnly;
impl<'de> VariantAccess<'de> for UnitOnly {
type Error = Error;
fn unit_variant(self) -> Result<(), Error> {
Ok(())
}
fn newtype_variant_seed<T: DeserializeSeed<'de>>(self, _: T) -> Result<T::Value, Error> {
Err(de::Error::custom("expected unit variant, got newtype"))
}
fn tuple_variant<V: Visitor<'de>>(self, _: usize, _: V) -> Result<V::Value, Error> {
Err(de::Error::custom("expected unit variant, got tuple"))
}
fn struct_variant<V: Visitor<'de>>(
self,
_: &'static [&'static str],
_: V,
) -> Result<V::Value, Error> {
Err(de::Error::custom("expected unit variant, got struct"))
}
}
struct TapeEnumAccess<'de> {
key: &'de str,
val: DomRef<'de, 'de>,
}
impl<'de> TapeEnumAccess<'de> {
fn new(r: DomRef<'de, 'de>) -> Self {
let mut iter = r.object_iter().expect("expected StartObject");
let (key, val) = iter
.next()
.expect("enum object must have at least one key-value pair");
Self { key, val }
}
}
impl<'de> EnumAccess<'de> for TapeEnumAccess<'de> {
type Error = Error;
type Variant = DomRef<'de, 'de>;
fn variant_seed<V: DeserializeSeed<'de>>(
self,
seed: V,
) -> Result<(V::Value, DomRef<'de, 'de>), Error> {
let variant = seed.deserialize(KeyDeserializer(self.key))?;
Ok((variant, self.val))
}
}
impl<'de> VariantAccess<'de> for DomRef<'de, 'de> {
type Error = Error;
fn unit_variant(self) -> Result<(), Error> {
Ok(())
}
fn newtype_variant_seed<T: DeserializeSeed<'de>>(self, seed: T) -> Result<T::Value, Error> {
seed.deserialize(self)
}
fn tuple_variant<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value, Error> {
de::Deserializer::deserialize_seq(self, visitor)
}
fn struct_variant<V: Visitor<'de>>(
self,
_fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Error> {
de::Deserializer::deserialize_map(self, visitor)
}
}
pub fn from_taperef<'de, T>(r: DomRef<'de, 'de>) -> Result<T, Error>
where
T: Deserialize<'de>,
{
T::deserialize(r)
}