use reinhardt_core::security::xss::strip_tags_safe;
use std::borrow::Cow;
pub fn escape(text: &str) -> String {
let mut result = String::with_capacity(text.len() + 10);
for ch in text.chars() {
match ch {
'&' => result.push_str("&"),
'<' => result.push_str("<"),
'>' => result.push_str(">"),
'"' => result.push_str("""),
'\'' => result.push_str("'"),
_ => result.push(ch),
}
}
result
}
pub fn unescape(text: &str) -> String {
let mut result = String::with_capacity(text.len());
let mut chars = text.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '&' {
let entity: String = chars.by_ref().take_while(|&c| c != ';').collect();
match entity.as_str() {
"amp" => result.push('&'),
"lt" => result.push('<'),
"gt" => result.push('>'),
"quot" => result.push('"'),
"#x27" | "apos" => result.push('\''),
_ if entity.starts_with('#') => {
if let Some(code_str) = entity.strip_prefix('#')
&& let Ok(code) = code_str.parse::<u32>()
&& let Some(unicode_char) = char::from_u32(code)
{
result.push(unicode_char);
continue;
}
result.push('&');
result.push_str(&entity);
result.push(';');
}
_ => {
result.push('&');
result.push_str(&entity);
result.push(';');
}
}
} else {
result.push(ch);
}
}
result
}
pub fn strip_tags(html: &str) -> String {
strip_tags_safe(html)
}
pub fn strip_spaces_between_tags(html: &str) -> String {
let mut result = String::with_capacity(html.len());
let mut in_tag = false;
let mut space_buffer = String::new();
for ch in html.chars() {
match ch {
'<' => {
in_tag = true;
result.push(ch);
space_buffer.clear();
}
'>' => {
in_tag = false;
result.push(ch);
}
' ' | '\t' | '\n' | '\r' if !in_tag => {
space_buffer.push(ch);
}
_ => {
if !in_tag && !space_buffer.is_empty() {
result.push_str(&space_buffer);
space_buffer.clear();
}
result.push(ch);
}
}
}
result
}
pub fn escape_attr(text: &str) -> String {
let escaped = escape(text);
let mut result = String::with_capacity(escaped.len());
for ch in escaped.chars() {
match ch {
'\n' => result.push_str(" "),
'\r' => result.push_str(" "),
'\t' => result.push_str("	"),
_ => result.push(ch),
}
}
result
}
pub fn format_html(template: &str, args: &[(&str, &str)]) -> String {
let mut result = template.to_string();
for (key, value) in args {
let placeholder = format!("{{{}}}", key);
let escaped_value = escape(value);
result = result.replace(&placeholder, &escaped_value);
}
result
}
pub fn conditional_escape(text: &str, autoescape: bool) -> Cow<'_, str> {
if autoescape {
Cow::Owned(escape(text))
} else {
Cow::Borrowed(text)
}
}
#[derive(Debug, Clone)]
pub struct SafeString(String);
impl SafeString {
pub fn new(s: impl Into<String>) -> Self {
Self(s.into())
}
pub fn as_str(&self) -> &str {
&self.0
}
}
impl From<String> for SafeString {
fn from(s: String) -> Self {
Self(s)
}
}
impl From<&str> for SafeString {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
pub fn truncate_html_words(html: &str, num_words: usize) -> String {
let mut result = String::new();
let mut word_count = 0;
let mut in_tag = false;
let mut current_word = String::new();
for ch in html.chars() {
match ch {
'<' => {
if !current_word.is_empty() {
result.push_str(¤t_word);
current_word.clear();
word_count += 1;
if word_count >= num_words {
return result + "...";
}
}
in_tag = true;
result.push(ch);
}
'>' => {
in_tag = false;
result.push(ch);
}
' ' | '\t' | '\n' | '\r' if !in_tag => {
if !current_word.is_empty() {
result.push_str(¤t_word);
current_word.clear();
word_count += 1;
if word_count >= num_words {
return result + "...";
}
}
result.push(ch);
}
_ => {
if in_tag {
result.push(ch);
} else {
current_word.push(ch);
}
}
}
}
if !current_word.is_empty() && word_count < num_words {
result.push_str(¤t_word);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_escape() {
assert_eq!(escape("Hello, World!"), "Hello, World!");
assert_eq!(
escape("<script>alert('XSS')</script>"),
"<script>alert('XSS')</script>"
);
assert_eq!(escape("5 < 10 & 10 > 5"), "5 < 10 & 10 > 5");
assert_eq!(escape("\"quoted\""), ""quoted"");
}
#[test]
fn test_unescape() {
assert_eq!(unescape("<div>"), "<div>");
assert_eq!(unescape("&"), "&");
assert_eq!(unescape(""test""), "\"test\"");
assert_eq!(unescape("'"), "'");
assert_eq!(unescape("'"), "'");
}
#[test]
fn test_strip_tags() {
assert_eq!(strip_tags("<p>Hello <b>World</b></p>"), "Hello World");
assert_eq!(strip_tags("<div><span>Test</span></div>"), "Test");
assert_eq!(strip_tags("No tags here"), "No tags here");
assert_eq!(strip_tags("<a href=\"#\">Link</a>"), "Link");
}
#[test]
fn test_strip_spaces_between_tags() {
assert_eq!(
strip_spaces_between_tags("<div> <span>Test</span> </div>"),
"<div><span>Test</span></div>"
);
}
#[test]
fn test_escape_attr() {
assert_eq!(escape_attr("value"), "value");
assert_eq!(
escape_attr("value with \"quotes\""),
"value with "quotes""
);
assert_eq!(escape_attr("line\nbreak"), "line break");
assert_eq!(escape_attr("tab\there"), "tab	here");
}
#[test]
fn test_format_html() {
let template = "<div class=\"{class}\">{content}</div>";
let args = [("class", "container"), ("content", "Hello")];
assert_eq!(
format_html(template, &args),
"<div class=\"container\">Hello</div>"
);
}
#[test]
fn test_conditional_escape() {
assert_eq!(conditional_escape("<script>", true), "<script>");
assert_eq!(conditional_escape("<script>", false), "<script>");
}
#[test]
fn test_safe_string() {
let safe = SafeString::new("<b>Bold</b>");
assert_eq!(safe.as_str(), "<b>Bold</b>");
}
#[test]
fn test_truncate_html_words() {
let html = "<p>This is a <b>test</b> sentence with many words.</p>";
let truncated = truncate_html_words(html, 5);
assert!(truncated.contains("This"));
assert!(truncated.contains("is"));
assert!(truncated.contains("..."));
}
#[test]
fn test_truncate_html_preserves_tags() {
let html = "<div>Hello <strong>world</strong> test</div>";
let truncated = truncate_html_words(html, 2);
assert!(truncated.contains("<div>"));
assert!(truncated.contains("<strong>"));
}
#[test]
fn test_safe_string_from_string() {
let s = String::from("<b>Bold</b>");
let safe = SafeString::from(s);
assert_eq!(safe.as_str(), "<b>Bold</b>");
}
#[test]
fn test_safe_string_from_str() {
let safe = SafeString::from("<i>Italic</i>");
assert_eq!(safe.as_str(), "<i>Italic</i>");
}
#[test]
fn test_escape_empty_string() {
assert_eq!(escape(""), "");
}
#[test]
fn test_escape_multibyte() {
assert_eq!(escape("こんにちは<>&"), "こんにちは<>&");
}
#[test]
fn test_unescape_incomplete_entity() {
assert_eq!(unescape("<"), "<");
assert_eq!(unescape("&"), "&;");
}
#[test]
fn test_unescape_unknown_entity() {
assert_eq!(unescape("&unknown;"), "&unknown;");
}
#[test]
fn test_strip_tags_nested() {
assert_eq!(strip_tags("<div><p><span>Test</span></p></div>"), "Test");
}
#[test]
fn test_strip_tags_empty() {
assert_eq!(strip_tags(""), "");
}
#[test]
fn test_strip_tags_quoted_attributes_with_angle_brackets() {
assert_eq!(strip_tags(r#"<a title="x>y">Link</a>"#), "Link");
assert_eq!(strip_tags("<a title='x>y'>Link</a>"), "Link");
assert_eq!(
strip_tags(r#"<a title="a>b" data-value="c>d">Text</a>"#),
"Text"
);
assert_eq!(strip_tags(r#"<a title='x"y'>Link</a>"#), "Link");
assert_eq!(strip_tags(r#"<a title="x'y">Link</a>"#), "Link");
}
#[test]
fn test_strip_spaces_between_tags_multiple_spaces() {
assert_eq!(
strip_spaces_between_tags("<div> \n\t <span>Test</span> \n\t </div>"),
"<div><span>Test</span></div>"
);
}
#[test]
fn test_escape_attr_carriage_return() {
assert_eq!(escape_attr("test\rvalue"), "test value");
}
#[test]
fn test_format_html_multiple_replacements() {
let template = "<div id=\"{id}\" class=\"{class}\">{content}</div>";
let args = [("id", "main"), ("class", "container"), ("content", "Hello")];
assert_eq!(
format_html(template, &args),
"<div id=\"main\" class=\"container\">Hello</div>"
);
}
#[test]
fn test_format_html_no_replacements() {
let template = "<div>Static content</div>";
let args: [(&str, &str); 0] = [];
assert_eq!(format_html(template, &args), "<div>Static content</div>");
}
#[test]
fn test_format_html_xss_prevention_script_tag() {
let template = "<p>{content}</p>";
let args = [("content", "<script>alert('xss')</script>")];
let result = format_html(template, &args);
assert!(!result.contains("<script>"));
assert!(result.contains("<script>"));
assert!(result.contains("</script>"));
assert!(result.contains("'xss'"));
}
#[test]
fn test_format_html_xss_prevention_event_handler() {
let template = r#"<div class="{class}">{content}</div>"#;
let args = [
("class", r#"container" onclick="alert('xss')"#),
("content", "Safe content"),
];
let result = format_html(template, &args);
assert!(result.contains("""));
assert!(!result.contains(r#"onclick="alert"#));
}
#[test]
fn test_format_html_xss_prevention_ampersand() {
let template = "<a href=\"/search?q={query}\">Search</a>";
let args = [("query", "test&redirect=evil.com")];
let result = format_html(template, &args);
assert!(result.contains("&"));
assert!(!result.contains("test&redirect"));
}
#[test]
fn test_format_html_xss_prevention_angle_brackets() {
let template = "<span>{text}</span>";
let args = [("text", "<<SCRIPT>alert('XSS');//<</SCRIPT>")];
let result = format_html(template, &args);
assert!(!result.contains("<SCRIPT>"));
assert!(result.contains("<"));
assert!(result.contains(">"));
}
#[test]
fn test_format_html_safe_values_unchanged() {
let template = "<div id=\"{id}\" class=\"{class}\">{content}</div>";
let args = [
("id", "main"),
("class", "container"),
("content", "Hello World"),
];
let result = format_html(template, &args);
assert_eq!(
result,
"<div id=\"main\" class=\"container\">Hello World</div>"
);
}
#[test]
fn test_truncate_html_words_exact_count() {
let html = "<p>One two three</p>";
let truncated = truncate_html_words(html, 3);
assert!(truncated.contains("..."));
}
#[test]
fn test_truncate_html_words_empty() {
let html = "";
let truncated = truncate_html_words(html, 5);
assert_eq!(truncated, "");
}
}
#[cfg(test)]
mod proptests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn prop_escape_no_special_chars(s in "[^<>&\"']*") {
let escaped = escape(&s);
assert!(!escaped.contains('<'));
assert!(!escaped.contains('>'));
assert!(!escaped.contains('&'));
assert!(!escaped.contains('"'));
assert!(!escaped.contains('\''));
}
#[test]
fn prop_strip_tags_no_angle_brackets(s in "\\PC*") {
let stripped = strip_tags(&s);
assert!(!stripped.contains('<'));
}
#[test]
fn prop_strip_tags_length_decrease(s in "\\PC*") {
let stripped = strip_tags(&s);
assert!(stripped.len() <= s.len());
}
#[test]
fn prop_truncate_html_words_respects_limit(html in "\\PC*", n in 1usize..20) {
let truncated = truncate_html_words(&html, n);
let word_count = truncated
.split(|c: char| c.is_whitespace() || c == '<' || c == '>')
.filter(|w| !w.is_empty() && !w.starts_with('/'))
.filter(|w| !w.chars().all(|c| !c.is_alphanumeric()))
.count();
assert!(word_count <= n + 5);
}
#[test]
fn prop_escape_attr_no_newlines(s in "\\PC*") {
let escaped = escape_attr(&s);
assert!(!escaped.contains('\n'));
assert!(!escaped.contains('\r'));
assert!(!escaped.contains('\t'));
}
#[test]
fn prop_conditional_escape_when_true(s in "\\PC*") {
let escaped_cond = conditional_escape(&s, true);
let escaped_direct = escape(&s);
assert_eq!(escaped_cond, escaped_direct);
}
#[test]
fn prop_conditional_escape_when_false(s in "\\PC*") {
let escaped = conditional_escape(&s, false);
assert_eq!(escaped, s);
}
#[test]
fn prop_safe_string_roundtrip(s in "\\PC*") {
let safe = SafeString::from(s.clone());
assert_eq!(safe.as_str(), &s);
}
#[test]
fn prop_format_html_preserves_non_placeholders(template in "\\PC*") {
let args: [(&str, &str); 0] = [];
let result = format_html(&template, &args);
assert_eq!(result, template);
}
#[test]
fn prop_strip_spaces_reduces_whitespace(s in "\\PC*") {
let stripped = strip_spaces_between_tags(&s);
assert!(stripped.len() <= s.len() + 100); }
}
}