Skip to main content

adk_guardrail/
pii.rs

1use crate::{Guardrail, GuardrailResult};
2use adk_core::{Content, Part};
3use async_trait::async_trait;
4use regex::Regex;
5
6/// Types of PII to detect and redact
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8pub enum PiiType {
9    Email,
10    Phone,
11    Ssn,
12    CreditCard,
13    IpAddress,
14}
15
16impl PiiType {
17    fn pattern(&self) -> &'static str {
18        match self {
19            PiiType::Email => r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b",
20            PiiType::Phone => r"\b(?:\+?1[-.\s]?)?\(?[0-9]{3}\)?[-.\s]?[0-9]{3}[-.\s]?[0-9]{4}\b",
21            PiiType::Ssn => r"\b\d{3}[-\s]?\d{2}[-\s]?\d{4}\b",
22            PiiType::CreditCard => r"\b(?:\d{4}[-\s]?){3}\d{4}\b",
23            PiiType::IpAddress => r"\b(?:\d{1,3}\.){3}\d{1,3}\b",
24        }
25    }
26
27    fn redaction(&self) -> &'static str {
28        match self {
29            PiiType::Email => "[EMAIL REDACTED]",
30            PiiType::Phone => "[PHONE REDACTED]",
31            PiiType::Ssn => "[SSN REDACTED]",
32            PiiType::CreditCard => "[CREDIT CARD REDACTED]",
33            PiiType::IpAddress => "[IP REDACTED]",
34        }
35    }
36}
37
38/// PII detection and redaction guardrail
39pub struct PiiRedactor {
40    patterns: Vec<(PiiType, Regex)>,
41}
42
43impl PiiRedactor {
44    /// Create a new PII redactor with all PII types enabled
45    pub fn new() -> Self {
46        Self::with_types(&[PiiType::Email, PiiType::Phone, PiiType::Ssn, PiiType::CreditCard])
47    }
48
49    /// Create a PII redactor with specific types
50    pub fn with_types(types: &[PiiType]) -> Self {
51        let patterns =
52            types.iter().filter_map(|t| Regex::new(t.pattern()).ok().map(|r| (*t, r))).collect();
53
54        Self { patterns }
55    }
56
57    /// Redact PII from text, returns (redacted_text, found_types)
58    pub fn redact(&self, text: &str) -> (String, Vec<PiiType>) {
59        let mut result = text.to_string();
60        let mut found = Vec::new();
61
62        for (pii_type, regex) in &self.patterns {
63            if regex.is_match(&result) {
64                found.push(*pii_type);
65                result = regex.replace_all(&result, pii_type.redaction()).to_string();
66            }
67        }
68
69        (result, found)
70    }
71}
72
73impl Default for PiiRedactor {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79#[async_trait]
80impl Guardrail for PiiRedactor {
81    fn name(&self) -> &str {
82        "pii_redactor"
83    }
84
85    async fn validate(&self, content: &Content) -> GuardrailResult {
86        let mut new_parts = Vec::new();
87        let mut any_redacted = false;
88        let mut redacted_types = Vec::new();
89
90        for part in &content.parts {
91            match part {
92                Part::Text { text } => {
93                    let (redacted, found) = self.redact(text);
94                    if !found.is_empty() {
95                        any_redacted = true;
96                        redacted_types.extend(found);
97                        new_parts.push(Part::Text { text: redacted });
98                    } else {
99                        new_parts.push(part.clone());
100                    }
101                }
102                _ => new_parts.push(part.clone()),
103            }
104        }
105
106        if any_redacted {
107            let types_str: Vec<_> = redacted_types.iter().map(|t| format!("{:?}", t)).collect();
108            GuardrailResult::Transform {
109                new_content: Content { role: content.role.clone(), parts: new_parts },
110                reason: format!("Redacted PII types: {}", types_str.join(", ")),
111            }
112        } else {
113            GuardrailResult::Pass
114        }
115    }
116
117    fn run_parallel(&self) -> bool {
118        false // Must run sequentially to transform content
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn test_email_redaction() {
128        let redactor = PiiRedactor::new();
129        let (result, found) = redactor.redact("Contact me at test@example.com");
130        assert_eq!(result, "Contact me at [EMAIL REDACTED]");
131        assert!(found.contains(&PiiType::Email));
132    }
133
134    #[test]
135    fn test_phone_redaction() {
136        let redactor = PiiRedactor::new();
137        let (result, found) = redactor.redact("Call me at 555-123-4567");
138        assert_eq!(result, "Call me at [PHONE REDACTED]");
139        assert!(found.contains(&PiiType::Phone));
140    }
141
142    #[test]
143    fn test_ssn_redaction() {
144        let redactor = PiiRedactor::new();
145        let (result, found) = redactor.redact("SSN: 123-45-6789");
146        assert_eq!(result, "SSN: [SSN REDACTED]");
147        assert!(found.contains(&PiiType::Ssn));
148    }
149
150    #[test]
151    fn test_credit_card_redaction() {
152        let redactor = PiiRedactor::new();
153        let (result, found) = redactor.redact("Card: 4111-1111-1111-1111");
154        assert_eq!(result, "Card: [CREDIT CARD REDACTED]");
155        assert!(found.contains(&PiiType::CreditCard));
156    }
157
158    #[test]
159    fn test_multiple_pii() {
160        let redactor = PiiRedactor::new();
161        let (result, found) = redactor.redact("Email: a@b.com, Phone: 555-123-4567");
162        assert!(result.contains("[EMAIL REDACTED]"));
163        assert!(result.contains("[PHONE REDACTED]"));
164        assert_eq!(found.len(), 2);
165    }
166
167    #[test]
168    fn test_no_pii() {
169        let redactor = PiiRedactor::new();
170        let (result, found) = redactor.redact("Hello world");
171        assert_eq!(result, "Hello world");
172        assert!(found.is_empty());
173    }
174
175    #[tokio::test]
176    async fn test_guardrail_transform() {
177        let redactor = PiiRedactor::new();
178        let content = Content::new("user").with_text("Email: test@example.com");
179        let result = redactor.validate(&content).await;
180
181        match result {
182            GuardrailResult::Transform { new_content, .. } => {
183                let text = new_content.parts[0].text().unwrap();
184                assert!(text.contains("[EMAIL REDACTED]"));
185            }
186            _ => panic!("Expected Transform result"),
187        }
188    }
189
190    #[tokio::test]
191    async fn test_guardrail_pass() {
192        let redactor = PiiRedactor::new();
193        let content = Content::new("user").with_text("Hello world");
194        let result = redactor.validate(&content).await;
195        assert!(result.is_pass());
196    }
197}