ai_lib_rust/guardrails/
filters.rs1use super::config::{FilterAction, FilterRule};
4use super::result::{Violation, ViolationType};
5
6pub trait ContentFilter: Send + Sync {
8 fn check(&self, content: &str) -> Vec<Violation>;
10
11 fn sanitize(&self, content: &str, replacement: &str) -> String;
13}
14
15#[derive(Debug, Clone, Default)]
17pub struct KeywordFilter {
18 rules: Vec<CompiledKeywordRule>,
19}
20
21#[derive(Debug, Clone)]
22struct CompiledKeywordRule {
23 keyword: String,
24 keyword_lower: String,
25 case_sensitive: bool,
26 action: FilterAction,
27 category: Option<String>,
28 description: Option<String>,
29}
30
31impl KeywordFilter {
32 pub fn new() -> Self {
34 Self { rules: Vec::new() }
35 }
36
37 pub fn from_rules(rules: &[FilterRule]) -> Self {
39 let compiled_rules: Vec<CompiledKeywordRule> = rules
40 .iter()
41 .filter(|r| !r.is_regex)
42 .map(|r| CompiledKeywordRule {
43 keyword: r.pattern.clone(),
44 keyword_lower: r.pattern.to_lowercase(),
45 case_sensitive: r.case_sensitive,
46 action: r.action,
47 category: r.category.clone(),
48 description: r.description.clone(),
49 })
50 .collect();
51
52 Self { rules: compiled_rules }
53 }
54
55 pub fn add_keyword(&mut self, keyword: impl Into<String>, action: FilterAction) {
57 let keyword = keyword.into();
58 self.rules.push(CompiledKeywordRule {
59 keyword_lower: keyword.to_lowercase(),
60 keyword,
61 case_sensitive: false,
62 action,
63 category: None,
64 description: None,
65 });
66 }
67}
68
69impl ContentFilter for KeywordFilter {
70 fn check(&self, content: &str) -> Vec<Violation> {
71 let content_lower = content.to_lowercase();
72 let mut violations = Vec::new();
73
74 for rule in &self.rules {
75 let matched = if rule.case_sensitive {
76 content.contains(&rule.keyword)
77 } else {
78 content_lower.contains(&rule.keyword_lower)
79 };
80
81 if matched {
82 violations.push(Violation {
83 violation_type: ViolationType::Keyword,
84 pattern: rule.keyword.clone(),
85 action: rule.action,
86 category: rule.category.clone(),
87 description: rule.description.clone(),
88 matched_text: Some(rule.keyword.clone()),
89 });
90 }
91 }
92
93 violations
94 }
95
96 fn sanitize(&self, content: &str, replacement: &str) -> String {
97 let mut result = content.to_string();
98
99 for rule in &self.rules {
100 if matches!(rule.action, FilterAction::Sanitize | FilterAction::Block) {
101 if rule.case_sensitive {
102 result = result.replace(&rule.keyword, replacement);
103 } else {
104 let lower = result.to_lowercase();
106 let keyword_lower = &rule.keyword_lower;
107
108 let mut new_result = String::new();
109 let mut last_end = 0;
110
111 for (start, _) in lower.match_indices(keyword_lower) {
112 new_result.push_str(&result[last_end..start]);
113 new_result.push_str(replacement);
114 last_end = start + keyword_lower.len();
115 }
116 new_result.push_str(&result[last_end..]);
117 result = new_result;
118 }
119 }
120 }
121
122 result
123 }
124}
125
126#[derive(Debug, Clone, Default)]
128pub struct PatternFilter {
129 rules: Vec<CompiledPatternRule>,
130}
131
132#[derive(Debug, Clone)]
133struct CompiledPatternRule {
134 pattern_str: String,
135 case_sensitive: bool,
136 action: FilterAction,
137 category: Option<String>,
138 description: Option<String>,
139}
140
141impl PatternFilter {
142 pub fn new() -> Self {
144 Self { rules: Vec::new() }
145 }
146
147 pub fn from_rules(rules: &[FilterRule]) -> Self {
149 let compiled_rules: Vec<CompiledPatternRule> = rules
150 .iter()
151 .filter(|r| r.is_regex)
152 .map(|r| CompiledPatternRule {
153 pattern_str: r.pattern.clone(),
154 case_sensitive: r.case_sensitive,
155 action: r.action,
156 category: r.category.clone(),
157 description: r.description.clone(),
158 })
159 .collect();
160
161 Self { rules: compiled_rules }
162 }
163
164 pub fn add_pattern(&mut self, pattern: impl Into<String>, action: FilterAction) {
166 self.rules.push(CompiledPatternRule {
167 pattern_str: pattern.into(),
168 case_sensitive: true,
169 action,
170 category: None,
171 description: None,
172 });
173 }
174
175 fn compile_pattern(pattern: &str, case_sensitive: bool) -> Option<regex::Regex> {
177 let pattern_str = if case_sensitive {
178 pattern.to_string()
179 } else {
180 format!("(?i){}", pattern)
181 };
182
183 regex::Regex::new(&pattern_str).ok()
184 }
185}
186
187impl ContentFilter for PatternFilter {
188 fn check(&self, content: &str) -> Vec<Violation> {
189 let mut violations = Vec::new();
190
191 for rule in &self.rules {
192 if let Some(re) = Self::compile_pattern(&rule.pattern_str, rule.case_sensitive) {
193 if let Some(m) = re.find(content) {
194 violations.push(Violation {
195 violation_type: ViolationType::Pattern,
196 pattern: rule.pattern_str.clone(),
197 action: rule.action,
198 category: rule.category.clone(),
199 description: rule.description.clone(),
200 matched_text: Some(m.as_str().to_string()),
201 });
202 }
203 }
204 }
205
206 violations
207 }
208
209 fn sanitize(&self, content: &str, replacement: &str) -> String {
210 let mut result = content.to_string();
211
212 for rule in &self.rules {
213 if matches!(rule.action, FilterAction::Sanitize | FilterAction::Block) {
214 if let Some(re) = Self::compile_pattern(&rule.pattern_str, rule.case_sensitive) {
215 result = re.replace_all(&result, replacement).to_string();
216 }
217 }
218 }
219
220 result
221 }
222}