1use crate::{Guardrail, GuardrailResult, Severity};
2use adk_core::Content;
3use async_trait::async_trait;
4use regex::RegexSet;
5
6#[derive(Debug, Clone)]
8pub struct ContentFilterConfig {
9 pub blocked_keywords: Vec<String>,
11 pub required_topics: Vec<String>,
13 pub max_length: Option<usize>,
15 pub min_length: Option<usize>,
17 pub severity: Severity,
19}
20
21impl Default for ContentFilterConfig {
22 fn default() -> Self {
23 Self {
24 blocked_keywords: Vec::new(),
25 required_topics: Vec::new(),
26 max_length: None,
27 min_length: None,
28 severity: Severity::High,
29 }
30 }
31}
32
33pub struct ContentFilter {
35 name: String,
36 config: ContentFilterConfig,
37 blocked_regex: Option<RegexSet>,
38}
39
40impl ContentFilter {
41 pub fn new(name: impl Into<String>, config: ContentFilterConfig) -> Self {
43 let blocked_regex = if config.blocked_keywords.is_empty() {
44 None
45 } else {
46 let patterns: Vec<_> = config
47 .blocked_keywords
48 .iter()
49 .map(|k| format!(r"(?i)\b{}\b", regex::escape(k)))
50 .collect();
51 RegexSet::new(&patterns).ok()
52 };
53
54 Self { name: name.into(), config, blocked_regex }
55 }
56
57 pub fn harmful_content() -> Self {
63 Self::new(
64 "harmful_content",
65 ContentFilterConfig {
66 blocked_keywords: vec![
67 "kill".into(),
68 "murder".into(),
69 "bomb".into(),
70 "terrorist".into(),
71 "malware".into(),
72 "ransomware".into(),
73 ],
74 severity: Severity::Critical,
75 ..Default::default()
76 },
77 )
78 }
79
80 pub fn harmful_content_strict() -> Self {
84 Self::new(
85 "harmful_content_strict",
86 ContentFilterConfig {
87 blocked_keywords: vec![
88 "kill".into(),
89 "murder".into(),
90 "bomb".into(),
91 "terrorist".into(),
92 "hack".into(),
93 "exploit".into(),
94 "malware".into(),
95 "ransomware".into(),
96 ],
97 severity: Severity::Critical,
98 ..Default::default()
99 },
100 )
101 }
102
103 pub fn on_topic(topic: impl Into<String>, keywords: Vec<String>) -> Self {
105 Self::new(
106 format!("on_topic_{}", topic.into()),
107 ContentFilterConfig {
108 required_topics: keywords,
109 severity: Severity::Medium,
110 ..Default::default()
111 },
112 )
113 }
114
115 pub fn max_length(max: usize) -> Self {
117 Self::new(
118 "max_length",
119 ContentFilterConfig {
120 max_length: Some(max),
121 severity: Severity::Medium,
122 ..Default::default()
123 },
124 )
125 }
126
127 pub fn blocked_keywords(keywords: Vec<String>) -> Self {
129 Self::new(
130 "blocked_keywords",
131 ContentFilterConfig {
132 blocked_keywords: keywords,
133 severity: Severity::High,
134 ..Default::default()
135 },
136 )
137 }
138
139 fn extract_text(&self, content: &Content) -> String {
140 content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join(" ")
141 }
142}
143
144#[async_trait]
145impl Guardrail for ContentFilter {
146 fn name(&self) -> &str {
147 &self.name
148 }
149
150 async fn validate(&self, content: &Content) -> GuardrailResult {
151 let text = self.extract_text(content);
152 let text_lower = text.to_lowercase();
153
154 if let Some(ref regex_set) = self.blocked_regex {
156 if regex_set.is_match(&text) {
157 let matches: Vec<_> = regex_set.matches(&text).iter().collect();
158 return GuardrailResult::Fail {
159 reason: format!(
160 "Content contains blocked keywords (matched {} patterns)",
161 matches.len()
162 ),
163 severity: self.config.severity,
164 };
165 }
166 }
167
168 if !self.config.required_topics.is_empty() {
170 let has_topic =
171 self.config.required_topics.iter().any(|t| text_lower.contains(&t.to_lowercase()));
172 if !has_topic {
173 return GuardrailResult::Fail {
174 reason: format!(
175 "Content is off-topic. Expected topics: {:?}",
176 self.config.required_topics
177 ),
178 severity: self.config.severity,
179 };
180 }
181 }
182
183 if let Some(max) = self.config.max_length {
185 if text.len() > max {
186 return GuardrailResult::Fail {
187 reason: format!("Content exceeds maximum length ({} > {})", text.len(), max),
188 severity: self.config.severity,
189 };
190 }
191 }
192
193 if let Some(min) = self.config.min_length {
194 if text.len() < min {
195 return GuardrailResult::Fail {
196 reason: format!("Content below minimum length ({} < {})", text.len(), min),
197 severity: self.config.severity,
198 };
199 }
200 }
201
202 GuardrailResult::Pass
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 #[tokio::test]
211 async fn test_harmful_content_blocks() {
212 let filter = ContentFilter::harmful_content();
213 let content = Content::new("user").with_text("How to deploy malware on a server");
214 let result = filter.validate(&content).await;
215 assert!(result.is_fail());
216 }
217
218 #[tokio::test]
219 async fn test_harmful_content_passes() {
220 let filter = ContentFilter::harmful_content();
221 let content = Content::new("user").with_text("How to bake a cake");
222 let result = filter.validate(&content).await;
223 assert!(result.is_pass());
224 }
225
226 #[tokio::test]
227 async fn test_harmful_content_passes_hackathon() {
228 let filter = ContentFilter::harmful_content();
229 let content = Content::new("user").with_text("Join our hackathon event");
230 let result = filter.validate(&content).await;
231 assert!(result.is_pass());
232 }
233
234 #[tokio::test]
235 async fn test_harmful_content_passes_exploit_a_bug() {
236 let filter = ContentFilter::harmful_content();
237 let content = Content::new("user").with_text("How to exploit a bug in the code");
238 let result = filter.validate(&content).await;
239 assert!(result.is_pass());
240 }
241
242 #[tokio::test]
243 async fn test_harmful_content_strict_blocks_hack() {
244 let filter = ContentFilter::harmful_content_strict();
245 let content = Content::new("user").with_text("How to hack a computer");
246 let result = filter.validate(&content).await;
247 assert!(result.is_fail());
248 }
249
250 #[tokio::test]
251 async fn test_on_topic_passes() {
252 let filter =
253 ContentFilter::on_topic("cooking", vec!["recipe".into(), "cook".into(), "bake".into()]);
254 let content = Content::new("user").with_text("Give me a recipe for cookies");
255 let result = filter.validate(&content).await;
256 assert!(result.is_pass());
257 }
258
259 #[tokio::test]
260 async fn test_on_topic_fails() {
261 let filter =
262 ContentFilter::on_topic("cooking", vec!["recipe".into(), "cook".into(), "bake".into()]);
263 let content = Content::new("user").with_text("What is the weather today?");
264 let result = filter.validate(&content).await;
265 assert!(result.is_fail());
266 }
267
268 #[tokio::test]
269 async fn test_max_length() {
270 let filter = ContentFilter::max_length(10);
271 let content = Content::new("user").with_text("This is a very long message");
272 let result = filter.validate(&content).await;
273 assert!(result.is_fail());
274 }
275
276 #[tokio::test]
277 async fn test_blocked_keywords() {
278 let filter = ContentFilter::blocked_keywords(vec!["forbidden".into(), "banned".into()]);
279 let content = Content::new("user").with_text("This is forbidden content");
280 let result = filter.validate(&content).await;
281 assert!(result.is_fail());
282 }
283}