use crate::{Guardrail, GuardrailResult, Severity};
use adk_core::Content;
use async_trait::async_trait;
use regex::RegexSet;
#[derive(Debug, Clone)]
pub struct ContentFilterConfig {
pub blocked_keywords: Vec<String>,
pub required_topics: Vec<String>,
pub max_length: Option<usize>,
pub min_length: Option<usize>,
pub severity: Severity,
}
impl Default for ContentFilterConfig {
fn default() -> Self {
Self {
blocked_keywords: Vec::new(),
required_topics: Vec::new(),
max_length: None,
min_length: None,
severity: Severity::High,
}
}
}
pub struct ContentFilter {
name: String,
config: ContentFilterConfig,
blocked_regex: Option<RegexSet>,
}
impl ContentFilter {
pub fn new(name: impl Into<String>, config: ContentFilterConfig) -> Self {
let blocked_regex = if config.blocked_keywords.is_empty() {
None
} else {
let patterns: Vec<_> = config
.blocked_keywords
.iter()
.map(|k| format!(r"(?i)\b{}\b", regex::escape(k)))
.collect();
RegexSet::new(&patterns).ok()
};
Self { name: name.into(), config, blocked_regex }
}
pub fn harmful_content() -> Self {
Self::new(
"harmful_content",
ContentFilterConfig {
blocked_keywords: vec![
"kill".into(),
"murder".into(),
"bomb".into(),
"terrorist".into(),
"hack".into(),
"exploit".into(),
"malware".into(),
"ransomware".into(),
],
severity: Severity::Critical,
..Default::default()
},
)
}
pub fn on_topic(topic: impl Into<String>, keywords: Vec<String>) -> Self {
Self::new(
format!("on_topic_{}", topic.into()),
ContentFilterConfig {
required_topics: keywords,
severity: Severity::Medium,
..Default::default()
},
)
}
pub fn max_length(max: usize) -> Self {
Self::new(
"max_length",
ContentFilterConfig {
max_length: Some(max),
severity: Severity::Medium,
..Default::default()
},
)
}
pub fn blocked_keywords(keywords: Vec<String>) -> Self {
Self::new(
"blocked_keywords",
ContentFilterConfig {
blocked_keywords: keywords,
severity: Severity::High,
..Default::default()
},
)
}
fn extract_text(&self, content: &Content) -> String {
content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join(" ")
}
}
#[async_trait]
impl Guardrail for ContentFilter {
fn name(&self) -> &str {
&self.name
}
async fn validate(&self, content: &Content) -> GuardrailResult {
let text = self.extract_text(content);
let text_lower = text.to_lowercase();
if let Some(ref regex_set) = self.blocked_regex {
if regex_set.is_match(&text) {
let matches: Vec<_> = regex_set.matches(&text).iter().collect();
return GuardrailResult::Fail {
reason: format!(
"Content contains blocked keywords (matched {} patterns)",
matches.len()
),
severity: self.config.severity,
};
}
}
if !self.config.required_topics.is_empty() {
let has_topic =
self.config.required_topics.iter().any(|t| text_lower.contains(&t.to_lowercase()));
if !has_topic {
return GuardrailResult::Fail {
reason: format!(
"Content is off-topic. Expected topics: {:?}",
self.config.required_topics
),
severity: self.config.severity,
};
}
}
if let Some(max) = self.config.max_length {
if text.len() > max {
return GuardrailResult::Fail {
reason: format!("Content exceeds maximum length ({} > {})", text.len(), max),
severity: self.config.severity,
};
}
}
if let Some(min) = self.config.min_length {
if text.len() < min {
return GuardrailResult::Fail {
reason: format!("Content below minimum length ({} < {})", text.len(), min),
severity: self.config.severity,
};
}
}
GuardrailResult::Pass
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_harmful_content_blocks() {
let filter = ContentFilter::harmful_content();
let content = Content::new("user").with_text("How to hack a computer");
let result = filter.validate(&content).await;
assert!(result.is_fail());
}
#[tokio::test]
async fn test_harmful_content_passes() {
let filter = ContentFilter::harmful_content();
let content = Content::new("user").with_text("How to bake a cake");
let result = filter.validate(&content).await;
assert!(result.is_pass());
}
#[tokio::test]
async fn test_on_topic_passes() {
let filter =
ContentFilter::on_topic("cooking", vec!["recipe".into(), "cook".into(), "bake".into()]);
let content = Content::new("user").with_text("Give me a recipe for cookies");
let result = filter.validate(&content).await;
assert!(result.is_pass());
}
#[tokio::test]
async fn test_on_topic_fails() {
let filter =
ContentFilter::on_topic("cooking", vec!["recipe".into(), "cook".into(), "bake".into()]);
let content = Content::new("user").with_text("What is the weather today?");
let result = filter.validate(&content).await;
assert!(result.is_fail());
}
#[tokio::test]
async fn test_max_length() {
let filter = ContentFilter::max_length(10);
let content = Content::new("user").with_text("This is a very long message");
let result = filter.validate(&content).await;
assert!(result.is_fail());
}
#[tokio::test]
async fn test_blocked_keywords() {
let filter = ContentFilter::blocked_keywords(vec!["forbidden".into(), "banned".into()]);
let content = Content::new("user").with_text("This is forbidden content");
let result = filter.validate(&content).await;
assert!(result.is_fail());
}
}