bitrouter_guardrails/
config.rs1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::pattern::PatternId;
6use crate::rule::Action;
7
8pub const REPO_URL: &str = "https://github.com/bitrouter/bitrouter";
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct GuardrailConfig {
39 #[serde(default = "default_enabled")]
41 pub enabled: bool,
42
43 #[serde(default)]
46 pub disabled_patterns: Vec<PatternId>,
47
48 #[serde(default)]
50 pub custom_patterns: Vec<CustomPatternDef>,
51
52 #[serde(default)]
54 pub upgoing: HashMap<PatternId, Action>,
55
56 #[serde(default)]
58 pub downgoing: HashMap<PatternId, Action>,
59
60 #[serde(default)]
62 pub custom_upgoing: HashMap<String, Action>,
63
64 #[serde(default)]
66 pub custom_downgoing: HashMap<String, Action>,
67
68 #[serde(default)]
70 pub block_message: BlockMessageConfig,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct CustomPatternDef {
76 pub name: String,
79
80 pub regex: String,
82
83 #[serde(default)]
86 pub direction: PatternDirection,
87}
88
89#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
91#[serde(rename_all = "snake_case")]
92pub enum PatternDirection {
93 #[default]
95 Upgoing,
96 Downgoing,
98 Both,
100}
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct BlockMessageConfig {
105 #[serde(default = "default_true")]
107 pub include_details: bool,
108
109 #[serde(default = "default_true")]
111 pub include_help_link: bool,
112}
113
114impl Default for BlockMessageConfig {
115 fn default() -> Self {
116 Self {
117 include_details: true,
118 include_help_link: true,
119 }
120 }
121}
122
123fn default_enabled() -> bool {
124 true
125}
126
127fn default_true() -> bool {
128 true
129}
130
131impl Default for GuardrailConfig {
132 fn default() -> Self {
133 Self {
134 enabled: true,
135 disabled_patterns: Vec::new(),
136 custom_patterns: Vec::new(),
137 upgoing: HashMap::new(),
138 downgoing: HashMap::new(),
139 custom_upgoing: HashMap::new(),
140 custom_downgoing: HashMap::new(),
141 block_message: BlockMessageConfig::default(),
142 }
143 }
144}
145
146impl GuardrailConfig {
147 pub fn is_pattern_disabled(&self, id: PatternId) -> bool {
150 self.disabled_patterns.contains(&id)
151 }
152
153 pub fn upgoing_action(&self, id: PatternId) -> Action {
157 self.upgoing.get(&id).copied().unwrap_or(Action::Warn)
158 }
159
160 pub fn downgoing_action(&self, id: PatternId) -> Action {
164 self.downgoing.get(&id).copied().unwrap_or(Action::Warn)
165 }
166
167 pub fn custom_upgoing_action(&self, name: &str) -> Action {
169 self.custom_upgoing
170 .get(name)
171 .copied()
172 .unwrap_or(Action::Warn)
173 }
174
175 pub fn custom_downgoing_action(&self, name: &str) -> Action {
177 self.custom_downgoing
178 .get(name)
179 .copied()
180 .unwrap_or(Action::Warn)
181 }
182
183 pub fn format_block_message(&self, direction: &str, description: &str) -> String {
185 let mut msg = format!("guardrail blocked {direction} content");
186
187 if self.block_message.include_details {
188 msg.push_str(&format!(": {description}"));
189 }
190
191 if self.block_message.include_help_link {
192 msg.push_str(&format!(". For more information, see {REPO_URL}"));
193 }
194
195 msg
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
204 fn default_config_is_enabled_with_warn() {
205 let config = GuardrailConfig::default();
206 assert!(config.enabled);
207 assert_eq!(config.upgoing_action(PatternId::ApiKeys), Action::Warn);
208 assert_eq!(
209 config.downgoing_action(PatternId::SuspiciousCommands),
210 Action::Warn
211 );
212 assert!(config.disabled_patterns.is_empty());
213 assert!(config.custom_patterns.is_empty());
214 assert!(config.block_message.include_details);
215 assert!(config.block_message.include_help_link);
216 }
217
218 #[test]
219 fn override_action_takes_precedence() {
220 let mut config = GuardrailConfig::default();
221 config.upgoing.insert(PatternId::ApiKeys, Action::Redact);
222 config.upgoing.insert(PatternId::PrivateKeys, Action::Block);
223 assert_eq!(config.upgoing_action(PatternId::ApiKeys), Action::Redact);
224 assert_eq!(config.upgoing_action(PatternId::PrivateKeys), Action::Block);
225 assert_eq!(config.upgoing_action(PatternId::Credentials), Action::Warn);
227 }
228
229 #[test]
230 fn config_round_trips_through_yaml() {
231 let yaml = r#"
232enabled: true
233upgoing:
234 api_keys: redact
235 private_keys: block
236downgoing:
237 suspicious_commands: block
238"#;
239 let config: GuardrailConfig = serde_saphyr::from_str(yaml).unwrap();
240 assert!(config.enabled);
241 assert_eq!(config.upgoing_action(PatternId::ApiKeys), Action::Redact);
242 assert_eq!(config.upgoing_action(PatternId::PrivateKeys), Action::Block);
243 assert_eq!(
244 config.downgoing_action(PatternId::SuspiciousCommands),
245 Action::Block
246 );
247
248 let serialized = serde_saphyr::to_string(&config).unwrap();
250 let reparsed: GuardrailConfig = serde_saphyr::from_str(&serialized).unwrap();
251 assert_eq!(reparsed.upgoing_action(PatternId::ApiKeys), Action::Redact);
252 }
253
254 #[test]
255 fn empty_yaml_deserializes_to_defaults() {
256 let config: GuardrailConfig = serde_saphyr::from_str("{}").unwrap();
257 assert!(config.enabled);
258 assert!(config.upgoing.is_empty());
259 assert!(config.downgoing.is_empty());
260 assert!(config.disabled_patterns.is_empty());
261 assert!(config.custom_patterns.is_empty());
262 }
263
264 #[test]
265 fn disabled_patterns_from_yaml() {
266 let yaml = r#"
267disabled_patterns:
268 - ip_addresses
269 - pii_phone_numbers
270"#;
271 let config: GuardrailConfig = serde_saphyr::from_str(yaml).unwrap();
272 assert!(config.is_pattern_disabled(PatternId::IpAddresses));
273 assert!(config.is_pattern_disabled(PatternId::PiiPhoneNumbers));
274 assert!(!config.is_pattern_disabled(PatternId::ApiKeys));
275 }
276
277 #[test]
278 fn custom_patterns_from_yaml() {
279 let yaml = r#"
280custom_patterns:
281 - name: my_token
282 regex: "myapp_[A-Za-z0-9]{32}"
283 direction: upgoing
284 - name: bad_url
285 regex: "https://evil\\.com"
286 direction: downgoing
287 - name: both_dirs
288 regex: "secret_value"
289 direction: both
290"#;
291 let config: GuardrailConfig = serde_saphyr::from_str(yaml).unwrap();
292 assert_eq!(config.custom_patterns.len(), 3);
293 assert_eq!(config.custom_patterns[0].name, "my_token");
294 assert_eq!(
295 config.custom_patterns[0].direction,
296 PatternDirection::Upgoing
297 );
298 assert_eq!(
299 config.custom_patterns[1].direction,
300 PatternDirection::Downgoing
301 );
302 assert_eq!(config.custom_patterns[2].direction, PatternDirection::Both);
303 }
304
305 #[test]
306 fn custom_pattern_action_overrides() {
307 let yaml = r#"
308custom_patterns:
309 - name: my_token
310 regex: "myapp_[A-Za-z0-9]{32}"
311custom_upgoing:
312 my_token: block
313"#;
314 let config: GuardrailConfig = serde_saphyr::from_str(yaml).unwrap();
315 assert_eq!(config.custom_upgoing_action("my_token"), Action::Block);
316 assert_eq!(config.custom_upgoing_action("nonexistent"), Action::Warn);
317 }
318
319 #[test]
320 fn block_message_config_from_yaml() {
321 let yaml = r#"
322block_message:
323 include_details: false
324 include_help_link: false
325"#;
326 let config: GuardrailConfig = serde_saphyr::from_str(yaml).unwrap();
327 assert!(!config.block_message.include_details);
328 assert!(!config.block_message.include_help_link);
329 }
330
331 #[test]
332 fn format_block_message_full() {
333 let config = GuardrailConfig::default();
334 let msg = config.format_block_message("upgoing", "API keys detected");
335 assert!(msg.contains("API keys detected"));
336 assert!(msg.contains(REPO_URL));
337 }
338
339 #[test]
340 fn format_block_message_no_details() {
341 let mut config = GuardrailConfig::default();
342 config.block_message.include_details = false;
343 let msg = config.format_block_message("upgoing", "API keys detected");
344 assert!(!msg.contains("API keys detected"));
345 assert!(msg.contains(REPO_URL));
346 }
347
348 #[test]
349 fn format_block_message_no_link() {
350 let mut config = GuardrailConfig::default();
351 config.block_message.include_help_link = false;
352 let msg = config.format_block_message("upgoing", "API keys detected");
353 assert!(msg.contains("API keys detected"));
354 assert!(!msg.contains(REPO_URL));
355 }
356
357 #[test]
358 fn format_block_message_bare() {
359 let mut config = GuardrailConfig::default();
360 config.block_message.include_details = false;
361 config.block_message.include_help_link = false;
362 let msg = config.format_block_message("upgoing", "API keys detected");
363 assert_eq!(msg, "guardrail blocked upgoing content");
364 }
365}