Skip to main content

instant_xml/
de.rs

1//! XML deserialization support code
2
3use std::borrow::Cow;
4use std::collections::{BTreeMap, VecDeque};
5use std::str::{self, FromStr};
6
7use xmlparser::{ElementEnd, Token, Tokenizer};
8
9use crate::impls::CowStrAccumulator;
10use crate::{Error, Id};
11
12/// XML deserializer for iterating over nodes in an element
13pub struct Deserializer<'cx, 'xml> {
14    pub(crate) local: &'xml str,
15    prefix: Option<&'xml str>,
16    level: usize,
17    done: bool,
18    context: &'cx mut Context<'xml>,
19}
20
21impl<'cx, 'xml> Deserializer<'cx, 'xml> {
22    pub(crate) fn new(element: Element<'xml>, context: &'cx mut Context<'xml>) -> Self {
23        let level = context.stack.len();
24        Self {
25            local: element.local,
26            prefix: element.prefix,
27            level,
28            done: false,
29            context,
30        }
31    }
32
33    /// Extract a string value from the current node
34    ///
35    /// Consumes a text node or attribute value, returning the content as a string.
36    pub fn take_str(&mut self) -> Result<Option<Cow<'xml, str>>, Error> {
37        loop {
38            match self.next() {
39                Some(Ok(Node::AttributeValue(s))) => return Ok(Some(s)),
40                Some(Ok(Node::Text(s))) => return Ok(Some(s)),
41                Some(Ok(Node::Attribute(_))) => continue,
42                Some(Ok(node)) => return Err(Error::ExpectedScalar(format!("{node:?}"))),
43                Some(Err(e)) => return Err(e),
44                None => return Ok(None),
45            }
46        }
47    }
48
49    /// Create a nested deserializer for a child element
50    pub fn nested<'a>(&'a mut self, element: Element<'xml>) -> Deserializer<'a, 'xml>
51    where
52        'cx: 'a,
53    {
54        Deserializer::new(element, self.context)
55    }
56
57    /// Skip all remaining nodes in the current element
58    pub fn ignore(&mut self) -> Result<(), Error> {
59        loop {
60            match self.next() {
61                Some(Err(e)) => return Err(e),
62                Some(Ok(Node::Open(element))) => {
63                    let mut nested = self.nested(element);
64                    nested.ignore()?;
65                }
66                Some(_) => continue,
67                None => return Ok(()),
68            }
69        }
70    }
71
72    /// Create a deserializer that will yield the given node first
73    pub fn for_node<'a>(&'a mut self, node: Node<'xml>) -> Deserializer<'a, 'xml>
74    where
75        'cx: 'a,
76    {
77        self.context.records.push_front(node);
78        Deserializer {
79            local: self.local,
80            prefix: self.prefix,
81            level: self.level,
82            done: self.done,
83            context: self.context,
84        }
85    }
86
87    /// Get the identifier of the parent element
88    pub fn parent(&self) -> Id<'xml> {
89        Id {
90            ns: match self.prefix {
91                Some(ns) => self.context.lookup(ns).unwrap(),
92                None => self.context.default_ns(),
93            },
94            name: self.local,
95        }
96    }
97
98    pub fn namespace(&self, prefix: &str) -> Option<&'xml str> {
99        self.context.lookup(prefix)
100    }
101
102    /// Get the identifier of an element (name and namespace)
103    #[inline]
104    pub fn element_id(&self, element: &Element<'xml>) -> Result<Id<'xml>, Error> {
105        self.context.element_id(element)
106    }
107
108    /// Get the identifier of an attribute (name and namespace)
109    #[inline]
110    pub fn attribute_id(&self, attr: &Attribute<'xml>) -> Result<Id<'xml>, Error> {
111        self.context.attribute_id(attr)
112    }
113}
114
115impl<'xml> Iterator for Deserializer<'_, 'xml> {
116    type Item = Result<Node<'xml>, Error>;
117
118    fn next(&mut self) -> Option<Self::Item> {
119        if self.done {
120            return None;
121        }
122
123        let (prefix, local) = match self.context.next() {
124            Some(Ok(Node::Close { prefix, local })) => (prefix, local),
125            item => return item,
126        };
127
128        if self.context.stack.len() == self.level - 1
129            && local == self.local
130            && prefix == self.prefix
131        {
132            self.done = true;
133            return None;
134        }
135
136        Some(Err(Error::UnexpectedState("close element mismatch")))
137    }
138}
139
140pub(crate) struct Context<'xml> {
141    parser: Tokenizer<'xml>,
142    stack: Vec<Level<'xml>>,
143    records: VecDeque<Node<'xml>>,
144}
145
146impl<'xml> Context<'xml> {
147    pub(crate) fn new(input: &'xml str) -> Result<(Self, Element<'xml>), Error> {
148        let mut new = Self {
149            parser: Tokenizer::from(input),
150            stack: Vec::new(),
151            records: VecDeque::new(),
152        };
153
154        let root = match new.next() {
155            Some(result) => match result? {
156                Node::Open(element) => element,
157                _ => return Err(Error::UnexpectedState("first node does not open element")),
158            },
159            None => return Err(Error::UnexpectedEndOfStream),
160        };
161
162        Ok((new, root))
163    }
164
165    pub(crate) fn element_id(&self, element: &Element<'xml>) -> Result<Id<'xml>, Error> {
166        Ok(Id {
167            ns: match (element.default_ns, element.prefix) {
168                (_, Some(prefix)) => match self.lookup(prefix) {
169                    Some(ns) => ns,
170                    None => return Err(Error::UnknownPrefix(prefix.to_owned())),
171                },
172                (Some(ns), None) => ns,
173                (None, None) => self.default_ns(),
174            },
175            name: element.local,
176        })
177    }
178
179    fn attribute_id(&self, attr: &Attribute<'xml>) -> Result<Id<'xml>, Error> {
180        Ok(Id {
181            ns: match attr.prefix {
182                Some(ns) => self
183                    .lookup(ns)
184                    .ok_or_else(|| Error::UnknownPrefix(ns.to_owned()))?,
185                None => "",
186            },
187            name: attr.local,
188        })
189    }
190
191    fn default_ns(&self) -> &'xml str {
192        self.stack
193            .iter()
194            .rev()
195            .find_map(|level| level.default_ns)
196            .unwrap_or("")
197    }
198
199    fn lookup(&self, prefix: &str) -> Option<&'xml str> {
200        // The prefix xml is by definition bound to the namespace
201        // name http://www.w3.org/XML/1998/namespace
202        // See https://www.w3.org/TR/xml-names/#ns-decl
203        if prefix == "xml" {
204            return Some("http://www.w3.org/XML/1998/namespace");
205        }
206
207        self.stack
208            .iter()
209            .rev()
210            .find_map(|level| level.prefixes.get(prefix).copied())
211    }
212}
213
214impl<'xml> Iterator for Context<'xml> {
215    type Item = Result<Node<'xml>, Error>;
216
217    fn next(&mut self) -> Option<Self::Item> {
218        if let Some(record) = self.records.pop_front() {
219            if let Node::Close { .. } = &record {
220                self.stack.pop();
221            }
222            return Some(Ok(record));
223        }
224
225        loop {
226            match self.parser.next()? {
227                Ok(Token::ElementStart { prefix, local, .. }) => {
228                    let prefix = prefix.as_str();
229                    self.stack.push(Level {
230                        local: local.as_str(),
231                        prefix: match prefix.is_empty() {
232                            true => None,
233                            false => Some(prefix),
234                        },
235                        default_ns: None,
236                        prefixes: BTreeMap::new(),
237                    });
238                }
239                Ok(Token::ElementEnd { end, .. }) => match end {
240                    ElementEnd::Open => {
241                        let Some(level) = self.stack.last() else {
242                            return Some(Err(Error::UnexpectedState(
243                                "opening element with no parent",
244                            )));
245                        };
246
247                        let element = Element {
248                            local: level.local,
249                            prefix: level.prefix,
250                            default_ns: level.default_ns,
251                        };
252
253                        return Some(Ok(Node::Open(element)));
254                    }
255                    ElementEnd::Close(prefix, v) => {
256                        let Some(level) = self.stack.pop() else {
257                            return Some(Err(Error::UnexpectedState(
258                                "closing element without parent",
259                            )));
260                        };
261
262                        let prefix = match prefix.is_empty() {
263                            true => None,
264                            false => Some(prefix.as_str()),
265                        };
266
267                        return Some(match v.as_str() == level.local && prefix == level.prefix {
268                            true => Ok(Node::Close {
269                                prefix,
270                                local: level.local,
271                            }),
272                            false => Err(Error::UnexpectedState("close element mismatch")),
273                        });
274                    }
275                    ElementEnd::Empty => {
276                        let Some(level) = self.stack.last() else {
277                            return Some(Err(Error::UnexpectedState(
278                                "opening element with no parent",
279                            )));
280                        };
281
282                        self.records.push_back(Node::Close {
283                            prefix: level.prefix,
284                            local: level.local,
285                        });
286
287                        let element = Element {
288                            local: level.local,
289                            prefix: level.prefix,
290                            default_ns: level.default_ns,
291                        };
292
293                        return Some(Ok(Node::Open(element)));
294                    }
295                },
296                Ok(Token::Attribute {
297                    prefix,
298                    local,
299                    value,
300                    ..
301                }) => {
302                    if prefix.is_empty() && local.as_str() == "xmlns" {
303                        match self.stack.last_mut() {
304                            Some(level) => level.default_ns = Some(value.as_str()),
305                            None => {
306                                return Some(Err(Error::UnexpectedState(
307                                    "attribute without element context",
308                                )))
309                            }
310                        }
311                    } else if prefix.as_str() == "xmlns" {
312                        match self.stack.last_mut() {
313                            Some(level) => {
314                                level.prefixes.insert(local.as_str(), value.as_str());
315                            }
316                            None => {
317                                return Some(Err(Error::UnexpectedState(
318                                    "attribute without element context",
319                                )))
320                            }
321                        }
322                    } else {
323                        let value = match decode(value.as_str()) {
324                            Ok(value) => value,
325                            Err(e) => return Some(Err(e)),
326                        };
327
328                        self.records.push_back(Node::Attribute(Attribute {
329                            prefix: match prefix.is_empty() {
330                                true => None,
331                                false => Some(prefix.as_str()),
332                            },
333                            local: local.as_str(),
334                            value,
335                        }));
336                    }
337                }
338                Ok(Token::Text { text }) => {
339                    return Some(decode(text.as_str()).map(Node::Text));
340                }
341                Ok(Token::Cdata { text, .. }) => {
342                    return Some(Ok(Node::Text(Cow::Borrowed(text.as_str()))));
343                }
344                Ok(token @ Token::Declaration { .. }) => {
345                    if !self.stack.is_empty() {
346                        return Some(Err(Error::UnexpectedToken(format!("{token:?}"))));
347                    }
348                }
349                Ok(Token::Comment { .. }) => continue,
350                Ok(token) => return Some(Err(Error::UnexpectedToken(format!("{token:?}")))),
351                Err(e) => return Some(Err(Error::Parse(e))),
352            }
353        }
354    }
355}
356
357/// Deserialize a borrowed `Cow<str>` value
358///
359/// Helper function for deserializing `Cow<str>` with zero-copy borrowing from the input.
360pub fn borrow_cow_str<'a, 'xml: 'a>(
361    into: &mut CowStrAccumulator<'xml, 'a>,
362    field: &'static str,
363    deserializer: &mut Deserializer<'_, 'xml>,
364) -> Result<(), Error> {
365    if into.inner.is_some() {
366        return Err(Error::DuplicateValue(field));
367    }
368
369    match deserializer.take_str()? {
370        Some(value) => into.inner = Some(value),
371        None => return Ok(()),
372    };
373
374    deserializer.ignore()?;
375    Ok(())
376}
377
378/// Deserialize a borrowed `Cow<[u8]>` value
379///
380/// Helper function for deserializing `Cow<[u8]>` with zero-copy borrowing from the input.
381pub fn borrow_cow_slice_u8<'xml>(
382    into: &mut Option<Cow<'xml, [u8]>>,
383    field: &'static str,
384    deserializer: &mut Deserializer<'_, 'xml>,
385) -> Result<(), Error> {
386    if into.is_some() {
387        return Err(Error::DuplicateValue(field));
388    }
389
390    if let Some(value) = deserializer.take_str()? {
391        *into = Some(match value {
392            Cow::Borrowed(v) => Cow::Borrowed(v.as_bytes()),
393            Cow::Owned(v) => Cow::Owned(v.into_bytes()),
394        });
395    }
396
397    deserializer.ignore()?;
398    Ok(())
399}
400
401fn decode(input: &str) -> Result<Cow<'_, str>, Error> {
402    let mut result = String::with_capacity(input.len());
403    let (mut state, mut last_end) = (DecodeState::Normal, 0);
404    for (i, &b) in input.as_bytes().iter().enumerate() {
405        // use a state machine to find entities
406        state = match (state, b) {
407            (DecodeState::Normal, b'&') => DecodeState::Entity([0; 6], 0),
408            (DecodeState::Normal, _) => DecodeState::Normal,
409            (DecodeState::Entity(chars, len), b';') => {
410                let decoded = match &chars[..len] {
411                    [b'a', b'm', b'p'] => '&',
412                    [b'a', b'p', b'o', b's'] => '\'',
413                    [b'g', b't'] => '>',
414                    [b'l', b't'] => '<',
415                    [b'q', b'u', b'o', b't'] => '"',
416                    [b'#', b'x' | b'X', hex @ ..] => {
417                        // Hexadecimal character reference e.g. "&#x007c;" -> '|'
418                        str::from_utf8(hex)
419                            .ok()
420                            .and_then(|hex_str| u32::from_str_radix(hex_str, 16).ok())
421                            .and_then(char::from_u32)
422                            .filter(valid_xml_character)
423                            .ok_or_else(|| {
424                                Error::InvalidEntity(
425                                    String::from_utf8_lossy(&chars[..len]).into_owned(),
426                                )
427                            })?
428                    }
429                    [b'#', decimal @ ..] => {
430                        // Decimal character reference e.g. "&#1234;" -> 'Ӓ'
431                        str::from_utf8(decimal)
432                            .ok()
433                            .and_then(|decimal_str| u32::from_str(decimal_str).ok())
434                            .and_then(char::from_u32)
435                            .filter(valid_xml_character)
436                            .ok_or_else(|| {
437                                Error::InvalidEntity(
438                                    String::from_utf8_lossy(&chars[..len]).into_owned(),
439                                )
440                            })?
441                    }
442                    _ => {
443                        return Err(Error::InvalidEntity(
444                            String::from_utf8_lossy(&chars[..len]).into_owned(),
445                        ))
446                    }
447                };
448
449                let start = i - (len + 1); // current position - (length of entity characters + 1 for '&')
450                if last_end < start {
451                    // Unwrap should be safe: `last_end` and `start` must be at character boundaries.
452                    result.push_str(input.get(last_end..start).unwrap());
453                }
454
455                last_end = i + 1;
456                result.push(decoded);
457                DecodeState::Normal
458            }
459            (DecodeState::Entity(mut chars, len), b) => {
460                if len >= 6 {
461                    let mut bytes = Vec::with_capacity(7);
462                    bytes.extend(&chars[..len]);
463                    bytes.push(b);
464                    return Err(Error::InvalidEntity(
465                        String::from_utf8_lossy(&bytes).into_owned(),
466                    ));
467                }
468
469                chars[len] = b;
470                DecodeState::Entity(chars, len + 1)
471            }
472        };
473    }
474
475    // Unterminated entity (& without ;) at end of input
476    if let DecodeState::Entity(chars, len) = state {
477        return Err(Error::InvalidEntity(
478            String::from_utf8_lossy(&chars[..len]).into_owned(),
479        ));
480    }
481
482    Ok(match result.is_empty() {
483        true => Cow::Borrowed(input),
484        false => {
485            // Unwrap should be safe: `last_end` and `input.len()` must be at character boundaries.
486            result.push_str(input.get(last_end..input.len()).unwrap());
487            Cow::Owned(result)
488        }
489    })
490}
491
492#[derive(Debug)]
493enum DecodeState {
494    Normal,
495    Entity([u8; 6], usize),
496}
497
498/// Valid character ranges per <https://www.w3.org/TR/xml/#NT-Char>
499fn valid_xml_character(c: &char) -> bool {
500    matches!(c, '\u{9}' | '\u{A}' | '\u{D}' | '\u{20}'..='\u{D7FF}' | '\u{E000}'..='\u{FFFD}' | '\u{10000}'..='\u{10FFFF}')
501}
502
503/// An XML node during deserialization
504#[derive(Debug)]
505pub enum Node<'xml> {
506    /// An attribute name (value follows in a separate AttributeValue node)
507    Attribute(Attribute<'xml>),
508    /// The value of the preceding Attribute node
509    AttributeValue(Cow<'xml, str>),
510    /// Closing tag for an element
511    Close {
512        /// The namespace prefix, if any
513        prefix: Option<&'xml str>,
514        /// The local name
515        local: &'xml str,
516    },
517    /// Text content
518    Text(Cow<'xml, str>),
519    /// Opening tag for an element
520    Open(Element<'xml>),
521}
522
523/// An XML element during deserialization
524#[derive(Debug)]
525pub struct Element<'xml> {
526    local: &'xml str,
527    default_ns: Option<&'xml str>,
528    prefix: Option<&'xml str>,
529}
530
531#[derive(Debug)]
532struct Level<'xml> {
533    local: &'xml str,
534    prefix: Option<&'xml str>,
535    default_ns: Option<&'xml str>,
536    prefixes: BTreeMap<&'xml str, &'xml str>,
537}
538
539/// An XML attribute during deserialization
540#[derive(Debug)]
541pub struct Attribute<'xml> {
542    /// The namespace prefix, if any
543    pub prefix: Option<&'xml str>,
544    /// The local name
545    pub local: &'xml str,
546    /// The attribute value
547    pub value: Cow<'xml, str>,
548}
549
550#[cfg(test)]
551mod tests {
552    use super::*;
553
554    #[test]
555    fn test_decode() {
556        decode_ok("foo", "foo");
557        decode_ok("foo &amp; bar", "foo & bar");
558        decode_ok("foo &lt; bar", "foo < bar");
559        decode_ok("foo &gt; bar", "foo > bar");
560        decode_ok("foo &quot; bar", "foo \" bar");
561        decode_ok("foo &apos; bar", "foo ' bar");
562        decode_ok("foo &amp;lt; bar", "foo &lt; bar");
563        decode_ok("&amp; foo", "& foo");
564        decode_ok("foo &amp;", "foo &");
565        decode_ok("cbdtéda&amp;sü", "cbdtéda&sü");
566        // Decimal character references
567        decode_ok("&#1234;", "Ӓ");
568        decode_ok("foo &#9; bar", "foo \t bar");
569        decode_ok("foo &#124; bar", "foo | bar");
570        decode_ok("foo &#1234; bar", "foo Ӓ bar");
571        // Hexadecimal character references
572        decode_ok("&#xc4;", "Ä");
573        decode_ok("&#x00c4;", "Ä");
574        decode_ok("foo &#x9; bar", "foo \t bar");
575        decode_ok("foo &#x007c; bar", "foo | bar");
576        decode_ok("foo &#xc4; bar", "foo Ä bar");
577        decode_ok("foo &#x00c4; bar", "foo Ä bar");
578        decode_ok("foo &#x10de; bar", "foo პ bar");
579
580        decode_err("&");
581        decode_err("&#");
582        decode_err("&#;");
583        decode_err("foo&");
584        decode_err("&bar");
585        decode_err("&foo;");
586        decode_err("&foobar;");
587        decode_err("cbdtéd&ampü");
588    }
589
590    fn decode_ok(input: &str, expected: &'static str) {
591        assert_eq!(decode(input).unwrap(), expected, "{input:?}");
592    }
593
594    fn decode_err(input: &str) {
595        assert!(decode(input).is_err(), "{input:?}");
596    }
597}