use serde::{Deserialize, Serialize};
#[cfg(feature = "openapi")]
use utoipa::ToSchema;
pub const MAX_CHECKS: usize = 64;
pub const MAX_ENTRIES_PER_CHECK: usize = 64;
pub const MAX_ENTRY_LEN: usize = 512;
pub const MAX_REPLACEMENT_LEN: usize = 2_000;
pub const MAX_CHECK_ID_LEN: usize = 64;
pub const MAX_JUDGE_PROMPT_LEN: usize = 4_000;
pub const MAX_MCP_REF_LEN: usize = 128;
const REGEX_SIZE_LIMIT: usize = 1 << 20;
const MAX_MATCH_SNIPPET: usize = 200;
pub const DEFAULT_OUTPUT_REPLACEMENT: &str = "[Response withheld by a guardrail.]";
pub const DEFAULT_TOOL_OUTPUT_REPLACEMENT: &str = "[Tool output withheld by a guardrail.]";
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
#[serde(rename_all = "snake_case")]
pub enum GuardrailMode {
#[default]
Active,
Advisory,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
#[cfg_attr(feature = "openapi", schema(example = "output"))]
#[serde(rename_all = "snake_case")]
pub enum GuardrailStage {
Output,
ToolUse,
ToolOutput,
}
impl GuardrailStage {
pub fn as_str(&self) -> &'static str {
match self {
GuardrailStage::Output => "output",
GuardrailStage::ToolUse => "tool_use",
GuardrailStage::ToolOutput => "tool_output",
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
#[serde(rename_all = "snake_case")]
pub enum GuardrailOnFail {
#[default]
Block,
Log,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum GuardrailRule {
Regex { patterns: Vec<String> },
Blocklist {
words: Vec<String>,
#[serde(default)]
case_sensitive: bool,
},
ToolPattern { tools: Vec<String> },
LlmJudge { prompt: String },
Mcp { server: String, tool: String },
}
impl GuardrailRule {
pub fn rule_type(&self) -> &'static str {
match self {
GuardrailRule::Regex { .. } => "regex",
GuardrailRule::Blocklist { .. } => "blocklist",
GuardrailRule::ToolPattern { .. } => "tool_pattern",
GuardrailRule::LlmJudge { .. } => "llm_judge",
GuardrailRule::Mcp { .. } => "mcp",
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct GuardrailCheck {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
pub stage: GuardrailStage,
#[serde(default)]
pub on_fail: GuardrailOnFail,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub replacement: Option<String>,
#[serde(flatten)]
pub rule: GuardrailRule,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
pub struct GuardrailsConfig {
#[serde(default)]
pub mode: GuardrailMode,
#[serde(default)]
pub checks: Vec<GuardrailCheck>,
}
impl GuardrailsConfig {
pub fn from_value(value: &serde_json::Value) -> Result<Self, String> {
if value.is_null() {
return Ok(Self::default());
}
serde_json::from_value(value.clone()).map_err(|e| format!("invalid guardrails config: {e}"))
}
pub fn compile(&self) -> Result<CompiledGuardrails, String> {
if self.checks.len() > MAX_CHECKS {
return Err(format!(
"too many checks: {} (max {MAX_CHECKS})",
self.checks.len()
));
}
let mut compiled = Vec::with_capacity(self.checks.len());
let mut judge_checks = Vec::new();
let mut mcp_checks = Vec::new();
for (index, check) in self.checks.iter().enumerate() {
match &check.rule {
GuardrailRule::LlmJudge { prompt } => {
judge_checks.push(compile_judge_check(index, check, prompt)?);
}
GuardrailRule::Mcp { server, tool } => {
mcp_checks.push(compile_mcp_check(index, check, server, tool)?);
}
_ => compiled.push(compile_check(index, check)?),
}
}
Ok(CompiledGuardrails {
mode: self.mode,
checks: compiled,
judge_checks,
mcp_checks,
})
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "openapi", derive(ToSchema))]
#[cfg_attr(feature = "openapi", schema(example = "block"))]
#[serde(rename_all = "snake_case")]
pub enum GuardrailAction {
Block,
Log,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GuardrailHit {
pub check_index: usize,
pub check_label: String,
pub stage: GuardrailStage,
pub rule_type: &'static str,
pub action: GuardrailAction,
pub reason_code: String,
pub replacement: Option<String>,
pub matched: Option<String>,
}
#[derive(Debug)]
pub struct CompiledJudgeCheck {
pub index: usize,
pub label: String,
pub stage: GuardrailStage,
pub on_fail: GuardrailOnFail,
pub replacement: Option<String>,
pub prompt: String,
}
#[derive(Debug)]
pub struct CompiledMcpCheck {
pub index: usize,
pub label: String,
pub stage: GuardrailStage,
pub on_fail: GuardrailOnFail,
pub replacement: Option<String>,
pub server: String,
pub tool: String,
}
#[derive(Debug)]
pub struct CompiledGuardrails {
mode: GuardrailMode,
checks: Vec<CompiledCheck>,
judge_checks: Vec<CompiledJudgeCheck>,
mcp_checks: Vec<CompiledMcpCheck>,
}
#[derive(Debug)]
struct CompiledCheck {
index: usize,
label: String,
stage: GuardrailStage,
on_fail: GuardrailOnFail,
replacement: Option<String>,
rule_type: &'static str,
matcher: CompiledRule,
}
#[derive(Debug)]
enum CompiledRule {
Regex(Vec<regex::Regex>),
Blocklist {
words: Vec<String>,
case_sensitive: bool,
},
ToolPattern(Vec<String>),
}
impl CompiledGuardrails {
pub fn mode(&self) -> GuardrailMode {
self.mode
}
pub fn has_stage(&self, stage: GuardrailStage) -> bool {
self.checks.iter().any(|c| c.stage == stage)
|| self.judge_checks.iter().any(|c| c.stage == stage)
|| self.mcp_checks.iter().any(|c| c.stage == stage)
}
pub fn judge_checks_for_stage(
&self,
stage: GuardrailStage,
) -> impl Iterator<Item = &CompiledJudgeCheck> {
self.judge_checks.iter().filter(move |c| c.stage == stage)
}
pub fn mcp_checks_for_stage(
&self,
stage: GuardrailStage,
) -> impl Iterator<Item = &CompiledMcpCheck> {
self.mcp_checks.iter().filter(move |c| c.stage == stage)
}
pub fn async_action(&self, on_fail: GuardrailOnFail) -> GuardrailAction {
match (self.mode, on_fail) {
(GuardrailMode::Advisory, _) | (_, GuardrailOnFail::Log) => GuardrailAction::Log,
(GuardrailMode::Active, GuardrailOnFail::Block) => GuardrailAction::Block,
}
}
pub fn judge_action(&self, on_fail: GuardrailOnFail) -> GuardrailAction {
self.async_action(on_fail)
}
pub fn evaluate(
&self,
stage: GuardrailStage,
text: &str,
tool_name: Option<&str>,
skip: &dyn Fn(usize) -> bool,
) -> Vec<GuardrailHit> {
let lowercased: std::cell::OnceCell<String> = std::cell::OnceCell::new();
let mut hits = Vec::new();
for check in self.checks.iter() {
if check.stage != stage || skip(check.index) {
continue;
}
let matched = match &check.matcher {
CompiledRule::Regex(patterns) => patterns
.iter()
.find_map(|re| re.find(text).map(|m| snippet(m.as_str()))),
CompiledRule::Blocklist {
words,
case_sensitive,
} => {
let haystack: &str = if *case_sensitive {
text
} else {
lowercased.get_or_init(|| text.to_lowercase())
};
words
.iter()
.find(|w| haystack.contains(w.as_str()))
.map(|w| snippet(w))
}
CompiledRule::ToolPattern(patterns) => tool_name.and_then(|name| {
patterns
.iter()
.find(|p| wildcard_match(p, name))
.map(|_| snippet(name))
}),
};
if matched.is_some() {
let action = match (self.mode, check.on_fail) {
(GuardrailMode::Advisory, _) | (_, GuardrailOnFail::Log) => {
GuardrailAction::Log
}
(GuardrailMode::Active, GuardrailOnFail::Block) => GuardrailAction::Block,
};
hits.push(GuardrailHit {
check_index: check.index,
check_label: check.label.clone(),
stage: check.stage,
rule_type: check.rule_type,
action,
reason_code: format!("guardrail.{}", check.rule_type),
replacement: check.replacement.clone(),
matched,
});
}
}
hits
}
}
fn compile_check(index: usize, check: &GuardrailCheck) -> Result<CompiledCheck, String> {
let label = match &check.id {
Some(id) => {
if id.is_empty() || id.chars().count() > MAX_CHECK_ID_LEN {
return Err(format!(
"check #{index}: id must be 1..={MAX_CHECK_ID_LEN} characters"
));
}
id.clone()
}
None => format!("{}#{}", check.rule.rule_type(), index),
};
if let Some(replacement) = &check.replacement
&& replacement.len() > MAX_REPLACEMENT_LEN
{
return Err(format!(
"check '{label}': replacement exceeds {MAX_REPLACEMENT_LEN} bytes"
));
}
let matcher = match &check.rule {
GuardrailRule::Regex { patterns } => {
validate_entries(&label, "patterns", patterns)?;
let mut compiled = Vec::with_capacity(patterns.len());
for pattern in patterns {
let re = regex::RegexBuilder::new(pattern)
.size_limit(REGEX_SIZE_LIMIT)
.build()
.map_err(|e| format!("check '{label}': invalid regex '{pattern}': {e}"))?;
compiled.push(re);
}
CompiledRule::Regex(compiled)
}
GuardrailRule::Blocklist {
words,
case_sensitive,
} => {
validate_entries(&label, "words", words)?;
let words = if *case_sensitive {
words.clone()
} else {
words.iter().map(|w| w.to_lowercase()).collect()
};
CompiledRule::Blocklist {
words,
case_sensitive: *case_sensitive,
}
}
GuardrailRule::ToolPattern { tools } => {
if check.stage != GuardrailStage::ToolUse {
return Err(format!(
"check '{label}': tool_pattern is only valid for the tool_use stage"
));
}
validate_entries(&label, "tools", tools)?;
CompiledRule::ToolPattern(tools.clone())
}
GuardrailRule::LlmJudge { .. } => {
unreachable!(
"llm_judge checks are routed to compile_judge_check before compile_check is called"
)
}
GuardrailRule::Mcp { .. } => {
unreachable!(
"mcp checks are routed to compile_mcp_check before compile_check is called"
)
}
};
Ok(CompiledCheck {
index,
label,
stage: check.stage,
on_fail: check.on_fail,
replacement: check.replacement.clone(),
rule_type: check.rule.rule_type(),
matcher,
})
}
fn compile_judge_check(
index: usize,
check: &GuardrailCheck,
prompt: &str,
) -> Result<CompiledJudgeCheck, String> {
let label = match &check.id {
Some(id) => {
if id.is_empty() || id.chars().count() > MAX_CHECK_ID_LEN {
return Err(format!(
"check #{index}: id must be 1..={MAX_CHECK_ID_LEN} characters"
));
}
id.clone()
}
None => format!("llm_judge#{index}"),
};
if prompt.is_empty() {
return Err(format!(
"check '{label}': llm_judge prompt must not be empty"
));
}
if prompt.len() > MAX_JUDGE_PROMPT_LEN {
return Err(format!(
"check '{label}': llm_judge prompt exceeds {MAX_JUDGE_PROMPT_LEN} bytes"
));
}
match check.stage {
GuardrailStage::ToolUse | GuardrailStage::ToolOutput => {}
GuardrailStage::Output => {
return Err(format!(
"check '{label}': llm_judge is not supported on the 'output' stage in this phase; \
use 'tool_use' or 'tool_output'"
));
}
}
if let Some(replacement) = &check.replacement
&& replacement.len() > MAX_REPLACEMENT_LEN
{
return Err(format!(
"check '{label}': replacement exceeds {MAX_REPLACEMENT_LEN} bytes"
));
}
Ok(CompiledJudgeCheck {
index,
label,
stage: check.stage,
on_fail: check.on_fail,
replacement: check.replacement.clone(),
prompt: prompt.to_string(),
})
}
fn compile_mcp_check(
index: usize,
check: &GuardrailCheck,
server: &str,
tool: &str,
) -> Result<CompiledMcpCheck, String> {
let label = match &check.id {
Some(id) => {
if id.is_empty() || id.chars().count() > MAX_CHECK_ID_LEN {
return Err(format!(
"check #{index}: id must be 1..={MAX_CHECK_ID_LEN} characters"
));
}
id.clone()
}
None => format!("mcp#{index}"),
};
match check.stage {
GuardrailStage::ToolUse | GuardrailStage::ToolOutput => {}
GuardrailStage::Output => {
return Err(format!(
"check '{label}': mcp is not supported on the 'output' stage in this phase; \
use 'tool_use' or 'tool_output'"
));
}
}
for (field, value) in [("server", server), ("tool", tool)] {
if value.is_empty() {
return Err(format!("check '{label}': mcp {field} must not be empty"));
}
if value.len() > MAX_MCP_REF_LEN {
return Err(format!(
"check '{label}': mcp {field} exceeds {MAX_MCP_REF_LEN} bytes"
));
}
}
if let Some(replacement) = &check.replacement
&& replacement.len() > MAX_REPLACEMENT_LEN
{
return Err(format!(
"check '{label}': replacement exceeds {MAX_REPLACEMENT_LEN} bytes"
));
}
Ok(CompiledMcpCheck {
index,
label,
stage: check.stage,
on_fail: check.on_fail,
replacement: check.replacement.clone(),
server: server.to_string(),
tool: tool.to_string(),
})
}
fn validate_entries(label: &str, field: &str, entries: &[String]) -> Result<(), String> {
if entries.is_empty() {
return Err(format!("check '{label}': {field} must not be empty"));
}
if entries.len() > MAX_ENTRIES_PER_CHECK {
return Err(format!(
"check '{label}': too many {field}: {} (max {MAX_ENTRIES_PER_CHECK})",
entries.len()
));
}
for entry in entries {
if entry.is_empty() {
return Err(format!(
"check '{label}': {field} entries must not be empty"
));
}
if entry.len() > MAX_ENTRY_LEN {
return Err(format!(
"check '{label}': {field} entry exceeds {MAX_ENTRY_LEN} bytes"
));
}
}
Ok(())
}
fn snippet(s: &str) -> String {
let mut end = MAX_MATCH_SNIPPET.min(s.len());
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
s[..end].to_string()
}
pub fn wildcard_match(pattern: &str, name: &str) -> bool {
let segments: Vec<&str> = pattern.split('*').collect();
if segments.len() == 1 {
return pattern == name;
}
let mut rest = name;
for (i, seg) in segments.iter().enumerate() {
if seg.is_empty() {
continue;
}
if i == 0 {
match rest.strip_prefix(seg) {
Some(r) => rest = r,
None => return false,
}
} else if i == segments.len() - 1 {
return rest.ends_with(seg);
} else {
match rest.find(seg) {
Some(pos) => rest = &rest[pos + seg.len()..],
None => return false,
}
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn no_skip() -> impl Fn(usize) -> bool {
|_| false
}
fn compile(value: serde_json::Value) -> Result<CompiledGuardrails, String> {
GuardrailsConfig::from_value(&value)?.compile()
}
#[test]
fn parses_and_compiles_minimal_config() {
let compiled = compile(json!({
"checks": [
{"stage": "output", "type": "blocklist", "words": ["forbidden"]},
]
}))
.expect("compiles");
assert_eq!(compiled.mode(), GuardrailMode::Active);
assert!(compiled.has_stage(GuardrailStage::Output));
assert!(!compiled.has_stage(GuardrailStage::ToolUse));
}
#[test]
fn empty_or_null_config_compiles_to_no_checks() {
let compiled = compile(json!({})).expect("compiles");
assert!(!compiled.has_stage(GuardrailStage::Output));
let compiled = GuardrailsConfig::from_value(&serde_json::Value::Null)
.unwrap()
.compile()
.unwrap();
assert!(!compiled.has_stage(GuardrailStage::Output));
}
#[test]
fn blocklist_matches_case_insensitive_by_default() {
let compiled = compile(json!({
"checks": [
{"stage": "output", "type": "blocklist", "words": ["Secret Word"]},
]
}))
.unwrap();
let hits = compiled.evaluate(
GuardrailStage::Output,
"this contains a SECRET word inside",
None,
&no_skip(),
);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].action, GuardrailAction::Block);
assert_eq!(hits[0].reason_code, "guardrail.blocklist");
assert_eq!(hits[0].matched.as_deref(), Some("secret word"));
}
#[test]
fn blocklist_case_sensitive_only_matches_exact_case() {
let compiled = compile(json!({
"checks": [
{"stage": "output", "type": "blocklist", "words": ["Secret"], "case_sensitive": true},
]
}))
.unwrap();
assert!(
compiled
.evaluate(GuardrailStage::Output, "a secret here", None, &no_skip())
.is_empty()
);
assert_eq!(
compiled
.evaluate(GuardrailStage::Output, "a Secret here", None, &no_skip())
.len(),
1
);
}
#[test]
fn regex_matches_and_reports_pattern_source() {
let compiled = compile(json!({
"checks": [
{"id": "ssn", "stage": "output", "type": "regex",
"patterns": ["\\b\\d{3}-\\d{2}-\\d{4}\\b"]},
]
}))
.unwrap();
let hits = compiled.evaluate(
GuardrailStage::Output,
"my ssn is 123-45-6789 ok",
None,
&no_skip(),
);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].check_label, "ssn");
assert_eq!(hits[0].rule_type, "regex");
}
#[test]
fn invalid_regex_fails_compile_with_check_label() {
let err = compile(json!({
"checks": [
{"id": "bad", "stage": "output", "type": "regex", "patterns": ["("]},
]
}))
.unwrap_err();
assert!(err.contains("check 'bad'"), "{err}");
}
#[test]
fn tool_pattern_matches_wildcards_on_tool_use_stage() {
let compiled = compile(json!({
"checks": [
{"stage": "tool_use", "type": "tool_pattern", "tools": ["mcp_*", "bash*"]},
]
}))
.unwrap();
let hits = compiled.evaluate(
GuardrailStage::ToolUse,
"{\"cmd\":\"ls\"}",
Some("mcp_github__create_issue"),
&no_skip(),
);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].matched.as_deref(), Some("mcp_github__create_issue"));
assert!(
compiled
.evaluate(GuardrailStage::ToolUse, "{}", Some("read_file"), &no_skip())
.is_empty()
);
}
#[test]
fn tool_pattern_rejected_outside_tool_use_stage() {
let err = compile(json!({
"checks": [
{"stage": "output", "type": "tool_pattern", "tools": ["bash*"]},
]
}))
.unwrap_err();
assert!(err.contains("only valid for the tool_use stage"), "{err}");
}
#[test]
fn advisory_mode_downgrades_block_to_log() {
let compiled = compile(json!({
"mode": "advisory",
"checks": [
{"stage": "output", "type": "blocklist", "words": ["x"], "on_fail": "block"},
]
}))
.unwrap();
let hits = compiled.evaluate(GuardrailStage::Output, "x", None, &no_skip());
assert_eq!(hits[0].action, GuardrailAction::Log);
}
#[test]
fn on_fail_log_yields_log_action_in_active_mode() {
let compiled = compile(json!({
"checks": [
{"stage": "output", "type": "blocklist", "words": ["x"], "on_fail": "log"},
]
}))
.unwrap();
let hits = compiled.evaluate(GuardrailStage::Output, "x", None, &no_skip());
assert_eq!(hits[0].action, GuardrailAction::Log);
}
#[test]
fn skip_suppresses_checks_by_index() {
let compiled = compile(json!({
"checks": [
{"stage": "output", "type": "blocklist", "words": ["x"]},
{"stage": "output", "type": "blocklist", "words": ["y"]},
]
}))
.unwrap();
let hits = compiled.evaluate(GuardrailStage::Output, "x y", None, &|i| i == 0);
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].check_index, 1);
}
#[test]
fn enforces_limits() {
let too_many_checks: Vec<_> = (0..=MAX_CHECKS)
.map(|_| json!({"stage": "output", "type": "blocklist", "words": ["x"]}))
.collect();
assert!(
compile(json!({"checks": too_many_checks}))
.unwrap_err()
.contains("too many checks")
);
let long_entry = "a".repeat(MAX_ENTRY_LEN + 1);
assert!(
compile(json!({
"checks": [{"stage": "output", "type": "blocklist", "words": [long_entry]}]
}))
.unwrap_err()
.contains("exceeds")
);
assert!(
compile(json!({
"checks": [{"stage": "output", "type": "blocklist", "words": []}]
}))
.unwrap_err()
.contains("must not be empty")
);
}
#[test]
fn unknown_fields_are_rejected_gracefully_by_value_parse() {
let err = GuardrailsConfig::from_value(&json!({"checks": "nope"})).unwrap_err();
assert!(err.contains("invalid guardrails config"), "{err}");
}
#[test]
fn wildcard_match_covers_anchors_and_inner_stars() {
assert!(wildcard_match("bash*", "bashkit_exec"));
assert!(wildcard_match("*_file", "read_file"));
assert!(wildcard_match("mcp_*__delete_*", "mcp_github__delete_repo"));
assert!(wildcard_match("*", "anything"));
assert!(wildcard_match("exact", "exact"));
assert!(!wildcard_match("exact", "exact_no"));
assert!(!wildcard_match("bash*", "zsh"));
assert!(!wildcard_match(
"mcp_*__delete_*",
"mcp_github__create_repo"
));
}
#[test]
fn config_roundtrips_serde() {
let cfg = GuardrailsConfig {
mode: GuardrailMode::Advisory,
checks: vec![GuardrailCheck {
id: Some("c1".into()),
stage: GuardrailStage::ToolUse,
on_fail: GuardrailOnFail::Log,
replacement: None,
rule: GuardrailRule::ToolPattern {
tools: vec!["bash*".into()],
},
}],
};
let value = serde_json::to_value(&cfg).unwrap();
assert_eq!(value["checks"][0]["type"], "tool_pattern");
assert_eq!(value["checks"][0]["stage"], "tool_use");
let back = GuardrailsConfig::from_value(&value).unwrap();
assert_eq!(back, cfg);
}
#[test]
fn llm_judge_compiles_for_tool_stages() {
let compiled = compile(json!({
"checks": [
{"stage": "tool_use", "type": "llm_judge", "prompt": "Block requests to delete data."},
{"id": "tj2", "stage": "tool_output", "type": "llm_judge",
"prompt": "Block responses containing PII.", "on_fail": "log"},
]
}))
.expect("compiles");
assert!(compiled.has_stage(GuardrailStage::ToolUse));
assert!(compiled.has_stage(GuardrailStage::ToolOutput));
assert!(
compiled
.evaluate(
GuardrailStage::ToolUse,
"{}",
Some("delete_user"),
&no_skip()
)
.is_empty()
);
let use_checks: Vec<_> = compiled
.judge_checks_for_stage(GuardrailStage::ToolUse)
.collect();
assert_eq!(use_checks.len(), 1);
assert_eq!(use_checks[0].prompt, "Block requests to delete data.");
assert_eq!(use_checks[0].on_fail, GuardrailOnFail::Block);
let out_checks: Vec<_> = compiled
.judge_checks_for_stage(GuardrailStage::ToolOutput)
.collect();
assert_eq!(out_checks.len(), 1);
assert_eq!(out_checks[0].label, "tj2");
assert_eq!(out_checks[0].on_fail, GuardrailOnFail::Log);
}
#[test]
fn llm_judge_rejected_on_output_stage() {
let err = compile(json!({
"checks": [
{"stage": "output", "type": "llm_judge", "prompt": "Block bad content."},
]
}))
.unwrap_err();
assert!(err.contains("not supported on the 'output' stage"), "{err}");
}
#[test]
fn llm_judge_empty_prompt_rejected() {
let err = compile(json!({
"checks": [
{"stage": "tool_use", "type": "llm_judge", "prompt": ""},
]
}))
.unwrap_err();
assert!(err.contains("prompt must not be empty"), "{err}");
}
#[test]
fn llm_judge_prompt_too_long_rejected() {
let long_prompt = "x".repeat(MAX_JUDGE_PROMPT_LEN + 1);
let err = compile(json!({
"checks": [
{"stage": "tool_use", "type": "llm_judge", "prompt": long_prompt},
]
}))
.unwrap_err();
assert!(err.contains("exceeds"), "{err}");
}
#[test]
fn llm_judge_not_in_sync_evaluate() {
let compiled = compile(json!({
"checks": [
{"stage": "tool_use", "type": "llm_judge", "prompt": "Block everything."},
]
}))
.unwrap();
assert!(
compiled
.evaluate(
GuardrailStage::ToolUse,
"anything",
Some("tool"),
&no_skip()
)
.is_empty()
);
}
#[test]
fn llm_judge_advisory_downgrades_judge_action() {
let compiled = compile(json!({
"mode": "advisory",
"checks": [
{"stage": "tool_use", "type": "llm_judge", "prompt": "p", "on_fail": "block"},
]
}))
.unwrap();
let check = compiled
.judge_checks_for_stage(GuardrailStage::ToolUse)
.next()
.unwrap();
assert_eq!(compiled.judge_action(check.on_fail), GuardrailAction::Log);
}
#[test]
fn llm_judge_active_block_yields_block_action() {
let compiled = compile(json!({
"checks": [
{"stage": "tool_use", "type": "llm_judge", "prompt": "p", "on_fail": "block"},
]
}))
.unwrap();
let check = compiled
.judge_checks_for_stage(GuardrailStage::ToolUse)
.next()
.unwrap();
assert_eq!(compiled.judge_action(check.on_fail), GuardrailAction::Block);
}
#[test]
fn llm_judge_serde_roundtrip() {
let cfg = GuardrailsConfig {
mode: GuardrailMode::Active,
checks: vec![GuardrailCheck {
id: Some("pii-judge".into()),
stage: GuardrailStage::ToolOutput,
on_fail: GuardrailOnFail::Log,
replacement: None,
rule: GuardrailRule::LlmJudge {
prompt: "Block responses that contain PII.".into(),
},
}],
};
let value = serde_json::to_value(&cfg).unwrap();
assert_eq!(value["checks"][0]["type"], "llm_judge");
assert_eq!(value["checks"][0]["stage"], "tool_output");
assert_eq!(
value["checks"][0]["prompt"],
"Block responses that contain PII."
);
let back = GuardrailsConfig::from_value(&value).unwrap();
assert_eq!(back, cfg);
}
#[test]
fn mixed_sync_and_judge_checks_compile_independently() {
let compiled = compile(json!({
"checks": [
{"stage": "tool_use", "type": "tool_pattern", "tools": ["bash*"]},
{"stage": "tool_use", "type": "llm_judge", "prompt": "Block policy violations."},
]
}))
.unwrap();
let hits = compiled.evaluate(GuardrailStage::ToolUse, "{}", Some("bash_exec"), &no_skip());
assert_eq!(hits.len(), 1);
assert_eq!(hits[0].rule_type, "tool_pattern");
let judges: Vec<_> = compiled
.judge_checks_for_stage(GuardrailStage::ToolUse)
.collect();
assert_eq!(judges.len(), 1);
}
#[test]
fn mcp_compiles_for_tool_stages() {
let compiled = compile(json!({
"checks": [
{"stage": "tool_use", "type": "mcp", "server": "guard", "tool": "screen"},
{"id": "mc2", "stage": "tool_output", "type": "mcp",
"server": "guard", "tool": "scan", "on_fail": "log"},
]
}))
.expect("compiles");
assert!(compiled.has_stage(GuardrailStage::ToolUse));
assert!(compiled.has_stage(GuardrailStage::ToolOutput));
assert!(
compiled
.evaluate(
GuardrailStage::ToolUse,
"{}",
Some("delete_user"),
&no_skip()
)
.is_empty()
);
let use_checks: Vec<_> = compiled
.mcp_checks_for_stage(GuardrailStage::ToolUse)
.collect();
assert_eq!(use_checks.len(), 1);
assert_eq!(use_checks[0].server, "guard");
assert_eq!(use_checks[0].tool, "screen");
assert_eq!(use_checks[0].on_fail, GuardrailOnFail::Block);
let out_checks: Vec<_> = compiled
.mcp_checks_for_stage(GuardrailStage::ToolOutput)
.collect();
assert_eq!(out_checks.len(), 1);
assert_eq!(out_checks[0].label, "mc2");
assert_eq!(out_checks[0].on_fail, GuardrailOnFail::Log);
}
#[test]
fn mcp_rejected_on_output_stage() {
let err = compile(json!({
"checks": [
{"stage": "output", "type": "mcp", "server": "guard", "tool": "scan"},
]
}))
.unwrap_err();
assert!(err.contains("not supported on the 'output' stage"), "{err}");
}
#[test]
fn mcp_empty_server_or_tool_rejected() {
let err = compile(json!({
"checks": [
{"stage": "tool_use", "type": "mcp", "server": "", "tool": "scan"},
]
}))
.unwrap_err();
assert!(err.contains("server must not be empty"), "{err}");
let err = compile(json!({
"checks": [
{"stage": "tool_use", "type": "mcp", "server": "guard", "tool": ""},
]
}))
.unwrap_err();
assert!(err.contains("tool must not be empty"), "{err}");
}
#[test]
fn mcp_ref_too_long_rejected() {
let long = "x".repeat(MAX_MCP_REF_LEN + 1);
let err = compile(json!({
"checks": [
{"stage": "tool_use", "type": "mcp", "server": long, "tool": "scan"},
]
}))
.unwrap_err();
assert!(err.contains("exceeds"), "{err}");
}
#[test]
fn mcp_not_in_sync_evaluate() {
let compiled = compile(json!({
"checks": [
{"stage": "tool_use", "type": "mcp", "server": "guard", "tool": "scan"},
]
}))
.unwrap();
assert!(
compiled
.evaluate(
GuardrailStage::ToolUse,
"anything",
Some("tool"),
&no_skip()
)
.is_empty()
);
}
#[test]
fn mcp_advisory_downgrades_action() {
let compiled = compile(json!({
"mode": "advisory",
"checks": [
{"stage": "tool_use", "type": "mcp", "server": "g", "tool": "t", "on_fail": "block"},
]
}))
.unwrap();
let check = compiled
.mcp_checks_for_stage(GuardrailStage::ToolUse)
.next()
.unwrap();
assert_eq!(compiled.async_action(check.on_fail), GuardrailAction::Log);
}
#[test]
fn mcp_active_block_yields_block_action() {
let compiled = compile(json!({
"checks": [
{"stage": "tool_use", "type": "mcp", "server": "g", "tool": "t", "on_fail": "block"},
]
}))
.unwrap();
let check = compiled
.mcp_checks_for_stage(GuardrailStage::ToolUse)
.next()
.unwrap();
assert_eq!(compiled.async_action(check.on_fail), GuardrailAction::Block);
}
#[test]
fn mcp_serde_roundtrip() {
let cfg = GuardrailsConfig {
mode: GuardrailMode::Active,
checks: vec![GuardrailCheck {
id: Some("ext-guard".into()),
stage: GuardrailStage::ToolOutput,
on_fail: GuardrailOnFail::Log,
replacement: None,
rule: GuardrailRule::Mcp {
server: "guard".into(),
tool: "scan".into(),
},
}],
};
let value = serde_json::to_value(&cfg).unwrap();
assert_eq!(value["checks"][0]["type"], "mcp");
assert_eq!(value["checks"][0]["stage"], "tool_output");
assert_eq!(value["checks"][0]["server"], "guard");
assert_eq!(value["checks"][0]["tool"], "scan");
let back = GuardrailsConfig::from_value(&value).unwrap();
assert_eq!(back, cfg);
}
}