use std::borrow::Cow;
#[must_use]
pub fn escape_text(value: &str) -> Cow<'_, str> {
if !value.bytes().any(needs_text_rewrite) {
return Cow::Borrowed(value);
}
let mut out = String::with_capacity(value.len());
for ch in value.chars() {
match ch {
'&' => out.push_str("&"),
'<' => out.push_str("<"),
'>' => out.push_str(">"),
other if is_xml_illegal_control(other) => {} other => out.push(other),
}
}
Cow::Owned(out)
}
#[must_use]
pub fn escape_attr(value: &str) -> Cow<'_, str> {
if !value.bytes().any(needs_attr_rewrite) {
return Cow::Borrowed(value);
}
let mut out = String::with_capacity(value.len());
push_escaped_attr(&mut out, value);
Cow::Owned(out)
}
pub(crate) fn push_escaped_attr(out: &mut String, value: &str) {
for ch in value.chars() {
match ch {
'&' => out.push_str("&"),
'<' => out.push_str("<"),
'>' => out.push_str(">"),
'"' => out.push_str("""),
other if is_xml_illegal_control(other) => {} other => out.push(other),
}
}
}
pub(crate) fn push_escaped_text(out: &mut String, value: &str) {
for ch in value.chars() {
match ch {
'&' => out.push_str("&"),
'<' => out.push_str("<"),
'>' => out.push_str(">"),
other if is_xml_illegal_control(other) => {} other => out.push(other),
}
}
}
#[inline]
fn needs_text_rewrite(b: u8) -> bool {
matches!(b, b'&' | b'<' | b'>') || maybe_illegal_control_byte(b)
}
#[inline]
fn needs_attr_rewrite(b: u8) -> bool {
matches!(b, b'&' | b'<' | b'>' | b'"') || maybe_illegal_control_byte(b)
}
#[inline]
fn maybe_illegal_control_byte(b: u8) -> bool {
matches!(b, 0x00..=0x08 | 0x0B | 0x0C | 0x0E..=0x1F | 0x7F | 0xC2)
}
#[inline]
fn is_xml_illegal_control(ch: char) -> bool {
let c = ch as u32;
(c < 0x20 && c != 0x09 && c != 0x0A && c != 0x0D) || c == 0x7F || (0x80..=0x9F).contains(&c)
}
#[inline]
#[must_use]
pub(crate) fn is_xml_whitespace_only(s: &str) -> bool {
s.bytes().all(|b| matches!(b, b' ' | b'\t' | b'\r' | b'\n'))
}
#[inline]
#[must_use]
pub fn is_name_start(b: u8) -> bool {
b.is_ascii_alphabetic() || b == b'_'
}
#[inline]
#[must_use]
pub fn is_name_char(b: u8) -> bool {
b.is_ascii_alphanumeric() || b == b'-' || b == b'_' || b == b'.'
}
#[must_use]
pub fn is_valid_name(name: &str) -> bool {
let bytes = name.as_bytes();
let Some((first, rest)) = bytes.split_first() else {
return false;
};
if !is_name_start(*first) {
return false;
}
rest.iter().all(|b| is_name_char(*b))
}
pub(crate) fn decode_entities(value: &str) -> Cow<'_, str> {
if !value.bytes().any(|b| b == b'&') {
return Cow::Borrowed(value);
}
let bytes = value.as_bytes();
let mut out = String::with_capacity(value.len());
let mut i = 0;
while i < bytes.len() {
if bytes[i] == b'&' {
let limit = (i + 16).min(bytes.len());
if let Some(rel) = bytes[i + 1..limit].iter().position(|&b| b == b';') {
let end = i + 1 + rel;
let body = &value[i + 1..end];
if let Some(decoded) = decode_one_entity(body) {
out.push(decoded);
i = end + 1;
continue;
}
}
out.push('&');
i += 1;
continue;
}
let ch = value[i..].chars().next().expect("non-empty tail");
out.push(ch);
i += ch.len_utf8();
}
Cow::Owned(out)
}
fn decode_one_entity(body: &str) -> Option<char> {
match body {
"amp" => Some('&'),
"lt" => Some('<'),
"gt" => Some('>'),
"apos" => Some('\''),
"quot" => Some('"'),
_ => {
let digits = body.strip_prefix('#')?;
let code = if let Some(hex) = digits.strip_prefix(['x', 'X']) {
u32::from_str_radix(hex, 16).ok()?
} else {
digits.parse::<u32>().ok()?
};
let ch = char::from_u32(code)?;
if is_valid_xml_char(ch) {
Some(ch)
} else {
None
}
}
}
}
#[inline]
fn is_valid_xml_char(ch: char) -> bool {
matches!(ch,
'\u{0009}' | '\u{000A}' | '\u{000D}'
| '\u{0020}'..='\u{007E}'
| '\u{00A0}'..='\u{D7FF}'
| '\u{E000}'..='\u{FFFD}'
| '\u{10000}'..='\u{10FFFF}'
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn text_escapes_lt_gt_amp() {
assert_eq!(escape_text("a < b & c > d"), "a < b & c > d");
}
#[test]
fn text_leaves_quotes_alone() {
assert_eq!(escape_text("he said \"hi\""), "he said \"hi\"");
}
#[test]
fn attr_escapes_quotes_too() {
assert_eq!(escape_attr("\" onclick=evil"), "" onclick=evil");
}
#[test]
fn passthrough_for_safe_input() {
let text = escape_text("hello world");
assert!(matches!(text, Cow::Borrowed(_)));
assert_eq!(text, "hello world");
let attr = escape_attr("hello world");
assert!(matches!(attr, Cow::Borrowed(_)));
assert_eq!(attr, "hello world");
}
#[test]
fn handles_multibyte_utf8() {
assert_eq!(escape_text("café — 日本"), "café — 日本");
assert_eq!(escape_attr("café — 日本"), "café — 日本");
}
#[test]
fn names_predicates() {
assert!(is_valid_name("task"));
assert!(is_valid_name("_data"));
assert!(is_valid_name("a-b_c.d"));
assert!(!is_valid_name(""));
assert!(!is_valid_name("1abc"));
assert!(!is_valid_name("a b"));
assert!(!is_valid_name("a\"b"));
assert!(!is_valid_name("a=b"));
}
}