use std::collections::HashMap;
use crate::pipeline::PipelineTask;
#[derive(Clone, Debug)]
pub struct ModelEntry {
pub id: &'static str,
pub name: &'static str,
pub task: PipelineTask,
pub input_size: Option<(u32, u32)>,
pub num_classes: Option<usize>,
pub notes: &'static str,
}
#[derive(Clone, Debug, Default)]
pub struct ModelZoo {
entries: HashMap<&'static str, ModelEntry>,
}
impl ModelZoo {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_defaults() -> Self {
let mut zoo = Self::new();
zoo.register(ModelEntry {
id: "places365/resnet18",
name: "Places365 ResNet-18 scene classifier",
task: PipelineTask::SceneClassification,
input_size: Some((224, 224)),
num_classes: Some(365),
notes: "Bring your own ONNX export of the Places365 ResNet-18 model.",
});
zoo.register(ModelEntry {
id: "transnet-v2",
name: "TransNet V2 shot boundary detector",
task: PipelineTask::ShotBoundary,
input_size: Some((48, 27)),
num_classes: Some(2),
notes: "Sliding-window input of 100 frames at 48x27 RGB per window.",
});
zoo
}
pub fn register(&mut self, entry: ModelEntry) {
self.entries.insert(entry.id, entry);
}
#[must_use]
pub fn get(&self, id: &str) -> Option<&ModelEntry> {
self.entries.get(id)
}
pub fn entries(&self) -> impl Iterator<Item = &ModelEntry> {
self.entries.values()
}
#[must_use]
pub fn len(&self) -> usize {
self.entries.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn defaults_contain_scene_and_shot() {
let zoo = ModelZoo::with_defaults();
assert!(zoo.get("places365/resnet18").is_some());
assert!(zoo.get("transnet-v2").is_some());
}
#[test]
fn empty_zoo_reports_empty() {
let zoo = ModelZoo::new();
assert!(zoo.is_empty());
assert_eq!(zoo.len(), 0);
}
#[test]
fn register_adds_entry() {
let mut zoo = ModelZoo::new();
zoo.register(ModelEntry {
id: "demo/x",
name: "Demo",
task: PipelineTask::Custom,
input_size: None,
num_classes: None,
notes: "",
});
assert_eq!(zoo.len(), 1);
assert!(zoo.get("demo/x").is_some());
}
}