Skip to main content

oximedia_ml/
zoo.rs

1//! Lightweight model registry ("zoo").
2//!
3//! The zoo is a static, in-memory catalogue that maps stable model IDs
4//! to descriptive metadata. It intentionally does **not** embed weights
5//! or network fetch logic — users bring their own `.onnx` file. The zoo
6//! is purely a discovery surface so pipelines can advertise their
7//! expected input contract.
8//!
9//! ## Example
10//!
11//! ```
12//! use oximedia_ml::ModelZoo;
13//!
14//! let zoo = ModelZoo::with_defaults();
15//! let scene = zoo.get("places365/resnet18").expect("default entry");
16//! assert_eq!(scene.input_size, Some((224, 224)));
17//! assert_eq!(scene.num_classes, Some(365));
18//! ```
19//!
20//! Register custom entries with [`ModelZoo::register`]. IDs are unique
21//! — registering the same ID twice overwrites the previous record.
22
23use std::collections::HashMap;
24
25use crate::pipeline::PipelineTask;
26
27/// Metadata entry describing a model that can be plugged into a pipeline.
28///
29/// Entries are static — each field is `&'static str` / `Option<…>` so
30/// entries can live in a `const`-friendly table. Use
31/// [`ModelZoo::register`] to add a [`ModelEntry`] to a zoo instance.
32#[derive(Clone, Debug)]
33pub struct ModelEntry {
34    /// Stable unique ID, e.g. `"places365/resnet18"`.
35    pub id: &'static str,
36    /// Human-readable name.
37    pub name: &'static str,
38    /// Which pipeline task this model is intended for.
39    pub task: PipelineTask,
40    /// Expected `(width, height)` of the image input, if applicable.
41    pub input_size: Option<(u32, u32)>,
42    /// Number of output classes, if applicable.
43    pub num_classes: Option<usize>,
44    /// Short notes / citation for the user.
45    pub notes: &'static str,
46}
47
48/// In-memory registry of known models.
49#[derive(Clone, Debug, Default)]
50pub struct ModelZoo {
51    entries: HashMap<&'static str, ModelEntry>,
52}
53
54impl ModelZoo {
55    /// Create an empty zoo.
56    #[must_use]
57    pub fn new() -> Self {
58        Self::default()
59    }
60
61    /// Create the default zoo with built-in entries.
62    #[must_use]
63    pub fn with_defaults() -> Self {
64        let mut zoo = Self::new();
65        zoo.register(ModelEntry {
66            id: "places365/resnet18",
67            name: "Places365 ResNet-18 scene classifier",
68            task: PipelineTask::SceneClassification,
69            input_size: Some((224, 224)),
70            num_classes: Some(365),
71            notes: "Bring your own ONNX export of the Places365 ResNet-18 model.",
72        });
73        zoo.register(ModelEntry {
74            id: "transnet-v2",
75            name: "TransNet V2 shot boundary detector",
76            task: PipelineTask::ShotBoundary,
77            input_size: Some((48, 27)),
78            num_classes: Some(2),
79            notes: "Sliding-window input of 100 frames at 48x27 RGB per window.",
80        });
81        zoo
82    }
83
84    /// Register a new entry (overwrites any existing entry with the same ID).
85    pub fn register(&mut self, entry: ModelEntry) {
86        self.entries.insert(entry.id, entry);
87    }
88
89    /// Look up an entry by ID.
90    #[must_use]
91    pub fn get(&self, id: &str) -> Option<&ModelEntry> {
92        self.entries.get(id)
93    }
94
95    /// Return all registered entries.
96    pub fn entries(&self) -> impl Iterator<Item = &ModelEntry> {
97        self.entries.values()
98    }
99
100    /// Number of entries.
101    #[must_use]
102    pub fn len(&self) -> usize {
103        self.entries.len()
104    }
105
106    /// Whether the zoo has no entries.
107    #[must_use]
108    pub fn is_empty(&self) -> bool {
109        self.entries.is_empty()
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn defaults_contain_scene_and_shot() {
119        let zoo = ModelZoo::with_defaults();
120        assert!(zoo.get("places365/resnet18").is_some());
121        assert!(zoo.get("transnet-v2").is_some());
122    }
123
124    #[test]
125    fn empty_zoo_reports_empty() {
126        let zoo = ModelZoo::new();
127        assert!(zoo.is_empty());
128        assert_eq!(zoo.len(), 0);
129    }
130
131    #[test]
132    fn register_adds_entry() {
133        let mut zoo = ModelZoo::new();
134        zoo.register(ModelEntry {
135            id: "demo/x",
136            name: "Demo",
137            task: PipelineTask::Custom,
138            input_size: None,
139            num_classes: None,
140            notes: "",
141        });
142        assert_eq!(zoo.len(), 1);
143        assert!(zoo.get("demo/x").is_some());
144    }
145}