use crate::{
de::key::QNameDeserializer,
de::resolver::EntityResolver,
de::simple_type::SimpleTypeDeserializer,
de::text::TextDeserializer,
de::{DeEvent, Deserializer, XmlRead, TEXT_KEY, VALUE_KEY},
errors::serialize::DeError,
errors::Error,
events::attributes::IterState,
events::BytesStart,
name::QName,
};
use serde::de::value::BorrowedStrDeserializer;
use serde::de::{self, DeserializeSeed, Deserializer as _, MapAccess, SeqAccess, Visitor};
use std::borrow::Cow;
use std::ops::Range;
#[derive(Debug, PartialEq)]
enum ValueSource {
Unknown,
Attribute(Range<usize>),
Text,
Content,
Nested,
}
pub(crate) struct ElementMapAccess<'de, 'd, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
start: BytesStart<'de>,
de: &'d mut Deserializer<'de, R, E>,
iter: IterState,
source: ValueSource,
fields: &'static [&'static str],
has_value_field: bool,
has_text_field: bool,
}
impl<'de, 'd, R, E> ElementMapAccess<'de, 'd, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
pub fn new(
de: &'d mut Deserializer<'de, R, E>,
start: BytesStart<'de>,
fields: &'static [&'static str],
) -> Self {
Self {
de,
iter: IterState::new(start.name().as_ref().len(), false),
start,
source: ValueSource::Unknown,
fields,
has_value_field: fields.contains(&VALUE_KEY),
has_text_field: fields.contains(&TEXT_KEY),
}
}
fn should_skip_subtree(&self, start: &BytesStart) -> bool {
self.de.reader.reader.has_nil_attr(&self.start) || self.de.reader.reader.has_nil_attr(start)
}
#[inline]
fn skip_whitespaces(&mut self) -> Result<(), DeError> {
self.de.skip_whitespaces()
}
}
impl<'de, 'd, R, E> MapAccess<'de> for ElementMapAccess<'de, 'd, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;
fn next_key_seed<K: DeserializeSeed<'de>>(
&mut self,
seed: K,
) -> Result<Option<K::Value>, Self::Error> {
debug_assert_eq!(self.source, ValueSource::Unknown);
let slice = &self.start.buf;
let decoder = self.start.decoder();
if let Some(a) = self.iter.next(slice).transpose()? {
let (key, value) = a.into();
self.source = ValueSource::Attribute(value.unwrap_or_default());
self.de.key_buf.clear();
self.de.key_buf.push('@');
let de =
QNameDeserializer::from_attr(QName(&slice[key]), decoder, &mut self.de.key_buf)?;
seed.deserialize(de).map(Some)
} else {
self.skip_whitespaces()?;
match self.de.peek()? {
DeEvent::Text(_) if self.has_value_field && !self.has_text_field => {
self.source = ValueSource::Content;
let de = BorrowedStrDeserializer::<DeError>::new(VALUE_KEY);
seed.deserialize(de).map(Some)
}
DeEvent::Text(_) => {
self.source = ValueSource::Text;
let de = BorrowedStrDeserializer::<DeError>::new(TEXT_KEY);
seed.deserialize(de).map(Some)
}
DeEvent::Start(e) if self.has_value_field && not_in(self.fields, e)? => {
self.source = ValueSource::Content;
let de = BorrowedStrDeserializer::<DeError>::new(VALUE_KEY);
seed.deserialize(de).map(Some)
}
DeEvent::Start(e) => {
self.source = ValueSource::Nested;
let de = QNameDeserializer::from_elem(e)?;
seed.deserialize(de).map(Some)
}
DeEvent::End(e) => {
debug_assert_eq!(self.start.name(), e.name());
self.de.next()?;
Ok(None)
}
DeEvent::Eof => {
Err(Error::missed_end(self.start.name(), self.start.decoder()).into())
}
}
}
}
fn next_value_seed<K: DeserializeSeed<'de>>(
&mut self,
seed: K,
) -> Result<K::Value, Self::Error> {
match std::mem::replace(&mut self.source, ValueSource::Unknown) {
ValueSource::Attribute(value) => seed.deserialize(SimpleTypeDeserializer::from_part(
&self.start.buf,
value,
self.start.decoder(),
)),
ValueSource::Text => match self.de.next()? {
DeEvent::Text(e) => seed.deserialize(SimpleTypeDeserializer::from_text_content(e)),
_ => unreachable!(),
},
ValueSource::Content => seed.deserialize(MapValueDeserializer {
map: self,
fixed_name: false,
}),
ValueSource::Nested => seed.deserialize(MapValueDeserializer {
map: self,
fixed_name: true,
}),
ValueSource::Unknown => Err(DeError::KeyNotRead),
}
}
}
struct MapValueDeserializer<'de, 'd, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
map: &'m mut ElementMapAccess<'de, 'd, R, E>,
fixed_name: bool,
}
impl<'de, 'd, 'm, R, E> MapValueDeserializer<'de, 'd, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
#[inline]
fn read_string(&mut self) -> Result<Cow<'de, str>, DeError> {
self.map.de.read_string_impl(self.fixed_name)
}
}
impl<'de, 'd, 'm, R, E> de::Deserializer<'de> for MapValueDeserializer<'de, 'd, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;
deserialize_primitives!(mut);
#[inline]
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.map.de.deserialize_unit(visitor)
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let _ = self.map.de.peek()?;
match self.map.de.last_peeked() {
DeEvent::Text(t) if t.is_empty() => visitor.visit_none(),
DeEvent::Start(start) if self.map.should_skip_subtree(start) => {
self.map.de.skip_next_tree()?;
visitor.visit_none()
}
_ => visitor.visit_some(self),
}
}
fn deserialize_newtype_struct<V>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let filter = if self.fixed_name {
match self.map.de.peek()? {
DeEvent::Start(e) => TagFilter::Include(e.clone()),
_ => unreachable!(),
}
} else {
TagFilter::Exclude(self.map.fields, self.map.has_text_field)
};
visitor.visit_seq(MapValueSeqAccess {
#[cfg(feature = "overlapped-lists")]
checkpoint: self.map.de.skip_checkpoint(),
map: self.map,
filter,
})
}
#[inline]
fn deserialize_struct<V>(
self,
name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.map.de.deserialize_struct(name, fields, visitor)
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
if self.fixed_name {
match self.map.de.next()? {
DeEvent::Start(e) => {
let text = self.map.de.read_text(e.name())?;
if text.is_empty() {
visitor.visit_enum(SimpleTypeDeserializer::from_text(TEXT_KEY.into()))
} else {
visitor.visit_enum(SimpleTypeDeserializer::from_text(text))
}
}
_ => unreachable!(),
}
} else {
visitor.visit_enum(self)
}
}
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.map.de.peek()? {
DeEvent::Text(_) => self.deserialize_str(visitor),
_ => self.deserialize_map(visitor),
}
}
}
impl<'de, 'd, 'm, R, E> de::EnumAccess<'de> for MapValueDeserializer<'de, 'd, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;
type Variant = MapValueVariantAccess<'de, 'd, 'm, R, E>;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
where
V: DeserializeSeed<'de>,
{
let (name, is_text) = match self.map.de.peek()? {
DeEvent::Start(e) => (seed.deserialize(QNameDeserializer::from_elem(e)?)?, false),
DeEvent::Text(_) => (
seed.deserialize(BorrowedStrDeserializer::<DeError>::new(TEXT_KEY))?,
true,
),
_ => unreachable!(),
};
Ok((
name,
MapValueVariantAccess {
map: self.map,
is_text,
},
))
}
}
struct MapValueVariantAccess<'de, 'd, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
map: &'m mut ElementMapAccess<'de, 'd, R, E>,
is_text: bool,
}
impl<'de, 'd, 'm, R, E> de::VariantAccess<'de> for MapValueVariantAccess<'de, 'd, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;
fn unit_variant(self) -> Result<(), Self::Error> {
match self.map.de.next()? {
DeEvent::Start(e) => self.map.de.read_to_end(e.name()),
DeEvent::Text(_) => Ok(()),
_ => unreachable!("Only `Start` or `Text` events are possible here"),
}
}
fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
where
T: DeserializeSeed<'de>,
{
if self.is_text {
match self.map.de.next()? {
DeEvent::Text(e) => seed.deserialize(SimpleTypeDeserializer::from_text_content(e)),
_ => unreachable!("Only `Text` events are possible here"),
}
} else {
seed.deserialize(MapValueDeserializer {
map: self.map,
fixed_name: true,
})
}
}
fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
if self.is_text {
match self.map.de.next()? {
DeEvent::Text(e) => {
SimpleTypeDeserializer::from_text_content(e).deserialize_tuple(len, visitor)
}
_ => unreachable!("Only `Text` events are possible here"),
}
} else {
MapValueDeserializer {
map: self.map,
fixed_name: true,
}
.deserialize_tuple(len, visitor)
}
}
fn struct_variant<V>(
self,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
match self.map.de.next()? {
DeEvent::Start(e) => visitor.visit_map(ElementMapAccess::new(self.map.de, e, fields)),
DeEvent::Text(e) => {
SimpleTypeDeserializer::from_text_content(e).deserialize_struct("", fields, visitor)
}
_ => unreachable!("Only `Start` or `Text` events are possible here"),
}
}
}
fn not_in(fields: &'static [&'static str], start: &BytesStart) -> Result<bool, DeError> {
let tag = start.decoder().decode(start.local_name().into_inner())?;
Ok(fields.iter().all(|&field| field != tag.as_ref()))
}
#[derive(Debug)]
enum TagFilter<'de> {
Include(BytesStart<'de>), Exclude(&'static [&'static str], bool),
}
impl<'de> TagFilter<'de> {
fn is_suitable(&self, start: &BytesStart) -> Result<bool, DeError> {
match self {
Self::Include(n) => Ok(n.name() == start.name()),
Self::Exclude(fields, _) => not_in(fields, start),
}
}
const fn need_skip_text(&self) -> bool {
match self {
Self::Include(_) => true,
Self::Exclude(_, has_text_field) => *has_text_field,
}
}
}
struct MapValueSeqAccess<'de, 'd, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
map: &'m mut ElementMapAccess<'de, 'd, R, E>,
filter: TagFilter<'de>,
#[cfg(feature = "overlapped-lists")]
checkpoint: usize,
}
#[cfg(feature = "overlapped-lists")]
impl<'de, 'd, 'm, R, E> Drop for MapValueSeqAccess<'de, 'd, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
fn drop(&mut self) {
self.map.de.start_replay(self.checkpoint);
}
}
impl<'de, 'd, 'm, R, E> SeqAccess<'de> for MapValueSeqAccess<'de, 'd, 'm, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;
fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, DeError>
where
T: DeserializeSeed<'de>,
{
loop {
self.map.skip_whitespaces()?;
break match self.map.de.peek()? {
#[cfg(feature = "overlapped-lists")]
DeEvent::Start(e) if !self.filter.is_suitable(e)? => {
self.map.de.skip()?;
continue;
}
#[cfg(feature = "overlapped-lists")]
DeEvent::Text(_) if self.filter.need_skip_text() => {
self.map.de.skip()?;
continue;
}
#[cfg(not(feature = "overlapped-lists"))]
DeEvent::Start(e) if !self.filter.is_suitable(e)? => Ok(None),
#[cfg(not(feature = "overlapped-lists"))]
DeEvent::Text(_) if self.filter.need_skip_text() => Ok(None),
DeEvent::End(e) => {
debug_assert_eq!(self.map.start.name(), e.name());
Ok(None)
}
DeEvent::Eof => {
Err(Error::missed_end(self.map.start.name(), self.map.start.decoder()).into())
}
DeEvent::Text(_) => match self.map.de.next()? {
DeEvent::Text(e) => seed.deserialize(TextDeserializer(e)).map(Some),
_ => unreachable!(),
},
DeEvent::Start(_) => match self.map.de.next()? {
DeEvent::Start(start) => seed
.deserialize(ElementDeserializer {
start,
de: self.map.de,
})
.map(Some),
_ => unreachable!(),
},
};
}
}
}
struct ElementDeserializer<'de, 'd, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
start: BytesStart<'de>,
de: &'d mut Deserializer<'de, R, E>,
}
impl<'de, 'd, R, E> ElementDeserializer<'de, 'd, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
#[inline]
fn read_string(&mut self) -> Result<Cow<'de, str>, DeError> {
self.de.read_text(self.start.name())
}
}
impl<'de, 'd, R, E> de::Deserializer<'de> for ElementDeserializer<'de, 'd, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;
deserialize_primitives!(mut);
fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.de.read_to_end(self.start.name())?;
visitor.visit_unit()
}
fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_some(self)
}
fn deserialize_newtype_struct<V>(
self,
_name: &'static str,
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_newtype_struct(self)
}
fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
let text = self.read_string()?;
SimpleTypeDeserializer::from_text(text).deserialize_seq(visitor)
}
fn deserialize_struct<V>(
self,
_name: &'static str,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_map(ElementMapAccess::new(self.de, self.start, fields))
}
fn deserialize_enum<V>(
self,
_name: &'static str,
_variants: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
visitor.visit_enum(self)
}
#[inline]
fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_map(visitor)
}
}
impl<'de, 'd, R, E> de::EnumAccess<'de> for ElementDeserializer<'de, 'd, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;
type Variant = Self;
fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
where
V: DeserializeSeed<'de>,
{
let name = seed.deserialize(QNameDeserializer::from_elem(&self.start)?)?;
Ok((name, self))
}
}
impl<'de, 'd, R, E> de::VariantAccess<'de> for ElementDeserializer<'de, 'd, R, E>
where
R: XmlRead<'de>,
E: EntityResolver,
{
type Error = DeError;
fn unit_variant(self) -> Result<(), Self::Error> {
self.de.read_to_end(self.start.name())
}
fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
where
T: DeserializeSeed<'de>,
{
seed.deserialize(self)
}
#[inline]
fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_tuple(len, visitor)
}
#[inline]
fn struct_variant<V>(
self,
fields: &'static [&'static str],
visitor: V,
) -> Result<V::Value, Self::Error>
where
V: Visitor<'de>,
{
self.deserialize_struct("", fields, visitor)
}
}
#[test]
fn test_not_in() {
use pretty_assertions::assert_eq;
let tag = BytesStart::new("tag");
assert_eq!(not_in(&[], &tag).unwrap(), true);
assert_eq!(not_in(&["no", "such", "tags"], &tag).unwrap(), true);
assert_eq!(not_in(&["some", "tag", "included"], &tag).unwrap(), false);
let tag_ns = BytesStart::new("ns1:tag");
assert_eq!(not_in(&["no", "such", "tags"], &tag_ns).unwrap(), true);
assert_eq!(
not_in(&["some", "tag", "included"], &tag_ns).unwrap(),
false
);
assert_eq!(
not_in(&["some", "namespace", "ns1:tag"], &tag_ns).unwrap(),
true
);
}