pii 0.1.0

PII detection and anonymization with deterministic, capability-aware NLP pipelines.
Documentation
use pii::analyzer::Analyzer;
use pii::config::PolicyConfig;
use pii::context::{ContextEnhancer, LemmaContextEnhancer};
use pii::nlp::SimpleNlpEngine;
use pii::presets::default_recognizers;
use pii::recognizers::ner::NerRecognizer;
use pii::recognizers::Recognizer;
use pii::types::{Detection, EntityType, Language, Token};
use pii::{Capabilities, NlpArtifacts};
use proptest::prelude::*;
use serde::Deserialize;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;

#[derive(Debug, Deserialize)]
struct Fixture {
    name: Option<String>,
    category: Option<String>,
    language: String,
    text: String,
    expected: Vec<ExpectedDetection>,
    entities: Option<Vec<String>>,
}

#[derive(Debug, Deserialize)]
struct ExpectedDetection {
    entity_type: String,
    start: usize,
    end: usize,
    recognizer: String,
}

#[test]
fn test_fixtures() {
    let base = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures");
    let entries = fs::read_dir(&base).expect("fixtures directory missing");
    let mut stats: HashMap<String, CategoryStats> = HashMap::new();
    let mut failures: Vec<String> = Vec::new();
    for entry in entries {
        let entry = entry.expect("fixture read failed");
        if entry.path().extension().and_then(|s| s.to_str()) != Some("json") {
            continue;
        }
        let content = fs::read_to_string(entry.path()).expect("fixture read failed");
        let fixture: Fixture = serde_json::from_str(&content).expect("invalid fixture json");
        let mut policy = PolicyConfig::default();
        if let Some(entities) = fixture.entities.as_ref() {
            policy.enabled_entities = entities.iter().map(|name| parse_entity(name)).collect();
        }
        let analyzer = Analyzer::new(
            Box::new(SimpleNlpEngine::default()),
            default_recognizers(),
            Vec::new(),
            policy,
        );
        let category = fixture
            .category
            .clone()
            .unwrap_or_else(|| "uncategorized".to_string());
        let start = std::time::Instant::now();
        let result = analyzer
            .analyze(&fixture.text, &Language::from(fixture.language.as_str()))
            .expect("analyze failed");
        let elapsed = start.elapsed();
        let mut passed = true;

        if result.entities.len() != fixture.expected.len() {
            failures.push(format!(
                "fixture {} expected {} detections, got {}",
                fixture.name.clone().unwrap_or_else(|| "unnamed".to_string()),
                fixture.expected.len(),
                result.entities.len()
            ));
            passed = false;
        }

        for expected in fixture.expected {
            let entity_type = parse_entity(&expected.entity_type);
            let found = result.entities.iter().find(|det| {
                det.entity_type == entity_type
                    && det.start == expected.start
                    && det.end == expected.end
                    && det.recognizer == expected.recognizer
            });
            if found.is_none() {
                failures.push(format!(
                    "fixture {} expected {:?} not found in {:?}",
                    fixture.name.clone().unwrap_or_else(|| "unnamed".to_string()),
                    expected,
                    result.entities
                ));
                passed = false;
            }
        }

        stats
            .entry(category.clone())
            .or_default()
            .record(elapsed, passed);
    }

    report_stats(&stats);
    if !failures.is_empty() {
        panic!("fixture failures:\\n{}", failures.join("\\n"));
    }
}

#[test]
fn test_lemma_context_enhancement() {
    let tokens = vec![
        Token {
            text: "running".to_string(),
            start: 0,
            end: 7,
            lemma: Some("run".to_string()),
            pos: None,
        },
        Token {
            text: "report".to_string(),
            start: 8,
            end: 14,
            lemma: Some("report".to_string()),
            pos: None,
        },
    ];

    let artifacts = NlpArtifacts {
        language: Language::from("en"),
        text_len: 14,
        tokens: tokens.clone(),
        sentences: vec![(0, 14)],
        ner: Vec::new(),
        capabilities: Capabilities {
            token_offsets: true,
            lemma: true,
            pos: false,
            ner: false,
            sentences: true,
        },
    };

    let detection = Detection {
        entity_type: EntityType::Email,
        start: 8,
        end: 14,
        score: 0.5,
        recognizer: "test".to_string(),
        explanation: pii::types::DetectionExplanation::Regex {
            pattern_name: "email".to_string(),
        },
    };

    let mut context = HashMap::new();
    context.insert(
        EntityType::Email,
        pii::ContextTerms {
            window_tokens: 2,
            boost: 0.2,
            terms: vec!["run".to_string()],
        },
    );

    let enhancer = LemmaContextEnhancer::new(context);
    let mut detections = vec![detection];
    enhancer.enhance(&mut detections, "running report", &artifacts);

    assert!(detections[0].score > 0.5);
}

