Skip to main content

pii_vault/
recognizer.rs

1use crate::entity::{EntityType, RecognizerResult};
2use regex::Regex;
3use serde::{Deserialize, Serialize};
4
5pub trait Recognizer: Send + Sync {
6    fn name(&self) -> &str;
7    fn supported_entities(&self) -> &[EntityType];
8    fn analyze(&self, text: &str, entities: &[EntityType]) -> Vec<RecognizerResult>;
9}
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PatternDef {
13    pub name: String,
14    pub regex: String,
15    pub score: f64,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RecognizerDef {
20    pub name: String,
21    pub entity_type: String,
22    pub version: String,
23    pub patterns: Vec<PatternDef>,
24    #[serde(default)]
25    pub context_words: Vec<String>,
26    #[serde(default)]
27    pub context_score_boost: f64,
28    #[serde(default)]
29    pub deny_list: Vec<String>,
30    #[serde(default)]
31    pub validators: Vec<String>,
32    pub supported_languages: Option<Vec<String>>,
33}
34
35pub struct RegexRecognizer {
36    def: RecognizerDef,
37    compiled: Vec<(String, Regex, f64)>,
38    entity: EntityType,
39}
40
41impl RegexRecognizer {
42    pub fn from_def(def: RecognizerDef) -> Result<Self, regex::Error> {
43        let mut compiled = Vec::new();
44        for p in &def.patterns {
45            let re = Regex::new(&p.regex)?;
46            compiled.push((p.name.clone(), re, p.score));
47        }
48        let entity = EntityType::new(&def.entity_type);
49        Ok(Self { def, compiled, entity })
50    }
51
52    pub fn from_json(json: &str) -> Result<Self, Box<dyn std::error::Error>> {
53        let def: RecognizerDef = serde_json::from_str(json)?;
54        Ok(Self::from_def(def)?)
55    }
56
57    fn has_context(&self, text: &str, start: usize, end: usize) -> bool {
58        if self.def.context_words.is_empty() {
59            return false;
60        }
61        let window_start = start.saturating_sub(100);
62        let window_end = (end + 100).min(text.len());
63        let window = &text[window_start..window_end].to_lowercase();
64        self.def.context_words.iter().any(|w| window.contains(&w.to_lowercase()))
65    }
66
67    fn is_denied(&self, matched: &str) -> bool {
68        self.def.deny_list.iter().any(|d| matched == d)
69    }
70
71    fn validate(&self, matched: &str) -> bool {
72        for v in &self.def.validators {
73            match v.as_str() {
74                "luhn" => { if !luhn_check(matched) { return false; } }
75                "cn_id_checksum" => { if !cn_id_check(matched) { return false; } }
76                "iban" => { if !iban_check(matched) { return false; } }
77                "de_tax_id" => { if !de_tax_id_check(matched) { return false; } }
78                "au_abn" => { if !au_abn_check(matched) { return false; } }
79                "au_tfn" => { if !au_tfn_check(matched) { return false; } }
80                "au_acn" => { if !au_acn_check(matched) { return false; } }
81                "au_medicare" => { if !au_medicare_check(matched) { return false; } }
82                "uk_driving_licence" => { if !uk_driving_licence_check(matched) { return false; } }
83                _ => {}
84            }
85        }
86        true
87    }
88}
89
90impl Recognizer for RegexRecognizer {
91    fn name(&self) -> &str {
92        &self.def.name
93    }
94
95    fn supported_entities(&self) -> &[EntityType] {
96        std::slice::from_ref(&self.entity)
97    }
98
99    fn analyze(&self, text: &str, entities: &[EntityType]) -> Vec<RecognizerResult> {
100        if !entities.is_empty() && !entities.contains(&self.entity) {
101            return Vec::new();
102        }
103
104        let mut results = Vec::new();
105        for (pat_name, re, base_score) in &self.compiled {
106            for m in re.find_iter(text) {
107                let matched = m.as_str();
108
109                if self.is_denied(matched) {
110                    continue;
111                }
112
113                if !self.validate(matched) {
114                    continue;
115                }
116
117                let mut score = *base_score;
118                if self.has_context(text, m.start(), m.end()) {
119                    score = (score + self.def.context_score_boost).min(1.0);
120                }
121
122                results.push(RecognizerResult {
123                    entity_type: self.entity.clone(),
124                    start: m.start(),
125                    end: m.end(),
126                    score,
127                    recognizer_name: Some(pat_name.clone()),
128                });
129            }
130        }
131        results
132    }
133}
134
135fn luhn_check(number: &str) -> bool {
136    let digits: Vec<u32> = number
137        .chars()
138        .filter(|c| c.is_ascii_digit())
139        .filter_map(|c| c.to_digit(10))
140        .collect();
141    if digits.len() < 2 {
142        return false;
143    }
144    let mut sum = 0u32;
145    let mut double = false;
146    for &d in digits.iter().rev() {
147        let mut val = d;
148        if double {
149            val *= 2;
150            if val > 9 {
151                val -= 9;
152            }
153        }
154        sum += val;
155        double = !double;
156    }
157    sum % 10 == 0
158}
159
160fn cn_id_check(id: &str) -> bool {
161    if id.len() != 18 {
162        return false;
163    }
164    let weights = [7, 9, 10, 5, 8, 4, 2, 1, 6, 3, 7, 9, 10, 5, 8, 4, 2];
165    let check_chars = ['1', '0', 'X', '9', '8', '7', '6', '5', '4', '3', '2'];
166    let chars: Vec<char> = id.chars().collect();
167    let mut sum = 0usize;
168    for i in 0..17 {
169        let d = match chars[i].to_digit(10) {
170            Some(d) => d as usize,
171            None => return false,
172        };
173        sum += d * weights[i];
174    }
175    let expected = check_chars[sum % 11];
176    chars[17].to_ascii_uppercase() == expected
177}
178
179fn iban_check(iban: &str) -> bool {
180    let cleaned: String = iban.chars().filter(|c| !c.is_whitespace() && *c != '-').collect();
181    if cleaned.len() < 5 || cleaned.len() > 34 {
182        return false;
183    }
184    let rearranged = format!("{}{}", &cleaned[4..], &cleaned[..4]);
185    let numeric: String = rearranged.chars().map(|c| {
186        if c.is_ascii_digit() { c.to_string() }
187        else { ((c as u32 - 'A' as u32) + 10).to_string() }
188    }).collect();
189    let mut remainder = 0u64;
190    for chunk in numeric.as_bytes().chunks(7) {
191        let s = format!("{}{}", remainder, std::str::from_utf8(chunk).unwrap_or(""));
192        remainder = s.parse::<u64>().unwrap_or(0) % 97;
193    }
194    remainder == 1
195}
196
197fn de_tax_id_check(id: &str) -> bool {
198    let digits: Vec<u32> = id.chars().filter(|c| c.is_ascii_digit()).filter_map(|c| c.to_digit(10)).collect();
199    if digits.len() != 11 || digits[0] == 0 {
200        return false;
201    }
202    let first10: std::collections::HashSet<u32> = digits[..10].iter().copied().collect();
203    if first10.len() == 1 {
204        return false;
205    }
206    let mut product = 10u32;
207    for i in 0..10 {
208        let total = (digits[i] + product) % 10;
209        let total = if total == 0 { 10 } else { total };
210        product = (total * 2) % 11;
211    }
212    let check = if 11 - product == 10 { 0 } else { 11 - product };
213    check == digits[10]
214}
215
216fn au_abn_check(abn: &str) -> bool {
217    let digits: Vec<i64> = abn.chars().filter(|c| c.is_ascii_digit()).filter_map(|c| c.to_digit(10).map(|d| d as i64)).collect();
218    if digits.len() != 11 {
219        return false;
220    }
221    let weights: [i64; 11] = [10, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19];
222    let mut d = digits.clone();
223    d[0] -= 1;
224    let sum: i64 = d.iter().zip(weights.iter()).map(|(a, b)| a * b).sum();
225    sum % 89 == 0
226}
227
228fn au_tfn_check(tfn: &str) -> bool {
229    let digits: Vec<u32> = tfn.chars().filter(|c| c.is_ascii_digit()).filter_map(|c| c.to_digit(10)).collect();
230    if digits.len() != 9 {
231        return false;
232    }
233    let weights: [u32; 9] = [1, 4, 3, 7, 5, 8, 6, 9, 10];
234    let sum: u32 = digits.iter().zip(weights.iter()).map(|(a, b)| a * b).sum();
235    sum % 11 == 0
236}
237
238fn au_acn_check(acn: &str) -> bool {
239    let digits: Vec<u32> = acn.chars().filter(|c| c.is_ascii_digit()).filter_map(|c| c.to_digit(10)).collect();
240    if digits.len() != 9 {
241        return false;
242    }
243    let weights: [u32; 8] = [8, 7, 6, 5, 4, 3, 2, 1];
244    let sum: u32 = digits[..8].iter().zip(weights.iter()).map(|(a, b)| a * b).sum();
245    let check = (10 - (sum % 10)) % 10;
246    check == digits[8]
247}
248
249fn au_medicare_check(medicare: &str) -> bool {
250    let digits: Vec<u32> = medicare.chars().filter(|c| c.is_ascii_digit()).filter_map(|c| c.to_digit(10)).collect();
251    if digits.len() < 10 || digits.len() > 11 {
252        return false;
253    }
254    if digits[0] < 2 || digits[0] > 6 {
255        return false;
256    }
257    let weights: [u32; 8] = [1, 3, 7, 9, 1, 3, 7, 9];
258    let sum: u32 = digits[..8].iter().zip(weights.iter()).map(|(a, b)| a * b).sum();
259    sum % 10 == digits[8]
260}
261
262fn uk_driving_licence_check(licence: &str) -> bool {
263    let text = licence.to_uppercase();
264    if text.len() != 16 {
265        return false;
266    }
267    let surname: &str = &text[..5];
268    // All 9s = no valid surname
269    if surname == "99999" {
270        return false;
271    }
272    // Surname must be letters followed by optional 9-padding (no 9 before a letter)
273    let chars: Vec<char> = surname.chars().collect();
274    let mut seen_nine = false;
275    for &c in &chars {
276        if c == '9' {
277            seen_nine = true;
278        } else if seen_nine {
279            // Letter after 9 = invalid padding
280            return false;
281        }
282    }
283    // Must start with at least one letter
284    !chars[0].is_ascii_digit()
285}
286
287pub fn load_recognizers_from_dir(dir: &std::path::Path) -> Vec<Box<dyn Recognizer>> {
288    let mut recognizers: Vec<Box<dyn Recognizer>> = Vec::new();
289    if let Ok(entries) = std::fs::read_dir(dir) {
290        for entry in entries.flatten() {
291            let path = entry.path();
292            if path.extension().map_or(false, |e| e == "json") {
293                if let Ok(json) = std::fs::read_to_string(&path) {
294                    match RegexRecognizer::from_json(&json) {
295                        Ok(r) => recognizers.push(Box::new(r)),
296                        Err(e) => eprintln!("Failed to load {:?}: {}", path, e),
297                    }
298                }
299            }
300        }
301    }
302    recognizers
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308
309    #[test]
310    fn test_luhn_valid() {
311        assert!(luhn_check("4532015112830366"));
312        assert!(luhn_check("4111111111111111"));
313    }
314
315    #[test]
316    fn test_luhn_invalid() {
317        assert!(!luhn_check("1234567890123456"));
318    }
319
320    #[test]
321    fn test_cn_id_valid() {
322        assert!(cn_id_check("11010519491231002X"));
323    }
324
325    #[test]
326    fn test_cn_id_invalid() {
327        assert!(!cn_id_check("110105194912310020"));
328    }
329
330    #[test]
331    fn test_regex_recognizer_email() {
332        let json = r#"{
333            "name": "email_recognizer",
334            "entity_type": "EMAIL_ADDRESS",
335            "version": "1.0.0",
336            "patterns": [{"name": "email", "regex": "[A-Za-z0-9._%+\\-]+@[A-Za-z0-9.\\-]+\\.[A-Za-z]{2,}", "score": 0.5}],
337            "context_words": ["email"],
338            "context_score_boost": 0.4
339        }"#;
340        let rec = RegexRecognizer::from_json(json).unwrap();
341        let results = rec.analyze("Contact me at test@example.com please", &[]);
342        assert_eq!(results.len(), 1);
343        assert_eq!(results[0].entity_type.as_str(), "EMAIL_ADDRESS");
344        assert_eq!(&"Contact me at test@example.com please"[results[0].start..results[0].end], "test@example.com");
345    }
346
347    #[test]
348    fn test_context_boost() {
349        let json = r#"{
350            "name": "email_recognizer",
351            "entity_type": "EMAIL_ADDRESS",
352            "version": "1.0.0",
353            "patterns": [{"name": "email", "regex": "[A-Za-z0-9._%+\\-]+@[A-Za-z0-9.\\-]+\\.[A-Za-z]{2,}", "score": 0.5}],
354            "context_words": ["email"],
355            "context_score_boost": 0.4
356        }"#;
357        let rec = RegexRecognizer::from_json(json).unwrap();
358
359        let with_ctx = rec.analyze("My email is test@example.com", &[]);
360        let without_ctx = rec.analyze("test@example.com", &[]);
361
362        assert!(with_ctx[0].score > without_ctx[0].score);
363    }
364
365    #[test]
366    fn test_deny_list() {
367        let json = r#"{
368            "name": "ip_recognizer",
369            "entity_type": "IP_ADDRESS",
370            "version": "1.0.0",
371            "patterns": [{"name": "ipv4", "regex": "\\b(?:(?:25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(?:25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\b", "score": 0.5}],
372            "deny_list": ["0.0.0.0", "127.0.0.1"],
373            "context_words": []
374        }"#;
375        let rec = RegexRecognizer::from_json(json).unwrap();
376        let results = rec.analyze("Server at 127.0.0.1 and 192.168.1.1", &[]);
377        assert_eq!(results.len(), 1);
378        assert_eq!(&"Server at 127.0.0.1 and 192.168.1.1"[results[0].start..results[0].end], "192.168.1.1");
379    }
380}