use std::borrow::Cow;
pub fn sanitize_url(url: &str) -> Cow<'_, str> {
let url_lower = url.trim().to_lowercase();
if url_lower.starts_with("javascript:")
|| url_lower.starts_with("data:")
|| url_lower.starts_with("vbscript:")
|| url_lower.starts_with("file:")
{
return Cow::Borrowed("#blocked-url");
}
Cow::Borrowed(url)
}
pub fn sanitize(input: &str) -> Cow<'_, str> {
if !input.contains(&['<', '>', '&'][..]) {
return Cow::Borrowed(input);
}
let mut result = String::with_capacity(input.len() + 32);
let mut chars = input.chars().peekable();
while let Some(ch) = chars.next() {
match ch {
'<' => result.push_str("<"),
'>' => result.push_str(">"),
'&' => {
if is_html_entity(&mut chars.clone()) {
result.push(ch);
} else {
result.push_str("&");
}
}
_ => result.push(ch),
}
}
Cow::Owned(result)
}
fn is_html_entity(chars: &mut std::iter::Peekable<std::str::Chars>) -> bool {
let mut entity = String::new();
let mut temp_chars = chars.clone();
while let Some(&ch) = temp_chars.peek() {
if ch == ';' {
return is_valid_entity(&entity);
}
if entity.len() > 10 {
return false;
}
if !ch.is_alphanumeric() && ch != '#' && ch != 'x' && ch != 'X' {
return false;
}
entity.push(ch);
temp_chars.next();
}
false
}
fn is_valid_entity(entity: &str) -> bool {
if entity.is_empty() {
return false;
}
if entity.starts_with('#') {
if entity.len() < 2 {
return false;
}
if entity[1..].starts_with('x') || entity[1..].starts_with('X') {
if entity.len() < 3 {
return false;
}
return entity[2..].chars().all(|c| c.is_ascii_hexdigit());
} else {
return entity[1..].chars().all(|c| c.is_ascii_digit());
}
}
matches!(
entity,
"nbsp"
| "lt"
| "gt"
| "amp"
| "quot"
| "apos"
| "copy"
| "reg"
| "trade"
| "ndash"
| "mdash"
| "lsquo"
| "rsquo"
| "ldquo"
| "rdquo"
| "hellip"
| "prime"
| "Prime"
| "euro"
| "yen"
| "pound"
| "cent"
| "times"
| "divide"
| "plusmn"
| "minus"
| "alpha"
| "beta"
| "gamma"
| "delta"
| "epsilon"
| "Alpha"
| "Beta"
| "Gamma"
| "Delta"
| "Epsilon" )
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_no_html() {
let input = "Hello World";
assert_eq!(sanitize(input), "Hello World");
}
#[test]
fn test_escape_tags() {
let input = "<script>alert('xss')</script>";
assert_eq!(sanitize(input), "<script>alert('xss')</script>");
}
#[test]
fn test_preserve_entities() {
let input = "Hello World";
assert_eq!(sanitize(input), "Hello World");
}
#[test]
fn test_escape_ampersand() {
let input = "A & B";
assert_eq!(sanitize(input), "A & B");
}
#[test]
fn test_mixed_content() {
let input = "<div>Hello World & stuff</div>";
assert_eq!(
sanitize(input),
"<div>Hello World & stuff</div>"
);
}
#[test]
fn test_numeric_entities() {
let input = "{ {";
assert_eq!(sanitize(input), "{ {");
}
#[test]
fn test_invalid_entity() {
let input = "&invalid;";
assert_eq!(sanitize(input), "&invalid;");
}
#[test]
fn test_xss_attempts() {
let test_cases = vec![
(
"<img src=x onerror=alert(1)>",
"<img src=x onerror=alert(1)>",
),
("<svg/onload=alert(1)>", "<svg/onload=alert(1)>"),
(
"<iframe src=javascript:alert(1)>",
"<iframe src=javascript:alert(1)>",
),
];
for (input, expected) in test_cases {
assert_eq!(sanitize(input), expected);
}
}
#[test]
fn test_entity_validation() {
assert!(is_valid_entity("nbsp"));
assert!(is_valid_entity("lt"));
assert!(is_valid_entity("gt"));
assert!(is_valid_entity("#123"));
assert!(is_valid_entity("#x7B"));
assert!(!is_valid_entity("invalid"));
assert!(!is_valid_entity(""));
}
#[test]
fn test_sanitize_url_safe_schemes() {
assert_eq!(sanitize_url("https://example.com"), "https://example.com");
assert_eq!(sanitize_url("http://example.com"), "http://example.com");
assert_eq!(
sanitize_url("mailto:user@example.com"),
"mailto:user@example.com"
);
assert_eq!(sanitize_url("ftp://example.com"), "ftp://example.com");
assert_eq!(sanitize_url("/relative/path"), "/relative/path");
assert_eq!(sanitize_url("./relative"), "./relative");
assert_eq!(sanitize_url("#anchor"), "#anchor");
}
#[test]
fn test_sanitize_url_custom_app_schemes() {
assert_eq!(sanitize_url("spotify:track:123"), "spotify:track:123");
assert_eq!(sanitize_url("steam://open/game"), "steam://open/game");
assert_eq!(sanitize_url("discord://invite/123"), "discord://invite/123");
assert_eq!(
sanitize_url("slack://channel?id=123"),
"slack://channel?id=123"
);
assert_eq!(sanitize_url("zoom:meeting:123"), "zoom:meeting:123");
assert_eq!(sanitize_url("vscode://file/path"), "vscode://file/path");
}
#[test]
fn test_sanitize_url_blocked_schemes() {
assert_eq!(sanitize_url("javascript:alert(1)"), "#blocked-url");
assert_eq!(sanitize_url("JavaScript:alert(1)"), "#blocked-url");
assert_eq!(sanitize_url("JAVASCRIPT:alert(1)"), "#blocked-url");
assert_eq!(
sanitize_url("data:text/html,<script>alert(1)</script>"),
"#blocked-url"
);
assert_eq!(sanitize_url("Data:text/html,test"), "#blocked-url");
assert_eq!(sanitize_url("vbscript:msgbox(1)"), "#blocked-url");
assert_eq!(sanitize_url("VBScript:msgbox(1)"), "#blocked-url");
assert_eq!(sanitize_url("file:///etc/passwd"), "#blocked-url");
assert_eq!(sanitize_url("FILE:///C:/Windows"), "#blocked-url");
}
#[test]
fn test_sanitize_url_with_whitespace() {
assert_eq!(sanitize_url(" javascript:alert(1) "), "#blocked-url");
assert_eq!(sanitize_url("\tdata:text/html,test\n"), "#blocked-url");
assert_eq!(
sanitize_url(" https://example.com "),
" https://example.com "
);
}
}