use crate::entity::{EntityType, RecognizerResult};
use regex::Regex;
use serde::{Deserialize, Serialize};
pub trait Recognizer: Send + Sync {
fn name(&self) -> &str;
fn supported_entities(&self) -> &[EntityType];
fn analyze(&self, text: &str, entities: &[EntityType]) -> Vec<RecognizerResult>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PatternDef {
pub name: String,
pub regex: String,
pub score: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RecognizerDef {
pub name: String,
pub entity_type: String,
pub version: String,
pub patterns: Vec<PatternDef>,
#[serde(default)]
pub context_words: Vec<String>,
#[serde(default)]
pub context_score_boost: f64,
#[serde(default)]
pub deny_list: Vec<String>,
#[serde(default)]
pub validators: Vec<String>,
pub supported_languages: Option<Vec<String>>,
}
pub struct RegexRecognizer {
def: RecognizerDef,
compiled: Vec<(String, Regex, f64)>,
entity: EntityType,
}
impl RegexRecognizer {
pub fn from_def(def: RecognizerDef) -> Result<Self, regex::Error> {
let mut compiled = Vec::new();
for p in &def.patterns {
let re = Regex::new(&p.regex)?;
compiled.push((p.name.clone(), re, p.score));
}
let entity = EntityType::new(&def.entity_type);
Ok(Self { def, compiled, entity })
}
pub fn from_json(json: &str) -> Result<Self, Box<dyn std::error::Error>> {
let def: RecognizerDef = serde_json::from_str(json)?;
Ok(Self::from_def(def)?)
}
fn has_context(&self, text: &str, start: usize, end: usize) -> bool {
if self.def.context_words.is_empty() {
return false;
}
let window_start = start.saturating_sub(100);
let window_end = (end + 100).min(text.len());
let window = &text[window_start..window_end].to_lowercase();
self.def.context_words.iter().any(|w| window.contains(&w.to_lowercase()))
}
fn is_denied(&self, matched: &str) -> bool {
self.def.deny_list.iter().any(|d| matched == d)
}
fn validate(&self, matched: &str) -> bool {
for v in &self.def.validators {
match v.as_str() {
"luhn" => {
if !luhn_check(matched) {
return false;
}
}
"cn_id_checksum" => {
if !cn_id_check(matched) {
return false;
}
}
_ => {}
}
}
true
}
}
impl Recognizer for RegexRecognizer {
fn name(&self) -> &str {
&self.def.name
}
fn supported_entities(&self) -> &[EntityType] {
std::slice::from_ref(&self.entity)
}
fn analyze(&self, text: &str, entities: &[EntityType]) -> Vec<RecognizerResult> {
if !entities.is_empty() && !entities.contains(&self.entity) {
return Vec::new();
}
let mut results = Vec::new();
for (pat_name, re, base_score) in &self.compiled {
for m in re.find_iter(text) {
let matched = m.as_str();
if self.is_denied(matched) {
continue;
}
if !self.validate(matched) {
continue;
}
let mut score = *base_score;
if self.has_context(text, m.start(), m.end()) {
score = (score + self.def.context_score_boost).min(1.0);
}
results.push(RecognizerResult {
entity_type: self.entity.clone(),
start: m.start(),
end: m.end(),
score,
recognizer_name: Some(pat_name.clone()),
});
}
}
results
}
}
fn luhn_check(number: &str) -> bool {
let digits: Vec<u32> = number
.chars()
.filter(|c| c.is_ascii_digit())
.filter_map(|c| c.to_digit(10))
.collect();
if digits.len() < 2 {
return false;
}
let mut sum = 0u32;
let mut double = false;
for &d in digits.iter().rev() {
let mut val = d;
if double {
val *= 2;
if val > 9 {
val -= 9;
}
}
sum += val;
double = !double;
}
sum % 10 == 0
}
fn cn_id_check(id: &str) -> bool {
if id.len() != 18 {
return false;
}
let weights = [7, 9, 10, 5, 8, 4, 2, 1, 6, 3, 7, 9, 10, 5, 8, 4, 2];
let check_chars = ['1', '0', 'X', '9', '8', '7', '6', '5', '4', '3', '2'];
let chars: Vec<char> = id.chars().collect();
let mut sum = 0usize;
for i in 0..17 {
let d = match chars[i].to_digit(10) {
Some(d) => d as usize,
None => return false,
};
sum += d * weights[i];
}
let expected = check_chars[sum % 11];
chars[17].to_ascii_uppercase() == expected
}
pub fn load_recognizers_from_dir(dir: &std::path::Path) -> Vec<Box<dyn Recognizer>> {
let mut recognizers: Vec<Box<dyn Recognizer>> = Vec::new();
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().map_or(false, |e| e == "json") {
if let Ok(json) = std::fs::read_to_string(&path) {
match RegexRecognizer::from_json(&json) {
Ok(r) => recognizers.push(Box::new(r)),
Err(e) => eprintln!("Failed to load {:?}: {}", path, e),
}
}
}
}
}
recognizers
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_luhn_valid() {
assert!(luhn_check("4532015112830366"));
assert!(luhn_check("4111111111111111"));
}
#[test]
fn test_luhn_invalid() {
assert!(!luhn_check("1234567890123456"));
}
#[test]
fn test_cn_id_valid() {
assert!(cn_id_check("11010519491231002X"));
}
#[test]
fn test_cn_id_invalid() {
assert!(!cn_id_check("110105194912310020"));
}
#[test]
fn test_regex_recognizer_email() {
let json = r#"{
"name": "email_recognizer",
"entity_type": "EMAIL_ADDRESS",
"version": "1.0.0",
"patterns": [{"name": "email", "regex": "[A-Za-z0-9._%+\\-]+@[A-Za-z0-9.\\-]+\\.[A-Za-z]{2,}", "score": 0.5}],
"context_words": ["email"],
"context_score_boost": 0.4
}"#;
let rec = RegexRecognizer::from_json(json).unwrap();
let results = rec.analyze("Contact me at test@example.com please", &[]);
assert_eq!(results.len(), 1);
assert_eq!(results[0].entity_type.as_str(), "EMAIL_ADDRESS");
assert_eq!(&"Contact me at test@example.com please"[results[0].start..results[0].end], "test@example.com");
}
#[test]
fn test_context_boost() {
let json = r#"{
"name": "email_recognizer",
"entity_type": "EMAIL_ADDRESS",
"version": "1.0.0",
"patterns": [{"name": "email", "regex": "[A-Za-z0-9._%+\\-]+@[A-Za-z0-9.\\-]+\\.[A-Za-z]{2,}", "score": 0.5}],
"context_words": ["email"],
"context_score_boost": 0.4
}"#;
let rec = RegexRecognizer::from_json(json).unwrap();
let with_ctx = rec.analyze("My email is test@example.com", &[]);
let without_ctx = rec.analyze("test@example.com", &[]);
assert!(with_ctx[0].score > without_ctx[0].score);
}
#[test]
fn test_deny_list() {
let json = r#"{
"name": "ip_recognizer",
"entity_type": "IP_ADDRESS",
"version": "1.0.0",
"patterns": [{"name": "ipv4", "regex": "\\b(?:(?:25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\.){3}(?:25[0-5]|2[0-4]\\d|[01]?\\d\\d?)\\b", "score": 0.5}],
"deny_list": ["0.0.0.0", "127.0.0.1"],
"context_words": []
}"#;
let rec = RegexRecognizer::from_json(json).unwrap();
let results = rec.analyze("Server at 127.0.0.1 and 192.168.1.1", &[]);
assert_eq!(results.len(), 1);
assert_eq!(&"Server at 127.0.0.1 and 192.168.1.1"[results[0].start..results[0].end], "192.168.1.1");
}
}