Skip to main content

instant_xml/
de.rs

1//! XML deserialization support code
2
3use std::borrow::Cow;
4use std::collections::{BTreeMap, VecDeque};
5use std::ops::{Deref, DerefMut};
6use std::str::{self, FromStr};
7
8use xmlparser::{ElementEnd, Token, Tokenizer};
9
10use crate::impls::CowStrAccumulator;
11use crate::{Accumulate, Error, FromXml, Id};
12
13/// XML deserializer for iterating over nodes in an element
14pub struct Deserializer<'cx, 'xml> {
15    parent: Element<'xml>,
16    level: usize,
17    done: bool,
18    context: Mut<'cx, Context<'xml>>,
19}
20
21impl<'cx, 'xml> Deserializer<'cx, 'xml> {
22    /// Create a new `Deserializer` from the given XML input string
23    ///
24    /// Errors if the input is empty or does not start with an opening element.
25    pub fn new(input: &'xml str) -> Result<Self, Error> {
26        let mut context = Mut::Owned(Context::new(input));
27        let parent = match context.next() {
28            Some(result) => match result? {
29                Node::Open(element) => element,
30                _ => return Err(Error::UnexpectedState("first node does not open element")),
31            },
32            None => return Err(Error::UnexpectedEndOfStream),
33        };
34
35        Ok(Self::with_context(parent, context))
36    }
37
38    /// Create a nested deserializer for a child element
39    pub fn nested<'a>(&'a mut self, element: Element<'xml>) -> Deserializer<'a, 'xml>
40    where
41        'cx: 'a,
42    {
43        Deserializer::with_context(element, self.context.borrow())
44    }
45
46    fn with_context(parent: Element<'xml>, context: Mut<'cx, Context<'xml>>) -> Self {
47        let level = context.stack.len();
48        Self {
49            parent,
50            level,
51            done: false,
52            context,
53        }
54    }
55
56    /// Override default limits for this deserializer
57    pub fn with_limits(mut self, limits: Limits) -> Self {
58        self.context.limits = limits;
59        self
60    }
61
62    /// Deserialize a value of type `T` from the deserializer's XML input
63    pub fn deserialize<T: FromXml<'xml>>(&mut self) -> Result<T, Error> {
64        let id = self.context.element_id(&self.parent)?;
65        if !T::matches(id, None) {
66            return Err(Error::UnexpectedValue(match id.ns.is_empty() {
67                true => format!("unexpected root element {:?}", id.name),
68                false => format!(
69                    "unexpected root element {:?} in namespace {:?}",
70                    id.name, id.ns
71                ),
72            }));
73        }
74
75        let mut value = T::Accumulator::default();
76        T::deserialize(&mut value, "<root element>", self)?;
77        value.try_done("<root element>")
78    }
79
80    /// Skip all remaining nodes in the current element
81    pub fn ignore(&mut self) -> Result<(), Error> {
82        loop {
83            match self.next() {
84                Some(Err(e)) => return Err(e),
85                Some(Ok(Node::Open(element))) => {
86                    let mut nested = self.nested(element);
87                    nested.ignore()?;
88                }
89                Some(_) => continue,
90                None => return Ok(()),
91            }
92        }
93    }
94
95    /// Create a deserializer that will yield the given node first
96    pub fn for_node<'a>(&'a mut self, node: Node<'xml>) -> Deserializer<'a, 'xml>
97    where
98        'cx: 'a,
99    {
100        self.context.records.push_front(node);
101        Deserializer {
102            parent: self.parent,
103            level: self.level,
104            done: self.done,
105            context: self.context.borrow(),
106        }
107    }
108
109    /// Extract a string value from the current node
110    ///
111    /// Consumes a text node or attribute value, returning the content as a string.
112    pub fn take_str(&mut self) -> Result<Option<Cow<'xml, str>>, Error> {
113        loop {
114            match self.next() {
115                Some(Ok(Node::AttributeValue(s))) => return Ok(Some(s)),
116                Some(Ok(Node::Text(s))) => return Ok(Some(s)),
117                Some(Ok(Node::Attribute(_))) => continue,
118                Some(Ok(node)) => return Err(Error::ExpectedScalar(format!("{node:?}"))),
119                Some(Err(e)) => return Err(e),
120                None => return Ok(None),
121            }
122        }
123    }
124
125    /// Get the identifier of the parent element
126    pub fn parent(&self) -> Id<'xml> {
127        Id {
128            ns: match self.parent.prefix {
129                Some(ns) => self.context.lookup(ns).unwrap(),
130                None => self.context.default_ns(),
131            },
132            name: self.parent.local,
133        }
134    }
135
136    pub fn namespace(&self, prefix: &str) -> Option<&'xml str> {
137        self.context.lookup(prefix)
138    }
139
140    /// Get the identifier of an element (name and namespace)
141    #[inline]
142    pub fn element_id(&self, element: &Element<'xml>) -> Result<Id<'xml>, Error> {
143        self.context.element_id(element)
144    }
145
146    /// Get the identifier of an attribute (name and namespace)
147    #[inline]
148    pub fn attribute_id(&self, attr: &Attribute<'xml>) -> Result<Id<'xml>, Error> {
149        self.context.attribute_id(attr)
150    }
151}
152
153impl<'xml> Iterator for Deserializer<'_, 'xml> {
154    type Item = Result<Node<'xml>, Error>;
155
156    fn next(&mut self) -> Option<Self::Item> {
157        if self.done {
158            return None;
159        }
160
161        let (prefix, local) = match self.context.next() {
162            Some(Ok(Node::Close { prefix, local })) => (prefix, local),
163            item => return item,
164        };
165
166        if self.context.stack.len() == self.level - 1
167            && local == self.parent.local
168            && prefix == self.parent.prefix
169        {
170            self.done = true;
171            return None;
172        }
173
174        Some(Err(Error::UnexpectedState("close element mismatch")))
175    }
176}
177
178struct Context<'xml> {
179    parser: Tokenizer<'xml>,
180    stack: Vec<Level<'xml>>,
181    records: VecDeque<Node<'xml>>,
182    limits: Limits,
183}
184
185impl<'xml> Context<'xml> {
186    fn new(input: &'xml str) -> Self {
187        Self {
188            parser: Tokenizer::from(input),
189            stack: Vec::new(),
190            records: VecDeque::new(),
191            limits: Limits::default(),
192        }
193    }
194
195    fn element_id(&self, element: &Element<'xml>) -> Result<Id<'xml>, Error> {
196        Ok(Id {
197            ns: match (element.default_ns, element.prefix) {
198                (_, Some(prefix)) => match self.lookup(prefix) {
199                    Some(ns) => ns,
200                    None => return Err(Error::UnknownPrefix(prefix.to_owned())),
201                },
202                (Some(ns), None) => ns,
203                (None, None) => self.default_ns(),
204            },
205            name: element.local,
206        })
207    }
208
209    fn attribute_id(&self, attr: &Attribute<'xml>) -> Result<Id<'xml>, Error> {
210        Ok(Id {
211            ns: match attr.prefix {
212                Some(ns) => self
213                    .lookup(ns)
214                    .ok_or_else(|| Error::UnknownPrefix(ns.to_owned()))?,
215                None => "",
216            },
217            name: attr.local,
218        })
219    }
220
221    fn default_ns(&self) -> &'xml str {
222        self.stack
223            .iter()
224            .rev()
225            .find_map(|level| level.default_ns)
226            .unwrap_or("")
227    }
228
229    fn lookup(&self, prefix: &str) -> Option<&'xml str> {
230        // The prefix xml is by definition bound to the namespace
231        // name http://www.w3.org/XML/1998/namespace
232        // See https://www.w3.org/TR/xml-names/#ns-decl
233        if prefix == "xml" {
234            return Some("http://www.w3.org/XML/1998/namespace");
235        }
236
237        self.stack
238            .iter()
239            .rev()
240            .find_map(|level| level.prefixes.get(prefix).copied())
241    }
242}
243
244impl<'xml> Iterator for Context<'xml> {
245    type Item = Result<Node<'xml>, Error>;
246
247    fn next(&mut self) -> Option<Self::Item> {
248        if let Some(record) = self.records.pop_front() {
249            if let Node::Close { .. } = &record {
250                self.stack.pop();
251            }
252            return Some(Ok(record));
253        }
254
255        loop {
256            match self.parser.next()? {
257                Ok(Token::ElementStart { prefix, local, .. }) => {
258                    if self.stack.len() >= self.limits.max_levels {
259                        return Some(Err(Error::Other(
260                            "maximum number of nested element levels exceeded".to_owned(),
261                        )));
262                    }
263
264                    let prefix = prefix.as_str();
265                    self.stack.push(Level {
266                        local: local.as_str(),
267                        prefix: match prefix.is_empty() {
268                            true => None,
269                            false => Some(prefix),
270                        },
271                        default_ns: None,
272                        prefixes: BTreeMap::new(),
273                    });
274                }
275                Ok(Token::ElementEnd { end, .. }) => match end {
276                    ElementEnd::Open => {
277                        let Some(level) = self.stack.last() else {
278                            return Some(Err(Error::UnexpectedState(
279                                "opening element with no parent",
280                            )));
281                        };
282
283                        let element = Element {
284                            local: level.local,
285                            prefix: level.prefix,
286                            default_ns: level.default_ns,
287                        };
288
289                        return Some(Ok(Node::Open(element)));
290                    }
291                    ElementEnd::Close(prefix, v) => {
292                        let Some(level) = self.stack.pop() else {
293                            return Some(Err(Error::UnexpectedState(
294                                "closing element without parent",
295                            )));
296                        };
297
298                        let prefix = match prefix.is_empty() {
299                            true => None,
300                            false => Some(prefix.as_str()),
301                        };
302
303                        return Some(match v.as_str() == level.local && prefix == level.prefix {
304                            true => Ok(Node::Close {
305                                prefix,
306                                local: level.local,
307                            }),
308                            false => Err(Error::UnexpectedState("close element mismatch")),
309                        });
310                    }
311                    ElementEnd::Empty => {
312                        let Some(level) = self.stack.last() else {
313                            return Some(Err(Error::UnexpectedState(
314                                "opening element with no parent",
315                            )));
316                        };
317
318                        self.records.push_back(Node::Close {
319                            prefix: level.prefix,
320                            local: level.local,
321                        });
322
323                        let element = Element {
324                            local: level.local,
325                            prefix: level.prefix,
326                            default_ns: level.default_ns,
327                        };
328
329                        return Some(Ok(Node::Open(element)));
330                    }
331                },
332                Ok(Token::Attribute {
333                    prefix,
334                    local,
335                    value,
336                    ..
337                }) => {
338                    if prefix.is_empty() && local.as_str() == "xmlns" {
339                        match self.stack.last_mut() {
340                            Some(level) => level.default_ns = Some(value.as_str()),
341                            None => {
342                                return Some(Err(Error::UnexpectedState(
343                                    "attribute without element context",
344                                )))
345                            }
346                        }
347                    } else if prefix.as_str() == "xmlns" {
348                        match self.stack.last_mut() {
349                            Some(level) => {
350                                if level.prefixes.len() >= self.limits.max_ns_declarations {
351                                    return Some(Err(Error::Other(
352                                        "maximum number of namespace declarations exceeded"
353                                            .to_owned(),
354                                    )));
355                                }
356
357                                level.prefixes.insert(local.as_str(), value.as_str());
358                            }
359                            None => {
360                                return Some(Err(Error::UnexpectedState(
361                                    "attribute without element context",
362                                )))
363                            }
364                        }
365                    } else {
366                        if self.records.len() >= self.limits.max_attributes {
367                            return Some(Err(Error::Other(
368                                "maximum number of attributes exceeded".to_owned(),
369                            )));
370                        }
371
372                        let value = match decode(value.as_str()) {
373                            Ok(value) => value,
374                            Err(e) => return Some(Err(e)),
375                        };
376
377                        self.records.push_back(Node::Attribute(Attribute {
378                            prefix: match prefix.is_empty() {
379                                true => None,
380                                false => Some(prefix.as_str()),
381                            },
382                            local: local.as_str(),
383                            value,
384                        }));
385                    }
386                }
387                Ok(Token::Text { text }) => {
388                    return Some(decode(text.as_str()).map(Node::Text));
389                }
390                Ok(Token::Cdata { text, .. }) => {
391                    return Some(Ok(Node::Text(Cow::Borrowed(text.as_str()))));
392                }
393                Ok(token @ Token::Declaration { .. }) => {
394                    if !self.stack.is_empty() {
395                        return Some(Err(Error::UnexpectedToken(format!("{token:?}"))));
396                    }
397                }
398                Ok(Token::Comment { .. }) => continue,
399                Ok(token) => return Some(Err(Error::UnexpectedToken(format!("{token:?}")))),
400                Err(e) => return Some(Err(Error::Parse(e))),
401            }
402        }
403    }
404}
405
406/// Deserialize a borrowed `Cow<str>` value
407///
408/// Helper function for deserializing `Cow<str>` with zero-copy borrowing from the input.
409pub fn borrow_cow_str<'a, 'xml: 'a>(
410    into: &mut CowStrAccumulator<'xml, 'a>,
411    field: &'static str,
412    deserializer: &mut Deserializer<'_, 'xml>,
413) -> Result<(), Error> {
414    if into.inner.is_some() {
415        return Err(Error::DuplicateValue(field));
416    }
417
418    match deserializer.take_str()? {
419        Some(value) => into.inner = Some(value),
420        None => return Ok(()),
421    };
422
423    deserializer.ignore()?;
424    Ok(())
425}
426
427/// Deserialize a borrowed `Cow<[u8]>` value
428///
429/// Helper function for deserializing `Cow<[u8]>` with zero-copy borrowing from the input.
430pub fn borrow_cow_slice_u8<'xml>(
431    into: &mut Option<Cow<'xml, [u8]>>,
432    field: &'static str,
433    deserializer: &mut Deserializer<'_, 'xml>,
434) -> Result<(), Error> {
435    if into.is_some() {
436        return Err(Error::DuplicateValue(field));
437    }
438
439    if let Some(value) = deserializer.take_str()? {
440        *into = Some(match value {
441            Cow::Borrowed(v) => Cow::Borrowed(v.as_bytes()),
442            Cow::Owned(v) => Cow::Owned(v.into_bytes()),
443        });
444    }
445
446    deserializer.ignore()?;
447    Ok(())
448}
449
450fn decode(input: &str) -> Result<Cow<'_, str>, Error> {
451    let mut result = String::new();
452    let (mut state, mut last_end) = (DecodeState::Normal, 0);
453    for (i, &b) in input.as_bytes().iter().enumerate() {
454        // use a state machine to find entities
455        state = match (state, b) {
456            (DecodeState::Normal, b'&') => DecodeState::Entity([0; MAX_ENTITY_LEN], 0),
457            (DecodeState::Normal, _) => DecodeState::Normal,
458            (DecodeState::Entity(chars, len), b';') => {
459                let decoded = match &chars[..len] {
460                    [b'a', b'm', b'p'] => '&',
461                    [b'a', b'p', b'o', b's'] => '\'',
462                    [b'g', b't'] => '>',
463                    [b'l', b't'] => '<',
464                    [b'q', b'u', b'o', b't'] => '"',
465                    [b'#', b'x' | b'X', hex @ ..] => {
466                        // Hexadecimal character reference e.g. "&#x007c;" -> '|'
467                        str::from_utf8(hex)
468                            .ok()
469                            .and_then(|hex_str| u32::from_str_radix(hex_str, 16).ok())
470                            .and_then(char::from_u32)
471                            .filter(valid_xml_character)
472                            .ok_or_else(|| {
473                                Error::InvalidEntity(
474                                    String::from_utf8_lossy(&chars[..len]).into_owned(),
475                                )
476                            })?
477                    }
478                    [b'#', decimal @ ..] => {
479                        // Decimal character reference e.g. "&#1234;" -> 'Ӓ'
480                        str::from_utf8(decimal)
481                            .ok()
482                            .and_then(|decimal_str| u32::from_str(decimal_str).ok())
483                            .and_then(char::from_u32)
484                            .filter(valid_xml_character)
485                            .ok_or_else(|| {
486                                Error::InvalidEntity(
487                                    String::from_utf8_lossy(&chars[..len]).into_owned(),
488                                )
489                            })?
490                    }
491                    _ => {
492                        return Err(Error::InvalidEntity(
493                            String::from_utf8_lossy(&chars[..len]).into_owned(),
494                        ))
495                    }
496                };
497
498                let start = i - (len + 1); // current position - (length of entity characters + 1 for '&')
499                if result.capacity() == 0 {
500                    result.reserve(input.len());
501                }
502
503                if last_end < start {
504                    // Unwrap should be safe: `last_end` and `start` must be at character boundaries.
505                    result.push_str(input.get(last_end..start).unwrap());
506                }
507
508                last_end = i + 1;
509                result.push(decoded);
510                DecodeState::Normal
511            }
512            (DecodeState::Entity(mut chars, len), b) => {
513                if len >= MAX_ENTITY_LEN {
514                    let mut bytes = Vec::with_capacity(MAX_ENTITY_LEN + 1);
515                    bytes.extend(&chars[..len]);
516                    bytes.push(b);
517                    return Err(Error::InvalidEntity(
518                        String::from_utf8_lossy(&bytes).into_owned(),
519                    ));
520                }
521
522                chars[len] = b;
523                DecodeState::Entity(chars, len + 1)
524            }
525        };
526    }
527
528    // Unterminated entity (& without ;) at end of input
529    if let DecodeState::Entity(chars, len) = state {
530        return Err(Error::InvalidEntity(
531            String::from_utf8_lossy(&chars[..len]).into_owned(),
532        ));
533    }
534
535    Ok(match result.is_empty() {
536        true => Cow::Borrowed(input),
537        false => {
538            // Unwrap should be safe: `last_end` and `input.len()` must be at character boundaries.
539            result.push_str(input.get(last_end..input.len()).unwrap());
540            Cow::Owned(result)
541        }
542    })
543}
544
545#[derive(Debug)]
546enum DecodeState {
547    Normal,
548    Entity([u8; MAX_ENTITY_LEN], usize),
549}
550
551/// Buffer size for the characters of an entity between `&` and `;`.
552///
553/// Must fit the longest reference we accept: the widest numeric character references are
554/// `#x10FFFF` and `#1114111` (the maximum Unicode scalar value), both 8 bytes.
555const MAX_ENTITY_LEN: usize = 8;
556
557/// Valid character ranges per <https://www.w3.org/TR/xml/#NT-Char>
558fn valid_xml_character(c: &char) -> bool {
559    matches!(c, '\u{9}' | '\u{A}' | '\u{D}' | '\u{20}'..='\u{D7FF}' | '\u{E000}'..='\u{FFFD}' | '\u{10000}'..='\u{10FFFF}')
560}
561
562/// An XML node during deserialization
563#[derive(Debug)]
564pub enum Node<'xml> {
565    /// An attribute name (value follows in a separate AttributeValue node)
566    Attribute(Attribute<'xml>),
567    /// The value of the preceding Attribute node
568    AttributeValue(Cow<'xml, str>),
569    /// Closing tag for an element
570    Close {
571        /// The namespace prefix, if any
572        prefix: Option<&'xml str>,
573        /// The local name
574        local: &'xml str,
575    },
576    /// Text content
577    Text(Cow<'xml, str>),
578    /// Opening tag for an element
579    Open(Element<'xml>),
580}
581
582#[non_exhaustive]
583#[derive(Clone, Copy, Debug)]
584pub struct Limits {
585    /// Maximum number of attributes allowed per element
586    pub max_attributes: usize,
587    /// Maximum number of nested element levels allowed
588    pub max_levels: usize,
589    /// Maximum number of namespace declarations allowed per element
590    pub max_ns_declarations: usize,
591}
592
593impl Default for Limits {
594    fn default() -> Self {
595        Self {
596            max_attributes: 64,
597            max_levels: 32,
598            max_ns_declarations: 32,
599        }
600    }
601}
602
603/// An XML element during deserialization
604#[derive(Clone, Copy, Debug)]
605pub struct Element<'xml> {
606    local: &'xml str,
607    default_ns: Option<&'xml str>,
608    prefix: Option<&'xml str>,
609}
610
611#[derive(Debug)]
612struct Level<'xml> {
613    local: &'xml str,
614    prefix: Option<&'xml str>,
615    default_ns: Option<&'xml str>,
616    prefixes: BTreeMap<&'xml str, &'xml str>,
617}
618
619/// An XML attribute during deserialization
620#[derive(Debug)]
621pub struct Attribute<'xml> {
622    /// The namespace prefix, if any
623    pub prefix: Option<&'xml str>,
624    /// The local name
625    pub local: &'xml str,
626    /// The attribute value
627    pub value: Cow<'xml, str>,
628}
629
630enum Mut<'a, T> {
631    Ref(&'a mut T),
632    Owned(T),
633}
634
635impl<'a, T> Mut<'a, T> {
636    fn borrow<'b>(&'b mut self) -> Mut<'b, T> {
637        match self {
638            Mut::Ref(r) => Mut::Ref(r),
639            Mut::Owned(ref mut o) => Mut::Ref(o),
640        }
641    }
642}
643
644impl<T> Deref for Mut<'_, T> {
645    type Target = T;
646
647    fn deref(&self) -> &Self::Target {
648        match self {
649            Mut::Ref(r) => r,
650            Mut::Owned(o) => o,
651        }
652    }
653}
654
655impl<T> DerefMut for Mut<'_, T> {
656    fn deref_mut(&mut self) -> &mut Self::Target {
657        match self {
658            Mut::Ref(r) => r,
659            Mut::Owned(o) => o,
660        }
661    }
662}
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667
668    #[test]
669    fn test_decode() {
670        decode_ok("foo", "foo");
671        decode_ok("foo &amp; bar", "foo & bar");
672        decode_ok("foo &lt; bar", "foo < bar");
673        decode_ok("foo &gt; bar", "foo > bar");
674        decode_ok("foo &quot; bar", "foo \" bar");
675        decode_ok("foo &apos; bar", "foo ' bar");
676        decode_ok("foo &amp;lt; bar", "foo &lt; bar");
677        decode_ok("&amp; foo", "& foo");
678        decode_ok("foo &amp;", "foo &");
679        decode_ok("cbdtéda&amp;sü", "cbdtéda&sü");
680        // Decimal character references
681        decode_ok("&#1234;", "Ӓ");
682        decode_ok("foo &#9; bar", "foo \t bar");
683        decode_ok("foo &#124; bar", "foo | bar");
684        decode_ok("foo &#1234; bar", "foo Ӓ bar");
685        // Hexadecimal character references
686        decode_ok("&#xc4;", "Ä");
687        decode_ok("&#x00c4;", "Ä");
688        decode_ok("foo &#x9; bar", "foo \t bar");
689        decode_ok("foo &#x007c; bar", "foo | bar");
690        decode_ok("foo &#xc4; bar", "foo Ä bar");
691        decode_ok("foo &#x00c4; bar", "foo Ä bar");
692        decode_ok("foo &#x10de; bar", "foo პ bar");
693        // Widest numeric references: the maximum Unicode scalar value
694        decode_ok("&#x10FFFF;", "\u{10FFFF}");
695        decode_ok("&#1114111;", "\u{10FFFF}");
696
697        decode_err("&");
698        decode_err("&#");
699        decode_err("&#;");
700        decode_err("foo&");
701        decode_err("&bar");
702        decode_err("&foo;");
703        decode_err("&foobar;");
704        decode_err("cbdtéd&ampü");
705    }
706
707    fn decode_ok(input: &str, expected: &'static str) {
708        assert_eq!(decode(input).unwrap(), expected, "{input:?}");
709    }
710
711    fn decode_err(input: &str) {
712        assert!(decode(input).is_err(), "{input:?}");
713    }
714}