use crate::{
de::key::QNameDeserializer,
de::simple_type::SimpleTypeDeserializer,
de::{str2bool, DeEvent, Deserializer, XmlRead, TEXT_KEY, VALUE_KEY},
encoding::Decoder,
errors::serialize::DeError,
events::attributes::IterState,
events::BytesStart,
name::QName,
};
use serde::de::{self, DeserializeSeed, IntoDeserializer, SeqAccess, Visitor};
use serde::serde_if_integer128;
use std::borrow::Cow;
use std::ops::Range;
#[derive(Debug, PartialEq)]
enum ValueSource {
Unknown,
Attribute(Range<usize>),
Text,
Content,
Nested,
}
pub(crate) struct MapAccess<'de, 'a, R>
where
R: XmlRead<'de>,
{
start: BytesStart<'de>,
de: &'a mut Deserializer<'de, R>,
iter: IterState,
source: ValueSource,
fields: &'static [&'static str],
has_value_field: bool,
}
impl<'de, 'a, R> MapAccess<'de, 'a, R>
where
R: XmlRead<'de>,
{
pub fn new(
de: &'a mut Deserializer<'de, R>,
start: BytesStart<'de>,
fields: &'static [&'static str],
) -> Result<Self, DeError> {
Ok(MapAccess {
de,
iter: IterState::new(start.name().as_ref().len(), false),
start,
source: ValueSource::Unknown,
fields,
has_value_field: fields.contains(&VALUE_KEY),
})
}
}
impl<'de, 'a, R> de::MapAccess<'de> for MapAccess<'de, 'a, R>
where
R: XmlRead<'de>,
{
type Error = DeError;
fn next_key_seed<K: DeserializeSeed<'de>>(
&mut self,
seed: K,
) -> Result<Option<K::Value>, Self::Error> {
debug_assert_eq!(self.source, ValueSource::Unknown);
let slice = &self.start.buf;
let decoder = self.de.reader.decoder();
if let Some(a) = self.iter.next(slice).transpose()? {
let (key, value) = a.into();
self.source = ValueSource::Attribute(value.unwrap_or_default());
let de = QNameDeserializer::from_attr(QName(&slice[key]), decoder)?;
seed.deserialize(de).map(Some)
} else {
match self.de.peek()? {
DeEvent::Text(_) | DeEvent::CData(_) if self.has_value_field => {
self.source = ValueSource::Content;
seed.deserialize(VALUE_KEY.into_deserializer()).map(Some)
}
DeEvent::Text(_) | DeEvent::CData(_) => {
self.source = ValueSource::Text;
seed.deserialize(TEXT_KEY.into_deserializer()).map(Some)
}
DeEvent::Start(e) if self.has_value_field && not_in(self.fields, e, decoder)? => {
self.source = ValueSource::Content;
seed.deserialize(VALUE_KEY.into_deserializer()).map(Some)
}
DeEvent::Start(e) => {
self.source = ValueSource::Nested;
let de = QNameDeserializer::from_elem(e.name(), decoder)?;
seed.deserialize(de).map(Some)
}
DeEvent::End(e) if e.name() == self.start.name() => Ok(None),
DeEvent::End(e) => Err(DeError::UnexpectedEnd(e.name().as_ref().to_owned())),
DeEvent::Eof => Err(DeError::UnexpectedEof),
}
}
}
fn next_value_seed<K: DeserializeSeed<'de>>(
&mut self,
seed: K,
) -> Result<K::Value, Self::Error> {
match std::mem::replace(&mut self.source, ValueSource::Unknown) {
ValueSource::Attribute(value) => seed.deserialize(SimpleTypeDeserializer::from_part(
&self.start.buf,
value,
true,
self.de.reader.decoder(),
)),
ValueSource::Text => match self.de.next()? {
DeEvent::Text(e) => seed.deserialize(SimpleTypeDeserializer::from_text_content(
e.decode(true)?,
)),
DeEvent::CData(e) => seed.deserialize(SimpleTypeDeserializer::from_text_content(
e.decode()?,
)),
_ => unreachable!(),
},
ValueSource::Content => seed.deserialize(MapValueDeserializer {
map: self,
allow_start: false,
}),
ValueSource::Nested => seed.deserialize(MapValueDeserializer {
map: self,
allow_start: true,
}),
ValueSource::Unknown => Err(DeError::KeyNotRead),
}
}
}
macro_rules! forward {
(
$deserialize:ident
$(
($($name:ident : $type:ty),*)
)?
) => {
#[inline]
fn $deserialize<V: Visitor<'de>>(
self,
$($($name: $type,)*)?
visitor: V
) -> Result<V::Value, Self::Error> {
self.map.de.$deserialize($($($name,)*)? visitor)
}
};
}
struct MapValueDeserializer<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
map: &'m mut MapAccess<'de, 'a, R>,
allow_start: bool,
}
impl<'de, 'a, 'm, R> MapValueDeserializer<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
#[inline]
fn read_string(&mut self, unescape: bool) -> Result<Cow<'de, str>, DeError> {
self.map.de.read_string_impl(unescape, self.allow_start)
}
}
impl<'de, 'a, 'm, R> de::Deserializer<'de> for MapValueDeserializer<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
type Error = DeError;
deserialize_primitives!(mut);
forward!(deserialize_option);
forward!(deserialize_unit);
forward!(deserialize_unit_struct(name: &'static str));
forward!(deserialize_newtype_struct(name: &'static str));
forward!(deserialize_map);
forward!(deserialize_struct(
name: &'static str,
fields: &'static [&'static str]
));
forward!(deserialize_enum(
name: &'static str,
variants: &'static [&'static str]
));
forward!(deserialize_any);
forward!(deserialize_ignored_any);
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, DeError>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
len: usize,
visitor: V,
) -> Result<V::Value, DeError>
where
V: Visitor<'de>,
{
self.deserialize_tuple(len, visitor)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let filter = if self.allow_start {
match self.map.de.peek()? {
DeEvent::Start(e) => TagFilter::Include(e.clone()),
_ => unreachable!(),
}
} else {
TagFilter::Exclude(self.map.fields)
};
visitor.visit_seq(MapValueSeqAccess {
#[cfg(feature = "overlapped-lists")]
checkpoint: self.map.de.skip_checkpoint(),
map: self.map,
filter,
})
}
#[inline]
fn is_human_readable(&self) -> bool {
self.map.de.is_human_readable()
}
}
fn not_in(
fields: &'static [&'static str],
start: &BytesStart,
decoder: Decoder,
) -> Result<bool, DeError> {
let tag = decoder.decode(start.name().into_inner())?;
Ok(fields.iter().all(|&field| field != tag.as_ref()))
}
#[derive(Debug)]
enum TagFilter<'de> {
Include(BytesStart<'de>), Exclude(&'static [&'static str]),
}
impl<'de> TagFilter<'de> {
fn is_suitable(&self, start: &BytesStart, decoder: Decoder) -> Result<bool, DeError> {
match self {
Self::Include(n) => Ok(n.name() == start.name()),
Self::Exclude(fields) => not_in(fields, start, decoder),
}
}
}
struct MapValueSeqAccess<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
map: &'m mut MapAccess<'de, 'a, R>,
filter: TagFilter<'de>,
#[cfg(feature = "overlapped-lists")]
checkpoint: usize,
}
#[cfg(feature = "overlapped-lists")]
impl<'de, 'a, 'm, R> Drop for MapValueSeqAccess<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
fn drop(&mut self) {
self.map.de.start_replay(self.checkpoint);
}
}
impl<'de, 'a, 'm, R> SeqAccess<'de> for MapValueSeqAccess<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
type Error = DeError;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, DeError>
where
T: DeserializeSeed<'de>,
{
let decoder = self.map.de.reader.decoder();
loop {
break match self.map.de.peek()? {
#[cfg(feature = "overlapped-lists")]
DeEvent::Start(e) if !self.filter.is_suitable(e, decoder)? => {
self.map.de.skip()?;
continue;
}
#[cfg(not(feature = "overlapped-lists"))]
DeEvent::Start(e) if !self.filter.is_suitable(e, decoder)? => Ok(None),
DeEvent::End(e) if e.name() == self.map.start.name() => Ok(None),
DeEvent::End(e) => Err(DeError::UnexpectedEnd(e.name().as_ref().to_owned())),
DeEvent::Eof => Err(DeError::UnexpectedEof),
_ => seed
.deserialize(SeqItemDeserializer { map: self.map })
.map(Some),
};
}
}
}
struct SeqItemDeserializer<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
map: &'m mut MapAccess<'de, 'a, R>,
}
impl<'de, 'a, 'm, R> SeqItemDeserializer<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
#[inline]
fn read_string(&mut self, unescape: bool) -> Result<Cow<'de, str>, DeError> {
self.map.de.read_string_impl(unescape, true)
}
}
impl<'de, 'a, 'm, R> de::Deserializer<'de> for SeqItemDeserializer<'de, 'a, 'm, R>
where
R: XmlRead<'de>,
{
type Error = DeError;
deserialize_primitives!(mut);
forward!(deserialize_option);
forward!(deserialize_unit);
forward!(deserialize_unit_struct(name: &'static str));
forward!(deserialize_newtype_struct(name: &'static str));
forward!(deserialize_map);
forward!(deserialize_struct(
name: &'static str,
fields: &'static [&'static str]
));
forward!(deserialize_enum(
name: &'static str,
variants: &'static [&'static str]
));
forward!(deserialize_any);
forward!(deserialize_ignored_any);
fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, DeError>
where
V: Visitor<'de>,
{
self.deserialize_seq(visitor)
}
fn deserialize_tuple_struct<V>(
self,
_name: &'static str,
len: usize,
visitor: V,
) -> Result<V::Value, DeError>
where
V: Visitor<'de>,
{
self.deserialize_tuple(len, visitor)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.map.de.next()? {
DeEvent::Text(e) => SimpleTypeDeserializer::from_text_content(
e.decode(true)?,
)
.deserialize_seq(visitor),
DeEvent::CData(e) => SimpleTypeDeserializer::from_text_content(
e.decode()?,
)
.deserialize_seq(visitor),
DeEvent::Start(e) => {
let value = match self.map.de.next()? {
DeEvent::Text(e) => SimpleTypeDeserializer::from_text_content(
e.decode(true)?,
)
.deserialize_seq(visitor),
DeEvent::CData(e) => SimpleTypeDeserializer::from_text_content(
e.decode()?,
)
.deserialize_seq(visitor),
e => Err(DeError::Unsupported(
format!("unsupported event {:?}", e).into(),
)),
};
self.map.de.read_to_end(e.name())?;
value
}
_ => unreachable!(),
}
}
#[inline]
fn is_human_readable(&self) -> bool {
self.map.de.is_human_readable()
}
}
#[test]
fn test_not_in() {
let tag = BytesStart::new("tag");
assert_eq!(not_in(&[], &tag, Decoder::utf8()).unwrap(), true);
assert_eq!(
not_in(&["no", "such", "tags"], &tag, Decoder::utf8()).unwrap(),
true
);
assert_eq!(
not_in(&["some", "tag", "included"], &tag, Decoder::utf8()).unwrap(),
false
);
}