use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use crate::agent::guardrail::{GuardAction, Guardrail};
use crate::error::Error;
use crate::llm::types::{CompletionRequest, CompletionResponse, ToolCall};
use crate::tool::ToolOutput;
pub struct GuardrailChain {
guardrails: Vec<Arc<dyn Guardrail>>,
}
impl GuardrailChain {
pub fn new(guardrails: Vec<Arc<dyn Guardrail>>) -> Self {
Self { guardrails }
}
}
impl Guardrail for GuardrailChain {
fn name(&self) -> &str {
"chain"
}
fn pre_llm(
&self,
request: &mut CompletionRequest,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
let futs: Vec<_> = self.guardrails.iter().map(|g| g.pre_llm(request)).collect();
Box::pin(async move {
for fut in futs {
fut.await?;
}
Ok(())
})
}
fn post_llm(
&self,
response: &mut CompletionResponse,
) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
let futs: Vec<_> = self
.guardrails
.iter()
.map(|g| g.post_llm(response))
.collect();
Box::pin(async move {
let mut worst = GuardAction::Allow;
for fut in futs {
let action = fut.await?;
if action.is_killed() {
return Ok(action);
}
if action.is_denied() {
return Ok(action);
}
if matches!(action, GuardAction::Warn { .. }) && matches!(worst, GuardAction::Allow)
{
worst = action;
}
}
Ok(worst)
})
}
fn pre_tool(
&self,
call: &ToolCall,
) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
let futs: Vec<_> = self.guardrails.iter().map(|g| g.pre_tool(call)).collect();
Box::pin(async move {
let mut worst = GuardAction::Allow;
for fut in futs {
let action = fut.await?;
if action.is_killed() {
return Ok(action);
}
if action.is_denied() {
return Ok(action);
}
if matches!(action, GuardAction::Warn { .. }) && matches!(worst, GuardAction::Allow)
{
worst = action;
}
}
Ok(worst)
})
}
fn post_tool(
&self,
call: &ToolCall,
output: &mut ToolOutput,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
let futs: Vec<_> = self
.guardrails
.iter()
.map(|g| g.post_tool(call, output))
.collect();
Box::pin(async move {
for fut in futs {
fut.await?;
}
Ok(())
})
}
}
pub struct WarnToDeny {
inner: Arc<dyn Guardrail>,
threshold: u32,
consecutive_warns: AtomicU32,
}
impl WarnToDeny {
pub fn new(inner: Arc<dyn Guardrail>, threshold: u32) -> Self {
Self {
inner,
threshold,
consecutive_warns: AtomicU32::new(0),
}
}
fn escalate_if_needed(&self, action: GuardAction) -> GuardAction {
match &action {
GuardAction::Warn { reason } => {
let prev = self.consecutive_warns.fetch_add(1, Ordering::Relaxed);
if prev + 1 >= self.threshold {
self.consecutive_warns.store(0, Ordering::Relaxed);
GuardAction::deny(format!(
"Escalated after {} consecutive warnings: {reason}",
self.threshold
))
} else {
action
}
}
GuardAction::Kill { .. } => action,
_ => {
self.consecutive_warns.store(0, Ordering::Relaxed);
action
}
}
}
}
impl Guardrail for WarnToDeny {
fn name(&self) -> &str {
"warn_to_deny"
}
fn pre_llm(
&self,
request: &mut CompletionRequest,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
self.inner.pre_llm(request)
}
fn post_llm(
&self,
response: &mut CompletionResponse,
) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
let fut = self.inner.post_llm(response);
Box::pin(async move {
let action = fut.await?;
Ok(self.escalate_if_needed(action))
})
}
fn pre_tool(
&self,
call: &ToolCall,
) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
let fut = self.inner.pre_tool(call);
Box::pin(async move {
let action = fut.await?;
Ok(self.escalate_if_needed(action))
})
}
fn post_tool(
&self,
call: &ToolCall,
output: &mut ToolOutput,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
self.inner.post_tool(call, output)
}
}
pub struct ConditionalGuardrail {
inner: Arc<dyn Guardrail>,
predicate: Arc<dyn Fn(&str) -> bool + Send + Sync>,
}
impl ConditionalGuardrail {
pub fn new(
inner: Arc<dyn Guardrail>,
predicate: Arc<dyn Fn(&str) -> bool + Send + Sync>,
) -> Self {
Self { inner, predicate }
}
}
impl Guardrail for ConditionalGuardrail {
fn name(&self) -> &str {
"conditional"
}
fn pre_llm(
&self,
request: &mut CompletionRequest,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
self.inner.pre_llm(request)
}
fn post_llm(
&self,
response: &mut CompletionResponse,
) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
self.inner.post_llm(response)
}
fn pre_tool(
&self,
call: &ToolCall,
) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
if (self.predicate)(&call.name) {
self.inner.pre_tool(call)
} else {
Box::pin(async { Ok(GuardAction::Allow) })
}
}
fn post_tool(
&self,
call: &ToolCall,
output: &mut ToolOutput,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
if (self.predicate)(&call.name) {
self.inner.post_tool(call, output)
} else {
Box::pin(async { Ok(()) })
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::llm::types::{StopReason, TokenUsage};
struct AlwaysDenyGuardrail;
impl Guardrail for AlwaysDenyGuardrail {
fn pre_tool(
&self,
_call: &ToolCall,
) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
Box::pin(async { Ok(GuardAction::deny("blocked")) })
}
fn post_llm(
&self,
_response: &mut CompletionResponse,
) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
Box::pin(async { Ok(GuardAction::deny("blocked")) })
}
}
struct AlwaysAllowGuardrail;
impl Guardrail for AlwaysAllowGuardrail {}
struct AlwaysWarnGuardrail;
impl Guardrail for AlwaysWarnGuardrail {
fn pre_tool(
&self,
_call: &ToolCall,
) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
Box::pin(async { Ok(GuardAction::warn("suspicious")) })
}
fn post_llm(
&self,
_response: &mut CompletionResponse,
) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
Box::pin(async { Ok(GuardAction::warn("suspicious")) })
}
}
fn test_call(name: &str) -> ToolCall {
ToolCall {
id: "c1".into(),
name: name.into(),
input: serde_json::json!({}),
}
}
fn test_response() -> CompletionResponse {
CompletionResponse {
content: vec![],
stop_reason: StopReason::EndTurn,
usage: TokenUsage::default(),
model: None,
}
}
#[tokio::test]
async fn chain_first_deny_wins() {
let chain = GuardrailChain::new(vec![
Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
Arc::new(AlwaysDenyGuardrail),
Arc::new(AlwaysAllowGuardrail),
]);
let action = chain.pre_tool(&test_call("bash")).await.unwrap();
assert!(action.is_denied());
}
#[tokio::test]
async fn chain_all_allow() {
let chain = GuardrailChain::new(vec![
Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
Arc::new(AlwaysAllowGuardrail),
]);
let action = chain.pre_tool(&test_call("read")).await.unwrap();
assert_eq!(action, GuardAction::Allow);
}
#[tokio::test]
async fn chain_post_llm_first_deny_wins() {
let chain = GuardrailChain::new(vec![
Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
Arc::new(AlwaysDenyGuardrail),
]);
let action = chain.post_llm(&mut test_response()).await.unwrap();
assert!(action.is_denied());
}
#[tokio::test]
async fn chain_empty_allows() {
let chain = GuardrailChain::new(vec![]);
let action = chain.pre_tool(&test_call("bash")).await.unwrap();
assert_eq!(action, GuardAction::Allow);
}
#[tokio::test]
async fn chain_propagates_warn() {
let chain = GuardrailChain::new(vec![
Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
Arc::new(AlwaysWarnGuardrail),
Arc::new(AlwaysAllowGuardrail),
]);
let action = chain.pre_tool(&test_call("bash")).await.unwrap();
assert!(
matches!(action, GuardAction::Warn { .. }),
"expected Warn, got: {action:?}"
);
}
#[tokio::test]
async fn chain_deny_trumps_warn() {
let chain = GuardrailChain::new(vec![
Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>,
Arc::new(AlwaysDenyGuardrail),
]);
let action = chain.pre_tool(&test_call("bash")).await.unwrap();
assert!(action.is_denied(), "Deny should win over Warn");
}
#[tokio::test]
async fn chain_post_llm_propagates_warn() {
let chain = GuardrailChain::new(vec![
Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>,
Arc::new(AlwaysAllowGuardrail),
]);
let action = chain.post_llm(&mut test_response()).await.unwrap();
assert!(matches!(action, GuardAction::Warn { .. }));
}
#[tokio::test]
async fn chain_post_llm_propagates_pii_redaction() {
use crate::agent::guardrails::pii::{PiiAction, PiiGuardrail};
use crate::llm::types::ContentBlock;
let chain = GuardrailChain::new(vec![
Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
Arc::new(PiiGuardrail::all_builtin(PiiAction::Redact)),
]);
let mut response = CompletionResponse {
content: vec![ContentBlock::Text {
text: "Contact john@example.com about it".into(),
}],
stop_reason: StopReason::EndTurn,
usage: TokenUsage::default(),
model: None,
};
let action = chain.post_llm(&mut response).await.unwrap();
assert!(matches!(action, GuardAction::Warn { .. }));
let ContentBlock::Text { text } = &response.content[0] else {
panic!("expected text block");
};
assert!(
!text.contains("john@example.com"),
"PiiGuardrail mutation didn't propagate through GuardrailChain: {text}"
);
assert!(text.contains("[REDACTED:email]"));
}
#[tokio::test]
async fn warn_to_deny_escalates_after_threshold() {
let inner = Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>;
let g = WarnToDeny::new(inner, 3);
let call = test_call("bash");
let a1 = g.pre_tool(&call).await.unwrap();
assert!(matches!(a1, GuardAction::Warn { .. }));
let a2 = g.pre_tool(&call).await.unwrap();
assert!(matches!(a2, GuardAction::Warn { .. }));
let a3 = g.pre_tool(&call).await.unwrap();
assert!(a3.is_denied());
if let GuardAction::Deny { reason } = &a3 {
assert!(reason.contains("3 consecutive warnings"));
}
}
#[tokio::test]
async fn warn_to_deny_resets_on_allow() {
let inner = Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>;
let g = WarnToDeny::new(inner, 3);
let call = test_call("bash");
g.pre_tool(&call).await.unwrap();
g.pre_tool(&call).await.unwrap();
g.consecutive_warns.store(0, Ordering::Relaxed);
let a1 = g.pre_tool(&call).await.unwrap();
assert!(matches!(a1, GuardAction::Warn { .. }));
let a2 = g.pre_tool(&call).await.unwrap();
assert!(matches!(a2, GuardAction::Warn { .. }));
let a3 = g.pre_tool(&call).await.unwrap();
assert!(a3.is_denied());
}
#[tokio::test]
async fn warn_to_deny_allow_resets_counter() {
let g = WarnToDeny::new(Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>, 1);
let call = test_call("bash");
g.consecutive_warns.store(5, Ordering::Relaxed);
let action = g.pre_tool(&call).await.unwrap();
assert_eq!(action, GuardAction::Allow);
assert_eq!(g.consecutive_warns.load(Ordering::Relaxed), 0);
}
#[tokio::test]
async fn warn_to_deny_post_llm_escalates() {
let inner = Arc::new(AlwaysWarnGuardrail) as Arc<dyn Guardrail>;
let g = WarnToDeny::new(inner, 2);
let mut resp = test_response();
let a1 = g.post_llm(&mut resp).await.unwrap();
assert!(matches!(a1, GuardAction::Warn { .. }));
let a2 = g.post_llm(&mut resp).await.unwrap();
assert!(a2.is_denied());
}
#[tokio::test]
async fn conditional_runs_when_predicate_true() {
let g = ConditionalGuardrail::new(
Arc::new(AlwaysDenyGuardrail) as Arc<dyn Guardrail>,
Arc::new(|name: &str| name == "bash"),
);
let action = g.pre_tool(&test_call("bash")).await.unwrap();
assert!(action.is_denied());
}
#[tokio::test]
async fn conditional_skips_when_false() {
let g = ConditionalGuardrail::new(
Arc::new(AlwaysDenyGuardrail) as Arc<dyn Guardrail>,
Arc::new(|name: &str| name == "bash"),
);
let action = g.pre_tool(&test_call("read")).await.unwrap();
assert_eq!(action, GuardAction::Allow);
}
#[tokio::test]
async fn conditional_post_tool_skips_when_false() {
let g = ConditionalGuardrail::new(
Arc::new(AlwaysDenyGuardrail) as Arc<dyn Guardrail>,
Arc::new(|name: &str| name == "bash"),
);
let call = test_call("read");
let mut output = ToolOutput::success("data".to_string());
g.post_tool(&call, &mut output).await.unwrap();
assert_eq!(output.content, "data");
}
#[tokio::test]
async fn conditional_llm_hooks_always_run() {
let g = ConditionalGuardrail::new(
Arc::new(AlwaysDenyGuardrail) as Arc<dyn Guardrail>,
Arc::new(|_name: &str| false),
);
let action = g.post_llm(&mut test_response()).await.unwrap();
assert!(action.is_denied());
}
#[test]
fn chain_meta_name() {
let chain = GuardrailChain::new(vec![]);
assert_eq!(chain.name(), "chain");
}
#[test]
fn warn_to_deny_meta_name() {
let g = WarnToDeny::new(Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>, 3);
assert_eq!(g.name(), "warn_to_deny");
}
#[test]
fn conditional_meta_name() {
let g = ConditionalGuardrail::new(
Arc::new(AlwaysAllowGuardrail) as Arc<dyn Guardrail>,
Arc::new(|_: &str| true),
);
assert_eq!(g.name(), "conditional");
}
}