use crate::error::{BlocksError, Result};
use regex::Regex;
use std::collections::HashSet;
const ALLOWED_TAGS: &[&str] = &[
"p",
"br",
"strong",
"em",
"u",
"s",
"code",
"pre",
"h1",
"h2",
"h3",
"h4",
"h5",
"h6",
"ul",
"ol",
"li",
"blockquote",
"a",
"img",
"table",
"tr",
"td",
"th",
"thead",
"tbody",
"span",
"div",
"section",
"article",
"header",
"footer",
"nav",
];
const ALLOWED_ATTRIBUTES: &[&str] = &[
"href",
"src",
"alt",
"title",
"class",
"id",
"data-block-id",
"data-type",
"width",
"height",
"loading",
];
const DANGEROUS_KEYWORDS: &[&str] = &[
"javascript:",
"data:",
"vbscript:",
"onerror",
"onload",
"onclick",
"onmouseover",
"onfocus",
"onblur",
"onchange",
"onsubmit",
"onkeydown",
"onkeyup",
];
pub struct ContentSanitizer {
allowed_tags: HashSet<String>,
allowed_attributes: HashSet<String>,
strict_mode: bool,
}
impl ContentSanitizer {
pub fn new() -> Self {
Self {
allowed_tags: ALLOWED_TAGS.iter().map(|s| s.to_string()).collect(),
allowed_attributes: ALLOWED_ATTRIBUTES.iter().map(|s| s.to_string()).collect(),
strict_mode: false,
}
}
pub fn strict() -> Self {
Self {
allowed_tags: HashSet::new(),
allowed_attributes: HashSet::new(),
strict_mode: true,
}
}
pub fn sanitize(&self, content: &str) -> Result<String> {
if self.strict_mode {
return Ok(self.strip_html(content));
}
let mut result = content.to_string();
for keyword in DANGEROUS_KEYWORDS {
result = self.remove_dangerous_pattern(keyword, &result)?;
}
result = self.remove_script_tags(&result);
result = self.remove_tags("iframe", &result);
result = self.remove_tags("style", &result);
result = self.remove_event_handlers(&result);
result = self.clean_attributes(&result);
Ok(result)
}
pub fn sanitize_text(&self, content: &str) -> Result<String> {
let escaped = html_escape::encode_text(content).to_string();
Ok(escaped)
}
fn remove_dangerous_pattern(&self, pattern: &str, content: &str) -> Result<String> {
let regex = Regex::new(&format!(r"(?i){}", regex::escape(pattern))).map_err(|e| {
BlocksError::CssError {
reason: format!("Regex error: {}", e),
}
})?;
Ok(regex.replace_all(content, "").to_string())
}
fn remove_script_tags(&self, content: &str) -> String {
let regex = Regex::new(r"(?i)<script[^>]*>.*?</script>")
.unwrap_or_else(|_| Regex::new(r"<>").unwrap());
regex.replace_all(content, "").to_string()
}
fn remove_tags(&self, tag: &str, content: &str) -> String {
let regex = Regex::new(&format!(
r"(?i)<{tag}[^>]*>.*?</{tag}>",
tag = regex::escape(tag)
))
.unwrap_or_else(|_| Regex::new(r"<>").unwrap());
regex.replace_all(content, "").to_string()
}
fn remove_event_handlers(&self, content: &str) -> String {
let regex = Regex::new(r#"\s+on[a-z]+\s*=\s*(?:"[^"]*"|'[^']*'|[^\s>]*)?"#)
.unwrap_or_else(|_| Regex::new(r"<>").unwrap());
regex.replace_all(content, "").to_string()
}
fn clean_attributes(&self, content: &str) -> String {
let mut result = content.to_string();
result = Regex::new(r#"href\s*=\s*"javascript:[^"]*""#)
.map(|r| r.replace_all(&result, r#"href=""#).to_string())
.unwrap_or(result);
result = Regex::new(r#"src\s*=\s*"data:[^"]*""#)
.map(|r| r.replace_all(&result, r#"src=""#).to_string())
.unwrap_or(result);
result
}
fn strip_html(&self, content: &str) -> String {
let regex = Regex::new(r"<[^>]*>").unwrap_or_else(|_| Regex::new(r"<>").unwrap());
regex.replace_all(content, "").to_string()
}
pub fn validate_url(&self, url: &str) -> Result<()> {
if url.is_empty() {
return Ok(());
}
for keyword in DANGEROUS_KEYWORDS {
if url.to_lowercase().starts_with(keyword) {
return Err(BlocksError::ValidationError {
message: format!("Dangerous URL protocol detected: {}", keyword),
});
}
}
if url.contains("javascript:") || url.contains("data:text/html") {
return Err(BlocksError::ValidationError {
message: "URL contains potentially dangerous content".to_string(),
});
}
Ok(())
}
}
impl Default for ContentSanitizer {
fn default() -> Self {
Self::new()
}
}
impl ContentSanitizer {
pub fn is_tag_allowed(&self, tag: &str) -> bool {
self.allowed_tags.contains(&tag.to_lowercase())
}
pub fn is_attribute_allowed(&self, attr: &str) -> bool {
self.allowed_attributes.contains(&attr.to_lowercase())
}
pub fn allow_tag(&mut self, tag: &str) {
self.allowed_tags.insert(tag.to_lowercase());
}
pub fn allow_attribute(&mut self, attr: &str) {
self.allowed_attributes.insert(attr.to_lowercase());
}
pub fn disallow_tag(&mut self, tag: &str) {
self.allowed_tags.remove(&tag.to_lowercase());
}
pub fn disallow_attribute(&mut self, attr: &str) {
self.allowed_attributes.remove(&attr.to_lowercase());
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_remove_script_tags() {
let sanitizer = ContentSanitizer::new();
let dangerous = r#"<p>Hello</p><script>alert('xss')</script><p>World</p>"#;
let safe = sanitizer.sanitize(dangerous).unwrap();
assert!(!safe.contains("script"));
assert!(safe.contains("Hello"));
assert!(safe.contains("World"));
}
#[test]
fn test_remove_event_handlers() {
let sanitizer = ContentSanitizer::new();
let dangerous = r#"<p onclick="alert('xss')">Click me</p>"#;
let safe = sanitizer.sanitize(dangerous).unwrap();
assert!(!safe.contains("onclick"));
assert!(!safe.to_lowercase().contains("onclick="));
}
#[test]
fn test_javascript_protocol() {
let sanitizer = ContentSanitizer::new();
let dangerous = r#"<a href="javascript:alert('xss')">Click</a>"#;
let safe = sanitizer.sanitize(dangerous).unwrap();
assert!(!safe.to_lowercase().contains("javascript:"));
}
#[test]
fn test_strict_mode() {
let sanitizer = ContentSanitizer::strict();
let html = "<p>Hello <b>World</b></p>";
let result = sanitizer.sanitize(html).unwrap();
assert_eq!(result, "Hello World");
}
#[test]
fn test_validate_safe_url() {
let sanitizer = ContentSanitizer::new();
assert!(sanitizer.validate_url("https://example.com").is_ok());
assert!(sanitizer.validate_url("http://example.com/path").is_ok());
assert!(sanitizer.validate_url("/relative/path").is_ok());
}
#[test]
fn test_validate_dangerous_url() {
let sanitizer = ContentSanitizer::new();
assert!(sanitizer.validate_url("javascript:alert('xss')").is_err());
assert!(sanitizer
.validate_url("data:text/html,<script>alert('xss')</script>")
.is_err());
}
#[test]
fn test_sanitize_text() {
let sanitizer = ContentSanitizer::new();
let text = "<script>alert('xss')</script>";
let safe = sanitizer.sanitize_text(text).unwrap();
assert!(!safe.contains('<'));
assert!(!safe.contains('>'));
}
}