use crate::{
de::escape::EscapedDeserializer,
de::seq::{not_in, TagFilter},
de::simple_type::SimpleTypeDeserializer,
de::{str2bool, DeEvent, Deserializer, XmlRead, INNER_VALUE, UNFLATTEN_PREFIX},
errors::serialize::DeError,
events::attributes::IterState,
events::BytesStart,
};
use serde::de::{self, DeserializeSeed, IntoDeserializer, SeqAccess, Visitor};
use serde::serde_if_integer128;
use std::borrow::Cow;
use std::ops::Range;
#[derive(Debug, PartialEq)]
enum ValueSource {
Unknown,
Attribute(Range<usize>),
Text,
Content,
Nested,
}
pub(crate) struct MapAccess<'de, 'a, R>
where
R: XmlRead<'de>,
{
start: BytesStart<'de>,
de: &'a mut Deserializer<'de, R>,
iter: IterState,
source: ValueSource,
fields: &'static [&'static str],
has_value_field: bool,
unflatten_fields: Vec<&'static [u8]>,
}
impl<'de, 'a, R> MapAccess<'de, 'a, R>
where
R: XmlRead<'de>,
{
pub fn new(
de: &'a mut Deserializer<'de, R>,
start: BytesStart<'de>,
fields: &'static [&'static str],
) -> Result<Self, DeError> {
Ok(MapAccess {
de,
iter: IterState::new(start.name().as_ref().len(), false),
start,
source: ValueSource::Unknown,
fields,
has_value_field: fields.contains(&INNER_VALUE),
unflatten_fields: fields
.iter()
.filter(|f| f.starts_with(UNFLATTEN_PREFIX))
.map(|f| f.as_bytes())
.collect(),
})
}
}
impl<'de, 'a, R> de::MapAccess<'de> for MapAccess<'de, 'a, R>
where
R: XmlRead<'de>,
{
type Error = DeError;
fn next_key_seed<K: DeserializeSeed<'de>>(
&mut self,
seed: K,
) -> Result<Option<K::Value>, Self::Error> {
debug_assert_eq!(self.source, ValueSource::Unknown);
let slice = &self.start.buf;
let decoder = self.de.reader.decoder();
if let Some(a) = self.iter.next(slice).transpose()? {
let (key, value) = a.into();
self.source = ValueSource::Attribute(value.unwrap_or_default());
seed.deserialize(EscapedDeserializer::new(
Cow::Borrowed(&slice[key]),
decoder,
false,
))
.map(Some)
} else {
match self.de.peek()? {
DeEvent::Text(_) | DeEvent::CData(_) => {
self.source = ValueSource::Text;
seed.deserialize(INNER_VALUE.into_deserializer()).map(Some)
}
DeEvent::Start(e) if self.has_value_field && not_in(self.fields, e, decoder)? => {
self.source = ValueSource::Content;
seed.deserialize(INNER_VALUE.into_deserializer()).map(Some)
}
DeEvent::Start(e) => {
self.source = ValueSource::Nested;
let key = if let Some(p) = self
.unflatten_fields
.iter()
.position(|f| e.name().as_ref() == &f[UNFLATTEN_PREFIX.len()..])
{
seed.deserialize(self.unflatten_fields.remove(p).into_deserializer())
} else {
let name = Cow::Borrowed(e.local_name().into_inner());
seed.deserialize(EscapedDeserializer::new(name, decoder, false))
};
key.map(Some)
}
DeEvent::End(e) if e.name() == self.start.name() => Ok(None),
DeEvent::End(e) => Err(DeError::UnexpectedEnd(e.name().as_ref().to_owned())),
DeEvent::Eof => Err(DeError::UnexpectedEof),
}
}
}
fn next_value_seed<K: DeserializeSeed<'de>>(
&mut self,
seed: K,
) -> Result<K::Value, Self::Error> {
match std::mem::replace(&mut self.source, ValueSource::Unknown) {
ValueSource::Attribute(value) => seed.deserialize(SimpleTypeDeserializer::from_part(
&self.start.buf,
value,
true,
self.de.reader.decoder(),
)),
ValueSource::Text => match self.de.next()? {
DeEvent::Text(e) => seed.deserialize(SimpleTypeDeserializer::from_cow(
e.into_inner(),
true,
self.de.reader.decoder(),
)),
DeEvent::CData(e) => seed.deserialize(SimpleTypeDeserializer::from_cow(
e.into_inner(),
false,
self.de.reader.decoder(),
)),
_ => unreachable!(),
},
ValueSource::Content => seed.deserialize(MapValueDeserializer {
map: self,
allow_start: false,
}),
ValueSource::Nested => seed.deserialize(MapValueDeserializer {
map: self,
allow_start: true,
}),
ValueSource::Unknown => Err(DeError::KeyNotRead),
}
}
}
macro_rules! forward {
(
$deserialize:ident
$(
($($name:ident : $type:ty),*)
)?
) => {
#[inline]
fn $deserialize<V: Visitor<'de>>(
self,
$($($name: $type,)*)?
visitor: V
) -> Result<V::Value, Self::Error> {
self.map.de.$deserialize($($($name,)*)? visitor)
}
};
}
struct MapValueDeserializer<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
map: &'m mut MapAccess<'de, 'a, R>,
allow_start: bool,
}
impl<'de, 'a, 'm, R> MapValueDeserializer<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
#[inline]
fn next_text(&mut self, unescape: bool) -> Result<Cow<'de, str>, DeError> {
self.map.de.next_text_impl(unescape, self.allow_start)
}
}
impl<'de, 'a, 'm, R> de::Deserializer<'de> for MapValueDeserializer<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
type Error = DeError;
deserialize_primitives!(mut);
forward!(deserialize_option);
forward!(deserialize_unit);
forward!(deserialize_unit_struct(name: &'static str));
forward!(deserialize_newtype_struct(name: &'static str));
forward!(deserialize_map);
forward!(deserialize_struct(
name: &'static str,
fields: &'static [&'static str]
));
forward!(deserialize_enum(
name: &'static str,
variants: &'static [&'static str]
));
forward!(deserialize_any);
forward!(deserialize_ignored_any);
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, DeError>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
len: usize,
visitor: V,
) -> Result<V::Value, DeError>
where
V: Visitor<'de>,
{
self.deserialize_tuple(len, visitor)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let filter = if self.allow_start {
match self.map.de.peek()? {
DeEvent::Start(e) => TagFilter::Include(e.clone()),
_ => unreachable!(),
}
} else {
TagFilter::Exclude(self.map.fields)
};
visitor.visit_seq(MapValueSeqAccess {
#[cfg(feature = "overlapped-lists")]
checkpoint: self.map.de.skip_checkpoint(),
map: self.map,
filter,
})
}
#[inline]
fn is_human_readable(&self) -> bool {
self.map.de.is_human_readable()
}
}
struct MapValueSeqAccess<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
map: &'m mut MapAccess<'de, 'a, R>,
filter: TagFilter<'de>,
#[cfg(feature = "overlapped-lists")]
checkpoint: usize,
}
#[cfg(feature = "overlapped-lists")]
impl<'de, 'a, 'm, R> Drop for MapValueSeqAccess<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
fn drop(&mut self) {
self.map.de.start_replay(self.checkpoint);
}
}
impl<'de, 'a, 'm, R> SeqAccess<'de> for MapValueSeqAccess<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
type Error = DeError;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, DeError>
where
T: DeserializeSeed<'de>,
{
let decoder = self.map.de.reader.decoder();
loop {
break match self.map.de.peek()? {
#[cfg(feature = "overlapped-lists")]
DeEvent::Start(e) if !self.filter.is_suitable(e, decoder)? => {
self.map.de.skip()?;
continue;
}
#[cfg(not(feature = "overlapped-lists"))]
DeEvent::Start(e) if !self.filter.is_suitable(e, decoder)? => Ok(None),
DeEvent::End(e) if e.name() == self.map.start.name() => Ok(None),
DeEvent::End(e) => Err(DeError::UnexpectedEnd(e.name().as_ref().to_owned())),
DeEvent::Eof => Err(DeError::UnexpectedEof),
_ => seed
.deserialize(SeqValueDeserializer { map: self.map })
.map(Some),
};
}
}
}
struct SeqValueDeserializer<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
map: &'m mut MapAccess<'de, 'a, R>,
}
impl<'de, 'a, 'm, R> SeqValueDeserializer<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
#[inline]
fn next_text(&mut self, unescape: bool) -> Result<Cow<'de, str>, DeError> {
self.map.de.next_text_impl(unescape, true)
}
}
impl<'de, 'a, 'm, R> de::Deserializer<'de> for SeqValueDeserializer<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
type Error = DeError;
deserialize_primitives!(mut);
forward!(deserialize_option);
forward!(deserialize_unit);
forward!(deserialize_unit_struct(name: &'static str));
forward!(deserialize_newtype_struct(name: &'static str));
forward!(deserialize_map);
forward!(deserialize_struct(
name: &'static str,
fields: &'static [&'static str]
));
forward!(deserialize_enum(
name: &'static str,
variants: &'static [&'static str]
));
forward!(deserialize_any);
forward!(deserialize_ignored_any);
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, DeError>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
len: usize,
visitor: V,
) -> Result<V::Value, DeError>
where
V: Visitor<'de>,
{
self.deserialize_tuple(len, visitor)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.map.de.next()? {
DeEvent::Text(e) => SimpleTypeDeserializer::from_cow(
e.into_inner(),
true,
self.map.de.reader.decoder(),
)
.deserialize_seq(visitor),
DeEvent::CData(e) => SimpleTypeDeserializer::from_cow(
e.into_inner(),
false,
self.map.de.reader.decoder(),
)
.deserialize_seq(visitor),
DeEvent::Start(e) => {
let value = match self.map.de.next()? {
DeEvent::Text(e) => SimpleTypeDeserializer::from_cow(
e.into_inner(),
true,
self.map.de.reader.decoder(),
)
.deserialize_seq(visitor),
DeEvent::CData(e) => SimpleTypeDeserializer::from_cow(
e.into_inner(),
false,
self.map.de.reader.decoder(),
)
.deserialize_seq(visitor),
e => Err(DeError::Unsupported(
format!("unsupported event {:?}", e).into(),
)),
};
self.map.de.read_to_end(e.name())?;
value
}
_ => unreachable!(),
}
}
#[inline]
fn is_human_readable(&self) -> bool {
self.map.de.is_human_readable()
}
}