Skip to main content

ai_lib_rust/guardrails/
filters.rs

1//! Content filtering implementations
2
3use super::config::{FilterAction, FilterRule};
4use super::result::{Violation, ViolationType};
5
6/// Trait for content filters
7pub trait ContentFilter: Send + Sync {
8    /// Check content for violations
9    fn check(&self, content: &str) -> Vec<Violation>;
10    
11    /// Sanitize content by replacing violations
12    fn sanitize(&self, content: &str, replacement: &str) -> String;
13}
14
15/// Keyword-based content filter
16#[derive(Debug, Clone, Default)]
17pub struct KeywordFilter {
18    rules: Vec<CompiledKeywordRule>,
19}
20
21#[derive(Debug, Clone)]
22struct CompiledKeywordRule {
23    keyword: String,
24    keyword_lower: String,
25    case_sensitive: bool,
26    action: FilterAction,
27    category: Option<String>,
28    description: Option<String>,
29}
30
31impl KeywordFilter {
32    /// Create a new empty keyword filter
33    pub fn new() -> Self {
34        Self { rules: Vec::new() }
35    }
36
37    /// Create from filter rules
38    pub fn from_rules(rules: &[FilterRule]) -> Self {
39        let compiled_rules: Vec<CompiledKeywordRule> = rules
40            .iter()
41            .filter(|r| !r.is_regex)
42            .map(|r| CompiledKeywordRule {
43                keyword: r.pattern.clone(),
44                keyword_lower: r.pattern.to_lowercase(),
45                case_sensitive: r.case_sensitive,
46                action: r.action,
47                category: r.category.clone(),
48                description: r.description.clone(),
49            })
50            .collect();
51
52        Self { rules: compiled_rules }
53    }
54
55    /// Add a keyword rule
56    pub fn add_keyword(&mut self, keyword: impl Into<String>, action: FilterAction) {
57        let keyword = keyword.into();
58        self.rules.push(CompiledKeywordRule {
59            keyword_lower: keyword.to_lowercase(),
60            keyword,
61            case_sensitive: false,
62            action,
63            category: None,
64            description: None,
65        });
66    }
67}
68
69impl ContentFilter for KeywordFilter {
70    fn check(&self, content: &str) -> Vec<Violation> {
71        let content_lower = content.to_lowercase();
72        let mut violations = Vec::new();
73
74        for rule in &self.rules {
75            let matched = if rule.case_sensitive {
76                content.contains(&rule.keyword)
77            } else {
78                content_lower.contains(&rule.keyword_lower)
79            };
80
81            if matched {
82                violations.push(Violation {
83                    violation_type: ViolationType::Keyword,
84                    pattern: rule.keyword.clone(),
85                    action: rule.action,
86                    category: rule.category.clone(),
87                    description: rule.description.clone(),
88                    matched_text: Some(rule.keyword.clone()),
89                });
90            }
91        }
92
93        violations
94    }
95
96    fn sanitize(&self, content: &str, replacement: &str) -> String {
97        let mut result = content.to_string();
98
99        for rule in &self.rules {
100            if matches!(rule.action, FilterAction::Sanitize | FilterAction::Block) {
101                if rule.case_sensitive {
102                    result = result.replace(&rule.keyword, replacement);
103                } else {
104                    // Case-insensitive replacement
105                    let lower = result.to_lowercase();
106                    let keyword_lower = &rule.keyword_lower;
107                    
108                    let mut new_result = String::new();
109                    let mut last_end = 0;
110                    
111                    for (start, _) in lower.match_indices(keyword_lower) {
112                        new_result.push_str(&result[last_end..start]);
113                        new_result.push_str(replacement);
114                        last_end = start + keyword_lower.len();
115                    }
116                    new_result.push_str(&result[last_end..]);
117                    result = new_result;
118                }
119            }
120        }
121
122        result
123    }
124}
125
126/// Regex pattern-based content filter
127#[derive(Debug, Clone, Default)]
128pub struct PatternFilter {
129    rules: Vec<CompiledPatternRule>,
130}
131
132#[derive(Debug, Clone)]
133struct CompiledPatternRule {
134    pattern_str: String,
135    case_sensitive: bool,
136    action: FilterAction,
137    category: Option<String>,
138    description: Option<String>,
139}
140
141impl PatternFilter {
142    /// Create a new empty pattern filter
143    pub fn new() -> Self {
144        Self { rules: Vec::new() }
145    }
146
147    /// Create from filter rules
148    pub fn from_rules(rules: &[FilterRule]) -> Self {
149        let compiled_rules: Vec<CompiledPatternRule> = rules
150            .iter()
151            .filter(|r| r.is_regex)
152            .map(|r| CompiledPatternRule {
153                pattern_str: r.pattern.clone(),
154                case_sensitive: r.case_sensitive,
155                action: r.action,
156                category: r.category.clone(),
157                description: r.description.clone(),
158            })
159            .collect();
160
161        Self { rules: compiled_rules }
162    }
163
164    /// Add a pattern rule
165    pub fn add_pattern(&mut self, pattern: impl Into<String>, action: FilterAction) {
166        self.rules.push(CompiledPatternRule {
167            pattern_str: pattern.into(),
168            case_sensitive: true,
169            action,
170            category: None,
171            description: None,
172        });
173    }
174
175    /// Compile a pattern string to regex
176    fn compile_pattern(pattern: &str, case_sensitive: bool) -> Option<regex::Regex> {
177        let pattern_str = if case_sensitive {
178            pattern.to_string()
179        } else {
180            format!("(?i){}", pattern)
181        };
182        
183        regex::Regex::new(&pattern_str).ok()
184    }
185}
186
187impl ContentFilter for PatternFilter {
188    fn check(&self, content: &str) -> Vec<Violation> {
189        let mut violations = Vec::new();
190
191        for rule in &self.rules {
192            if let Some(re) = Self::compile_pattern(&rule.pattern_str, rule.case_sensitive) {
193                if let Some(m) = re.find(content) {
194                    violations.push(Violation {
195                        violation_type: ViolationType::Pattern,
196                        pattern: rule.pattern_str.clone(),
197                        action: rule.action,
198                        category: rule.category.clone(),
199                        description: rule.description.clone(),
200                        matched_text: Some(m.as_str().to_string()),
201                    });
202                }
203            }
204        }
205
206        violations
207    }
208
209    fn sanitize(&self, content: &str, replacement: &str) -> String {
210        let mut result = content.to_string();
211
212        for rule in &self.rules {
213            if matches!(rule.action, FilterAction::Sanitize | FilterAction::Block) {
214                if let Some(re) = Self::compile_pattern(&rule.pattern_str, rule.case_sensitive) {
215                    result = re.replace_all(&result, replacement).to_string();
216                }
217            }
218        }
219
220        result
221    }
222}