1use super::{Recognizer, RecognizerResult};
7use crate::types::EntityType;
8use anyhow::Result;
9use lazy_static::lazy_static;
10use regex::Regex;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
15pub struct PatternRecognizer {
16 name: String,
17 patterns: HashMap<EntityType, Vec<CompiledPattern>>,
18 min_score: f32,
19}
20
21#[derive(Debug, Clone)]
22struct CompiledPattern {
23 regex: Regex,
24 score: f32,
25 context_words: Vec<String>,
26}
27
28impl PatternRecognizer {
29 pub fn new() -> Self {
31 let mut recognizer = Self {
32 name: "PatternRecognizer".to_string(),
33 patterns: HashMap::new(),
34 min_score: 0.5,
35 };
36 recognizer.load_default_patterns();
37 recognizer
38 }
39
40 pub fn with_name(name: impl Into<String>) -> Self {
42 let mut recognizer = Self::new();
43 recognizer.name = name.into();
44 recognizer
45 }
46
47 pub fn with_min_score(mut self, min_score: f32) -> Self {
49 self.min_score = min_score;
50 self
51 }
52
53 pub fn add_pattern(
55 &mut self,
56 entity_type: EntityType,
57 pattern: &str,
58 score: f32,
59 ) -> Result<()> {
60 let regex = Regex::new(pattern)?;
61 let compiled = CompiledPattern {
62 regex,
63 score,
64 context_words: vec![],
65 };
66 self.patterns.entry(entity_type).or_default().push(compiled);
67 Ok(())
68 }
69
70 pub fn add_pattern_with_context(
72 &mut self,
73 entity_type: EntityType,
74 pattern: &str,
75 score: f32,
76 context_words: Vec<String>,
77 ) -> Result<()> {
78 let regex = Regex::new(pattern)?;
79 let compiled = CompiledPattern {
80 regex,
81 score,
82 context_words,
83 };
84 self.patterns.entry(entity_type).or_default().push(compiled);
85 Ok(())
86 }
87
88 fn load_default_patterns(&mut self) {
90 let _ = self.add_pattern(
92 EntityType::EmailAddress,
93 r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b",
94 0.8,
95 );
96
97 let _ = self.add_pattern(
100 EntityType::PhoneNumber,
101 r"\(?\d{3}\)?[-.\s]?\d{3}[-.\s]?\d{4}\b",
102 0.7,
103 );
104
105 let _ = self.add_pattern(
107 EntityType::CreditCard,
108 r"\b(?:4[0-9]{12}(?:[0-9]{3})?|5[1-5][0-9]{14}|3[47][0-9]{13}|6(?:011|5[0-9]{2})[0-9]{12})\b",
109 0.9,
110 );
111
112 let _ = self.add_pattern(EntityType::UsSsn, r"\b\d{3}-\d{2}-\d{4}\b", 0.9);
115
116 let _ = self.add_pattern(
118 EntityType::IpAddress,
119 r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
120 0.8,
121 );
122
123 let _ = self.add_pattern(
125 EntityType::Url,
126 r"\b(?:https?://|www\.)[a-zA-Z0-9][-a-zA-Z0-9]*(?:\.[a-zA-Z0-9][-a-zA-Z0-9]*)+(?:/[^\s]*)?\b",
127 0.7,
128 );
129
130 let _ = self.add_pattern(
132 EntityType::Guid,
133 r"\b[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}\b",
134 0.9,
135 );
136
137 let _ = self.add_pattern(
139 EntityType::MacAddress,
140 r"\b(?:[0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b",
141 0.9,
142 );
143
144 let _ = self.add_pattern_with_context(
146 EntityType::UkNhs,
147 r"\b(?:\d{3}\s?\d{3}\s?\d{4}|\d{10})\b",
148 0.6,
149 vec![
150 "NHS".to_string(),
151 "patient".to_string(),
152 "health".to_string(),
153 ],
154 );
155
156 let _ = self.add_pattern(
158 EntityType::UkNino,
159 r"\b[A-CEGHJ-PR-TW-Z]{1}[A-CEGHJ-NPR-TW-Z]{1}\d{6}[A-D]{1}\b",
160 0.85,
161 );
162
163 let _ = self.add_pattern(
165 EntityType::UkPostcode,
166 r"\b[A-Z]{1,2}\d[A-Z\d]?\s?\d[A-Z]{2}\b",
167 0.75,
168 );
169
170 let _ = self.add_pattern(EntityType::UkSortCode, r"\b\d{2}-\d{2}-\d{2}\b", 0.7);
172
173 let _ = self.add_pattern(
175 EntityType::IbanCode,
176 r"\b[A-Z]{2}\d{2}[A-Z0-9]{1,30}\b",
177 0.75,
178 );
179
180 let _ = self.add_pattern(
182 EntityType::BtcAddress,
183 r"\b(?:bc1|[13])[a-zA-HJ-NP-Z0-9]{25,62}\b",
184 0.85,
185 );
186
187 let _ = self.add_pattern(EntityType::EthAddress, r"\b0x[a-fA-F0-9]{40}\b", 0.9);
189
190 let _ = self.add_pattern(EntityType::Md5Hash, r"\b[a-fA-F0-9]{32}\b", 0.6);
192
193 let _ = self.add_pattern(EntityType::Sha1Hash, r"\b[a-fA-F0-9]{40}\b", 0.6);
195
196 let _ = self.add_pattern(EntityType::Sha256Hash, r"\b[a-fA-F0-9]{64}\b", 0.6);
198
199 let _ = self.add_pattern(
201 EntityType::UsZipCode,
202 r"\b\d{5}(?:-\d{4})?\b",
203 0.6, );
205
206 let _ = self.add_pattern_with_context(
208 EntityType::PoBox,
209 r"\b(?:P\.?\s?O\.?|POST\s+OFFICE)\s*BOX\s+\d+\b",
210 0.85,
211 vec![
212 "address".to_string(),
213 "mail".to_string(),
214 "ship".to_string(),
215 ],
216 );
217
218 let _ = self.add_pattern(
220 EntityType::Isbn,
221 r"\b(?:ISBN(?:-1[03])?:?\s*)?(?:\d{9}[\dX]|\d{13})\b",
222 0.8,
223 );
224
225 let _ = self.add_pattern_with_context(
227 EntityType::PassportNumber,
228 r"\b[A-Z]{1,2}\d{6,9}\b",
229 0.7,
230 vec!["passport".to_string(), "travel".to_string()],
231 );
232
233 let _ = self.add_pattern_with_context(
235 EntityType::MedicalRecordNumber,
236 r"\b(?:MRN|Medical\s*Record|Patient\s*ID):?\s*[A-Z0-9]{6,12}\b",
237 0.85,
238 vec![
239 "patient".to_string(),
240 "medical".to_string(),
241 "hospital".to_string(),
242 ],
243 );
244
245 let _ = self.add_pattern_with_context(
247 EntityType::Age,
248 r"\b(?:age|aged|years old):?\s*(\d{1,3})\b",
249 0.8,
250 vec!["years".to_string(), "old".to_string(), "age".to_string()],
251 );
252
253 let _ = self.add_pattern(
255 EntityType::DateTime,
256 r"\b\d{4}-\d{2}-\d{2}(?:[T\s]\d{2}:\d{2}(?::\d{2})?)?\b",
257 0.5,
258 );
259 }
260
261 fn check_context(&self, text: &str, start: usize, end: usize, context_words: &[String]) -> f32 {
263 if context_words.is_empty() {
264 return 0.0;
265 }
266
267 let context_start = start.saturating_sub(50);
269 let context_end = (end + 50).min(text.len());
270 let context = &text[context_start..context_end].to_lowercase();
271
272 let matches = context_words
274 .iter()
275 .filter(|word| context.contains(&word.to_lowercase()))
276 .count();
277
278 (matches as f32 / context_words.len() as f32) * 0.3
280 }
281}
282
283impl Default for PatternRecognizer {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289impl Recognizer for PatternRecognizer {
290 fn name(&self) -> &str {
291 &self.name
292 }
293
294 fn supported_entities(&self) -> &[EntityType] {
295 lazy_static! {
296 static ref SUPPORTED: Vec<EntityType> = vec![
297 EntityType::EmailAddress,
298 EntityType::PhoneNumber,
299 EntityType::CreditCard,
300 EntityType::UsSsn,
301 EntityType::IpAddress,
302 EntityType::Url,
303 EntityType::Guid,
304 EntityType::MacAddress,
305 EntityType::UkNhs,
306 EntityType::UkNino,
307 EntityType::UkPostcode,
308 EntityType::UkSortCode,
309 EntityType::IbanCode,
310 EntityType::BtcAddress,
311 EntityType::EthAddress,
312 EntityType::Md5Hash,
313 EntityType::Sha1Hash,
314 EntityType::Sha256Hash,
315 EntityType::DateTime,
316 ];
317 }
318 &SUPPORTED
319 }
320
321 fn analyze(&self, text: &str, _language: &str) -> Result<Vec<RecognizerResult>> {
322 let mut results = Vec::new();
323
324 for (entity_type, patterns) in &self.patterns {
325 for pattern in patterns {
326 for capture in pattern.regex.captures_iter(text) {
327 if let Some(matched) = capture.get(0) {
328 let start = matched.start();
329 let end = matched.end();
330
331 let mut score = pattern.score;
333
334 if !pattern.context_words.is_empty() {
336 score += self.check_context(text, start, end, &pattern.context_words);
337 score = score.min(1.0); }
339
340 if score >= self.min_score {
341 results.push(
342 RecognizerResult::new(
343 entity_type.clone(),
344 start,
345 end,
346 score,
347 self.name(),
348 )
349 .with_text(text),
350 );
351 }
352 }
353 }
354 }
355 }
356
357 Ok(results)
358 }
359
360 fn min_score(&self) -> f32 {
361 self.min_score
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_email_detection() {
371 let recognizer = PatternRecognizer::new();
372 let text = "Contact me at john.doe@example.com for details";
373 let results = recognizer.analyze(text, "en").unwrap();
374
375 assert_eq!(results.len(), 1);
376 assert_eq!(results[0].entity_type, EntityType::EmailAddress);
377 assert_eq!(results[0].text, Some("john.doe@example.com".to_string()));
378 assert!(results[0].score >= 0.8);
379 }
380
381 #[test]
382 fn test_phone_detection() {
383 let recognizer = PatternRecognizer::new();
384 let text = "Call me at (555) 123-4567";
385 let results = recognizer.analyze(text, "en").unwrap();
386
387 assert!(!results.is_empty());
388 let phone_result = results
389 .iter()
390 .find(|r| r.entity_type == EntityType::PhoneNumber);
391 assert!(phone_result.is_some());
392 }
393
394 #[test]
395 fn test_credit_card_detection() {
396 let recognizer = PatternRecognizer::new();
397 let text = "Card number: 4532015112830366";
398 let results = recognizer.analyze(text, "en").unwrap();
399
400 assert!(!results.is_empty());
401 let cc_result = results
402 .iter()
403 .find(|r| r.entity_type == EntityType::CreditCard);
404 assert!(cc_result.is_some());
405 }
406
407 #[test]
408 fn test_ssn_detection() {
409 let recognizer = PatternRecognizer::new();
410 let text = "SSN: 123-45-6789";
411 let results = recognizer.analyze(text, "en").unwrap();
412
413 assert!(!results.is_empty());
414 let ssn_result = results.iter().find(|r| r.entity_type == EntityType::UsSsn);
415 assert!(ssn_result.is_some());
416 }
417
418 #[test]
419 fn test_uk_nhs_with_context() {
420 let recognizer = PatternRecognizer::new();
421 let text = "NHS patient number is 123 456 7890";
422 let results = recognizer.analyze(text, "en").unwrap();
423
424 assert!(!results.is_empty());
425 let nhs_result = results.iter().find(|r| r.entity_type == EntityType::UkNhs);
426 assert!(nhs_result.is_some());
427 if let Some(result) = nhs_result {
429 assert!(result.score > 0.6);
430 }
431 }
432
433 #[test]
434 fn test_uk_nino_detection() {
435 let recognizer = PatternRecognizer::new();
436 let text = "NINO: AB123456C";
437 let results = recognizer.analyze(text, "en").unwrap();
438
439 assert!(!results.is_empty());
440 let nino_result = results.iter().find(|r| r.entity_type == EntityType::UkNino);
441 assert!(nino_result.is_some());
442 }
443
444 #[test]
445 fn test_multiple_entities() {
446 let recognizer = PatternRecognizer::new();
447 let text = "Email john@example.com, phone (555) 123-4567, SSN 123-45-6789";
448 let results = recognizer.analyze(text, "en").unwrap();
449
450 assert!(results.len() >= 3);
451 assert!(results
452 .iter()
453 .any(|r| r.entity_type == EntityType::EmailAddress));
454 assert!(results
455 .iter()
456 .any(|r| r.entity_type == EntityType::PhoneNumber));
457 assert!(results.iter().any(|r| r.entity_type == EntityType::UsSsn));
458 }
459
460 #[test]
461 fn test_custom_pattern() {
462 let mut recognizer = PatternRecognizer::new();
463 recognizer
464 .add_pattern(
465 EntityType::Custom("CUSTOM_ID".to_string()),
466 r"\bCID-\d{6}\b",
467 0.9,
468 )
469 .unwrap();
470
471 let text = "Your customer ID is CID-123456";
472 let results = recognizer.analyze(text, "en").unwrap();
473
474 let custom_result = results
475 .iter()
476 .find(|r| matches!(r.entity_type, EntityType::Custom(_)));
477 assert!(custom_result.is_some());
478 }
479
480 #[test]
481 fn test_min_score_filtering() {
482 let recognizer = PatternRecognizer::new().with_min_score(0.9);
483 let text = "Date: 2024-01-15"; let results = recognizer.analyze(text, "en").unwrap();
485
486 let date_results = results
488 .iter()
489 .filter(|r| r.entity_type == EntityType::DateTime)
490 .count();
491 assert_eq!(date_results, 0);
492 }
493}