use std::collections::BTreeMap;
use crate::{LoError, Result};
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum XmlItem {
Text(String),
Node(XmlNode),
}
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct XmlNode {
pub name: String,
pub attributes: BTreeMap<String, String>,
pub children: Vec<XmlNode>,
pub items: Vec<XmlItem>,
pub text: String,
}
impl XmlNode {
pub fn local_name(&self) -> &str {
local_name(&self.name)
}
pub fn attr(&self, name: &str) -> Option<&str> {
self.attributes.get(name).map(String::as_str).or_else(|| {
self.attributes
.iter()
.find(|(key, _)| key.as_str() == name || local_name(key.as_str()) == name)
.map(|(_, value)| value.as_str())
})
}
pub fn child(&self, name: &str) -> Option<&XmlNode> {
self.children
.iter()
.find(|child| child.local_name() == name || child.name == name)
}
pub fn children_named<'a>(&'a self, name: &'a str) -> impl Iterator<Item = &'a XmlNode> + 'a {
self.children
.iter()
.filter(move |child| child.local_name() == name || child.name == name)
}
pub fn descendants_named<'a>(&'a self, name: &'a str, out: &mut Vec<&'a XmlNode>) {
for child in &self.children {
if child.local_name() == name || child.name == name {
out.push(child);
}
child.descendants_named(name, out);
}
}
pub fn text_content(&self) -> String {
let mut out = String::new();
collect_text(self, &mut out);
out
}
}
fn collect_text(node: &XmlNode, out: &mut String) {
if !node.text.is_empty() {
out.push_str(&node.text);
}
for child in &node.children {
collect_text(child, out);
}
}
pub fn local_name(name: &str) -> &str {
name.rsplit_once(':')
.map(|(_, local)| local)
.unwrap_or(name)
}
pub fn parse_xml_document(xml: &str) -> Result<XmlNode> {
let bytes = xml.as_bytes();
let mut stack: Vec<XmlNode> = Vec::new();
let mut root: Option<XmlNode> = None;
let mut index = 0usize;
while index < bytes.len() {
if bytes[index] == b'<' {
if bytes[index..].starts_with(b"<!--") {
let end = find_bytes(bytes, index + 4, b"-->")?;
index = end + 3;
continue;
}
if bytes[index..].starts_with(b"<![CDATA[") {
let end = find_bytes(bytes, index + 9, b"]]>")?;
let text = String::from_utf8(bytes[index + 9..end].to_vec())
.map_err(|err| LoError::Parse(format!("invalid cdata utf-8: {err}")))?;
if let Some(current) = stack.last_mut() {
current.text.push_str(&text);
current.items.push(XmlItem::Text(text));
}
index = end + 3;
continue;
}
if bytes[index..].starts_with(b"<?") {
let end = find_bytes(bytes, index + 2, b"?>")?;
index = end + 2;
continue;
}
if bytes[index..].starts_with(b"<!") {
let end = find_byte(bytes, index + 2, b'>')?;
index = end + 1;
continue;
}
if bytes[index..].starts_with(b"</") {
let end = find_byte(bytes, index + 2, b'>')?;
let name = String::from_utf8(bytes[index + 2..end].to_vec())
.map_err(|err| LoError::Parse(format!("invalid closing tag name: {err}")))?;
let node = stack.pop().ok_or_else(|| {
LoError::Parse("xml closing tag without opening tag".to_string())
})?;
if local_name(name.trim()) != node.local_name() {
return Err(LoError::Parse(format!(
"xml closing tag mismatch: expected {}, found {}",
node.name,
name.trim()
)));
}
if let Some(parent) = stack.last_mut() {
parent.children.push(node.clone());
parent.items.push(XmlItem::Node(node));
} else if root.is_none() {
root = Some(node);
} else {
return Err(LoError::Parse("multiple xml roots".to_string()));
}
index = end + 1;
continue;
}
let end = find_tag_end(bytes, index + 1)?;
let raw = String::from_utf8(bytes[index + 1..end].to_vec())
.map_err(|err| LoError::Parse(format!("invalid tag utf-8: {err}")))?;
let self_closing = raw.trim_end().ends_with('/');
let raw = if self_closing {
raw.trim_end().trim_end_matches('/').trim_end().to_string()
} else {
raw
};
let (name, attributes) = parse_start_tag(&raw)?;
let node = XmlNode {
name,
attributes,
children: Vec::new(),
items: Vec::new(),
text: String::new(),
};
if self_closing {
if let Some(parent) = stack.last_mut() {
parent.children.push(node.clone());
parent.items.push(XmlItem::Node(node));
} else if root.is_none() {
root = Some(node);
} else {
return Err(LoError::Parse("multiple xml roots".to_string()));
}
} else {
stack.push(node);
}
index = end + 1;
} else {
let next = find_byte_optional(bytes, index, b'<').unwrap_or(bytes.len());
let raw_text = String::from_utf8(bytes[index..next].to_vec())
.map_err(|err| LoError::Parse(format!("invalid text utf-8: {err}")))?;
let decoded = decode_entities(&raw_text);
if let Some(current) = stack.last_mut() {
current.text.push_str(&decoded);
if !decoded.is_empty() {
current.items.push(XmlItem::Text(decoded));
}
}
index = next;
}
}
while let Some(node) = stack.pop() {
if let Some(parent) = stack.last_mut() {
parent.children.push(node.clone());
parent.items.push(XmlItem::Node(node));
} else if root.is_none() {
root = Some(node);
} else {
return Err(LoError::Parse("multiple xml roots".to_string()));
}
}
root.ok_or_else(|| LoError::Parse("empty xml document".to_string()))
}
fn find_bytes(haystack: &[u8], start: usize, needle: &[u8]) -> Result<usize> {
haystack[start..]
.windows(needle.len())
.position(|window| window == needle)
.map(|offset| start + offset)
.ok_or_else(|| LoError::Parse("unterminated xml construct".to_string()))
}
fn find_byte(bytes: &[u8], start: usize, byte: u8) -> Result<usize> {
find_byte_optional(bytes, start, byte)
.ok_or_else(|| LoError::Parse("unterminated xml tag".to_string()))
}
fn find_byte_optional(bytes: &[u8], start: usize, byte: u8) -> Option<usize> {
bytes[start..]
.iter()
.position(|&value| value == byte)
.map(|offset| start + offset)
}
fn find_tag_end(bytes: &[u8], start: usize) -> Result<usize> {
let mut quote: Option<u8> = None;
for index in start..bytes.len() {
let byte = bytes[index];
match quote {
Some(current) if byte == current => quote = None,
Some(_) => {}
None if byte == b'\'' || byte == b'"' => quote = Some(byte),
None if byte == b'>' => return Ok(index),
None => {}
}
}
Err(LoError::Parse("unterminated xml start tag".to_string()))
}
fn parse_start_tag(raw: &str) -> Result<(String, BTreeMap<String, String>)> {
let mut chars = raw.chars().peekable();
let mut name = String::new();
while let Some(&ch) = chars.peek() {
if ch.is_whitespace() {
break;
}
name.push(ch);
chars.next();
}
if name.is_empty() {
return Err(LoError::Parse("empty xml tag name".to_string()));
}
while matches!(chars.peek(), Some(ch) if ch.is_whitespace()) {
chars.next();
}
let mut attrs = BTreeMap::new();
while chars.peek().is_some() {
let mut key = String::new();
while let Some(&ch) = chars.peek() {
if ch.is_whitespace() || ch == '=' {
break;
}
key.push(ch);
chars.next();
}
while matches!(chars.peek(), Some(ch) if ch.is_whitespace()) {
chars.next();
}
if chars.next() != Some('=') {
return Err(LoError::Parse(format!("malformed xml attribute {key}")));
}
while matches!(chars.peek(), Some(ch) if ch.is_whitespace()) {
chars.next();
}
let quote = chars
.next()
.ok_or_else(|| LoError::Parse("unexpected end of xml attribute".to_string()))?;
if quote != '\'' && quote != '"' {
return Err(LoError::Parse("xml attribute must be quoted".to_string()));
}
let mut value = String::new();
for ch in chars.by_ref() {
if ch == quote {
break;
}
value.push(ch);
}
attrs.insert(key, decode_entities(&value));
while matches!(chars.peek(), Some(ch) if ch.is_whitespace()) {
chars.next();
}
}
Ok((name, attrs))
}
pub fn decode_entities(text: &str) -> String {
let mut out = String::new();
let mut index = 0usize;
let bytes = text.as_bytes();
while index < bytes.len() {
if bytes[index] == b'&' {
if let Some(end) = bytes[index + 1..].iter().position(|&b| b == b';') {
let end = index + 1 + end;
let entity = &text[index + 1..end];
match entity {
"amp" => out.push('&'),
"lt" => out.push('<'),
"gt" => out.push('>'),
"quot" => out.push('"'),
"apos" => out.push('\''),
_ if entity.starts_with("#x") => {
if let Ok(value) = u32::from_str_radix(&entity[2..], 16) {
if let Some(ch) = char::from_u32(value) {
out.push(ch);
}
}
}
_ if entity.starts_with('#') => {
if let Ok(value) = entity[1..].parse::<u32>() {
if let Some(ch) = char::from_u32(value) {
out.push(ch);
}
}
}
_ => {
out.push('&');
out.push_str(entity);
out.push(';');
}
}
index = end + 1;
continue;
}
}
out.push(bytes[index] as char);
index += 1;
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parser_handles_simple_tree() {
let root = parse_xml_document("<root><a x=\"1\">hi</a><b/></root>").unwrap();
assert_eq!(root.local_name(), "root");
assert_eq!(root.child("a").unwrap().text_content(), "hi");
assert_eq!(root.child("a").unwrap().attr("x"), Some("1"));
}
#[test]
fn decode_entities_handles_named_and_numeric() {
assert_eq!(decode_entities("&<AB"), "&<AB");
}
}