Skip to main content

lo_core/
xml_parser.rs

1//! Minimal pure-Rust XML parser used by importers (DOCX, XLSX, ODF, …).
2//!
3//! It is intentionally lenient: namespaces are kept on element/attribute
4//! names but accessors expose `local_name()` for convenience.
5
6use std::collections::BTreeMap;
7
8use crate::{LoError, Result};
9
10#[derive(Clone, Debug, PartialEq, Eq)]
11pub enum XmlItem {
12    Text(String),
13    Node(XmlNode),
14}
15
16#[derive(Clone, Debug, Default, PartialEq, Eq)]
17pub struct XmlNode {
18    pub name: String,
19    pub attributes: BTreeMap<String, String>,
20    pub children: Vec<XmlNode>,
21    pub items: Vec<XmlItem>,
22    pub text: String,
23}
24
25impl XmlNode {
26    pub fn local_name(&self) -> &str {
27        local_name(&self.name)
28    }
29
30    pub fn attr(&self, name: &str) -> Option<&str> {
31        self.attributes.get(name).map(String::as_str).or_else(|| {
32            self.attributes
33                .iter()
34                .find(|(key, _)| key.as_str() == name || local_name(key.as_str()) == name)
35                .map(|(_, value)| value.as_str())
36        })
37    }
38
39    pub fn child(&self, name: &str) -> Option<&XmlNode> {
40        self.children
41            .iter()
42            .find(|child| child.local_name() == name || child.name == name)
43    }
44
45    pub fn children_named<'a>(&'a self, name: &'a str) -> impl Iterator<Item = &'a XmlNode> + 'a {
46        self.children
47            .iter()
48            .filter(move |child| child.local_name() == name || child.name == name)
49    }
50
51    pub fn descendants_named<'a>(&'a self, name: &'a str, out: &mut Vec<&'a XmlNode>) {
52        for child in &self.children {
53            if child.local_name() == name || child.name == name {
54                out.push(child);
55            }
56            child.descendants_named(name, out);
57        }
58    }
59
60    pub fn text_content(&self) -> String {
61        let mut out = String::new();
62        collect_text(self, &mut out);
63        out
64    }
65}
66
67fn collect_text(node: &XmlNode, out: &mut String) {
68    if !node.text.is_empty() {
69        out.push_str(&node.text);
70    }
71    for child in &node.children {
72        collect_text(child, out);
73    }
74}
75
76pub fn local_name(name: &str) -> &str {
77    name.rsplit_once(':')
78        .map(|(_, local)| local)
79        .unwrap_or(name)
80}
81
82/// Serialize an `XmlNode` tree back into a self-closed/open XML string.
83/// Includes the XML declaration. Useful when an importer parses, mutates,
84/// and re-emits the document (e.g. recalculating an XLSX in place).
85pub fn serialize_xml_document(root: &XmlNode) -> String {
86    let mut out = String::new();
87    out.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
88    serialize_xml_node(root, &mut out);
89    out
90}
91
92pub fn serialize_xml_node(node: &XmlNode, out: &mut String) {
93    out.push('<');
94    out.push_str(&node.name);
95    for (key, value) in &node.attributes {
96        out.push(' ');
97        out.push_str(key);
98        out.push_str("=\"");
99        out.push_str(&xml_attr_escape(value));
100        out.push('"');
101    }
102    if node.items.is_empty() {
103        out.push_str("/>");
104        return;
105    }
106    out.push('>');
107    for item in &node.items {
108        match item {
109            XmlItem::Text(text) => out.push_str(&xml_text_escape(text)),
110            XmlItem::Node(child) => serialize_xml_node(child, out),
111        }
112    }
113    out.push_str("</");
114    out.push_str(&node.name);
115    out.push('>');
116}
117
118fn xml_text_escape(value: &str) -> String {
119    let mut out = String::with_capacity(value.len());
120    for ch in value.chars() {
121        match ch {
122            '&' => out.push_str("&amp;"),
123            '<' => out.push_str("&lt;"),
124            '>' => out.push_str("&gt;"),
125            _ => out.push(ch),
126        }
127    }
128    out
129}
130
131fn xml_attr_escape(value: &str) -> String {
132    let mut out = String::with_capacity(value.len());
133    for ch in value.chars() {
134        match ch {
135            '&' => out.push_str("&amp;"),
136            '<' => out.push_str("&lt;"),
137            '>' => out.push_str("&gt;"),
138            '"' => out.push_str("&quot;"),
139            '\'' => out.push_str("&apos;"),
140            _ => out.push(ch),
141        }
142    }
143    out
144}
145
146pub fn parse_xml_document(xml: &str) -> Result<XmlNode> {
147    let bytes = xml.as_bytes();
148    let mut stack: Vec<XmlNode> = Vec::new();
149    let mut root: Option<XmlNode> = None;
150    let mut index = 0usize;
151
152    while index < bytes.len() {
153        if bytes[index] == b'<' {
154            if bytes[index..].starts_with(b"<!--") {
155                let end = find_bytes(bytes, index + 4, b"-->")?;
156                index = end + 3;
157                continue;
158            }
159            if bytes[index..].starts_with(b"<![CDATA[") {
160                let end = find_bytes(bytes, index + 9, b"]]>")?;
161                let text = String::from_utf8(bytes[index + 9..end].to_vec())
162                    .map_err(|err| LoError::Parse(format!("invalid cdata utf-8: {err}")))?;
163                if let Some(current) = stack.last_mut() {
164                    current.text.push_str(&text);
165                    current.items.push(XmlItem::Text(text));
166                }
167                index = end + 3;
168                continue;
169            }
170            if bytes[index..].starts_with(b"<?") {
171                let end = find_bytes(bytes, index + 2, b"?>")?;
172                index = end + 2;
173                continue;
174            }
175            if bytes[index..].starts_with(b"<!") {
176                let end = find_byte(bytes, index + 2, b'>')?;
177                index = end + 1;
178                continue;
179            }
180            if bytes[index..].starts_with(b"</") {
181                let end = find_byte(bytes, index + 2, b'>')?;
182                let name = String::from_utf8(bytes[index + 2..end].to_vec())
183                    .map_err(|err| LoError::Parse(format!("invalid closing tag name: {err}")))?;
184                let node = stack.pop().ok_or_else(|| {
185                    LoError::Parse("xml closing tag without opening tag".to_string())
186                })?;
187                if local_name(name.trim()) != node.local_name() {
188                    return Err(LoError::Parse(format!(
189                        "xml closing tag mismatch: expected {}, found {}",
190                        node.name,
191                        name.trim()
192                    )));
193                }
194                if let Some(parent) = stack.last_mut() {
195                    parent.children.push(node.clone());
196                    parent.items.push(XmlItem::Node(node));
197                } else if root.is_none() {
198                    root = Some(node);
199                } else {
200                    return Err(LoError::Parse("multiple xml roots".to_string()));
201                }
202                index = end + 1;
203                continue;
204            }
205
206            let end = find_tag_end(bytes, index + 1)?;
207            let raw = String::from_utf8(bytes[index + 1..end].to_vec())
208                .map_err(|err| LoError::Parse(format!("invalid tag utf-8: {err}")))?;
209            let self_closing = raw.trim_end().ends_with('/');
210            let raw = if self_closing {
211                raw.trim_end().trim_end_matches('/').trim_end().to_string()
212            } else {
213                raw
214            };
215            let (name, attributes) = parse_start_tag(&raw)?;
216            let node = XmlNode {
217                name,
218                attributes,
219                children: Vec::new(),
220                items: Vec::new(),
221                text: String::new(),
222            };
223            if self_closing {
224                if let Some(parent) = stack.last_mut() {
225                    parent.children.push(node.clone());
226                    parent.items.push(XmlItem::Node(node));
227                } else if root.is_none() {
228                    root = Some(node);
229                } else {
230                    return Err(LoError::Parse("multiple xml roots".to_string()));
231                }
232            } else {
233                stack.push(node);
234            }
235            index = end + 1;
236        } else {
237            let next = find_byte_optional(bytes, index, b'<').unwrap_or(bytes.len());
238            let raw_text = String::from_utf8(bytes[index..next].to_vec())
239                .map_err(|err| LoError::Parse(format!("invalid text utf-8: {err}")))?;
240            let decoded = decode_entities(&raw_text);
241            if let Some(current) = stack.last_mut() {
242                current.text.push_str(&decoded);
243                if !decoded.is_empty() {
244                    current.items.push(XmlItem::Text(decoded));
245                }
246            }
247            index = next;
248        }
249    }
250
251    while let Some(node) = stack.pop() {
252        if let Some(parent) = stack.last_mut() {
253            parent.children.push(node.clone());
254            parent.items.push(XmlItem::Node(node));
255        } else if root.is_none() {
256            root = Some(node);
257        } else {
258            return Err(LoError::Parse("multiple xml roots".to_string()));
259        }
260    }
261
262    root.ok_or_else(|| LoError::Parse("empty xml document".to_string()))
263}
264
265fn find_bytes(haystack: &[u8], start: usize, needle: &[u8]) -> Result<usize> {
266    haystack[start..]
267        .windows(needle.len())
268        .position(|window| window == needle)
269        .map(|offset| start + offset)
270        .ok_or_else(|| LoError::Parse("unterminated xml construct".to_string()))
271}
272
273fn find_byte(bytes: &[u8], start: usize, byte: u8) -> Result<usize> {
274    find_byte_optional(bytes, start, byte)
275        .ok_or_else(|| LoError::Parse("unterminated xml tag".to_string()))
276}
277
278fn find_byte_optional(bytes: &[u8], start: usize, byte: u8) -> Option<usize> {
279    bytes[start..]
280        .iter()
281        .position(|&value| value == byte)
282        .map(|offset| start + offset)
283}
284
285fn find_tag_end(bytes: &[u8], start: usize) -> Result<usize> {
286    let mut quote: Option<u8> = None;
287    for index in start..bytes.len() {
288        let byte = bytes[index];
289        match quote {
290            Some(current) if byte == current => quote = None,
291            Some(_) => {}
292            None if byte == b'\'' || byte == b'"' => quote = Some(byte),
293            None if byte == b'>' => return Ok(index),
294            None => {}
295        }
296    }
297    Err(LoError::Parse("unterminated xml start tag".to_string()))
298}
299
300fn parse_start_tag(raw: &str) -> Result<(String, BTreeMap<String, String>)> {
301    let mut chars = raw.chars().peekable();
302    let mut name = String::new();
303    while let Some(&ch) = chars.peek() {
304        if ch.is_whitespace() {
305            break;
306        }
307        name.push(ch);
308        chars.next();
309    }
310    if name.is_empty() {
311        return Err(LoError::Parse("empty xml tag name".to_string()));
312    }
313    while matches!(chars.peek(), Some(ch) if ch.is_whitespace()) {
314        chars.next();
315    }
316    let mut attrs = BTreeMap::new();
317    while chars.peek().is_some() {
318        let mut key = String::new();
319        while let Some(&ch) = chars.peek() {
320            if ch.is_whitespace() || ch == '=' {
321                break;
322            }
323            key.push(ch);
324            chars.next();
325        }
326        while matches!(chars.peek(), Some(ch) if ch.is_whitespace()) {
327            chars.next();
328        }
329        if chars.next() != Some('=') {
330            return Err(LoError::Parse(format!("malformed xml attribute {key}")));
331        }
332        while matches!(chars.peek(), Some(ch) if ch.is_whitespace()) {
333            chars.next();
334        }
335        let quote = chars
336            .next()
337            .ok_or_else(|| LoError::Parse("unexpected end of xml attribute".to_string()))?;
338        if quote != '\'' && quote != '"' {
339            return Err(LoError::Parse("xml attribute must be quoted".to_string()));
340        }
341        let mut value = String::new();
342        for ch in chars.by_ref() {
343            if ch == quote {
344                break;
345            }
346            value.push(ch);
347        }
348        attrs.insert(key, decode_entities(&value));
349        while matches!(chars.peek(), Some(ch) if ch.is_whitespace()) {
350            chars.next();
351        }
352    }
353    Ok((name, attrs))
354}
355
356pub fn decode_entities(text: &str) -> String {
357    // Walk the input by char (not by byte) so multi-byte UTF-8 sequences
358    // such as `æ` (U+00E6 → 0xC3 0xA6) round-trip correctly. The previous
359    // implementation cast each byte to a `char`, which mojibake-d every
360    // non-ASCII character into Latin-1.
361    let mut out = String::with_capacity(text.len());
362    let mut iter = text.char_indices().peekable();
363    while let Some((idx, ch)) = iter.next() {
364        if ch != '&' {
365            out.push(ch);
366            continue;
367        }
368        // Find the matching ';' (still within the original `text`).
369        if let Some(rel_end) = text[idx + 1..].find(';') {
370            let end = idx + 1 + rel_end;
371            let entity = &text[idx + 1..end];
372            let mut consumed = false;
373            match entity {
374                "amp" => {
375                    out.push('&');
376                    consumed = true;
377                }
378                "lt" => {
379                    out.push('<');
380                    consumed = true;
381                }
382                "gt" => {
383                    out.push('>');
384                    consumed = true;
385                }
386                "quot" => {
387                    out.push('"');
388                    consumed = true;
389                }
390                "apos" => {
391                    out.push('\'');
392                    consumed = true;
393                }
394                _ if entity.starts_with("#x") || entity.starts_with("#X") => {
395                    if let Ok(value) = u32::from_str_radix(&entity[2..], 16) {
396                        if let Some(ch) = char::from_u32(value) {
397                            out.push(ch);
398                            consumed = true;
399                        }
400                    }
401                }
402                _ if entity.starts_with('#') => {
403                    if let Ok(value) = entity[1..].parse::<u32>() {
404                        if let Some(ch) = char::from_u32(value) {
405                            out.push(ch);
406                            consumed = true;
407                        }
408                    }
409                }
410                _ => {
411                    out.push('&');
412                    out.push_str(entity);
413                    out.push(';');
414                    consumed = true;
415                }
416            }
417            if consumed {
418                // Advance the char iterator past the entity body + closing ';'.
419                while let Some(&(next_idx, _)) = iter.peek() {
420                    if next_idx > end {
421                        break;
422                    }
423                    iter.next();
424                }
425                continue;
426            }
427        }
428        out.push('&');
429    }
430    out
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn parser_handles_simple_tree() {
439        let root = parse_xml_document("<root><a x=\"1\">hi</a><b/></root>").unwrap();
440        assert_eq!(root.local_name(), "root");
441        assert_eq!(root.child("a").unwrap().text_content(), "hi");
442        assert_eq!(root.child("a").unwrap().attr("x"), Some("1"));
443    }
444
445    #[test]
446    fn decode_entities_handles_named_and_numeric() {
447        assert_eq!(decode_entities("&amp;&lt;&#65;&#x42;"), "&<AB");
448    }
449}