mecha10_core/model.rs
1//! Model configuration utilities for AI nodes
2//!
3//! This module provides shared utilities for loading and managing ONNX model configurations.
4//! It's used by both the CLI (for model management) and AI nodes (for runtime inference).
5//!
6//! # Design
7//!
8//! Model configurations are stored in `models/<name>/config.json` and include:
9//! - Model metadata (name, task, repo, filename)
10//! - Preprocessing parameters (input_size, mean, std, channel_order)
11//! - Class labels configuration
12//!
13//! # Usage
14//!
15//! ```no_run
16//! use mecha10_core::model::{load_model_config, load_labels};
17//!
18//! // Load complete model configuration
19//! let config = load_model_config("mobilenet-v2")?;
20//!
21//! // Load class labels
22//! let labels = load_labels("mobilenet-v2")?;
23//! ```
24
25use anyhow::{Context, Result};
26use serde::{Deserialize, Serialize};
27use std::path::PathBuf;
28
29/// Model configuration stored in models/<name>/config.json
30///
31/// This is the complete configuration for a model, including metadata,
32/// preprocessing parameters, and label information.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ModelConfig {
35 /// Model name
36 pub name: String,
37 /// Task type (e.g., "image-classification", "object-detection")
38 pub task: String,
39 /// HuggingFace repository
40 pub repo: String,
41 /// Filename within the repo
42 pub filename: String,
43 /// Input size [width, height]
44 pub input_size: [u32; 2],
45 /// Preprocessing configuration
46 pub preprocessing: PreprocessingConfig,
47 /// Number of output classes
48 pub num_classes: usize,
49 /// Path to labels file (relative to model directory)
50 pub labels_file: String,
51 /// Custom labels configuration (for fine-tuning)
52 #[serde(default)]
53 pub custom_labels: CustomLabelsConfig,
54}
55
56/// Preprocessing configuration
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct PreprocessingConfig {
59 /// Mean values for normalization [R, G, B]
60 pub mean: [f32; 3],
61 /// Standard deviation for normalization [R, G, B]
62 pub std: [f32; 3],
63 /// Channel order (RGB or BGR)
64 pub channel_order: String,
65}
66
67/// Custom labels for fine-tuned models
68#[derive(Debug, Clone, Default, Serialize, Deserialize)]
69pub struct CustomLabelsConfig {
70 /// Whether custom labels are enabled
71 #[serde(default)]
72 pub enabled: bool,
73 /// Additional classes for fine-tuned models
74 #[serde(default)]
75 pub additional_classes: Vec<String>,
76}
77
78/// Load model configuration from disk
79///
80/// Loads the complete model configuration from `models/<name>/config.json`.
81///
82/// # Arguments
83///
84/// * `model_name` - Name of the model (e.g., "mobilenet-v2")
85///
86/// # Returns
87///
88/// The complete `ModelConfig` with all metadata and preprocessing parameters.
89///
90/// # Errors
91///
92/// Returns an error if:
93/// - The config file doesn't exist
94/// - The config file is not valid JSON
95/// - The config file is missing required fields
96///
97/// # Example
98///
99/// ```no_run
100/// use mecha10_core::model::load_model_config;
101///
102/// let config = load_model_config("mobilenet-v2")?;
103/// println!("Model: {} ({})", config.name, config.task);
104/// println!("Input size: {:?}", config.input_size);
105/// println!("Preprocessing: mean={:?}, std={:?}", config.preprocessing.mean, config.preprocessing.std);
106/// # Ok::<(), anyhow::Error>(())
107/// ```
108pub fn load_model_config(model_name: &str) -> Result<ModelConfig> {
109 let config_path = PathBuf::from("models").join(model_name).join("config.json");
110
111 if !config_path.exists() {
112 anyhow::bail!(
113 "Model config not found for '{}' at {}. Run: mecha10 models pull {}",
114 model_name,
115 config_path.display(),
116 model_name
117 );
118 }
119
120 let content = std::fs::read_to_string(&config_path)
121 .context(format!("Failed to read model config: {}", config_path.display()))?;
122
123 let config: ModelConfig = serde_json::from_str(&content)
124 .context(format!("Failed to parse model config.json: {}", config_path.display()))?;
125
126 Ok(config)
127}
128
129/// Load class labels from model directory
130///
131/// Loads class labels from `models/<name>/labels.txt` with support for custom labels.
132/// If custom labels are enabled in the model config, they will be appended to the base labels.
133///
134/// # Arguments
135///
136/// * `model_name` - Name of the model (e.g., "mobilenet-v2")
137///
138/// # Returns
139///
140/// A vector of class label strings. The index in the vector corresponds to the class ID.
141///
142/// # Errors
143///
144/// Returns an error if:
145/// - The labels file doesn't exist
146/// - The labels file cannot be read
147/// - The model config cannot be loaded (when checking for custom labels)
148///
149/// # Example
150///
151/// ```no_run
152/// use mecha10_core::model::load_labels;
153///
154/// let labels = load_labels("mobilenet-v2")?;
155/// println!("Loaded {} class labels", labels.len());
156/// println!("First label: {}", labels[0]);
157/// # Ok::<(), anyhow::Error>(())
158/// ```
159pub fn load_labels(model_name: &str) -> Result<Vec<String>> {
160 // Load from models/<name>/labels.txt
161 let labels_path = PathBuf::from("models").join(model_name).join("labels.txt");
162
163 if !labels_path.exists() {
164 anyhow::bail!(
165 "Labels file not found for model '{}' at {}. Run: mecha10 models pull {}",
166 model_name,
167 labels_path.display(),
168 model_name
169 );
170 }
171
172 tracing::info!("📋 Loading labels from: {}", labels_path.display());
173
174 let content = std::fs::read_to_string(&labels_path)
175 .context(format!("Failed to read labels file: {}", labels_path.display()))?;
176
177 let mut labels: Vec<String> = content.lines().map(|s| s.trim().to_string()).collect();
178
179 // Load model config to check for custom labels
180 let config_path = PathBuf::from("models").join(model_name).join("config.json");
181 if config_path.exists() {
182 if let Ok(config_content) = std::fs::read_to_string(&config_path) {
183 if let Ok(config) = serde_json::from_str::<ModelConfig>(&config_content) {
184 if config.custom_labels.enabled {
185 // Append custom labels
186 let num_custom = config.custom_labels.additional_classes.len();
187 labels.extend(config.custom_labels.additional_classes);
188 tracing::info!("✅ Added {} custom labels (total: {} labels)", num_custom, labels.len());
189 }
190 }
191 }
192 }
193
194 tracing::info!("✅ Loaded {} labels", labels.len());
195 Ok(labels)
196}
197
198#[cfg(test)]
199mod tests {
200 use super::*;
201
202 #[test]
203 fn test_model_config_serialization() {
204 let config = ModelConfig {
205 name: "test-model".to_string(),
206 task: "image-classification".to_string(),
207 repo: "test/repo".to_string(),
208 filename: "model.onnx".to_string(),
209 input_size: [224, 224],
210 preprocessing: PreprocessingConfig {
211 mean: [0.485, 0.456, 0.406],
212 std: [0.229, 0.224, 0.225],
213 channel_order: "RGB".to_string(),
214 },
215 num_classes: 1000,
216 labels_file: "labels.txt".to_string(),
217 custom_labels: CustomLabelsConfig::default(),
218 };
219
220 let json = serde_json::to_string(&config).unwrap();
221 let deserialized: ModelConfig = serde_json::from_str(&json).unwrap();
222
223 assert_eq!(config.name, deserialized.name);
224 assert_eq!(config.input_size, deserialized.input_size);
225 assert_eq!(config.preprocessing.mean, deserialized.preprocessing.mean);
226 }
227}