use pii::analyzer::Analyzer;
use pii::config::PolicyConfig;
use pii::context::{ContextEnhancer, LemmaContextEnhancer};
use pii::nlp::SimpleNlpEngine;
use pii::presets::default_recognizers;
use pii::recognizers::ner::NerRecognizer;
use pii::recognizers::Recognizer;
use pii::types::{Detection, EntityType, Language, Token};
use pii::{Capabilities, NlpArtifacts};
use proptest::prelude::*;
use serde::Deserialize;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
#[derive(Debug, Deserialize)]
struct Fixture {
name: Option<String>,
category: Option<String>,
language: String,
text: String,
expected: Vec<ExpectedDetection>,
entities: Option<Vec<String>>,
}
#[derive(Debug, Deserialize)]
struct ExpectedDetection {
entity_type: String,
start: usize,
end: usize,
recognizer: String,
}
#[test]
fn test_fixtures() {
let base = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
let entries = fs::read_dir(&base).expect("fixtures directory missing");
let mut stats: HashMap<String, CategoryStats> = HashMap::new();
let mut failures: Vec<String> = Vec::new();
for entry in entries {
let entry = entry.expect("fixture read failed");
if entry.path().extension().and_then(|s| s.to_str()) != Some("json") {
continue;
}
let content = fs::read_to_string(entry.path()).expect("fixture read failed");
let fixture: Fixture = serde_json::from_str(&content).expect("invalid fixture json");
let mut policy = PolicyConfig::default();
if let Some(entities) = fixture.entities.as_ref() {
policy.enabled_entities = entities.iter().map(|name| parse_entity(name)).collect();
}
let analyzer = Analyzer::new(
Box::new(SimpleNlpEngine::default()),
default_recognizers(),
Vec::new(),
policy,
);
let category = fixture
.category
.clone()
.unwrap_or_else(|| "uncategorized".to_string());
let start = std::time::Instant::now();
let result = analyzer
.analyze(&fixture.text, &Language::from(fixture.language.as_str()))
.expect("analyze failed");
let elapsed = start.elapsed();
let mut passed = true;
if result.entities.len() != fixture.expected.len() {
failures.push(format!(
"fixture {} expected {} detections, got {}",
fixture.name.clone().unwrap_or_else(|| "unnamed".to_string()),
fixture.expected.len(),
result.entities.len()
));
passed = false;
}
for expected in fixture.expected {
let entity_type = parse_entity(&expected.entity_type);
let found = result.entities.iter().find(|det| {
det.entity_type == entity_type
&& det.start == expected.start
&& det.end == expected.end
&& det.recognizer == expected.recognizer
});
if found.is_none() {
failures.push(format!(
"fixture {} expected {:?} not found in {:?}",
fixture.name.clone().unwrap_or_else(|| "unnamed".to_string()),
expected,
result.entities
));
passed = false;
}
}
stats
.entry(category.clone())
.or_default()
.record(elapsed, passed);
}
report_stats(&stats);
if !failures.is_empty() {
panic!("fixture failures:\\n{}", failures.join("\\n"));
}
}
#[test]
fn test_lemma_context_enhancement() {
let tokens = vec![
Token {
text: "running".to_string(),
start: 0,
end: 7,
lemma: Some("run".to_string()),
pos: None,
},
Token {
text: "report".to_string(),
start: 8,
end: 14,
lemma: Some("report".to_string()),
pos: None,
},
];
let artifacts = NlpArtifacts {
language: Language::from("en"),
text_len: 14,
tokens: tokens.clone(),
sentences: vec![(0, 14)],
ner: Vec::new(),
capabilities: Capabilities {
token_offsets: true,
lemma: true,
pos: false,
ner: false,
sentences: true,
},
};
let detection = Detection {
entity_type: EntityType::Email,
start: 8,
end: 14,
score: 0.5,
recognizer: "test".to_string(),
explanation: pii::types::DetectionExplanation::Regex {
pattern_name: "email".to_string(),
},
};
let mut context = HashMap::new();
context.insert(
EntityType::Email,
pii::ContextTerms {
window_tokens: 2,
boost: 0.2,
terms: vec!["run".to_string()],
},
);
let enhancer = LemmaContextEnhancer::new(context);
let mut detections = vec![detection];
enhancer.enhance(&mut detections, "running report", &artifacts);
assert!(detections[0].score > 0.5);
}
#[test]
fn test_context_fallback_surface_terms() {
let tokens = vec![Token {
text: "running".to_string(),
start: 0,
end: 7,
lemma: None,
pos: None,
}];
let artifacts = NlpArtifacts {
language: Language::from("en"),
text_len: 7,
tokens: tokens.clone(),
sentences: Vec::new(),
ner: Vec::new(),
capabilities: Capabilities {
token_offsets: true,
lemma: false,
pos: false,
ner: false,
sentences: false,
},
};
let detection = Detection {
entity_type: EntityType::Email,
start: 0,
end: 7,
score: 0.5,
recognizer: "test".to_string(),
explanation: pii::types::DetectionExplanation::Regex {
pattern_name: "email".to_string(),
},
};
let mut context = HashMap::new();
context.insert(
EntityType::Email,
pii::ContextTerms {
window_tokens: 2,
boost: 0.2,
terms: vec!["running".to_string()],
},
);
let enhancer = LemmaContextEnhancer::new(context);
let mut detections = vec![detection];
enhancer.enhance(&mut detections, "running", &artifacts);
assert!(detections[0].score > 0.5);
}
#[test]
fn test_ner_disabled_disables_ner_recognizer() {
let artifacts = NlpArtifacts {
language: Language::from("en"),
text_len: 4,
tokens: Vec::new(),
sentences: Vec::new(),
ner: Vec::new(),
capabilities: Capabilities::basic(),
};
let recognizer = NerRecognizer::new(
"ner",
vec![(EntityType::Person, EntityType::Person)],
);
let detections = recognizer.analyze("John", &artifacts);
assert!(detections.is_empty());
}
proptest! {
#![proptest_config(ProptestConfig { failure_persistence: None, .. ProptestConfig::default() })]
#[test]
fn prop_email_offsets(prefix in "[a-z ]{0,40}", suffix in "[ ]{0,40}") {
prop_assume!(prefix.is_empty() || prefix.ends_with(' '));
let email = "user@example.com";
let text = format!("{}{}{}", prefix, email, suffix);
let analyzer = Analyzer::new(
Box::new(SimpleNlpEngine::default()),
default_recognizers(),
Vec::new(),
PolicyConfig::default(),
);
let result = analyzer.analyze(&text, &Language::from("en")).unwrap();
let detection = result.entities.iter().find(|det| det.entity_type == EntityType::Email).unwrap();
let expected_start = prefix.len();
let expected_end = expected_start + email.len();
assert_eq!(detection.start, expected_start);
assert_eq!(detection.end, expected_end);
assert_eq!(&text[detection.start..detection.end], email);
}
}
#[test]
#[ignore]
fn perf_smoke() {
let max_ms: u128 = std::env::var("PII_PERF_MAX_MS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(0);
if max_ms == 0 {
return;
}
let text = "Contact john@example.com. ".repeat(10_000);
let analyzer = Analyzer::new(
Box::new(SimpleNlpEngine::default()),
default_recognizers(),
Vec::new(),
PolicyConfig::default(),
);
let start = std::time::Instant::now();
let _ = analyzer.analyze(&text, &Language::from("en")).unwrap();
let elapsed = start.elapsed().as_millis();
assert!(elapsed <= max_ms, "perf regression: {} ms", elapsed);
}
fn parse_entity(value: &str) -> EntityType {
match value {
"Email" => EntityType::Email,
"Phone" => EntityType::Phone,
"IpAddress" => EntityType::IpAddress,
"Ipv6" => EntityType::Ipv6,
"CreditCard" => EntityType::CreditCard,
"Iban" => EntityType::Iban,
"Ssn" => EntityType::Ssn,
"Itin" => EntityType::Itin,
"TaxId" => EntityType::TaxId,
"Passport" => EntityType::Passport,
"DriverLicense" => EntityType::DriverLicense,
"BankAccount" => EntityType::BankAccount,
"RoutingNumber" => EntityType::RoutingNumber,
"CryptoAddress" => EntityType::CryptoAddress,
"MacAddress" => EntityType::MacAddress,
"Uuid" => EntityType::Uuid,
"Vin" => EntityType::Vin,
"Imei" => EntityType::Imei,
"Url" => EntityType::Url,
"Domain" => EntityType::Domain,
"Hostname" => EntityType::Hostname,
"Person" => EntityType::Person,
"Location" => EntityType::Location,
"Organization" => EntityType::Organization,
other => EntityType::Custom(other.to_string()),
}
}
#[derive(Default)]
struct CategoryStats {
count: usize,
total_ms: u128,
passed: usize,
failed: usize,
}
impl CategoryStats {
fn record(&mut self, duration: std::time::Duration, passed: bool) {
self.count += 1;
self.total_ms += duration.as_millis();
if passed {
self.passed += 1;
} else {
self.failed += 1;
}
}
}
fn report_stats(stats: &HashMap<String, CategoryStats>) {
if stats.is_empty() {
return;
}
let mut keys: Vec<_> = stats.keys().collect();
keys.sort();
eprintln!("fixture timing summary:");
for key in keys {
if let Some(stat) = stats.get(key) {
let avg = if stat.count == 0 {
0
} else {
stat.total_ms / stat.count as u128
};
eprintln!(
" category={} count={} passed={} failed={} total_ms={} avg_ms={}",
key, stat.count, stat.passed, stat.failed, stat.total_ms, avg
);
}
}
}