Skip to main content

heartbit_core/agent/
guardrail.rs

1//! `Guardrail` trait and `GuardAction` — base types for all guardrail implementations.
2
3#![allow(missing_docs)]
4use std::future::Future;
5use std::pin::Pin;
6
7use crate::error::Error;
8use crate::llm::types::{CompletionRequest, CompletionResponse, ToolCall};
9use crate::tool::ToolOutput;
10
11/// Action returned by guardrail hooks that can deny operations.
12#[derive(Debug, Clone, PartialEq)]
13pub enum GuardAction {
14    /// Allow the operation to proceed.
15    Allow,
16    /// Deny the operation with a reason.
17    Deny { reason: String },
18    /// Log the concern but allow the operation to proceed.
19    ///
20    /// The agent loop treats `Warn` like `Allow` but emits
21    /// `AgentEvent::GuardrailWarned` and an audit record. This enables
22    /// monitoring mode (shadow enforcement) without blocking production.
23    Warn { reason: String },
24    /// Immediately terminate the agent run. Used for critical detections
25    /// (e.g., CSAM, active exploitation) where the agent loop must stop
26    /// without further processing. The agent emits `KillSwitchActivated`
27    /// and returns `Error::KillSwitch`.
28    Kill { reason: String },
29}
30
31impl GuardAction {
32    /// Create a `Deny` action with the given reason.
33    pub fn deny(reason: impl Into<String>) -> Self {
34        GuardAction::Deny {
35            reason: reason.into(),
36        }
37    }
38
39    /// Create a `Warn` action with the given reason.
40    pub fn warn(reason: impl Into<String>) -> Self {
41        GuardAction::Warn {
42            reason: reason.into(),
43        }
44    }
45
46    /// Create a `Kill` action with the given reason.
47    pub fn kill(reason: impl Into<String>) -> Self {
48        GuardAction::Kill {
49            reason: reason.into(),
50        }
51    }
52
53    /// Returns `true` if this action blocks the operation (`Deny` or `Kill`).
54    pub fn is_denied(&self) -> bool {
55        matches!(self, GuardAction::Deny { .. } | GuardAction::Kill { .. })
56    }
57
58    /// Returns `true` if this action terminates the agent run (`Kill`).
59    pub fn is_killed(&self) -> bool {
60        matches!(self, GuardAction::Kill { .. })
61    }
62}
63
64/// Interceptor hooks for LLM calls and tool executions.
65///
66/// All methods have default no-op implementations so guardrails only need to
67/// override the hooks they care about. Methods use `Pin<Box<dyn Future>>` for
68/// dyn-compatibility (same pattern as the `Tool` trait).
69///
70/// Multiple guardrails are registered via `Vec<Arc<dyn Guardrail>>` — first
71/// `Deny` wins.
72///
73/// # Example
74///
75/// A trivial guardrail that denies any LLM response containing a forbidden
76/// substring:
77///
78/// ```rust
79/// use std::future::Future;
80/// use std::pin::Pin;
81/// use heartbit_core::{GuardAction, Guardrail};
82/// use heartbit_core::llm::types::CompletionResponse;
83///
84/// struct NoSecrets;
85///
86/// impl Guardrail for NoSecrets {
87///     fn name(&self) -> &str { "no-secrets" }
88///
89///     fn post_llm(
90///         &self,
91///         response: &mut CompletionResponse,
92///     ) -> Pin<Box<dyn Future<Output = Result<GuardAction, heartbit_core::Error>> + Send + '_>> {
93///         let leaked = response
94///             .content
95///             .iter()
96///             .any(|block| matches!(block, heartbit_core::llm::types::ContentBlock::Text { text }
97///                 if text.contains("sk-")));
98///         Box::pin(async move {
99///             Ok(if leaked {
100///                 GuardAction::deny("response contained an API key prefix")
101///             } else {
102///                 GuardAction::Allow
103///             })
104///         })
105///     }
106/// }
107/// ```
108pub trait Guardrail: Send + Sync {
109    /// Human-readable name for this guardrail, used in events and audit.
110    /// Override to attribute which guardrail fired in logs.
111    fn name(&self) -> &str {
112        "unnamed"
113    }
114
115    /// Called before each LLM call. Can mutate the request (e.g., inject safety
116    /// instructions, redact content). `Err` aborts the run.
117    fn pre_llm(
118        &self,
119        _request: &mut CompletionRequest,
120    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
121        Box::pin(async { Ok(()) })
122    }
123
124    /// Called after each LLM response. Can inspect *and mutate* the response
125    /// (e.g. redact PII in `ContentBlock::Text` blocks before it reaches the
126    /// caller, audit log, or downstream tools).
127    ///
128    /// `Deny` discards the response and injects the denial reason as a user
129    /// message (consumes a turn). `Warn` lets the (possibly mutated) response
130    /// flow through but raises an audit signal. `Err` aborts the run.
131    ///
132    /// **Mutations must run synchronously inside this method body** — the
133    /// returned future's lifetime is tied to `&self`, not to `response`, so
134    /// it cannot capture `&mut response`. Apply any changes to
135    /// `response.content` before constructing the `Box::pin(async move { … })`.
136    /// This is also what lets `GuardrailChain` pipe each guardrail's mutations
137    /// through before any future is awaited.
138    fn post_llm(
139        &self,
140        _response: &mut CompletionResponse,
141    ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
142        Box::pin(async { Ok(GuardAction::Allow) })
143    }
144
145    /// Called before each tool execution. Can deny individual tool calls.
146    /// `Deny` returns a `ToolResult::error` for that call. `Err` aborts the run.
147    fn pre_tool(
148        &self,
149        _call: &ToolCall,
150    ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
151        Box::pin(async { Ok(GuardAction::Allow) })
152    }
153
154    /// Called after each tool execution (after truncation). Can mutate the
155    /// output (e.g., redact sensitive data). `Err` converts to a tool error
156    /// (consistent with tool execution errors — the agent loop continues).
157    fn post_tool(
158        &self,
159        _call: &ToolCall,
160        _output: &mut ToolOutput,
161    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
162        Box::pin(async { Ok(()) })
163    }
164
165    /// Called by the agent loop before each guardrail evaluation to provide
166    /// the current turn number. Stateful guardrails (e.g., `BehavioralMonitorGuardrail`)
167    /// can override this to track turn context.
168    fn set_turn(&self, _turn: usize) {}
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn guard_action_deny_constructor() {
177        let action = GuardAction::deny("PII detected");
178        match action {
179            GuardAction::Deny { reason } => assert_eq!(reason, "PII detected"),
180            _ => panic!("expected Deny"),
181        }
182    }
183
184    #[test]
185    fn guard_action_warn_constructor() {
186        let action = GuardAction::warn("suspicious pattern");
187        match action {
188            GuardAction::Warn { reason } => assert_eq!(reason, "suspicious pattern"),
189            _ => panic!("expected Warn"),
190        }
191    }
192
193    #[test]
194    fn guard_action_is_denied() {
195        assert!(GuardAction::deny("blocked").is_denied());
196        assert!(GuardAction::kill("critical").is_denied());
197        assert!(!GuardAction::Allow.is_denied());
198        assert!(!GuardAction::warn("suspicious").is_denied());
199    }
200
201    #[test]
202    fn guard_action_kill_constructor() {
203        let action = GuardAction::kill("CSAM detected");
204        match action {
205            GuardAction::Kill { reason } => assert_eq!(reason, "CSAM detected"),
206            _ => panic!("expected Kill"),
207        }
208    }
209
210    #[test]
211    fn guard_action_is_killed() {
212        assert!(GuardAction::kill("critical").is_killed());
213        assert!(!GuardAction::deny("blocked").is_killed());
214        assert!(!GuardAction::Allow.is_killed());
215        assert!(!GuardAction::warn("suspicious").is_killed());
216    }
217
218    #[test]
219    fn guardrail_default_name() {
220        struct MyGuardrail;
221        impl Guardrail for MyGuardrail {}
222        let g = MyGuardrail;
223        assert_eq!(g.name(), "unnamed");
224    }
225
226    #[test]
227    fn guardrail_custom_name() {
228        struct NamedGuardrail;
229        impl Guardrail for NamedGuardrail {
230            fn name(&self) -> &str {
231                "pii_detector"
232            }
233        }
234        let g = NamedGuardrail;
235        assert_eq!(g.name(), "pii_detector");
236    }
237
238    #[tokio::test]
239    async fn default_guardrail_allows_everything() {
240        struct NoOpGuardrail;
241        impl Guardrail for NoOpGuardrail {}
242
243        let g = NoOpGuardrail;
244
245        let mut request = CompletionRequest {
246            system: "sys".into(),
247            messages: vec![],
248            tools: vec![],
249            max_tokens: 1024,
250            tool_choice: None,
251            reasoning_effort: None,
252        };
253        g.pre_llm(&mut request).await.unwrap();
254
255        let mut response = CompletionResponse {
256            content: vec![],
257            stop_reason: crate::llm::types::StopReason::EndTurn,
258            usage: crate::llm::types::TokenUsage::default(),
259            model: None,
260        };
261        let action = g.post_llm(&mut response).await.unwrap();
262        assert!(matches!(action, GuardAction::Allow));
263
264        let call = ToolCall {
265            id: "c1".into(),
266            name: "test".into(),
267            input: serde_json::json!({}),
268        };
269        let action = g.pre_tool(&call).await.unwrap();
270        assert!(matches!(action, GuardAction::Allow));
271
272        let mut output = ToolOutput::success("result".to_string());
273        g.post_tool(&call, &mut output).await.unwrap();
274        assert_eq!(output.content, "result");
275    }
276
277    #[tokio::test]
278    async fn post_tool_can_mutate_output() {
279        struct RedactGuardrail;
280        impl Guardrail for RedactGuardrail {
281            fn post_tool(
282                &self,
283                _call: &ToolCall,
284                output: &mut ToolOutput,
285            ) -> Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>>
286            {
287                // Mutation is synchronous; the future just returns Ok(())
288                output.content = output.content.replace("secret", "[REDACTED]");
289                Box::pin(async { Ok(()) })
290            }
291        }
292
293        let g = RedactGuardrail;
294        let call = ToolCall {
295            id: "c1".into(),
296            name: "test".into(),
297            input: serde_json::json!({}),
298        };
299        let mut output = ToolOutput::success("the secret is 42".to_string());
300        g.post_tool(&call, &mut output).await.unwrap();
301        assert_eq!(output.content, "the [REDACTED] is 42");
302    }
303
304    #[tokio::test]
305    async fn pre_tool_deny_returns_deny_action() {
306        struct BlockBashGuardrail;
307        impl Guardrail for BlockBashGuardrail {
308            fn pre_tool(
309                &self,
310                call: &ToolCall,
311            ) -> Pin<Box<dyn std::future::Future<Output = Result<GuardAction, Error>> + Send + '_>>
312            {
313                let name = call.name.clone();
314                Box::pin(async move {
315                    if name == "bash" {
316                        Ok(GuardAction::deny("bash tool is disabled"))
317                    } else {
318                        Ok(GuardAction::Allow)
319                    }
320                })
321            }
322        }
323
324        let g = BlockBashGuardrail;
325        let bash_call = ToolCall {
326            id: "c1".into(),
327            name: "bash".into(),
328            input: serde_json::json!({"command": "rm -rf /"}),
329        };
330        let action = g.pre_tool(&bash_call).await.unwrap();
331        assert!(
332            matches!(action, GuardAction::Deny { reason } if reason == "bash tool is disabled")
333        );
334
335        let read_call = ToolCall {
336            id: "c2".into(),
337            name: "read".into(),
338            input: serde_json::json!({"path": "/tmp/test.txt"}),
339        };
340        let action = g.pre_tool(&read_call).await.unwrap();
341        assert!(matches!(action, GuardAction::Allow));
342    }
343}