mod config;
mod filters;
mod pii;
mod result;
pub use config::{FilterAction, FilterRule, GuardrailsConfig, GuardrailsConfigBuilder};
pub use filters::{ContentFilter, KeywordFilter, PatternFilter};
pub use pii::PiiDetector;
pub use result::{CheckResult, Violation, ViolationType};
use ai_lib_core::types::message::Message;
#[derive(Debug, Clone)]
pub struct Guardrails {
config: GuardrailsConfig,
keyword_filter: KeywordFilter,
pattern_filter: PatternFilter,
pii_detector: Option<PiiDetector>,
}
impl Guardrails {
pub fn new(config: GuardrailsConfig) -> Self {
let keyword_filter = KeywordFilter::from_rules(&config.keyword_rules);
let pattern_filter = PatternFilter::from_rules(&config.pattern_rules);
let pii_detector = if config.enable_pii_detection {
Some(PiiDetector::new())
} else {
None
};
Self {
config,
keyword_filter,
pattern_filter,
pii_detector,
}
}
pub fn permissive() -> Self {
Self::new(GuardrailsConfig::permissive())
}
pub fn strict() -> Self {
Self::new(GuardrailsConfig::strict())
}
pub fn check_input(&self, content: &str) -> CheckResult {
self.check_content(content, true)
}
pub fn check_output(&self, content: &str) -> CheckResult {
self.check_content(content, false)
}
pub fn check_message(&self, message: &Message) -> CheckResult {
let content = extract_text_content(message);
self.check_content(&content, true)
}
pub fn check_messages(&self, messages: &[Message]) -> CheckResult {
let mut combined_result = CheckResult::passed();
for message in messages {
let result = self.check_message(message);
combined_result = combined_result.merge(result);
if combined_result.is_blocked() && self.config.stop_on_first_block {
break;
}
}
combined_result
}
fn check_content(&self, content: &str, is_input: bool) -> CheckResult {
let mut violations = Vec::new();
if (is_input && self.config.filter_input) || (!is_input && self.config.filter_output) {
violations.extend(self.keyword_filter.check(content));
violations.extend(self.pattern_filter.check(content));
}
if let Some(ref pii_detector) = self.pii_detector {
if (is_input && self.config.check_pii_input)
|| (!is_input && self.config.check_pii_output)
{
violations.extend(pii_detector.check(content));
}
}
CheckResult::from_violations(violations)
}
pub fn sanitize(&self, content: &str) -> String {
let mut sanitized = content.to_string();
sanitized = self
.keyword_filter
.sanitize(&sanitized, &self.config.sanitize_replacement);
sanitized = self
.pattern_filter
.sanitize(&sanitized, &self.config.sanitize_replacement);
if let Some(ref pii_detector) = self.pii_detector {
sanitized = pii_detector.sanitize(&sanitized, &self.config.pii_replacement);
}
sanitized
}
pub fn config(&self) -> &GuardrailsConfig {
&self.config
}
}
fn extract_text_content(message: &Message) -> String {
use ai_lib_core::types::message::MessageContent;
match &message.content {
MessageContent::Text(text) => text.clone(),
MessageContent::Blocks(blocks) => blocks
.iter()
.filter_map(|block| {
use ai_lib_core::types::message::ContentBlock;
match block {
ContentBlock::Text { text } => Some(text.clone()),
_ => None,
}
})
.collect::<Vec<_>>()
.join(" "),
}
}
impl Default for Guardrails {
fn default() -> Self {
Self::permissive()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_permissive_allows_all() {
let guardrails = Guardrails::permissive();
let result = guardrails.check_input("Any content should pass");
assert!(result.is_passed());
}
#[test]
fn test_keyword_blocking() {
let config = GuardrailsConfig::builder()
.add_keyword_filter("blocked_word", FilterAction::Block)
.build();
let guardrails = Guardrails::new(config);
let result = guardrails.check_input("This contains blocked_word in it");
assert!(result.is_blocked());
}
#[test]
fn test_keyword_warning() {
let config = GuardrailsConfig::builder()
.add_keyword_filter("warn_word", FilterAction::Warn)
.build();
let guardrails = Guardrails::new(config);
let result = guardrails.check_input("This contains warn_word in it");
assert!(result.is_warned());
assert!(!result.is_blocked());
}
#[test]
fn test_sanitization() {
let config = GuardrailsConfig::builder()
.add_keyword_filter("secret", FilterAction::Sanitize)
.sanitize_replacement("[REDACTED]".to_string())
.build();
let guardrails = Guardrails::new(config);
let sanitized = guardrails.sanitize("My secret is here");
assert!(sanitized.contains("[REDACTED]"));
assert!(!sanitized.contains("secret"));
}
#[test]
fn test_pii_detection() {
let config = GuardrailsConfig::builder()
.enable_pii_detection(true)
.build();
let guardrails = Guardrails::new(config);
let result = guardrails.check_input("My email is test@example.com");
assert!(result.has_violations());
}
}