tiktag 0.1.3

Rust library and CLI for multilingual text anonymization with a built-in ONNX NER model.
Documentation
// Internal model config loading for built-in multilingual DistilBERT NER model.
// The repo keeps its model path and token limits in models/profiles.toml
// as a flat single-model TOML (no selector, no nested map).

use std::fs;
use std::path::{Path, PathBuf};

use serde::Deserialize;

use crate::error::TiktagError;

/// Logical name of the built-in bundled profile.
pub const BUILTIN_PROFILE_NAME: &str = "distilbert_ner_hrl";

/// The built-in model config after validation and path resolution.
#[derive(Debug, Clone)]
pub struct ResolvedProfile {
    /// Logical profile name used by the library and CLI output.
    pub name: String,
    /// Hugging Face repo identifier for the bundled model.
    pub hf_repo: String,
    /// Resolved filesystem path to the model directory.
    pub model_dir: PathBuf,
    /// Maximum token count allowed by the profile.
    pub max_tokens: usize,
    /// Overlap between adjacent sliding windows.
    pub overlap_tokens: usize,
    /// Whether the built-in email recognizer is enabled.
    pub email_recognizer: bool,
}

/// Parsed built-in profile configuration before resolution.
///
/// This type is mainly useful for advanced callers that want to validate or
/// inspect the bundled profile settings before constructing [`crate::Tiktag`].
#[derive(Debug)]
pub struct Profiles {
    base_dir: PathBuf,
    profile: ProfileSpec,
}

#[derive(Debug, Clone)]
struct ProfileSpec {
    hf_repo: String,
    model_dir: PathBuf,
    max_tokens: usize,
    overlap_tokens: usize,
    email_recognizer: bool,
}

#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
struct ProfileFileRaw {
    hf_repo: String,
    model_dir: PathBuf,
    max_tokens: usize,
    overlap_tokens: usize,
    #[serde(default)]
    recognizers: RecognizersRaw,
}

#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
struct RecognizersRaw {
    #[serde(default = "default_true")]
    email: bool,
}

impl Default for RecognizersRaw {
    fn default() -> Self {
        Self { email: true }
    }
}

fn default_true() -> bool {
    true
}

impl Profiles {
    /// Loads and validates the built-in profile file from disk.
    pub fn load(path: &Path) -> Result<Self, TiktagError> {
        let raw_text = fs::read_to_string(path).map_err(|source| TiktagError::ProfileRead {
            path: path.to_path_buf(),
            source,
        })?;
        let base_dir = path
            .parent()
            .map(Path::to_path_buf)
            .unwrap_or_else(|| PathBuf::from("."));

        // Parse errors get a dedicated typed variant so callers can distinguish
        // TOML decoding failures from semantic validation.
        let raw: ProfileFileRaw =
            toml::from_str(&raw_text).map_err(|source| TiktagError::ProfileParse {
                path: path.to_path_buf(),
                source,
            })?;

        Self::validate_raw(&base_dir, raw).map_err(TiktagError::ProfileInvalid)
    }

    /// Resolves the single built-in profile into absolute runtime settings.
    pub fn resolve_default(&self) -> ResolvedProfile {
        ResolvedProfile {
            name: BUILTIN_PROFILE_NAME.to_owned(),
            hf_repo: self.profile.hf_repo.clone(),
            model_dir: resolve_profile_model_dir(&self.base_dir, &self.profile.model_dir),
            max_tokens: self.profile.max_tokens,
            overlap_tokens: self.profile.overlap_tokens,
            email_recognizer: self.profile.email_recognizer,
        }
    }

    fn validate_raw(base_dir: &Path, raw: ProfileFileRaw) -> Result<Self, String> {
        if raw.hf_repo.trim().is_empty() {
            return Err(format!(
                "profile '{BUILTIN_PROFILE_NAME}' has empty hf_repo"
            ));
        }
        if raw.max_tokens == 0 {
            return Err(format!(
                "profile '{BUILTIN_PROFILE_NAME}' has invalid max_tokens=0"
            ));
        }
        // Reserve 2 tokens for [CLS] and [SEP]; overlap must fit inside the
        // remaining content budget or sliding-window stride makes no progress.
        let content_tokens = raw.max_tokens.saturating_sub(2);
        if raw.overlap_tokens >= content_tokens {
            return Err(format!(
                "profile '{BUILTIN_PROFILE_NAME}' has overlap_tokens={} which must be less than max_tokens - 2 ({})",
                raw.overlap_tokens, content_tokens
            ));
        }

        Ok(Self {
            base_dir: base_dir.to_path_buf(),
            profile: ProfileSpec {
                hf_repo: raw.hf_repo,
                model_dir: raw.model_dir,
                max_tokens: raw.max_tokens,
                overlap_tokens: raw.overlap_tokens,
                email_recognizer: raw.recognizers.email,
            },
        })
    }

