use std::collections::HashMap;
use crate::error::{ErrorSeverity, ParseDiagnostic, ParseError, SourceLocation};
use crate::parser::{EntityResolver, ExternalEntityRequest};
pub(crate) const DEFAULT_MAX_DEPTH: u32 = 256;
pub(crate) const DEFAULT_MAX_ATTRIBUTES: u32 = 256;
pub(crate) const DEFAULT_MAX_ATTRIBUTE_LENGTH: usize = 10 * 1024 * 1024;
pub(crate) const DEFAULT_MAX_TEXT_LENGTH: usize = 10 * 1024 * 1024;
pub(crate) const DEFAULT_MAX_NAME_LENGTH: usize = 50_000;
pub(crate) const DEFAULT_MAX_ENTITY_EXPANSIONS: u32 = 10_000;
pub(crate) fn is_xml_char(c: char) -> bool {
matches!(c as u32,
0x09 | 0x0A | 0x0D | 0x20..=0xD7FF | 0xE000..=0xFFFD | 0x0001_0000..=0x0010_FFFF
)
}
pub(crate) fn is_name_start_char(c: char) -> bool {
matches!(c,
':' | 'A'..='Z' | '_' | 'a'..='z' |
'\u{C0}'..='\u{D6}' | '\u{D8}'..='\u{F6}' | '\u{F8}'..='\u{2FF}' |
'\u{370}'..='\u{37D}' | '\u{37F}'..='\u{1FFF}' |
'\u{200C}'..='\u{200D}' | '\u{2070}'..='\u{218F}' |
'\u{2C00}'..='\u{2FEF}' | '\u{3001}'..='\u{D7FF}' |
'\u{F900}'..='\u{FDCF}' | '\u{FDF0}'..='\u{FFFD}' |
'\u{10000}'..='\u{EFFFF}'
)
}
pub(crate) fn is_name_char(c: char) -> bool {
is_name_start_char(c)
|| matches!(c,
'-' | '.' | '0'..='9' | '\u{B7}' |
'\u{300}'..='\u{36F}' | '\u{203F}'..='\u{2040}'
)
}
fn is_ascii_name_start(b: u8) -> bool {
b.is_ascii_alphabetic() || b == b'_' || b == b':'
}
fn is_ascii_name_char(b: u8) -> bool {
b.is_ascii_alphanumeric() || b == b'_' || b == b':' || b == b'-' || b == b'.'
}
fn match_builtin_entity(bytes: &[u8]) -> Option<(&'static str, usize)> {
match bytes.first() {
Some(b'a') => {
if bytes.starts_with(b"amp;") {
return Some(("&", 4));
}
if bytes.starts_with(b"apos;") {
return Some(("'", 5));
}
None
}
Some(b'l') => {
if bytes.starts_with(b"lt;") {
return Some(("<", 3));
}
None
}
Some(b'g') => {
if bytes.starts_with(b"gt;") {
return Some((">", 3));
}
None
}
Some(b'q') => {
if bytes.starts_with(b"quot;") {
return Some(("\"", 5));
}
None
}
_ => None,
}
}
pub(crate) fn find_invalid_xml_char(s: &str) -> Option<char> {
s.chars().find(|&ch| !is_xml_char(ch))
}
pub(crate) fn may_contain_invalid_xml_chars(bytes: &[u8]) -> bool {
let len = bytes.len();
let mut i = 0;
while i < len {
let b = bytes[i];
if (b < 0x20 && b != b'\t' && b != b'\n' && b != b'\r') || b == 0x7F {
return true;
}
if b == 0xEF
&& i + 2 < len
&& bytes[i + 1] == 0xBF
&& (bytes[i + 2] == 0xBE || bytes[i + 2] == 0xBF)
{
return true;
}
i += 1;
}
false
}
pub(crate) fn split_name(name: &str) -> (Option<&str>, &str) {
match name.find(':') {
Some(pos) => (Some(&name[..pos]), &name[pos + 1..]),
None => (None, name),
}
}
pub(crate) fn split_owned_name(name: String) -> (Option<String>, String) {
match name.find(':') {
Some(pos) => {
let local = name[pos + 1..].to_string();
let mut prefix = name;
prefix.truncate(pos);
(Some(prefix), local)
}
None => (None, name),
}
}
#[allow(dead_code)]
pub(crate) fn validate_qname(name: &str) -> Option<&'static str> {
let colon_count = name.chars().filter(|&c| c == ':').count();
if colon_count > 1 {
return Some("QName contains multiple colons");
}
if colon_count == 1 && (name.starts_with(':') || name.ends_with(':')) {
return Some("QName has empty prefix or local part");
}
None
}
pub(crate) const XMLNS_NAMESPACE: &str = "http://www.w3.org/2000/xmlns/";
pub(crate) fn is_pubid_char(c: char) -> bool {
matches!(c,
' ' | '\r' | '\n' |
'a'..='z' | 'A'..='Z' | '0'..='9' |
'-' | '\'' | '(' | ')' | '+' | ',' | '.' | '/' | ':' |
'=' | '?' | ';' | '!' | '*' | '#' | '@' | '$' | '_' | '%'
)
}
pub(crate) fn validate_pubid(s: &str) -> Option<String> {
for c in s.chars() {
if !is_pubid_char(c) {
return Some(format!(
"invalid character '{}' (U+{:04X}) in public ID",
c.escape_default(),
c as u32
));
}
}
None
}
#[derive(Debug, Clone, Copy)]
#[allow(dead_code)]
pub(crate) struct SavedPosition {
pos: usize,
line: u32,
column: u32,
}
#[derive(Debug, Clone)]
pub(crate) struct ExternalEntityInfo {
pub system_id: String,
pub public_id: Option<String>,
}
pub(crate) struct ParserInput<'a> {
input: &'a [u8],
pos: usize,
line: u32,
column: u32,
depth: u32,
max_depth: u32,
max_name_length: usize,
pub(crate) entity_expansions: u32,
max_entity_expansions: u32,
recover: bool,
pub(crate) diagnostics: Vec<ParseDiagnostic>,
pub(crate) entity_map: HashMap<String, String>,
pub(crate) entity_external: HashMap<String, ExternalEntityInfo>,
pub(crate) has_pe_references: bool,
pub(crate) has_external_dtd: bool,
validated_entities: std::collections::HashSet<String>,
entity_resolver: Option<EntityResolver>,
}
impl<'a> ParserInput<'a> {
pub fn new(input: &'a str) -> Self {
Self {
input: input.as_bytes(),
pos: 0,
line: 1,
column: 1,
depth: 0,
max_depth: DEFAULT_MAX_DEPTH,
max_name_length: DEFAULT_MAX_NAME_LENGTH,
entity_expansions: 0,
max_entity_expansions: DEFAULT_MAX_ENTITY_EXPANSIONS,
recover: false,
diagnostics: Vec::new(),
entity_map: HashMap::new(),
entity_external: HashMap::new(),
has_pe_references: false,
has_external_dtd: false,
validated_entities: std::collections::HashSet::new(),
entity_resolver: None,
}
}
pub fn set_entity_resolver(&mut self, resolver: Option<EntityResolver>) {
self.entity_resolver = resolver;
}
pub fn set_max_depth(&mut self, max: u32) {
self.max_depth = max;
}
pub fn set_max_name_length(&mut self, max: usize) {
self.max_name_length = max;
}
pub fn set_max_entity_expansions(&mut self, max: u32) {
self.max_entity_expansions = max;
}
pub fn set_recover(&mut self, recover: bool) {
self.recover = recover;
}
pub fn recover(&self) -> bool {
self.recover
}
pub fn increment_depth(&mut self) -> Result<(), ParseError> {
self.depth += 1;
if self.depth > self.max_depth {
return Err(self.fatal(format!(
"maximum nesting depth exceeded ({})",
self.max_depth
)));
}
Ok(())
}
pub fn decrement_depth(&mut self) {
self.depth = self.depth.saturating_sub(1);
}
#[allow(dead_code)]
pub fn depth(&self) -> u32 {
self.depth
}
pub fn location(&self) -> SourceLocation {
SourceLocation {
line: self.line,
column: self.column,
byte_offset: self.pos,
}
}
#[inline]
pub fn at_end(&self) -> bool {
self.pos >= self.input.len()
}
#[allow(dead_code)]
#[inline]
pub fn pos(&self) -> usize {
self.pos
}
#[allow(dead_code)]
#[inline]
pub fn slice(&self, start: usize, end: usize) -> &[u8] {
&self.input[start..end]
}
#[allow(dead_code)]
#[inline]
pub fn remaining(&self) -> &[u8] {
&self.input[self.pos..]
}
#[allow(dead_code)]
pub fn save_position(&self) -> SavedPosition {
SavedPosition {
pos: self.pos,
line: self.line,
column: self.column,
}
}
#[allow(dead_code)]
pub fn restore_position(&mut self, saved: SavedPosition) {
self.pos = saved.pos;
self.line = saved.line;
self.column = saved.column;
}
#[inline]
pub fn peek(&self) -> Option<u8> {
self.input.get(self.pos).copied()
}
pub fn peek_at(&self, offset: usize) -> Option<u8> {
self.input.get(self.pos + offset).copied()
}
#[inline]
pub fn peek_char(&self) -> Option<char> {
if self.pos >= self.input.len() {
return None;
}
let first = self.input[self.pos];
if first < 0x80 {
return Some(first as char);
}
let len = match first {
0xC0..=0xDF => 2,
0xE0..=0xEF => 3,
0xF0..=0xF7 => 4,
_ => return None, };
let remaining = &self.input[self.pos..];
if remaining.len() < len {
return None;
}
std::str::from_utf8(&remaining[..len])
.ok()
.and_then(|s| s.chars().next())
}
#[inline]
pub fn advance(&mut self, count: usize) {
if count == 1 {
if self.pos < self.input.len() {
if self.input[self.pos] == b'\n' {
self.line += 1;
self.column = 1;
} else {
self.column += 1;
}
self.pos += 1;
}
return;
}
self.advance_counting_lines(count.min(self.input.len() - self.pos));
}
#[inline]
pub fn advance_char(&mut self, ch: char) {
let len = ch.len_utf8();
if ch == '\n' {
self.line += 1;
self.column = 1;
} else {
self.column += 1;
}
self.pos += len;
}
#[inline]
pub fn next_byte(&mut self) -> Result<u8, ParseError> {
if self.at_end() {
return Err(self.fatal("unexpected end of input"));
}
let b = self.input[self.pos];
self.advance(1);
Ok(b)
}
#[inline]
pub fn next_char(&mut self) -> Result<char, ParseError> {
if self.pos >= self.input.len() {
return Err(self.fatal("unexpected end of input"));
}
let first = self.input[self.pos];
if first < 0x80 {
self.pos += 1;
if first == b'\n' {
self.line += 1;
self.column = 1;
} else if first == b'\r' {
self.line += 1;
self.column = 1;
if self.pos < self.input.len() && self.input[self.pos] == b'\n' {
self.pos += 1;
}
return Ok('\n');
} else {
self.column += 1;
if first < 0x20 && first != b'\t' {
let ch = first as char;
if self.recover {
self.push_diagnostic(
ErrorSeverity::Error,
format!("invalid XML character: U+{:04X}", ch as u32),
);
} else {
return Err(
self.fatal(format!("invalid XML character: U+{:04X}", ch as u32))
);
}
}
}
return Ok(first as char);
}
let ch = self
.peek_char()
.ok_or_else(|| self.fatal("unexpected end of input"))?;
self.advance_char(ch);
if !is_xml_char(ch) {
if self.recover {
self.push_diagnostic(
ErrorSeverity::Error,
format!("invalid XML character: U+{:04X}", ch as u32),
);
} else {
return Err(self.fatal(format!("invalid XML character: U+{:04X}", ch as u32)));
}
}
Ok(ch)
}
#[inline]
pub fn scan_char_data(&self) -> usize {
let bytes = &self.input[self.pos..];
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
match b {
b'<' | b'&' => return i,
b']' if i + 2 < bytes.len() && bytes[i + 1] == b']' && bytes[i + 2] == b'>' => {
return i;
}
_ => {
if b < 0x20 && b != b'\t' && b != b'\n' && b != b'\r' {
return i;
}
i += 1;
}
}
}
bytes.len()
}
#[inline]
pub fn scan_attr_value(&self, quote: u8) -> usize {
let bytes = &self.input[self.pos..];
let mut i = 0;
while i < bytes.len() {
let b = bytes[i];
if b == quote || b == b'&' || b == b'<' {
return i;
}
if b < 0x20 && b != b'\t' && b != b'\n' && b != b'\r' {
return i;
}
i += 1;
}
bytes.len()
}
#[allow(dead_code)]
pub fn scan_until_any(&self, markers: &[u8]) -> usize {
let mut marker_set = [false; 256];
for &m in markers {
marker_set[m as usize] = true;
}
let bytes = &self.input[self.pos..];
for (i, &b) in bytes.iter().enumerate() {
if marker_set[b as usize] {
return i;
}
}
bytes.len()
}
pub fn scan_for_2byte_terminator(&self, t0: u8, t1: u8) -> Option<usize> {
let bytes = &self.input[self.pos..];
if bytes.len() < 2 {
return None;
}
let mut i = 0;
let end = bytes.len() - 1;
while i < end {
if bytes[i] == t0 && bytes[i + 1] == t1 {
return Some(i);
}
i += 1;
}
None
}
pub fn scan_for_3byte_terminator(&self, t0: u8, t1: u8, t2: u8) -> Option<usize> {
let bytes = &self.input[self.pos..];
if bytes.len() < 3 {
return None;
}
let mut i = 0;
let end = bytes.len() - 2;
while i < end {
if bytes[i] == t0 && bytes[i + 1] == t1 && bytes[i + 2] == t2 {
return Some(i);
}
i += 1;
}
None
}
#[inline]
pub fn advance_counting_lines(&mut self, count: usize) {
let end = self.pos + count;
let slice = &self.input[self.pos..end];
if slice.contains(&b'\n') {
for &b in slice {
if b == b'\n' {
self.line += 1;
self.column = 1;
} else {
self.column += 1;
}
}
} else {
#[allow(clippy::cast_possible_truncation)]
{
self.column += count as u32;
}
}
self.pos = end;
}
#[inline]
pub fn expect_byte(&mut self, expected: u8) -> Result<(), ParseError> {
let b = self.next_byte()?;
if b != expected {
return Err(self.fatal(format!(
"expected '{}', found '{}'",
expected as char, b as char
)));
}
Ok(())
}
#[inline]
pub fn expect_str(&mut self, expected: &[u8]) -> Result<(), ParseError> {
if self.pos + expected.len() > self.input.len() {
return Err(self.fatal("unexpected end of input"));
}
if &self.input[self.pos..self.pos + expected.len()] == expected {
self.advance_counting_lines(expected.len());
} else {
for &b in expected {
self.expect_byte(b)?;
}
}
Ok(())
}
#[inline]
pub fn looking_at(&self, s: &[u8]) -> bool {
self.input[self.pos..].starts_with(s)
}
#[allow(dead_code)]
pub fn looking_at_ci(&self, expected: &[u8]) -> bool {
if self.pos + expected.len() > self.input.len() {
return false;
}
self.input[self.pos..self.pos + expected.len()].eq_ignore_ascii_case(expected)
}
#[inline]
pub fn skip_whitespace(&mut self) -> bool {
let start = self.pos;
while self.pos < self.input.len() {
match self.input[self.pos] {
b'\n' => {
self.line += 1;
self.column = 1;
self.pos += 1;
}
b' ' | b'\t' | b'\r' => {
self.column += 1;
self.pos += 1;
}
_ => break,
}
}
self.pos > start
}
pub fn consume_whitespace(&mut self) -> &str {
let start = self.pos;
while self.pos < self.input.len() {
match self.input[self.pos] {
b'\n' => {
self.line += 1;
self.column = 1;
self.pos += 1;
}
b' ' | b'\t' | b'\r' => {
self.column += 1;
self.pos += 1;
}
_ => break,
}
}
std::str::from_utf8(&self.input[start..self.pos]).unwrap_or_default()
}
pub fn skip_whitespace_required(&mut self) -> Result<(), ParseError> {
if !self.skip_whitespace() {
return Err(self.fatal("whitespace required"));
}
Ok(())
}
pub fn take_while(&mut self, pred: impl Fn(u8) -> bool) -> String {
let start = self.pos;
while self.pos < self.input.len() && pred(self.input[self.pos]) {
if self.input[self.pos] == b'\n' {
self.line += 1;
self.column = 1;
} else {
self.column += 1;
}
self.pos += 1;
}
std::str::from_utf8(&self.input[start..self.pos])
.unwrap_or("")
.to_string()
}
#[inline]
pub fn parse_name(&mut self) -> Result<String, ParseError> {
let start = self.pos;
if self.pos >= self.input.len() {
return Err(self.fatal("expected name, found end of input"));
}
let first = self.input[self.pos];
if is_ascii_name_start(first) {
self.pos += 1;
self.column += 1;
while self.pos < self.input.len() && is_ascii_name_char(self.input[self.pos]) {
self.pos += 1;
self.column += 1;
}
if self.pos >= self.input.len() || self.input[self.pos] < 0x80 {
let len = self.pos - start;
if len > self.max_name_length {
return Err(self.fatal(format!(
"name length ({len}) exceeds maximum ({})",
self.max_name_length
)));
}
let name = std::str::from_utf8(&self.input[start..self.pos])
.map_err(|_| self.fatal("invalid UTF-8 in name"))?;
return Ok(name.to_string());
}
} else {
let ch = self
.peek_char()
.ok_or_else(|| self.fatal("expected name"))?;
if !is_name_start_char(ch) {
return Err(self.fatal(format!("invalid name start character: '{ch}'")));
}
self.advance_char(ch);
}
while let Some(ch) = self.peek_char() {
if is_name_char(ch) {
self.advance_char(ch);
} else {
break;
}
}
let len = self.pos - start;
if len > self.max_name_length {
return Err(self.fatal(format!(
"name length ({len}) exceeds maximum ({})",
self.max_name_length
)));
}
let name = std::str::from_utf8(&self.input[start..self.pos])
.map_err(|_| self.fatal("invalid UTF-8 in name"))?;
Ok(name.to_string())
}
#[allow(dead_code)]
pub fn parse_name_eq(&mut self, expected: &str) -> Result<Option<String>, ParseError> {
let start = self.pos;
if self.pos >= self.input.len() {
return Err(self.fatal("expected name, found end of input"));
}
let first = self.input[self.pos];
if is_ascii_name_start(first) {
self.pos += 1;
self.column += 1;
while self.pos < self.input.len() && is_ascii_name_char(self.input[self.pos]) {
self.pos += 1;
self.column += 1;
}
if self.pos >= self.input.len() || self.input[self.pos] < 0x80 {
let len = self.pos - start;
if len > self.max_name_length {
return Err(self.fatal(format!(
"name length ({len}) exceeds maximum ({})",
self.max_name_length
)));
}
if len == expected.len() && &self.input[start..self.pos] == expected.as_bytes() {
return Ok(None); }
let name = std::str::from_utf8(&self.input[start..self.pos])
.map_err(|_| self.fatal("invalid UTF-8 in name"))?;
return Ok(Some(name.to_string()));
}
} else {
let ch = self
.peek_char()
.ok_or_else(|| self.fatal("expected name"))?;
if !is_name_start_char(ch) {
return Err(self.fatal(format!("invalid name start character: '{ch}'")));
}
self.advance_char(ch);
}
while let Some(ch) = self.peek_char() {
if is_name_char(ch) {
self.advance_char(ch);
} else {
break;
}
}
let len = self.pos - start;
if len > self.max_name_length {
return Err(self.fatal(format!(
"name length ({len}) exceeds maximum ({})",
self.max_name_length
)));
}
if len == expected.len() && &self.input[start..self.pos] == expected.as_bytes() {
return Ok(None);
}
let name = std::str::from_utf8(&self.input[start..self.pos])
.map_err(|_| self.fatal("invalid UTF-8 in name"))?;
Ok(Some(name.to_string()))
}
pub fn parse_name_eq_parts(
&mut self,
prefix: Option<&str>,
local: &str,
) -> Result<Option<String>, ParseError> {
let start = self.pos;
if self.pos >= self.input.len() {
return Err(self.fatal("expected name, found end of input"));
}
let first = self.input[self.pos];
if is_ascii_name_start(first) {
self.pos += 1;
self.column += 1;
while self.pos < self.input.len() && is_ascii_name_char(self.input[self.pos]) {
self.pos += 1;
self.column += 1;
}
if self.pos >= self.input.len() || self.input[self.pos] < 0x80 {
let len = self.pos - start;
if len > self.max_name_length {
return Err(self.fatal(format!(
"name length ({len}) exceeds maximum ({})",
self.max_name_length
)));
}
let parsed = &self.input[start..self.pos];
let matches = match prefix {
Some(pfx) => {
let expected_len = pfx.len() + 1 + local.len();
len == expected_len
&& parsed[..pfx.len()] == *pfx.as_bytes()
&& parsed[pfx.len()] == b':'
&& parsed[pfx.len() + 1..] == *local.as_bytes()
}
None => len == local.len() && parsed == local.as_bytes(),
};
if matches {
return Ok(None);
}
let name =
std::str::from_utf8(parsed).map_err(|_| self.fatal("invalid UTF-8 in name"))?;
return Ok(Some(name.to_string()));
}
} else {
let ch = self
.peek_char()
.ok_or_else(|| self.fatal("expected name"))?;
if !is_name_start_char(ch) {
return Err(self.fatal(format!("invalid name start character: '{ch}'")));
}
self.advance_char(ch);
}
while let Some(ch) = self.peek_char() {
if is_name_char(ch) {
self.advance_char(ch);
} else {
break;
}
}
let len = self.pos - start;
if len > self.max_name_length {
return Err(self.fatal(format!(
"name length ({len}) exceeds maximum ({})",
self.max_name_length
)));
}
let parsed = &self.input[start..self.pos];
let matches = match prefix {
Some(pfx) => {
let expected_len = pfx.len() + 1 + local.len();
len == expected_len
&& parsed[..pfx.len()] == *pfx.as_bytes()
&& parsed[pfx.len()] == b':'
&& parsed[pfx.len() + 1..] == *local.as_bytes()
}
None => len == local.len() && parsed == local.as_bytes(),
};
if matches {
return Ok(None);
}
let name = std::str::from_utf8(parsed).map_err(|_| self.fatal("invalid UTF-8 in name"))?;
Ok(Some(name.to_string()))
}
#[cfg(test)]
pub fn parse_reference(&mut self) -> Result<String, ParseError> {
let mut buf = String::new();
self.parse_reference_into(&mut buf)?;
Ok(buf)
}
pub fn parse_reference_into<'b>(&mut self, buf: &'b mut String) -> Result<&'b str, ParseError> {
self.entity_expansions += 1;
if self.entity_expansions > self.max_entity_expansions {
return Err(self.fatal(format!(
"entity expansion limit exceeded ({})",
self.max_entity_expansions
)));
}
self.expect_byte(b'&')?;
let remaining = &self.input[self.pos..];
if let Some(result) = match_builtin_entity(remaining) {
let advance_len = result.1;
self.advance_counting_lines(advance_len);
let start = buf.len();
buf.push_str(result.0);
return Ok(&buf[start..]);
}
if self.peek() == Some(b'#') {
self.advance(1);
let value = if self.peek() == Some(b'x') {
self.advance(1);
let hex = self.take_while(|b| b.is_ascii_hexdigit());
if hex.is_empty() {
return Err(self.fatal("empty hex character reference"));
}
u32::from_str_radix(&hex, 16)
.map_err(|_| self.fatal("invalid hex character reference"))?
} else {
let dec = self.take_while(|b| b.is_ascii_digit());
if dec.is_empty() {
return Err(self.fatal("empty decimal character reference"));
}
dec.parse::<u32>()
.map_err(|_| self.fatal("invalid decimal character reference"))?
};
self.expect_byte(b';')?;
let ch = char::from_u32(value)
.ok_or_else(|| self.fatal(format!("invalid character reference: U+{value:04X}")))?;
if !is_xml_char(ch) {
return Err(self.fatal(format!(
"character reference &#x{value:X}; does not refer to a valid XML character"
)));
}
let start = buf.len();
buf.push(ch);
Ok(&buf[start..])
} else {
let name = self.parse_name()?;
self.expect_byte(b';')?;
let expanded = match name.as_str() {
"amp" | "lt" | "gt" | "apos" | "quot" => {
unreachable!("builtin entity should be caught by fast path")
}
_ => {
if let Some(info) = self.entity_external.get(&name).cloned() {
if let Some(ref resolver) = self.entity_resolver.clone() {
let request = ExternalEntityRequest {
name: &name,
system_id: &info.system_id,
public_id: info.public_id.as_deref(),
};
if let Some(resolved) = resolver(request) {
self.expand_entity_text(&resolved)?
} else {
return Err(self.fatal(format!(
"reference to external entity '{name}' is not supported"
)));
}
} else {
return Err(self.fatal(format!(
"reference to external entity '{name}' is not supported"
)));
}
} else if let Some(value) = self.entity_map.get(&name).cloned() {
if !self.validated_entities.contains(&name) {
self.validated_entities.insert(name.clone());
self.validate_entity_content(&name, &value)?;
}
self.expand_entity_text(&value)?
} else if self.recover || self.has_pe_references || self.has_external_dtd {
self.push_diagnostic(
ErrorSeverity::Warning,
format!("unknown entity reference: &{name};"),
);
String::new()
} else {
return Err(self.fatal(format!("unknown entity reference: &{name};")));
}
}
};
let start = buf.len();
buf.push_str(&expanded);
Ok(&buf[start..])
}
}
#[allow(clippy::too_many_lines)]
fn expand_entity_text(&mut self, text: &str) -> Result<String, ParseError> {
if !text.contains('&') {
return Ok(text.to_string());
}
let bytes = text.as_bytes();
let mut result = String::with_capacity(text.len());
let mut i = 0;
let mut in_cdata = false;
while i < bytes.len() {
if !in_cdata && i + 8 < bytes.len() && &bytes[i..i + 9] == b"<![CDATA[" {
in_cdata = true;
result.push_str("<![CDATA[");
i += 9;
continue;
}
if in_cdata {
if i + 2 < bytes.len() && &bytes[i..i + 3] == b"]]>" {
in_cdata = false;
result.push_str("]]>");
i += 3;
} else {
result.push(bytes[i] as char);
i += 1;
}
continue;
}
if bytes[i] == b'&' {
i += 1;
if i < bytes.len() && bytes[i] == b'#' {
i += 1;
let char_val = if i < bytes.len() && bytes[i] == b'x' {
i += 1;
let start = i;
while i < bytes.len() && bytes[i].is_ascii_hexdigit() {
i += 1;
}
let hex = std::str::from_utf8(&bytes[start..i])
.map_err(|_| self.fatal("invalid UTF-8 in entity value"))?;
u32::from_str_radix(hex, 16)
.map_err(|_| self.fatal("invalid hex character reference"))?
} else {
let start = i;
while i < bytes.len() && bytes[i].is_ascii_digit() {
i += 1;
}
let dec = std::str::from_utf8(&bytes[start..i])
.map_err(|_| self.fatal("invalid UTF-8 in entity value"))?;
dec.parse::<u32>()
.map_err(|_| self.fatal("invalid decimal character reference"))?
};
if i >= bytes.len() || bytes[i] != b';' {
return Err(self.fatal("incomplete character reference in entity value"));
}
i += 1;
let ch = char::from_u32(char_val).ok_or_else(|| {
self.fatal(format!("invalid character reference: U+{char_val:04X}"))
})?;
result.push(ch);
} else {
let start = i;
while i < bytes.len() && bytes[i] != b';' {
i += 1;
}
if i >= bytes.len() {
return Err(self.fatal("incomplete entity reference in entity value"));
}
let name = std::str::from_utf8(&bytes[start..i])
.map_err(|_| self.fatal("invalid UTF-8 in entity name"))?;
i += 1;
self.entity_expansions += 1;
if self.entity_expansions > self.max_entity_expansions {
return Err(self.fatal(format!(
"entity expansion limit exceeded ({})",
self.max_entity_expansions
)));
}
match name {
"amp" => {
result.push('&');
continue;
}
"lt" => {
result.push('<');
continue;
}
"gt" => {
result.push('>');
continue;
}
"apos" => {
result.push('\'');
continue;
}
"quot" => {
result.push('"');
continue;
}
_ => {}
}
let expanded = if let Some(info) = self.entity_external.get(name).cloned() {
if let Some(ref resolver) = self.entity_resolver.clone() {
let request = ExternalEntityRequest {
name,
system_id: &info.system_id,
public_id: info.public_id.as_deref(),
};
if let Some(resolved) = resolver(request) {
self.expand_entity_text(&resolved)?
} else {
return Err(self.fatal(format!(
"reference to external entity '{name}' is not supported"
)));
}
} else {
return Err(self.fatal(format!(
"reference to external entity '{name}' is not supported"
)));
}
} else if let Some(value) = self.entity_map.get(name).cloned() {
self.expand_entity_text(&value)?
} else if self.recover || self.has_pe_references || self.has_external_dtd {
self.push_diagnostic(
ErrorSeverity::Warning,
format!("unknown entity reference: &{name};"),
);
String::new()
} else {
return Err(self.fatal(format!("unknown entity reference: &{name};")));
};
result.push_str(&expanded);
}
} else {
let start = i;
i += 1;
while i < bytes.len() && bytes[i] & 0xC0 == 0x80 {
i += 1;
}
if let Ok(s) = std::str::from_utf8(&bytes[start..i]) {
result.push_str(s);
}
}
}
Ok(result)
}
fn validate_entity_content(&self, name: &str, raw_value: &str) -> Result<(), ParseError> {
let replacement = crate::validation::dtd::expand_char_refs_only(raw_value);
if !replacement.contains('<') {
return Ok(());
}
let sanitized = crate::validation::dtd::replace_entity_refs(&replacement);
let wrapped = format!("<_r>{sanitized}</_r>");
let options = super::ParseOptions::default();
if super::parse_str_with_options(&wrapped, &options).is_err() {
return Err(self.fatal(format!(
"entity '{name}' replacement text is not \
well-formed XML content"
)));
}
Ok(())
}
pub fn parse_attribute_value(&mut self) -> Result<String, ParseError> {
let quote = self.next_byte()?;
if quote != b'"' && quote != b'\'' {
return Err(self.fatal("attribute value must be quoted"));
}
let mut value = String::new();
loop {
let safe_len = self.scan_attr_value(quote);
if safe_len > 0 {
let start = self.pos;
let chunk = std::str::from_utf8(&self.input[start..start + safe_len])
.map_err(|_| self.fatal("invalid UTF-8 in attribute value"))?;
if let Some(bad) = may_contain_invalid_xml_chars(chunk.as_bytes())
.then(|| find_invalid_xml_char(chunk))
.flatten()
{
if self.recover {
self.push_diagnostic(
ErrorSeverity::Error,
format!("invalid XML character: U+{:04X}", bad as u32),
);
} else {
return Err(
self.fatal(format!("invalid XML character: U+{:04X}", bad as u32))
);
}
}
if chunk
.as_bytes()
.iter()
.any(|&b| b == b'\t' || b == b'\n' || b == b'\r')
{
for ch in chunk.chars() {
match ch {
'\t' | '\n' | '\r' => value.push(' '),
_ => value.push(ch),
}
}
} else {
value.push_str(chunk);
}
self.advance_counting_lines(safe_len);
}
if self.at_end() {
return Err(self.fatal("unexpected end of input in attribute value"));
}
let b = self.input[self.pos];
if b == quote {
self.advance(1);
break;
}
if b == b'&' {
let is_custom_entity = self.input.get(self.pos + 1) != Some(&b'#')
&& !self.input[self.pos + 1..].starts_with(b"lt;")
&& !self.input[self.pos + 1..].starts_with(b"gt;")
&& !self.input[self.pos + 1..].starts_with(b"amp;")
&& !self.input[self.pos + 1..].starts_with(b"apos;")
&& !self.input[self.pos + 1..].starts_with(b"quot;");
let resolved = self.parse_reference_into(&mut value)?;
if is_custom_entity && resolved.contains('<') {
return Err(
self.fatal("'<' not allowed in attribute values (from entity expansion)")
);
}
} else if b == b'<' {
return Err(self.fatal("'<' not allowed in attribute values"));
} else {
let ch = self.next_char()?;
if ch == '\r' || ch == '\n' || ch == '\t' {
value.push(' ');
} else {
value.push(ch);
}
}
}
Ok(value)
}
pub fn parse_quoted_value(&mut self) -> Result<String, ParseError> {
let quote = self.next_byte()?;
if quote != b'"' && quote != b'\'' {
return Err(self.fatal("expected quoted value"));
}
let start = self.pos;
while !self.at_end() && self.peek() != Some(quote) {
self.advance(1);
}
let value = std::str::from_utf8(&self.input[start..self.pos])
.map_err(|_| self.fatal("invalid UTF-8 in quoted value"))?
.to_string();
self.expect_byte(quote)?;
Ok(value)
}
pub fn fatal(&self, message: impl Into<String>) -> ParseError {
ParseError {
message: message.into(),
location: self.location(),
diagnostics: self.diagnostics.clone(),
}
}
pub fn push_diagnostic(&mut self, severity: ErrorSeverity, message: String) {
self.diagnostics.push(ParseDiagnostic {
severity,
message,
location: self.location(),
});
}
}
pub(crate) struct NamespaceResolver {
stack: Vec<Vec<(Option<String>, String)>>,
default_ns: Option<String>,
prefixed_ns: HashMap<String, String>,
}
pub(crate) const XML_NAMESPACE: &str = "http://www.w3.org/XML/1998/namespace";
impl NamespaceResolver {
pub fn new() -> Self {
let initial = vec![(Some("xml".to_string()), XML_NAMESPACE.to_string())];
let mut prefixed_ns = HashMap::new();
prefixed_ns.insert("xml".to_string(), XML_NAMESPACE.to_string());
Self {
stack: vec![initial],
default_ns: None,
prefixed_ns,
}
}
pub fn push_scope(&mut self) {
self.stack.push(Vec::new());
}
pub fn pop_scope(&mut self) {
if let Some(bindings) = self.stack.pop() {
for (prefix, _uri) in bindings.iter().rev() {
let prev = self
.stack
.iter()
.rev()
.flat_map(|frame| frame.iter().rev())
.find(|(p, _)| p == prefix)
.map(|(_, u)| u.clone());
match prefix {
None => {
self.default_ns = prev;
}
Some(pfx) => {
if let Some(prev_uri) = prev {
self.prefixed_ns.insert(pfx.clone(), prev_uri);
} else {
self.prefixed_ns.remove(pfx);
}
}
}
}
}
}
pub fn bind(&mut self, prefix: Option<String>, uri: String) {
if let Some(frame) = self.stack.last_mut() {
frame.push((prefix.clone(), uri.clone()));
}
match prefix {
None => {
self.default_ns = Some(uri);
}
Some(pfx) => {
self.prefixed_ns.insert(pfx, uri);
}
}
}
pub fn resolve(&self, prefix: Option<&str>) -> Option<&str> {
match prefix {
None => self.default_ns.as_deref().filter(|s| !s.is_empty()),
Some(pfx) => self
.prefixed_ns
.get(pfx)
.map(String::as_str)
.filter(|s| !s.is_empty()),
}
}
}
pub(crate) fn parse_comment_content(input: &mut ParserInput<'_>) -> Result<String, ParseError> {
input.expect_str(b"<!--")?;
let mut content = String::new();
loop {
match input.scan_for_2byte_terminator(b'-', b'-') {
Some(safe_len) => {
if safe_len > 0 {
let start = input.pos();
let has_bad =
may_contain_invalid_xml_chars(input.slice(start, start + safe_len));
let chunk = std::str::from_utf8(input.slice(start, start + safe_len))
.map_err(|_| input.fatal("invalid UTF-8 in comment"))?
.to_string();
if has_bad {
if let Some(bad) = find_invalid_xml_char(&chunk) {
if input.recover() {
input.push_diagnostic(
ErrorSeverity::Error,
format!("invalid XML character: U+{:04X}", bad as u32),
);
} else {
return Err(input.fatal(format!(
"invalid XML character: U+{:04X}",
bad as u32
)));
}
}
}
content.push_str(&chunk);
input.advance_counting_lines(safe_len);
}
if input.looking_at(b"-->") {
input.advance_counting_lines(3);
break;
}
if input.recover() {
input.push_diagnostic(
ErrorSeverity::Error,
"'--' not allowed inside comments".to_string(),
);
content.push_str("--");
input.advance_counting_lines(2);
} else {
return Err(input.fatal("'--' not allowed inside comments"));
}
}
None => {
return Err(input.fatal("unexpected end of input in comment"));
}
}
}
Ok(content)
}
pub(crate) fn parse_cdata_content(input: &mut ParserInput<'_>) -> Result<String, ParseError> {
input.expect_str(b"<![CDATA[")?;
match input.scan_for_3byte_terminator(b']', b']', b'>') {
Some(safe_len) => {
let start = input.pos();
let has_bad = may_contain_invalid_xml_chars(input.slice(start, start + safe_len));
let content = std::str::from_utf8(input.slice(start, start + safe_len))
.map_err(|_| input.fatal("invalid UTF-8 in CDATA section"))?
.to_string();
if has_bad {
if let Some(bad) = find_invalid_xml_char(&content) {
if input.recover() {
input.push_diagnostic(
ErrorSeverity::Error,
format!("invalid XML character: U+{:04X}", bad as u32),
);
} else {
return Err(
input.fatal(format!("invalid XML character: U+{:04X}", bad as u32))
);
}
}
}
input.advance_counting_lines(safe_len + 3); Ok(content)
}
None => Err(input.fatal("unexpected end of input in CDATA section")),
}
}
pub(crate) fn parse_pi_content(
input: &mut ParserInput<'_>,
) -> Result<(String, Option<String>), ParseError> {
input.expect_str(b"<?")?;
let target = input.parse_name()?;
if target.eq_ignore_ascii_case("xml") {
return Err(input.fatal("PI target 'xml' is reserved"));
}
if target.contains(':') {
return Err(input.fatal("PI target must not contain a colon"));
}
let data = if input.skip_whitespace() {
match input.scan_for_2byte_terminator(b'?', b'>') {
Some(data_len) => {
let start = input.pos();
let has_bad = may_contain_invalid_xml_chars(input.slice(start, start + data_len));
let data = std::str::from_utf8(input.slice(start, start + data_len))
.map_err(|_| input.fatal("invalid UTF-8 in processing instruction"))?
.to_string();
if has_bad {
if let Some(bad) = find_invalid_xml_char(&data) {
if input.recover() {
input.push_diagnostic(
ErrorSeverity::Error,
format!("invalid XML character: U+{:04X}", bad as u32),
);
} else {
return Err(
input.fatal(format!("invalid XML character: U+{:04X}", bad as u32))
);
}
}
}
input.advance_counting_lines(data_len + 2); if data.is_empty() {
None
} else {
Some(data)
}
}
None => {
return Err(input.fatal("unexpected end of input in processing instruction"));
}
}
} else {
input.expect_str(b"?>")?;
None
};
Ok((target, data))
}
#[derive(Debug, Clone)]
pub(crate) struct XmlDeclaration {
pub version: String,
pub encoding: Option<String>,
pub standalone: Option<bool>,
}
pub(crate) fn parse_xml_decl(input: &mut ParserInput<'_>) -> Result<XmlDeclaration, ParseError> {
input.expect_str(b"<?xml")?;
input.skip_whitespace_required()?;
input.expect_str(b"version")?;
input.skip_whitespace();
input.expect_byte(b'=')?;
input.skip_whitespace();
let version = input.parse_quoted_value()?;
if !is_valid_version_num(&version) {
return Err(input.fatal(format!("invalid version number: '{version}'")));
}
let had_ws = input.skip_whitespace();
let encoding = if input.looking_at(b"encoding") {
if !had_ws {
return Err(input.fatal("whitespace required before encoding"));
}
input.expect_str(b"encoding")?;
input.skip_whitespace();
input.expect_byte(b'=')?;
input.skip_whitespace();
let enc = input.parse_quoted_value()?;
if !is_valid_encoding_name(&enc) {
return Err(input.fatal(format!("invalid encoding name: '{enc}'")));
}
Some(enc)
} else {
None
};
let had_ws2 = input.skip_whitespace() || (encoding.is_none() && had_ws);
let standalone = if input.looking_at(b"standalone") {
if !had_ws2 {
return Err(input.fatal("whitespace required before standalone"));
}
input.expect_str(b"standalone")?;
input.skip_whitespace();
input.expect_byte(b'=')?;
input.skip_whitespace();
let val = input.parse_quoted_value()?;
match val.as_str() {
"yes" => Some(true),
"no" => Some(false),
_ => return Err(input.fatal("standalone must be 'yes' or 'no'")),
}
} else {
None
};
input.skip_whitespace();
input.expect_str(b"?>")?;
Ok(XmlDeclaration {
version,
encoding,
standalone,
})
}
fn is_valid_version_num(s: &str) -> bool {
if let Some(rest) = s.strip_prefix("1.") {
!rest.is_empty() && rest.bytes().all(|b| b.is_ascii_digit())
} else {
false
}
}
fn is_valid_encoding_name(s: &str) -> bool {
let bytes = s.as_bytes();
if bytes.is_empty() {
return false;
}
if !bytes[0].is_ascii_alphabetic() {
return false;
}
bytes[1..]
.iter()
.all(|&b| b.is_ascii_alphanumeric() || b == b'.' || b == b'_' || b == b'-')
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_peek_and_advance() {
let mut input = ParserInput::new("abc");
assert_eq!(input.peek(), Some(b'a'));
assert_eq!(input.peek_at(1), Some(b'b'));
input.advance(1);
assert_eq!(input.peek(), Some(b'b'));
input.advance(2);
assert!(input.at_end());
}
#[test]
fn test_line_column_tracking() {
let mut input = ParserInput::new("ab\ncd");
assert_eq!(input.location().line, 1);
assert_eq!(input.location().column, 1);
input.advance(2); assert_eq!(input.location().column, 3);
input.advance(1); assert_eq!(input.location().line, 2);
assert_eq!(input.location().column, 1);
}
#[test]
fn test_next_char_cr_normalization() {
let mut input = ParserInput::new("a\r\nb");
assert_eq!(input.next_char().unwrap(), 'a');
assert_eq!(input.next_char().unwrap(), '\n'); assert_eq!(input.next_char().unwrap(), 'b');
}
#[test]
fn test_parse_name() {
let mut input = ParserInput::new("foo:bar ");
let name = input.parse_name().unwrap();
assert_eq!(name, "foo:bar");
}
#[test]
fn test_parse_name_length_limit() {
let long_name = "a".repeat(100);
let mut input = ParserInput::new(&long_name);
input.set_max_name_length(50);
let result = input.parse_name();
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("name length"));
}
#[test]
fn test_parse_reference_builtin() {
let mut input = ParserInput::new("&");
assert_eq!(input.parse_reference().unwrap(), "&");
let mut input = ParserInput::new("<");
assert_eq!(input.parse_reference().unwrap(), "<");
let mut input = ParserInput::new(">");
assert_eq!(input.parse_reference().unwrap(), ">");
let mut input = ParserInput::new("'");
assert_eq!(input.parse_reference().unwrap(), "'");
let mut input = ParserInput::new(""");
assert_eq!(input.parse_reference().unwrap(), "\"");
}
#[test]
fn test_parse_reference_char_decimal() {
let mut input = ParserInput::new("A");
assert_eq!(input.parse_reference().unwrap(), "A");
}
#[test]
fn test_parse_reference_char_hex() {
let mut input = ParserInput::new("A");
assert_eq!(input.parse_reference().unwrap(), "A");
}
#[test]
fn test_parse_reference_unknown_error() {
let mut input = ParserInput::new("&bogus;");
assert!(input.parse_reference().is_err());
}
#[test]
fn test_parse_reference_unknown_recovery() {
let mut input = ParserInput::new("&bogus;");
input.set_recover(true);
let result = input.parse_reference().unwrap();
assert_eq!(result, "");
assert_eq!(input.diagnostics.len(), 1);
}
#[test]
fn test_entity_expansion_limit() {
let mut input = ParserInput::new("&&&");
input.set_max_entity_expansions(2);
assert!(input.parse_reference().is_ok());
assert!(input.parse_reference().is_ok());
assert!(input.parse_reference().is_err());
}
#[test]
fn test_depth_limit() {
let mut input = ParserInput::new("");
input.set_max_depth(2);
assert!(input.increment_depth().is_ok()); assert!(input.increment_depth().is_ok()); assert!(input.increment_depth().is_err()); }
#[test]
fn test_parse_attribute_value() {
let mut input = ParserInput::new("\"hello & world\"");
let value = input.parse_attribute_value().unwrap();
assert_eq!(value, "hello & world");
}
#[test]
fn test_parse_attribute_value_whitespace_normalization() {
let mut input = ParserInput::new("\"a\tb\nc\"");
let value = input.parse_attribute_value().unwrap();
assert_eq!(value, "a b c");
}
#[test]
fn test_parse_quoted_value() {
let mut input = ParserInput::new("'hello'");
let value = input.parse_quoted_value().unwrap();
assert_eq!(value, "hello");
}
#[test]
fn test_skip_whitespace() {
let mut input = ParserInput::new(" \t\n abc");
assert!(input.skip_whitespace());
assert_eq!(input.peek(), Some(b'a'));
}
#[test]
fn test_looking_at() {
let input = ParserInput::new("<!--comment-->");
assert!(input.looking_at(b"<!--"));
assert!(!input.looking_at(b"<![CDATA["));
}
#[test]
fn test_take_while() {
let mut input = ParserInput::new("12345abc");
let digits = input.take_while(|b| b.is_ascii_digit());
assert_eq!(digits, "12345");
assert_eq!(input.peek(), Some(b'a'));
}
#[test]
fn test_split_name() {
assert_eq!(split_name("foo:bar"), (Some("foo"), "bar"));
assert_eq!(split_name("bar"), (None, "bar"));
assert_eq!(split_name(":bar"), (Some(""), "bar"));
}
#[test]
fn test_namespace_resolver() {
let mut ns = NamespaceResolver::new();
assert_eq!(ns.resolve(Some("xml")), Some(XML_NAMESPACE));
assert_eq!(ns.resolve(None), None);
ns.push_scope();
ns.bind(None, "http://default".to_string());
ns.bind(Some("foo".to_string()), "http://foo".to_string());
assert_eq!(ns.resolve(None), Some("http://default"));
assert_eq!(ns.resolve(Some("foo")), Some("http://foo"));
ns.pop_scope();
assert_eq!(ns.resolve(None), None);
assert_eq!(ns.resolve(Some("foo")), None);
}
#[test]
fn test_namespace_undeclare_default() {
let mut ns = NamespaceResolver::new();
ns.push_scope();
ns.bind(None, "http://default".to_string());
assert_eq!(ns.resolve(None), Some("http://default"));
ns.push_scope();
ns.bind(None, String::new()); assert_eq!(ns.resolve(None), None);
ns.pop_scope();
assert_eq!(ns.resolve(None), Some("http://default"));
}
#[test]
fn test_parse_comment_content() {
let mut input = ParserInput::new("<!-- hello -->");
let content = parse_comment_content(&mut input).unwrap();
assert_eq!(content, " hello ");
}
#[test]
fn test_parse_cdata_content() {
let mut input = ParserInput::new("<![CDATA[some <data>]]>");
let content = parse_cdata_content(&mut input).unwrap();
assert_eq!(content, "some <data>");
}
#[test]
fn test_parse_pi_content() {
let mut input = ParserInput::new("<?target data?>");
let (target, data) = parse_pi_content(&mut input).unwrap();
assert_eq!(target, "target");
assert_eq!(data.as_deref(), Some("data"));
}
#[test]
fn test_parse_pi_no_data() {
let mut input = ParserInput::new("<?target?>");
let (target, data) = parse_pi_content(&mut input).unwrap();
assert_eq!(target, "target");
assert_eq!(data, None);
}
#[test]
fn test_parse_xml_decl() {
let mut input = ParserInput::new("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
let decl = parse_xml_decl(&mut input).unwrap();
assert_eq!(decl.version, "1.0");
assert_eq!(decl.encoding.as_deref(), Some("UTF-8"));
assert_eq!(decl.standalone, None);
}
#[test]
fn test_parse_xml_decl_standalone() {
let mut input =
ParserInput::new("<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"yes\"?>");
let decl = parse_xml_decl(&mut input).unwrap();
assert_eq!(decl.standalone, Some(true));
}
#[test]
fn test_is_name_chars() {
assert!(is_name_start_char('a'));
assert!(is_name_start_char('Z'));
assert!(is_name_start_char('_'));
assert!(is_name_start_char(':'));
assert!(!is_name_start_char('0'));
assert!(!is_name_start_char('-'));
assert!(is_name_char('a'));
assert!(is_name_char('0'));
assert!(is_name_char('-'));
assert!(is_name_char('.'));
assert!(!is_name_char(' '));
}
#[test]
fn test_increment_depth_exact_boundary() {
let mut input = ParserInput::new("");
input.set_max_depth(3);
assert!(input.increment_depth().is_ok()); assert!(input.increment_depth().is_ok()); assert!(input.increment_depth().is_ok()); assert!(input.increment_depth().is_err()); }
#[test]
fn test_increment_depth_max_depth_one() {
let mut input = ParserInput::new("");
input.set_max_depth(1);
assert!(input.increment_depth().is_ok()); let err = input.increment_depth().unwrap_err();
assert!(err.message.contains("maximum nesting depth exceeded"));
}
#[test]
fn test_increment_depth_max_depth_zero() {
let mut input = ParserInput::new("");
input.set_max_depth(0);
let err = input.increment_depth().unwrap_err();
assert!(err.message.contains("maximum nesting depth exceeded"));
}
#[test]
fn test_decrement_depth_saturates_at_zero() {
let mut input = ParserInput::new("");
input.decrement_depth();
assert_eq!(input.depth(), 0);
input.increment_depth().unwrap();
assert_eq!(input.depth(), 1);
input.decrement_depth();
assert_eq!(input.depth(), 0);
input.decrement_depth();
assert_eq!(input.depth(), 0);
}
#[test]
fn test_depth_resets_after_decrement_allows_reentry() {
let mut input = ParserInput::new("");
input.set_max_depth(2);
assert!(input.increment_depth().is_ok()); assert!(input.increment_depth().is_ok()); input.decrement_depth(); assert!(input.increment_depth().is_ok()); assert!(input.increment_depth().is_err()); }
#[test]
fn test_entity_expansion_limit_exact_boundary() {
let mut input = ParserInput::new("&&&&");
input.set_max_entity_expansions(3);
assert!(input.parse_reference().is_ok()); assert!(input.parse_reference().is_ok()); assert!(input.parse_reference().is_ok()); let err = input.parse_reference().unwrap_err();
assert!(err.message.contains("entity expansion limit exceeded"));
}
#[test]
fn test_entity_expansion_limit_zero() {
let mut input = ParserInput::new("&");
input.set_max_entity_expansions(0);
let err = input.parse_reference().unwrap_err();
assert!(err.message.contains("entity expansion limit exceeded"));
}
#[test]
fn test_entity_expansion_limit_one() {
let mut input = ParserInput::new("&<");
input.set_max_entity_expansions(1);
assert!(input.parse_reference().is_ok());
let err = input.parse_reference().unwrap_err();
assert!(err.message.contains("entity expansion limit exceeded"));
}
#[test]
fn test_entity_expansion_counter_includes_char_refs() {
let mut input = ParserInput::new("ABC");
input.set_max_entity_expansions(2);
assert!(input.parse_reference().is_ok()); assert!(input.parse_reference().is_ok()); let err = input.parse_reference().unwrap_err();
assert!(err.message.contains("entity expansion limit exceeded"));
}
#[test]
fn test_entity_expansion_limit_via_parse_str() {
use crate::parser::{parse_str_with_options, ParseOptions};
let refs: String = (0..50).map(|_| "A").collect();
let xml = format!("<r>{refs}</r>");
let opts = ParseOptions::default().max_entity_expansions(10);
let result = parse_str_with_options(&xml, &opts);
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("entity expansion limit"));
}
#[test]
fn test_entity_expansion_dtd_internal_entity() {
use crate::parser::{parse_str_with_options, ParseOptions};
let xml = r#"<!DOCTYPE r [
<!ENTITY greet "Hello">
]>
<r>&greet;</r>"#;
let doc = parse_str_with_options(xml, &ParseOptions::default()).unwrap();
let root = doc.root_element().unwrap();
assert_eq!(doc.text_content(root), "Hello");
}
#[test]
fn test_entity_expansion_nested_dtd_entities() {
use crate::parser::{parse_str_with_options, ParseOptions};
let xml = r#"<!DOCTYPE r [
<!ENTITY a "world">
<!ENTITY b "hello &a;">
]>
<r>&b;</r>"#;
let doc = parse_str_with_options(xml, &ParseOptions::default()).unwrap();
let root = doc.root_element().unwrap();
let content = doc.text_content(root);
assert!(
content.contains("hello"),
"entity value should contain 'hello', got: {content}"
);
}
#[test]
fn test_entity_expansion_limit_nested_dtd_entities_in_attributes() {
use crate::parser::{parse_str_with_options, ParseOptions};
let xml = r#"<!DOCTYPE r [
<!ENTITY a "x">
<!ENTITY b "&a;&a;&a;">
<!ENTITY c "&b;&b;&b;">
]>
<r v="&c;"/>"#;
let opts = ParseOptions::default().max_entity_expansions(5);
let result = parse_str_with_options(xml, &opts);
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("entity expansion limit"));
}
#[test]
fn test_billion_laughs_entity_bomb_in_attribute() {
use crate::parser::{parse_str_with_options, ParseOptions};
let xml = r#"<!DOCTYPE r [
<!ENTITY lol "lol">
<!ENTITY lol2 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
<!ENTITY lol3 "&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;">
<!ENTITY lol4 "&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;&lol3;">
]>
<r v="&lol4;"/>"#;
let opts = ParseOptions::default().max_entity_expansions(100);
let result = parse_str_with_options(xml, &opts);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.message.contains("entity expansion limit"),
"billion laughs should be caught by expansion limit, got: {}",
err.message
);
}
#[test]
fn test_billion_laughs_in_text_content_with_markup() {
use crate::parser::{parse_str_with_options, ParseOptions};
let xml = r#"<!DOCTYPE r [
<!ENTITY lol "lol">
<!ENTITY lol2 "&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;&lol;">
<!ENTITY lol3 "<i>&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;&lol2;</i>">
]>
<r>&lol3;</r>"#;
let opts = ParseOptions::default().max_entity_expansions(50);
let result = parse_str_with_options(xml, &opts);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.message.contains("entity expansion limit"),
"billion laughs with markup should be caught, got: {}",
err.message
);
}
#[test]
fn test_xxe_external_entity_rejected_by_default() {
use crate::parser::{parse_str_with_options, ParseOptions};
let xml = r#"<!DOCTYPE r [
<!ENTITY xxe SYSTEM "file:///etc/passwd">
]>
<r>&xxe;</r>"#;
let result = parse_str_with_options(xml, &ParseOptions::default());
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.message.contains("external entity"),
"XXE should be rejected by default, got: {}",
err.message
);
}
#[test]
fn test_xxe_external_entity_with_public_id_rejected() {
use crate::parser::{parse_str_with_options, ParseOptions};
let xml = r#"<!DOCTYPE r [
<!ENTITY xxe PUBLIC "-//Evil//EN" "http://evil.com/payload">
]>
<r>&xxe;</r>"#;
let result = parse_str_with_options(xml, &ParseOptions::default());
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
err.message.contains("external entity"),
"XXE with PUBLIC id should be rejected, got: {}",
err.message
);
}
#[test]
fn test_xxe_external_entity_in_attribute_rejected() {
use crate::parser::{parse_str_with_options, ParseOptions};
let xml = r#"<!DOCTYPE r [
<!ENTITY xxe SYSTEM "file:///etc/shadow">
]>
<r a="&xxe;"/>"#;
let result = parse_str_with_options(xml, &ParseOptions::default());
assert!(result.is_err(), "XXE in attribute should be rejected");
}
#[test]
fn test_xxe_multiple_external_entities_all_rejected() {
use crate::parser::{parse_str_with_options, ParseOptions};
let xml = r#"<!DOCTYPE r [
<!ENTITY safe "ok">
<!ENTITY evil SYSTEM "file:///etc/passwd">
]>
<r>&safe;&evil;</r>"#;
let result = parse_str_with_options(xml, &ParseOptions::default());
assert!(result.is_err());
}
#[test]
fn test_char_ref_null_character_rejected() {
let mut input = ParserInput::new("�");
let err = input.parse_reference().unwrap_err();
assert!(
err.message.contains("valid XML character"),
"null char ref should be rejected as invalid XML char, got: {}",
err.message
);
}
#[test]
fn test_char_ref_null_hex_rejected() {
let mut input = ParserInput::new("�");
let err = input.parse_reference().unwrap_err();
assert!(
err.message.contains("valid XML character"),
"� should be rejected, got: {}",
err.message
);
}
#[test]
fn test_char_ref_control_chars_rejected() {
for codepoint in [1u32, 2, 7, 8, 0x0B, 0x0C, 0x0E, 0x1F] {
let ref_str = format!("&#x{codepoint:X};");
let mut input = ParserInput::new(&ref_str);
let result = input.parse_reference();
assert!(
result.is_err(),
"&#x{codepoint:X}; should be rejected as invalid XML character"
);
}
}
#[test]
fn test_char_ref_allowed_control_chars() {
let mut input = ParserInput::new("	");
assert_eq!(input.parse_reference().unwrap(), "\t");
let mut input = ParserInput::new("
");
assert_eq!(input.parse_reference().unwrap(), "\n");
let mut input = ParserInput::new("
");
assert_eq!(input.parse_reference().unwrap(), "\r");
}
#[test]
fn test_char_ref_surrogate_codepoints_rejected() {
let mut input = ParserInput::new("�");
let err = input.parse_reference().unwrap_err();
assert!(
err.message.contains("invalid character reference"),
"surrogate � should be rejected, got: {}",
err.message
);
let mut input = ParserInput::new("�");
let err = input.parse_reference().unwrap_err();
assert!(
err.message.contains("invalid character reference"),
"surrogate � should be rejected, got: {}",
err.message
);
}
#[test]
fn test_char_ref_fffe_and_ffff_rejected() {
let mut input = ParserInput::new("");
let err = input.parse_reference().unwrap_err();
assert!(
err.message.contains("valid XML character"),
" should be rejected, got: {}",
err.message
);
let mut input = ParserInput::new("");
let err = input.parse_reference().unwrap_err();
assert!(
err.message.contains("valid XML character"),
" should be rejected, got: {}",
err.message
);
}
#[test]
fn test_char_ref_max_valid_codepoint() {
let mut input = ParserInput::new("");
let result = input.parse_reference().unwrap();
assert_eq!(result, "\u{10FFFF}");
}
#[test]
fn test_char_ref_beyond_unicode_range() {
let mut input = ParserInput::new("�");
let err = input.parse_reference().unwrap_err();
assert!(
err.message.contains("invalid character reference"),
"codepoint beyond Unicode range should be rejected, got: {}",
err.message
);
}
#[test]
fn test_char_ref_very_large_decimal_rejected() {
let mut input = ParserInput::new("�");
let err = input.parse_reference().unwrap_err();
assert!(
err.message.contains("invalid decimal character reference"),
"overflowing decimal char ref should be rejected, got: {}",
err.message
);
}
#[test]
fn test_char_ref_very_large_hex_rejected() {
let mut input = ParserInput::new("�");
let err = input.parse_reference().unwrap_err();
assert!(
err.message.contains("invalid hex character reference"),
"overflowing hex char ref should be rejected, got: {}",
err.message
);
}
#[test]
fn test_char_ref_empty_decimal_rejected() {
let mut input = ParserInput::new("&#;");
let err = input.parse_reference().unwrap_err();
assert!(
err.message.contains("empty decimal character reference"),
"empty decimal ref should be rejected, got: {}",
err.message
);
}
#[test]
fn test_char_ref_empty_hex_rejected() {
let mut input = ParserInput::new("&#x;");
let err = input.parse_reference().unwrap_err();
assert!(
err.message.contains("empty hex character reference"),
"empty hex ref should be rejected, got: {}",
err.message
);
}
#[test]
fn test_char_ref_valid_bmp_characters() {
let mut input = ParserInput::new(" ");
assert_eq!(input.parse_reference().unwrap(), " ");
let mut input = ParserInput::new("A");
assert_eq!(input.parse_reference().unwrap(), "A");
let mut input = ParserInput::new("中"); assert_eq!(input.parse_reference().unwrap(), "\u{4E2D}");
}
#[test]
fn test_char_ref_supplementary_plane() {
let mut input = ParserInput::new("𝄞");
assert_eq!(input.parse_reference().unwrap(), "\u{1D11E}");
}
#[test]
fn test_parse_name_at_exact_length_limit() {
let name = "a".repeat(50);
let input_str = format!("{name} ");
let mut input = ParserInput::new(&input_str);
input.set_max_name_length(50);
let result = input.parse_name().unwrap();
assert_eq!(result.len(), 50);
}
#[test]
fn test_parse_name_one_over_length_limit() {
let name = "a".repeat(51);
let input_str = format!("{name} ");
let mut input = ParserInput::new(&input_str);
input.set_max_name_length(50);
let result = input.parse_name();
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("name length"));
}
#[test]
fn test_parse_name_length_limit_one() {
let mut input = ParserInput::new("a ");
input.set_max_name_length(1);
assert_eq!(input.parse_name().unwrap(), "a");
let mut input = ParserInput::new("ab ");
input.set_max_name_length(1);
assert!(input.parse_name().is_err());
}
#[test]
fn test_parse_name_unicode_length_counted_in_bytes() {
let name = "\u{C0}\u{C0}\u{C0}"; let input_str = format!("{name} ");
let mut input = ParserInput::new(&input_str);
input.set_max_name_length(5);
let result = input.parse_name();
assert!(
result.is_err(),
"6-byte unicode name should exceed 5-byte limit"
);
let mut input = ParserInput::new(&input_str);
input.set_max_name_length(6);
assert!(
input.parse_name().is_ok(),
"6-byte unicode name should fit 6-byte limit"
);
}
#[test]
fn test_parse_name_eq_length_limit() {
let name = "a".repeat(51);
let input_str = format!("{name} ");
let mut input = ParserInput::new(&input_str);
input.set_max_name_length(50);
let result = input.parse_name_eq("something");
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("name length"));
}
#[test]
fn test_parse_name_eq_parts_length_limit() {
let name = "a".repeat(51);
let input_str = format!("{name} ");
let mut input = ParserInput::new(&input_str);
input.set_max_name_length(50);
let result = input.parse_name_eq_parts(None, "something");
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("name length"));
}
#[test]
fn test_comment_double_dash_rejected() {
let mut input = ParserInput::new("<!-- bad -- comment -->");
let result = parse_comment_content(&mut input);
assert!(result.is_err());
assert!(
result.unwrap_err().message.contains("'--' not allowed"),
"double dash inside comment should be rejected"
);
}
#[test]
fn test_comment_double_dash_recovery() {
let mut input = ParserInput::new("<!-- bad -- comment -->");
input.set_recover(true);
let content = parse_comment_content(&mut input).unwrap();
assert!(content.contains("--"));
assert!(!input.diagnostics.is_empty());
}
#[test]
fn test_comment_unterminated() {
let mut input = ParserInput::new("<!-- no end");
let result = parse_comment_content(&mut input);
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("unexpected end of input in comment"));
}
#[test]
fn test_comment_empty() {
let mut input = ParserInput::new("<!---->");
let content = parse_comment_content(&mut input).unwrap();
assert_eq!(content, "");
}
#[test]
fn test_comment_single_dash_allowed() {
let mut input = ParserInput::new("<!-- a - b -->");
let content = parse_comment_content(&mut input).unwrap();
assert_eq!(content, " a - b ");
}
#[test]
fn test_comment_ending_with_triple_dash_rejected() {
let mut input = ParserInput::new("<!----->");
let result = parse_comment_content(&mut input);
assert!(
result.is_err() || {
false
}
);
}
#[test]
fn test_cdata_unterminated() {
let mut input = ParserInput::new("<![CDATA[no end");
let result = parse_cdata_content(&mut input);
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("unexpected end of input in CDATA"));
}
#[test]
fn test_cdata_empty() {
let mut input = ParserInput::new("<![CDATA[]]>");
let content = parse_cdata_content(&mut input).unwrap();
assert_eq!(content, "");
}
#[test]
fn test_cdata_with_angle_brackets() {
let mut input = ParserInput::new("<![CDATA[<div>hello</div>]]>");
let content = parse_cdata_content(&mut input).unwrap();
assert_eq!(content, "<div>hello</div>");
}
#[test]
fn test_cdata_with_double_bracket_not_terminator() {
let mut input = ParserInput::new("<![CDATA[a]]b]]>");
let content = parse_cdata_content(&mut input).unwrap();
assert_eq!(content, "a]]b");
}
#[test]
fn test_cdata_with_ampersand() {
let mut input = ParserInput::new("<![CDATA[& <]]>");
let content = parse_cdata_content(&mut input).unwrap();
assert_eq!(content, "& <");
}
#[test]
fn test_pi_target_xml_reserved() {
let mut input = ParserInput::new("<?xml data?>");
let result = parse_pi_content(&mut input);
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("PI target 'xml' is reserved"));
}
#[test]
fn test_pi_target_xml_case_insensitive() {
for target in ["XML", "Xml", "xMl", "xmL"] {
let pi = format!("<?{target} data?>");
let mut input = ParserInput::new(&pi);
let result = parse_pi_content(&mut input);
assert!(result.is_err(), "PI target '{target}' should be reserved");
}
}
#[test]
fn test_pi_target_with_colon_rejected() {
let mut input = ParserInput::new("<?ns:target data?>");
let result = parse_pi_content(&mut input);
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("must not contain a colon"));
}
#[test]
fn test_pi_unterminated() {
let mut input = ParserInput::new("<?target no end");
let result = parse_pi_content(&mut input);
assert!(result.is_err());
}
#[test]
fn test_pi_empty_data_after_whitespace() {
let mut input = ParserInput::new("<?target ?>");
let (target, data) = parse_pi_content(&mut input).unwrap();
assert_eq!(target, "target");
assert_eq!(data, None); }
#[test]
fn test_attribute_value_less_than_rejected() {
let mut input = ParserInput::new("\"abc<def\"");
let result = input.parse_attribute_value();
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("'<' not allowed in attribute values"));
}
#[test]
fn test_attribute_value_unterminated() {
let mut input = ParserInput::new("\"no closing quote");
let result = input.parse_attribute_value();
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("unexpected end of input"));
}
#[test]
fn test_attribute_value_not_quoted() {
let mut input = ParserInput::new("unquoted");
let result = input.parse_attribute_value();
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("attribute value must be quoted"));
}
#[test]
fn test_attribute_value_single_quotes() {
let mut input = ParserInput::new("'hello'");
let value = input.parse_attribute_value().unwrap();
assert_eq!(value, "hello");
}
#[test]
fn test_attribute_value_entity_with_less_than_rejected() {
use crate::parser::{parse_str_with_options, ParseOptions};
let xml = r#"<!DOCTYPE r [
<!ENTITY bad "a<b">
]>
<r a="&bad;"/>"#;
let result = parse_str_with_options(xml, &ParseOptions::default());
assert!(result.is_err());
}
#[test]
fn test_is_xml_char_boundary_values() {
assert!(is_xml_char('\t')); assert!(is_xml_char('\n')); assert!(is_xml_char('\r')); assert!(is_xml_char(' ')); assert!(is_xml_char('\u{D7FF}'));
assert!(is_xml_char('\u{E000}'));
assert!(is_xml_char('\u{FFFD}'));
assert!(is_xml_char('\u{10000}'));
assert!(is_xml_char('\u{10FFFF}'));
assert!(!is_xml_char('\0')); assert!(!is_xml_char('\u{0001}')); assert!(!is_xml_char('\u{0008}')); assert!(!is_xml_char('\u{000B}')); assert!(!is_xml_char('\u{000C}')); assert!(!is_xml_char('\u{000E}')); assert!(!is_xml_char('\u{001F}')); assert!(!is_xml_char('\u{FFFE}')); assert!(!is_xml_char('\u{FFFF}')); }
#[test]
fn test_next_char_rejects_control_characters() {
let input_bytes = "\x01";
let mut input = ParserInput::new(input_bytes);
let result = input.next_char();
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("invalid XML character"));
}
#[test]
fn test_next_char_control_char_recovery() {
let input_bytes = "\x01X";
let mut input = ParserInput::new(input_bytes);
input.set_recover(true);
let ch = input.next_char().unwrap();
assert_eq!(ch, '\x01');
assert!(!input.diagnostics.is_empty());
assert_eq!(input.next_char().unwrap(), 'X');
}
#[test]
fn test_scan_char_data_cdata_end_marker() {
let input = ParserInput::new("text]]>more");
let len = input.scan_char_data();
assert_eq!(len, 4); }
#[test]
fn test_scan_char_data_empty() {
let input = ParserInput::new("<");
assert_eq!(input.scan_char_data(), 0);
}
#[test]
fn test_scan_char_data_stops_at_ampersand() {
let input = ParserInput::new("text&ref;");
assert_eq!(input.scan_char_data(), 4);
}
#[test]
fn test_scan_char_data_stops_at_less_than() {
let input = ParserInput::new("text<elem");
assert_eq!(input.scan_char_data(), 4);
}
#[test]
fn test_scan_for_2byte_terminator_at_end() {
let input = ParserInput::new("x");
assert_eq!(input.scan_for_2byte_terminator(b'-', b'-'), None);
}
#[test]
fn test_scan_for_2byte_terminator_exact_2_bytes() {
let input = ParserInput::new("--");
assert_eq!(input.scan_for_2byte_terminator(b'-', b'-'), Some(0));
}
#[test]
fn test_scan_for_3byte_terminator_at_end() {
let input = ParserInput::new("]]");
assert_eq!(input.scan_for_3byte_terminator(b']', b']', b'>'), None);
}
#[test]
fn test_scan_for_3byte_terminator_exact_3_bytes() {
let input = ParserInput::new("]]>");
assert_eq!(input.scan_for_3byte_terminator(b']', b']', b'>'), Some(0));
}
#[test]
fn test_namespace_resolver_nested_override() {
let mut ns = NamespaceResolver::new();
ns.push_scope();
ns.bind(Some("p".to_string()), "http://outer".to_string());
assert_eq!(ns.resolve(Some("p")), Some("http://outer"));
ns.push_scope();
ns.bind(Some("p".to_string()), "http://inner".to_string());
assert_eq!(ns.resolve(Some("p")), Some("http://inner"));
ns.pop_scope();
assert_eq!(ns.resolve(Some("p")), Some("http://outer"));
ns.pop_scope();
assert_eq!(ns.resolve(Some("p")), None);
}
#[test]
fn test_namespace_resolver_default_ns_override_and_restore() {
let mut ns = NamespaceResolver::new();
ns.push_scope();
ns.bind(None, "http://a".to_string());
ns.push_scope();
ns.bind(None, "http://b".to_string());
assert_eq!(ns.resolve(None), Some("http://b"));
ns.pop_scope();
assert_eq!(ns.resolve(None), Some("http://a"));
ns.pop_scope();
assert_eq!(ns.resolve(None), None);
}
#[test]
fn test_namespace_resolver_undeclare_default_then_redeclare() {
let mut ns = NamespaceResolver::new();
ns.push_scope();
ns.bind(None, "http://ns".to_string());
ns.push_scope();
ns.bind(None, String::new()); assert_eq!(ns.resolve(None), None);
ns.push_scope();
ns.bind(None, "http://new".to_string()); assert_eq!(ns.resolve(None), Some("http://new"));
ns.pop_scope();
assert_eq!(ns.resolve(None), None); ns.pop_scope();
assert_eq!(ns.resolve(None), Some("http://ns")); ns.pop_scope();
assert_eq!(ns.resolve(None), None);
}
#[test]
fn test_namespace_resolver_xml_prefix_always_bound() {
let ns = NamespaceResolver::new();
assert_eq!(ns.resolve(Some("xml")), Some(XML_NAMESPACE));
}
#[test]
fn test_namespace_resolver_unbound_prefix() {
let ns = NamespaceResolver::new();
assert_eq!(ns.resolve(Some("foo")), None);
assert_eq!(ns.resolve(Some("xmlns")), None);
}
#[test]
fn test_namespace_resolver_many_scopes() {
let mut ns = NamespaceResolver::new();
for i in 0..100 {
ns.push_scope();
ns.bind(Some("p".to_string()), format!("http://ns/{i}"));
}
assert_eq!(ns.resolve(Some("p")), Some("http://ns/99"));
for i in (0..100).rev() {
ns.pop_scope();
if i > 0 {
let expected = format!("http://ns/{}", i - 1);
assert_eq!(ns.resolve(Some("p")), Some(expected.as_str()));
}
}
assert_eq!(ns.resolve(Some("p")), None);
}
#[test]
fn test_validate_qname_valid() {
assert_eq!(validate_qname("foo"), None);
assert_eq!(validate_qname("ns:local"), None);
assert_eq!(validate_qname("a"), None);
}
#[test]
fn test_validate_qname_multiple_colons() {
let result = validate_qname("a:b:c");
assert!(result.is_some());
assert!(result.unwrap().contains("multiple colons"));
}
#[test]
fn test_validate_qname_empty_prefix() {
let result = validate_qname(":local");
assert!(result.is_some());
assert!(result.unwrap().contains("empty prefix or local part"));
}
#[test]
fn test_validate_qname_empty_local() {
let result = validate_qname("prefix:");
assert!(result.is_some());
assert!(result.unwrap().contains("empty prefix or local part"));
}
#[test]
fn test_split_owned_name_with_prefix() {
let (prefix, local) = split_owned_name("ns:elem".to_string());
assert_eq!(prefix.as_deref(), Some("ns"));
assert_eq!(local, "elem");
}
#[test]
fn test_split_owned_name_no_prefix() {
let (prefix, local) = split_owned_name("elem".to_string());
assert_eq!(prefix, None);
assert_eq!(local, "elem");
}
#[test]
fn test_validate_pubid_valid() {
assert_eq!(validate_pubid("-//W3C//DTD XML 1.0//EN"), None);
}
#[test]
fn test_validate_pubid_invalid_char() {
let result = validate_pubid("bad\x01char");
assert!(result.is_some());
assert!(result.unwrap().contains("invalid character"));
}
#[test]
fn test_xml_decl_invalid_version() {
let mut input = ParserInput::new("<?xml version=\"2.0\"?>");
let result = parse_xml_decl(&mut input);
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("invalid version number"));
}
#[test]
fn test_xml_decl_invalid_encoding() {
let mut input = ParserInput::new("<?xml version=\"1.0\" encoding=\"123bad\"?>");
let result = parse_xml_decl(&mut input);
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("invalid encoding name"));
}
#[test]
fn test_xml_decl_standalone_invalid() {
let mut input = ParserInput::new("<?xml version=\"1.0\" standalone=\"maybe\"?>");
let result = parse_xml_decl(&mut input);
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("standalone must be 'yes' or 'no'"));
}
#[test]
fn test_xml_decl_standalone_no() {
let mut input = ParserInput::new("<?xml version=\"1.0\" standalone=\"no\"?>");
let decl = parse_xml_decl(&mut input).unwrap();
assert_eq!(decl.standalone, Some(false));
}
#[test]
fn test_save_restore_position() {
let mut input = ParserInput::new("abcdef");
input.advance(3);
assert_eq!(input.peek(), Some(b'd'));
let saved = input.save_position();
input.advance(2);
assert_eq!(input.peek(), Some(b'f'));
input.restore_position(saved);
assert_eq!(input.peek(), Some(b'd'));
assert_eq!(input.location().column, 4); }
#[test]
fn test_depth_limit_via_parse_str_with_options() {
use crate::parser::{parse_str_with_options, ParseOptions};
let xml = "<a><b><c/></b></a>";
let opts = ParseOptions::default().max_depth(3);
assert!(parse_str_with_options(xml, &opts).is_ok());
let xml = "<a><b><c><d/></c></b></a>";
let result = parse_str_with_options(xml, &opts);
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("depth"));
}
#[test]
fn test_match_builtin_entity_all() {
assert_eq!(match_builtin_entity(b"amp;"), Some(("&", 4)));
assert_eq!(match_builtin_entity(b"lt;"), Some(("<", 3)));
assert_eq!(match_builtin_entity(b"gt;"), Some((">", 3)));
assert_eq!(match_builtin_entity(b"apos;"), Some(("'", 5)));
assert_eq!(match_builtin_entity(b"quot;"), Some(("\"", 5)));
}
#[test]
fn test_match_builtin_entity_partial_no_match() {
assert_eq!(match_builtin_entity(b"am"), None);
assert_eq!(match_builtin_entity(b"l"), None);
assert_eq!(match_builtin_entity(b"apo"), None);
assert_eq!(match_builtin_entity(b"quo"), None);
}
#[test]
fn test_match_builtin_entity_unknown() {
assert_eq!(match_builtin_entity(b"foo;"), None);
assert_eq!(match_builtin_entity(b""), None);
assert_eq!(match_builtin_entity(b"x"), None);
}
#[test]
fn test_expand_entity_text_counts_against_limit() {
let mut input = ParserInput::new("");
input
.entity_map
.insert("a".to_string(), "hello".to_string());
input
.entity_map
.insert("b".to_string(), "&a; &a;".to_string());
input.set_max_entity_expansions(2);
let result = input.expand_entity_text("&b;");
assert!(
result.is_ok() || result.is_err(),
"expansion should either succeed at limit or fail over limit"
);
}
#[test]
fn test_expand_entity_text_no_references() {
let mut input = ParserInput::new("");
let result = input.expand_entity_text("plain text").unwrap();
assert_eq!(result, "plain text");
}
#[test]
fn test_expand_entity_text_builtin_entities() {
let mut input = ParserInput::new("");
let result = input.expand_entity_text("a & b < c").unwrap();
assert_eq!(result, "a & b < c");
}
#[test]
fn test_expand_entity_text_char_refs() {
let mut input = ParserInput::new("");
let result = input.expand_entity_text("A B").unwrap();
assert_eq!(result, "A B");
}
#[test]
fn test_expand_entity_text_unknown_entity_strict() {
let mut input = ParserInput::new("");
let result = input.expand_entity_text("&unknown;");
assert!(result.is_err());
}
#[test]
fn test_expand_entity_text_cdata_not_expanded() {
let mut input = ParserInput::new("");
let result = input
.expand_entity_text("<![CDATA[& not expanded]]>")
.unwrap();
assert_eq!(result, "<![CDATA[& not expanded]]>");
}
#[test]
fn test_find_invalid_xml_char_clean() {
assert_eq!(find_invalid_xml_char("hello world"), None);
assert_eq!(find_invalid_xml_char("tab\there"), None);
assert_eq!(find_invalid_xml_char("newline\nhere"), None);
}
#[test]
fn test_find_invalid_xml_char_with_null() {
assert_eq!(find_invalid_xml_char("bad\x00char"), Some('\x00'));
}
#[test]
fn test_find_invalid_xml_char_with_control() {
assert_eq!(find_invalid_xml_char("bad\x01char"), Some('\x01'));
assert_eq!(find_invalid_xml_char("bad\x08char"), Some('\x08'));
}
#[test]
fn test_may_contain_invalid_xml_chars_fast_check() {
assert!(!may_contain_invalid_xml_chars(b"hello world"));
assert!(!may_contain_invalid_xml_chars(b"tab\there"));
assert!(!may_contain_invalid_xml_chars(b"newline\nhere"));
assert!(may_contain_invalid_xml_chars(b"bad\x00char"));
assert!(may_contain_invalid_xml_chars(b"bad\x01char"));
assert!(may_contain_invalid_xml_chars(b"\x7F")); }
}