use std::collections::HashMap;
use std::path::Path;
use once_cell::sync::Lazy;
use crate::model_zoo::loader::ModelZooError;
use crate::model_zoo::manifest::ModelManifest;
const TRANSE_FB15K237: &str = include_str!("manifests/transe-fb15k237.toml");
const TRANSE_WN18RR: &str = include_str!("manifests/transe-wn18rr.toml");
const ROTATE_FB15K237: &str = include_str!("manifests/rotate-fb15k237.toml");
const COMPLEX_WN18RR: &str = include_str!("manifests/complex-wn18rr.toml");
const DISTMULT_FB15K237: &str = include_str!("manifests/distmult-fb15k237.toml");
const BUILTIN_MANIFESTS: &[&str] = &[
TRANSE_FB15K237,
TRANSE_WN18RR,
ROTATE_FB15K237,
COMPLEX_WN18RR,
DISTMULT_FB15K237,
];
pub struct ModelZoo {
entries: HashMap<String, ModelManifest>,
}
static GLOBAL_ZOO: Lazy<ModelZoo> = Lazy::new(|| {
ModelZoo::from_embedded()
.unwrap_or_else(|e| panic!("Failed to parse built-in model zoo manifests: {e}"))
});
impl ModelZoo {
fn from_embedded() -> Result<Self, ModelZooError> {
let mut entries = HashMap::new();
for raw in BUILTIN_MANIFESTS {
let manifest: ModelManifest =
toml::from_str(raw).map_err(|e| ModelZooError::ManifestParse(e.to_string()))?;
entries.insert(manifest.name.clone(), manifest);
}
Ok(Self { entries })
}
pub fn registry() -> &'static ModelZoo {
&GLOBAL_ZOO
}
pub fn with_manifest_dir(dir: &Path) -> Result<ModelZoo, ModelZooError> {
let mut entries = HashMap::new();
let read_dir = std::fs::read_dir(dir)
.map_err(|e| ModelZooError::Io(std::io::Error::new(e.kind(), e.to_string())))?;
for entry in read_dir {
let entry = entry
.map_err(|e| ModelZooError::Io(std::io::Error::new(e.kind(), e.to_string())))?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) != Some("toml") {
continue;
}
let raw = match std::fs::read_to_string(&path) {
Ok(s) => s,
Err(e) => {
tracing::warn!("Skipping unreadable manifest {:?}: {}", path, e);
continue;
}
};
match toml::from_str::<ModelManifest>(&raw) {
Ok(manifest) => {
entries.insert(manifest.name.clone(), manifest);
}
Err(e) => {
tracing::warn!("Skipping invalid manifest {:?}: {}", path, e);
}
}
}
Ok(Self { entries })
}
pub fn get(&self, name: &str) -> Option<&ModelManifest> {
self.entries.get(name)
}
pub fn list(&self) -> Vec<&ModelManifest> {
self.entries.values().collect()
}
pub fn search(&self, query: &str) -> Vec<&ModelManifest> {
let q = query.to_lowercase();
self.entries
.values()
.filter(|m| {
m.name.to_lowercase().contains(&q)
|| m.dataset.to_lowercase().contains(&q)
|| m.model_type.to_lowercase().contains(&q)
})
.collect()
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_global_registry_has_five_entries() {
let zoo = ModelZoo::registry();
assert_eq!(zoo.len(), 5, "Expected 5 built-in entries");
}
#[test]
fn test_registry_get_existing() {
let zoo = ModelZoo::registry();
let m = zoo.get("transe-fb15k237");
assert!(m.is_some());
let m = m.expect("entry should exist");
assert_eq!(m.model_type, "TransE");
assert_eq!(m.dimensions, 200);
}
#[test]
fn test_registry_get_missing() {
let zoo = ModelZoo::registry();
assert!(zoo.get("nonexistent-model").is_none());
}
#[test]
fn test_registry_list() {
let zoo = ModelZoo::registry();
let list = zoo.list();
assert_eq!(list.len(), 5);
}
#[test]
fn test_registry_search_by_dataset() {
let zoo = ModelZoo::registry();
let fb = zoo.search("FB15k");
assert_eq!(fb.len(), 3, "FB15k-237 should have 3 entries");
}
#[test]
fn test_registry_search_by_model_type() {
let zoo = ModelZoo::registry();
let transe_models = zoo.search("transe");
assert_eq!(transe_models.len(), 2);
}
#[test]
fn test_registry_search_case_insensitive() {
let zoo = ModelZoo::registry();
let upper = zoo.search("TRANSE");
let lower = zoo.search("transe");
assert_eq!(upper.len(), lower.len());
}
#[test]
fn test_registry_search_no_results() {
let zoo = ModelZoo::registry();
let results = zoo.search("zzzznotexist");
assert!(results.is_empty());
}
#[test]
fn test_with_manifest_dir_empty() {
let tmp = std::env::temp_dir().join("oxirs_zoo_test_empty");
std::fs::create_dir_all(&tmp).expect("create temp dir");
let zoo = ModelZoo::with_manifest_dir(&tmp).expect("empty dir ok");
assert!(zoo.is_empty());
std::fs::remove_dir_all(&tmp).ok();
}
#[test]
fn test_with_manifest_dir_single_manifest() {
let tmp = std::env::temp_dir().join("oxirs_zoo_test_single");
std::fs::create_dir_all(&tmp).expect("create temp dir");
let manifest = crate::model_zoo::manifest::ModelManifest {
name: "custom-model".to_string(),
model_type: "TransE".to_string(),
dataset: "CustomDS".to_string(),
dimensions: 64,
entities: 100,
relations: 5,
sha256: "PLACEHOLDER".to_string(),
source: "file:///tmp/custom.ckpt".to_string(),
license: "Apache-2.0".to_string(),
citation: "Custom 2026".to_string(),
version: "1.0.0".to_string(),
created: "2026-05-01".to_string(),
notes: None,
};
let toml_str = toml::to_string(&manifest).expect("serialize");
std::fs::write(tmp.join("custom-model.toml"), toml_str).expect("write");
let zoo = ModelZoo::with_manifest_dir(&tmp).expect("parse manifest dir");
assert_eq!(zoo.len(), 1);
let m = zoo.get("custom-model").expect("should be present");
assert_eq!(m.model_type, "TransE");
std::fs::remove_dir_all(&tmp).ok();
}
}