Skip to main content

adk_guardrail/
content.rs

1use crate::{Guardrail, GuardrailResult, Severity};
2use adk_core::Content;
3use async_trait::async_trait;
4use regex::RegexSet;
5
6/// Configuration for content filtering
7#[derive(Debug, Clone)]
8pub struct ContentFilterConfig {
9    /// Blocked keywords (case-insensitive)
10    pub blocked_keywords: Vec<String>,
11    /// Required topic keywords (at least one must be present)
12    pub required_topics: Vec<String>,
13    /// Maximum character length
14    pub max_length: Option<usize>,
15    /// Minimum character length
16    pub min_length: Option<usize>,
17    /// Severity for failures
18    pub severity: Severity,
19}
20
21impl Default for ContentFilterConfig {
22    fn default() -> Self {
23        Self {
24            blocked_keywords: Vec::new(),
25            required_topics: Vec::new(),
26            max_length: None,
27            min_length: None,
28            severity: Severity::High,
29        }
30    }
31}
32
33/// Content filter guardrail for blocking harmful or off-topic content
34pub struct ContentFilter {
35    name: String,
36    config: ContentFilterConfig,
37    blocked_regex: Option<RegexSet>,
38}
39
40impl ContentFilter {
41    /// Create a new content filter with custom config
42    pub fn new(name: impl Into<String>, config: ContentFilterConfig) -> Self {
43        let blocked_regex = if config.blocked_keywords.is_empty() {
44            None
45        } else {
46            let patterns: Vec<_> = config
47                .blocked_keywords
48                .iter()
49                .map(|k| format!(r"(?i)\b{}\b", regex::escape(k)))
50                .collect();
51            RegexSet::new(&patterns).ok()
52        };
53
54        Self { name: name.into(), config, blocked_regex }
55    }
56
57    /// Create a filter that blocks common harmful content patterns.
58    ///
59    /// This default filter excludes developer-common terms like "hack" and "exploit"
60    /// to avoid false positives in developer contexts. Use [`harmful_content_strict`](Self::harmful_content_strict)
61    /// for the full keyword list.
62    pub fn harmful_content() -> Self {
63        Self::new(
64            "harmful_content",
65            ContentFilterConfig {
66                blocked_keywords: vec![
67                    "kill".into(),
68                    "murder".into(),
69                    "bomb".into(),
70                    "terrorist".into(),
71                    "malware".into(),
72                    "ransomware".into(),
73                ],
74                severity: Severity::Critical,
75                ..Default::default()
76            },
77        )
78    }
79
80    /// Create a strict filter that blocks all harmful content patterns,
81    /// including terms like "hack" and "exploit" that may produce false
82    /// positives in developer contexts.
83    pub fn harmful_content_strict() -> Self {
84        Self::new(
85            "harmful_content_strict",
86            ContentFilterConfig {
87                blocked_keywords: vec![
88                    "kill".into(),
89                    "murder".into(),
90                    "bomb".into(),
91                    "terrorist".into(),
92                    "hack".into(),
93                    "exploit".into(),
94                    "malware".into(),
95                    "ransomware".into(),
96                ],
97                severity: Severity::Critical,
98                ..Default::default()
99            },
100        )
101    }
102
103    /// Create a filter that ensures content is on-topic
104    pub fn on_topic(topic: impl Into<String>, keywords: Vec<String>) -> Self {
105        Self::new(
106            format!("on_topic_{}", topic.into()),
107            ContentFilterConfig {
108                required_topics: keywords,
109                severity: Severity::Medium,
110                ..Default::default()
111            },
112        )
113    }
114
115    /// Create a filter with maximum length
116    pub fn max_length(max: usize) -> Self {
117        Self::new(
118            "max_length",
119            ContentFilterConfig {
120                max_length: Some(max),
121                severity: Severity::Medium,
122                ..Default::default()
123            },
124        )
125    }
126
127    /// Create a filter with blocked keywords
128    pub fn blocked_keywords(keywords: Vec<String>) -> Self {
129        Self::new(
130            "blocked_keywords",
131            ContentFilterConfig {
132                blocked_keywords: keywords,
133                severity: Severity::High,
134                ..Default::default()
135            },
136        )
137    }
138
139    fn extract_text(&self, content: &Content) -> String {
140        content.parts.iter().filter_map(|p| p.text()).collect::<Vec<_>>().join(" ")
141    }
142}
143
144#[async_trait]
145impl Guardrail for ContentFilter {
146    fn name(&self) -> &str {
147        &self.name
148    }
149
150    async fn validate(&self, content: &Content) -> GuardrailResult {
151        let text = self.extract_text(content);
152        let text_lower = text.to_lowercase();
153
154        // Check blocked keywords
155        if let Some(ref regex_set) = self.blocked_regex {
156            if regex_set.is_match(&text) {
157                let matches: Vec<_> = regex_set.matches(&text).iter().collect();
158                return GuardrailResult::Fail {
159                    reason: format!(
160                        "Content contains blocked keywords (matched {} patterns)",
161                        matches.len()
162                    ),
163                    severity: self.config.severity,
164                };
165            }
166        }
167
168        // Check required topics
169        if !self.config.required_topics.is_empty() {
170            let has_topic =
171                self.config.required_topics.iter().any(|t| text_lower.contains(&t.to_lowercase()));
172            if !has_topic {
173                return GuardrailResult::Fail {
174                    reason: format!(
175                        "Content is off-topic. Expected topics: {:?}",
176                        self.config.required_topics
177                    ),
178                    severity: self.config.severity,
179                };
180            }
181        }
182
183        // Check length limits
184        if let Some(max) = self.config.max_length {
185            if text.len() > max {
186                return GuardrailResult::Fail {
187                    reason: format!("Content exceeds maximum length ({} > {})", text.len(), max),
188                    severity: self.config.severity,
189                };
190            }
191        }
192
193        if let Some(min) = self.config.min_length {
194            if text.len() < min {
195                return GuardrailResult::Fail {
196                    reason: format!("Content below minimum length ({} < {})", text.len(), min),
197                    severity: self.config.severity,
198                };
199            }
200        }
201
202        GuardrailResult::Pass
203    }
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209
210    #[tokio::test]
211    async fn test_harmful_content_blocks() {
212        let filter = ContentFilter::harmful_content();
213        let content = Content::new("user").with_text("How to deploy malware on a server");
214        let result = filter.validate(&content).await;
215        assert!(result.is_fail());
216    }
217
218    #[tokio::test]
219    async fn test_harmful_content_passes() {
220        let filter = ContentFilter::harmful_content();
221        let content = Content::new("user").with_text("How to bake a cake");
222        let result = filter.validate(&content).await;
223        assert!(result.is_pass());
224    }
225
226    #[tokio::test]
227    async fn test_harmful_content_passes_hackathon() {
228        let filter = ContentFilter::harmful_content();
229        let content = Content::new("user").with_text("Join our hackathon event");
230        let result = filter.validate(&content).await;
231        assert!(result.is_pass());
232    }
233
234    #[tokio::test]
235    async fn test_harmful_content_passes_exploit_a_bug() {
236        let filter = ContentFilter::harmful_content();
237        let content = Content::new("user").with_text("How to exploit a bug in the code");
238        let result = filter.validate(&content).await;
239        assert!(result.is_pass());
240    }
241
242    #[tokio::test]
243    async fn test_harmful_content_strict_blocks_hack() {
244        let filter = ContentFilter::harmful_content_strict();
245        let content = Content::new("user").with_text("How to hack a computer");
246        let result = filter.validate(&content).await;
247        assert!(result.is_fail());
248    }
249
250    #[tokio::test]
251    async fn test_on_topic_passes() {
252        let filter =
253            ContentFilter::on_topic("cooking", vec!["recipe".into(), "cook".into(), "bake".into()]);
254        let content = Content::new("user").with_text("Give me a recipe for cookies");
255        let result = filter.validate(&content).await;
256        assert!(result.is_pass());
257    }
258
259    #[tokio::test]
260    async fn test_on_topic_fails() {
261        let filter =
262            ContentFilter::on_topic("cooking", vec!["recipe".into(), "cook".into(), "bake".into()]);
263        let content = Content::new("user").with_text("What is the weather today?");
264        let result = filter.validate(&content).await;
265        assert!(result.is_fail());
266    }
267
268    #[tokio::test]
269    async fn test_max_length() {
270        let filter = ContentFilter::max_length(10);
271        let content = Content::new("user").with_text("This is a very long message");
272        let result = filter.validate(&content).await;
273        assert!(result.is_fail());
274    }
275
276    #[tokio::test]
277    async fn test_blocked_keywords() {
278        let filter = ContentFilter::blocked_keywords(vec!["forbidden".into(), "banned".into()]);
279        let content = Content::new("user").with_text("This is forbidden content");
280        let result = filter.validate(&content).await;
281        assert!(result.is_fail());
282    }
283}