Skip to main content

instant_xml/
de.rs

1//! XML deserialization support code
2
3use std::borrow::Cow;
4use std::collections::{BTreeMap, VecDeque};
5use std::str::{self, FromStr};
6
7use xmlparser::{ElementEnd, Token, Tokenizer};
8
9use crate::impls::CowStrAccumulator;
10use crate::{Error, Id};
11
12/// XML deserializer for iterating over nodes in an element
13pub struct Deserializer<'cx, 'xml> {
14    pub(crate) local: &'xml str,
15    prefix: Option<&'xml str>,
16    level: usize,
17    done: bool,
18    context: &'cx mut Context<'xml>,
19}
20
21impl<'cx, 'xml> Deserializer<'cx, 'xml> {
22    pub(crate) fn new(element: Element<'xml>, context: &'cx mut Context<'xml>) -> Self {
23        let level = context.stack.len();
24        Self {
25            local: element.local,
26            prefix: element.prefix,
27            level,
28            done: false,
29            context,
30        }
31    }
32
33    /// Extract a string value from the current node
34    ///
35    /// Consumes a text node or attribute value, returning the content as a string.
36    pub fn take_str(&mut self) -> Result<Option<Cow<'xml, str>>, Error> {
37        loop {
38            match self.next() {
39                Some(Ok(Node::AttributeValue(s))) => return Ok(Some(s)),
40                Some(Ok(Node::Text(s))) => return Ok(Some(s)),
41                Some(Ok(Node::Attribute(_))) => continue,
42                Some(Ok(node)) => return Err(Error::ExpectedScalar(format!("{node:?}"))),
43                Some(Err(e)) => return Err(e),
44                None => return Ok(None),
45            }
46        }
47    }
48
49    /// Create a nested deserializer for a child element
50    pub fn nested<'a>(&'a mut self, element: Element<'xml>) -> Deserializer<'a, 'xml>
51    where
52        'cx: 'a,
53    {
54        Deserializer::new(element, self.context)
55    }
56
57    /// Skip all remaining nodes in the current element
58    pub fn ignore(&mut self) -> Result<(), Error> {
59        loop {
60            match self.next() {
61                Some(Err(e)) => return Err(e),
62                Some(Ok(Node::Open(element))) => {
63                    let mut nested = self.nested(element);
64                    nested.ignore()?;
65                }
66                Some(_) => continue,
67                None => return Ok(()),
68            }
69        }
70    }
71
72    /// Create a deserializer that will yield the given node first
73    pub fn for_node<'a>(&'a mut self, node: Node<'xml>) -> Deserializer<'a, 'xml>
74    where
75        'cx: 'a,
76    {
77        self.context.records.push_front(node);
78        Deserializer {
79            local: self.local,
80            prefix: self.prefix,
81            level: self.level,
82            done: self.done,
83            context: self.context,
84        }
85    }
86
87    /// Get the identifier of the parent element
88    pub fn parent(&self) -> Id<'xml> {
89        Id {
90            ns: match self.prefix {
91                Some(ns) => self.context.lookup(ns).unwrap(),
92                None => self.context.default_ns(),
93            },
94            name: self.local,
95        }
96    }
97
98    /// Get the identifier of an element (name and namespace)
99    #[inline]
100    pub fn element_id(&self, element: &Element<'xml>) -> Result<Id<'xml>, Error> {
101        self.context.element_id(element)
102    }
103
104    /// Get the identifier of an attribute (name and namespace)
105    #[inline]
106    pub fn attribute_id(&self, attr: &Attribute<'xml>) -> Result<Id<'xml>, Error> {
107        self.context.attribute_id(attr)
108    }
109}
110
111impl<'xml> Iterator for Deserializer<'_, 'xml> {
112    type Item = Result<Node<'xml>, Error>;
113
114    fn next(&mut self) -> Option<Self::Item> {
115        if self.done {
116            return None;
117        }
118
119        let (prefix, local) = match self.context.next() {
120            Some(Ok(Node::Close { prefix, local })) => (prefix, local),
121            item => return item,
122        };
123
124        if self.context.stack.len() == self.level - 1
125            && local == self.local
126            && prefix == self.prefix
127        {
128            self.done = true;
129            return None;
130        }
131
132        Some(Err(Error::UnexpectedState("close element mismatch")))
133    }
134}
135
136pub(crate) struct Context<'xml> {
137    parser: Tokenizer<'xml>,
138    stack: Vec<Level<'xml>>,
139    records: VecDeque<Node<'xml>>,
140}
141
142impl<'xml> Context<'xml> {
143    pub(crate) fn new(input: &'xml str) -> Result<(Self, Element<'xml>), Error> {
144        let mut new = Self {
145            parser: Tokenizer::from(input),
146            stack: Vec::new(),
147            records: VecDeque::new(),
148        };
149
150        let root = match new.next() {
151            Some(result) => match result? {
152                Node::Open(element) => element,
153                _ => return Err(Error::UnexpectedState("first node does not open element")),
154            },
155            None => return Err(Error::UnexpectedEndOfStream),
156        };
157
158        Ok((new, root))
159    }
160
161    pub(crate) fn element_id(&self, element: &Element<'xml>) -> Result<Id<'xml>, Error> {
162        Ok(Id {
163            ns: match (element.default_ns, element.prefix) {
164                (_, Some(prefix)) => match self.lookup(prefix) {
165                    Some(ns) => ns,
166                    None => return Err(Error::UnknownPrefix(prefix.to_owned())),
167                },
168                (Some(ns), None) => ns,
169                (None, None) => self.default_ns(),
170            },
171            name: element.local,
172        })
173    }
174
175    fn attribute_id(&self, attr: &Attribute<'xml>) -> Result<Id<'xml>, Error> {
176        Ok(Id {
177            ns: match attr.prefix {
178                Some(ns) => self
179                    .lookup(ns)
180                    .ok_or_else(|| Error::UnknownPrefix(ns.to_owned()))?,
181                None => "",
182            },
183            name: attr.local,
184        })
185    }
186
187    fn default_ns(&self) -> &'xml str {
188        self.stack
189            .iter()
190            .rev()
191            .find_map(|level| level.default_ns)
192            .unwrap_or("")
193    }
194
195    fn lookup(&self, prefix: &str) -> Option<&'xml str> {
196        // The prefix xml is by definition bound to the namespace
197        // name http://www.w3.org/XML/1998/namespace
198        // See https://www.w3.org/TR/xml-names/#ns-decl
199        if prefix == "xml" {
200            return Some("http://www.w3.org/XML/1998/namespace");
201        }
202
203        self.stack
204            .iter()
205            .rev()
206            .find_map(|level| level.prefixes.get(prefix).copied())
207    }
208}
209
210impl<'xml> Iterator for Context<'xml> {
211    type Item = Result<Node<'xml>, Error>;
212
213    fn next(&mut self) -> Option<Self::Item> {
214        if let Some(record) = self.records.pop_front() {
215            if let Node::Close { .. } = &record {
216                self.stack.pop();
217            }
218            return Some(Ok(record));
219        }
220
221        loop {
222            match self.parser.next()? {
223                Ok(Token::ElementStart { prefix, local, .. }) => {
224                    let prefix = prefix.as_str();
225                    self.stack.push(Level {
226                        local: local.as_str(),
227                        prefix: match prefix.is_empty() {
228                            true => None,
229                            false => Some(prefix),
230                        },
231                        default_ns: None,
232                        prefixes: BTreeMap::new(),
233                    });
234                }
235                Ok(Token::ElementEnd { end, .. }) => match end {
236                    ElementEnd::Open => {
237                        let Some(level) = self.stack.last() else {
238                            return Some(Err(Error::UnexpectedState(
239                                "opening element with no parent",
240                            )));
241                        };
242
243                        let element = Element {
244                            local: level.local,
245                            prefix: level.prefix,
246                            default_ns: level.default_ns,
247                        };
248
249                        return Some(Ok(Node::Open(element)));
250                    }
251                    ElementEnd::Close(prefix, v) => {
252                        let Some(level) = self.stack.pop() else {
253                            return Some(Err(Error::UnexpectedState(
254                                "closing element without parent",
255                            )));
256                        };
257
258                        let prefix = match prefix.is_empty() {
259                            true => None,
260                            false => Some(prefix.as_str()),
261                        };
262
263                        return Some(match v.as_str() == level.local && prefix == level.prefix {
264                            true => Ok(Node::Close {
265                                prefix,
266                                local: level.local,
267                            }),
268                            false => Err(Error::UnexpectedState("close element mismatch")),
269                        });
270                    }
271                    ElementEnd::Empty => {
272                        let Some(level) = self.stack.last() else {
273                            return Some(Err(Error::UnexpectedState(
274                                "opening element with no parent",
275                            )));
276                        };
277
278                        self.records.push_back(Node::Close {
279                            prefix: level.prefix,
280                            local: level.local,
281                        });
282
283                        let element = Element {
284                            local: level.local,
285                            prefix: level.prefix,
286                            default_ns: level.default_ns,
287                        };
288
289                        return Some(Ok(Node::Open(element)));
290                    }
291                },
292                Ok(Token::Attribute {
293                    prefix,
294                    local,
295                    value,
296                    ..
297                }) => {
298                    if prefix.is_empty() && local.as_str() == "xmlns" {
299                        match self.stack.last_mut() {
300                            Some(level) => level.default_ns = Some(value.as_str()),
301                            None => {
302                                return Some(Err(Error::UnexpectedState(
303                                    "attribute without element context",
304                                )))
305                            }
306                        }
307                    } else if prefix.as_str() == "xmlns" {
308                        match self.stack.last_mut() {
309                            Some(level) => {
310                                level.prefixes.insert(local.as_str(), value.as_str());
311                            }
312                            None => {
313                                return Some(Err(Error::UnexpectedState(
314                                    "attribute without element context",
315                                )))
316                            }
317                        }
318                    } else {
319                        let value = match decode(value.as_str()) {
320                            Ok(value) => value,
321                            Err(e) => return Some(Err(e)),
322                        };
323
324                        self.records.push_back(Node::Attribute(Attribute {
325                            prefix: match prefix.is_empty() {
326                                true => None,
327                                false => Some(prefix.as_str()),
328                            },
329                            local: local.as_str(),
330                            value,
331                        }));
332                    }
333                }
334                Ok(Token::Text { text }) => {
335                    return Some(decode(text.as_str()).map(Node::Text));
336                }
337                Ok(Token::Cdata { text, .. }) => {
338                    return Some(Ok(Node::Text(Cow::Borrowed(text.as_str()))));
339                }
340                Ok(token @ Token::Declaration { .. }) => {
341                    if !self.stack.is_empty() {
342                        return Some(Err(Error::UnexpectedToken(format!("{token:?}"))));
343                    }
344                }
345                Ok(Token::Comment { .. }) => continue,
346                Ok(token) => return Some(Err(Error::UnexpectedToken(format!("{token:?}")))),
347                Err(e) => return Some(Err(Error::Parse(e))),
348            }
349        }
350    }
351}
352
353/// Deserialize a borrowed `Cow<str>` value
354///
355/// Helper function for deserializing `Cow<str>` with zero-copy borrowing from the input.
356pub fn borrow_cow_str<'a, 'xml: 'a>(
357    into: &mut CowStrAccumulator<'xml, 'a>,
358    field: &'static str,
359    deserializer: &mut Deserializer<'_, 'xml>,
360) -> Result<(), Error> {
361    if into.inner.is_some() {
362        return Err(Error::DuplicateValue(field));
363    }
364
365    match deserializer.take_str()? {
366        Some(value) => into.inner = Some(value),
367        None => return Ok(()),
368    };
369
370    deserializer.ignore()?;
371    Ok(())
372}
373
374/// Deserialize a borrowed `Cow<[u8]>` value
375///
376/// Helper function for deserializing `Cow<[u8]>` with zero-copy borrowing from the input.
377pub fn borrow_cow_slice_u8<'xml>(
378    into: &mut Option<Cow<'xml, [u8]>>,
379    field: &'static str,
380    deserializer: &mut Deserializer<'_, 'xml>,
381) -> Result<(), Error> {
382    if into.is_some() {
383        return Err(Error::DuplicateValue(field));
384    }
385
386    if let Some(value) = deserializer.take_str()? {
387        *into = Some(match value {
388            Cow::Borrowed(v) => Cow::Borrowed(v.as_bytes()),
389            Cow::Owned(v) => Cow::Owned(v.into_bytes()),
390        });
391    }
392
393    deserializer.ignore()?;
394    Ok(())
395}
396
397fn decode(input: &str) -> Result<Cow<'_, str>, Error> {
398    let mut result = String::with_capacity(input.len());
399    let (mut state, mut last_end) = (DecodeState::Normal, 0);
400    for (i, &b) in input.as_bytes().iter().enumerate() {
401        // use a state machine to find entities
402        state = match (state, b) {
403            (DecodeState::Normal, b'&') => DecodeState::Entity([0; 6], 0),
404            (DecodeState::Normal, _) => DecodeState::Normal,
405            (DecodeState::Entity(chars, len), b';') => {
406                let decoded = match &chars[..len] {
407                    [b'a', b'm', b'p'] => '&',
408                    [b'a', b'p', b'o', b's'] => '\'',
409                    [b'g', b't'] => '>',
410                    [b'l', b't'] => '<',
411                    [b'q', b'u', b'o', b't'] => '"',
412                    [b'#', b'x' | b'X', hex @ ..] => {
413                        // Hexadecimal character reference e.g. "&#x007c;" -> '|'
414                        str::from_utf8(hex)
415                            .ok()
416                            .and_then(|hex_str| u32::from_str_radix(hex_str, 16).ok())
417                            .and_then(char::from_u32)
418                            .filter(valid_xml_character)
419                            .ok_or_else(|| {
420                                Error::InvalidEntity(
421                                    String::from_utf8_lossy(&chars[..len]).into_owned(),
422                                )
423                            })?
424                    }
425                    [b'#', decimal @ ..] => {
426                        // Decimal character reference e.g. "&#1234;" -> 'Ӓ'
427                        str::from_utf8(decimal)
428                            .ok()
429                            .and_then(|decimal_str| u32::from_str(decimal_str).ok())
430                            .and_then(char::from_u32)
431                            .filter(valid_xml_character)
432                            .ok_or_else(|| {
433                                Error::InvalidEntity(
434                                    String::from_utf8_lossy(&chars[..len]).into_owned(),
435                                )
436                            })?
437                    }
438                    _ => {
439                        return Err(Error::InvalidEntity(
440                            String::from_utf8_lossy(&chars[..len]).into_owned(),
441                        ))
442                    }
443                };
444
445                let start = i - (len + 1); // current position - (length of entity characters + 1 for '&')
446                if last_end < start {
447                    // Unwrap should be safe: `last_end` and `start` must be at character boundaries.
448                    result.push_str(input.get(last_end..start).unwrap());
449                }
450
451                last_end = i + 1;
452                result.push(decoded);
453                DecodeState::Normal
454            }
455            (DecodeState::Entity(mut chars, len), b) => {
456                if len >= 6 {
457                    let mut bytes = Vec::with_capacity(7);
458                    bytes.extend(&chars[..len]);
459                    bytes.push(b);
460                    return Err(Error::InvalidEntity(
461                        String::from_utf8_lossy(&bytes).into_owned(),
462                    ));
463                }
464
465                chars[len] = b;
466                DecodeState::Entity(chars, len + 1)
467            }
468        };
469    }
470
471    // Unterminated entity (& without ;) at end of input
472    if let DecodeState::Entity(chars, len) = state {
473        return Err(Error::InvalidEntity(
474            String::from_utf8_lossy(&chars[..len]).into_owned(),
475        ));
476    }
477
478    Ok(match result.is_empty() {
479        true => Cow::Borrowed(input),
480        false => {
481            // Unwrap should be safe: `last_end` and `input.len()` must be at character boundaries.
482            result.push_str(input.get(last_end..input.len()).unwrap());
483            Cow::Owned(result)
484        }
485    })
486}
487
488#[derive(Debug)]
489enum DecodeState {
490    Normal,
491    Entity([u8; 6], usize),
492}
493
494/// Valid character ranges per <https://www.w3.org/TR/xml/#NT-Char>
495fn valid_xml_character(c: &char) -> bool {
496    matches!(c, '\u{9}' | '\u{A}' | '\u{D}' | '\u{20}'..='\u{D7FF}' | '\u{E000}'..='\u{FFFD}' | '\u{10000}'..='\u{10FFFF}')
497}
498
499/// An XML node during deserialization
500#[derive(Debug)]
501pub enum Node<'xml> {
502    /// An attribute name (value follows in a separate AttributeValue node)
503    Attribute(Attribute<'xml>),
504    /// The value of the preceding Attribute node
505    AttributeValue(Cow<'xml, str>),
506    /// Closing tag for an element
507    Close {
508        /// The namespace prefix, if any
509        prefix: Option<&'xml str>,
510        /// The local name
511        local: &'xml str,
512    },
513    /// Text content
514    Text(Cow<'xml, str>),
515    /// Opening tag for an element
516    Open(Element<'xml>),
517}
518
519/// An XML element during deserialization
520#[derive(Debug)]
521pub struct Element<'xml> {
522    local: &'xml str,
523    default_ns: Option<&'xml str>,
524    prefix: Option<&'xml str>,
525}
526
527#[derive(Debug)]
528struct Level<'xml> {
529    local: &'xml str,
530    prefix: Option<&'xml str>,
531    default_ns: Option<&'xml str>,
532    prefixes: BTreeMap<&'xml str, &'xml str>,
533}
534
535/// An XML attribute during deserialization
536#[derive(Debug)]
537pub struct Attribute<'xml> {
538    /// The namespace prefix, if any
539    pub prefix: Option<&'xml str>,
540    /// The local name
541    pub local: &'xml str,
542    /// The attribute value
543    pub value: Cow<'xml, str>,
544}
545
546#[cfg(test)]
547mod tests {
548    use super::*;
549
550    #[test]
551    fn test_decode() {
552        decode_ok("foo", "foo");
553        decode_ok("foo &amp; bar", "foo & bar");
554        decode_ok("foo &lt; bar", "foo < bar");
555        decode_ok("foo &gt; bar", "foo > bar");
556        decode_ok("foo &quot; bar", "foo \" bar");
557        decode_ok("foo &apos; bar", "foo ' bar");
558        decode_ok("foo &amp;lt; bar", "foo &lt; bar");
559        decode_ok("&amp; foo", "& foo");
560        decode_ok("foo &amp;", "foo &");
561        decode_ok("cbdtéda&amp;sü", "cbdtéda&sü");
562        // Decimal character references
563        decode_ok("&#1234;", "Ӓ");
564        decode_ok("foo &#9; bar", "foo \t bar");
565        decode_ok("foo &#124; bar", "foo | bar");
566        decode_ok("foo &#1234; bar", "foo Ӓ bar");
567        // Hexadecimal character references
568        decode_ok("&#xc4;", "Ä");
569        decode_ok("&#x00c4;", "Ä");
570        decode_ok("foo &#x9; bar", "foo \t bar");
571        decode_ok("foo &#x007c; bar", "foo | bar");
572        decode_ok("foo &#xc4; bar", "foo Ä bar");
573        decode_ok("foo &#x00c4; bar", "foo Ä bar");
574        decode_ok("foo &#x10de; bar", "foo პ bar");
575
576        decode_err("&");
577        decode_err("&#");
578        decode_err("&#;");
579        decode_err("foo&");
580        decode_err("&bar");
581        decode_err("&foo;");
582        decode_err("&foobar;");
583        decode_err("cbdtéd&ampü");
584    }
585
586    fn decode_ok(input: &str, expected: &'static str) {
587        assert_eq!(decode(input).unwrap(), expected, "{input:?}");
588    }
589
590    fn decode_err(input: &str) {
591        assert!(decode(input).is_err(), "{input:?}");
592    }
593}