use std::collections::HashSet;
use regex::Regex;
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum GuardrailAction {
ForceApproval,
Block,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GuardrailRule {
pub id: String,
pub name: String,
pub patterns: Vec<String>,
pub action: GuardrailAction,
}
#[derive(Debug, Error)]
pub enum GuardrailLoadError {
#[error("guardrail {rule_id:?} pattern {index} invalid: {error}")]
InvalidPattern {
rule_id: String,
index: usize,
error: String,
},
#[error("guardrail rule id {0:?} duplicated")]
DuplicateId(String),
#[error("guardrail {0:?} has no patterns")]
EmptyRule(String),
}
#[derive(Debug, Clone)]
pub struct GuardrailSet {
rules: Vec<CompiledRule>,
}
#[derive(Debug, Clone)]
struct CompiledRule {
id: String,
name: String,
action: GuardrailAction,
patterns: Vec<Regex>,
}
impl GuardrailSet {
pub fn build(rules: Vec<GuardrailRule>) -> Result<Self, GuardrailLoadError> {
let mut seen = HashSet::with_capacity(rules.len());
let mut compiled = Vec::with_capacity(rules.len());
for rule in rules {
if rule.patterns.is_empty() {
return Err(GuardrailLoadError::EmptyRule(rule.id));
}
if !seen.insert(rule.id.clone()) {
return Err(GuardrailLoadError::DuplicateId(rule.id));
}
let mut pats = Vec::with_capacity(rule.patterns.len());
for (i, raw) in rule.patterns.iter().enumerate() {
let prepared = if raw.starts_with("(?") {
raw.clone()
} else {
format!("(?i){raw}")
};
let re = Regex::new(&prepared).map_err(|e| GuardrailLoadError::InvalidPattern {
rule_id: rule.id.clone(),
index: i,
error: e.to_string(),
})?;
pats.push(re);
}
compiled.push(CompiledRule {
id: rule.id,
name: rule.name,
action: rule.action,
patterns: pats,
});
}
Ok(Self { rules: compiled })
}
pub fn empty() -> Self {
Self { rules: Vec::new() }
}
pub fn rule_count(&self) -> usize {
self.rules.len()
}
pub fn scan(&self, text: &str) -> Vec<GuardrailMatch> {
let mut hits = Vec::new();
for rule in &self.rules {
for (i, pat) in rule.patterns.iter().enumerate() {
if let Some(m) = pat.find(text) {
hits.push(GuardrailMatch {
rule_id: rule.id.clone(),
rule_name: rule.name.clone(),
action: rule.action,
matched_pattern_index: i,
excerpt: extract_excerpt(text, m.start(), m.end()),
});
break;
}
}
}
hits
}
pub fn has_block_match(matches: &[GuardrailMatch]) -> bool {
matches.iter().any(|m| m.action == GuardrailAction::Block)
}
pub fn has_force_approval_match(matches: &[GuardrailMatch]) -> bool {
matches
.iter()
.any(|m| m.action == GuardrailAction::ForceApproval)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct GuardrailMatch {
pub rule_id: String,
pub rule_name: String,
pub action: GuardrailAction,
pub matched_pattern_index: usize,
pub excerpt: String,
}
fn extract_excerpt(text: &str, start: usize, end: usize) -> String {
const RADIUS: usize = 30;
let chars: Vec<(usize, char)> = text.char_indices().collect();
let mut start_idx = 0;
let mut end_idx = chars.len();
for (i, (b, _)) in chars.iter().enumerate() {
if *b >= start.saturating_sub(RADIUS) && start_idx == 0 {
start_idx = i;
}
if *b >= end + RADIUS {
end_idx = i;
break;
}
}
let prefix = if start_idx > 0 { "…" } else { "" };
let suffix = if end_idx < chars.len() { "…" } else { "" };
let body: String = chars[start_idx..end_idx].iter().map(|(_, c)| *c).collect();
format!("{prefix}{body}{suffix}")
}
#[cfg(test)]
mod tests {
use super::*;
fn rule(id: &str, action: GuardrailAction, patterns: &[&str]) -> GuardrailRule {
GuardrailRule {
id: id.into(),
name: id.into(),
patterns: patterns.iter().map(|s| s.to_string()).collect(),
action,
}
}
fn pricing_set() -> GuardrailSet {
GuardrailSet::build(vec![
rule(
"pricing_quotes",
GuardrailAction::ForceApproval,
&[r"\bprecio\b", r"\bcotizaci[oó]n\b", r"\bpricing\b"],
),
rule(
"legal",
GuardrailAction::Block,
&[r"\bcontrato\b", r"\bnda\b", r"\bclausula\b"],
),
])
.unwrap()
}
#[test]
fn build_accepts_canonical_set() {
let s = pricing_set();
assert_eq!(s.rule_count(), 2);
}
#[test]
fn build_rejects_empty_pattern_list() {
let r = GuardrailSet::build(vec![rule("x", GuardrailAction::Block, &[])]);
assert!(matches!(r, Err(GuardrailLoadError::EmptyRule(_))));
}
#[test]
fn build_rejects_duplicate_ids() {
let r = GuardrailSet::build(vec![
rule("dup", GuardrailAction::Block, &["a"]),
rule("dup", GuardrailAction::ForceApproval, &["b"]),
]);
assert!(matches!(r, Err(GuardrailLoadError::DuplicateId(_))));
}
#[test]
fn build_rejects_invalid_regex() {
let r = GuardrailSet::build(vec![rule("x", GuardrailAction::Block, &["[unclosed"])]);
assert!(matches!(r, Err(GuardrailLoadError::InvalidPattern { .. })));
}
#[test]
fn scan_pricing_match_force_approval() {
let s = pricing_set();
let m = s.scan("Necesito el precio del plan enterprise");
assert_eq!(m.len(), 1);
assert_eq!(m[0].rule_id, "pricing_quotes");
assert_eq!(m[0].action, GuardrailAction::ForceApproval);
assert_eq!(m[0].matched_pattern_index, 0);
assert!(m[0].excerpt.contains("precio"));
}
#[test]
fn scan_legal_match_block() {
let s = pricing_set();
let m = s.scan("Mándame el contrato firmado");
assert_eq!(m.len(), 1);
assert_eq!(m[0].rule_id, "legal");
assert_eq!(m[0].action, GuardrailAction::Block);
}
#[test]
fn scan_multiple_rules_fire_in_order() {
let s = pricing_set();
let m = s.scan("Necesito el precio + envíame el contrato");
assert_eq!(m.len(), 2);
assert_eq!(m[0].rule_id, "pricing_quotes");
assert_eq!(m[1].rule_id, "legal");
}
#[test]
fn scan_one_rule_fires_at_most_once() {
let s = pricing_set();
let m = s.scan("El precio y la cotización ya las tengo");
assert_eq!(m.len(), 1);
assert_eq!(m[0].matched_pattern_index, 0);
}
#[test]
fn scan_case_insensitive_by_default() {
let s = pricing_set();
let m = s.scan("PRECIO total del proyecto");
assert_eq!(m.len(), 1);
assert_eq!(m[0].rule_id, "pricing_quotes");
}
#[test]
fn scan_no_match_returns_empty() {
let s = pricing_set();
let m = s.scan("Hola, gracias por tu mensaje.");
assert!(m.is_empty());
}
#[test]
fn scan_empty_set_never_fires() {
let s = GuardrailSet::empty();
let m = s.scan("Necesito el precio del plan");
assert!(m.is_empty());
}
#[test]
fn scan_excerpt_carries_match_context() {
let s = pricing_set();
let m = s.scan(
"Hola equipo, después de revisar el plan el precio que ofrecen es competitivo, ¿podemos avanzar?",
);
assert_eq!(m.len(), 1);
assert!(m[0].excerpt.contains("precio"));
assert!(m[0].excerpt.chars().count() < 200);
}
#[test]
fn has_block_match_distinguishes_action_kinds() {
let s = pricing_set();
let force_only = s.scan("Necesito el precio");
assert!(GuardrailSet::has_force_approval_match(&force_only));
assert!(!GuardrailSet::has_block_match(&force_only));
let block_too = s.scan("Necesito el precio + el contrato");
assert!(GuardrailSet::has_force_approval_match(&block_too));
assert!(GuardrailSet::has_block_match(&block_too));
}
#[test]
fn pattern_with_explicit_flags_is_left_alone() {
let s = GuardrailSet::build(vec![rule(
"case_sensitive",
GuardrailAction::Block,
&[r"(?-i)PII"],
)])
.unwrap();
assert!(s.scan("This carries PII").len() == 1);
assert!(s.scan("this carries pii").is_empty());
}
#[test]
fn rule_count_reports_compiled_rules() {
assert_eq!(GuardrailSet::empty().rule_count(), 0);
assert_eq!(pricing_set().rule_count(), 2);
}
}