Skip to main content

bitrouter_guardrails/
config.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5use crate::pattern::PatternId;
6use crate::rule::Action;
7
8/// The repository URL included in block messages when `include_help_link` is
9/// enabled.
10pub const REPO_URL: &str = "https://github.com/bitrouter/bitrouter";
11
12/// Top-level guardrail configuration, embedded in the bitrouter config under
13/// the `guardrails` key.
14///
15/// ```yaml
16/// guardrails:
17///   enabled: true
18///   disabled_patterns:
19///     - ip_addresses
20///     - pii_phone_numbers
21///   custom_patterns:
22///     - name: my_token
23///       regex: "myapp_[A-Za-z0-9]{32}"
24///       direction: upgoing
25///   upgoing:
26///     api_keys: redact
27///     private_keys: block
28///   downgoing:
29///     suspicious_commands: block
30///   block_message:
31///     include_details: true
32///     include_help_link: true
33/// ```
34///
35/// Any pattern not explicitly listed uses the default action for its
36/// direction (`Warn` for both upgoing and downgoing).
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct GuardrailConfig {
39    /// Master switch. When `false` the guardrail engine is a no-op.
40    #[serde(default = "default_enabled")]
41    pub enabled: bool,
42
43    /// Built-in patterns to disable. Any [`PatternId`] listed here will be
44    /// skipped during inspection, regardless of the action configured for it.
45    #[serde(default)]
46    pub disabled_patterns: Vec<PatternId>,
47
48    /// User-defined custom patterns appended to the built-in set.
49    #[serde(default)]
50    pub custom_patterns: Vec<CustomPatternDef>,
51
52    /// Per-pattern action overrides for **outbound** traffic (user → LLM).
53    #[serde(default)]
54    pub upgoing: HashMap<PatternId, Action>,
55
56    /// Per-pattern action overrides for **inbound** traffic (LLM → user).
57    #[serde(default)]
58    pub downgoing: HashMap<PatternId, Action>,
59
60    /// Per-custom-pattern action overrides for **outbound** traffic.
61    #[serde(default)]
62    pub custom_upgoing: HashMap<String, Action>,
63
64    /// Per-custom-pattern action overrides for **inbound** traffic.
65    #[serde(default)]
66    pub custom_downgoing: HashMap<String, Action>,
67
68    /// Controls the content of error messages produced when content is blocked.
69    #[serde(default)]
70    pub block_message: BlockMessageConfig,
71}
72
73/// Configuration for a user-defined custom pattern.
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct CustomPatternDef {
76    /// Unique name for this custom pattern (used to reference it in
77    /// `custom_upgoing` / `custom_downgoing` action maps).
78    pub name: String,
79
80    /// The regex pattern string.
81    pub regex: String,
82
83    /// Whether the pattern applies to outbound (`upgoing`), inbound
84    /// (`downgoing`), or `both` directions.
85    #[serde(default)]
86    pub direction: PatternDirection,
87}
88
89/// Direction for a custom pattern.
90#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
91#[serde(rename_all = "snake_case")]
92pub enum PatternDirection {
93    /// Apply only to outbound (user → LLM) traffic.
94    #[default]
95    Upgoing,
96    /// Apply only to inbound (LLM → user) traffic.
97    Downgoing,
98    /// Apply in both directions.
99    Both,
100}
101
102/// Controls what information is included in block error messages.
103#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct BlockMessageConfig {
105    /// Include a human-readable description of why the content was blocked.
106    #[serde(default = "default_true")]
107    pub include_details: bool,
108
109    /// Include a link to the bitrouter repository for further information.
110    #[serde(default = "default_true")]
111    pub include_help_link: bool,
112}
113
114impl Default for BlockMessageConfig {
115    fn default() -> Self {
116        Self {
117            include_details: true,
118            include_help_link: true,
119        }
120    }
121}
122
123fn default_enabled() -> bool {
124    true
125}
126
127fn default_true() -> bool {
128    true
129}
130
131impl Default for GuardrailConfig {
132    fn default() -> Self {
133        Self {
134            enabled: true,
135            disabled_patterns: Vec::new(),
136            custom_patterns: Vec::new(),
137            upgoing: HashMap::new(),
138            downgoing: HashMap::new(),
139            custom_upgoing: HashMap::new(),
140            custom_downgoing: HashMap::new(),
141            block_message: BlockMessageConfig::default(),
142        }
143    }
144}
145
146impl GuardrailConfig {
147    /// Returns `true` if the given built-in pattern has been disabled by the
148    /// user.
149    pub fn is_pattern_disabled(&self, id: PatternId) -> bool {
150        self.disabled_patterns.contains(&id)
151    }
152
153    /// Resolve the effective action for an upgoing pattern.
154    ///
155    /// Returns the user-configured action if present, otherwise `Warn`.
156    pub fn upgoing_action(&self, id: PatternId) -> Action {
157        self.upgoing.get(&id).copied().unwrap_or(Action::Warn)
158    }
159
160    /// Resolve the effective action for a downgoing pattern.
161    ///
162    /// Returns the user-configured action if present, otherwise `Warn`.
163    pub fn downgoing_action(&self, id: PatternId) -> Action {
164        self.downgoing.get(&id).copied().unwrap_or(Action::Warn)
165    }
166
167    /// Resolve the effective action for a custom upgoing pattern.
168    pub fn custom_upgoing_action(&self, name: &str) -> Action {
169        self.custom_upgoing
170            .get(name)
171            .copied()
172            .unwrap_or(Action::Warn)
173    }
174
175    /// Resolve the effective action for a custom downgoing pattern.
176    pub fn custom_downgoing_action(&self, name: &str) -> Action {
177        self.custom_downgoing
178            .get(name)
179            .copied()
180            .unwrap_or(Action::Warn)
181    }
182
183    /// Format a block error message, respecting `block_message` config.
184    pub fn format_block_message(&self, direction: &str, description: &str) -> String {
185        let mut msg = format!("guardrail blocked {direction} content");
186
187        if self.block_message.include_details {
188            msg.push_str(&format!(": {description}"));
189        }
190
191        if self.block_message.include_help_link {
192            msg.push_str(&format!(". For more information, see {REPO_URL}"));
193        }
194
195        msg
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202
203    #[test]
204    fn default_config_is_enabled_with_warn() {
205        let config = GuardrailConfig::default();
206        assert!(config.enabled);
207        assert_eq!(config.upgoing_action(PatternId::ApiKeys), Action::Warn);
208        assert_eq!(
209            config.downgoing_action(PatternId::SuspiciousCommands),
210            Action::Warn
211        );
212        assert!(config.disabled_patterns.is_empty());
213        assert!(config.custom_patterns.is_empty());
214        assert!(config.block_message.include_details);
215        assert!(config.block_message.include_help_link);
216    }
217
218    #[test]
219    fn override_action_takes_precedence() {
220        let mut config = GuardrailConfig::default();
221        config.upgoing.insert(PatternId::ApiKeys, Action::Redact);
222        config.upgoing.insert(PatternId::PrivateKeys, Action::Block);
223        assert_eq!(config.upgoing_action(PatternId::ApiKeys), Action::Redact);
224        assert_eq!(config.upgoing_action(PatternId::PrivateKeys), Action::Block);
225        // Unset patterns still default to Warn
226        assert_eq!(config.upgoing_action(PatternId::Credentials), Action::Warn);
227    }
228
229    #[test]
230    fn config_round_trips_through_yaml() {
231        let yaml = r#"
232enabled: true
233upgoing:
234  api_keys: redact
235  private_keys: block
236downgoing:
237  suspicious_commands: block
238"#;
239        let config: GuardrailConfig = serde_saphyr::from_str(yaml).unwrap();
240        assert!(config.enabled);
241        assert_eq!(config.upgoing_action(PatternId::ApiKeys), Action::Redact);
242        assert_eq!(config.upgoing_action(PatternId::PrivateKeys), Action::Block);
243        assert_eq!(
244            config.downgoing_action(PatternId::SuspiciousCommands),
245            Action::Block
246        );
247
248        // Round-trip through serialization
249        let serialized = serde_saphyr::to_string(&config).unwrap();
250        let reparsed: GuardrailConfig = serde_saphyr::from_str(&serialized).unwrap();
251        assert_eq!(reparsed.upgoing_action(PatternId::ApiKeys), Action::Redact);
252    }
253
254    #[test]
255    fn empty_yaml_deserializes_to_defaults() {
256        let config: GuardrailConfig = serde_saphyr::from_str("{}").unwrap();
257        assert!(config.enabled);
258        assert!(config.upgoing.is_empty());
259        assert!(config.downgoing.is_empty());
260        assert!(config.disabled_patterns.is_empty());
261        assert!(config.custom_patterns.is_empty());
262    }
263
264    #[test]
265    fn disabled_patterns_from_yaml() {
266        let yaml = r#"
267disabled_patterns:
268  - ip_addresses
269  - pii_phone_numbers
270"#;
271        let config: GuardrailConfig = serde_saphyr::from_str(yaml).unwrap();
272        assert!(config.is_pattern_disabled(PatternId::IpAddresses));
273        assert!(config.is_pattern_disabled(PatternId::PiiPhoneNumbers));
274        assert!(!config.is_pattern_disabled(PatternId::ApiKeys));
275    }
276
277    #[test]
278    fn custom_patterns_from_yaml() {
279        let yaml = r#"
280custom_patterns:
281  - name: my_token
282    regex: "myapp_[A-Za-z0-9]{32}"
283    direction: upgoing
284  - name: bad_url
285    regex: "https://evil\\.com"
286    direction: downgoing
287  - name: both_dirs
288    regex: "secret_value"
289    direction: both
290"#;
291        let config: GuardrailConfig = serde_saphyr::from_str(yaml).unwrap();
292        assert_eq!(config.custom_patterns.len(), 3);
293        assert_eq!(config.custom_patterns[0].name, "my_token");
294        assert_eq!(
295            config.custom_patterns[0].direction,
296            PatternDirection::Upgoing
297        );
298        assert_eq!(
299            config.custom_patterns[1].direction,
300            PatternDirection::Downgoing
301        );
302        assert_eq!(config.custom_patterns[2].direction, PatternDirection::Both);
303    }
304
305    #[test]
306    fn custom_pattern_action_overrides() {
307        let yaml = r#"
308custom_patterns:
309  - name: my_token
310    regex: "myapp_[A-Za-z0-9]{32}"
311custom_upgoing:
312  my_token: block
313"#;
314        let config: GuardrailConfig = serde_saphyr::from_str(yaml).unwrap();
315        assert_eq!(config.custom_upgoing_action("my_token"), Action::Block);
316        assert_eq!(config.custom_upgoing_action("nonexistent"), Action::Warn);
317    }
318
319    #[test]
320    fn block_message_config_from_yaml() {
321        let yaml = r#"
322block_message:
323  include_details: false
324  include_help_link: false
325"#;
326        let config: GuardrailConfig = serde_saphyr::from_str(yaml).unwrap();
327        assert!(!config.block_message.include_details);
328        assert!(!config.block_message.include_help_link);
329    }
330
331    #[test]
332    fn format_block_message_full() {
333        let config = GuardrailConfig::default();
334        let msg = config.format_block_message("upgoing", "API keys detected");
335        assert!(msg.contains("API keys detected"));
336        assert!(msg.contains(REPO_URL));
337    }
338
339    #[test]
340    fn format_block_message_no_details() {
341        let mut config = GuardrailConfig::default();
342        config.block_message.include_details = false;
343        let msg = config.format_block_message("upgoing", "API keys detected");
344        assert!(!msg.contains("API keys detected"));
345        assert!(msg.contains(REPO_URL));
346    }
347
348    #[test]
349    fn format_block_message_no_link() {
350        let mut config = GuardrailConfig::default();
351        config.block_message.include_help_link = false;
352        let msg = config.format_block_message("upgoing", "API keys detected");
353        assert!(msg.contains("API keys detected"));
354        assert!(!msg.contains(REPO_URL));
355    }
356
357    #[test]
358    fn format_block_message_bare() {
359        let mut config = GuardrailConfig::default();
360        config.block_message.include_details = false;
361        config.block_message.include_help_link = false;
362        let msg = config.format_block_message("upgoing", "API keys detected");
363        assert_eq!(msg, "guardrail blocked upgoing content");
364    }
365}