#![allow(missing_docs)]
use std::future::Future;
use std::pin::Pin;
use crate::error::Error;
use crate::llm::types::{CompletionRequest, CompletionResponse, ToolCall};
use crate::tool::ToolOutput;
#[derive(Debug, Clone, PartialEq)]
pub enum GuardAction {
Allow,
Deny { reason: String },
Warn { reason: String },
Kill { reason: String },
}
impl GuardAction {
pub fn deny(reason: impl Into<String>) -> Self {
GuardAction::Deny {
reason: reason.into(),
}
}
pub fn warn(reason: impl Into<String>) -> Self {
GuardAction::Warn {
reason: reason.into(),
}
}
pub fn kill(reason: impl Into<String>) -> Self {
GuardAction::Kill {
reason: reason.into(),
}
}
pub fn is_denied(&self) -> bool {
matches!(self, GuardAction::Deny { .. } | GuardAction::Kill { .. })
}
pub fn is_killed(&self) -> bool {
matches!(self, GuardAction::Kill { .. })
}
}
pub trait Guardrail: Send + Sync {
fn name(&self) -> &str {
"unnamed"
}
fn pre_llm(
&self,
_request: &mut CompletionRequest,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
Box::pin(async { Ok(()) })
}
fn post_llm(
&self,
_response: &CompletionResponse,
) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
Box::pin(async { Ok(GuardAction::Allow) })
}
fn pre_tool(
&self,
_call: &ToolCall,
) -> Pin<Box<dyn Future<Output = Result<GuardAction, Error>> + Send + '_>> {
Box::pin(async { Ok(GuardAction::Allow) })
}
fn post_tool(
&self,
_call: &ToolCall,
_output: &mut ToolOutput,
) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + '_>> {
Box::pin(async { Ok(()) })
}
fn set_turn(&self, _turn: usize) {}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn guard_action_deny_constructor() {
let action = GuardAction::deny("PII detected");
match action {
GuardAction::Deny { reason } => assert_eq!(reason, "PII detected"),
_ => panic!("expected Deny"),
}
}
#[test]
fn guard_action_warn_constructor() {
let action = GuardAction::warn("suspicious pattern");
match action {
GuardAction::Warn { reason } => assert_eq!(reason, "suspicious pattern"),
_ => panic!("expected Warn"),
}
}
#[test]
fn guard_action_is_denied() {
assert!(GuardAction::deny("blocked").is_denied());
assert!(GuardAction::kill("critical").is_denied());
assert!(!GuardAction::Allow.is_denied());
assert!(!GuardAction::warn("suspicious").is_denied());
}
#[test]
fn guard_action_kill_constructor() {
let action = GuardAction::kill("CSAM detected");
match action {
GuardAction::Kill { reason } => assert_eq!(reason, "CSAM detected"),
_ => panic!("expected Kill"),
}
}
#[test]
fn guard_action_is_killed() {
assert!(GuardAction::kill("critical").is_killed());
assert!(!GuardAction::deny("blocked").is_killed());
assert!(!GuardAction::Allow.is_killed());
assert!(!GuardAction::warn("suspicious").is_killed());
}
#[test]
fn guardrail_default_name() {
struct MyGuardrail;
impl Guardrail for MyGuardrail {}
let g = MyGuardrail;
assert_eq!(g.name(), "unnamed");
}
#[test]
fn guardrail_custom_name() {
struct NamedGuardrail;
impl Guardrail for NamedGuardrail {
fn name(&self) -> &str {
"pii_detector"
}
}
let g = NamedGuardrail;
assert_eq!(g.name(), "pii_detector");
}
#[tokio::test]
async fn default_guardrail_allows_everything() {
struct NoOpGuardrail;
impl Guardrail for NoOpGuardrail {}
let g = NoOpGuardrail;
let mut request = CompletionRequest {
system: "sys".into(),
messages: vec![],
tools: vec![],
max_tokens: 1024,
tool_choice: None,
reasoning_effort: None,
};
g.pre_llm(&mut request).await.unwrap();
let response = CompletionResponse {
content: vec![],
stop_reason: crate::llm::types::StopReason::EndTurn,
usage: crate::llm::types::TokenUsage::default(),
model: None,
};
let action = g.post_llm(&response).await.unwrap();
assert!(matches!(action, GuardAction::Allow));
let call = ToolCall {
id: "c1".into(),
name: "test".into(),
input: serde_json::json!({}),
};
let action = g.pre_tool(&call).await.unwrap();
assert!(matches!(action, GuardAction::Allow));
let mut output = ToolOutput::success("result".to_string());
g.post_tool(&call, &mut output).await.unwrap();
assert_eq!(output.content, "result");
}
#[tokio::test]
async fn post_tool_can_mutate_output() {
struct RedactGuardrail;
impl Guardrail for RedactGuardrail {
fn post_tool(
&self,
_call: &ToolCall,
output: &mut ToolOutput,
) -> Pin<Box<dyn std::future::Future<Output = Result<(), Error>> + Send + '_>>
{
output.content = output.content.replace("secret", "[REDACTED]");
Box::pin(async { Ok(()) })
}
}
let g = RedactGuardrail;
let call = ToolCall {
id: "c1".into(),
name: "test".into(),
input: serde_json::json!({}),
};
let mut output = ToolOutput::success("the secret is 42".to_string());
g.post_tool(&call, &mut output).await.unwrap();
assert_eq!(output.content, "the [REDACTED] is 42");
}
#[tokio::test]
async fn pre_tool_deny_returns_deny_action() {
struct BlockBashGuardrail;
impl Guardrail for BlockBashGuardrail {
fn pre_tool(
&self,
call: &ToolCall,
) -> Pin<Box<dyn std::future::Future<Output = Result<GuardAction, Error>> + Send + '_>>
{
let name = call.name.clone();
Box::pin(async move {
if name == "bash" {
Ok(GuardAction::deny("bash tool is disabled"))
} else {
Ok(GuardAction::Allow)
}
})
}
}
let g = BlockBashGuardrail;
let bash_call = ToolCall {
id: "c1".into(),
name: "bash".into(),
input: serde_json::json!({"command": "rm -rf /"}),
};
let action = g.pre_tool(&bash_call).await.unwrap();
assert!(
matches!(action, GuardAction::Deny { reason } if reason == "bash tool is disabled")
);
let read_call = ToolCall {
id: "c2".into(),
name: "read".into(),
input: serde_json::json!({"path": "/tmp/test.txt"}),
};
let action = g.pre_tool(&read_call).await.unwrap();
assert!(matches!(action, GuardAction::Allow));
}
}