heartbit_core/agent/
guardrail.rs1#![allow(missing_docs)]
4use std::future::Future;
5use std::pin::Pin;
6
7use crate::error::Error;
8use crate::llm::types::{CompletionRequest, CompletionResponse, ToolCall};
9use crate::tool::ToolOutput;
10
11#[derive(Debug, Clone, PartialEq)]
13pub enum GuardAction {
14 Allow,
16 Deny { reason: String },
18 Warn { reason: String },
24 Kill { reason: String },
29}
30
31impl GuardAction {
32 pub fn deny(reason: impl Into<String>) -> Self {
34 GuardAction::Deny {
35 reason: reason.into(),
36 }
37 }
38
39 pub fn warn(reason: impl Into<String>) -> Self {
41 GuardAction::Warn {
42 reason: reason.into(),
43 }
44 }
45
46 pub fn kill(reason: impl Into<String>) -> Self {
48 GuardAction::Kill {
49 reason: reason.into(),
50 }
51 }
52
53 pub fn is_denied(&self) -> bool {
55 matches!(self, GuardAction::Deny { .. } | GuardAction::Kill { .. })
56 }
57
58 pub fn is_killed(&self) -> bool {
60 matches!(self, GuardAction::Kill { .. })
61 }
62}
63
64pub trait Guardrail: Send + Sync {
109 fn name(&self) -> &str {
112 "unnamed"
113 }
114
115 fn pre_llm(
118 &self,
119 _request: &mut CompletionRequest,
120 ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
121 Box::pin(async { Ok(()) })
122 }
123
124 fn post_llm(
139 &self,
140 _response: &mut CompletionResponse,
141 ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
142 Box::pin(async { Ok(GuardAction::Allow) })
143 }
144
145 fn pre_tool(
148 &self,
149 _call: &ToolCall,
150 ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
151 Box::pin(async { Ok(GuardAction::Allow) })
152 }
153
154 fn post_tool(
158 &self,
159 _call: &ToolCall,
160 _output: &mut ToolOutput,
161 ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
162 Box::pin(async { Ok(()) })
163 }
164
165 fn set_turn(&self, _turn: usize) {}
169}
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn guard_action_deny_constructor() {
177 let action = GuardAction::deny("PII detected");
178 match action {
179 GuardAction::Deny { reason } => assert_eq!(reason, "PII detected"),
180 _ => panic!("expected Deny"),
181 }
182 }
183
184 #[test]
185 fn guard_action_warn_constructor() {
186 let action = GuardAction::warn("suspicious pattern");
187 match action {
188 GuardAction::Warn { reason } => assert_eq!(reason, "suspicious pattern"),
189 _ => panic!("expected Warn"),
190 }
191 }
192
193 #[test]
194 fn guard_action_is_denied() {
195 assert!(GuardAction::deny("blocked").is_denied());
196 assert!(GuardAction::kill("critical").is_denied());
197 assert!(!GuardAction::Allow.is_denied());
198 assert!(!GuardAction::warn("suspicious").is_denied());
199 }
200
201 #[test]
202 fn guard_action_kill_constructor() {
203 let action = GuardAction::kill("CSAM detected");
204 match action {
205 GuardAction::Kill { reason } => assert_eq!(reason, "CSAM detected"),
206 _ => panic!("expected Kill"),
207 }
208 }
209
210 #[test]
211 fn guard_action_is_killed() {
212 assert!(GuardAction::kill("critical").is_killed());
213 assert!(!GuardAction::deny("blocked").is_killed());
214 assert!(!GuardAction::Allow.is_killed());
215 assert!(!GuardAction::warn("suspicious").is_killed());
216 }
217
218 #[test]
219 fn guardrail_default_name() {
220 struct MyGuardrail;
221 impl Guardrail for MyGuardrail {}
222 let g = MyGuardrail;
223 assert_eq!(g.name(), "unnamed");
224 }
225
226 #[test]
227 fn guardrail_custom_name() {
228 struct NamedGuardrail;
229 impl Guardrail for NamedGuardrail {
230 fn name(&self) -> &str {
231 "pii_detector"
232 }
233 }
234 let g = NamedGuardrail;
235 assert_eq!(g.name(), "pii_detector");
236 }
237
238 #[tokio::test]
239 async fn default_guardrail_allows_everything() {
240 struct NoOpGuardrail;
241 impl Guardrail for NoOpGuardrail {}
242
243 let g = NoOpGuardrail;
244
245 let mut request = CompletionRequest {
246 system: "sys".into(),
247 messages: vec![],
248 tools: vec![],
249 max_tokens: 1024,
250 tool_choice: None,
251 reasoning_effort: None,
252 };
253 g.pre_llm(&mut request).await.unwrap();
254
255 let mut response = CompletionResponse {
256 content: vec![],
257 stop_reason: crate::llm::types::StopReason::EndTurn,
258 usage: crate::llm::types::TokenUsage::default(),
259 model: None,
260 };
261 let action = g.post_llm(&mut response).await.unwrap();
262 assert!(matches!(action, GuardAction::Allow));
263
264 let call = ToolCall {
265 id: "c1".into(),
266 name: "test".into(),
267 input: serde_json::json!({}),
268 };
269 let action = g.pre_tool(&call).await.unwrap();
270 assert!(matches!(action, GuardAction::Allow));
271
272 let mut output = ToolOutput::success("result".to_string());
273 g.post_tool(&call, &mut output).await.unwrap();
274 assert_eq!(output.content, "result");
275 }
276
277 #[tokio::test]
278 async fn post_tool_can_mutate_output() {
279 struct RedactGuardrail;
280 impl Guardrail for RedactGuardrail {
281 fn post_tool(
282 &self,
283 _call: &ToolCall,
284 output: &mut ToolOutput,
285 ) -> Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>>
286 {
287 output.content = output.content.replace("secret", "[REDACTED]");
289 Box::pin(async { Ok(()) })
290 }
291 }
292
293 let g = RedactGuardrail;
294 let call = ToolCall {
295 id: "c1".into(),
296 name: "test".into(),
297 input: serde_json::json!({}),
298 };
299 let mut output = ToolOutput::success("the secret is 42".to_string());
300 g.post_tool(&call, &mut output).await.unwrap();
301 assert_eq!(output.content, "the [REDACTED] is 42");
302 }
303
304 #[tokio::test]
305 async fn pre_tool_deny_returns_deny_action() {
306 struct BlockBashGuardrail;
307 impl Guardrail for BlockBashGuardrail {
308 fn pre_tool(
309 &self,
310 call: &ToolCall,
311 ) -> Pin<Box<dyn std::future::Future<Output = Result<GuardAction, Error>> + Send + '_>>
312 {
313 let name = call.name.clone();
314 Box::pin(async move {
315 if name == "bash" {
316 Ok(GuardAction::deny("bash tool is disabled"))
317 } else {
318 Ok(GuardAction::Allow)
319 }
320 })
321 }
322 }
323
324 let g = BlockBashGuardrail;
325 let bash_call = ToolCall {
326 id: "c1".into(),
327 name: "bash".into(),
328 input: serde_json::json!({"command": "rm -rf /"}),
329 };
330 let action = g.pre_tool(&bash_call).await.unwrap();
331 assert!(
332 matches!(action, GuardAction::Deny { reason } if reason == "bash tool is disabled")
333 );
334
335 let read_call = ToolCall {
336 id: "c2".into(),
337 name: "read".into(),
338 input: serde_json::json!({"path": "/tmp/test.txt"}),
339 };
340 let action = g.pre_tool(&read_call).await.unwrap();
341 assert!(matches!(action, GuardAction::Allow));
342 }
343}