    #[cfg(test)]
    fn from_raw(base_dir: &Path, raw_text: &str) -> anyhow::Result<Self> {
        let raw: ProfileFileRaw = toml::from_str(raw_text)?;
        Self::validate_raw(base_dir, raw).map_err(|msg| anyhow::anyhow!(msg))
    }
}

/// Resolve model_dir: absolute paths pass through, relative ones are joined to base_dir.
fn resolve_profile_model_dir(base_dir: &Path, model_dir: &Path) -> PathBuf {
    if model_dir.is_absolute() {
        model_dir.to_path_buf()
    } else {
        base_dir.join(model_dir)
    }
}

#[cfg(test)]
mod tests {
    use std::path::PathBuf;

    use super::{BUILTIN_PROFILE_NAME, Profiles};

    #[test]
    fn loads_and_resolves_builtin_profile() {
        let profiles = Profiles::from_raw(
            &PathBuf::from("models"),
            r#"
hf_repo = "Xenova/distilbert-base-multilingual-cased-ner-hrl"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 512
overlap_tokens = 128
"#,
        )
        .expect("profiles should parse");

        let resolved = profiles.resolve_default();
        assert_eq!(resolved.name, BUILTIN_PROFILE_NAME);
        assert_eq!(
            resolved.hf_repo,
            "Xenova/distilbert-base-multilingual-cased-ner-hrl"
        );
        assert_eq!(
            resolved.model_dir,
            PathBuf::from("models/distilbert-base-multilingual-cased-ner-hrl")
        );
        assert_eq!(resolved.max_tokens, 512);
        assert_eq!(resolved.overlap_tokens, 128);
        assert!(resolved.email_recognizer);
    }

    #[test]
    fn rejects_zero_max_tokens() {
        let err = Profiles::from_raw(
            &PathBuf::from("models"),
            r#"
hf_repo = "Xenova/distilbert-base-multilingual-cased-ner-hrl"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 0
overlap_tokens = 0
"#,
        )
        .expect_err("zero max_tokens should fail");

        assert!(err.to_string().contains("invalid max_tokens=0"));
    }

    #[test]
    fn rejects_unknown_fields() {
        let err = Profiles::from_raw(
            &PathBuf::from("models"),
            r#"
hf_repo = "Xenova/distilbert-base-multilingual-cased-ner-hrl"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 512
overlap_tokens = 128
extra_key = "nope"
"#,
        )
        .expect_err("unknown fields should fail");

        assert!(err.to_string().contains("unknown field"));
    }

    #[test]
    fn rejects_missing_hf_repo() {
        let err = Profiles::from_raw(
            &PathBuf::from("models"),
            r#"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 512
overlap_tokens = 128
"#,
        )
        .expect_err("missing hf_repo should fail");

        assert!(err.to_string().contains("missing field `hf_repo`"));
    }

    #[test]
    fn rejects_overlap_exceeding_limit() {
        let err = Profiles::from_raw(
            &PathBuf::from("models"),
            r#"
hf_repo = "Xenova/distilbert-base-multilingual-cased-ner-hrl"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 512
overlap_tokens = 510
"#,
        )
        .expect_err("overlap >= max_tokens - 2 should fail");

        assert!(err.to_string().contains("overlap_tokens=510"));
    }

    #[test]
    fn recognizers_email_can_be_disabled() {
        let profiles = Profiles::from_raw(
            &PathBuf::from("models"),
            r#"
hf_repo = "Xenova/distilbert-base-multilingual-cased-ner-hrl"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 512
overlap_tokens = 128

[recognizers]
email = false
"#,
        )
        .expect("profiles should parse");

        let resolved = profiles.resolve_default();
        assert!(!resolved.email_recognizer);
    }

    #[test]
    fn rejects_unknown_recognizer_fields() {
        let err = Profiles::from_raw(
            &PathBuf::from("models"),
            r#"
hf_repo = "Xenova/distilbert-base-multilingual-cased-ner-hrl"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 512
overlap_tokens = 128

[recognizers]
email = true
unknown_key = true
"#,
        )
        .expect_err("unknown recognizer field should fail");

        assert!(err.to_string().contains("unknown field"));
    }
}