mecha10_core/
model.rs

1//! Model configuration utilities for AI nodes
2//!
3//! This module provides shared utilities for loading and managing ONNX model configurations.
4//! It's used by both the CLI (for model management) and AI nodes (for runtime inference).
5//!
6//! # Design
7//!
8//! Model configurations are stored in `models/<name>/config.json` and include:
9//! - Model metadata (name, task, repo, filename)
10//! - Preprocessing parameters (input_size, mean, std, channel_order)
11//! - Class labels configuration
12//!
13//! # Usage
14//!
15//! ```no_run
16//! use mecha10_core::model::{load_model_config, load_labels};
17//!
18//! // Load complete model configuration
19//! let config = load_model_config("mobilenet-v2")?;
20//!
21//! // Load class labels
22//! let labels = load_labels("mobilenet-v2")?;
23//! ```
24
25use anyhow::{Context, Result};
26use serde::{Deserialize, Serialize};
27use std::path::PathBuf;
28
29/// Model configuration stored in models/<name>/config.json
30///
31/// This is the complete configuration for a model, including metadata,
32/// preprocessing parameters, and label information.
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct ModelConfig {
35    /// Model name
36    pub name: String,
37    /// Task type (e.g., "image-classification", "object-detection")
38    pub task: String,
39    /// HuggingFace repository
40    pub repo: String,
41    /// Filename within the repo
42    pub filename: String,
43    /// Input size [width, height]
44    pub input_size: [u32; 2],
45    /// Preprocessing configuration
46    pub preprocessing: PreprocessingConfig,
47    /// Number of output classes
48    pub num_classes: usize,
49    /// Path to labels file (relative to model directory)
50    pub labels_file: String,
51    /// Custom labels configuration (for fine-tuning)
52    #[serde(default)]
53    pub custom_labels: CustomLabelsConfig,
54}
55
56/// Preprocessing configuration
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct PreprocessingConfig {
59    /// Mean values for normalization [R, G, B]
60    pub mean: [f32; 3],
61    /// Standard deviation for normalization [R, G, B]
62    pub std: [f32; 3],
63    /// Channel order (RGB or BGR)
64    pub channel_order: String,
65}
66
67/// Custom labels for fine-tuned models
68#[derive(Debug, Clone, Default, Serialize, Deserialize)]
69pub struct CustomLabelsConfig {
70    /// Whether custom labels are enabled
71    #[serde(default)]
72    pub enabled: bool,
73    /// Additional classes for fine-tuned models
74    #[serde(default)]
75    pub additional_classes: Vec<String>,
76}
77
78/// Load model configuration from disk
79///
80/// Loads the complete model configuration from `models/<name>/config.json`.
81///
82/// # Arguments
83///
84/// * `model_name` - Name of the model (e.g., "mobilenet-v2")
85///
86/// # Returns
87///
88/// The complete `ModelConfig` with all metadata and preprocessing parameters.
89///
90/// # Errors
91///
92/// Returns an error if:
93/// - The config file doesn't exist
94/// - The config file is not valid JSON
95/// - The config file is missing required fields
96///
97/// # Example
98///
99/// ```no_run
100/// use mecha10_core::model::load_model_config;
101///
102/// let config = load_model_config("mobilenet-v2")?;
103/// println!("Model: {} ({})", config.name, config.task);
104/// println!("Input size: {:?}", config.input_size);
105/// println!("Preprocessing: mean={:?}, std={:?}", config.preprocessing.mean, config.preprocessing.std);
106/// # Ok::<(), anyhow::Error>(())
107/// ```
108pub fn load_model_config(model_name: &str) -> Result<ModelConfig> {
109    let config_path = PathBuf::from("models").join(model_name).join("config.json");
110
111    if !config_path.exists() {
112        anyhow::bail!(
113            "Model config not found for '{}' at {}. Run: mecha10 models pull {}",
114            model_name,
115            config_path.display(),
116            model_name
117        );
118    }
119
120    let content = std::fs::read_to_string(&config_path)
121        .context(format!("Failed to read model config: {}", config_path.display()))?;
122
123    let config: ModelConfig = serde_json::from_str(&content)
124        .context(format!("Failed to parse model config.json: {}", config_path.display()))?;
125
126    Ok(config)
127}
128
129/// Load class labels from model directory
130///
131/// Loads class labels from `models/<name>/labels.txt` with support for custom labels.
132/// If custom labels are enabled in the model config, they will be appended to the base labels.
133///
134/// # Arguments
135///
136/// * `model_name` - Name of the model (e.g., "mobilenet-v2")
137///
138/// # Returns
139///
140/// A vector of class label strings. The index in the vector corresponds to the class ID.
141///
142/// # Errors
143///
144/// Returns an error if:
145/// - The labels file doesn't exist
146/// - The labels file cannot be read
147/// - The model config cannot be loaded (when checking for custom labels)
148///
149/// # Example
150///
151/// ```no_run
152/// use mecha10_core::model::load_labels;
153///
154/// let labels = load_labels("mobilenet-v2")?;
155/// println!("Loaded {} class labels", labels.len());
156/// println!("First label: {}", labels[0]);
157/// # Ok::<(), anyhow::Error>(())
158/// ```
159pub fn load_labels(model_name: &str) -> Result<Vec<String>> {
160    // Load from models/<name>/labels.txt
161    let labels_path = PathBuf::from("models").join(model_name).join("labels.txt");
162
163    if !labels_path.exists() {
164        anyhow::bail!(
165            "Labels file not found for model '{}' at {}. Run: mecha10 models pull {}",
166            model_name,
167            labels_path.display(),
168            model_name
169        );
170    }
171
172    tracing::info!("📋 Loading labels from: {}", labels_path.display());
173
174    let content = std::fs::read_to_string(&labels_path)
175        .context(format!("Failed to read labels file: {}", labels_path.display()))?;
176
177    let mut labels: Vec<String> = content.lines().map(|s| s.trim().to_string()).collect();
178
179    // Load model config to check for custom labels
180    let config_path = PathBuf::from("models").join(model_name).join("config.json");
181    if config_path.exists() {
182        if let Ok(config_content) = std::fs::read_to_string(&config_path) {
183            if let Ok(config) = serde_json::from_str::<ModelConfig>(&config_content) {
184                if config.custom_labels.enabled {
185                    // Append custom labels
186                    let num_custom = config.custom_labels.additional_classes.len();
187                    labels.extend(config.custom_labels.additional_classes);
188                    tracing::info!("✅ Added {} custom labels (total: {} labels)", num_custom, labels.len());
189                }
190            }
191        }
192    }
193
194    tracing::info!("✅ Loaded {} labels", labels.len());
195    Ok(labels)
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201
202    #[test]
203    fn test_model_config_serialization() {
204        let config = ModelConfig {
205            name: "test-model".to_string(),
206            task: "image-classification".to_string(),
207            repo: "test/repo".to_string(),
208            filename: "model.onnx".to_string(),
209            input_size: [224, 224],
210            preprocessing: PreprocessingConfig {
211                mean: [0.485, 0.456, 0.406],
212                std: [0.229, 0.224, 0.225],
213                channel_order: "RGB".to_string(),
214            },
215            num_classes: 1000,
216            labels_file: "labels.txt".to_string(),
217            custom_labels: CustomLabelsConfig::default(),
218        };
219
220        let json = serde_json::to_string(&config).unwrap();
221        let deserialized: ModelConfig = serde_json::from_str(&json).unwrap();
222
223        assert_eq!(config.name, deserialized.name);
224        assert_eq!(config.input_size, deserialized.input_size);
225        assert_eq!(config.preprocessing.mean, deserialized.preprocessing.mean);
226    }
227}