use std::borrow::Cow;
use unicode_normalization::UnicodeNormalization;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InputSource {
Keyboard,
Clipboard,
File,
Network,
Unknown,
}
impl std::fmt::Display for InputSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InputSource::Keyboard => write!(f, "keyboard"),
InputSource::Clipboard => write!(f, "clipboard"),
InputSource::File => write!(f, "file"),
InputSource::Network => write!(f, "network"),
InputSource::Unknown => write!(f, "unknown"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct SanitizeStats {
pub original_bytes: usize,
pub sanitized_bytes: usize,
pub invalid_sequences: usize,
pub null_bytes_removed: usize,
pub control_chars_removed: usize,
pub line_endings_normalized: usize,
pub unicode_normalized: bool,
pub was_valid: bool,
}
impl SanitizeStats {
pub fn had_issues(&self) -> bool {
!self.was_valid
|| self.invalid_sequences > 0
|| self.null_bytes_removed > 0
|| self.control_chars_removed > 0
|| self.line_endings_normalized > 0
}
pub fn summary(&self) -> String {
if !self.had_issues() && !self.unicode_normalized {
return "Input was clean UTF-8".to_string();
}
let mut parts = Vec::new();
if self.invalid_sequences > 0 {
parts.push(format!(
"{} invalid UTF-8 sequences",
self.invalid_sequences
));
}
if self.null_bytes_removed > 0 {
parts.push(format!("{} null bytes", self.null_bytes_removed));
}
if self.control_chars_removed > 0 {
parts.push(format!("{} control chars", self.control_chars_removed));
}
if self.line_endings_normalized > 0 {
parts.push(format!("{} line endings", self.line_endings_normalized));
}
if self.unicode_normalized {
parts.push("Unicode NFC normalized".to_string());
}
if parts.is_empty() {
"Input was clean".to_string()
} else {
format!("Sanitized: {}", parts.join(", "))
}
}
}
pub fn sanitize_input(bytes: &[u8], source: InputSource) -> String {
let (sanitized, _stats) = sanitize_input_with_stats(bytes, source);
sanitized
}
pub fn sanitize_input_with_stats(bytes: &[u8], _source: InputSource) -> (String, SanitizeStats) {
let original_bytes = bytes.len();
let (utf8_str, invalid_sequences) = match std::str::from_utf8(bytes) {
Ok(s) => (Cow::Borrowed(s), 0),
Err(_) => {
let lossy = String::from_utf8_lossy(bytes);
let invalid_count = lossy.matches('�').count();
(lossy, invalid_count)
}
};
let was_valid = invalid_sequences == 0;
let normalized_unicode: String = utf8_str.nfc().collect();
let unicode_normalized =
normalized_unicode.len() != utf8_str.len() || normalized_unicode != utf8_str.as_ref();
let (no_nulls, null_bytes_removed) = if normalized_unicode.contains('\0') {
let filtered: String = normalized_unicode.chars().filter(|&c| c != '\0').collect();
let removed = normalized_unicode.len() - filtered.len();
(filtered, removed)
} else {
(normalized_unicode, 0)
};
let original_len = no_nulls.len();
let filtered: String = no_nulls
.chars()
.filter(|&c| !c.is_control() || matches!(c, '\n' | '\r' | '\t'))
.collect();
let control_chars_removed = original_len - filtered.len();
let (normalized, line_endings_normalized) = normalize_line_endings(&filtered);
let sanitized_bytes = normalized.len();
let stats = SanitizeStats {
original_bytes,
sanitized_bytes,
invalid_sequences,
null_bytes_removed,
control_chars_removed,
line_endings_normalized,
unicode_normalized,
was_valid,
};
if stats.had_issues() {
#[cfg(debug_assertions)]
log::debug!("[UTF-8 Sanitizer] {}", stats.summary());
}
(normalized.into_owned(), stats)
}
fn normalize_line_endings(s: &str) -> (Cow<'_, str>, usize) {
if !s.contains('\r') {
return (Cow::Borrowed(s), 0);
}
let cr_count = s.matches('\r').count();
let normalized = s.replace("\r\n", "\n").replace('\r', "\n");
(Cow::Owned(normalized), cr_count)
}
pub fn is_char_boundary(s: &str, index: usize) -> bool {
s.is_char_boundary(index)
}
pub fn find_prev_boundary(s: &str, index: usize) -> usize {
if index >= s.len() {
return s.len();
}
let mut pos = index;
while pos > 0 && !s.is_char_boundary(pos) {
pos -= 1;
}
pos
}
pub fn find_next_boundary(s: &str, index: usize) -> usize {
if index >= s.len() {
return s.len();
}
let mut pos = index;
while pos < s.len() && !s.is_char_boundary(pos) {
pos += 1;
}
pos
}
pub fn char_byte_length(s: &str, index: usize) -> usize {
if !s.is_char_boundary(index) {
return 0;
}
s[index..].chars().next().map(|c| c.len_utf8()).unwrap_or(0)
}
pub fn substring_by_chars(s: &str, char_start: usize, char_end: usize) -> &str {
let byte_start = s
.char_indices()
.nth(char_start)
.map(|(i, _)| i)
.unwrap_or(s.len());
let byte_end = s
.char_indices()
.nth(char_end)
.map(|(i, _)| i)
.unwrap_or(s.len());
&s[byte_start..byte_end]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_valid_utf8() {
let input = b"Hello, World!";
let (result, stats) = sanitize_input_with_stats(input, InputSource::Keyboard);
assert_eq!(result, "Hello, World!");
assert!(!stats.had_issues());
assert_eq!(stats.invalid_sequences, 0);
}
#[test]
fn test_invalid_utf8_replaced() {
let input = b"Hello \xF0\x28\x8C\x28 World";
let (result, stats) = sanitize_input_with_stats(input, InputSource::Clipboard);
assert!(result.contains('�'));
assert!(stats.had_issues());
assert!(stats.invalid_sequences > 0);
}
#[test]
fn test_null_bytes_removed() {
let input = b"Hello\x00World\x00";
let (result, stats) = sanitize_input_with_stats(input, InputSource::File);
assert_eq!(result, "HelloWorld");
assert!(stats.had_issues());
assert_eq!(stats.null_bytes_removed, 2);
}
#[test]
fn test_line_ending_normalization_crlf() {
let input = b"Line1\r\nLine2\r\nLine3";
let (result, stats) = sanitize_input_with_stats(input, InputSource::File);
assert_eq!(result, "Line1\nLine2\nLine3");
assert!(stats.had_issues());
assert!(stats.line_endings_normalized > 0);
}
#[test]
fn test_line_ending_normalization_cr() {
let input = b"Line1\rLine2\rLine3";
let (result, stats) = sanitize_input_with_stats(input, InputSource::File);
assert_eq!(result, "Line1\nLine2\nLine3");
assert!(stats.had_issues());
}
#[test]
fn test_em_dash_char_boundary() {
let text = "Hello — World";
assert!(is_char_boundary(text, 6)); assert!(is_char_boundary(text, 9)); assert!(!is_char_boundary(text, 7)); assert!(!is_char_boundary(text, 8)); }
#[test]
fn test_find_boundaries() {
let text = "Hello — World";
assert_eq!(find_prev_boundary(text, 7), 6); assert_eq!(find_prev_boundary(text, 6), 6);
assert_eq!(find_next_boundary(text, 7), 9); assert_eq!(find_next_boundary(text, 9), 9); }
#[test]
fn test_char_byte_length() {
let text = "Hello — World 😀";
assert_eq!(char_byte_length(text, 0), 1); assert_eq!(char_byte_length(text, 6), 3); assert_eq!(char_byte_length(text, 16), 4); }
#[test]
fn test_substring_by_chars() {
let text = "Hello — World";
assert_eq!(substring_by_chars(text, 0, 5), "Hello");
assert_eq!(substring_by_chars(text, 6, 7), "—");
assert_eq!(substring_by_chars(text, 8, 13), "World");
}
#[test]
fn test_emoji_handling() {
let input = "Hello 😀 World 🎉".as_bytes();
let (result, stats) = sanitize_input_with_stats(input, InputSource::Keyboard);
assert_eq!(result, "Hello 😀 World 🎉");
assert!(!stats.had_issues());
}
#[test]
fn test_cjk_characters() {
let input = "こんにちは世界".as_bytes();
let (result, stats) = sanitize_input_with_stats(input, InputSource::Keyboard);
assert_eq!(result, "こんにちは世界");
assert!(!stats.had_issues());
}
#[test]
fn test_unicode_nfc_normalization_precomposed() {
let decomposed = "cafe\u{0301}"; let input = decomposed.as_bytes();
let (result, stats) = sanitize_input_with_stats(input, InputSource::Keyboard);
assert_eq!(result, "café"); assert!(stats.unicode_normalized);
}
#[test]
fn test_unicode_nfc_already_normalized() {
let input = "café".as_bytes(); let (result, _stats) = sanitize_input_with_stats(input, InputSource::Keyboard);
assert_eq!(result, "café");
}
#[test]
fn test_em_dash_preserved() {
let input = "Native performance — no login".as_bytes();
let (result, _stats) = sanitize_input_with_stats(input, InputSource::Keyboard);
assert_eq!(result, "Native performance — no login");
assert!(result.contains('—'));
let em_dash_char = result.chars().find(|&c| c == '—').unwrap();
assert_eq!(em_dash_char as u32, 0x2014); }
#[test]
fn test_hyphen_vs_em_dash() {
let input = "hyphen - and em dash —".as_bytes();
let (result, _stats) = sanitize_input_with_stats(input, InputSource::Keyboard);
assert_eq!(result, "hyphen - and em dash —");
let hyphen_count = result.matches('-').count();
let em_dash_count = result.matches('—').count();
assert_eq!(hyphen_count, 1);
assert_eq!(em_dash_count, 1);
}
#[test]
fn test_control_characters_filtered() {
let input = "Hello\x01\x02World\nNew\tLine\r\n".as_bytes();
let (result, stats) = sanitize_input_with_stats(input, InputSource::Keyboard);
assert_eq!(result, "HelloWorld\nNew\tLine\n");
assert!(stats.control_chars_removed > 0);
}
#[test]
fn test_complex_markdown_with_em_dashes() {
let input = "- **Bold** — description\n- *Italic* — another item".as_bytes();
let (result, _stats) = sanitize_input_with_stats(input, InputSource::File);
assert!(result.contains("**Bold** — description"));
assert!(result.contains("*Italic* — another item"));
assert_eq!(result.matches('—').count(), 2);
}
#[test]
fn test_mixed_multibyte() {
let input = "ASCII Café 日本語 😀".as_bytes();
let (result, stats) = sanitize_input_with_stats(input, InputSource::File);
assert_eq!(result, "ASCII Café 日本語 😀");
assert!(!stats.had_issues());
}
#[test]
fn test_stats_summary() {
let input = b"Hello\x00\xF0\x28World\r\n";
let (_result, stats) = sanitize_input_with_stats(input, InputSource::Clipboard);
assert!(stats.had_issues());
let summary = stats.summary();
assert!(summary.contains("Sanitized"));
}
}