use ego_tree::NodeId;
use scraper::{Html, Node};
use crate::dom;
const ALLOWED_ATTRIBUTES: &[&str] = &[
"accent",
"accentunder",
"align",
"alt",
"allow",
"allowfullscreen",
"aria-label",
"checked",
"class",
"colspan",
"columnalign",
"columnlines",
"columnspacing",
"columnspan",
"controls",
"data-callout",
"data-callout-fold",
"data-callout-title",
"data-lang",
"data-latex",
"data-mjx-texclass",
"data-src",
"data-srcset",
"depth",
"dir",
"display",
"displaystyle",
"fence",
"frame",
"frameborder",
"framespacing",
"headers",
"height",
"href",
"id",
"kind",
"label",
"lang",
"linethickness",
"lspace",
"mathsize",
"mathvariant",
"maxsize",
"minsize",
"movablelimits",
"notation",
"poster",
"role",
"rowalign",
"rowlines",
"rowspacing",
"rowspan",
"rspace",
"scriptlevel",
"separator",
"src",
"srclang",
"srcset",
"stretchy",
"symmetric",
"title",
"type",
"voffset",
"width",
"xmlns",
];
const ALLOWED_EMPTY: &[&str] = &[
"area", "audio", "base", "br", "circle", "col", "defs", "ellipse", "embed", "figure", "g",
"hr", "iframe", "img", "input", "line", "link", "mask", "meta", "object", "param", "path",
"pattern", "picture", "polygon", "polyline", "rect", "source", "stop", "svg", "td", "th",
"track", "use", "video", "wbr",
];
const UNSAFE_ELEMENTS: &[&str] = &["frame", "frameset", "object", "embed", "applet", "base"];
const URL_ATTRS: &[&str] = &["href", "src", "action", "formaction"];
pub fn standardize_content(html: &mut Html, main_content: NodeId) {
remove_wbr_elements(html, main_content);
clean_attributes(html, main_content);
remove_empty_elements(html, main_content);
normalize_headings(html, main_content);
unwrap_wrapper_divs(html, main_content);
}
fn remove_wbr_elements(html: &mut Html, main_content: NodeId) {
let wbrs = dom::descendant_elements_by_tag(html, main_content, "wbr");
for id in wbrs {
dom::remove_node(html, id);
}
}
pub fn clean_attributes_on(html: &mut Html, main_content: NodeId) {
clean_attributes(html, main_content);
}
fn clean_attributes(html: &mut Html, main_content: NodeId) {
let descendants = dom::all_descendant_elements(html, main_content);
for node_id in descendants {
let Some(node_ref) = html.tree.get(node_id) else {
continue;
};
let Node::Element(el) = node_ref.value() else {
continue;
};
let tag = el.name.local.as_ref().to_string();
let is_svg_related = matches!(
tag.as_str(),
"svg"
| "path"
| "circle"
| "rect"
| "line"
| "polygon"
| "polyline"
| "g"
| "defs"
| "use"
| "mask"
| "ellipse"
| "stop"
| "pattern"
| "text"
| "tspan"
| "clippath"
| "lineargradient"
| "radialgradient"
| "filter"
| "fegaussianblur"
| "feoffset"
| "feblend"
| "marker"
| "symbol"
| "image"
| "foreignobject"
| "desc"
| "metadata"
| "style"
) || is_inside_svg(html, node_id);
if is_svg_related {
continue;
}
let Some(mut node_mut) = html.tree.get_mut(node_id) else {
continue;
};
let Node::Element(el) = node_mut.value() else {
continue;
};
el.attrs
.retain(|(name, _)| ALLOWED_ATTRIBUTES.contains(&name.local.as_ref()));
}
}
fn remove_empty_elements(html: &mut Html, main_content: NodeId) {
let mut to_remove = Vec::new();
let descendants = dom::all_descendant_elements(html, main_content);
for node_id in descendants {
let Some(node_ref) = html.tree.get(node_id) else {
continue;
};
let Node::Element(el) = node_ref.value() else {
continue;
};
let tag = el.name.local.as_ref();
if ALLOWED_EMPTY.contains(&tag) {
continue;
}
let text = dom::text_content(html, node_id);
if text.trim().is_empty() && !node_ref.has_children() {
to_remove.push(node_id);
continue;
}
let only_brs = has_only_br_children(html, node_id);
if tag == "p" && text.trim().is_empty() && only_brs {
to_remove.push(node_id);
}
}
for id in to_remove {
dom::remove_node(html, id);
}
}
fn has_only_br_children(html: &Html, node_id: NodeId) -> bool {
let Some(node_ref) = html.tree.get(node_id) else {
return false;
};
let mut has_element_child = false;
for child in node_ref.children() {
if let Node::Element(el) = child.value() {
if el.name.local.as_ref() != "br" {
return false;
}
has_element_child = true;
}
}
has_element_child
}
fn normalize_headings(html: &mut Html, main_content: NodeId) {
let h1s = dom::descendant_elements_by_tag(html, main_content, "h1");
if h1s.len() <= 1 {
return;
}
for &h1_id in &h1s[1..] {
let Some(mut node_mut) = html.tree.get_mut(h1_id) else {
continue;
};
let Node::Element(el) = node_mut.value() else {
continue;
};
el.name.local = markup5ever::local_name!("h2");
}
}
fn unwrap_wrapper_divs(html: &mut Html, main_content: NodeId) {
let block_tags = [
"article",
"section",
"div",
"main",
"p",
"blockquote",
"figure",
"table",
"ul",
"ol",
"dl",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
];
let descendants = dom::all_descendant_elements(html, main_content);
let mut to_unwrap = Vec::new();
for node_id in descendants {
if node_id == main_content {
continue;
}
let Some(tag) = dom::tag_name(html, node_id) else {
continue;
};
if tag != "div" {
continue;
}
let el_id = dom::get_attr(html, node_id, "id").unwrap_or_default();
if el_id == "footnotes" || el_id.starts_with("fn:") {
continue;
}
let children = dom::child_elements(html, node_id);
if children.len() != 1 {
continue;
}
let child_tag = dom::tag_name(html, children[0]);
let is_block_child = child_tag
.as_ref()
.is_some_and(|t| block_tags.contains(&t.as_str()));
let text = dom::text_content(html, node_id);
let child_text = dom::text_content(html, children[0]);
let no_extra_text = text.trim().len() == child_text.trim().len();
if is_block_child && no_extra_text {
to_unwrap.push(node_id);
}
}
for &wrapper_id in to_unwrap.iter().rev() {
let child_ids = collect_child_ids(html, wrapper_id);
for child_id in child_ids {
let Some(mut wrapper_mut) = html.tree.get_mut(wrapper_id) else {
break;
};
wrapper_mut.insert_id_before(child_id);
}
dom::remove_node(html, wrapper_id);
}
}
fn collect_child_ids(html: &Html, node_id: NodeId) -> Vec<NodeId> {
let Some(node_ref) = html.tree.get(node_id) else {
return Vec::new();
};
node_ref.children().map(|c| c.id()).collect()
}
pub fn strip_unsafe_elements(html: &mut Html) {
remove_unsafe_tags(html);
remove_dangerous_attributes(html);
}
fn remove_unsafe_tags(html: &mut Html) {
let mut to_remove = Vec::new();
for tag in UNSAFE_ELEMENTS {
let ids = dom::select_ids(html, tag);
to_remove.extend(ids);
}
for id in to_remove {
dom::remove_node(html, id);
}
}
fn remove_dangerous_attributes(html: &mut Html) {
let all = dom::all_descendant_elements(html, html.tree.root().id());
for node_id in all {
let is_svg_child = is_inside_svg(html, node_id);
let Some(mut node_mut) = html.tree.get_mut(node_id) else {
continue;
};
let Node::Element(el) = node_mut.value() else {
continue;
};
if is_svg_child && el.name.local.as_ref() == "style" {
continue;
}
el.attrs.retain(|(name, value)| {
let n = name.local.as_ref();
if n.starts_with("on") {
return false;
}
if n == "srcdoc" {
return false;
}
if URL_ATTRS.contains(&n) && is_dangerous_uri(value) {
return false;
}
true
});
for (name, value) in &mut el.attrs {
let n = name.local.as_ref();
if n == "srcset" || n == "data-srcset" {
let candidates = split_srcset_candidates(value);
let safe: Vec<&str> = candidates
.into_iter()
.filter(|entry| {
let url = entry.split_whitespace().next().unwrap_or("");
!is_dangerous_uri(url)
})
.collect();
*value = safe.join(",").into();
}
}
}
}
fn is_dangerous_uri(value: &str) -> bool {
let trimmed = value.to_ascii_lowercase();
let trimmed = trimmed.trim();
if trimmed.starts_with("javascript:") {
return true;
}
if let Some(data_uri) = trimmed.strip_prefix("data:") {
let media_type = data_uri.split(',').next().unwrap_or("").trim();
let essence = media_type.split(';').next().unwrap_or("").trim();
let safe = matches!(
essence,
"image/png" | "image/jpeg" | "image/gif" | "image/webp" | "image/avif"
);
return !safe;
}
false
}
fn is_inside_svg(html: &Html, node_id: NodeId) -> bool {
let mut current = node_id;
loop {
let Some(node_ref) = html.tree.get(current) else {
return false;
};
let Some(parent) = node_ref.parent() else {
return false;
};
if let Node::Element(el) = parent.value()
&& el.name.local.as_ref() == "svg"
{
return true;
}
current = parent.id();
}
}
pub fn resolve_urls(html: &mut Html, main_content: NodeId, base_url: &str) {
let base = resolve_base_url(html, base_url);
let Some(base) = base else {
return;
};
let attrs_to_resolve = [
("a", "href"),
("img", "src"),
("img", "srcset"),
("video", "poster"),
("source", "src"),
("source", "srcset"),
("iframe", "src"),
];
for (tag, attr) in &attrs_to_resolve {
let elements = dom::descendant_elements_by_tag(html, main_content, tag);
for node_id in elements {
resolve_single_attr(html, node_id, &base, attr);
}
}
}
fn resolve_base_url(html: &Html, base_url: &str) -> Option<url::Url> {
if let Ok(parsed) = url::Url::parse(base_url) {
return Some(parsed);
}
let base_ids = dom::select_ids(html, "base[href]");
for id in base_ids {
if let Some(href) = dom::get_attr(html, id, "href")
&& let Ok(parsed) = url::Url::parse(&href)
{
return Some(parsed);
}
}
None
}
fn resolve_single_attr(html: &mut Html, node_id: NodeId, base: &url::Url, attr: &str) {
let Some(val) = dom::get_attr(html, node_id, attr) else {
return;
};
if attr == "srcset" {
resolve_srcset(html, node_id, base);
return;
}
if val.starts_with("http://") || val.starts_with("https://") || val.starts_with("//") {
return;
}
let Ok(resolved) = base.join(&val) else {
return;
};
dom::set_attr(html, node_id, attr, resolved.as_ref());
}
fn resolve_srcset(html: &mut Html, node_id: NodeId, base: &url::Url) {
let Some(val) = dom::get_attr(html, node_id, "srcset") else {
return;
};
let mut parts = Vec::new();
for entry in split_srcset_candidates(&val) {
let trimmed = entry.trim();
let mut tokens = trimmed.split_whitespace();
let Some(url_part) = tokens.next() else {
continue;
};
let descriptor: String = tokens.collect::<Vec<_>>().join(" ");
if is_dangerous_uri(url_part) {
continue;
}
let resolved = if url_part.starts_with("http://") || url_part.starts_with("https://") {
url_part.to_string()
} else {
base.join(url_part)
.map_or_else(|_| url_part.to_string(), |u| u.to_string())
};
if descriptor.is_empty() {
parts.push(resolved);
} else {
parts.push(format!("{resolved} {descriptor}"));
}
}
let new_val = parts.join(", ");
dom::set_attr(html, node_id, "srcset", &new_val);
}
fn split_srcset_candidates(srcset: &str) -> Vec<&str> {
let mut candidates = Vec::new();
let mut start = 0;
let mut i = 0;
let bytes = srcset.as_bytes();
while i < bytes.len() {
if bytes[i] == b',' {
let candidate = srcset[start..i].trim();
let has_descriptor = candidate
.rsplit_once(char::is_whitespace)
.is_some_and(|(_, desc)| desc.ends_with('x') || desc.ends_with('w'));
let is_data_uri = candidate
.split_whitespace()
.next()
.is_some_and(|url| url.starts_with("data:"));
if is_data_uri && !has_descriptor {
i += 1;
continue;
}
candidates.push(&srcset[start..i]);
start = i + 1;
}
i += 1;
}
if start < srcset.len() {
candidates.push(&srcset[start..]);
}
candidates
}
#[must_use]
pub fn sanitize_html_string(html_str: &str) -> String {
let mut html = Html::parse_fragment(html_str);
strip_unsafe_elements(&mut html);
dom::inner_html(&html, html.tree.root().id())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn resolve_source_srcset() {
let html_str = r#"<html><body>
<picture>
<source srcset="/small.jpg 1x, /large.jpg 2x">
<img src="/fallback.jpg">
</picture>
</body></html>"#;
let mut doc = Html::parse_document(html_str);
let body_ids = dom::select_ids(&doc, "body");
let body = body_ids[0];
resolve_urls(&mut doc, body, "https://example.com");
let output = dom::outer_html(&doc, body);
assert!(output.contains("https://example.com/small.jpg 1x"));
assert!(output.contains("https://example.com/large.jpg 2x"));
assert!(output.contains("https://example.com/fallback.jpg"));
}
#[test]
fn resolve_mixed_srcset() {
let html_str = r#"<html><body>
<img srcset="https://cdn.example.com/abs.jpg 1x, /relative.jpg 2x">
</body></html>"#;
let mut doc = Html::parse_document(html_str);
let body_ids = dom::select_ids(&doc, "body");
let body = body_ids[0];
resolve_urls(&mut doc, body, "https://example.com");
let output = dom::outer_html(&doc, body);
assert!(output.contains("https://cdn.example.com/abs.jpg 1x"));
assert!(output.contains("https://example.com/relative.jpg 2x"));
}
#[test]
fn strip_removes_frame_and_frameset() {
let mut doc = Html::parse_document(
"<html><body><frame src=\"x.html\"><frameset><frame src=\"y.html\"></frameset><p>keep</p></body></html>",
);
strip_unsafe_elements(&mut doc);
let out = dom::outer_html(&doc, doc.tree.root().id());
assert!(!out.contains("<frame"));
assert!(!out.contains("<frameset"));
assert!(out.contains("keep"));
}
#[test]
fn strip_removes_event_handler_attributes() {
let mut doc = Html::parse_document(
r#"<html><body><div onclick="evil()" onerror="bad()">text</div></body></html>"#,
);
strip_unsafe_elements(&mut doc);
let out = dom::outer_html(&doc, doc.tree.root().id());
assert!(!out.contains("onclick"));
assert!(!out.contains("onerror"));
assert!(out.contains("text"));
}
#[test]
fn strip_removes_javascript_uri_from_href() {
let mut doc = Html::parse_document(
r#"<html><body><a href="javascript:alert(1)">link</a></body></html>"#,
);
strip_unsafe_elements(&mut doc);
let out = dom::outer_html(&doc, doc.tree.root().id());
assert!(!out.contains("javascript:"));
}
#[test]
fn strip_removes_data_text_html_from_src() {
let mut doc = Html::parse_document(
r#"<html><body><img src="data:text/html,<script>alert(1)</script>"></body></html>"#,
);
strip_unsafe_elements(&mut doc);
let out = dom::outer_html(&doc, doc.tree.root().id());
assert!(!out.contains("data:text/html"));
}
#[test]
fn strip_removes_srcdoc_from_iframes() {
let mut doc = Html::parse_document(
r#"<html><body><iframe srcdoc="<script>x</script>"></iframe></body></html>"#,
);
strip_unsafe_elements(&mut doc);
let out = dom::outer_html(&doc, doc.tree.root().id());
assert!(!out.contains("srcdoc"));
}
#[test]
fn strip_preserves_style_inside_svg() {
let mut doc = Html::parse_document(
r"<html><body><svg><style>.cls{fill:red}</style><rect/></svg></body></html>",
);
strip_unsafe_elements(&mut doc);
let out = dom::outer_html(&doc, doc.tree.root().id());
assert!(out.contains("<style>"));
}
#[test]
fn strip_removes_object_embed_applet() {
let mut doc = Html::parse_document(
r#"<html><body>
<object data="x.swf"></object>
<embed src="y.swf">
<applet code="z.class"></applet>
<p>safe</p>
</body></html>"#,
);
strip_unsafe_elements(&mut doc);
let out = dom::outer_html(&doc, doc.tree.root().id());
assert!(!out.contains("<object"));
assert!(!out.contains("<embed"));
assert!(!out.contains("<applet"));
assert!(out.contains("safe"));
}
#[test]
fn dangerous_uri_blocks_javascript() {
assert!(is_dangerous_uri("javascript:alert(1)"));
assert!(is_dangerous_uri(" JavaScript:void(0) "));
}
#[test]
fn dangerous_uri_blocks_data_text_types() {
assert!(is_dangerous_uri("data:text/html,<script>x</script>"));
assert!(is_dangerous_uri("data:text/javascript,alert(1)"));
assert!(is_dangerous_uri("data:application/javascript,x"));
}
#[test]
fn dangerous_uri_allows_raster_images() {
assert!(!is_dangerous_uri("data:image/png;base64,iVBOR"));
assert!(!is_dangerous_uri("data:image/jpeg;base64,/9j/4"));
assert!(!is_dangerous_uri("data:image/gif;base64,R0lGO"));
assert!(!is_dangerous_uri("data:image/webp;base64,UklG"));
assert!(!is_dangerous_uri("data:image/avif;base64,AAAA"));
}
#[test]
fn dangerous_uri_blocks_svg_data_uri() {
assert!(is_dangerous_uri("data:image/svg+xml;utf8,<svg/>"));
assert!(is_dangerous_uri("data:image/svg+xml;charset=utf-8,<svg/>"));
assert!(is_dangerous_uri("data:image/svg+xml;base64,PHN2Zz4="));
}
#[test]
fn dangerous_uri_blocks_other_data_types() {
assert!(is_dangerous_uri("data:application/pdf,x"));
assert!(is_dangerous_uri("data:,arbitrary"));
}
#[test]
fn dangerous_uri_allows_normal_urls() {
assert!(!is_dangerous_uri("https://example.com"));
assert!(!is_dangerous_uri("/relative/path"));
assert!(!is_dangerous_uri(""));
}
#[test]
fn strip_removes_dangerous_data_uri_from_src() {
let mut doc = Html::parse_document(
r#"<html><body><img src="data:text/javascript,alert(1)"></body></html>"#,
);
strip_unsafe_elements(&mut doc);
let out = dom::outer_html(&doc, doc.tree.root().id());
assert!(!out.contains("data:text/javascript"));
}
#[test]
fn strip_preserves_data_image_in_src() {
let mut doc = Html::parse_document(
r#"<html><body><img src="data:image/gif;base64,R0lGODlh"></body></html>"#,
);
strip_unsafe_elements(&mut doc);
let out = dom::outer_html(&doc, doc.tree.root().id());
assert!(out.contains("data:image/gif"));
}
#[test]
fn srcset_filters_dangerous_uris() {
let html_str = r#"<html><body>
<img srcset="javascript:alert(1) 1x, /safe.jpg 2x">
</body></html>"#;
let mut doc = Html::parse_document(html_str);
let body_ids = dom::select_ids(&doc, "body");
let body = body_ids[0];
resolve_urls(&mut doc, body, "https://example.com");
let out = dom::outer_html(&doc, body);
assert!(!out.contains("javascript:"));
assert!(out.contains("https://example.com/safe.jpg 2x"));
}
#[test]
fn strip_sanitizes_srcset_without_url_resolution() {
let mut doc = Html::parse_document(
r#"<html><body><img srcset="javascript:x 1x, /safe.jpg 2x"></body></html>"#,
);
strip_unsafe_elements(&mut doc);
let out = dom::outer_html(&doc, doc.tree.root().id());
assert!(!out.contains("javascript:"));
assert!(out.contains("/safe.jpg 2x"));
}
#[test]
fn strip_blocks_svg_data_uri_in_src() {
let mut doc = Html::parse_document(
r#"<html><body><img src="data:image/svg+xml;utf8,<svg><script>alert(1)</script></svg>"></body></html>"#,
);
strip_unsafe_elements(&mut doc);
let out = dom::outer_html(&doc, doc.tree.root().id());
assert!(!out.contains("data:image/svg+xml"));
}
#[test]
fn split_srcset_simple() {
let result = split_srcset_candidates("/a.jpg 1x, /b.jpg 2x");
assert_eq!(result.len(), 2);
}
#[test]
fn split_srcset_preserves_data_uri() {
let srcset = "data:image/gif;base64,R0lGODlh 1x, /fallback.jpg 2x";
let result = split_srcset_candidates(srcset);
assert_eq!(result.len(), 2);
assert!(result[0].trim().starts_with("data:image/gif"));
assert!(result[1].trim().contains("fallback.jpg"));
}
}