use std::borrow::Cow;
pub struct XmlWriter {
buf: String,
}
impl XmlWriter {
pub fn new() -> Self {
XmlWriter { buf: String::new() }
}
pub fn with_capacity(capacity: usize) -> Self {
XmlWriter {
buf: String::with_capacity(capacity),
}
}
pub fn write_declaration(&mut self) {
self.buf
.push_str("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
}
pub fn write_declaration_full(
&mut self,
version: &str,
encoding: Option<&str>,
standalone: Option<bool>,
) {
self.buf.push_str("<?xml version=\"");
self.buf.push_str(&safe_xml_version(version));
self.buf.push('"');
if let Some(enc) = encoding {
self.buf.push_str(" encoding=\"");
self.buf.push_str(&safe_xml_encoding(enc));
self.buf.push('"');
}
if let Some(sa) = standalone {
self.buf.push_str(" standalone=\"");
self.buf.push_str(if sa { "yes" } else { "no" });
self.buf.push('"');
}
self.buf.push_str("?>");
}
pub fn start_element(&mut self, name: &str, attrs: &[(&str, &str)]) {
self.buf.push('<');
self.buf.push_str(name);
for &(key, val) in attrs {
self.buf.push(' ');
self.buf.push_str(key);
self.buf.push_str("=\"");
write_escaped_attr_to_string(&mut self.buf, val);
self.buf.push('"');
}
self.buf.push('>');
}
pub fn empty_element(&mut self, name: &str, attrs: &[(&str, &str)]) {
self.buf.push('<');
self.buf.push_str(name);
for &(key, val) in attrs {
self.buf.push(' ');
self.buf.push_str(key);
self.buf.push_str("=\"");
write_escaped_attr_to_string(&mut self.buf, val);
self.buf.push('"');
}
self.buf.push_str("/>");
}
pub fn start_element_with<I, K, V>(&mut self, name: &str, attrs: I)
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
self.buf.push('<');
self.buf.push_str(name);
for (key, val) in attrs {
self.buf.push(' ');
self.buf.push_str(key.as_ref());
self.buf.push_str("=\"");
write_escaped_attr_to_string(&mut self.buf, val.as_ref());
self.buf.push('"');
}
self.buf.push('>');
}
pub fn empty_element_with<I, K, V>(&mut self, name: &str, attrs: I)
where
I: IntoIterator<Item = (K, V)>,
K: AsRef<str>,
V: AsRef<str>,
{
self.buf.push('<');
self.buf.push_str(name);
for (key, val) in attrs {
self.buf.push(' ');
self.buf.push_str(key.as_ref());
self.buf.push_str("=\"");
write_escaped_attr_to_string(&mut self.buf, val.as_ref());
self.buf.push('"');
}
self.buf.push_str("/>");
}
pub fn empty_element_expanded(&mut self, name: &str, attrs: &[(&str, &str)]) {
self.buf.push('<');
self.buf.push_str(name);
for &(key, val) in attrs {
self.buf.push(' ');
self.buf.push_str(key);
self.buf.push_str("=\"");
write_escaped_attr_to_string(&mut self.buf, val);
self.buf.push('"');
}
self.buf.push_str("></");
self.buf.push_str(name);
self.buf.push('>');
}
pub fn end_element(&mut self, name: &str) {
self.buf.push_str("</");
self.buf.push_str(name);
self.buf.push('>');
}
pub fn text(&mut self, content: &str) {
write_escaped_text_to_string(&mut self.buf, content);
}
pub fn cdata(&mut self, content: &str) {
self.buf.push_str("<![CDATA[");
self.buf.push_str(&split_cdata_content(content));
self.buf.push_str("]]>");
}
pub fn comment(&mut self, content: &str) {
self.buf.push_str("<!--");
self.buf.push_str(&sanitize_comment_content(content));
self.buf.push_str("-->");
}
pub fn processing_instruction(&mut self, target: &str, data: Option<&str>) {
self.buf.push_str("<?");
self.buf.push_str(&sanitize_pi_target(target));
if let Some(d) = data {
self.buf.push(' ');
self.buf.push_str(&sanitize_pi_data(d));
}
self.buf.push_str("?>");
}
pub fn raw(&mut self, xml: &str) {
self.buf.push_str(xml);
}
pub fn as_str(&self) -> &str {
&self.buf
}
pub fn into_string(self) -> String {
self.buf
}
pub fn into_bytes(self) -> Vec<u8> {
self.buf.into_bytes()
}
pub fn len(&self) -> usize {
self.buf.len()
}
pub fn is_empty(&self) -> bool {
self.buf.is_empty()
}
}
impl Default for XmlWriter {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for XmlWriter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.buf)
}
}
pub(crate) fn sanitize_comment_content(s: &str) -> Cow<'_, str> {
if !s.contains("--") && !s.ends_with('-') {
return Cow::Borrowed(s);
}
let mut out = String::with_capacity(s.len() + 4);
let mut prev_was_dash = false;
for c in s.chars() {
if c == '-' && prev_was_dash {
out.push(' ');
}
out.push(c);
prev_was_dash = c == '-';
}
if out.ends_with('-') {
out.push(' ');
}
Cow::Owned(out)
}
pub(crate) fn sanitize_pi_data(s: &str) -> Cow<'_, str> {
if !s.contains("?>") {
return Cow::Borrowed(s);
}
Cow::Owned(s.replace("?>", "? >"))
}
pub(crate) fn sanitize_pi_target(s: &str) -> Cow<'_, str> {
if s.eq_ignore_ascii_case("xml") {
Cow::Owned(format!("_{}", s))
} else {
Cow::Borrowed(s)
}
}
pub(crate) fn split_cdata_content(s: &str) -> Cow<'_, str> {
if !s.contains("]]>") {
return Cow::Borrowed(s);
}
Cow::Owned(s.replace("]]>", "]]]]><![CDATA[>"))
}
pub(crate) fn safe_xml_version(s: &str) -> Cow<'_, str> {
if s == "1.0" || s == "1.1" {
Cow::Borrowed(s)
} else {
Cow::Borrowed("1.0")
}
}
pub(crate) fn safe_xml_encoding(s: &str) -> Cow<'_, str> {
if is_valid_xml_encoding(s) {
Cow::Borrowed(s)
} else {
Cow::Borrowed("UTF-8")
}
}
fn is_valid_xml_encoding(s: &str) -> bool {
let mut bytes = s.bytes();
match bytes.next() {
Some(b) if b.is_ascii_alphabetic() => {}
_ => return false,
}
bytes.all(|b| b.is_ascii_alphanumeric() || matches!(b, b'.' | b'_' | b'-'))
}
fn write_escaped_text_to_string(buf: &mut String, s: &str) {
for c in s.chars() {
match c {
'&' => buf.push_str("&"),
'<' => buf.push_str("<"),
'>' => buf.push_str(">"),
'\r' => buf.push_str("
"),
_ => buf.push(c),
}
}
}
fn write_escaped_attr_to_string(buf: &mut String, s: &str) {
for c in s.chars() {
match c {
'&' => buf.push_str("&"),
'<' => buf.push_str("<"),
'>' => buf.push_str(">"),
'"' => buf.push_str("""),
'\t' => buf.push_str("	"),
'\n' => buf.push_str("
"),
'\r' => buf.push_str("
"),
_ => buf.push(c),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sanitize_comment_passes_safe_content() {
assert!(matches!(
sanitize_comment_content("hello world"),
Cow::Borrowed(_)
));
assert!(matches!(sanitize_comment_content(""), Cow::Borrowed(_)));
assert!(matches!(
sanitize_comment_content("single - dash"),
Cow::Borrowed(_)
));
}
#[test]
fn sanitize_comment_separates_consecutive_dashes() {
assert_eq!(&*sanitize_comment_content("a--b"), "a- -b");
assert_eq!(&*sanitize_comment_content("a---b"), "a- - -b");
assert_eq!(&*sanitize_comment_content("--"), "- - ");
assert_eq!(&*sanitize_comment_content("-->"), "- ->");
}
#[test]
fn sanitize_comment_pads_trailing_dash() {
assert_eq!(&*sanitize_comment_content("foo-"), "foo- ");
assert_eq!(&*sanitize_comment_content("-"), "- ");
assert_eq!(&*sanitize_comment_content("a--"), "a- - ");
}
#[test]
fn sanitize_pi_data_inserts_space_in_terminator() {
assert!(matches!(sanitize_pi_data("safe data"), Cow::Borrowed(_)));
assert_eq!(&*sanitize_pi_data("a?>b"), "a? >b");
assert_eq!(&*sanitize_pi_data("?>?>"), "? >? >");
assert_eq!(&*sanitize_pi_data(""), "");
}
#[test]
fn sanitize_pi_target_renames_reserved_xml() {
assert_eq!(&*sanitize_pi_target("xml"), "_xml");
assert_eq!(&*sanitize_pi_target("XML"), "_XML");
assert_eq!(&*sanitize_pi_target("Xml"), "_Xml");
assert_eq!(&*sanitize_pi_target("xMl"), "_xMl");
}
#[test]
fn sanitize_pi_target_passes_legitimate_names() {
assert!(matches!(sanitize_pi_target("xsl"), Cow::Borrowed(_)));
assert!(matches!(
sanitize_pi_target("xml-stylesheet"),
Cow::Borrowed(_)
));
assert!(matches!(sanitize_pi_target("xmlrpc"), Cow::Borrowed(_)));
assert!(matches!(sanitize_pi_target(""), Cow::Borrowed(_)));
}
#[test]
fn split_cdata_preserves_safe_content() {
assert!(matches!(
split_cdata_content("hello world"),
Cow::Borrowed(_)
));
assert!(matches!(split_cdata_content(""), Cow::Borrowed(_)));
}
#[test]
fn split_cdata_splits_terminator() {
assert_eq!(
&*split_cdata_content("hello]]>world"),
"hello]]]]><![CDATA[>world"
);
assert_eq!(&*split_cdata_content("]]>"), "]]]]><![CDATA[>");
}
#[test]
fn roundtrip_comment_smuggle_is_blocked() {
let mut w = XmlWriter::new();
w.start_element("r", &[]);
w.comment("safe --> <injected/> <!--trailing");
w.end_element("r");
let out = w.into_string();
let doc = crate::parse(&out).expect("sanitized output must reparse");
let root = doc.document_element().unwrap();
let element_children: Vec<_> = doc
.children(root)
.into_iter()
.filter(|c| matches!(doc.node_kind(*c), Some(crate::NodeKind::Element(_))))
.collect();
assert!(
element_children.is_empty(),
"comment sanitization failed; output smuggled an element: {:?}",
out
);
}
#[test]
fn roundtrip_pi_smuggle_is_blocked() {
let mut w = XmlWriter::new();
w.start_element("r", &[]);
w.processing_instruction("x", Some("?><injected/>"));
w.end_element("r");
let out = w.into_string();
let doc = crate::parse(&out).expect("sanitized output must reparse");
let root = doc.document_element().unwrap();
let element_children: Vec<_> = doc
.children(root)
.into_iter()
.filter(|c| matches!(doc.node_kind(*c), Some(crate::NodeKind::Element(_))))
.collect();
assert!(
element_children.is_empty(),
"PI sanitization failed; output smuggled an element: {:?}",
out
);
}
#[test]
fn roundtrip_pi_reserved_xml_target_is_renamed() {
let mut w = XmlWriter::new();
w.start_element("r", &[]);
w.processing_instruction("xml", Some("version=\"1.0\" standalone=\"yes\""));
w.end_element("r");
let out = w.into_string();
assert!(
!out.contains("<?xml "),
"reserved `xml` target must not reach the output: {:?}",
out
);
let doc = crate::parse(&out).expect("sanitized output must reparse");
let root = doc.document_element().unwrap();
let pi_children: Vec<_> = doc
.children(root)
.into_iter()
.filter_map(|c| match doc.node_kind(c) {
Some(crate::NodeKind::ProcessingInstruction(pi)) => Some(pi),
_ => None,
})
.collect();
assert_eq!(pi_children.len(), 1, "expected exactly one PI child");
assert_eq!(&*pi_children[0].target, "_xml");
}
#[test]
fn safe_xml_version_passes_valid() {
assert!(matches!(safe_xml_version("1.0"), Cow::Borrowed(_)));
assert!(matches!(safe_xml_version("1.1"), Cow::Borrowed(_)));
assert_eq!(&*safe_xml_version("1.0"), "1.0");
assert_eq!(&*safe_xml_version("1.1"), "1.1");
}
#[test]
fn safe_xml_version_rejects_invalid() {
assert_eq!(&*safe_xml_version(""), "1.0");
assert_eq!(&*safe_xml_version("1"), "1.0");
assert_eq!(&*safe_xml_version("1."), "1.0");
assert_eq!(&*safe_xml_version("2.0"), "1.0");
assert_eq!(&*safe_xml_version("1.10"), "1.0");
assert_eq!(&*safe_xml_version("1.42"), "1.0");
assert_eq!(&*safe_xml_version("1.0a"), "1.0");
assert_eq!(&*safe_xml_version("1.0\"?><x/><?y "), "1.0");
assert_eq!(&*safe_xml_version("1.0 "), "1.0");
}
#[test]
fn safe_xml_encoding_passes_valid() {
assert!(matches!(safe_xml_encoding("UTF-8"), Cow::Borrowed(_)));
assert!(matches!(safe_xml_encoding("utf-8"), Cow::Borrowed(_)));
assert!(matches!(safe_xml_encoding("ISO-8859-1"), Cow::Borrowed(_)));
assert!(matches!(safe_xml_encoding("US_ASCII.1"), Cow::Borrowed(_)));
assert_eq!(&*safe_xml_encoding("UTF-8"), "UTF-8");
}
#[test]
fn safe_xml_encoding_rejects_invalid() {
assert_eq!(&*safe_xml_encoding(""), "UTF-8");
assert_eq!(&*safe_xml_encoding("1UTF"), "UTF-8");
assert_eq!(&*safe_xml_encoding("-foo"), "UTF-8");
assert_eq!(&*safe_xml_encoding("UTF-8\"?><x/>"), "UTF-8");
assert_eq!(&*safe_xml_encoding("utf 8"), "UTF-8");
assert_eq!(&*safe_xml_encoding("utf\x00"), "UTF-8");
}
#[test]
fn roundtrip_xml_writer_declaration_version_injection_blocked() {
let mut w = XmlWriter::new();
w.write_declaration_full("1.0\"?><!-- smuggled -->", Some("UTF-8"), None);
w.start_element("r", &[]);
w.end_element("r");
let out = w.into_string();
assert!(
!out.contains("smuggled"),
"attacker-controlled version must not reach output: {:?}",
out
);
let doc = crate::parse(&out).expect("sanitized output must reparse");
assert_eq!(doc.xml_declaration.as_ref().unwrap().version, "1.0");
}
#[test]
fn roundtrip_xml_writer_declaration_encoding_injection_blocked() {
let mut w = XmlWriter::new();
w.write_declaration_full("1.0", Some("UTF-8\"?><inject/><?x "), None);
w.start_element("r", &[]);
w.end_element("r");
let out = w.into_string();
assert!(
!out.contains("<inject"),
"attacker-controlled encoding must not reach output: {:?}",
out
);
let doc = crate::parse(&out).expect("sanitized output must reparse");
let root = doc.document_element().unwrap();
match doc.node_kind(root) {
Some(crate::NodeKind::Element(e)) => {
assert_eq!(&*e.name.local_name, "r");
}
_ => panic!("expected element root"),
}
assert_eq!(
doc.xml_declaration.as_ref().unwrap().encoding.as_deref(),
Some("UTF-8")
);
}
#[test]
fn roundtrip_dom_declaration_version_injection_blocked() {
let mut doc = crate::parse("<r/>").expect("parse");
doc.xml_declaration = Some(crate::dom::XmlDeclaration {
version: "1.0\"?><forged/><?y ".into(),
encoding: Some("UTF-8".into()),
standalone: None,
});
let out = doc.to_xml();
assert!(
!out.contains("<forged"),
"DOM-mutation version injection not blocked: {:?}",
out
);
let reparsed = crate::parse(&out).expect("sanitized output must reparse");
assert_eq!(reparsed.xml_declaration.as_ref().unwrap().version, "1.0");
}
#[test]
fn roundtrip_dom_declaration_encoding_injection_blocked() {
let mut doc = crate::parse("<r/>").expect("parse");
doc.xml_declaration = Some(crate::dom::XmlDeclaration {
version: "1.0".into(),
encoding: Some("UTF-8\"?><forged/><?y ".into()),
standalone: None,
});
let out = doc.to_xml();
assert!(
!out.contains("<forged"),
"DOM-mutation encoding injection not blocked: {:?}",
out
);
let reparsed = crate::parse(&out).expect("sanitized output must reparse");
assert_eq!(
reparsed
.xml_declaration
.as_ref()
.unwrap()
.encoding
.as_deref(),
Some("UTF-8")
);
}
#[test]
fn roundtrip_cdata_smuggle_is_blocked() {
let mut w = XmlWriter::new();
w.start_element("r", &[]);
w.cdata("safe]]><injected/>more");
w.end_element("r");
let out = w.into_string();
let doc = crate::parse(&out).expect("split CDATA must reparse");
let root = doc.document_element().unwrap();
let element_children: Vec<_> = doc
.children(root)
.into_iter()
.filter(|c| matches!(doc.node_kind(*c), Some(crate::NodeKind::Element(_))))
.collect();
assert!(
element_children.is_empty(),
"CDATA split failed; output smuggled an element: {:?}",
out
);
assert_eq!(
doc.text_content_deep(root),
"safe]]><injected/>more",
"CDATA split must preserve the original text semantically"
);
}
}