#[test]
fn test_context_fallback_surface_terms() {
    let tokens = vec![Token {
        text: "running".to_string(),
        start: 0,
        end: 7,
        lemma: None,
        pos: None,
    }];

    let artifacts = NlpArtifacts {
        language: Language::from("en"),
        text_len: 7,
        tokens: tokens.clone(),
        sentences: Vec::new(),
        ner: Vec::new(),
        capabilities: Capabilities {
            token_offsets: true,
            lemma: false,
            pos: false,
            ner: false,
            sentences: false,
        },
    };

    let detection = Detection {
        entity_type: EntityType::Email,
        start: 0,
        end: 7,
        score: 0.5,
        recognizer: "test".to_string(),
        explanation: pii::types::DetectionExplanation::Regex {
            pattern_name: "email".to_string(),
        },
    };

    let mut context = HashMap::new();
    context.insert(
        EntityType::Email,
        pii::ContextTerms {
            window_tokens: 2,
            boost: 0.2,
            terms: vec!["running".to_string()],
        },
    );

    let enhancer = LemmaContextEnhancer::new(context);
    let mut detections = vec![detection];
    enhancer.enhance(&mut detections, "running", &artifacts);

    assert!(detections[0].score > 0.5);
}

#[test]
fn test_ner_disabled_disables_ner_recognizer() {
    let artifacts = NlpArtifacts {
        language: Language::from("en"),
        text_len: 4,
        tokens: Vec::new(),
        sentences: Vec::new(),
        ner: Vec::new(),
        capabilities: Capabilities::basic(),
    };
    let recognizer = NerRecognizer::new(
        "ner",
        vec![(EntityType::Person, EntityType::Person)],
    );

    let detections = recognizer.analyze("John", &artifacts);
    assert!(detections.is_empty());
}

proptest! {
    #![proptest_config(ProptestConfig { failure_persistence: None, .. ProptestConfig::default() })]
    #[test]
    fn prop_email_offsets(prefix in "[a-z ]{0,40}", suffix in "[ ]{0,40}") {
        prop_assume!(prefix.is_empty() || prefix.ends_with(' '));
        let email = "user@example.com";
        let text = format!("{}{}{}", prefix, email, suffix);
        let analyzer = Analyzer::new(
            Box::new(SimpleNlpEngine::default()),
            default_recognizers(),
            Vec::new(),
            PolicyConfig::default(),
        );
        let result = analyzer.analyze(&text, &Language::from("en")).unwrap();
        let detection = result.entities.iter().find(|det| det.entity_type == EntityType::Email).unwrap();
        let expected_start = prefix.len();
        let expected_end = expected_start + email.len();
        assert_eq!(detection.start, expected_start);
        assert_eq!(detection.end, expected_end);
        assert_eq!(&text[detection.start..detection.end], email);
    }
}

#[test]
#[ignore]
fn perf_smoke() {
    let max_ms: u128 = std::env::var("PII_PERF_MAX_MS")
        .ok()
        .and_then(|v| v.parse().ok())
        .unwrap_or(0);
    if max_ms == 0 {
        return;
    }

    let text = "Contact john@example.com. ".repeat(10_000);
    let analyzer = Analyzer::new(
        Box::new(SimpleNlpEngine::default()),
        default_recognizers(),
        Vec::new(),
        PolicyConfig::default(),
    );
    let start = std::time::Instant::now();
    let _ = analyzer.analyze(&text, &Language::from("en")).unwrap();
    let elapsed = start.elapsed().as_millis();
    assert!(elapsed <= max_ms, "perf regression: {} ms", elapsed);
}

fn parse_entity(value: &str) -> EntityType {
    match value {
        "Email" => EntityType::Email,
        "Phone" => EntityType::Phone,
        "IpAddress" => EntityType::IpAddress,
        "Ipv6" => EntityType::Ipv6,
        "CreditCard" => EntityType::CreditCard,
        "Iban" => EntityType::Iban,
        "Ssn" => EntityType::Ssn,
        "Itin" => EntityType::Itin,
        "TaxId" => EntityType::TaxId,
        "Passport" => EntityType::Passport,
        "DriverLicense" => EntityType::DriverLicense,
        "BankAccount" => EntityType::BankAccount,
        "RoutingNumber" => EntityType::RoutingNumber,
        "CryptoAddress" => EntityType::CryptoAddress,
        "MacAddress" => EntityType::MacAddress,
        "Uuid" => EntityType::Uuid,
        "Vin" => EntityType::Vin,
        "Imei" => EntityType::Imei,
        "Url" => EntityType::Url,
        "Domain" => EntityType::Domain,
        "Hostname" => EntityType::Hostname,
        "Person" => EntityType::Person,
        "Location" => EntityType::Location,
        "Organization" => EntityType::Organization,
        other => EntityType::Custom(other.to_string()),
    }
}

#[derive(Default)]
struct CategoryStats {
    count: usize,
    total_ms: u128,
    passed: usize,
    failed: usize,
}

impl CategoryStats {
    fn record(&mut self, duration: std::time::Duration, passed: bool) {
        self.count += 1;
        self.total_ms += duration.as_millis();
        if passed {
            self.passed += 1;
        } else {
            self.failed += 1;
        }
    }
}

fn report_stats(stats: &HashMap<String, CategoryStats>) {
    if stats.is_empty() {
        return;
    }
    let mut keys: Vec<_> = stats.keys().collect();
    keys.sort();
    eprintln!("fixture timing summary:");
    for key in keys {
        if let Some(stat) = stats.get(key) {
            let avg = if stat.count == 0 {
                0
            } else {
                stat.total_ms / stat.count as u128
            };
            eprintln!(
                "  category={} count={} passed={} failed={} total_ms={} avg_ms={}",
                key, stat.count, stat.passed, stat.failed, stat.total_ms, avg
            );
        }
    }
}