use ammonia::Builder;
use std::collections::HashSet;
pub fn sanitize_html(input: &str) -> String {
let safe_tags: HashSet<_> = [
"a",
"abbr",
"acronym",
"b",
"cite",
"code",
"em",
"i",
"kbd",
"mark",
"s",
"samp",
"small",
"strike",
"strong",
"sub",
"sup",
"u",
"var",
"br",
"div",
"hr",
"p",
"span",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"dd",
"dl",
"dt",
"li",
"ol",
"ul",
"caption",
"table",
"tbody",
"td",
"tfoot",
"th",
"thead",
"tr",
"blockquote",
"q",
"pre",
"img",
]
.into_iter()
.collect();
let safe_attrs: HashSet<_> = ["alt", "cite", "class", "href", "id", "src", "title"]
.into_iter()
.collect();
let safe_url_schemes: HashSet<_> = ["http", "https", "mailto"].into_iter().collect();
Builder::default()
.tags(safe_tags)
.generic_attributes(safe_attrs)
.link_rel(Some("nofollow noopener noreferrer"))
.url_schemes(safe_url_schemes)
.clean(input)
.to_string()
}
pub fn decode_entities(input: &str) -> String {
html_escape::decode_html_entities(input).to_string()
}
pub fn sanitize_and_decode(input: &str) -> String {
let decoded = decode_entities(input);
sanitize_html(&decoded)
}
pub fn strip_tags(input: &str) -> String {
Builder::default()
.tags(HashSet::new())
.clean(input)
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sanitize_removes_script() {
let html = r"<p>Hello</p><script>alert('XSS')</script>";
let clean = sanitize_html(html);
assert!(!clean.contains("script"));
assert!(clean.contains("Hello"));
}
#[test]
fn test_sanitize_allows_safe_tags() {
let html = r#"<p>Hello <b>world</b> <a href="http://example.com">link</a></p>"#;
let clean = sanitize_html(html);
assert!(clean.contains("<p>"));
assert!(clean.contains("<b>"));
assert!(clean.contains("<a"));
}
#[test]
fn test_sanitize_removes_onclick() {
let html = r#"<a href="/" onclick="alert('XSS')">Click</a>"#;
let clean = sanitize_html(html);
assert!(!clean.contains("onclick"));
assert!(clean.contains("href"));
}
#[test]
fn test_decode_entities() {
assert_eq!(decode_entities("<p>"), "<p>");
assert_eq!(decode_entities("&"), "&");
assert_eq!(decode_entities("""), "\"");
assert_eq!(decode_entities("'"), "'");
}
#[test]
fn test_decode_numeric_entities() {
assert_eq!(decode_entities("<"), "<");
assert_eq!(decode_entities("<"), "<");
}
#[test]
fn test_sanitize_and_decode() {
let input = "<p>Safe</p><script>Bad</script>";
let output = sanitize_and_decode(input);
assert!(output.contains("<p>Safe</p>"));
assert!(!output.contains("script"));
}
#[test]
fn test_strip_tags() {
let html = "<p>Hello <b>world</b></p>";
assert_eq!(strip_tags(html), "Hello world");
}
#[test]
fn test_xss_img_onerror() {
let html = r#"<img src="x" onerror="alert('XSS')">"#;
let clean = sanitize_html(html);
assert!(!clean.contains("onerror"));
}
#[test]
fn test_xss_javascript_url() {
let html = r#"<a href="javascript:alert('XSS')">Click</a>"#;
let clean = sanitize_html(html);
assert!(!clean.contains("javascript:"));
}
#[test]
fn test_xss_iframe() {
let html = r#"<iframe src="http://evil.com"></iframe>"#;
let clean = sanitize_html(html);
assert!(!clean.contains("iframe"));
}
#[test]
fn test_xss_data_url() {
let html = r#"<a href="data:text/html,<script>alert('XSS')</script>">Click</a>"#;
let clean = sanitize_html(html);
assert!(!clean.contains("data:"));
}
#[test]
fn test_sanitize_empty_string() {
assert_eq!(sanitize_html(""), "");
}
#[test]
fn test_sanitize_plain_text() {
let text = "Plain text with no tags";
assert_eq!(sanitize_html(text), text);
}
#[test]
fn test_decode_entities_no_entities() {
let text = "No entities here";
assert_eq!(decode_entities(text), text);
}
#[test]
fn test_strip_tags_nested() {
let html = "<div><p>Hello <span><b>world</b></span></p></div>";
assert_eq!(strip_tags(html), "Hello world");
}
#[test]
fn test_sanitize_link_rel_attribute() {
let html = r#"<a href="http://example.com">Link</a>"#;
let clean = sanitize_html(html);
assert!(clean.contains("nofollow"));
assert!(clean.contains("noopener"));
assert!(clean.contains("noreferrer"));
}
}