1use std::collections::HashMap;
24
25use crate::pipeline::PipelineTask;
26
27#[derive(Clone, Debug)]
33pub struct ModelEntry {
34 pub id: &'static str,
36 pub name: &'static str,
38 pub task: PipelineTask,
40 pub input_size: Option<(u32, u32)>,
42 pub num_classes: Option<usize>,
44 pub notes: &'static str,
46}
47
48#[derive(Clone, Debug, Default)]
50pub struct ModelZoo {
51 entries: HashMap<&'static str, ModelEntry>,
52}
53
54impl ModelZoo {
55 #[must_use]
57 pub fn new() -> Self {
58 Self::default()
59 }
60
61 #[must_use]
63 pub fn with_defaults() -> Self {
64 let mut zoo = Self::new();
65 zoo.register(ModelEntry {
66 id: "places365/resnet18",
67 name: "Places365 ResNet-18 scene classifier",
68 task: PipelineTask::SceneClassification,
69 input_size: Some((224, 224)),
70 num_classes: Some(365),
71 notes: "Bring your own ONNX export of the Places365 ResNet-18 model.",
72 });
73 zoo.register(ModelEntry {
74 id: "transnet-v2",
75 name: "TransNet V2 shot boundary detector",
76 task: PipelineTask::ShotBoundary,
77 input_size: Some((48, 27)),
78 num_classes: Some(2),
79 notes: "Sliding-window input of 100 frames at 48x27 RGB per window.",
80 });
81 zoo
82 }
83
84 pub fn register(&mut self, entry: ModelEntry) {
86 self.entries.insert(entry.id, entry);
87 }
88
89 #[must_use]
91 pub fn get(&self, id: &str) -> Option<&ModelEntry> {
92 self.entries.get(id)
93 }
94
95 pub fn entries(&self) -> impl Iterator<Item = &ModelEntry> {
97 self.entries.values()
98 }
99
100 #[must_use]
102 pub fn len(&self) -> usize {
103 self.entries.len()
104 }
105
106 #[must_use]
108 pub fn is_empty(&self) -> bool {
109 self.entries.is_empty()
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn defaults_contain_scene_and_shot() {
119 let zoo = ModelZoo::with_defaults();
120 assert!(zoo.get("places365/resnet18").is_some());
121 assert!(zoo.get("transnet-v2").is_some());
122 }
123
124 #[test]
125 fn empty_zoo_reports_empty() {
126 let zoo = ModelZoo::new();
127 assert!(zoo.is_empty());
128 assert_eq!(zoo.len(), 0);
129 }
130
131 #[test]
132 fn register_adds_entry() {
133 let mut zoo = ModelZoo::new();
134 zoo.register(ModelEntry {
135 id: "demo/x",
136 name: "Demo",
137 task: PipelineTask::Custom,
138 input_size: None,
139 num_classes: None,
140 notes: "",
141 });
142 assert_eq!(zoo.len(), 1);
143 assert!(zoo.get("demo/x").is_some());
144 }
145}