1use super::{validation::validate_entity, Recognizer, RecognizerResult};
7use crate::types::EntityType;
8use anyhow::Result;
9use lazy_static::lazy_static;
10use regex::Regex;
11use std::collections::HashMap;
12
13#[derive(Debug, Clone)]
15pub struct PatternRecognizer {
16 name: String,
17 patterns: HashMap<EntityType, Vec<CompiledPattern>>,
18 min_score: f32,
19}
20
21#[derive(Debug, Clone)]
22struct CompiledPattern {
23 regex: Regex,
24 score: f32,
25 context_words: Vec<String>,
26}
27
28impl PatternRecognizer {
29 pub fn new() -> Self {
31 let mut recognizer = Self {
32 name: "PatternRecognizer".to_string(),
33 patterns: HashMap::new(),
34 min_score: 0.5,
35 };
36 recognizer.load_default_patterns();
37 recognizer
38 }
39
40 pub fn with_name(name: impl Into<String>) -> Self {
42 let mut recognizer = Self::new();
43 recognizer.name = name.into();
44 recognizer
45 }
46
47 pub fn with_min_score(mut self, min_score: f32) -> Self {
49 self.min_score = min_score;
50 self
51 }
52
53 pub fn add_pattern(
55 &mut self,
56 entity_type: EntityType,
57 pattern: &str,
58 score: f32,
59 ) -> Result<()> {
60 let regex = Regex::new(pattern)?;
61 let compiled = CompiledPattern {
62 regex,
63 score,
64 context_words: vec![],
65 };
66 self.patterns.entry(entity_type).or_default().push(compiled);
67 Ok(())
68 }
69
70 pub fn add_pattern_with_context(
72 &mut self,
73 entity_type: EntityType,
74 pattern: &str,
75 score: f32,
76 context_words: Vec<String>,
77 ) -> Result<()> {
78 let regex = Regex::new(pattern)?;
79 let compiled = CompiledPattern {
80 regex,
81 score,
82 context_words,
83 };
84 self.patterns.entry(entity_type).or_default().push(compiled);
85 Ok(())
86 }
87
88 fn load_default_patterns(&mut self) {
90 let _ = self.add_pattern(
92 EntityType::EmailAddress,
93 r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b",
94 0.8,
95 );
96
97 let _ = self.add_pattern(
103 EntityType::PhoneNumber,
104 r"\(\d{3}\)[-.\s]?\d{3}[-.\s]?\d{4}\b|\b\d{3}[-.\s]\d{3}[-.\s]?\d{4}\b",
105 0.7,
106 );
107
108 let _ = self.add_pattern(
110 EntityType::CreditCard,
111 r"\b(?:4[0-9]{12}(?:[0-9]{3})?|5[1-5][0-9]{14}|3[47][0-9]{13}|6(?:011|5[0-9]{2})[0-9]{12})\b",
112 0.9,
113 );
114
115 let _ = self.add_pattern(EntityType::UsSsn, r"\b\d{3}-\d{2}-\d{4}\b", 0.9);
118
119 let _ = self.add_pattern(
121 EntityType::IpAddress,
122 r"\b(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\b",
123 0.8,
124 );
125
126 let _ = self.add_pattern(
128 EntityType::Url,
129 r"\b(?:https?://|www\.)[a-zA-Z0-9][-a-zA-Z0-9]*(?:\.[a-zA-Z0-9][-a-zA-Z0-9]*)+(?:/[^\s]*)?\b",
130 0.7,
131 );
132
133 let _ = self.add_pattern(
135 EntityType::DomainName,
136 r"\b(?:[A-Za-z0-9](?:[A-Za-z0-9-]{0,61}[A-Za-z0-9])?\.)+[A-Za-z]{2,}\b",
137 0.7,
138 );
139
140 let _ = self.add_pattern(
142 EntityType::Guid,
143 r"\b[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}\b",
144 0.9,
145 );
146
147 let _ = self.add_pattern(
149 EntityType::MacAddress,
150 r"\b(?:[0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b",
151 0.9,
152 );
153
154 let _ = self.add_pattern_with_context(
156 EntityType::UkNhs,
157 r"\b(?:\d{3}\s?\d{3}\s?\d{4}|\d{10})\b",
158 0.6,
159 vec![
160 "NHS".to_string(),
161 "patient".to_string(),
162 "health".to_string(),
163 ],
164 );
165
166 let _ = self.add_pattern(
168 EntityType::UkNino,
169 r"\b[A-CEGHJ-PR-TW-Z]{1}[A-CEGHJ-NPR-TW-Z]{1}\d{6}[A-D]{1}\b",
170 0.85,
171 );
172
173 let _ = self.add_pattern(
175 EntityType::UkPostcode,
176 r"\b[A-Z]{1,2}\d[A-Z\d]?\s?\d[A-Z]{2}\b",
177 0.75,
178 );
179
180 let _ = self.add_pattern(EntityType::UkSortCode, r"\b\d{2}-\d{2}-\d{2}\b", 0.7);
182
183 let _ = self.add_pattern(
185 EntityType::IbanCode,
186 r"\b[A-Z]{2}\d{2}[A-Z0-9]{1,30}\b",
187 0.75,
188 );
189
190 let _ = self.add_pattern(
192 EntityType::BtcAddress,
193 r"\b(?:bc1|[13])[a-zA-HJ-NP-Z0-9]{25,62}\b",
194 0.85,
195 );
196
197 let _ = self.add_pattern(EntityType::EthAddress, r"\b0x[a-fA-F0-9]{40}\b", 0.9);
199
200 let _ = self.add_pattern(EntityType::Md5Hash, r"\b[a-fA-F0-9]{32}\b", 0.6);
202
203 let _ = self.add_pattern(EntityType::Sha1Hash, r"\b[a-fA-F0-9]{40}\b", 0.6);
205
206 let _ = self.add_pattern(EntityType::Sha256Hash, r"\b[a-fA-F0-9]{64}\b", 0.6);
208
209 let _ = self.add_pattern(
211 EntityType::UsZipCode,
212 r"\b\d{5}(?:-\d{4})?\b",
213 0.6, );
215
216 let _ = self.add_pattern_with_context(
218 EntityType::PoBox,
219 r"\b(?:P\.?\s?O\.?|POST\s+OFFICE)\s*BOX\s+\d+\b",
220 0.85,
221 vec![
222 "address".to_string(),
223 "mail".to_string(),
224 "ship".to_string(),
225 ],
226 );
227
228 let _ = self.add_pattern(
230 EntityType::Isbn,
231 r"\b(?:ISBN(?:-1[03])?:?\s*)?(?:\d{9}[\dX]|\d{13})\b",
232 0.8,
233 );
234
235 let _ = self.add_pattern_with_context(
237 EntityType::PassportNumber,
238 r"\b[A-Z]{1,2}\d{6,9}\b",
239 0.7,
240 vec!["passport".to_string(), "travel".to_string()],
241 );
242
243 let _ = self.add_pattern_with_context(
245 EntityType::MedicalRecordNumber,
246 r"\b(?:MRN|Medical\s*Record|Patient\s*ID):?\s*[A-Z0-9]{6,12}\b",
247 0.85,
248 vec![
249 "patient".to_string(),
250 "medical".to_string(),
251 "hospital".to_string(),
252 ],
253 );
254
255 let _ = self.add_pattern_with_context(
257 EntityType::Age,
258 r"\b(?:age|aged|years old):?\s*(\d{1,3})\b",
259 0.8,
260 vec!["years".to_string(), "old".to_string(), "age".to_string()],
261 );
262
263 let _ = self.add_pattern(
265 EntityType::DateTime,
266 r"\b\d{4}-\d{2}-\d{2}(?:[T\s]\d{2}:\d{2}(?::\d{2})?)?\b",
267 0.5,
268 );
269
270 let _ = self.add_pattern_with_context(
276 EntityType::UsDriverLicense,
277 r"\b[A-Z]\d{6,8}\b|\b[A-Z]\d{3}-\d{4}-\d{4}\b",
278 0.4,
279 vec![
280 "driver".to_string(),
281 "license".to_string(),
282 "DL".to_string(),
283 "DMV".to_string(),
284 ],
285 );
286
287 let _ = self.add_pattern_with_context(
290 EntityType::UsPassport,
291 r"\b[A-Z]?\d{9}\b",
292 0.4,
293 vec![
294 "passport".to_string(),
295 "travel".to_string(),
296 "state department".to_string(),
297 ],
298 );
299
300 let _ = self.add_pattern_with_context(
303 EntityType::UsBankNumber,
304 r"\b\d{8,17}\b",
305 0.3,
306 vec![
307 "account".to_string(),
308 "bank".to_string(),
309 "routing".to_string(),
310 "checking".to_string(),
311 "savings".to_string(),
312 ],
313 );
314
315 let _ = self.add_pattern(
318 EntityType::UkDriverLicense,
319 r"\b[A-Z]{5}\d{6}[A-Z0-9]{2}\d[A-Z]{2}\s?\d{2}\b",
320 0.85,
321 );
322
323 let _ = self.add_pattern_with_context(
326 EntityType::UkPassportNumber,
327 r"\b\d{9}\b",
328 0.3,
329 vec![
330 "passport".to_string(),
331 "travel".to_string(),
332 "HMPO".to_string(),
333 ],
334 );
335
336 let _ = self.add_pattern(
338 EntityType::UkPhoneNumber,
339 r"\b(?:0[1-3]\d{2,3}\s?\d{3}\s?\d{4}|0[1-3]\d{2,3}\s?\d{6,7})\b",
340 0.75,
341 );
342
343 let _ = self.add_pattern(
345 EntityType::UkMobileNumber,
346 r"\b07\d{3}\s?\d{3}\s?\d{3}\b",
347 0.8,
348 );
349
350 let _ = self.add_pattern_with_context(
353 EntityType::UkCompanyNumber,
354 r"\b(?:\d{8}|[A-Z]{2}\d{6})\b",
355 0.3,
356 vec![
357 "company".to_string(),
358 "companies house".to_string(),
359 "registration".to_string(),
360 "CRN".to_string(),
361 ],
362 );
363
364 let _ = self.add_pattern_with_context(
366 EntityType::MedicalLicense,
367 r"\b(?:MD|DO|NP|PA|RN|LPN)[-\s]?\d{5,10}\b",
368 0.8,
369 vec![
370 "license".to_string(),
371 "medical".to_string(),
372 "physician".to_string(),
373 "doctor".to_string(),
374 "nurse".to_string(),
375 ],
376 );
377
378 let _ = self.add_pattern_with_context(
381 EntityType::CryptoWallet,
382 r"\b[LMr3][a-km-zA-HJ-NP-Z1-9]{25,34}\b",
383 0.75,
384 vec![
385 "wallet".to_string(),
386 "crypto".to_string(),
387 "address".to_string(),
388 "coin".to_string(),
389 ],
390 );
391 }
392
393 fn check_context(&self, text: &str, start: usize, end: usize, context_words: &[String]) -> f32 {
395 if context_words.is_empty() {
396 return 0.0;
397 }
398
399 let context_start = start.saturating_sub(50);
401 let context_end = (end + 50).min(text.len());
402 let context = &text[context_start..context_end].to_lowercase();
403
404 let matches = context_words
406 .iter()
407 .filter(|word| context.contains(&word.to_lowercase()))
408 .count();
409
410 (matches as f32 / context_words.len() as f32) * 0.3
412 }
413}
414
415impl Default for PatternRecognizer {
416 fn default() -> Self {
417 Self::new()
418 }
419}
420
421impl Recognizer for PatternRecognizer {
422 fn name(&self) -> &str {
423 &self.name
424 }
425
426 fn supported_entities(&self) -> &[EntityType] {
427 lazy_static! {
428 static ref SUPPORTED: Vec<EntityType> = vec![
429 EntityType::EmailAddress,
431 EntityType::PhoneNumber,
432 EntityType::IpAddress,
433 EntityType::Url,
434 EntityType::DomainName,
435 EntityType::CreditCard,
437 EntityType::IbanCode,
438 EntityType::UsBankNumber,
439 EntityType::UsSsn,
441 EntityType::UsDriverLicense,
442 EntityType::UsPassport,
443 EntityType::UsZipCode,
444 EntityType::UkNhs,
446 EntityType::UkNino,
447 EntityType::UkPostcode,
448 EntityType::UkSortCode,
449 EntityType::UkDriverLicense,
450 EntityType::UkPassportNumber,
451 EntityType::UkPhoneNumber,
452 EntityType::UkMobileNumber,
453 EntityType::UkCompanyNumber,
454 EntityType::MedicalLicense,
456 EntityType::MedicalRecordNumber,
457 EntityType::PassportNumber,
459 EntityType::Age,
460 EntityType::Isbn,
461 EntityType::PoBox,
462 EntityType::DateTime,
463 EntityType::CryptoWallet,
465 EntityType::BtcAddress,
466 EntityType::EthAddress,
467 EntityType::Guid,
469 EntityType::MacAddress,
470 EntityType::Md5Hash,
471 EntityType::Sha1Hash,
472 EntityType::Sha256Hash,
473 ];
474 }
475 &SUPPORTED
476 }
477
478 fn analyze(&self, text: &str, _language: &str) -> Result<Vec<RecognizerResult>> {
479 let mut results = Vec::new();
480
481 for (entity_type, patterns) in &self.patterns {
482 for pattern in patterns {
483 for capture in pattern.regex.captures_iter(text) {
484 if let Some(matched) = capture.get(0) {
485 let start = matched.start();
486 let end = matched.end();
487 let matched_text = matched.as_str();
488
489 let mut score = pattern.score;
491
492 if !pattern.context_words.is_empty() {
494 score += self.check_context(text, start, end, &pattern.context_words);
495 score = score.min(1.0); }
497
498 let validation_factor = validate_entity(entity_type, matched_text);
501 score *= validation_factor;
502
503 if score >= self.min_score {
504 results.push(
505 RecognizerResult::new(
506 entity_type.clone(),
507 start,
508 end,
509 score,
510 self.name(),
511 )
512 .with_text(text),
513 );
514 }
515 }
516 }
517 }
518 }
519
520 Ok(results)
521 }
522
523 fn min_score(&self) -> f32 {
524 self.min_score
525 }
526}
527
528#[cfg(test)]
529mod tests {
530 use super::*;
531
532 #[test]
533 fn test_email_detection() {
534 let recognizer = PatternRecognizer::new();
535 let text = "Contact me at john.doe@example.com for details";
536 let results = recognizer.analyze(text, "en").unwrap();
537
538 let email_results: Vec<_> = results
539 .iter()
540 .filter(|r| r.entity_type == EntityType::EmailAddress)
541 .collect();
542 assert_eq!(email_results.len(), 1);
543 assert_eq!(
544 email_results[0].text,
545 Some("john.doe@example.com".to_string())
546 );
547 assert!(email_results[0].score >= 0.8);
548 }
549
550 #[test]
551 fn test_phone_detection() {
552 let recognizer = PatternRecognizer::new();
553 let text = "Call me at (555) 123-4567";
554 let results = recognizer.analyze(text, "en").unwrap();
555
556 assert!(!results.is_empty());
557 let phone_result = results
558 .iter()
559 .find(|r| r.entity_type == EntityType::PhoneNumber);
560 assert!(phone_result.is_some());
561 }
562
563 #[test]
564 fn test_credit_card_detection() {
565 let recognizer = PatternRecognizer::new();
566 let text = "Card number: 4532015112830366";
567 let results = recognizer.analyze(text, "en").unwrap();
568
569 assert!(!results.is_empty());
570 let cc_result = results
571 .iter()
572 .find(|r| r.entity_type == EntityType::CreditCard);
573 assert!(cc_result.is_some());
574 }
575
576 #[test]
577 fn test_ssn_detection() {
578 let recognizer = PatternRecognizer::new();
579 let text = "SSN: 123-45-6789";
580 let results = recognizer.analyze(text, "en").unwrap();
581
582 assert!(!results.is_empty());
583 let ssn_result = results.iter().find(|r| r.entity_type == EntityType::UsSsn);
584 assert!(ssn_result.is_some());
585 }
586
587 #[test]
588 fn test_uk_nhs_with_context() {
589 let recognizer = PatternRecognizer::new();
590 let text = "NHS patient number is 401 023 2137";
595 let results = recognizer.analyze(text, "en").unwrap();
596
597 assert!(!results.is_empty());
598 let nhs_result = results.iter().find(|r| r.entity_type == EntityType::UkNhs);
599 assert!(
600 nhs_result.is_some(),
601 "Should detect NHS number with context"
602 );
603 if let Some(result) = nhs_result {
605 assert!(result.score > 0.6);
606 }
607 }
608
609 #[test]
610 fn test_uk_nino_detection() {
611 let recognizer = PatternRecognizer::new();
612 let text = "NINO: AB123456C";
613 let results = recognizer.analyze(text, "en").unwrap();
614
615 assert!(!results.is_empty());
616 let nino_result = results.iter().find(|r| r.entity_type == EntityType::UkNino);
617 assert!(nino_result.is_some());
618 }
619
620 #[test]
621 fn test_multiple_entities() {
622 let recognizer = PatternRecognizer::new();
623 let text = "Email john@example.com, phone (555) 123-4567, SSN 123-45-6789";
624 let results = recognizer.analyze(text, "en").unwrap();
625
626 assert!(results.len() >= 3);
627 assert!(results
628 .iter()
629 .any(|r| r.entity_type == EntityType::EmailAddress));
630 assert!(results
631 .iter()
632 .any(|r| r.entity_type == EntityType::PhoneNumber));
633 assert!(results.iter().any(|r| r.entity_type == EntityType::UsSsn));
634 }
635
636 #[test]
637 fn test_custom_pattern() {
638 let mut recognizer = PatternRecognizer::new();
639 recognizer
640 .add_pattern(
641 EntityType::Custom("CUSTOM_ID".to_string()),
642 r"\bCID-\d{6}\b",
643 0.9,
644 )
645 .unwrap();
646
647 let text = "Your customer ID is CID-123456";
648 let results = recognizer.analyze(text, "en").unwrap();
649
650 let custom_result = results
651 .iter()
652 .find(|r| matches!(r.entity_type, EntityType::Custom(_)));
653 assert!(custom_result.is_some());
654 }
655
656 #[test]
657 fn test_min_score_filtering() {
658 let recognizer = PatternRecognizer::new().with_min_score(0.9);
659 let text = "Date: 2024-01-15"; let results = recognizer.analyze(text, "en").unwrap();
661
662 let date_results = results
664 .iter()
665 .filter(|r| r.entity_type == EntityType::DateTime)
666 .count();
667 assert_eq!(date_results, 0);
668 }
669
670 #[test]
671 fn test_uk_driver_license_detection() {
672 let recognizer = PatternRecognizer::new();
673 let text = "UK DL: MORGA753116SM9IJ 35";
674 let results = recognizer.analyze(text, "en").unwrap();
675
676 let dl_result = results
677 .iter()
678 .find(|r| r.entity_type == EntityType::UkDriverLicense);
679 assert!(dl_result.is_some(), "Should detect UK driver's license");
680 }
681
682 #[test]
683 fn test_uk_mobile_detection() {
684 let recognizer = PatternRecognizer::new();
685 let text = "Call me on 07700 900123";
686 let results = recognizer.analyze(text, "en").unwrap();
687
688 let mobile_result = results
689 .iter()
690 .find(|r| r.entity_type == EntityType::UkMobileNumber);
691 assert!(mobile_result.is_some(), "Should detect UK mobile number");
692 }
693
694 #[test]
695 fn test_uk_phone_detection() {
696 let recognizer = PatternRecognizer::new();
697 let text = "Office: 0207 123 4567";
698 let results = recognizer.analyze(text, "en").unwrap();
699
700 let phone_result = results
701 .iter()
702 .find(|r| r.entity_type == EntityType::UkPhoneNumber);
703 assert!(phone_result.is_some(), "Should detect UK phone number");
704 }
705
706 #[test]
707 fn test_medical_license_detection() {
708 let recognizer = PatternRecognizer::new();
709 let text = "Medical license: MD-123456789";
710 let results = recognizer.analyze(text, "en").unwrap();
711
712 let license_result = results
713 .iter()
714 .find(|r| r.entity_type == EntityType::MedicalLicense);
715 assert!(license_result.is_some(), "Should detect medical license");
716 }
717
718 #[test]
719 fn test_supported_entities_count() {
720 let recognizer = PatternRecognizer::new();
721 let supported = recognizer.supported_entities();
722 assert_eq!(
724 supported.len(),
725 36,
726 "Should support 36 pattern-based entity types, got {}",
727 supported.len()
728 );
729 }
730}