use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use super::LanguageTag;
use crate::error::Result;
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub enum ModelType {
Ngram,
Embedding,
Hybrid,
}
impl std::fmt::Display for ModelType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelType::Ngram => write!(f, "ngram"),
ModelType::Embedding => write!(f, "embedding"),
ModelType::Hybrid => write!(f, "hybrid"),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ModelMetadata {
pub language: LanguageTag,
pub model_type: ModelType,
pub corpus_sources: Vec<String>,
pub trained_at: DateTime<Utc>,
pub vocab_size: usize,
pub ngram_order: Option<usize>,
pub embedding_dim: Option<usize>,
#[serde(default)]
pub extra: HashMap<String, String>,
}
impl ModelMetadata {
pub fn ngram(language: LanguageTag, vocab_size: usize, order: usize) -> Self {
Self {
language,
model_type: ModelType::Ngram,
corpus_sources: Vec::new(),
trained_at: Utc::now(),
vocab_size,
ngram_order: Some(order),
embedding_dim: None,
extra: HashMap::new(),
}
}
pub fn embedding(language: LanguageTag, vocab_size: usize, dim: usize) -> Self {
Self {
language,
model_type: ModelType::Embedding,
corpus_sources: Vec::new(),
trained_at: Utc::now(),
vocab_size,
ngram_order: None,
embedding_dim: Some(dim),
extra: HashMap::new(),
}
}
pub fn hybrid(language: LanguageTag, vocab_size: usize, order: usize, dim: usize) -> Self {
Self {
language,
model_type: ModelType::Hybrid,
corpus_sources: Vec::new(),
trained_at: Utc::now(),
vocab_size,
ngram_order: Some(order),
embedding_dim: Some(dim),
extra: HashMap::new(),
}
}
pub fn with_corpus_source(mut self, source: impl Into<String>) -> Self {
self.corpus_sources.push(source.into());
self
}
pub fn with_extra(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.extra.insert(key.into(), value.into());
self
}
pub fn save(&self, model_path: &Path) -> Result<()> {
let meta_path = model_path.with_extension("meta.json");
let content = serde_json::to_string_pretty(self)
.map_err(|e| crate::Error::SerializationMessage(e.to_string()))?;
fs::write(&meta_path, content)?;
log::debug!("Saved metadata to {:?}", meta_path);
Ok(())
}
pub fn load(model_path: &Path) -> Result<Option<Self>> {
let meta_path = model_path.with_extension("meta.json");
if !meta_path.exists() {
return Ok(None);
}
let content = fs::read_to_string(&meta_path)?;
let meta = serde_json::from_str(&content)
.map_err(|e| crate::Error::SerializationMessage(e.to_string()))?;
Ok(Some(meta))
}
}
#[derive(Clone, Debug)]
pub struct ModelEntry {
pub path: PathBuf,
pub language: LanguageTag,
pub model_type: ModelType,
pub size_bytes: u64,
pub metadata: Option<ModelMetadata>,
}
#[derive(Debug)]
pub struct ModelRegistry {
root: PathBuf,
index: HashMap<String, Vec<ModelEntry>>,
}
impl ModelRegistry {
pub fn scan(root: &Path) -> Result<Self> {
let mut index: HashMap<String, Vec<ModelEntry>> = HashMap::new();
if !root.exists() {
return Ok(Self {
root: root.to_path_buf(),
index,
});
}
for lang_entry in fs::read_dir(root)? {
let lang_entry = lang_entry?;
if !lang_entry.file_type()?.is_dir() {
continue;
}
let lang_name = lang_entry.file_name().to_string_lossy().to_string();
for dialect_entry in fs::read_dir(lang_entry.path())? {
let dialect_entry = dialect_entry?;
let dialect_path = dialect_entry.path();
if dialect_entry.file_type()?.is_dir() {
Self::scan_model_files(&mut index, &dialect_path, &lang_name)?;
} else if dialect_path.extension().map_or(false, |e| e == "bin") {
Self::add_model_file(&mut index, &dialect_path, &lang_name, None)?;
}
}
}
Ok(Self {
root: root.to_path_buf(),
index,
})
}
fn scan_model_files(
index: &mut HashMap<String, Vec<ModelEntry>>,
dir: &Path,
lang: &str,
) -> Result<()> {
let dialect = dir.file_name().map(|n| n.to_string_lossy().to_string());
for entry in fs::read_dir(dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().map_or(false, |e| e == "bin") {
Self::add_model_file(index, &path, lang, dialect.as_deref())?;
}
}
Ok(())
}
fn load_metadata_from_model(path: &Path) -> Option<ModelMetadata> {
let sidecar_path = path.with_extension("bin.meta.json");
if sidecar_path.exists() {
if let Ok(content) = fs::read_to_string(&sidecar_path) {
if let Ok(meta) = serde_json::from_str(&content) {
log::debug!("Loaded metadata from {:?}", sidecar_path);
return Some(meta);
}
}
}
let meta_path = path.with_extension("meta.json");
if meta_path.exists() {
if let Ok(content) = fs::read_to_string(&meta_path) {
if let Ok(meta) = serde_json::from_str(&content) {
log::debug!("Loaded metadata from {:?}", meta_path);
return Some(meta);
}
}
}
if let Some(stem) = path.file_stem() {
let parent = path.parent().unwrap_or(Path::new("."));
let stem_meta = parent.join(format!("{}.meta.json", stem.to_string_lossy()));
if stem_meta.exists() {
if let Ok(content) = fs::read_to_string(&stem_meta) {
if let Ok(meta) = serde_json::from_str(&content) {
log::debug!("Loaded metadata from {:?}", stem_meta);
return Some(meta);
}
}
}
}
None
}
fn add_model_file(
index: &mut HashMap<String, Vec<ModelEntry>>,
path: &Path,
lang: &str,
dialect: Option<&str>,
) -> Result<()> {
let file_metadata = fs::metadata(path)?;
let size_bytes = file_metadata.len();
let model_metadata = Self::load_metadata_from_model(path);
let model_type = model_metadata
.as_ref()
.map(|m| m.model_type.clone())
.unwrap_or_else(|| {
path.file_stem()
.and_then(|s| s.to_str())
.map(|s| {
if s.contains("hybrid") {
ModelType::Hybrid
} else if s.contains("embedding") || s.contains("embed") {
ModelType::Embedding
} else {
ModelType::Ngram
}
})
.unwrap_or(ModelType::Ngram)
});
let language = model_metadata
.as_ref()
.map(|m| m.language.clone())
.unwrap_or_else(|| {
if let Some(d) = dialect {
d.parse().unwrap_or_else(|_| LanguageTag::new(lang))
} else {
LanguageTag::new(lang)
}
});
let entry = ModelEntry {
path: path.to_path_buf(),
language: language.clone(),
model_type,
size_bytes,
metadata: model_metadata,
};
index.entry(lang.to_string()).or_default().push(entry);
Ok(())
}
pub fn root(&self) -> &Path {
&self.root
}
pub fn find(&self, lang: &LanguageTag) -> Vec<&ModelEntry> {
self.index
.get(lang.language())
.map(|entries| entries.iter().filter(|e| e.language == *lang).collect())
.unwrap_or_default()
}
pub fn find_best_match(&self, lang: &LanguageTag) -> Option<&ModelEntry> {
let exact = self.find(lang);
if !exact.is_empty() {
return exact
.iter()
.find(|e| e.model_type == ModelType::Hybrid)
.or_else(|| exact.iter().find(|e| e.model_type == ModelType::Ngram))
.or_else(|| exact.first())
.copied();
}
let base = lang.base();
if base != *lang {
return self.find_best_match(&base);
}
self.index.get(lang.language()).and_then(|entries| {
entries
.iter()
.find(|e| e.model_type == ModelType::Hybrid)
.or_else(|| entries.iter().find(|e| e.model_type == ModelType::Ngram))
.or_else(|| entries.first())
})
}
pub fn languages(&self) -> Vec<&str> {
self.index.keys().map(String::as_str).collect()
}
pub fn all_models(&self) -> Vec<&ModelEntry> {
self.index.values().flat_map(|v| v.iter()).collect()
}
pub fn count(&self) -> usize {
self.index.values().map(|v| v.len()).sum()
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_model_metadata_ngram() {
let meta = ModelMetadata::ngram(LanguageTag::new("en"), 10000, 5);
assert_eq!(meta.model_type, ModelType::Ngram);
assert_eq!(meta.ngram_order, Some(5));
assert_eq!(meta.embedding_dim, None);
}
#[test]
fn test_model_metadata_hybrid() {
let meta = ModelMetadata::hybrid(LanguageTag::new("en"), 10000, 5, 100);
assert_eq!(meta.model_type, ModelType::Hybrid);
assert_eq!(meta.ngram_order, Some(5));
assert_eq!(meta.embedding_dim, Some(100));
}
#[test]
fn test_empty_registry() {
let temp_dir = TempDir::new().unwrap();
let registry = ModelRegistry::scan(temp_dir.path()).unwrap();
assert_eq!(registry.count(), 0);
assert!(registry.languages().is_empty());
}
#[test]
fn test_registry_scan() {
let temp_dir = TempDir::new().unwrap();
let en_us = temp_dir.path().join("en").join("en-US");
fs::create_dir_all(&en_us).unwrap();
fs::write(en_us.join("ngram.bin"), b"test").unwrap();
fs::write(en_us.join("hybrid.bin"), b"test").unwrap();
let de_de = temp_dir.path().join("de").join("de-DE");
fs::create_dir_all(&de_de).unwrap();
fs::write(de_de.join("ngram.bin"), b"test").unwrap();
let registry = ModelRegistry::scan(temp_dir.path()).unwrap();
assert_eq!(registry.count(), 3);
assert_eq!(registry.languages().len(), 2);
let en_models = registry.find(&LanguageTag::with_region("en", "US"));
assert_eq!(en_models.len(), 2);
}
}