Skip to main content

robinson/serde/
mod.rs

1//! ```
2//! use serde::Deserialize;
3//! use robinson::serde::from_str;
4//!
5//! #[derive(Deserialize)]
6//! struct Record {
7//!     field: String,
8//! }
9//!
10//! let record = from_str::<Record>("<record><field>foobar</field></record>").unwrap();
11//!
12//! assert_eq!(record.field, "foobar");
13//! ```
14#[cfg(feature = "raw-node")]
15mod raw_node;
16
17use std::borrow::Cow;
18use std::char::ParseCharError;
19use std::error::Error as StdError;
20use std::fmt;
21use std::iter::{Peekable, once};
22use std::marker::PhantomData;
23use std::num::{NonZeroUsize, ParseFloatError, ParseIntError};
24use std::str::{FromStr, ParseBoolError};
25
26use serde_core::de;
27
28use crate::{Attribute, Document, Error as XmlError, Node, strings::cmp_opt_names};
29
30#[cfg(feature = "raw-node")]
31pub use raw_node::RawNode;
32
33pub fn from_str<T>(text: &str) -> Result<T, Box<Error>>
34where
35    T: de::DeserializeOwned,
36{
37    defaults().from_str(text)
38}
39
40pub fn from_doc<'de, 'input, T>(document: &'de Document<'input>) -> Result<T, Box<Error>>
41where
42    T: de::Deserialize<'de>,
43{
44    defaults().from_doc(document)
45}
46
47pub fn from_node<'de, 'input, T>(node: Node<'de, 'input>) -> Result<T, Box<Error>>
48where
49    T: de::Deserialize<'de>,
50{
51    defaults().from_node(node)
52}
53
54pub trait Options: Sized {
55    #[allow(clippy::wrong_self_convention)]
56    fn from_str<T>(self, text: &str) -> Result<T, Box<Error>>
57    where
58        T: de::DeserializeOwned,
59    {
60        let document = Document::parse(text).map_err(Error::ParseXml)?;
61        self.from_doc(&document)
62    }
63
64    #[allow(clippy::wrong_self_convention)]
65    fn from_doc<'de, 'input, T>(self, document: &'de Document<'input>) -> Result<T, Box<Error>>
66    where
67        T: de::Deserialize<'de>,
68    {
69        let node = document.root_element();
70        self.from_node(node)
71    }
72
73    #[allow(clippy::wrong_self_convention)]
74    fn from_node<'de, 'input, T>(self, node: Node<'de, 'input>) -> Result<T, Box<Error>>
75    where
76        T: de::Deserialize<'de>,
77    {
78        let len = node.document().len();
79
80        let deserializer = Deserializer {
81            source: Source::Node(node),
82            temp: &mut Temp::new(len),
83            options: PhantomData::<Self>,
84        };
85
86        T::deserialize(deserializer)
87    }
88
89    fn namespaces(self) -> Namespaces<Self> {
90        Namespaces(PhantomData)
91    }
92
93    fn prefix_attr(self) -> PrefixAttr<Self> {
94        PrefixAttr(PhantomData)
95    }
96
97    fn only_children(self) -> OnlyChildren<Self> {
98        OnlyChildren(PhantomData)
99    }
100
101    #[doc(hidden)]
102    const NAMESPACES: bool = false;
103
104    #[doc(hidden)]
105    const PREFIX_ATTR: bool = false;
106
107    #[doc(hidden)]
108    const ONLY_CHILDREN: bool = false;
109}
110
111#[doc(hidden)]
112#[derive(Clone, Copy, Default, Debug)]
113pub struct Defaults;
114
115pub fn defaults() -> Defaults {
116    Defaults
117}
118
119impl Options for Defaults {}
120
121#[doc(hidden)]
122#[derive(Clone, Copy, Default, Debug)]
123pub struct Namespaces<O>(PhantomData<O>);
124
125impl<O> Options for Namespaces<O>
126where
127    O: Options,
128{
129    const NAMESPACES: bool = true;
130    const PREFIX_ATTR: bool = O::PREFIX_ATTR;
131    const ONLY_CHILDREN: bool = O::ONLY_CHILDREN;
132}
133
134#[doc(hidden)]
135#[derive(Clone, Copy, Default, Debug)]
136pub struct PrefixAttr<O>(PhantomData<O>);
137
138impl<O> Options for PrefixAttr<O>
139where
140    O: Options,
141{
142    const NAMESPACES: bool = O::NAMESPACES;
143    const PREFIX_ATTR: bool = true;
144    const ONLY_CHILDREN: bool = O::ONLY_CHILDREN;
145}
146
147#[doc(hidden)]
148#[derive(Clone, Copy, Default, Debug)]
149pub struct OnlyChildren<O>(PhantomData<O>);
150
151impl<O> Options for OnlyChildren<O>
152where
153    O: Options,
154{
155    const NAMESPACES: bool = O::NAMESPACES;
156    const PREFIX_ATTR: bool = O::PREFIX_ATTR;
157    const ONLY_CHILDREN: bool = true;
158}
159
160struct Deserializer<'doc, 'input, 'temp, O> {
161    source: Source<'doc, 'input>,
162    temp: &'temp mut Temp,
163    options: PhantomData<O>,
164}
165
166#[derive(Clone, Copy)]
167enum Source<'doc, 'input> {
168    Node(Node<'doc, 'input>),
169    Attribute(Attribute<'doc, 'input>),
170    Content(Node<'doc, 'input>),
171}
172
173impl Source<'_, '_> {
174    fn name<'a, O>(&'a self, buffer: &'a mut String) -> &'a str
175    where
176        O: Options,
177    {
178        match self {
179            Self::Node(node) => {
180                let name = node.name().unwrap();
181
182                match name.namespace {
183                    Some(namespace) if O::NAMESPACES => {
184                        buffer.clear();
185
186                        buffer.reserve(namespace.len() + 2 + name.local.len());
187
188                        buffer.push('{');
189                        buffer.push_str(namespace);
190                        buffer.push('}');
191
192                        buffer.push_str(name.local);
193
194                        &*buffer
195                    }
196                    _ => name.local,
197                }
198            }
199            Self::Attribute(attr) => {
200                let name = attr.name();
201
202                match name.namespace {
203                    Some(namespace) if O::NAMESPACES => {
204                        buffer.clear();
205
206                        if O::PREFIX_ATTR {
207                            buffer.reserve(3 + namespace.len() + name.local.len());
208
209                            buffer.push('@');
210                        } else {
211                            buffer.reserve(2 + namespace.len() + name.local.len());
212                        }
213
214                        buffer.push('{');
215                        buffer.push_str(namespace);
216                        buffer.push('}');
217
218                        buffer.push_str(name.local);
219
220                        &*buffer
221                    }
222                    _ => {
223                        if O::PREFIX_ATTR {
224                            buffer.clear();
225
226                            buffer.reserve(1 + name.local.len());
227
228                            buffer.push('@');
229                            buffer.push_str(name.local);
230
231                            &*buffer
232                        } else {
233                            name.local
234                        }
235                    }
236                }
237            }
238            Self::Content(_) => "#content",
239        }
240    }
241}
242
243struct Temp {
244    visited: Box<[usize]>,
245    buffer: String,
246}
247
248impl Temp {
249    fn new(len: NonZeroUsize) -> Self {
250        let len = len.get().div_ceil(USIZE_BITS);
251
252        Self {
253            visited: (0..len).map(|_idx| 0).collect(),
254            buffer: String::new(),
255        }
256    }
257
258    fn set_visited(&mut self, node: usize) {
259        let idx = node / USIZE_BITS;
260        let bit = node % USIZE_BITS;
261
262        self.visited[idx] |= 1 << bit;
263    }
264
265    fn is_visited(&self, node: usize) -> bool {
266        let idx = node / USIZE_BITS;
267        let bit = node % USIZE_BITS;
268
269        self.visited[idx] & (1 << bit) != 0
270    }
271}
272
273const USIZE_BITS: usize = usize::BITS as usize;
274
275impl<'doc, 'input, O> Deserializer<'doc, 'input, '_, O>
276where
277    O: Options,
278{
279    fn name(&mut self) -> &str {
280        self.source.name::<O>(&mut self.temp.buffer)
281    }
282
283    fn node(&self) -> Result<Node<'doc, 'input>, Box<Error>> {
284        match self.source {
285            Source::Node(node) | Source::Content(node) => Ok(node),
286            Source::Attribute(_) => Error::MissingNode.into(),
287        }
288    }
289
290    fn children(
291        &self,
292    ) -> Result<impl Iterator<Item = Source<'doc, 'input>> + use<'doc, 'input, O>, Box<Error>> {
293        let node = self.node()?;
294
295        let children = node
296            .children()
297            .filter(|node| node.is_element())
298            .map(Source::Node);
299
300        Ok(children)
301    }
302
303    fn children_and_attributes(
304        &self,
305    ) -> Result<impl Iterator<Item = Source<'doc, 'input>> + use<'doc, 'input, O>, Box<Error>> {
306        let node = self.node()?;
307
308        let children = node
309            .children()
310            .filter(|node| node.is_element())
311            .map(Source::Node);
312
313        let attributes = node.attributes().map(Source::Attribute);
314
315        let content = once(Source::Content(node));
316
317        Ok(children.chain(attributes).chain(content))
318    }
319
320    fn siblings(
321        &self,
322    ) -> Result<impl Iterator<Item = Node<'doc, 'input>> + use<'doc, 'input, O>, Box<Error>> {
323        let node = self.node()?;
324        let name = node.name();
325
326        Ok(node.next_siblings().filter(move |node1| {
327            let name1 = node1.name();
328
329            if O::NAMESPACES {
330                name == name1
331            } else {
332                cmp_opt_names(name.map(|name| name.local), name1.map(|name1| name1.local))
333            }
334        }))
335    }
336
337    fn text(&self) -> Cow<'doc, str> {
338        match self.source {
339            Source::Node(node) | Source::Content(node) => {
340                node.child_text().unwrap_or(Cow::Borrowed(""))
341            }
342            Source::Attribute(attr) => Cow::Borrowed(attr.value()),
343        }
344    }
345
346    fn parse<T>(&self, map_err: fn(T::Err) -> Error) -> Result<T, Box<Error>>
347    where
348        T: FromStr,
349    {
350        self.text()
351            .trim()
352            .parse()
353            .map_err(|err| map_err(err).into())
354    }
355}
356
357impl<'de, O> de::Deserializer<'de> for Deserializer<'de, '_, '_, O>
358where
359    O: Options,
360{
361    type Error = Box<Error>;
362
363    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
364    where
365        V: de::Visitor<'de>,
366    {
367        visitor.visit_bool(self.parse(Error::ParseBool)?)
368    }
369
370    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
371    where
372        V: de::Visitor<'de>,
373    {
374        visitor.visit_i8(self.parse(Error::ParseInt)?)
375    }
376
377    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
378    where
379        V: de::Visitor<'de>,
380    {
381        visitor.visit_i16(self.parse(Error::ParseInt)?)
382    }
383
384    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
385    where
386        V: de::Visitor<'de>,
387    {
388        visitor.visit_i32(self.parse(Error::ParseInt)?)
389    }
390
391    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
392    where
393        V: de::Visitor<'de>,
394    {
395        visitor.visit_i64(self.parse(Error::ParseInt)?)
396    }
397
398    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
399    where
400        V: de::Visitor<'de>,
401    {
402        visitor.visit_u8(self.parse(Error::ParseInt)?)
403    }
404
405    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
406    where
407        V: de::Visitor<'de>,
408    {
409        visitor.visit_u16(self.parse(Error::ParseInt)?)
410    }
411
412    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
413    where
414        V: de::Visitor<'de>,
415    {
416        visitor.visit_u32(self.parse(Error::ParseInt)?)
417    }
418
419    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
420    where
421        V: de::Visitor<'de>,
422    {
423        visitor.visit_u64(self.parse(Error::ParseInt)?)
424    }
425
426    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
427    where
428        V: de::Visitor<'de>,
429    {
430        visitor.visit_f32(self.parse(Error::ParseFloat)?)
431    }
432
433    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
434    where
435        V: de::Visitor<'de>,
436    {
437        visitor.visit_f64(self.parse(Error::ParseFloat)?)
438    }
439
440    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
441    where
442        V: de::Visitor<'de>,
443    {
444        visitor.visit_char(self.parse(Error::ParseChar)?)
445    }
446
447    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
448    where
449        V: de::Visitor<'de>,
450    {
451        match self.text() {
452            Cow::Borrowed(text) => visitor.visit_borrowed_str(text),
453            Cow::Owned(text) => visitor.visit_string(text),
454        }
455    }
456
457    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
458    where
459        V: de::Visitor<'de>,
460    {
461        self.deserialize_str(visitor)
462    }
463
464    fn deserialize_bytes<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
465    where
466        V: de::Visitor<'de>,
467    {
468        Error::NotSupported.into()
469    }
470
471    fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
472    where
473        V: de::Visitor<'de>,
474    {
475        Error::NotSupported.into()
476    }
477
478    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
479    where
480        V: de::Visitor<'de>,
481    {
482        visitor.visit_some(self)
483    }
484
485    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
486    where
487        V: de::Visitor<'de>,
488    {
489        visitor.visit_unit()
490    }
491
492    fn deserialize_unit_struct<V>(
493        self,
494        _name: &'static str,
495        visitor: V,
496    ) -> Result<V::Value, Self::Error>
497    where
498        V: de::Visitor<'de>,
499    {
500        self.deserialize_unit(visitor)
501    }
502
503    fn deserialize_newtype_struct<V>(
504        self,
505        _name: &'static str,
506        visitor: V,
507    ) -> Result<V::Value, Self::Error>
508    where
509        V: de::Visitor<'de>,
510    {
511        visitor.visit_newtype_struct(self)
512    }
513
514    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
515    where
516        V: de::Visitor<'de>,
517    {
518        visitor.visit_seq(SeqAccess {
519            source: self.siblings()?,
520            temp: self.temp,
521            options: PhantomData::<O>,
522        })
523    }
524
525    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
526    where
527        V: de::Visitor<'de>,
528    {
529        self.deserialize_seq(visitor)
530    }
531
532    fn deserialize_tuple_struct<V>(
533        self,
534        _name: &'static str,
535        _len: usize,
536        visitor: V,
537    ) -> Result<V::Value, Self::Error>
538    where
539        V: de::Visitor<'de>,
540    {
541        self.deserialize_seq(visitor)
542    }
543
544    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
545    where
546        V: de::Visitor<'de>,
547    {
548        if O::ONLY_CHILDREN {
549            visitor.visit_map(MapAccess {
550                source: self.children()?.peekable(),
551                temp: self.temp,
552                options: PhantomData::<O>,
553            })
554        } else {
555            visitor.visit_map(MapAccess {
556                source: self.children_and_attributes()?.peekable(),
557                temp: self.temp,
558                options: PhantomData::<O>,
559            })
560        }
561    }
562
563    fn deserialize_struct<V>(
564        self,
565        #[allow(unused_variables)] name: &'static str,
566        _fields: &'static [&'static str],
567        visitor: V,
568    ) -> Result<V::Value, Self::Error>
569    where
570        V: de::Visitor<'de>,
571    {
572        #[cfg(feature = "raw-node")]
573        let res =
574            raw_node::deserialize_struct(self, name, move |this| this.deserialize_map(visitor));
575
576        #[cfg(not(feature = "raw-node"))]
577        let res = self.deserialize_map(visitor);
578
579        res
580    }
581
582    fn deserialize_enum<V>(
583        self,
584        _name: &'static str,
585        variants: &'static [&'static str],
586        visitor: V,
587    ) -> Result<V::Value, Self::Error>
588    where
589        V: de::Visitor<'de>,
590    {
591        if O::ONLY_CHILDREN {
592            visitor.visit_enum(EnumAccess {
593                source: self.children()?,
594                variants,
595                temp: self.temp,
596                options: PhantomData::<O>,
597            })
598        } else {
599            visitor.visit_enum(EnumAccess {
600                source: self.children_and_attributes()?,
601                variants,
602                temp: self.temp,
603                options: PhantomData::<O>,
604            })
605        }
606    }
607
608    fn deserialize_identifier<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
609    where
610        V: de::Visitor<'de>,
611    {
612        visitor.visit_str(self.name())
613    }
614
615    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
616    where
617        V: de::Visitor<'de>,
618    {
619        Error::NotSupported.into()
620    }
621
622    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
623    where
624        V: de::Visitor<'de>,
625    {
626        self.deserialize_unit(visitor)
627    }
628}
629
630struct SeqAccess<'doc, 'input, 'temp, I, O>
631where
632    I: Iterator<Item = Node<'doc, 'input>>,
633    'input: 'doc,
634{
635    source: I,
636    temp: &'temp mut Temp,
637    options: PhantomData<O>,
638}
639
640impl<'de, 'input, I, O> de::SeqAccess<'de> for SeqAccess<'de, 'input, '_, I, O>
641where
642    I: Iterator<Item = Node<'de, 'input>>,
643    O: Options,
644{
645    type Error = Box<Error>;
646
647    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
648    where
649        T: de::DeserializeSeed<'de>,
650    {
651        match self.source.next() {
652            None => Ok(None),
653            Some(node) => {
654                self.temp.set_visited(node.id().get());
655
656                let deserializer = Deserializer {
657                    source: Source::Node(node),
658                    temp: &mut *self.temp,
659                    options: PhantomData::<O>,
660                };
661                seed.deserialize(deserializer).map(Some)
662            }
663        }
664    }
665}
666
667struct MapAccess<'doc, 'input, 'temp, I, O>
668where
669    I: Iterator<Item = Source<'doc, 'input>>,
670    'input: 'doc,
671{
672    source: Peekable<I>,
673    temp: &'temp mut Temp,
674    options: PhantomData<O>,
675}
676
677impl<'de, 'input, I, O> de::MapAccess<'de> for MapAccess<'de, 'input, '_, I, O>
678where
679    I: Iterator<Item = Source<'de, 'input>>,
680    O: Options,
681{
682    type Error = Box<Error>;
683
684    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
685    where
686        K: de::DeserializeSeed<'de>,
687    {
688        loop {
689            match self.source.peek() {
690                None => return Ok(None),
691                Some(source) => {
692                    if let Source::Node(node) = source
693                        && self.temp.is_visited(node.id().get())
694                    {
695                        self.source.next().unwrap();
696                        continue;
697                    }
698
699                    let deserailizer = Deserializer {
700                        source: *source,
701                        temp: &mut *self.temp,
702                        options: PhantomData::<O>,
703                    };
704                    return seed.deserialize(deserailizer).map(Some);
705                }
706            }
707        }
708    }
709
710    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
711    where
712        V: de::DeserializeSeed<'de>,
713    {
714        let source = self.source.next().unwrap();
715
716        let deserializer = Deserializer {
717            source,
718            temp: &mut *self.temp,
719            options: PhantomData::<O>,
720        };
721        seed.deserialize(deserializer)
722    }
723}
724
725struct EnumAccess<'doc, 'input, 'temp, I, O>
726where
727    I: Iterator<Item = Source<'doc, 'input>>,
728    'input: 'doc,
729{
730    source: I,
731    variants: &'static [&'static str],
732    temp: &'temp mut Temp,
733    options: PhantomData<O>,
734}
735
736impl<'de, 'input, 'temp, I, O> de::EnumAccess<'de> for EnumAccess<'de, 'input, 'temp, I, O>
737where
738    I: Iterator<Item = Source<'de, 'input>>,
739    O: Options,
740{
741    type Error = Box<Error>;
742    type Variant = Deserializer<'de, 'input, 'temp, O>;
743
744    fn variant_seed<V>(mut self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
745    where
746        V: de::DeserializeSeed<'de>,
747    {
748        let source = self
749            .source
750            .find(|source| {
751                self.variants
752                    .contains(&source.name::<O>(&mut self.temp.buffer))
753            })
754            .ok_or(Error::MissingChildOrAttribute)?;
755
756        let deserializer = Deserializer {
757            source,
758            temp: &mut *self.temp,
759            options: PhantomData::<O>,
760        };
761        let value = seed.deserialize(deserializer)?;
762
763        let deserializer = Deserializer {
764            source,
765            temp: &mut *self.temp,
766            options: PhantomData::<O>,
767        };
768        Ok((value, deserializer))
769    }
770}
771
772impl<'de, O> de::VariantAccess<'de> for Deserializer<'de, '_, '_, O>
773where
774    O: Options,
775{
776    type Error = Box<Error>;
777
778    fn unit_variant(self) -> Result<(), Self::Error> {
779        Ok(())
780    }
781
782    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
783    where
784        T: de::DeserializeSeed<'de>,
785    {
786        seed.deserialize(self)
787    }
788
789    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
790    where
791        V: de::Visitor<'de>,
792    {
793        de::Deserializer::deserialize_tuple(self, len, visitor)
794    }
795
796    fn struct_variant<V>(
797        self,
798        fields: &'static [&'static str],
799        visitor: V,
800    ) -> Result<V::Value, Self::Error>
801    where
802        V: de::Visitor<'de>,
803    {
804        de::Deserializer::deserialize_struct(self, "", fields, visitor)
805    }
806}
807
808#[derive(Debug)]
809pub enum Error {
810    MissingNode,
811    MissingChildOrAttribute,
812    ParseXml(Box<XmlError>),
813    ParseBool(ParseBoolError),
814    ParseInt(ParseIntError),
815    ParseFloat(ParseFloatError),
816    ParseChar(ParseCharError),
817    NotSupported,
818    Custom(String),
819}
820
821impl<T> From<Error> for Result<T, Box<Error>> {
822    #[cold]
823    #[inline(never)]
824    fn from(err: Error) -> Self {
825        Err(Box::new(err))
826    }
827}
828
829impl de::Error for Box<Error> {
830    #[cold]
831    #[inline(never)]
832    fn custom<T: fmt::Display>(msg: T) -> Self {
833        Box::new(Error::Custom(msg.to_string()))
834    }
835}
836
837impl fmt::Display for Error {
838    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
839        match self {
840            Self::MissingNode => write!(fmt, "missing node"),
841            Self::MissingChildOrAttribute => write!(fmt, "missing child or attribute"),
842            Self::ParseXml(err) => write!(fmt, "XML parse error: {err}"),
843            Self::ParseBool(err) => write!(fmt, "bool parse error: {err}"),
844            Self::ParseInt(err) => write!(fmt, "int parse error: {err}"),
845            Self::ParseFloat(err) => write!(fmt, "float parse error: {err}"),
846            Self::ParseChar(err) => write!(fmt, "char parse error: {err}"),
847            Self::NotSupported => write!(fmt, "not supported"),
848            Self::Custom(msg) => write!(fmt, "custom error: {msg}"),
849        }
850    }
851}
852
853impl StdError for Error {
854    fn source(&self) -> Option<&(dyn StdError + 'static)> {
855        match self {
856            Self::ParseXml(err) => Some(err),
857            Self::ParseBool(err) => Some(err),
858            Self::ParseInt(err) => Some(err),
859            Self::ParseFloat(err) => Some(err),
860            Self::ParseChar(err) => Some(err),
861            _ => None,
862        }
863    }
864}
865
866#[cfg(test)]
867mod tests {
868    use super::*;
869
870    use serde::Deserialize;
871
872    #[test]
873    fn parse_bool() {
874        let val = from_str::<bool>("<root>false</root>").unwrap();
875        assert!(!val);
876        let val = from_str::<bool>("<root>\n\ttrue\n</root>").unwrap();
877        assert!(val);
878
879        let res = from_str::<bool>("<root>foobar</root>");
880        assert!(matches!(*res.unwrap_err(), Error::ParseBool(_err)));
881    }
882
883    #[test]
884    fn parse_char() {
885        let val = from_str::<char>("<root>x</root>").unwrap();
886        assert_eq!(val, 'x');
887        let val = from_str::<char>("<root>\n\ty\n</root>").unwrap();
888        assert_eq!(val, 'y');
889
890        let res = from_str::<char>("<root>xyz</root>");
891        assert!(matches!(*res.unwrap_err(), Error::ParseChar(_err)));
892    }
893
894    #[test]
895    fn empty_text() {
896        let val = from_str::<String>("<root></root>").unwrap();
897        assert!(val.is_empty());
898    }
899
900    #[test]
901    fn children_and_attributes() {
902        #[derive(Deserialize)]
903        struct Root {
904            attr: i32,
905            child: u64,
906        }
907
908        let val = from_str::<Root>(r#"<root attr="23"><child>42</child></root>"#).unwrap();
909        assert_eq!(val.attr, 23);
910        assert_eq!(val.child, 42);
911    }
912
913    #[test]
914    fn children_with_attributes() {
915        #[derive(Deserialize)]
916        struct Root {
917            child: Child,
918        }
919
920        #[derive(Deserialize)]
921        struct Child {
922            attr: i32,
923            #[serde(rename = "#content")]
924            text: u64,
925        }
926
927        let val = from_str::<Root>(r#"<root><child attr="23">42</child></root>"#).unwrap();
928        assert_eq!(val.child.attr, 23);
929        assert_eq!(val.child.text, 42);
930    }
931
932    #[test]
933    fn multiple_children() {
934        #[derive(Deserialize)]
935        struct Root {
936            child: Vec<i32>,
937            another_child: String,
938        }
939
940        let val = from_str::<Root>(r#"<root><child>23</child><another_child>foobar</another_child><child>42</child></root>"#).unwrap();
941        assert_eq!(val.child, [23, 42]);
942        assert_eq!(val.another_child, "foobar");
943    }
944
945    #[test]
946    fn multiple_lists_of_multiple_children() {
947        #[derive(Deserialize)]
948        struct Root {
949            child: Vec<i32>,
950            another_child: Vec<String>,
951        }
952
953        let val = from_str::<Root>(r#"<root><child>23</child><another_child>foo</another_child><child>42</child><another_child>bar</another_child></root>"#).unwrap();
954        assert_eq!(val.child, [23, 42]);
955        assert_eq!(val.another_child, ["foo", "bar"]);
956    }
957
958    #[test]
959    fn zero_of_multiple_children() {
960        #[derive(Deserialize)]
961        struct Root {
962            #[serde(default)]
963            child: Vec<i32>,
964        }
965
966        let val = from_str::<Root>(r#"<root></root>"#).unwrap();
967        assert_eq!(val.child, []);
968    }
969
970    #[test]
971    fn optional_child() {
972        #[derive(Deserialize)]
973        struct Root {
974            child: Option<f32>,
975        }
976
977        let val = from_str::<Root>(r#"<root><child>23.42</child></root>"#).unwrap();
978        assert_eq!(val.child, Some(23.42));
979
980        let val = from_str::<Root>(r#"<root></root>"#).unwrap();
981        assert_eq!(val.child, None);
982    }
983
984    #[test]
985    fn optional_attribute() {
986        #[derive(Deserialize)]
987        struct Root {
988            attr: Option<f64>,
989        }
990
991        let val = from_str::<Root>(r#"<root attr="23.42"></root>"#).unwrap();
992        assert_eq!(val.attr, Some(23.42));
993
994        let val = from_str::<Root>(r#"<root></root>"#).unwrap();
995        assert_eq!(val.attr, None);
996    }
997
998    #[test]
999    fn child_variants() {
1000        #[derive(Debug, PartialEq, Deserialize)]
1001        enum Root {
1002            Foo(Foo),
1003            Bar(Bar),
1004        }
1005
1006        #[derive(Debug, PartialEq, Deserialize)]
1007        struct Foo {
1008            attr: i64,
1009        }
1010
1011        #[derive(Debug, PartialEq, Deserialize)]
1012        struct Bar {
1013            child: u32,
1014        }
1015
1016        let val = from_str::<Root>(r#"<root><Foo attr="23" /></root>"#).unwrap();
1017        assert_eq!(val, Root::Foo(Foo { attr: 23 }));
1018
1019        let val = from_str::<Root>(r#"<root><Bar><child>42</child></Bar></root>"#).unwrap();
1020        assert_eq!(val, Root::Bar(Bar { child: 42 }));
1021    }
1022
1023    #[test]
1024    fn attribute_variants() {
1025        #[derive(Debug, PartialEq, Deserialize)]
1026        enum Root {
1027            Foo(u32),
1028            Bar(i64),
1029        }
1030
1031        let val = from_str::<Root>(r#"<root Foo="23" />"#).unwrap();
1032        assert_eq!(val, Root::Foo(23));
1033
1034        let val = from_str::<Root>(r#"<root Bar="42" />"#).unwrap();
1035        assert_eq!(val, Root::Bar(42));
1036    }
1037
1038    #[test]
1039    fn mixed_enum_and_struct_children() {
1040        #[derive(Debug, PartialEq, Deserialize)]
1041        enum Foobar {
1042            Foo(u32),
1043            Bar(i64),
1044        }
1045
1046        #[derive(Deserialize)]
1047        struct Root {
1048            #[serde(rename = "#content")]
1049            foobar: Foobar,
1050            qux: f32,
1051        }
1052
1053        let val = from_str::<Root>(r#"<root><qux>42.0</qux><Foo>23</Foo></root>"#).unwrap();
1054        assert_eq!(val.foobar, Foobar::Foo(23));
1055        assert_eq!(val.qux, 42.0);
1056    }
1057
1058    #[test]
1059    fn mixed_enum_and_repeated_struct_children() {
1060        #[derive(Debug, PartialEq, Deserialize)]
1061        enum Foobar {
1062            Foo(u32),
1063            Bar(i64),
1064        }
1065
1066        #[derive(Deserialize)]
1067        struct Root {
1068            #[serde(rename = "#content")]
1069            foobar: Foobar,
1070            qux: Vec<f32>,
1071            baz: String,
1072        }
1073
1074        let val = from_str::<Root>(
1075            r#"<root><Bar>42</Bar><qux>1.0</qux><baz>baz</baz><qux>2.0</qux><qux>3.0</qux></root>"#,
1076        )
1077        .unwrap();
1078        assert_eq!(val.foobar, Foobar::Bar(42));
1079        assert_eq!(val.qux, [1.0, 2.0, 3.0]);
1080        assert_eq!(val.baz, "baz");
1081    }
1082
1083    #[test]
1084    fn repeated_enum_and_struct_children() {
1085        #[derive(Debug, PartialEq, Deserialize)]
1086        enum Foobar {
1087            Foo(Vec<u32>),
1088            Bar(i64),
1089        }
1090
1091        #[derive(Deserialize)]
1092        struct Root {
1093            #[serde(rename = "#content")]
1094            foobar: Foobar,
1095            baz: String,
1096        }
1097
1098        let val =
1099            from_str::<Root>(r#"<root><Foo>42</Foo><baz>baz</baz><Foo>23</Foo></root>"#).unwrap();
1100        assert_eq!(val.foobar, Foobar::Foo(vec![42, 23]));
1101        assert_eq!(val.baz, "baz");
1102    }
1103
1104    #[test]
1105    fn borrowed_str() {
1106        let doc = Document::parse("<root><child>foobar</child></root>").unwrap();
1107
1108        #[derive(Deserialize)]
1109        struct Root<'a> {
1110            child: &'a str,
1111        }
1112
1113        let val = from_doc::<Root>(&doc).unwrap();
1114        assert_eq!(val.child, "foobar");
1115    }
1116
1117    #[test]
1118    fn unit_struct() {
1119        #[derive(Deserialize)]
1120        #[allow(dead_code)]
1121        struct Root {
1122            child: Child,
1123        }
1124
1125        #[derive(Deserialize)]
1126        struct Child;
1127
1128        from_str::<Root>(r#"<root><child /></root>"#).unwrap();
1129
1130        from_str::<Root>(r#"<root><child>foobar</child></root>"#).unwrap();
1131    }
1132
1133    #[test]
1134    fn unit_variant() {
1135        #[derive(Debug, Deserialize)]
1136        enum Root {
1137            Child,
1138        }
1139
1140        from_str::<Root>(r#"<root><Child /></root>"#).unwrap();
1141
1142        from_str::<Root>(r#"<root><Child>foobar</Child></root>"#).unwrap();
1143    }
1144
1145    #[test]
1146    fn children_with_namespaces() {
1147        #[derive(Deserialize)]
1148        struct Root {
1149            #[serde(rename = "{http://name.space}child")]
1150            child: u64,
1151        }
1152
1153        let val = defaults()
1154            .namespaces()
1155            .from_str::<Root>(r#"<root xmlns="http://name.space"><child>42</child></root>"#)
1156            .unwrap();
1157        assert_eq!(val.child, 42);
1158
1159        let val = defaults()
1160            .namespaces()
1161            .from_str::<Root>(r#"<root xmlns:namespace="http://name.space"><namespace:child>42</namespace:child></root>"#)
1162            .unwrap();
1163        assert_eq!(val.child, 42);
1164    }
1165
1166    #[test]
1167    fn attributes_with_namespaces() {
1168        #[derive(Deserialize)]
1169        struct Root {
1170            #[serde(rename = "{http://name.space}attr")]
1171            attr: i32,
1172        }
1173
1174        let val = defaults()
1175            .namespaces()
1176            .from_str::<Root>(
1177                r#"<root xmlns:namespace="http://name.space" namespace:attr="23"></root>"#,
1178            )
1179            .unwrap();
1180        assert_eq!(val.attr, 23);
1181    }
1182
1183    #[test]
1184    fn prefixed_attributes() {
1185        #[derive(Deserialize)]
1186        struct Root {
1187            #[serde(rename = "@attr")]
1188            attr: i32,
1189        }
1190
1191        let val = defaults()
1192            .prefix_attr()
1193            .from_str::<Root>(r#"<root attr="23"></root>"#)
1194            .unwrap();
1195        assert_eq!(val.attr, 23);
1196    }
1197
1198    #[test]
1199    fn prefixed_attributes_with_namespaces() {
1200        #[derive(Deserialize)]
1201        struct Root {
1202            #[serde(rename = "@{http://name.space}attr")]
1203            attr: i32,
1204        }
1205
1206        let val = defaults()
1207            .namespaces()
1208            .prefix_attr()
1209            .from_str::<Root>(
1210                r#"<root xmlns:namespace="http://name.space" namespace:attr="23"></root>"#,
1211            )
1212            .unwrap();
1213        assert_eq!(val.attr, 23);
1214    }
1215
1216    #[test]
1217    fn only_children_skips_attributes() {
1218        #[derive(Deserialize)]
1219        struct Root {
1220            child: u64,
1221            attr: Option<i32>,
1222        }
1223
1224        let val = defaults()
1225            .from_str::<Root>(r#"<root attr="23"><child>42</child></root>"#)
1226            .unwrap();
1227        assert_eq!(val.child, 42);
1228        assert_eq!(val.attr, Some(23));
1229
1230        let val = defaults()
1231            .only_children()
1232            .from_str::<Root>(r#"<root attr="23"><child>42</child></root>"#)
1233            .unwrap();
1234        assert_eq!(val.child, 42);
1235        assert_eq!(val.attr, None);
1236    }
1237
1238    #[test]
1239    fn only_children_skips_content() {
1240        #[derive(Deserialize)]
1241        struct Root {
1242            child: u64,
1243            #[serde(rename = "#content")]
1244            text: Option<String>,
1245        }
1246
1247        let val = defaults()
1248            .from_str::<Root>(r#"<root>text<child>42</child></root>"#)
1249            .unwrap();
1250        assert_eq!(val.child, 42);
1251        assert_eq!(val.text.as_deref(), Some("text"));
1252
1253        let val = defaults()
1254            .only_children()
1255            .from_str::<Root>(r#"<root>text<child>42</child></root>"#)
1256            .unwrap();
1257        assert_eq!(val.child, 42);
1258        assert_eq!(val.text.as_deref(), None);
1259    }
1260
1261    #[test]
1262    fn repeated_namespaced_elements() {
1263        #[derive(Deserialize)]
1264        struct Root {
1265            #[serde(rename = "{http://foo}child")]
1266            foo: Vec<u64>,
1267            #[serde(rename = "{http://bar}child")]
1268            bar: Vec<u64>,
1269        }
1270
1271        let val = defaults()
1272            .namespaces()
1273            .from_str::<Root>(
1274                r#"<root xmlns:foo="http://foo" xmlns:bar="http://bar">
1275    <foo:child>1</foo:child>
1276    <bar:child>2</bar:child>
1277    <bar:child>3</bar:child>
1278    <foo:child>4</foo:child>
1279</root>"#,
1280            )
1281            .unwrap();
1282        assert_eq!(val.foo, [1, 4]);
1283        assert_eq!(val.bar, [2, 3]);
1284    }
1285}