use std::borrow::Cow;
pub fn sanitize_url(url: &str) -> Cow<'_, str> {
let normalized = remove_disallowed_blank_chars(url);
let url_lower = normalized.trim().to_lowercase();
if url_lower.starts_with("javascript:")
|| url_lower.starts_with("data:")
|| url_lower.starts_with("vbscript:")
|| url_lower.starts_with("file:")
{
return Cow::Borrowed("#blocked-url");
}
normalized
}
pub fn sanitize(input: &str) -> Cow<'_, str> {
let normalized = remove_disallowed_blank_chars(input);
let source = normalized.as_ref();
if !source.contains(&['<', '>', '&'][..]) {
return normalized;
}
let mut result = String::with_capacity(source.len() + 32);
let mut chars = source.chars().peekable();
while let Some(ch) = chars.next() {
match ch {
'<' => result.push_str("<"),
'>' => result.push_str(">"),
'&' => {
if is_html_entity(&mut chars.clone()) {
result.push(ch);
} else {
result.push_str("&");
}
}
_ => result.push(ch),
}
}
Cow::Owned(result)
}
fn remove_disallowed_blank_chars(input: &str) -> Cow<'_, str> {
if !input.chars().any(is_disallowed_blank_char) {
return Cow::Borrowed(input);
}
let filtered: String = input
.chars()
.filter(|&ch| !is_disallowed_blank_char(ch))
.collect();
Cow::Owned(filtered)
}
fn is_disallowed_blank_char(ch: char) -> bool {
matches!(
ch,
'\u{200B}' | '\u{200C}' | '\u{200D}' | '\u{FEFF}' | '\u{3164}' ) || ('\u{202A}'..='\u{202E}').contains(&ch) || ('\u{2066}'..='\u{2069}').contains(&ch) }
fn is_ascii_control_char(ch: char) -> bool {
let c = ch as u32;
matches!(c, 0x00..=0x08 | 0x0B | 0x0C | 0x0E..=0x1F | 0x7F)
}
pub fn remove_ascii_control_chars_from_markup(input: &str) -> std::borrow::Cow<'_, str> {
if !input.chars().any(is_ascii_control_char) {
return std::borrow::Cow::Borrowed(input);
}
let ends_with_newline = input.ends_with('\n');
let mut result = String::with_capacity(input.len());
let mut in_code_block = false;
let mut code_fence_char = '`';
for line in input.lines() {
let trimmed = line.trim_start();
if trimmed.starts_with("```") || trimmed.starts_with("~~~") {
let fence_char = if trimmed.starts_with("```") { '`' } else { '~' };
if !in_code_block {
in_code_block = true;
code_fence_char = fence_char;
} else if fence_char == code_fence_char {
in_code_block = false;
}
result.push_str(line);
result.push('\n');
continue;
}
if in_code_block {
result.push_str(line);
result.push('\n');
continue;
}
for c in line.chars() {
if !is_ascii_control_char(c) {
result.push(c);
}
}
result.push('\n');
}
if !ends_with_newline && result.ends_with('\n') {
result.pop();
}
std::borrow::Cow::Owned(result)
}
fn is_html_entity(chars: &mut std::iter::Peekable<std::str::Chars>) -> bool {
let mut entity = String::new();
let mut temp_chars = chars.clone();
while let Some(&ch) = temp_chars.peek() {
if ch == ';' {
return is_valid_entity(&entity);
}
if entity.len() > 10 {
return false;
}
if !ch.is_alphanumeric() && ch != '#' && ch != 'x' && ch != 'X' {
return false;
}
entity.push(ch);
temp_chars.next();
}
false
}
fn is_valid_entity(entity: &str) -> bool {
if entity.is_empty() {
return false;
}
if entity.starts_with('#') {
if entity.len() < 2 {
return false;
}
if entity[1..].starts_with('x') || entity[1..].starts_with('X') {
if entity.len() < 3 {
return false;
}
return entity[2..].chars().all(|c| c.is_ascii_hexdigit());
} else {
return entity[1..].chars().all(|c| c.is_ascii_digit());
}
}
matches!(
entity,
"nbsp"
| "lt"
| "gt"
| "amp"
| "quot"
| "apos"
| "copy"
| "reg"
| "trade"
| "ndash"
| "mdash"
| "lsquo"
| "rsquo"
| "ldquo"
| "rdquo"
| "hellip"
| "prime"
| "Prime"
| "euro"
| "yen"
| "pound"
| "cent"
| "times"
| "divide"
| "plusmn"
| "minus"
| "alpha"
| "beta"
| "gamma"
| "delta"
| "epsilon"
| "Alpha"
| "Beta"
| "Gamma"
| "Delta"
| "Epsilon" )
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_html() {
let input = "Hello World";
assert_eq!(sanitize(input), "Hello World");
}
#[test]
fn test_remove_disallowed_blank_like_chars() {
let input = "A\u{200B}B\u{200C}C\u{200D}D\u{FEFF}E\u{3164}F";
assert_eq!(sanitize(input), "ABCDEF");
}
#[test]
fn test_remove_bidi_control_chars() {
let input = "A\u{202A}B\u{202E}C\u{2066}D\u{2069}E";
assert_eq!(sanitize(input), "ABCDE");
}
#[test]
fn test_preserve_allowed_spaces_only() {
let input = "A B C";
assert_eq!(sanitize(input), "A B C");
}
#[test]
fn test_escape_tags() {
let input = "<script>alert('xss')</script>";
assert_eq!(sanitize(input), "<script>alert('xss')</script>");
}
#[test]
fn test_preserve_entities() {
let input = "Hello World";
assert_eq!(sanitize(input), "Hello World");
}
#[test]
fn test_escape_ampersand() {
let input = "A & B";
assert_eq!(sanitize(input), "A & B");
}
#[test]
fn test_mixed_content() {
let input = "<div>Hello World & stuff</div>";
assert_eq!(
sanitize(input),
"<div>Hello World & stuff</div>"
);
}
#[test]
fn test_numeric_entities() {
let input = "{ {";
assert_eq!(sanitize(input), "{ {");
}
#[test]
fn test_invalid_entity() {
let input = "&invalid;";
assert_eq!(sanitize(input), "&invalid;");
}
#[test]
fn test_xss_attempts() {
let test_cases = vec![
(
"<img src=x onerror=alert(1)>",
"<img src=x onerror=alert(1)>",
),
("<svg/onload=alert(1)>", "<svg/onload=alert(1)>"),
(
"<iframe src=javascript:alert(1)>",
"<iframe src=javascript:alert(1)>",
),
];
for (input, expected) in test_cases {
assert_eq!(sanitize(input), expected);
}
}
#[test]
fn test_entity_validation() {
assert!(is_valid_entity("nbsp"));
assert!(is_valid_entity("lt"));
assert!(is_valid_entity("gt"));
assert!(is_valid_entity("#123"));
assert!(is_valid_entity("#x7B"));
assert!(!is_valid_entity("invalid"));
assert!(!is_valid_entity(""));
}
#[test]
fn test_sanitize_url_safe_schemes() {
assert_eq!(sanitize_url("https://example.com"), "https://example.com");
assert_eq!(sanitize_url("http://example.com"), "http://example.com");
assert_eq!(
sanitize_url("mailto:user@example.com"),
"mailto:user@example.com"
);
assert_eq!(sanitize_url("ftp://example.com"), "ftp://example.com");
assert_eq!(sanitize_url("/relative/path"), "/relative/path");
assert_eq!(sanitize_url("./relative"), "./relative");
assert_eq!(sanitize_url("#anchor"), "#anchor");
}
#[test]
fn test_sanitize_url_custom_app_schemes() {
assert_eq!(sanitize_url("spotify:track:123"), "spotify:track:123");
assert_eq!(sanitize_url("steam://open/game"), "steam://open/game");
assert_eq!(sanitize_url("discord://invite/123"), "discord://invite/123");
assert_eq!(
sanitize_url("slack://channel?id=123"),
"slack://channel?id=123"
);
assert_eq!(sanitize_url("zoom:meeting:123"), "zoom:meeting:123");
assert_eq!(sanitize_url("vscode://file/path"), "vscode://file/path");
}
#[test]
fn test_sanitize_url_blocked_schemes() {
assert_eq!(sanitize_url("javascript:alert(1)"), "#blocked-url");
assert_eq!(sanitize_url("JavaScript:alert(1)"), "#blocked-url");
assert_eq!(sanitize_url("JAVASCRIPT:alert(1)"), "#blocked-url");
assert_eq!(
sanitize_url("data:text/html,<script>alert(1)</script>"),
"#blocked-url"
);
assert_eq!(sanitize_url("Data:text/html,test"), "#blocked-url");
assert_eq!(sanitize_url("vbscript:msgbox(1)"), "#blocked-url");
assert_eq!(sanitize_url("VBScript:msgbox(1)"), "#blocked-url");
assert_eq!(sanitize_url("file:///etc/passwd"), "#blocked-url");
assert_eq!(sanitize_url("FILE:///C:/Windows"), "#blocked-url");
}
#[test]
fn test_sanitize_url_with_whitespace() {
assert_eq!(sanitize_url(" javascript:alert(1) "), "#blocked-url");
assert_eq!(sanitize_url("\tdata:text/html,test\n"), "#blocked-url");
assert_eq!(
sanitize_url(" https://example.com "),
" https://example.com "
);
}
#[test]
fn test_sanitize_url_removes_disallowed_blank_like_chars() {
assert_eq!(
sanitize_url("https://exa\u{200B}mple.com/\u{3164}path"),
"https://example.com/path"
);
}
#[test]
fn test_sanitize_url_blocks_scheme_after_normalization() {
assert_eq!(sanitize_url("java\u{200B}script:alert(1)"), "#blocked-url");
assert_eq!(sanitize_url("data:\u{FEFF}text/html,test"), "#blocked-url");
assert_eq!(sanitize_url("java\u{202E}script:alert(1)"), "#blocked-url");
}
#[test]
fn test_ascii_control_chars_removed_from_text() {
let input = "hello\x00\x01\x07\x08\x0B\x0C\x0E\x7Fworld";
assert_eq!(remove_ascii_control_chars_from_markup(input), "helloworld");
}
#[test]
fn test_ascii_control_chars_preserved_tab_lf_cr() {
let input = "col1\tcol2\nline2\r\nline3";
let result = remove_ascii_control_chars_from_markup(input);
assert!(result.contains('\t'));
assert!(result.contains('\n'));
}
#[test]
fn test_ascii_control_chars_preserved_inside_code_fence() {
let input = "text\n```\nhello\x01world\n```\nafter";
let result = remove_ascii_control_chars_from_markup(input);
assert!(result.contains("hello\x01world"));
assert!(result.contains("text"));
assert!(result.contains("after"));
}
#[test]
fn test_ascii_control_chars_removed_outside_code_fence() {
let input = "be\x01fore\n```\nclean\n```\naf\x01ter";
let result = remove_ascii_control_chars_from_markup(input);
assert!(result.contains("before"));
assert!(result.contains("after"));
}
#[test]
fn test_ascii_control_fast_path_no_change() {
let input = "hello world\n\ttab here";
assert_eq!(remove_ascii_control_chars_from_markup(input), input);
}
#[test]
fn test_tilde_fence_also_protected() {
let input = "~~~\nhello\x01world\n~~~\n";
let result = remove_ascii_control_chars_from_markup(input);
assert!(result.contains("hello\x01world"));
}
}