bastion_toolkit/
text_guard.rs1use regex::Regex;
7use std::sync::OnceLock;
8
9#[cfg(feature = "text")]
10use unicode_normalization::UnicodeNormalization;
11
12#[derive(Debug, PartialEq, Eq)]
14pub enum ValidationResult {
15 Valid,
17 Blocked(String),
19}
20
21pub struct Guard {
23 max_len: usize,
24}
25
26impl Default for Guard {
27 fn default() -> Self {
28 Self { max_len: 4096 }
29 }
30}
31
32static INJECTION_PATTERNS: OnceLock<Vec<Regex>> = OnceLock::new();
33
34fn get_patterns() -> &'static Vec<Regex> {
35 INJECTION_PATTERNS.get_or_init(|| {
36 vec![
37 Regex::new(r"(?i)ignore previous instructions").unwrap(),
39 Regex::new(r"(?i)system prompt").unwrap(),
40 Regex::new(r"(?i)you are an ai").unwrap(),
41 Regex::new(r"(?i)<script").unwrap(),
43 Regex::new(r"(?i)javascript:").unwrap(),
44 Regex::new(r"(?i)vbscript:").unwrap(),
45 Regex::new(r"(?i)data:text/html").unwrap(),
46 Regex::new(r#"(?i)alert\("#).unwrap(),
47 ]
48 })
49}
50
51impl Guard {
52 pub fn new() -> Self {
53 Self::default()
54 }
55
56 pub fn max_len(mut self, len: usize) -> Self {
58 self.max_len = len;
59 self
60 }
61
62 pub fn analyze(&self, input: &str) -> ValidationResult {
64 if input.len() > self.max_len {
66 return ValidationResult::Blocked(format!(
67 "Input too long (max {} bytes, got {})",
68 self.max_len,
69 input.len()
70 ));
71 }
72
73 let patterns = get_patterns();
75 for re in patterns {
76 if re.is_match(input) {
77 return ValidationResult::Blocked("Potential injection detected".to_string());
78 }
79 }
80
81 ValidationResult::Valid
82 }
83
84 pub fn sanitize(&self, input: &str) -> String {
86 let mut text = if input.len() > self.max_len {
88 input[..self.max_len].to_string()
89 } else {
90 input.to_string()
91 };
92
93 #[cfg(feature = "text")]
95 {
96 text = text.nfc().collect::<String>();
97 }
98
99 text = text.chars().filter(|&c| !self.is_forbidden_char(c)).collect();
101
102 text = self.mask_windows_reserved(&text);
104
105 text
106 }
107
108 fn is_forbidden_char(&self, c: char) -> bool {
109 if c.is_control() {
110 return true;
111 }
112 match c {
113 '\u{200E}' | '\u{200F}' | '\u{202A}'..='\u{202A}' | '\u{202B}'..='\u{202B}' |
114 '\u{202C}'..='\u{202C}' | '\u{202D}'..='\u{202D}' | '\u{202E}'..='\u{202E}' |
115 '\u{2066}'..='\u{2069}' => return true,
116 _ => {}
117 }
118 matches!(c, '/' | '\\' | ':' | '*' | '?' | '"' | '<' | '>' | '|')
120 }
121
122 fn mask_windows_reserved(&self, name: &str) -> String {
123 let upper = name.to_uppercase();
124 let reserved = [
125 "CON", "PRN", "AUX", "NUL",
126 "COM1", "COM2", "COM3", "COM4", "COM5", "COM6", "COM7", "COM8", "COM9",
127 "LPT1", "LPT2", "LPT3", "LPT4", "LPT5", "LPT6", "LPT7", "LPT8", "LPT9",
128 ];
129
130 if reserved.contains(&upper.as_str()) {
131 format!("_{}", name)
132 } else {
133 name.to_string()
134 }
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 #[test]
143 fn test_analyze_and_sanitize() {
144 let guard = Guard::new().max_len(20);
145
146 assert_eq!(guard.analyze("Hello"), ValidationResult::Valid);
148 assert!(matches!(guard.analyze("<script>"), ValidationResult::Blocked(_)));
149
150 assert_eq!(guard.sanitize("file/name.txt"), "filename.txt");
152 assert_eq!(guard.sanitize("CON"), "_CON");
153 }
154}