Skip to main content

instant_xml/
de.rs

1use std::borrow::Cow;
2use std::collections::{BTreeMap, VecDeque};
3use std::str::{self, FromStr};
4
5use xmlparser::{ElementEnd, Token, Tokenizer};
6
7use crate::impls::CowStrAccumulator;
8use crate::{Error, Id};
9
10/// XML deserializer for iterating over nodes in an element
11pub struct Deserializer<'cx, 'xml> {
12    pub(crate) local: &'xml str,
13    prefix: Option<&'xml str>,
14    level: usize,
15    done: bool,
16    context: &'cx mut Context<'xml>,
17}
18
19impl<'cx, 'xml> Deserializer<'cx, 'xml> {
20    pub(crate) fn new(element: Element<'xml>, context: &'cx mut Context<'xml>) -> Self {
21        let level = context.stack.len();
22        Self {
23            local: element.local,
24            prefix: element.prefix,
25            level,
26            done: false,
27            context,
28        }
29    }
30
31    /// Extract a string value from the current node
32    ///
33    /// Consumes a text node or attribute value, returning the content as a string.
34    pub fn take_str(&mut self) -> Result<Option<Cow<'xml, str>>, Error> {
35        loop {
36            match self.next() {
37                Some(Ok(Node::AttributeValue(s))) => return Ok(Some(s)),
38                Some(Ok(Node::Text(s))) => return Ok(Some(s)),
39                Some(Ok(Node::Attribute(_))) => continue,
40                Some(Ok(node)) => return Err(Error::ExpectedScalar(format!("{node:?}"))),
41                Some(Err(e)) => return Err(e),
42                None => return Ok(None),
43            }
44        }
45    }
46
47    /// Create a nested deserializer for a child element
48    pub fn nested<'a>(&'a mut self, element: Element<'xml>) -> Deserializer<'a, 'xml>
49    where
50        'cx: 'a,
51    {
52        Deserializer::new(element, self.context)
53    }
54
55    /// Skip all remaining nodes in the current element
56    pub fn ignore(&mut self) -> Result<(), Error> {
57        loop {
58            match self.next() {
59                Some(Err(e)) => return Err(e),
60                Some(Ok(Node::Open(element))) => {
61                    let mut nested = self.nested(element);
62                    nested.ignore()?;
63                }
64                Some(_) => continue,
65                None => return Ok(()),
66            }
67        }
68    }
69
70    /// Create a deserializer that will yield the given node first
71    pub fn for_node<'a>(&'a mut self, node: Node<'xml>) -> Deserializer<'a, 'xml>
72    where
73        'cx: 'a,
74    {
75        self.context.records.push_front(node);
76        Deserializer {
77            local: self.local,
78            prefix: self.prefix,
79            level: self.level,
80            done: self.done,
81            context: self.context,
82        }
83    }
84
85    /// Get the identifier of the parent element
86    pub fn parent(&self) -> Id<'xml> {
87        Id {
88            ns: match self.prefix {
89                Some(ns) => self.context.lookup(ns).unwrap(),
90                None => self.context.default_ns(),
91            },
92            name: self.local,
93        }
94    }
95
96    /// Get the identifier of an element (name and namespace)
97    #[inline]
98    pub fn element_id(&self, element: &Element<'xml>) -> Result<Id<'xml>, Error> {
99        self.context.element_id(element)
100    }
101
102    /// Get the identifier of an attribute (name and namespace)
103    #[inline]
104    pub fn attribute_id(&self, attr: &Attribute<'xml>) -> Result<Id<'xml>, Error> {
105        self.context.attribute_id(attr)
106    }
107}
108
109impl<'xml> Iterator for Deserializer<'_, 'xml> {
110    type Item = Result<Node<'xml>, Error>;
111
112    fn next(&mut self) -> Option<Self::Item> {
113        if self.done {
114            return None;
115        }
116
117        let (prefix, local) = match self.context.next() {
118            Some(Ok(Node::Close { prefix, local })) => (prefix, local),
119            item => return item,
120        };
121
122        if self.context.stack.len() == self.level - 1
123            && local == self.local
124            && prefix == self.prefix
125        {
126            self.done = true;
127            return None;
128        }
129
130        Some(Err(Error::UnexpectedState("close element mismatch")))
131    }
132}
133
134pub(crate) struct Context<'xml> {
135    parser: Tokenizer<'xml>,
136    stack: Vec<Level<'xml>>,
137    records: VecDeque<Node<'xml>>,
138}
139
140impl<'xml> Context<'xml> {
141    pub(crate) fn new(input: &'xml str) -> Result<(Self, Element<'xml>), Error> {
142        let mut new = Self {
143            parser: Tokenizer::from(input),
144            stack: Vec::new(),
145            records: VecDeque::new(),
146        };
147
148        let root = match new.next() {
149            Some(result) => match result? {
150                Node::Open(element) => element,
151                _ => return Err(Error::UnexpectedState("first node does not open element")),
152            },
153            None => return Err(Error::UnexpectedEndOfStream),
154        };
155
156        Ok((new, root))
157    }
158
159    pub(crate) fn element_id(&self, element: &Element<'xml>) -> Result<Id<'xml>, Error> {
160        Ok(Id {
161            ns: match (element.default_ns, element.prefix) {
162                (_, Some(prefix)) => match self.lookup(prefix) {
163                    Some(ns) => ns,
164                    None => return Err(Error::UnknownPrefix(prefix.to_owned())),
165                },
166                (Some(ns), None) => ns,
167                (None, None) => self.default_ns(),
168            },
169            name: element.local,
170        })
171    }
172
173    fn attribute_id(&self, attr: &Attribute<'xml>) -> Result<Id<'xml>, Error> {
174        Ok(Id {
175            ns: match attr.prefix {
176                Some(ns) => self
177                    .lookup(ns)
178                    .ok_or_else(|| Error::UnknownPrefix(ns.to_owned()))?,
179                None => "",
180            },
181            name: attr.local,
182        })
183    }
184
185    fn default_ns(&self) -> &'xml str {
186        self.stack
187            .iter()
188            .rev()
189            .find_map(|level| level.default_ns)
190            .unwrap_or("")
191    }
192
193    fn lookup(&self, prefix: &str) -> Option<&'xml str> {
194        // The prefix xml is by definition bound to the namespace
195        // name http://www.w3.org/XML/1998/namespace
196        // See https://www.w3.org/TR/xml-names/#ns-decl
197        if prefix == "xml" {
198            return Some("http://www.w3.org/XML/1998/namespace");
199        }
200
201        self.stack
202            .iter()
203            .rev()
204            .find_map(|level| level.prefixes.get(prefix).copied())
205    }
206}
207
208impl<'xml> Iterator for Context<'xml> {
209    type Item = Result<Node<'xml>, Error>;
210
211    fn next(&mut self) -> Option<Self::Item> {
212        if let Some(record) = self.records.pop_front() {
213            if let Node::Close { .. } = &record {
214                self.stack.pop();
215            }
216            return Some(Ok(record));
217        }
218
219        loop {
220            match self.parser.next()? {
221                Ok(Token::ElementStart { prefix, local, .. }) => {
222                    let prefix = prefix.as_str();
223                    self.stack.push(Level {
224                        local: local.as_str(),
225                        prefix: match prefix.is_empty() {
226                            true => None,
227                            false => Some(prefix),
228                        },
229                        default_ns: None,
230                        prefixes: BTreeMap::new(),
231                    });
232                }
233                Ok(Token::ElementEnd { end, .. }) => match end {
234                    ElementEnd::Open => {
235                        let level = match self.stack.last() {
236                            Some(level) => level,
237                            None => {
238                                return Some(Err(Error::UnexpectedState(
239                                    "opening element with no parent",
240                                )))
241                            }
242                        };
243
244                        let element = Element {
245                            local: level.local,
246                            prefix: level.prefix,
247                            default_ns: level.default_ns,
248                        };
249
250                        return Some(Ok(Node::Open(element)));
251                    }
252                    ElementEnd::Close(prefix, v) => {
253                        let level = match self.stack.pop() {
254                            Some(level) => level,
255                            None => {
256                                return Some(Err(Error::UnexpectedState(
257                                    "closing element without parent",
258                                )))
259                            }
260                        };
261
262                        let prefix = match prefix.is_empty() {
263                            true => None,
264                            false => Some(prefix.as_str()),
265                        };
266
267                        match v.as_str() == level.local && prefix == level.prefix {
268                            true => {
269                                return Some(Ok(Node::Close {
270                                    prefix,
271                                    local: level.local,
272                                }))
273                            }
274                            false => {
275                                return Some(Err(Error::UnexpectedState("close element mismatch")))
276                            }
277                        }
278                    }
279                    ElementEnd::Empty => {
280                        let level = match self.stack.last() {
281                            Some(level) => level,
282                            None => {
283                                return Some(Err(Error::UnexpectedState(
284                                    "opening element with no parent",
285                                )))
286                            }
287                        };
288
289                        self.records.push_back(Node::Close {
290                            prefix: level.prefix,
291                            local: level.local,
292                        });
293
294                        let element = Element {
295                            local: level.local,
296                            prefix: level.prefix,
297                            default_ns: level.default_ns,
298                        };
299
300                        return Some(Ok(Node::Open(element)));
301                    }
302                },
303                Ok(Token::Attribute {
304                    prefix,
305                    local,
306                    value,
307                    ..
308                }) => {
309                    if prefix.is_empty() && local.as_str() == "xmlns" {
310                        match self.stack.last_mut() {
311                            Some(level) => level.default_ns = Some(value.as_str()),
312                            None => {
313                                return Some(Err(Error::UnexpectedState(
314                                    "attribute without element context",
315                                )))
316                            }
317                        }
318                    } else if prefix.as_str() == "xmlns" {
319                        match self.stack.last_mut() {
320                            Some(level) => {
321                                level.prefixes.insert(local.as_str(), value.as_str());
322                            }
323                            None => {
324                                return Some(Err(Error::UnexpectedState(
325                                    "attribute without element context",
326                                )))
327                            }
328                        }
329                    } else {
330                        let value = match decode(value.as_str()) {
331                            Ok(value) => value,
332                            Err(e) => return Some(Err(e)),
333                        };
334
335                        self.records.push_back(Node::Attribute(Attribute {
336                            prefix: match prefix.is_empty() {
337                                true => None,
338                                false => Some(prefix.as_str()),
339                            },
340                            local: local.as_str(),
341                            value,
342                        }));
343                    }
344                }
345                Ok(Token::Text { text }) => {
346                    return Some(decode(text.as_str()).map(Node::Text));
347                }
348                Ok(Token::Cdata { text, .. }) => {
349                    return Some(Ok(Node::Text(Cow::Borrowed(text.as_str()))));
350                }
351                Ok(token @ Token::Declaration { .. }) => {
352                    if !self.stack.is_empty() {
353                        return Some(Err(Error::UnexpectedToken(format!("{token:?}"))));
354                    }
355                }
356                Ok(Token::Comment { .. }) => continue,
357                Ok(token) => return Some(Err(Error::UnexpectedToken(format!("{token:?}")))),
358                Err(e) => return Some(Err(Error::Parse(e))),
359            }
360        }
361    }
362}
363
364/// Deserialize a borrowed `Cow<str>` value
365///
366/// Helper function for deserializing `Cow<str>` with zero-copy borrowing from the input.
367pub fn borrow_cow_str<'a, 'xml: 'a>(
368    into: &mut CowStrAccumulator<'xml, 'a>,
369    field: &'static str,
370    deserializer: &mut Deserializer<'_, 'xml>,
371) -> Result<(), Error> {
372    if into.inner.is_some() {
373        return Err(Error::DuplicateValue(field));
374    }
375
376    match deserializer.take_str()? {
377        Some(value) => into.inner = Some(value),
378        None => return Ok(()),
379    };
380
381    deserializer.ignore()?;
382    Ok(())
383}
384
385/// Deserialize a borrowed `Cow<[u8]>` value
386///
387/// Helper function for deserializing `Cow<[u8]>` with zero-copy borrowing from the input.
388pub fn borrow_cow_slice_u8<'xml>(
389    into: &mut Option<Cow<'xml, [u8]>>,
390    field: &'static str,
391    deserializer: &mut Deserializer<'_, 'xml>,
392) -> Result<(), Error> {
393    if into.is_some() {
394        return Err(Error::DuplicateValue(field));
395    }
396
397    if let Some(value) = deserializer.take_str()? {
398        *into = Some(match value {
399            Cow::Borrowed(v) => Cow::Borrowed(v.as_bytes()),
400            Cow::Owned(v) => Cow::Owned(v.into_bytes()),
401        });
402    }
403
404    deserializer.ignore()?;
405    Ok(())
406}
407
408fn decode(input: &str) -> Result<Cow<'_, str>, Error> {
409    let mut result = String::with_capacity(input.len());
410    let (mut state, mut last_end) = (DecodeState::Normal, 0);
411    for (i, &b) in input.as_bytes().iter().enumerate() {
412        // use a state machine to find entities
413        state = match (state, b) {
414            (DecodeState::Normal, b'&') => DecodeState::Entity([0; 6], 0),
415            (DecodeState::Normal, _) => DecodeState::Normal,
416            (DecodeState::Entity(chars, len), b';') => {
417                let decoded = match &chars[..len] {
418                    [b'a', b'm', b'p'] => '&',
419                    [b'a', b'p', b'o', b's'] => '\'',
420                    [b'g', b't'] => '>',
421                    [b'l', b't'] => '<',
422                    [b'q', b'u', b'o', b't'] => '"',
423                    [b'#', b'x' | b'X', hex @ ..] => {
424                        // Hexadecimal character reference e.g. "&#x007c;" -> '|'
425                        str::from_utf8(hex)
426                            .ok()
427                            .and_then(|hex_str| u32::from_str_radix(hex_str, 16).ok())
428                            .and_then(char::from_u32)
429                            .filter(valid_xml_character)
430                            .ok_or_else(|| {
431                                Error::InvalidEntity(
432                                    String::from_utf8_lossy(&chars[..len]).into_owned(),
433                                )
434                            })?
435                    }
436                    [b'#', decimal @ ..] => {
437                        // Decimal character reference e.g. "&#1234;" -> 'Ӓ'
438                        str::from_utf8(decimal)
439                            .ok()
440                            .and_then(|decimal_str| u32::from_str(decimal_str).ok())
441                            .and_then(char::from_u32)
442                            .filter(valid_xml_character)
443                            .ok_or_else(|| {
444                                Error::InvalidEntity(
445                                    String::from_utf8_lossy(&chars[..len]).into_owned(),
446                                )
447                            })?
448                    }
449                    _ => {
450                        return Err(Error::InvalidEntity(
451                            String::from_utf8_lossy(&chars[..len]).into_owned(),
452                        ))
453                    }
454                };
455
456                let start = i - (len + 1); // current position - (length of entity characters + 1 for '&')
457                if last_end < start {
458                    // Unwrap should be safe: `last_end` and `start` must be at character boundaries.
459                    result.push_str(input.get(last_end..start).unwrap());
460                }
461
462                last_end = i + 1;
463                result.push(decoded);
464                DecodeState::Normal
465            }
466            (DecodeState::Entity(mut chars, len), b) => {
467                if len >= 6 {
468                    let mut bytes = Vec::with_capacity(7);
469                    bytes.extend(&chars[..len]);
470                    bytes.push(b);
471                    return Err(Error::InvalidEntity(
472                        String::from_utf8_lossy(&bytes).into_owned(),
473                    ));
474                }
475
476                chars[len] = b;
477                DecodeState::Entity(chars, len + 1)
478            }
479        };
480    }
481
482    // Unterminated entity (& without ;) at end of input
483    if let DecodeState::Entity(chars, len) = state {
484        return Err(Error::InvalidEntity(
485            String::from_utf8_lossy(&chars[..len]).into_owned(),
486        ));
487    }
488
489    Ok(match result.is_empty() {
490        true => Cow::Borrowed(input),
491        false => {
492            // Unwrap should be safe: `last_end` and `input.len()` must be at character boundaries.
493            result.push_str(input.get(last_end..input.len()).unwrap());
494            Cow::Owned(result)
495        }
496    })
497}
498
499#[derive(Debug)]
500enum DecodeState {
501    Normal,
502    Entity([u8; 6], usize),
503}
504
505/// Valid character ranges per <https://www.w3.org/TR/xml/#NT-Char>
506fn valid_xml_character(c: &char) -> bool {
507    matches!(c, '\u{9}' | '\u{A}' | '\u{D}' | '\u{20}'..='\u{D7FF}' | '\u{E000}'..='\u{FFFD}' | '\u{10000}'..='\u{10FFFF}')
508}
509
510/// An XML node during deserialization
511#[derive(Debug)]
512pub enum Node<'xml> {
513    /// An attribute name (value follows in a separate AttributeValue node)
514    Attribute(Attribute<'xml>),
515    /// The value of the preceding Attribute node
516    AttributeValue(Cow<'xml, str>),
517    /// Closing tag for an element
518    Close {
519        /// The namespace prefix, if any
520        prefix: Option<&'xml str>,
521        /// The local name
522        local: &'xml str,
523    },
524    /// Text content
525    Text(Cow<'xml, str>),
526    /// Opening tag for an element
527    Open(Element<'xml>),
528}
529
530/// An XML element during deserialization
531#[derive(Debug)]
532pub struct Element<'xml> {
533    local: &'xml str,
534    default_ns: Option<&'xml str>,
535    prefix: Option<&'xml str>,
536}
537
538#[derive(Debug)]
539struct Level<'xml> {
540    local: &'xml str,
541    prefix: Option<&'xml str>,
542    default_ns: Option<&'xml str>,
543    prefixes: BTreeMap<&'xml str, &'xml str>,
544}
545
546/// An XML attribute during deserialization
547#[derive(Debug)]
548pub struct Attribute<'xml> {
549    /// The namespace prefix, if any
550    pub prefix: Option<&'xml str>,
551    /// The local name
552    pub local: &'xml str,
553    /// The attribute value
554    pub value: Cow<'xml, str>,
555}
556
557#[cfg(test)]
558mod tests {
559    use super::*;
560
561    #[test]
562    fn test_decode() {
563        decode_ok("foo", "foo");
564        decode_ok("foo &amp; bar", "foo & bar");
565        decode_ok("foo &lt; bar", "foo < bar");
566        decode_ok("foo &gt; bar", "foo > bar");
567        decode_ok("foo &quot; bar", "foo \" bar");
568        decode_ok("foo &apos; bar", "foo ' bar");
569        decode_ok("foo &amp;lt; bar", "foo &lt; bar");
570        decode_ok("&amp; foo", "& foo");
571        decode_ok("foo &amp;", "foo &");
572        decode_ok("cbdtéda&amp;sü", "cbdtéda&sü");
573        // Decimal character references
574        decode_ok("&#1234;", "Ӓ");
575        decode_ok("foo &#9; bar", "foo \t bar");
576        decode_ok("foo &#124; bar", "foo | bar");
577        decode_ok("foo &#1234; bar", "foo Ӓ bar");
578        // Hexadecimal character references
579        decode_ok("&#xc4;", "Ä");
580        decode_ok("&#x00c4;", "Ä");
581        decode_ok("foo &#x9; bar", "foo \t bar");
582        decode_ok("foo &#x007c; bar", "foo | bar");
583        decode_ok("foo &#xc4; bar", "foo Ä bar");
584        decode_ok("foo &#x00c4; bar", "foo Ä bar");
585        decode_ok("foo &#x10de; bar", "foo პ bar");
586
587        decode_err("&");
588        decode_err("&#");
589        decode_err("&#;");
590        decode_err("foo&");
591        decode_err("&bar");
592        decode_err("&foo;");
593        decode_err("&foobar;");
594        decode_err("cbdtéd&ampü");
595    }
596
597    fn decode_ok(input: &str, expected: &'static str) {
598        assert_eq!(super::decode(input).unwrap(), expected, "{input:?}");
599    }
600
601    fn decode_err(input: &str) {
602        assert!(super::decode(input).is_err(), "{input:?}");
603    }
604}