Skip to main content

heartbit_core/agent/guardrails/
compose.rs

1//! Composition operators for guardrails.
2//!
3//! - [`GuardrailChain`]: Ordered pipeline — first `Deny` wins.
4//! - [`WarnToDeny`]: Graduated containment — N consecutive `Warn` → `Deny`.
5//! - [`ConditionalGuardrail`]: Predicate-gated — only runs when condition is true.
6
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU32, Ordering};
11
12use crate::agent::guardrail::{GuardAction, Guardrail};
13use crate::error::Error;
14use crate::llm::types::{CompletionRequest, CompletionResponse, ToolCall};
15use crate::tool::ToolOutput;
16
17// ---------------------------------------------------------------------------
18// GuardrailChain
19// ---------------------------------------------------------------------------
20
21/// Ordered pipeline of guardrails — first `Deny` wins.
22///
23/// Equivalent to the default `Vec<Arc<dyn Guardrail>>` behavior in the agent
24/// loop, but wrapped as a single `Guardrail` for nested composition.
25///
26/// **Implementation note**: The `Guardrail` trait's lifetime elision ties the
27/// returned future to `&self` only (not to reference parameters like `call`
28/// or `response`). This means inner guardrails perform their work (including
29/// mutations) synchronously during the call, and the returned futures only
30/// carry no-op cleanup. We eagerly evaluate all inner guardrails and collect
31/// their futures for awaiting.
32pub struct GuardrailChain {
33    guardrails: Vec<Arc<dyn Guardrail>>,
34}
35
36impl GuardrailChain {
37    /// Create a new guardrail chain from a list of guardrails.
38    pub fn new(guardrails: Vec<Arc<dyn Guardrail>>) -> Self {
39        Self { guardrails }
40    }
41}
42
43impl Guardrail for GuardrailChain {
44    fn name(&self) -> &str {
45        "chain"
46    }
47
48    fn pre_llm(
49        &self,
50        request: &mut CompletionRequest,
51    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
52        // Eagerly call each guardrail's pre_llm (synchronous mutations run now).
53        let futs: Vec<_> = self.guardrails.iter().map(|g| g.pre_llm(request)).collect();
54        Box::pin(async move {
55            for fut in futs {
56                fut.await?;
57            }
58            Ok(())
59        })
60    }
61
62    fn post_llm(
63        &self,
64        response: &mut CompletionResponse,
65    ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
66        // Eagerly call each guardrail's `post_llm`. By contract the synchronous
67        // portion of every implementation (including any redaction mutation)
68        // runs *before* the returned future is built, so each call here sees
69        // the cumulative effect of prior ones — correct semantics for chained
70        // redaction. The collected futures only carry the resulting action.
71        let futs: Vec<_> = self
72            .guardrails
73            .iter()
74            .map(|g| g.post_llm(response))
75            .collect();
76        Box::pin(async move {
77            let mut worst = GuardAction::Allow;
78            for fut in futs {
79                let action = fut.await?;
80                if action.is_killed() {
81                    return Ok(action);
82                }
83                if action.is_denied() {
84                    return Ok(action);
85                }
86                if matches!(action, GuardAction::Warn { .. }) && matches!(worst, GuardAction::Allow)
87                {
88                    worst = action;
89                }
90            }
91            Ok(worst)
92        })
93    }
94
95    fn pre_tool(
96        &self,
97        call: &ToolCall,
98    ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
99        // Eagerly call each guardrail's pre_tool.
100        let futs: Vec<_> = self.guardrails.iter().map(|g| g.pre_tool(call)).collect();
101        Box::pin(async move {
102            let mut worst = GuardAction::Allow;
103            for fut in futs {
104                let action = fut.await?;
105                if action.is_killed() {
106                    return Ok(action);
107                }
108                if action.is_denied() {
109                    return Ok(action);
110                }
111                if matches!(action, GuardAction::Warn { .. }) && matches!(worst, GuardAction::Allow)
112                {
113                    worst = action;
114                }
115            }
116            Ok(worst)
117        })
118    }
119
120    fn post_tool(
121        &self,
122        call: &ToolCall,
123        output: &mut ToolOutput,
124    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
125        // Eagerly call each guardrail's post_tool (mutations run synchronously).
126        let futs: Vec<_> = self
127            .guardrails
128            .iter()
129            .map(|g| g.post_tool(call, output))
130            .collect();
131        Box::pin(async move {
132            for fut in futs {
133                fut.await?;
134            }
135            Ok(())
136        })
137    }
138}
139
140// ---------------------------------------------------------------------------
141// WarnToDeny
142// ---------------------------------------------------------------------------
143
144/// Graduated containment: converts N consecutive `Warn` actions to `Deny`.
145///
146/// Wraps an inner guardrail. Tracks consecutive `Warn` actions across calls.
147/// When the count reaches `threshold`, the `Warn` is escalated to `Deny`.
148/// Any `Allow` resets the counter.
149pub struct WarnToDeny {
150    inner: Arc<dyn Guardrail>,
151    threshold: u32,
152    consecutive_warns: AtomicU32,
153}
154
155impl WarnToDeny {
156    /// Create a `WarnToDeny` adaptor that escalates after `threshold` consecutive warnings.
157    pub fn new(inner: Arc<dyn Guardrail>, threshold: u32) -> Self {
158        Self {
159            inner,
160            threshold,
161            consecutive_warns: AtomicU32::new(0),
162        }
163    }
164
165    fn escalate_if_needed(&self, action: GuardAction) -> GuardAction {
166        match &action {
167            GuardAction::Warn { reason } => {
168                let prev = self.consecutive_warns.fetch_add(1, Ordering::Relaxed);
169                if prev + 1 >= self.threshold {
170                    self.consecutive_warns.store(0, Ordering::Relaxed);
171                    GuardAction::deny(format!(
172                        "Escalated after {} consecutive warnings: {reason}",
173                        self.threshold
174                    ))
175                } else {
176                    action
177                }
178            }
179            // Allow/Deny reset warn counter; Kill passes through unchanged
180            GuardAction::Kill { .. } => action,
181            _ => {
182                self.consecutive_warns.store(0, Ordering::Relaxed);
183                action
184            }
185        }
186    }
187}
188
189impl Guardrail for WarnToDeny {
190    fn name(&self) -> &str {
191        "warn_to_deny"
192    }
193
194    fn pre_llm(
195        &self,
196        request: &mut CompletionRequest,
197    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
198        self.inner.pre_llm(request)
199    }
200
201    fn post_llm(
202        &self,
203        response: &mut CompletionResponse,
204    ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
205        let fut = self.inner.post_llm(response);
206        Box::pin(async move {
207            let action = fut.await?;
208            Ok(self.escalate_if_needed(action))
209        })
210    }
211
212    fn pre_tool(
213        &self,
214        call: &ToolCall,
215    ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
216        // Eagerly call inner (doesn't capture `call` per trait elision).
217        let fut = self.inner.pre_tool(call);
218        Box::pin(async move {
219            let action = fut.await?;
220            Ok(self.escalate_if_needed(action))
221        })
222    }
223
224    fn post_tool(
225        &self,
226        call: &ToolCall,
227        output: &mut ToolOutput,
228    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
229        self.inner.post_tool(call, output)
230    }
231}
232
233// ---------------------------------------------------------------------------
234// ConditionalGuardrail
235// ---------------------------------------------------------------------------
236
237/// Predicate-gated guardrail — only runs the inner guardrail when the
238/// predicate returns `true` for the tool name.
239///
240/// Use for patterns like "apply this guardrail only to MCP tools" or
241/// "only check bash tool calls".
242///
243/// The predicate receives the tool name for `pre_tool`/`post_tool` hooks.
244/// For `pre_llm`/`post_llm` hooks (no tool name), the inner guardrail
245/// always runs.
246pub struct ConditionalGuardrail {
247    inner: Arc<dyn Guardrail>,
248    predicate: Arc<dyn Fn(&str) -> bool + Send + Sync>,
249}
250
251impl ConditionalGuardrail {
252    /// Create a guardrail that only activates for tool calls matching `predicate`.
253    pub fn new(
254        inner: Arc<dyn Guardrail>,
255        predicate: Arc<dyn Fn(&str) -> bool + Send + Sync>,
256    ) -> Self {
257        Self { inner, predicate }
258    }
259}
260
261impl Guardrail for ConditionalGuardrail {
262    fn name(&self) -> &str {
263        "conditional"
264    }
265
266    fn pre_llm(
267        &self,
268        request: &mut CompletionRequest,
269    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
270        self.inner.pre_llm(request)
271    }
272
273    fn post_llm(
274        &self,
275        response: &mut CompletionResponse,
276    ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
277        self.inner.post_llm(response)
278    }
279
280    fn pre_tool(
281        &self,
282        call: &ToolCall,
283    ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
284        if (self.predicate)(&call.name) {
285            self.inner.pre_tool(call)
286        } else {
287            Box::pin(async { Ok(GuardAction::Allow) })
288        }
289    }
290
291    fn post_tool(
292        &self,
293        call: &ToolCall,
294        output: &mut ToolOutput,
295    ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
296        if (self.predicate)(&call.name) {
297            self.inner.post_tool(call, output)
298        } else {
299            Box::pin(async { Ok(()) })
300        }
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use crate::llm::types::{StopReason, TokenUsage};
308
309    /// A guardrail that always denies pre_tool calls.
310    struct AlwaysDenyGuardrail;
311    impl Guardrail for AlwaysDenyGuardrail {
312        fn pre_tool(
313            &self,
314            _call: &ToolCall,
315        ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
316            Box::pin(async { Ok(GuardAction::deny("blocked")) })
317        }
318        fn post_llm(
319            &self,
320            _response: &mut CompletionResponse,
321        ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
322            Box::pin(async { Ok(GuardAction::deny("blocked")) })
323        }
324    }
325
326    /// A guardrail that always allows.
327    struct AlwaysAllowGuardrail;
328    impl Guardrail for AlwaysAllowGuardrail {}
329
330    /// A guardrail that always warns.
331    struct AlwaysWarnGuardrail;
332    impl Guardrail for AlwaysWarnGuardrail {
333        fn pre_tool(
334            &self,
335            _call: &ToolCall,
336        ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
337            Box::pin(async { Ok(GuardAction::warn("suspicious")) })
338        }
339        fn post_llm(
340            &self,
341            _response: &mut CompletionResponse,
342        ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
343            Box::pin(async { Ok(GuardAction::warn("suspicious")) })
344        }
345    }
346
347    fn test_call(name: &str) -> ToolCall {
348        ToolCall {
349            id: "c1".into(),
350            name: name.into(),
351            input: serde_json::json!({}),
352        }
353    }
354
355    fn test_response() -> CompletionResponse {
356        CompletionResponse {
357            content: vec![],
358            stop_reason: StopReason::EndTurn,
359            usage: TokenUsage::default(),
360            model: None,
361        }
362    }
363
364    // --- GuardrailChain tests ---
365
366    #[tokio::test]
367    async fn chain_first_deny_wins() {
368        let chain = GuardrailChain::new(vec![
369            Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
370            Arc::new(AlwaysDenyGuardrail),
371            Arc::new(AlwaysAllowGuardrail),
372        ]);
373        let action = chain.pre_tool(&test_call("bash")).await.unwrap();
374        assert!(action.is_denied());
375    }
376
377    #[tokio::test]
378    async fn chain_all_allow() {
379        let chain = GuardrailChain::new(vec![
380            Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
381            Arc::new(AlwaysAllowGuardrail),
382        ]);
383        let action = chain.pre_tool(&test_call("read")).await.unwrap();
384        assert_eq!(action, GuardAction::Allow);
385    }
386
387    #[tokio::test]
388    async fn chain_post_llm_first_deny_wins() {
389        let chain = GuardrailChain::new(vec![
390            Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
391            Arc::new(AlwaysDenyGuardrail),
392        ]);
393        let action = chain.post_llm(&mut test_response()).await.unwrap();
394        assert!(action.is_denied());
395    }
396
397    #[tokio::test]
398    async fn chain_empty_allows() {
399        let chain = GuardrailChain::new(vec![]);
400        let action = chain.pre_tool(&test_call("bash")).await.unwrap();
401        assert_eq!(action, GuardAction::Allow);
402    }
403
404    #[tokio::test]
405    async fn chain_propagates_warn() {
406        let chain = GuardrailChain::new(vec![
407            Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
408            Arc::new(AlwaysWarnGuardrail),
409            Arc::new(AlwaysAllowGuardrail),
410        ]);
411        let action = chain.pre_tool(&test_call("bash")).await.unwrap();
412        assert!(
413            matches!(action, GuardAction::Warn { .. }),
414            "expected Warn, got: {action:?}"
415        );
416    }
417
418    #[tokio::test]
419    async fn chain_deny_trumps_warn() {
420        let chain = GuardrailChain::new(vec![
421            Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>,
422            Arc::new(AlwaysDenyGuardrail),
423        ]);
424        let action = chain.pre_tool(&test_call("bash")).await.unwrap();
425        assert!(action.is_denied(), "Deny should win over Warn");
426    }
427
428    #[tokio::test]
429    async fn chain_post_llm_propagates_warn() {
430        let chain = GuardrailChain::new(vec![
431            Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>,
432            Arc::new(AlwaysAllowGuardrail),
433        ]);
434        let action = chain.post_llm(&mut test_response()).await.unwrap();
435        assert!(matches!(action, GuardAction::Warn { .. }));
436    }
437
438    #[tokio::test]
439    async fn chain_post_llm_propagates_pii_redaction() {
440        // Issue #7 regression guard: a `PiiGuardrail` wrapped in a chain must
441        // still mutate `response` in place. This pins the contract that
442        // post_llm impls perform any mutation synchronously, before the
443        // chain's collected futures are awaited.
444        use crate::agent::guardrails::pii::{PiiAction, PiiGuardrail};
445        use crate::llm::types::ContentBlock;
446
447        let chain = GuardrailChain::new(vec![
448            Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
449            Arc::new(PiiGuardrail::all_builtin(PiiAction::Redact)),
450        ]);
451
452        let mut response = CompletionResponse {
453            content: vec![ContentBlock::Text {
454                text: "Contact john@example.com about it".into(),
455            }],
456            stop_reason: StopReason::EndTurn,
457            usage: TokenUsage::default(),
458            model: None,
459        };
460
461        let action = chain.post_llm(&mut response).await.unwrap();
462        assert!(matches!(action, GuardAction::Warn { .. }));
463
464        let ContentBlock::Text { text } = &response.content[0] else {
465            panic!("expected text block");
466        };
467        assert!(
468            !text.contains("john@example.com"),
469            "PiiGuardrail mutation didn't propagate through GuardrailChain: {text}"
470        );
471        assert!(text.contains("[REDACTED:email]"));
472    }
473
474    // --- WarnToDeny tests ---
475
476    #[tokio::test]
477    async fn warn_to_deny_escalates_after_threshold() {
478        let inner = Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>;
479        let g = WarnToDeny::new(inner, 3);
480        let call = test_call("bash");
481
482        // First two: still Warn
483        let a1 = g.pre_tool(&call).await.unwrap();
484        assert!(matches!(a1, GuardAction::Warn { .. }));
485        let a2 = g.pre_tool(&call).await.unwrap();
486        assert!(matches!(a2, GuardAction::Warn { .. }));
487
488        // Third: escalated to Deny
489        let a3 = g.pre_tool(&call).await.unwrap();
490        assert!(a3.is_denied());
491        if let GuardAction::Deny { reason } = &a3 {
492            assert!(reason.contains("3 consecutive warnings"));
493        }
494    }
495
496    #[tokio::test]
497    async fn warn_to_deny_resets_on_allow() {
498        let inner = Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>;
499        let g = WarnToDeny::new(inner, 3);
500        let call = test_call("bash");
501
502        // Two warns
503        g.pre_tool(&call).await.unwrap();
504        g.pre_tool(&call).await.unwrap();
505
506        // Reset (simulating an Allow from inner)
507        g.consecutive_warns.store(0, Ordering::Relaxed);
508
509        // Two more warns — should not escalate yet
510        let a1 = g.pre_tool(&call).await.unwrap();
511        assert!(matches!(a1, GuardAction::Warn { .. }));
512        let a2 = g.pre_tool(&call).await.unwrap();
513        assert!(matches!(a2, GuardAction::Warn { .. }));
514
515        // Third → escalate
516        let a3 = g.pre_tool(&call).await.unwrap();
517        assert!(a3.is_denied());
518    }
519
520    #[tokio::test]
521    async fn warn_to_deny_allow_resets_counter() {
522        let g = WarnToDeny::new(Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>, 1);
523        let call = test_call("bash");
524        // Set counter artificially
525        g.consecutive_warns.store(5, Ordering::Relaxed);
526        let action = g.pre_tool(&call).await.unwrap();
527        assert_eq!(action, GuardAction::Allow);
528        assert_eq!(g.consecutive_warns.load(Ordering::Relaxed), 0);
529    }
530
531    #[tokio::test]
532    async fn warn_to_deny_post_llm_escalates() {
533        let inner = Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>;
534        let g = WarnToDeny::new(inner, 2);
535        let mut resp = test_response();
536
537        let a1 = g.post_llm(&mut resp).await.unwrap();
538        assert!(matches!(a1, GuardAction::Warn { .. }));
539
540        let a2 = g.post_llm(&mut resp).await.unwrap();
541        assert!(a2.is_denied());
542    }
543
544    // --- ConditionalGuardrail tests ---
545
546    #[tokio::test]
547    async fn conditional_runs_when_predicate_true() {
548        let g = ConditionalGuardrail::new(
549            Arc::new(AlwaysDenyGuardrail) as Arc<dyn Guardrail>,
550            Arc::new(|name: &str| name == "bash"),
551        );
552        let action = g.pre_tool(&test_call("bash")).await.unwrap();
553        assert!(action.is_denied());
554    }
555
556    #[tokio::test]
557    async fn conditional_skips_when_false() {
558        let g = ConditionalGuardrail::new(
559            Arc::new(AlwaysDenyGuardrail) as Arc<dyn Guardrail>,
560            Arc::new(|name: &str| name == "bash"),
561        );
562        let action = g.pre_tool(&test_call("read")).await.unwrap();
563        assert_eq!(action, GuardAction::Allow);
564    }
565
566    #[tokio::test]
567    async fn conditional_post_tool_skips_when_false() {
568        let g = ConditionalGuardrail::new(
569            Arc::new(AlwaysDenyGuardrail) as Arc<dyn Guardrail>,
570            Arc::new(|name: &str| name == "bash"),
571        );
572        let call = test_call("read");
573        let mut output = ToolOutput::success("data".to_string());
574        g.post_tool(&call, &mut output).await.unwrap();
575        assert_eq!(output.content, "data");
576    }
577
578    #[tokio::test]
579    async fn conditional_llm_hooks_always_run() {
580        let g = ConditionalGuardrail::new(
581            Arc::new(AlwaysDenyGuardrail) as Arc<dyn Guardrail>,
582            Arc::new(|_name: &str| false),
583        );
584        let action = g.post_llm(&mut test_response()).await.unwrap();
585        assert!(action.is_denied());
586    }
587
588    // --- Meta tests ---
589
590    #[test]
591    fn chain_meta_name() {
592        let chain = GuardrailChain::new(vec![]);
593        assert_eq!(chain.name(), "chain");
594    }
595
596    #[test]
597    fn warn_to_deny_meta_name() {
598        let g = WarnToDeny::new(Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>, 3);
599        assert_eq!(g.name(), "warn_to_deny");
600    }
601
602    #[test]
603    fn conditional_meta_name() {
604        let g = ConditionalGuardrail::new(
605            Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
606            Arc::new(|_: &str| true),
607        );
608        assert_eq!(g.name(), "conditional");
609    }
610}