1use crate::entity::{EntityType, RecognizerResult};
2use regex::Regex;
3use serde::{Deserialize, Serialize};
4
5pub trait Recognizer: Send + Sync {
6 fn name(&self) -> &str;
7 fn supported_entities(&self) -> &[EntityType];
8 fn analyze(&self, text: &str, entities: &[EntityType]) -> Vec<RecognizerResult>;
9}
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct PatternDef {
13 pub name: String,
14 pub regex: String,
15 pub score: f64,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct RecognizerDef {
20 pub name: String,
21 pub entity_type: String,
22 pub version: String,
23 pub patterns: Vec<PatternDef>,
24 #[serde(default)]
25 pub context_words: Vec<String>,
26 #[serde(default)]
27 pub context_score_boost: f64,
28 #[serde(default)]
29 pub deny_list: Vec<String>,
30 #[serde(default)]
31 pub validators: Vec<String>,
32 pub supported_languages: Option<Vec<String>>,
33}
34
35pub struct RegexRecognizer {
36 def: RecognizerDef,
37 compiled: Vec<(String, Regex, f64)>,
38 entity: EntityType,
39}
40
41impl RegexRecognizer {
42 pub fn from_def(def: RecognizerDef) -> Result<Self, regex::Error> {
43 let mut compiled = Vec::new();
44 for p in &def.patterns {
45 let re = Regex::new(&p.regex)?;
46 compiled.push((p.name.clone(), re, p.score));
47 }
48 let entity = EntityType::new(&def.entity_type);
49 Ok(Self { def, compiled, entity })
50 }
51
52 pub fn from_json(json: &str) -> Result<Self, Box<dyn std::error::Error>> {
53 let def: RecognizerDef = serde_json::from_str(json)?;
54 Ok(Self::from_def(def)?)
55 }
56
57 fn has_context(&self, text: &str, start: usize, end: usize) -> bool {
58 if self.def.context_words.is_empty() {
59 return false;
60 }
61 let window_start = start.saturating_sub(100);
62 let window_end = (end + 100).min(text.len());
63 let window = &text[window_start..window_end].to_lowercase();
64 self.def.context_words.iter().any(|w| window.contains(&w.to_lowercase()))
65 }
66
67 fn is_denied(&self, matched: &str) -> bool {
68 self.def.deny_list.iter().any(|d| matched == d)
69 }
70
71 fn validate(&self, matched: &str) -> bool {
72 for v in &self.def.validators {
73 match v.as_str() {
74 "luhn" => { if !luhn_check(matched) { return false; } }
75 "cn_id_checksum" => { if !cn_id_check(matched) { return false; } }
76 "iban" => { if !iban_check(matched) { return false; } }
77 "de_tax_id" => { if !de_tax_id_check(matched) { return false; } }
78 "au_abn" => { if !au_abn_check(matched) { return false; } }
79 "au_tfn" => { if !au_tfn_check(matched) { return false; } }
80 "au_acn" => { if !au_acn_check(matched) { return false; } }
81 "au_medicare" => { if !au_medicare_check(matched) { return false; } }
82 "uk_driving_licence" => { if !uk_driving_licence_check(matched) { return false; } }
83 _ => {}
84 }
85 }
86 true
87 }
88}
89
90impl Recognizer for RegexRecognizer {
91 fn name(&self) -> &str {
92 &self.def.name
93 }
94
95 fn supported_entities(&self) -> &[EntityType] {
96 std::slice::from_ref(&self.entity)
97 }
98
99 fn analyze(&self, text: &str, entities: &[EntityType]) -> Vec<RecognizerResult> {
100 if !entities.is_empty() && !entities.contains(&self.entity) {
101 return Vec::new();
102 }
103
104 let mut results = Vec::new();
105 for (pat_name, re, base_score) in &self.compiled {
106 for m in re.find_iter(text) {
107 let matched = m.as_str();
108
109 if self.is_denied(matched) {
110 continue;
111 }
112
113 if !self.validate(matched) {
114 continue;
115 }
116
117 let mut score = *base_score;
118 if self.has_context(text, m.start(), m.end()) {
119 score = (score + self.def.context_score_boost).min(1.0);
120 }
121
122 results.push(RecognizerResult {
123 entity_type: self.entity.clone(),
124 start: m.start(),
125 end: m.end(),
126 score,
127 recognizer_name: Some(pat_name.clone()),
128 });
129 }
130 }
131 results
132 }
133}
134
135fn luhn_check(number: &str) -> bool {
136 let digits: Vec<u32> = number
137 .chars()
138 .filter(|c| c.is_ascii_digit())
139 .filter_map(|c| c.to_digit(10))
140 .collect();
141 if digits.len() < 2 {
142 return false;
143 }
144 let mut sum = 0u32;
145 let mut double = false;
146 for &d in digits.iter().rev() {
147 let mut val = d;
148 if double {
149 val *= 2;
150 if val > 9 {
151 val -= 9;
152 }
153 }
154 sum += val;
155 double = !double;
156 }
157 sum % 10 == 0
158}
159
160fn cn_id_check(id: &str) -> bool {
161 if id.len() != 18 {
162 return false;
163 }
164 let weights = [7, 9, 10, 5, 8, 4, 2, 1, 6, 3, 7, 9, 10, 5, 8, 4, 2];
165 let check_chars = ['1', '0', 'X', '9', '8', '7', '6', '5', '4', '3', '2'];
166 let chars: Vec<char> = id.chars().collect();
167 let mut sum = 0usize;
168 for i in 0..17 {
169 let d = match chars[i].to_digit(10) {
170 Some(d) => d as usize,
171 None => return false,
172 };
173 sum += d * weights[i];
174 }
175 let expected = check_chars[sum % 11];
176 chars[17].to_ascii_uppercase() == expected
177}
178
179fn iban_check(iban: &str) -> bool {
180 let cleaned: String = iban.chars().filter(|c| !c.is_whitespace() && *c != '-').collect();
181 if cleaned.len() < 5 || cleaned.len() > 34 {
182 return false;
183 }
184 let rearranged = format!("{}{}", &cleaned[4..], &cleaned[..4]);
185 let numeric: String = rearranged.chars().map(|c| {
186 if c.is_ascii_digit() { c.to_string() }
187 else { ((c as u32 - 'A' as u32) + 10).to_string() }
188 }).collect();
189 let mut remainder = 0u64;
190 for chunk in numeric.as_bytes().chunks(7) {
191 let s = format!("{}{}", remainder, std::str::from_utf8(chunk).unwrap_or(""));
192 remainder = s.parse::<u64>().unwrap_or(0) % 97;
193 }
194 remainder == 1
195}
196
197fn de_tax_id_check(id: &str) -> bool {
198 let digits: Vec<u32> = id.chars().filter(|c| c.is_ascii_digit()).filter_map(|c| c.to_digit(10)).collect();
199 if digits.len() != 11 || digits[0] == 0 {
200 return false;
201 }
202 let first10: std::collections::HashSet<u32> = digits[..10].iter().copied().collect();
203 if first10.len() == 1 {
204 return false;
205 }
206 let mut product = 10u32;
207 for i in 0..10 {
208 let total = (digits[i] + product) % 10;
209 let total = if total == 0 { 10 } else { total };
210 product = (total * 2) % 11;
211 }
212 let check = if 11 - product == 10 { 0 } else { 11 - product };
213 check == digits[10]
214}
215
216fn au_abn_check(abn: &str) -> bool {
217 let digits: Vec<i64> = abn.chars().filter(|c| c.is_ascii_digit()).filter_map(|c| c.to_digit(10).map(|d| d as i64)).collect();
218 if digits.len() != 11 {
219 return false;
220 }
221 let weights: [i64; 11] = [10, 1, 3, 5, 7, 9, 11, 13, 15, 17, 19];
222 let mut d = digits.clone();
223 d[0] -= 1;
224 let sum: i64 = d.iter().zip(weights.iter()).map(|(a, b)| a * b).sum();
225 sum % 89 == 0
226}
227
228fn au_tfn_check(tfn: &str) -> bool {
229 let digits: Vec<u32> = tfn.chars().filter(|c| c.is_ascii_digit()).filter_map(|c| c.to_digit(10)).collect();
230 if digits.len() != 9 {
231 return false;
232 }
233 let weights: [u32; 9] = [1, 4, 3, 7, 5, 8, 6, 9, 10];
234 let sum: u32 = digits.iter().zip(weights.iter()).map(|(a, b)| a * b).sum();
235 sum % 11 == 0
236}
237
238fn au_acn_check(acn: &str) -> bool {
239 let digits: Vec<u32> = acn.chars().filter(|c| c.is_ascii_digit()).filter_map(|c| c.to_digit(10)).collect();
240 if digits.len() != 9 {
241 return false;
242 }
243 let weights: [u32; 8] = [8, 7, 6, 5, 4, 3, 2, 1];
244 let sum: u32 = digits[..8].iter().zip(weights.iter()).map(|(a, b)| a * b).sum();
245 let check = (10 - (sum % 10)) % 10;
246 check == digits[8]
247}
248
249fn au_medicare_check(medicare: &str) -> bool {
250 let digits: Vec<u32> = medicare.chars().filter(|c| c.is_ascii_digit()).filter_map(|c| c.to_digit(10)).collect();
251 if digits.len() < 10 || digits.len() > 11 {
252 return false;
253 }
254 if digits[0] < 2 || digits[0] > 6 {
255 return false;
256 }
257 let weights: [u32; 8] = [1, 3, 7, 9, 1, 3, 7, 9];
258 let sum: u32 = digits[..8].iter().zip(weights.iter()).map(|(a, b)| a * b).sum();
259 sum % 10 == digits[8]
260}
261
262fn uk_driving_licence_check(licence: &str) -> bool {
263 let text = licence.to_uppercase();
264 if text.len() != 16 {
265 return false;
266 }
267 let surname: &str = &text[..5];
268 if surname == "99999" {
270 return false;
271 }
272 let chars: Vec<char> = surname.chars().collect();
274 let mut seen_nine = false;
275 for &c in &chars {
276 if c == '9' {
277 seen_nine = true;
278 } else if seen_nine {
279 return false;
281 }
282 }
283 !chars[0].is_ascii_digit()
285}
286
287pub fn load_recognizers_from_dir(dir: &std::path::Path) -> Vec<Box<dyn Recognizer>> {
288 let mut recognizers: Vec<Box<dyn Recognizer>> = Vec::new();
289 if let Ok(entries) = std::fs::read_dir(dir) {
290 for entry in entries.flatten() {
291 let path = entry.path();
292 if path.extension().map_or(false, |e| e == "json") {
293 if let Ok(json) = std::fs::read_to_string(&path) {
294 match RegexRecognizer::from_json(&json) {
295 Ok(r) => recognizers.push(Box::new(r)),
296 Err(e) => eprintln!("Failed to load {:?}: {}", path, e),
297 }
298 }
299 }
300 }
301 }
302 recognizers
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_luhn_valid() {
311 assert!(luhn_check("4532015112830366"));
312 assert!(luhn_check("4111111111111111"));
313 }
314
315 #[test]
316 fn test_luhn_invalid() {
317 assert!(!luhn_check("1234567890123456"));
318 }
319
320 #[test]
321 fn test_cn_id_valid() {
322 assert!(cn_id_check("11010519491231002X"));
323 }
324
325 #[test]
326 fn test_cn_id_invalid() {
327 assert!(!cn_id_check("110105194912310020"));
328 }
329
330 #[test]
331 fn test_regex_recognizer_email() {
332 let json = r#"{
333 "name": "email_recognizer",
334 "entity_type": "EMAIL_ADDRESS",
335 "version": "1.0.0",
336 "patterns": [{"name": "email", "regex": "[A-Za-z0-9._%+\\-]+@[A-Za-z0-9.\\-]+\\.[A-Za-z]{2,}", "score": 0.5}],
337 "context_words": ["email"],
338 "context_score_boost": 0.4
339 }"#;
340 let rec = RegexRecognizer::from_json(json).unwrap();
341 let results = rec.analyze("Contact me at test@example.com please", &[]);
342 assert_eq!(results.len(), 1);
343 assert_eq!(results[0].entity_type.as_str(), "EMAIL_ADDRESS");
344 assert_eq!(&"Contact me at test@example.com please"[results[0].start..results[0].end], "test@example.com");
345 }
346
347 #[test]
348 fn test_context_boost() {
349 let json = r#"{
350 "name": "email_recognizer",
351 "entity_type": "EMAIL_ADDRESS",
352 "version": "1.0.0",
353 "patterns": [{"name": "email", "regex": "[A-Za-z0-9._%+\\-]+@[A-Za-z0-9.\\-]+\\.[A-Za-z]{2,}", "score": 0.5}],
354 "context_words": ["email"],
355 "context_score_boost": 0.4
356 }"#;
357 let rec = RegexRecognizer::from_json(json).unwrap();
358
359 let with_ctx = rec.analyze("My email is test@example.com", &[]);
360 let without_ctx = rec.analyze("test@example.com", &[]);
361
362 assert!(with_ctx[0].score > without_ctx[0].score);
363 }
364
365 #[test]
366 fn test_deny_list() {
367 let json = r#"{
368 "name": "ip_recognizer",
369 "entity_type": "IP_ADDRESS",
370 "version": "1.0.0",
371 "patterns": [{"name": "ipv4", "regex": "\\b(?:(?:25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(?:25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\b", "score": 0.5}],
372 "deny_list": ["0.0.0.0", "127.0.0.1"],
373 "context_words": []
374 }"#;
375 let rec = RegexRecognizer::from_json(json).unwrap();
376 let results = rec.analyze("Server at 127.0.0.1 and 192.168.1.1", &[]);
377 assert_eq!(results.len(), 1);
378 assert_eq!(&"Server at 127.0.0.1 and 192.168.1.1"[results[0].start..results[0].end], "192.168.1.1");
379 }
380}