use std::fs;
use std::path::{Path, PathBuf};
use serde::Deserialize;
use crate::error::TiktagError;
pub const BUILTIN_PROFILE_NAME: &str = "distilbert_ner_hrl";
#[derive(Debug, Clone)]
pub struct ResolvedProfile {
pub name: String,
pub hf_repo: String,
pub model_dir: PathBuf,
pub max_tokens: usize,
pub overlap_tokens: usize,
pub email_recognizer: bool,
}
#[derive(Debug)]
pub struct Profiles {
base_dir: PathBuf,
profile: ProfileSpec,
}
#[derive(Debug, Clone)]
struct ProfileSpec {
hf_repo: String,
model_dir: PathBuf,
max_tokens: usize,
overlap_tokens: usize,
email_recognizer: bool,
}
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
struct ProfileFileRaw {
hf_repo: String,
model_dir: PathBuf,
max_tokens: usize,
overlap_tokens: usize,
#[serde(default)]
recognizers: RecognizersRaw,
}
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
struct RecognizersRaw {
#[serde(default = "default_true")]
email: bool,
}
impl Default for RecognizersRaw {
fn default() -> Self {
Self { email: true }
}
}
fn default_true() -> bool {
true
}
impl Profiles {
pub fn load(path: &Path) -> Result<Self, TiktagError> {
let raw_text = fs::read_to_string(path).map_err(|source| TiktagError::ProfileRead {
path: path.to_path_buf(),
source,
})?;
let base_dir = path
.parent()
.map(Path::to_path_buf)
.unwrap_or_else(|| PathBuf::from("."));
let raw: ProfileFileRaw =
toml::from_str(&raw_text).map_err(|source| TiktagError::ProfileParse {
path: path.to_path_buf(),
source,
})?;
Self::validate_raw(&base_dir, raw).map_err(TiktagError::ProfileInvalid)
}
pub fn resolve_default(&self) -> ResolvedProfile {
ResolvedProfile {
name: BUILTIN_PROFILE_NAME.to_owned(),
hf_repo: self.profile.hf_repo.clone(),
model_dir: resolve_profile_model_dir(&self.base_dir, &self.profile.model_dir),
max_tokens: self.profile.max_tokens,
overlap_tokens: self.profile.overlap_tokens,
email_recognizer: self.profile.email_recognizer,
}
}
fn validate_raw(base_dir: &Path, raw: ProfileFileRaw) -> Result<Self, String> {
if raw.hf_repo.trim().is_empty() {
return Err(format!(
"profile '{BUILTIN_PROFILE_NAME}' has empty hf_repo"
));
}
if raw.max_tokens == 0 {
return Err(format!(
"profile '{BUILTIN_PROFILE_NAME}' has invalid max_tokens=0"
));
}
let content_tokens = raw.max_tokens.saturating_sub(2);
if raw.overlap_tokens >= content_tokens {
return Err(format!(
"profile '{BUILTIN_PROFILE_NAME}' has overlap_tokens={} which must be less than max_tokens - 2 ({})",
raw.overlap_tokens, content_tokens
));
}
Ok(Self {
base_dir: base_dir.to_path_buf(),
profile: ProfileSpec {
hf_repo: raw.hf_repo,
model_dir: raw.model_dir,
max_tokens: raw.max_tokens,
overlap_tokens: raw.overlap_tokens,
email_recognizer: raw.recognizers.email,
},
})
}
#[cfg(test)]
fn from_raw(base_dir: &Path, raw_text: &str) -> anyhow::Result<Self> {
let raw: ProfileFileRaw = toml::from_str(raw_text)?;
Self::validate_raw(base_dir, raw).map_err(|msg| anyhow::anyhow!(msg))
}
}
fn resolve_profile_model_dir(base_dir: &Path, model_dir: &Path) -> PathBuf {
if model_dir.is_absolute() {
model_dir.to_path_buf()
} else {
base_dir.join(model_dir)
}
}
#[cfg(test)]
mod tests {
use std::path::PathBuf;
use super::{BUILTIN_PROFILE_NAME, Profiles};
#[test]
fn loads_and_resolves_builtin_profile() {
let profiles = Profiles::from_raw(
&PathBuf::from("models"),
r#"
hf_repo = "Xenova/distilbert-base-multilingual-cased-ner-hrl"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 512
overlap_tokens = 128
"#,
)
.expect("profiles should parse");
let resolved = profiles.resolve_default();
assert_eq!(resolved.name, BUILTIN_PROFILE_NAME);
assert_eq!(
resolved.hf_repo,
"Xenova/distilbert-base-multilingual-cased-ner-hrl"
);
assert_eq!(
resolved.model_dir,
PathBuf::from("models/distilbert-base-multilingual-cased-ner-hrl")
);
assert_eq!(resolved.max_tokens, 512);
assert_eq!(resolved.overlap_tokens, 128);
assert!(resolved.email_recognizer);
}
#[test]
fn rejects_zero_max_tokens() {
let err = Profiles::from_raw(
&PathBuf::from("models"),
r#"
hf_repo = "Xenova/distilbert-base-multilingual-cased-ner-hrl"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 0
overlap_tokens = 0
"#,
)
.expect_err("zero max_tokens should fail");
assert!(err.to_string().contains("invalid max_tokens=0"));
}
#[test]
fn rejects_unknown_fields() {
let err = Profiles::from_raw(
&PathBuf::from("models"),
r#"
hf_repo = "Xenova/distilbert-base-multilingual-cased-ner-hrl"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 512
overlap_tokens = 128
extra_key = "nope"
"#,
)
.expect_err("unknown fields should fail");
assert!(err.to_string().contains("unknown field"));
}
#[test]
fn rejects_missing_hf_repo() {
let err = Profiles::from_raw(
&PathBuf::from("models"),
r#"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 512
overlap_tokens = 128
"#,
)
.expect_err("missing hf_repo should fail");
assert!(err.to_string().contains("missing field `hf_repo`"));
}
#[test]
fn rejects_overlap_exceeding_limit() {
let err = Profiles::from_raw(
&PathBuf::from("models"),
r#"
hf_repo = "Xenova/distilbert-base-multilingual-cased-ner-hrl"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 512
overlap_tokens = 510
"#,
)
.expect_err("overlap >= max_tokens - 2 should fail");
assert!(err.to_string().contains("overlap_tokens=510"));
}
#[test]
fn recognizers_email_can_be_disabled() {
let profiles = Profiles::from_raw(
&PathBuf::from("models"),
r#"
hf_repo = "Xenova/distilbert-base-multilingual-cased-ner-hrl"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 512
overlap_tokens = 128
[recognizers]
email = false
"#,
)
.expect("profiles should parse");
let resolved = profiles.resolve_default();
assert!(!resolved.email_recognizer);
}
#[test]
fn rejects_unknown_recognizer_fields() {
let err = Profiles::from_raw(
&PathBuf::from("models"),
r#"
hf_repo = "Xenova/distilbert-base-multilingual-cased-ner-hrl"
model_dir = "distilbert-base-multilingual-cased-ner-hrl"
max_tokens = 512
overlap_tokens = 128
[recognizers]
email = true
unknown_key = true
"#,
)
.expect_err("unknown recognizer field should fail");
assert!(err.to_string().contains("unknown field"));
}
}