use serde::Deserialize;
use serde_json::Value as JsonValue;
use std::sync::OnceLock;
use crate::error::NikaError;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OnFailure {
#[default]
Retry,
Escalate,
Fail,
}
impl OnFailure {
pub fn description(&self) -> &'static str {
match self {
OnFailure::Retry => "retry with feedback",
OnFailure::Escalate => "escalate to human/supervisor",
OnFailure::Fail => "fail the task",
}
}
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type", rename_all = "lowercase")]
pub enum GuardrailConfig {
Length(LengthGuardrail),
Schema(SchemaGuardrail),
Regex(RegexGuardrail),
Llm(LlmGuardrail),
}
impl GuardrailConfig {
pub fn guardrail_type(&self) -> &'static str {
match self {
GuardrailConfig::Length(_) => "length",
GuardrailConfig::Schema(_) => "schema",
GuardrailConfig::Regex(_) => "regex",
GuardrailConfig::Llm(_) => "llm",
}
}
pub fn id(&self) -> String {
match self {
GuardrailConfig::Length(g) => g.id.clone().unwrap_or_else(|| "length".to_string()),
GuardrailConfig::Schema(g) => g.id.clone().unwrap_or_else(|| "schema".to_string()),
GuardrailConfig::Regex(g) => g.id.clone().unwrap_or_else(|| "regex".to_string()),
GuardrailConfig::Llm(g) => g.id.clone().unwrap_or_else(|| "llm".to_string()),
}
}
pub fn on_failure(&self) -> OnFailure {
match self {
GuardrailConfig::Length(g) => g.on_failure,
GuardrailConfig::Schema(g) => g.on_failure,
GuardrailConfig::Regex(g) => g.on_failure,
GuardrailConfig::Llm(g) => g.on_failure,
}
}
pub fn is_async(&self) -> bool {
matches!(self, GuardrailConfig::Llm(_))
}
pub fn validate(&self) -> Result<(), NikaError> {
match self {
GuardrailConfig::Length(g) => g.validate(),
GuardrailConfig::Schema(g) => g.validate(),
GuardrailConfig::Regex(g) => g.validate(),
GuardrailConfig::Llm(g) => g.validate(),
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct LengthGuardrail {
#[serde(default)]
pub id: Option<String>,
#[serde(default)]
pub min_words: Option<u32>,
#[serde(default)]
pub max_words: Option<u32>,
#[serde(default)]
pub min_chars: Option<u32>,
#[serde(default)]
pub max_chars: Option<u32>,
#[serde(default)]
pub message: Option<String>,
#[serde(default)]
pub on_failure: OnFailure,
}
impl LengthGuardrail {
pub fn validate(&self) -> Result<(), NikaError> {
if self.min_words.is_none()
&& self.max_words.is_none()
&& self.min_chars.is_none()
&& self.max_chars.is_none()
{
return Err(NikaError::ValidationError {
reason: "length guardrail requires at least one of: min_words, max_words, min_chars, max_chars".into(),
});
}
if let (Some(min), Some(max)) = (self.min_words, self.max_words) {
if min > max {
return Err(NikaError::ValidationError {
reason: format!(
"length guardrail: min_words ({}) > max_words ({})",
min, max
),
});
}
}
if let (Some(min), Some(max)) = (self.min_chars, self.max_chars) {
if min > max {
return Err(NikaError::ValidationError {
reason: format!(
"length guardrail: min_chars ({}) > max_chars ({})",
min, max
),
});
}
}
Ok(())
}
pub fn check(&self, output: &str) -> GuardrailResult {
let word_count = output.split_whitespace().count() as u32;
let char_count = output.chars().count() as u32;
let id = self.id.clone().unwrap_or_else(|| "length".to_string());
if let Some(min) = self.min_words {
if word_count < min {
return GuardrailResult::failed_with_action(
id,
"length",
self.message.clone().unwrap_or_else(|| {
format!("Output has {} words, minimum is {}", word_count, min)
}),
self.on_failure,
);
}
}
if let Some(max) = self.max_words {
if word_count > max {
return GuardrailResult::failed_with_action(
id,
"length",
self.message.clone().unwrap_or_else(|| {
format!("Output has {} words, maximum is {}", word_count, max)
}),
self.on_failure,
);
}
}
if let Some(min) = self.min_chars {
if char_count < min {
return GuardrailResult::failed_with_action(
id,
"length",
self.message.clone().unwrap_or_else(|| {
format!("Output has {} chars, minimum is {}", char_count, min)
}),
self.on_failure,
);
}
}
if let Some(max) = self.max_chars {
if char_count > max {
return GuardrailResult::failed_with_action(
id,
"length",
self.message.clone().unwrap_or_else(|| {
format!("Output has {} chars, maximum is {}", char_count, max)
}),
self.on_failure,
);
}
}
GuardrailResult::passed(id, "length")
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct SchemaGuardrail {
#[serde(default)]
pub id: Option<String>,
pub json_schema: JsonValue,
#[serde(default)]
pub message: Option<String>,
#[serde(default)]
pub on_failure: OnFailure,
}
impl SchemaGuardrail {
pub fn validate(&self) -> Result<(), NikaError> {
if !self.json_schema.is_object() {
return Err(NikaError::ValidationError {
reason: "schema guardrail: json_schema must be an object".into(),
});
}
Ok(())
}
pub fn check(&self, output: &str) -> GuardrailResult {
let id = self.id.clone().unwrap_or_else(|| "schema".to_string());
let parsed: Result<JsonValue, _> = serde_json::from_str(output);
let json = match parsed {
Ok(v) => v,
Err(e) => {
return GuardrailResult::failed_with_action(
id,
"schema",
self.message
.clone()
.unwrap_or_else(|| format!("Invalid JSON: {}", e)),
self.on_failure,
);
}
};
if let Some(required) = self.json_schema.get("required").and_then(|r| r.as_array()) {
if let Some(obj) = json.as_object() {
for field in required {
if let Some(field_name) = field.as_str() {
if !obj.contains_key(field_name) {
return GuardrailResult::failed_with_action(
id,
"schema",
self.message.clone().unwrap_or_else(|| {
format!("Missing required field: {}", field_name)
}),
self.on_failure,
);
}
}
}
} else {
return GuardrailResult::failed_with_action(
id,
"schema",
self.message
.clone()
.unwrap_or_else(|| "Expected JSON object".to_string()),
self.on_failure,
);
}
}
GuardrailResult::passed(id, "schema")
}
}
#[derive(Debug, Deserialize)]
pub struct RegexGuardrail {
#[serde(default)]
pub id: Option<String>,
pub pattern: String,
#[serde(default)]
pub negate: bool,
#[serde(default)]
pub message: Option<String>,
#[serde(default)]
pub on_failure: OnFailure,
#[serde(skip)]
compiled: OnceLock<Option<regex::Regex>>,
}
impl Clone for RegexGuardrail {
fn clone(&self) -> Self {
Self {
id: self.id.clone(),
pattern: self.pattern.clone(),
negate: self.negate,
message: self.message.clone(),
on_failure: self.on_failure,
compiled: OnceLock::new(),
}
}
}
impl Default for RegexGuardrail {
fn default() -> Self {
Self {
id: None,
pattern: String::new(),
negate: false,
message: None,
on_failure: OnFailure::default(),
compiled: OnceLock::new(),
}
}
}
impl RegexGuardrail {
fn get_compiled(&self) -> Option<®ex::Regex> {
self.compiled
.get_or_init(|| match regex::Regex::new(&self.pattern) {
Ok(re) => Some(re),
Err(e) => {
tracing::warn!(pattern = %self.pattern, error = %e, "Invalid guardrail regex pattern");
None
}
})
.as_ref()
}
pub fn validate(&self) -> Result<(), NikaError> {
match self.get_compiled() {
Some(_) => Ok(()),
None => Err(NikaError::ValidationError {
reason: format!("regex guardrail: invalid pattern '{}'", self.pattern),
}),
}
}
pub fn check(&self, output: &str) -> GuardrailResult {
let id = self.id.clone().unwrap_or_else(|| "regex".to_string());
let re = match self.get_compiled() {
Some(r) => r,
None => {
return GuardrailResult::failed_with_action(
id,
"regex",
format!("Invalid regex pattern: {}", self.pattern),
self.on_failure,
);
}
};
let matches = re.is_match(output);
let passed = if self.negate { !matches } else { matches };
if passed {
GuardrailResult::passed(id, "regex")
} else {
let default_msg = if self.negate {
format!("Output must NOT match pattern: {}", self.pattern)
} else {
format!("Output must match pattern: {}", self.pattern)
};
GuardrailResult::failed_with_action(
id,
"regex",
self.message.clone().unwrap_or(default_msg),
self.on_failure,
)
}
}
}
#[derive(Debug, Deserialize)]
pub struct LlmGuardrail {
#[serde(default)]
pub id: Option<String>,
pub judge_prompt: String,
#[serde(default = "default_pass_pattern")]
pub pass_pattern: String,
#[serde(default)]
pub model: Option<String>,
#[serde(default = "default_judge_max_tokens")]
pub max_tokens: u32,
#[serde(default = "default_judge_temperature")]
pub temperature: f64,
#[serde(default)]
pub message: Option<String>,
#[serde(default)]
pub on_failure: OnFailure,
#[serde(skip)]
compiled_pass_pattern: OnceLock<Option<regex::Regex>>,
}
impl Clone for LlmGuardrail {
fn clone(&self) -> Self {
Self {
id: self.id.clone(),
judge_prompt: self.judge_prompt.clone(),
pass_pattern: self.pass_pattern.clone(),
model: self.model.clone(),
max_tokens: self.max_tokens,
temperature: self.temperature,
message: self.message.clone(),
on_failure: self.on_failure,
compiled_pass_pattern: OnceLock::new(),
}
}
}
impl Default for LlmGuardrail {
fn default() -> Self {
Self {
id: None,
judge_prompt: String::new(),
pass_pattern: default_pass_pattern(),
model: None,
max_tokens: default_judge_max_tokens(),
temperature: default_judge_temperature(),
message: None,
on_failure: OnFailure::default(),
compiled_pass_pattern: OnceLock::new(),
}
}
}
fn default_pass_pattern() -> String {
r"^PASS".to_string()
}
fn default_judge_max_tokens() -> u32 {
150
}
fn default_judge_temperature() -> f64 {
0.0 }
impl LlmGuardrail {
fn get_compiled_pass_pattern(&self) -> Option<®ex::Regex> {
self.compiled_pass_pattern
.get_or_init(|| match regex::Regex::new(&self.pass_pattern) {
Ok(re) => Some(re),
Err(e) => {
tracing::warn!(pattern = %self.pass_pattern, error = %e, "Invalid guardrail pass_pattern regex");
None
}
})
.as_ref()
}
pub fn validate(&self) -> Result<(), NikaError> {
if self.judge_prompt.trim().is_empty() {
return Err(NikaError::ValidationError {
reason: "llm guardrail: judge_prompt cannot be empty".into(),
});
}
if self.pass_pattern.trim().is_empty() {
return Err(NikaError::ValidationError {
reason: "llm guardrail: pass_pattern cannot be empty".into(),
});
}
if self.get_compiled_pass_pattern().is_none() {
return Err(NikaError::ValidationError {
reason: format!(
"llm guardrail: invalid pass_pattern '{}'",
self.pass_pattern
),
});
}
if self.max_tokens == 0 {
return Err(NikaError::ValidationError {
reason: "llm guardrail: max_tokens must be > 0".into(),
});
}
if !(0.0..=2.0).contains(&self.temperature) {
return Err(NikaError::ValidationError {
reason: format!(
"llm guardrail: temperature must be 0.0-2.0, got {}",
self.temperature
),
});
}
Ok(())
}
pub fn build_judge_prompt(&self, output: &str) -> String {
format!(
"{}\n\n---\nOutput to evaluate:\n{}\n---",
self.judge_prompt.trim(),
output
)
}
pub fn check_judge_response(&self, judge_response: &str) -> GuardrailResult {
let id = self.id.clone().unwrap_or_else(|| "llm".to_string());
let re = match self.get_compiled_pass_pattern() {
Some(r) => r,
None => {
return GuardrailResult::failed_with_action(
id,
"llm",
format!("Invalid pass_pattern: {}", self.pass_pattern),
self.on_failure,
);
}
};
if re.is_match(judge_response) {
GuardrailResult::passed(id, "llm")
} else {
let default_msg = format!(
"LLM judge did not pass. Response: {}",
judge_response.chars().take(200).collect::<String>()
);
GuardrailResult::failed_with_action(
id,
"llm",
self.message.clone().unwrap_or(default_msg),
self.on_failure,
)
}
}
}
#[derive(Debug, Clone)]
pub struct GuardrailResult {
pub passed: bool,
pub guardrail_id: String,
pub guardrail_type: String,
pub message: Option<String>,
pub on_failure: OnFailure,
}
impl GuardrailResult {
pub fn passed(guardrail_id: String, guardrail_type: &str) -> Self {
Self {
passed: true,
guardrail_id,
guardrail_type: guardrail_type.to_string(),
message: None,
on_failure: OnFailure::Retry, }
}
pub fn failed(guardrail_id: String, guardrail_type: &str, message: String) -> Self {
Self {
passed: false,
guardrail_id,
guardrail_type: guardrail_type.to_string(),
message: Some(message),
on_failure: OnFailure::Retry, }
}
pub fn failed_with_action(
guardrail_id: String,
guardrail_type: &str,
message: String,
on_failure: OnFailure,
) -> Self {
Self {
passed: false,
guardrail_id,
guardrail_type: guardrail_type.to_string(),
message: Some(message),
on_failure,
}
}
pub fn requires_escalation(&self) -> bool {
!self.passed && self.on_failure == OnFailure::Escalate
}
pub fn should_fail(&self) -> bool {
!self.passed && self.on_failure == OnFailure::Fail
}
pub fn should_retry(&self) -> bool {
!self.passed && self.on_failure == OnFailure::Retry
}
}
pub fn run_sync_guardrails(guardrails: &[GuardrailConfig], output: &str) -> Vec<GuardrailResult> {
guardrails
.iter()
.filter_map(|g| match g {
GuardrailConfig::Length(lg) => Some(lg.check(output)),
GuardrailConfig::Schema(sg) => Some(sg.check(output)),
GuardrailConfig::Regex(rg) => Some(rg.check(output)),
GuardrailConfig::Llm(_) => None, })
.collect()
}
pub fn all_guardrails_passed(results: &[GuardrailResult]) -> bool {
results.iter().all(|r| r.passed)
}
pub fn first_failed_guardrail(results: &[GuardrailResult]) -> Option<&GuardrailResult> {
results.iter().find(|r| !r.passed)
}
pub fn escalation_required(results: &[GuardrailResult]) -> Vec<&GuardrailResult> {
results.iter().filter(|r| r.requires_escalation()).collect()
}
pub fn immediate_failures(results: &[GuardrailResult]) -> Vec<&GuardrailResult> {
results.iter().filter(|r| r.should_fail()).collect()
}
pub fn has_llm_guardrails(guardrails: &[GuardrailConfig]) -> bool {
guardrails.iter().any(|g| g.is_async())
}
pub fn partition_guardrails(
guardrails: &[GuardrailConfig],
) -> (Vec<&GuardrailConfig>, Vec<&GuardrailConfig>) {
guardrails.iter().partition(|g| !g.is_async())
}
#[derive(Debug, Clone)]
pub struct GuardrailChainResult {
pub results: Vec<GuardrailResult>,
pub early_terminated: bool,
pub termination_reason: Option<ChainTerminationReason>,
}
impl GuardrailChainResult {
pub fn completed(results: Vec<GuardrailResult>) -> Self {
Self {
results,
early_terminated: false,
termination_reason: None,
}
}
pub fn terminated(results: Vec<GuardrailResult>, reason: ChainTerminationReason) -> Self {
Self {
results,
early_terminated: true,
termination_reason: Some(reason),
}
}
pub fn all_passed(&self) -> bool {
self.results.iter().all(|r| r.passed)
}
pub fn has_immediate_failure(&self) -> bool {
self.results.iter().any(|r| r.should_fail())
}
pub fn has_escalation(&self) -> bool {
self.results.iter().any(|r| r.requires_escalation())
}
pub fn first_failure(&self) -> Option<&GuardrailResult> {
self.results.iter().find(|r| !r.passed)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ChainTerminationReason {
ImmediateFailure {
guardrail_id: String,
message: String,
},
}
pub fn run_sync_guardrails_chain(
guardrails: &[GuardrailConfig],
output: &str,
) -> GuardrailChainResult {
let mut results = Vec::with_capacity(guardrails.len());
for guardrail in guardrails {
let result = match guardrail {
GuardrailConfig::Length(lg) => lg.check(output),
GuardrailConfig::Schema(sg) => sg.check(output),
GuardrailConfig::Regex(rg) => rg.check(output),
GuardrailConfig::Llm(_) => continue,
};
if result.should_fail() {
let reason = ChainTerminationReason::ImmediateFailure {
guardrail_id: result.guardrail_id.clone(),
message: result
.message
.clone()
.unwrap_or_else(|| "Guardrail failed".to_string()),
};
results.push(result);
return GuardrailChainResult::terminated(results, reason);
}
results.push(result);
}
GuardrailChainResult::completed(results)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::serde_yaml;
#[test]
fn test_length_guardrail_min_words_pass() {
let guardrail = LengthGuardrail {
id: Some("len1".to_string()),
min_words: Some(5),
max_words: None,
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
};
let result = guardrail.check("This has exactly five words here.");
assert!(result.passed);
}
#[test]
fn test_length_guardrail_min_words_fail() {
let guardrail = LengthGuardrail {
id: Some("len1".to_string()),
min_words: Some(10),
max_words: None,
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
};
let result = guardrail.check("Too short");
assert!(!result.passed);
assert!(result.message.unwrap().contains("2 words"));
}
#[test]
fn test_length_guardrail_max_words_pass() {
let guardrail = LengthGuardrail {
id: None,
min_words: None,
max_words: Some(10),
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
};
let result = guardrail.check("This is short");
assert!(result.passed);
}
#[test]
fn test_length_guardrail_max_words_fail() {
let guardrail = LengthGuardrail {
id: None,
min_words: None,
max_words: Some(3),
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
};
let result = guardrail.check("This has more than three words");
assert!(!result.passed);
}
#[test]
fn test_length_guardrail_chars() {
let guardrail = LengthGuardrail {
id: None,
min_words: None,
max_words: None,
min_chars: Some(10),
max_chars: Some(20),
message: None,
on_failure: OnFailure::Retry,
};
let result = guardrail.check("Hello, World!!!");
assert!(result.passed);
let result = guardrail.check("Hi");
assert!(!result.passed);
let result = guardrail.check("This is way too long for the limit");
assert!(!result.passed);
}
#[test]
fn test_length_guardrail_custom_message() {
let guardrail = LengthGuardrail {
id: None,
min_words: Some(100),
max_words: None,
min_chars: None,
max_chars: None,
message: Some("Response too brief".to_string()),
on_failure: OnFailure::Retry,
};
let result = guardrail.check("Short");
assert!(!result.passed);
assert_eq!(result.message.unwrap(), "Response too brief");
}
#[test]
fn test_length_guardrail_validation() {
let guardrail = LengthGuardrail {
id: None,
min_words: Some(10),
max_words: Some(100),
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
};
assert!(guardrail.validate().is_ok());
let guardrail = LengthGuardrail {
id: None,
min_words: None,
max_words: None,
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
};
assert!(guardrail.validate().is_err());
let guardrail = LengthGuardrail {
id: None,
min_words: Some(100),
max_words: Some(10),
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
};
assert!(guardrail.validate().is_err());
}
#[test]
fn test_schema_guardrail_valid_json() {
let guardrail = SchemaGuardrail {
id: Some("schema1".to_string()),
json_schema: serde_json::json!({
"type": "object",
"required": ["name"]
}),
message: None,
on_failure: OnFailure::Retry,
};
let result = guardrail.check(r#"{"name": "test", "value": 42}"#);
assert!(result.passed);
}
#[test]
fn test_schema_guardrail_missing_required() {
let guardrail = SchemaGuardrail {
id: None,
json_schema: serde_json::json!({
"type": "object",
"required": ["name", "value"]
}),
message: None,
on_failure: OnFailure::Retry,
};
let result = guardrail.check(r#"{"name": "test"}"#);
assert!(!result.passed);
assert!(result.message.unwrap().contains("value"));
}
#[test]
fn test_schema_guardrail_invalid_json() {
let guardrail = SchemaGuardrail {
id: None,
json_schema: serde_json::json!({"type": "object"}),
message: None,
on_failure: OnFailure::Retry,
};
let result = guardrail.check("not valid json {");
assert!(!result.passed);
assert!(result.message.unwrap().contains("Invalid JSON"));
}
#[test]
fn test_schema_guardrail_not_object() {
let guardrail = SchemaGuardrail {
id: None,
json_schema: serde_json::json!({
"type": "object",
"required": ["x"]
}),
message: None,
on_failure: OnFailure::Retry,
};
let result = guardrail.check(r#"[1, 2, 3]"#);
assert!(!result.passed);
assert!(result.message.unwrap().contains("Expected JSON object"));
}
#[test]
fn test_regex_guardrail_match() {
let guardrail = RegexGuardrail {
id: Some("regex1".to_string()),
pattern: r"^Summary:".to_string(),
negate: false,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
};
let result = guardrail.check("Summary: This is the summary.");
assert!(result.passed);
}
#[test]
fn test_regex_guardrail_no_match() {
let guardrail = RegexGuardrail {
id: None,
pattern: r"^Summary:".to_string(),
negate: false,
message: Some("Must start with Summary:".to_string()),
on_failure: OnFailure::Retry,
..Default::default()
};
let result = guardrail.check("This doesn't start with Summary");
assert!(!result.passed);
assert_eq!(result.message.unwrap(), "Must start with Summary:");
}
#[test]
fn test_regex_guardrail_negate() {
let guardrail = RegexGuardrail {
id: None,
pattern: r"TODO|FIXME".to_string(),
negate: true, message: None,
on_failure: OnFailure::Retry,
..Default::default()
};
let result = guardrail.check("This is clean code");
assert!(result.passed);
let result = guardrail.check("This has a TODO in it");
assert!(!result.passed);
}
#[test]
fn test_regex_guardrail_validation() {
let guardrail = RegexGuardrail {
id: None,
pattern: r"^\w+$".to_string(),
negate: false,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
};
assert!(guardrail.validate().is_ok());
let guardrail = RegexGuardrail {
id: None,
pattern: r"[invalid(".to_string(),
negate: false,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
};
assert!(guardrail.validate().is_err());
}
#[test]
fn test_parse_length_guardrail() {
let yaml = r#"
type: length
id: len1
min_words: 50
max_words: 200
"#;
let config: GuardrailConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.guardrail_type(), "length");
assert_eq!(config.id(), "len1");
if let GuardrailConfig::Length(g) = config {
assert_eq!(g.min_words, Some(50));
assert_eq!(g.max_words, Some(200));
} else {
panic!("Expected Length guardrail");
}
}
#[test]
fn test_parse_schema_guardrail() {
let yaml = r#"
type: schema
json_schema:
type: object
required:
- name
- value
"#;
let config: GuardrailConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.guardrail_type(), "schema");
if let GuardrailConfig::Schema(g) = config {
assert!(g.json_schema.is_object());
} else {
panic!("Expected Schema guardrail");
}
}
#[test]
fn test_parse_regex_guardrail() {
let yaml = r#"
type: regex
pattern: "^Result:"
negate: false
message: "Must start with Result:"
"#;
let config: GuardrailConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.guardrail_type(), "regex");
if let GuardrailConfig::Regex(g) = config {
assert_eq!(g.pattern, "^Result:");
assert!(!g.negate);
} else {
panic!("Expected Regex guardrail");
}
}
#[test]
fn test_run_guardrails_all_pass() {
let guardrails = vec![
GuardrailConfig::Length(LengthGuardrail {
id: Some("len".to_string()),
min_words: Some(3),
max_words: None,
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
}),
GuardrailConfig::Regex(RegexGuardrail {
id: Some("re".to_string()),
pattern: r"\w+".to_string(),
negate: false,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
}),
];
let output = "This is a valid response";
let results = run_sync_guardrails(&guardrails, output);
assert_eq!(results.len(), 2);
assert!(all_guardrails_passed(&results));
assert!(first_failed_guardrail(&results).is_none());
}
#[test]
fn test_run_guardrails_one_fails() {
let guardrails = vec![
GuardrailConfig::Length(LengthGuardrail {
id: Some("len".to_string()),
min_words: Some(100), max_words: None,
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
}),
GuardrailConfig::Regex(RegexGuardrail {
id: Some("re".to_string()),
pattern: r"\w+".to_string(), negate: false,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
}),
];
let output = "Short response";
let results = run_sync_guardrails(&guardrails, output);
assert_eq!(results.len(), 2);
assert!(!all_guardrails_passed(&results));
let failed = first_failed_guardrail(&results).unwrap();
assert_eq!(failed.guardrail_id, "len");
assert_eq!(failed.guardrail_type, "length");
}
#[test]
fn test_llm_guardrail_creation() {
let guardrail = LlmGuardrail {
id: Some("content_safety".to_string()),
judge_prompt: "Is this content safe and appropriate? Respond PASS or FAIL.".to_string(),
pass_pattern: "^PASS".to_string(),
model: Some("gpt-4o-mini".to_string()),
max_tokens: 100,
temperature: 0.0,
message: Some("Content failed safety check".to_string()),
on_failure: OnFailure::Escalate,
..Default::default()
};
assert_eq!(guardrail.id, Some("content_safety".to_string()));
assert!(guardrail.judge_prompt.contains("safe"));
assert_eq!(guardrail.on_failure, OnFailure::Escalate);
}
#[test]
fn test_llm_guardrail_defaults() {
let yaml = r#"
type: llm
judge_prompt: "Is this valid? Respond PASS or FAIL."
"#;
let config: GuardrailConfig = serde_yaml::from_str(yaml).unwrap();
if let GuardrailConfig::Llm(g) = config {
assert_eq!(g.pass_pattern, "^PASS");
assert_eq!(g.max_tokens, 150);
assert_eq!(g.temperature, 0.0);
assert!(g.model.is_none());
assert_eq!(g.on_failure, OnFailure::Retry);
} else {
panic!("Expected Llm guardrail");
}
}
#[test]
fn test_llm_guardrail_validation() {
let guardrail = LlmGuardrail {
id: None,
judge_prompt: "Is this OK? PASS or FAIL.".to_string(),
pass_pattern: "^PASS".to_string(),
model: None,
max_tokens: 100,
temperature: 0.5,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
};
assert!(guardrail.validate().is_ok());
let guardrail = LlmGuardrail {
id: None,
judge_prompt: "".to_string(),
pass_pattern: "^PASS".to_string(),
model: None,
max_tokens: 100,
temperature: 0.5,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
};
assert!(guardrail.validate().is_err());
let guardrail = LlmGuardrail {
id: None,
judge_prompt: "Is this OK?".to_string(),
pass_pattern: "".to_string(),
model: None,
max_tokens: 100,
temperature: 0.5,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
};
assert!(guardrail.validate().is_err());
let guardrail = LlmGuardrail {
id: None,
judge_prompt: "Is this OK?".to_string(),
pass_pattern: "[invalid(".to_string(),
model: None,
max_tokens: 100,
temperature: 0.5,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
};
assert!(guardrail.validate().is_err());
}
#[test]
fn test_llm_guardrail_build_judge_prompt() {
let guardrail = LlmGuardrail {
id: None,
judge_prompt: "Evaluate this output for quality:\n{{output}}\nRespond PASS or FAIL."
.to_string(),
pass_pattern: "^PASS".to_string(),
model: None,
max_tokens: 100,
temperature: 0.0,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
};
let prompt = guardrail.build_judge_prompt("Hello world");
assert!(prompt.contains("Hello world"));
assert!(prompt.contains("Evaluate this output"));
}
#[test]
fn test_llm_guardrail_check_judge_response() {
let guardrail = LlmGuardrail {
id: None,
judge_prompt: "Check this.".to_string(),
pass_pattern: "^PASS".to_string(),
model: None,
max_tokens: 100,
temperature: 0.0,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
};
assert!(guardrail.check_judge_response("PASS").passed);
assert!(guardrail.check_judge_response("PASS - looks good").passed);
assert!(!guardrail.check_judge_response("FAIL - not good").passed);
assert!(!guardrail.check_judge_response("fail").passed);
assert!(!guardrail.check_judge_response("The output is okay").passed);
}
#[test]
fn test_llm_guardrail_custom_pass_pattern() {
let guardrail = LlmGuardrail {
id: None,
judge_prompt: "Score 1-10.".to_string(),
pass_pattern: r"^(8|9|10)/10".to_string(), model: None,
max_tokens: 100,
temperature: 0.0,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
};
assert!(guardrail.check_judge_response("9/10 - excellent").passed);
assert!(guardrail.check_judge_response("10/10").passed);
assert!(!guardrail.check_judge_response("7/10 - good").passed);
assert!(!guardrail.check_judge_response("5/10").passed);
}
#[test]
fn test_on_failure_default() {
assert_eq!(OnFailure::default(), OnFailure::Retry);
}
#[test]
fn test_on_failure_parsing() {
let yaml_retry = "retry";
let yaml_escalate = "escalate";
let yaml_fail = "fail";
assert_eq!(
serde_yaml::from_str::<OnFailure>(yaml_retry).unwrap(),
OnFailure::Retry
);
assert_eq!(
serde_yaml::from_str::<OnFailure>(yaml_escalate).unwrap(),
OnFailure::Escalate
);
assert_eq!(
serde_yaml::from_str::<OnFailure>(yaml_fail).unwrap(),
OnFailure::Fail
);
}
#[test]
fn test_guardrail_config_on_failure() {
let yaml = r#"
type: length
min_words: 10
on_failure: escalate
"#;
let config: GuardrailConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.on_failure(), OnFailure::Escalate);
}
#[test]
fn test_partition_guardrails_sync_only() {
let guardrails = vec![
GuardrailConfig::Length(LengthGuardrail {
id: None,
min_words: Some(10),
max_words: None,
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
}),
GuardrailConfig::Regex(RegexGuardrail {
id: None,
pattern: r"\w+".to_string(),
negate: false,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
}),
];
let (sync, async_) = partition_guardrails(&guardrails);
assert_eq!(sync.len(), 2);
assert_eq!(async_.len(), 0);
}
#[test]
fn test_partition_guardrails_mixed() {
let guardrails = vec![
GuardrailConfig::Length(LengthGuardrail {
id: None,
min_words: Some(10),
max_words: None,
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
}),
GuardrailConfig::Llm(LlmGuardrail {
id: None,
judge_prompt: "Check this.".to_string(),
pass_pattern: "^PASS".to_string(),
model: None,
max_tokens: 100,
temperature: 0.0,
message: None,
on_failure: OnFailure::Escalate,
..Default::default()
}),
];
let (sync, async_) = partition_guardrails(&guardrails);
assert_eq!(sync.len(), 1);
assert_eq!(async_.len(), 1);
}
#[test]
fn test_has_llm_guardrails_false() {
let guardrails = vec![GuardrailConfig::Length(LengthGuardrail {
id: None,
min_words: Some(10),
max_words: None,
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
})];
assert!(!has_llm_guardrails(&guardrails));
}
#[test]
fn test_has_llm_guardrails_true() {
let guardrails = vec![GuardrailConfig::Llm(LlmGuardrail {
id: None,
judge_prompt: "Check.".to_string(),
pass_pattern: "^PASS".to_string(),
model: None,
max_tokens: 100,
temperature: 0.0,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
})];
assert!(has_llm_guardrails(&guardrails));
}
#[test]
fn test_guardrail_result_with_action() {
let result = GuardrailResult::failed_with_action(
"safety".to_string(),
"llm",
"Content unsafe".to_string(),
OnFailure::Escalate,
);
assert!(!result.passed);
assert_eq!(result.on_failure, OnFailure::Escalate);
assert!(result.requires_escalation());
assert!(!result.should_fail());
assert!(!result.should_retry());
}
#[test]
fn test_guardrail_result_should_fail() {
let result = GuardrailResult::failed_with_action(
"critical".to_string(),
"schema",
"Invalid format".to_string(),
OnFailure::Fail,
);
assert!(result.should_fail());
assert!(!result.requires_escalation());
assert!(!result.should_retry());
}
#[test]
fn test_guardrail_result_should_retry() {
let result = GuardrailResult::failed_with_action(
"length".to_string(),
"length",
"Too short".to_string(),
OnFailure::Retry,
);
assert!(result.should_retry());
assert!(!result.requires_escalation());
assert!(!result.should_fail());
}
#[test]
fn test_escalation_required_filter() {
let results = vec![
GuardrailResult::passed("g1".to_string(), "length"),
GuardrailResult::failed_with_action(
"g2".to_string(),
"llm",
"Safety issue".to_string(),
OnFailure::Escalate,
),
GuardrailResult::failed_with_action(
"g3".to_string(),
"length",
"Too short".to_string(),
OnFailure::Retry,
),
];
let escalations = escalation_required(&results);
assert_eq!(escalations.len(), 1);
assert_eq!(escalations[0].guardrail_id, "g2");
}
#[test]
fn test_immediate_failures_filter() {
let results = vec![
GuardrailResult::passed("g1".to_string(), "length"),
GuardrailResult::failed_with_action(
"g2".to_string(),
"schema",
"Invalid".to_string(),
OnFailure::Fail,
),
GuardrailResult::failed_with_action(
"g3".to_string(),
"llm",
"Safety".to_string(),
OnFailure::Escalate,
),
];
let failures = immediate_failures(&results);
assert_eq!(failures.len(), 1);
assert_eq!(failures[0].guardrail_id, "g2");
}
#[test]
fn test_parse_llm_guardrail_full() {
let yaml = r#"
type: llm
id: content_safety
judge_prompt: |
Evaluate if this content is appropriate.
Output: {{output}}
Respond with PASS or FAIL.
pass_pattern: "^PASS"
model: gpt-4o-mini
max_tokens: 200
temperature: 0.1
message: "Content failed safety review"
on_failure: escalate
"#;
let config: GuardrailConfig = serde_yaml::from_str(yaml).unwrap();
assert_eq!(config.guardrail_type(), "llm");
assert_eq!(config.id(), "content_safety");
assert_eq!(config.on_failure(), OnFailure::Escalate);
assert!(config.is_async());
if let GuardrailConfig::Llm(g) = config {
assert!(g.judge_prompt.contains("appropriate"));
assert_eq!(g.model, Some("gpt-4o-mini".to_string()));
assert_eq!(g.max_tokens, 200);
} else {
panic!("Expected Llm guardrail");
}
}
#[test]
fn test_chain_result_completed() {
let results = vec![
GuardrailResult::passed("g1".to_string(), "length"),
GuardrailResult::passed("g2".to_string(), "schema"),
];
let chain = GuardrailChainResult::completed(results);
assert!(!chain.early_terminated);
assert!(chain.termination_reason.is_none());
assert!(chain.all_passed());
assert!(!chain.has_immediate_failure());
assert!(!chain.has_escalation());
assert!(chain.first_failure().is_none());
}
#[test]
fn test_chain_result_terminated() {
let results = vec![GuardrailResult::failed_with_action(
"g1".to_string(),
"length",
"Too short".to_string(),
OnFailure::Fail,
)];
let reason = ChainTerminationReason::ImmediateFailure {
guardrail_id: "g1".to_string(),
message: "Too short".to_string(),
};
let chain = GuardrailChainResult::terminated(results, reason);
assert!(chain.early_terminated);
assert!(chain.termination_reason.is_some());
assert!(!chain.all_passed());
assert!(chain.has_immediate_failure());
assert!(chain.first_failure().is_some());
}
#[test]
fn test_chain_all_pass() {
let guardrails = vec![
GuardrailConfig::Length(LengthGuardrail {
id: Some("g1".to_string()),
min_words: Some(2),
max_words: None,
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
}),
GuardrailConfig::Regex(RegexGuardrail {
id: Some("g2".to_string()),
pattern: "test".to_string(),
negate: false,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
}),
];
let chain = run_sync_guardrails_chain(&guardrails, "This is a test");
assert!(!chain.early_terminated);
assert!(chain.all_passed());
assert_eq!(chain.results.len(), 2);
}
#[test]
fn test_chain_early_termination_on_fail() {
let guardrails = vec![
GuardrailConfig::Length(LengthGuardrail {
id: Some("critical".to_string()),
min_words: Some(100), max_words: None,
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Fail, }),
GuardrailConfig::Regex(RegexGuardrail {
id: Some("should_not_run".to_string()),
pattern: ".*".to_string(),
negate: false,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
}),
];
let chain = run_sync_guardrails_chain(&guardrails, "Short text");
assert!(chain.early_terminated);
assert_eq!(chain.results.len(), 1); assert!(chain.has_immediate_failure());
assert_eq!(chain.first_failure().unwrap().guardrail_id, "critical");
if let Some(ChainTerminationReason::ImmediateFailure {
guardrail_id,
message: _,
}) = &chain.termination_reason
{
assert_eq!(guardrail_id, "critical");
} else {
panic!("Expected ImmediateFailure reason");
}
}
#[test]
fn test_chain_no_early_termination_on_retry() {
let guardrails = vec![
GuardrailConfig::Length(LengthGuardrail {
id: Some("g1".to_string()),
min_words: Some(100), max_words: None,
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry, }),
GuardrailConfig::Regex(RegexGuardrail {
id: Some("g2".to_string()),
pattern: "test".to_string(),
negate: false,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
}),
];
let chain = run_sync_guardrails_chain(&guardrails, "This is a test");
assert!(!chain.early_terminated); assert_eq!(chain.results.len(), 2);
assert!(!chain.all_passed()); assert!(!chain.has_immediate_failure()); }
#[test]
fn test_chain_no_early_termination_on_escalate() {
let guardrails = vec![
GuardrailConfig::Length(LengthGuardrail {
id: Some("g1".to_string()),
min_words: Some(100), max_words: None,
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Escalate, }),
GuardrailConfig::Regex(RegexGuardrail {
id: Some("g2".to_string()),
pattern: "test".to_string(),
negate: false,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
}),
];
let chain = run_sync_guardrails_chain(&guardrails, "This is a test");
assert!(!chain.early_terminated); assert_eq!(chain.results.len(), 2);
assert!(chain.has_escalation());
}
#[test]
fn test_chain_skips_llm_guardrails() {
let guardrails = vec![
GuardrailConfig::Length(LengthGuardrail {
id: Some("g1".to_string()),
min_words: Some(2),
max_words: None,
min_chars: None,
max_chars: None,
message: None,
on_failure: OnFailure::Retry,
}),
GuardrailConfig::Llm(LlmGuardrail {
id: Some("llm_guard".to_string()),
judge_prompt: "Evaluate".to_string(),
pass_pattern: "^PASS".to_string(),
model: None,
max_tokens: 150,
temperature: 0.0,
message: None,
on_failure: OnFailure::Fail, ..Default::default()
}),
GuardrailConfig::Regex(RegexGuardrail {
id: Some("g2".to_string()),
pattern: "test".to_string(),
negate: false,
message: None,
on_failure: OnFailure::Retry,
..Default::default()
}),
];
let chain = run_sync_guardrails_chain(&guardrails, "This is a test");
assert!(!chain.early_terminated);
assert_eq!(chain.results.len(), 2); assert!(chain.all_passed());
}
#[test]
fn test_chain_termination_reason_fields() {
let reason = ChainTerminationReason::ImmediateFailure {
guardrail_id: "critical_check".to_string(),
message: "Content validation failed".to_string(),
};
let ChainTerminationReason::ImmediateFailure {
guardrail_id,
message,
} = reason;
assert_eq!(guardrail_id, "critical_check");
assert_eq!(message, "Content validation failed");
}
}