use lol_html::{RewriteStrSettings, element, rewrite_str};
use scraper::{Html, Selector};
use std::collections::HashSet;
pub fn clean_html(
html: &str,
only_main_content: bool,
include_tags: &[String],
exclude_tags: &[String],
) -> Result<String, String> {
let mut handlers = vec![
element!("script", |el| {
el.remove();
Ok(())
}),
element!("style", |el| {
el.remove();
Ok(())
}),
element!("noscript", |el| {
el.remove();
Ok(())
}),
element!("iframe", |el| {
el.remove();
Ok(())
}),
element!("svg", |el| {
el.remove();
Ok(())
}),
element!("canvas", |el| {
el.remove();
Ok(())
}),
element!("img", |el| {
if let Some(src) = el.get_attribute("src")
&& src.starts_with("data:")
{
el.remove();
}
Ok(())
}),
];
if only_main_content {
handlers.push(element!("nav", |el| {
el.remove();
Ok(())
}));
handlers.push(element!("footer", |el| {
el.remove();
Ok(())
}));
handlers.push(element!("header", |el| {
el.remove();
Ok(())
}));
handlers.push(element!("aside", |el| {
el.remove();
Ok(())
}));
handlers.push(element!("menu", |el| {
el.remove();
Ok(())
}));
handlers.push(element!("select", |el| {
el.remove();
Ok(())
}));
handlers.push(element!("*", |el| {
let tag = el.tag_name();
let tag_name = tag.as_str();
if matches!(tag_name, "html" | "head" | "body" | "main") {
return Ok(());
}
let class = el.get_attribute("class").unwrap_or_default().to_lowercase();
let id = el.get_attribute("id").unwrap_or_default().to_lowercase();
let combined = format!("{class} {id}");
const NOISE_PATTERNS: &[&str] = &[
"sidebar",
"table-of-contents",
"tableofcontents",
"infobox",
"navbox",
"nav-box",
"navigation",
"breadcrumb",
"cookie",
"consent",
"banner",
"disqus",
"advert",
"popup",
"modal",
"newsletter",
"subscribe",
"printfooter",
"catlinks",
"mw-panel",
"mw-navigation",
"sitesub",
"jump-to-nav",
"mw-editsection",
"reflist",
"mw-references",
"authority-control",
"mw-indicators",
"sistersitebox",
"mbox",
"ambox",
"ombox",
"hatnote",
"shortdescription",
"sphinxsidebar",
"sphinxfooter",
"copyright",
"dropdown",
"city-selector",
"location-selector",
];
const NOISE_EXACT_TOKENS: &[&str] = &[
"toc", "share", "social", "related", "recommended",
"comment", "footer", ];
const NOISE_PREFIXES: &[&str] = &[
"ad-", "ads-",
];
let is_noise = NOISE_PATTERNS.iter().any(|p| combined.contains(p)) || {
let tokens_iter = class.split_whitespace().chain(std::iter::once(id.as_str()));
tokens_iter.into_iter().any(|tok| {
NOISE_EXACT_TOKENS.contains(&tok)
|| NOISE_PREFIXES.iter().any(|pre| tok.starts_with(pre))
})
};
if is_noise {
el.remove();
return Ok(());
}
let role = el.get_attribute("role").unwrap_or_default().to_lowercase();
if matches!(
role.as_str(),
"contentinfo" | "navigation" | "banner" | "complementary"
) {
el.remove();
return Ok(());
}
Ok(())
}));
}
let mut result = rewrite_str(
html,
RewriteStrSettings {
element_content_handlers: handlers,
..Default::default()
},
)
.map_err(|e| e.to_string())?;
if !include_tags.is_empty() {
result = keep_only_selectors(&result, include_tags);
}
if !exclude_tags.is_empty() {
result = remove_by_selectors(&result, exclude_tags);
}
Ok(result)
}
fn keep_only_selectors(html: &str, selectors: &[String]) -> String {
let doc = Html::parse_document(html);
let mut parts = Vec::new();
for sel_str in selectors {
match Selector::parse(sel_str) {
Ok(sel) => {
for el in doc.select(&sel) {
parts.push(el.html());
}
}
Err(e) => {
tracing::warn!("Invalid CSS selector '{}': {:?}", sel_str, e);
}
}
}
if parts.is_empty() {
return html.to_string();
}
parts.join("\n")
}
fn remove_by_selectors(html: &str, selectors: &[String]) -> String {
let doc = Html::parse_document(html);
let mut skip_ptrs: HashSet<*const scraper::node::Element> = HashSet::new();
for sel_str in selectors {
match Selector::parse(sel_str) {
Ok(sel) => {
for el in doc.select(&sel) {
skip_ptrs.insert(el.value() as *const _);
}
}
Err(e) => {
tracing::warn!("Invalid CSS selector '{}': {:?}", sel_str, e);
}
}
}
if skip_ptrs.is_empty() {
return html.to_string();
}
let root = doc.root_element();
let mut out = String::with_capacity(html.len());
collect_excluding(&root, &skip_ptrs, &mut out);
out
}
fn is_excluded(
el: &scraper::ElementRef,
skip_ptrs: &HashSet<*const scraper::node::Element>,
) -> bool {
let ptr = el.value() as *const scraper::node::Element;
skip_ptrs.contains(&ptr)
}
fn collect_excluding(
element: &scraper::ElementRef,
skip_ptrs: &HashSet<*const scraper::node::Element>,
out: &mut String,
) {
if is_excluded(element, skip_ptrs) {
return;
}
let el = element.value();
out.push('<');
out.push_str(&el.name.local);
for (name, value) in el.attrs() {
out.push(' ');
out.push_str(name);
out.push_str("=\"");
out.push_str(&value.replace('"', """));
out.push('"');
}
out.push('>');
for child in element.children() {
match child.value() {
scraper::node::Node::Text(text) => {
out.push_str(text);
}
scraper::node::Node::Element(_) => {
if let Some(child_el) = scraper::ElementRef::wrap(child) {
collect_excluding(&child_el, skip_ptrs, out);
}
}
_ => {}
}
}
let self_closing = matches!(
&*el.name.local,
"br" | "hr"
| "img"
| "input"
| "meta"
| "link"
| "area"
| "base"
| "col"
| "embed"
| "source"
| "track"
| "wbr"
);
if !self_closing {
out.push_str("</");
out.push_str(&el.name.local);
out.push('>');
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn strips_scripts_and_styles() {
let html =
r#"<html><body><script>alert(1)</script><p>Hello</p><style>x{}</style></body></html>"#;
let result = clean_html(html, false, &[], &[]).unwrap();
assert!(!result.contains("<script>"));
assert!(!result.contains("<style>"));
assert!(result.contains("Hello"));
}
#[test]
fn strips_nav_footer_in_main_content_mode() {
let html = r#"<body><nav>Menu</nav><article>Content</article><footer>Foot</footer></body>"#;
let result = clean_html(html, true, &[], &[]).unwrap();
assert!(!result.contains("Menu"));
assert!(!result.contains("Foot"));
assert!(result.contains("Content"));
}
#[test]
fn exclude_tags_removes_matching_elements() {
let html = r#"<body><div class="ad">Ad stuff</div><p>Real content</p></body>"#;
let result = clean_html(html, false, &[], &["div.ad".into()]).unwrap();
assert!(!result.contains("Ad stuff"));
assert!(result.contains("Real content"));
}
#[test]
fn does_not_remove_html_body_with_noise_classes() {
let html = r#"<html class="vector-toc-available"><body><main class="mw-body"><p>Content</p></main></body></html>"#;
let result = clean_html(html, true, &[], &[]).unwrap();
assert!(
result.contains("Content"),
"Structural elements must not be removed by noise patterns"
);
}
#[test]
fn strips_role_contentinfo_in_main_content_mode() {
let html = r#"<body><div role="contentinfo">Copyright 2024</div><p>Content</p><div role="navigation">Nav</div></body>"#;
let result = clean_html(html, true, &[], &[]).unwrap();
assert!(!result.contains("Copyright"));
assert!(!result.contains("Nav"));
assert!(result.contains("Content"));
}
#[test]
fn strips_sphinx_patterns_in_main_content_mode() {
let html = r#"<body><div class="sphinxsidebar">Sidebar</div><p>Content</p><div class="copyright">Copyright</div></body>"#;
let result = clean_html(html, true, &[], &[]).unwrap();
assert!(!result.contains("Sidebar"));
assert!(!result.contains("Copyright"));
assert!(result.contains("Content"));
}
#[test]
fn include_tags_keeps_only_matching() {
let html =
r#"<body><nav>Nav</nav><article><p>Article</p></article><footer>Foot</footer></body>"#;
let result = clean_html(html, false, &["article".into()], &[]).unwrap();
assert!(result.contains("Article"));
assert!(!result.contains("Nav"));
assert!(!result.contains("Foot"));
}
}