heartbit_core/agent/guardrails/
compose.rs1use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU32, Ordering};
11
12use crate::agent::guardrail::{GuardAction, Guardrail};
13use crate::error::Error;
14use crate::llm::types::{CompletionRequest, CompletionResponse, ToolCall};
15use crate::tool::ToolOutput;
16
17pub struct GuardrailChain {
33 guardrails: Vec<Arc<dyn Guardrail>>,
34}
35
36impl GuardrailChain {
37 pub fn new(guardrails: Vec<Arc<dyn Guardrail>>) -> Self {
39 Self { guardrails }
40 }
41}
42
43impl Guardrail for GuardrailChain {
44 fn name(&self) -> &str {
45 "chain"
46 }
47
48 fn pre_llm(
49 &self,
50 request: &mut CompletionRequest,
51 ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
52 let futs: Vec<_> = self.guardrails.iter().map(|g| g.pre_llm(request)).collect();
54 Box::pin(async move {
55 for fut in futs {
56 fut.await?;
57 }
58 Ok(())
59 })
60 }
61
62 fn post_llm(
63 &self,
64 response: &mut CompletionResponse,
65 ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
66 let futs: Vec<_> = self
72 .guardrails
73 .iter()
74 .map(|g| g.post_llm(response))
75 .collect();
76 Box::pin(async move {
77 let mut worst = GuardAction::Allow;
78 for fut in futs {
79 let action = fut.await?;
80 if action.is_killed() {
81 return Ok(action);
82 }
83 if action.is_denied() {
84 return Ok(action);
85 }
86 if matches!(action, GuardAction::Warn { .. }) && matches!(worst, GuardAction::Allow)
87 {
88 worst = action;
89 }
90 }
91 Ok(worst)
92 })
93 }
94
95 fn pre_tool(
96 &self,
97 call: &ToolCall,
98 ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
99 let futs: Vec<_> = self.guardrails.iter().map(|g| g.pre_tool(call)).collect();
101 Box::pin(async move {
102 let mut worst = GuardAction::Allow;
103 for fut in futs {
104 let action = fut.await?;
105 if action.is_killed() {
106 return Ok(action);
107 }
108 if action.is_denied() {
109 return Ok(action);
110 }
111 if matches!(action, GuardAction::Warn { .. }) && matches!(worst, GuardAction::Allow)
112 {
113 worst = action;
114 }
115 }
116 Ok(worst)
117 })
118 }
119
120 fn post_tool(
121 &self,
122 call: &ToolCall,
123 output: &mut ToolOutput,
124 ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
125 let futs: Vec<_> = self
127 .guardrails
128 .iter()
129 .map(|g| g.post_tool(call, output))
130 .collect();
131 Box::pin(async move {
132 for fut in futs {
133 fut.await?;
134 }
135 Ok(())
136 })
137 }
138}
139
140pub struct WarnToDeny {
150 inner: Arc<dyn Guardrail>,
151 threshold: u32,
152 consecutive_warns: AtomicU32,
153}
154
155impl WarnToDeny {
156 pub fn new(inner: Arc<dyn Guardrail>, threshold: u32) -> Self {
158 Self {
159 inner,
160 threshold,
161 consecutive_warns: AtomicU32::new(0),
162 }
163 }
164
165 fn escalate_if_needed(&self, action: GuardAction) -> GuardAction {
166 match &action {
167 GuardAction::Warn { reason } => {
168 let prev = self.consecutive_warns.fetch_add(1, Ordering::Relaxed);
169 if prev + 1 >= self.threshold {
170 self.consecutive_warns.store(0, Ordering::Relaxed);
171 GuardAction::deny(format!(
172 "Escalated after {} consecutive warnings: {reason}",
173 self.threshold
174 ))
175 } else {
176 action
177 }
178 }
179 GuardAction::Kill { .. } => action,
181 _ => {
182 self.consecutive_warns.store(0, Ordering::Relaxed);
183 action
184 }
185 }
186 }
187}
188
189impl Guardrail for WarnToDeny {
190 fn name(&self) -> &str {
191 "warn_to_deny"
192 }
193
194 fn pre_llm(
195 &self,
196 request: &mut CompletionRequest,
197 ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
198 self.inner.pre_llm(request)
199 }
200
201 fn post_llm(
202 &self,
203 response: &mut CompletionResponse,
204 ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
205 let fut = self.inner.post_llm(response);
206 Box::pin(async move {
207 let action = fut.await?;
208 Ok(self.escalate_if_needed(action))
209 })
210 }
211
212 fn pre_tool(
213 &self,
214 call: &ToolCall,
215 ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
216 let fut = self.inner.pre_tool(call);
218 Box::pin(async move {
219 let action = fut.await?;
220 Ok(self.escalate_if_needed(action))
221 })
222 }
223
224 fn post_tool(
225 &self,
226 call: &ToolCall,
227 output: &mut ToolOutput,
228 ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
229 self.inner.post_tool(call, output)
230 }
231}
232
233pub struct ConditionalGuardrail {
247 inner: Arc<dyn Guardrail>,
248 predicate: Arc<dyn Fn(&str) -> bool + Send + Sync>,
249}
250
251impl ConditionalGuardrail {
252 pub fn new(
254 inner: Arc<dyn Guardrail>,
255 predicate: Arc<dyn Fn(&str) -> bool + Send + Sync>,
256 ) -> Self {
257 Self { inner, predicate }
258 }
259}
260
261impl Guardrail for ConditionalGuardrail {
262 fn name(&self) -> &str {
263 "conditional"
264 }
265
266 fn pre_llm(
267 &self,
268 request: &mut CompletionRequest,
269 ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
270 self.inner.pre_llm(request)
271 }
272
273 fn post_llm(
274 &self,
275 response: &mut CompletionResponse,
276 ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
277 self.inner.post_llm(response)
278 }
279
280 fn pre_tool(
281 &self,
282 call: &ToolCall,
283 ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
284 if (self.predicate)(&call.name) {
285 self.inner.pre_tool(call)
286 } else {
287 Box::pin(async { Ok(GuardAction::Allow) })
288 }
289 }
290
291 fn post_tool(
292 &self,
293 call: &ToolCall,
294 output: &mut ToolOutput,
295 ) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
296 if (self.predicate)(&call.name) {
297 self.inner.post_tool(call, output)
298 } else {
299 Box::pin(async { Ok(()) })
300 }
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use crate::llm::types::{StopReason, TokenUsage};
308
309 struct AlwaysDenyGuardrail;
311 impl Guardrail for AlwaysDenyGuardrail {
312 fn pre_tool(
313 &self,
314 _call: &ToolCall,
315 ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
316 Box::pin(async { Ok(GuardAction::deny("blocked")) })
317 }
318 fn post_llm(
319 &self,
320 _response: &mut CompletionResponse,
321 ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
322 Box::pin(async { Ok(GuardAction::deny("blocked")) })
323 }
324 }
325
326 struct AlwaysAllowGuardrail;
328 impl Guardrail for AlwaysAllowGuardrail {}
329
330 struct AlwaysWarnGuardrail;
332 impl Guardrail for AlwaysWarnGuardrail {
333 fn pre_tool(
334 &self,
335 _call: &ToolCall,
336 ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
337 Box::pin(async { Ok(GuardAction::warn("suspicious")) })
338 }
339 fn post_llm(
340 &self,
341 _response: &mut CompletionResponse,
342 ) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
343 Box::pin(async { Ok(GuardAction::warn("suspicious")) })
344 }
345 }
346
347 fn test_call(name: &str) -> ToolCall {
348 ToolCall {
349 id: "c1".into(),
350 name: name.into(),
351 input: serde_json::json!({}),
352 }
353 }
354
355 fn test_response() -> CompletionResponse {
356 CompletionResponse {
357 content: vec![],
358 stop_reason: StopReason::EndTurn,
359 usage: TokenUsage::default(),
360 model: None,
361 }
362 }
363
364 #[tokio::test]
367 async fn chain_first_deny_wins() {
368 let chain = GuardrailChain::new(vec![
369 Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
370 Arc::new(AlwaysDenyGuardrail),
371 Arc::new(AlwaysAllowGuardrail),
372 ]);
373 let action = chain.pre_tool(&test_call("bash")).await.unwrap();
374 assert!(action.is_denied());
375 }
376
377 #[tokio::test]
378 async fn chain_all_allow() {
379 let chain = GuardrailChain::new(vec![
380 Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
381 Arc::new(AlwaysAllowGuardrail),
382 ]);
383 let action = chain.pre_tool(&test_call("read")).await.unwrap();
384 assert_eq!(action, GuardAction::Allow);
385 }
386
387 #[tokio::test]
388 async fn chain_post_llm_first_deny_wins() {
389 let chain = GuardrailChain::new(vec![
390 Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
391 Arc::new(AlwaysDenyGuardrail),
392 ]);
393 let action = chain.post_llm(&mut test_response()).await.unwrap();
394 assert!(action.is_denied());
395 }
396
397 #[tokio::test]
398 async fn chain_empty_allows() {
399 let chain = GuardrailChain::new(vec![]);
400 let action = chain.pre_tool(&test_call("bash")).await.unwrap();
401 assert_eq!(action, GuardAction::Allow);
402 }
403
404 #[tokio::test]
405 async fn chain_propagates_warn() {
406 let chain = GuardrailChain::new(vec![
407 Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
408 Arc::new(AlwaysWarnGuardrail),
409 Arc::new(AlwaysAllowGuardrail),
410 ]);
411 let action = chain.pre_tool(&test_call("bash")).await.unwrap();
412 assert!(
413 matches!(action, GuardAction::Warn { .. }),
414 "expected Warn, got: {action:?}"
415 );
416 }
417
418 #[tokio::test]
419 async fn chain_deny_trumps_warn() {
420 let chain = GuardrailChain::new(vec![
421 Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>,
422 Arc::new(AlwaysDenyGuardrail),
423 ]);
424 let action = chain.pre_tool(&test_call("bash")).await.unwrap();
425 assert!(action.is_denied(), "Deny should win over Warn");
426 }
427
428 #[tokio::test]
429 async fn chain_post_llm_propagates_warn() {
430 let chain = GuardrailChain::new(vec![
431 Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>,
432 Arc::new(AlwaysAllowGuardrail),
433 ]);
434 let action = chain.post_llm(&mut test_response()).await.unwrap();
435 assert!(matches!(action, GuardAction::Warn { .. }));
436 }
437
438 #[tokio::test]
439 async fn chain_post_llm_propagates_pii_redaction() {
440 use crate::agent::guardrails::pii::{PiiAction, PiiGuardrail};
445 use crate::llm::types::ContentBlock;
446
447 let chain = GuardrailChain::new(vec![
448 Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
449 Arc::new(PiiGuardrail::all_builtin(PiiAction::Redact)),
450 ]);
451
452 let mut response = CompletionResponse {
453 content: vec![ContentBlock::Text {
454 text: "Contact john@example.com about it".into(),
455 }],
456 stop_reason: StopReason::EndTurn,
457 usage: TokenUsage::default(),
458 model: None,
459 };
460
461 let action = chain.post_llm(&mut response).await.unwrap();
462 assert!(matches!(action, GuardAction::Warn { .. }));
463
464 let ContentBlock::Text { text } = &response.content[0] else {
465 panic!("expected text block");
466 };
467 assert!(
468 !text.contains("john@example.com"),
469 "PiiGuardrail mutation didn't propagate through GuardrailChain: {text}"
470 );
471 assert!(text.contains("[REDACTED:email]"));
472 }
473
474 #[tokio::test]
477 async fn warn_to_deny_escalates_after_threshold() {
478 let inner = Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>;
479 let g = WarnToDeny::new(inner, 3);
480 let call = test_call("bash");
481
482 let a1 = g.pre_tool(&call).await.unwrap();
484 assert!(matches!(a1, GuardAction::Warn { .. }));
485 let a2 = g.pre_tool(&call).await.unwrap();
486 assert!(matches!(a2, GuardAction::Warn { .. }));
487
488 let a3 = g.pre_tool(&call).await.unwrap();
490 assert!(a3.is_denied());
491 if let GuardAction::Deny { reason } = &a3 {
492 assert!(reason.contains("3 consecutive warnings"));
493 }
494 }
495
496 #[tokio::test]
497 async fn warn_to_deny_resets_on_allow() {
498 let inner = Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>;
499 let g = WarnToDeny::new(inner, 3);
500 let call = test_call("bash");
501
502 g.pre_tool(&call).await.unwrap();
504 g.pre_tool(&call).await.unwrap();
505
506 g.consecutive_warns.store(0, Ordering::Relaxed);
508
509 let a1 = g.pre_tool(&call).await.unwrap();
511 assert!(matches!(a1, GuardAction::Warn { .. }));
512 let a2 = g.pre_tool(&call).await.unwrap();
513 assert!(matches!(a2, GuardAction::Warn { .. }));
514
515 let a3 = g.pre_tool(&call).await.unwrap();
517 assert!(a3.is_denied());
518 }
519
520 #[tokio::test]
521 async fn warn_to_deny_allow_resets_counter() {
522 let g = WarnToDeny::new(Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>, 1);
523 let call = test_call("bash");
524 g.consecutive_warns.store(5, Ordering::Relaxed);
526 let action = g.pre_tool(&call).await.unwrap();
527 assert_eq!(action, GuardAction::Allow);
528 assert_eq!(g.consecutive_warns.load(Ordering::Relaxed), 0);
529 }
530
531 #[tokio::test]
532 async fn warn_to_deny_post_llm_escalates() {
533 let inner = Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>;
534 let g = WarnToDeny::new(inner, 2);
535 let mut resp = test_response();
536
537 let a1 = g.post_llm(&mut resp).await.unwrap();
538 assert!(matches!(a1, GuardAction::Warn { .. }));
539
540 let a2 = g.post_llm(&mut resp).await.unwrap();
541 assert!(a2.is_denied());
542 }
543
544 #[tokio::test]
547 async fn conditional_runs_when_predicate_true() {
548 let g = ConditionalGuardrail::new(
549 Arc::new(AlwaysDenyGuardrail) as Arc<dyn Guardrail>,
550 Arc::new(|name: &str| name == "bash"),
551 );
552 let action = g.pre_tool(&test_call("bash")).await.unwrap();
553 assert!(action.is_denied());
554 }
555
556 #[tokio::test]
557 async fn conditional_skips_when_false() {
558 let g = ConditionalGuardrail::new(
559 Arc::new(AlwaysDenyGuardrail) as Arc<dyn Guardrail>,
560 Arc::new(|name: &str| name == "bash"),
561 );
562 let action = g.pre_tool(&test_call("read")).await.unwrap();
563 assert_eq!(action, GuardAction::Allow);
564 }
565
566 #[tokio::test]
567 async fn conditional_post_tool_skips_when_false() {
568 let g = ConditionalGuardrail::new(
569 Arc::new(AlwaysDenyGuardrail) as Arc<dyn Guardrail>,
570 Arc::new(|name: &str| name == "bash"),
571 );
572 let call = test_call("read");
573 let mut output = ToolOutput::success("data".to_string());
574 g.post_tool(&call, &mut output).await.unwrap();
575 assert_eq!(output.content, "data");
576 }
577
578 #[tokio::test]
579 async fn conditional_llm_hooks_always_run() {
580 let g = ConditionalGuardrail::new(
581 Arc::new(AlwaysDenyGuardrail) as Arc<dyn Guardrail>,
582 Arc::new(|_name: &str| false),
583 );
584 let action = g.post_llm(&mut test_response()).await.unwrap();
585 assert!(action.is_denied());
586 }
587
588 #[test]
591 fn chain_meta_name() {
592 let chain = GuardrailChain::new(vec![]);
593 assert_eq!(chain.name(), "chain");
594 }
595
596 #[test]
597 fn warn_to_deny_meta_name() {
598 let g = WarnToDeny::new(Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>, 3);
599 assert_eq!(g.name(), "warn_to_deny");
600 }
601
602 #[test]
603 fn conditional_meta_name() {
604 let g = ConditionalGuardrail::new(
605 Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
606 Arc::new(|_: &str| true),
607 );
608 assert_eq!(g.name(), "conditional");
609 }
610}