Skip to main content

ai_lib_rust/guardrails/
config.rs

1//! Guardrails configuration
2
3use serde::{Deserialize, Serialize};
4
5/// Action to take when a filter rule matches
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7#[serde(rename_all = "snake_case")]
8pub enum FilterAction {
9    /// Block the content entirely
10    Block,
11    /// Allow but log a warning
12    Warn,
13    /// Log for audit purposes only
14    Log,
15    /// Sanitize (remove/replace) the matched content
16    Sanitize,
17    /// Allow without any action
18    Allow,
19}
20
21impl Default for FilterAction {
22    fn default() -> Self {
23        FilterAction::Warn
24    }
25}
26
27/// A filter rule definition
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct FilterRule {
30    /// The pattern or keyword to match
31    pub pattern: String,
32    /// Whether this is a regex pattern
33    pub is_regex: bool,
34    /// Case-sensitive matching
35    pub case_sensitive: bool,
36    /// Action to take when matched
37    pub action: FilterAction,
38    /// Optional category for grouping
39    pub category: Option<String>,
40    /// Optional description
41    pub description: Option<String>,
42}
43
44impl FilterRule {
45    /// Create a simple keyword rule
46    pub fn keyword(pattern: impl Into<String>, action: FilterAction) -> Self {
47        Self {
48            pattern: pattern.into(),
49            is_regex: false,
50            case_sensitive: false,
51            action,
52            category: None,
53            description: None,
54        }
55    }
56
57    /// Create a regex pattern rule
58    pub fn regex(pattern: impl Into<String>, action: FilterAction) -> Self {
59        Self {
60            pattern: pattern.into(),
61            is_regex: true,
62            case_sensitive: true,
63            action,
64            category: None,
65            description: None,
66        }
67    }
68
69    /// Set case sensitivity
70    pub fn case_sensitive(mut self, sensitive: bool) -> Self {
71        self.case_sensitive = sensitive;
72        self
73    }
74
75    /// Set category
76    pub fn with_category(mut self, category: impl Into<String>) -> Self {
77        self.category = Some(category.into());
78        self
79    }
80
81    /// Set description
82    pub fn with_description(mut self, description: impl Into<String>) -> Self {
83        self.description = Some(description.into());
84        self
85    }
86}
87
88/// Configuration for the Guardrails system
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct GuardrailsConfig {
91    /// Whether to filter input content
92    pub filter_input: bool,
93    /// Whether to filter output content
94    pub filter_output: bool,
95    /// Keyword-based filter rules
96    pub keyword_rules: Vec<FilterRule>,
97    /// Regex pattern-based filter rules
98    pub pattern_rules: Vec<FilterRule>,
99    /// Enable PII (Personally Identifiable Information) detection
100    pub enable_pii_detection: bool,
101    /// Check PII in input
102    pub check_pii_input: bool,
103    /// Check PII in output
104    pub check_pii_output: bool,
105    /// Replacement string for sanitization
106    pub sanitize_replacement: String,
107    /// Replacement string for PII
108    pub pii_replacement: String,
109    /// Stop checking on first block
110    pub stop_on_first_block: bool,
111}
112
113impl GuardrailsConfig {
114    /// Create a builder for GuardrailsConfig
115    pub fn builder() -> GuardrailsConfigBuilder {
116        GuardrailsConfigBuilder::default()
117    }
118
119    /// Create a permissive configuration (no filtering)
120    pub fn permissive() -> Self {
121        Self {
122            filter_input: false,
123            filter_output: false,
124            keyword_rules: Vec::new(),
125            pattern_rules: Vec::new(),
126            enable_pii_detection: false,
127            check_pii_input: false,
128            check_pii_output: false,
129            sanitize_replacement: "[FILTERED]".to_string(),
130            pii_replacement: "[PII]".to_string(),
131            stop_on_first_block: false,
132        }
133    }
134
135    /// Create a strict configuration with common safety rules
136    pub fn strict() -> Self {
137        let mut config = Self::permissive();
138        config.filter_input = true;
139        config.filter_output = true;
140        config.enable_pii_detection = true;
141        config.check_pii_input = true;
142        config.check_pii_output = true;
143        config.stop_on_first_block = true;
144        
145        // Add common sensitive keyword patterns
146        config.keyword_rules = vec![
147            FilterRule::keyword("password", FilterAction::Warn)
148                .with_category("credentials")
149                .with_description("Password mention"),
150            FilterRule::keyword("api_key", FilterAction::Warn)
151                .with_category("credentials")
152                .with_description("API key mention"),
153            FilterRule::keyword("secret_key", FilterAction::Warn)
154                .with_category("credentials")
155                .with_description("Secret key mention"),
156            FilterRule::keyword("access_token", FilterAction::Warn)
157                .with_category("credentials")
158                .with_description("Access token mention"),
159        ];
160
161        config
162    }
163}
164
165impl Default for GuardrailsConfig {
166    fn default() -> Self {
167        Self::permissive()
168    }
169}
170
171/// Builder for GuardrailsConfig
172#[derive(Debug, Default)]
173pub struct GuardrailsConfigBuilder {
174    filter_input: bool,
175    filter_output: bool,
176    keyword_rules: Vec<FilterRule>,
177    pattern_rules: Vec<FilterRule>,
178    enable_pii_detection: bool,
179    check_pii_input: bool,
180    check_pii_output: bool,
181    sanitize_replacement: Option<String>,
182    pii_replacement: Option<String>,
183    stop_on_first_block: bool,
184}
185
186impl GuardrailsConfigBuilder {
187    /// Enable input filtering
188    pub fn filter_input(mut self, enable: bool) -> Self {
189        self.filter_input = enable;
190        self
191    }
192
193    /// Enable output filtering
194    pub fn filter_output(mut self, enable: bool) -> Self {
195        self.filter_output = enable;
196        self
197    }
198
199    /// Add a keyword filter rule
200    pub fn add_keyword_filter(mut self, keyword: impl Into<String>, action: FilterAction) -> Self {
201        self.filter_input = true; // Auto-enable input filtering
202        self.keyword_rules.push(FilterRule::keyword(keyword, action));
203        self
204    }
205
206    /// Add a regex pattern filter rule
207    pub fn add_pattern_filter(mut self, pattern: impl Into<String>, action: FilterAction) -> Self {
208        self.filter_input = true; // Auto-enable input filtering
209        self.pattern_rules.push(FilterRule::regex(pattern, action));
210        self
211    }
212
213    /// Add a custom filter rule
214    pub fn add_rule(mut self, rule: FilterRule) -> Self {
215        self.filter_input = true;
216        if rule.is_regex {
217            self.pattern_rules.push(rule);
218        } else {
219            self.keyword_rules.push(rule);
220        }
221        self
222    }
223
224    /// Enable PII detection
225    pub fn enable_pii_detection(mut self, enable: bool) -> Self {
226        self.enable_pii_detection = enable;
227        self.check_pii_input = enable;
228        self.check_pii_output = enable;
229        self
230    }
231
232    /// Set the sanitize replacement string
233    pub fn sanitize_replacement(mut self, replacement: String) -> Self {
234        self.sanitize_replacement = Some(replacement);
235        self
236    }
237
238    /// Set the PII replacement string
239    pub fn pii_replacement(mut self, replacement: String) -> Self {
240        self.pii_replacement = Some(replacement);
241        self
242    }
243
244    /// Stop on first block
245    pub fn stop_on_first_block(mut self, stop: bool) -> Self {
246        self.stop_on_first_block = stop;
247        self
248    }
249
250    /// Build the configuration
251    pub fn build(self) -> GuardrailsConfig {
252        GuardrailsConfig {
253            filter_input: self.filter_input,
254            filter_output: self.filter_output,
255            keyword_rules: self.keyword_rules,
256            pattern_rules: self.pattern_rules,
257            enable_pii_detection: self.enable_pii_detection,
258            check_pii_input: self.check_pii_input,
259            check_pii_output: self.check_pii_output,
260            sanitize_replacement: self.sanitize_replacement.unwrap_or_else(|| "[FILTERED]".to_string()),
261            pii_replacement: self.pii_replacement.unwrap_or_else(|| "[PII]".to_string()),
262            stop_on_first_block: self.stop_on_first_block,
263        }
264    }
265}