Skip to main content

axonml_train/
hub.rs

1//! Unified Model Hub - Central Registry for All Pretrained Models
2//!
3//! # File
4//! `crates/axonml-train/src/hub.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18// HashMap is used by feature-gated functions
19#[allow(unused_imports)]
20use std::collections::HashMap;
21
22// =============================================================================
23// Model Categories
24// =============================================================================
25
26/// Category of pretrained model.
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum ModelCategory {
29    /// Vision models (ResNet, VGG, ViT, etc.)
30    Vision,
31    /// Language models (BERT, GPT-2, LLaMA, etc.)
32    Language,
33    /// Audio models (Wav2Vec, Whisper, etc.)
34    Audio,
35    /// Multimodal models (CLIP, etc.)
36    Multimodal,
37}
38
39impl std::fmt::Display for ModelCategory {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            ModelCategory::Vision => write!(f, "Vision"),
43            ModelCategory::Language => write!(f, "Language"),
44            ModelCategory::Audio => write!(f, "Audio"),
45            ModelCategory::Multimodal => write!(f, "Multimodal"),
46        }
47    }
48}
49
50// =============================================================================
51// Unified Model Info
52// =============================================================================
53
54/// Unified model information across all categories.
55#[derive(Debug, Clone)]
56pub struct UnifiedModelInfo {
57    /// Model name
58    pub name: String,
59    /// Category
60    pub category: ModelCategory,
61    /// Architecture (e.g., "ResNet", "BERT", "LLaMA")
62    pub architecture: String,
63    /// Number of parameters
64    pub num_parameters: u64,
65    /// File size in bytes
66    pub size_bytes: u64,
67    /// Download URL
68    pub url: String,
69    /// Training dataset
70    pub dataset: String,
71    /// Description
72    pub description: String,
73    /// Tags for search
74    pub tags: Vec<String>,
75}
76
77impl UnifiedModelInfo {
78    /// Returns size in megabytes.
79    pub fn size_mb(&self) -> f64 {
80        self.size_bytes as f64 / 1_000_000.0
81    }
82
83    /// Returns formatted parameter count.
84    pub fn params_str(&self) -> String {
85        if self.num_parameters >= 1_000_000_000 {
86            format!("{:.1}B", self.num_parameters as f64 / 1_000_000_000.0)
87        } else if self.num_parameters >= 1_000_000 {
88            format!("{:.1}M", self.num_parameters as f64 / 1_000_000.0)
89        } else if self.num_parameters >= 1_000 {
90            format!("{:.1}K", self.num_parameters as f64 / 1_000.0)
91        } else {
92            format!("{}", self.num_parameters)
93        }
94    }
95}
96
97// =============================================================================
98// Registry Functions
99// =============================================================================
100
101/// Get all available models across all categories.
102#[cfg(all(feature = "vision", feature = "llm"))]
103pub fn list_all_models() -> Vec<UnifiedModelInfo> {
104    let mut models = Vec::new();
105
106    // Add vision models
107    #[cfg(feature = "vision")]
108    {
109        for (name, info) in axonml_vision::hub::model_registry() {
110            models.push(UnifiedModelInfo {
111                name: name.clone(),
112                category: ModelCategory::Vision,
113                architecture: extract_architecture(&name),
114                num_parameters: estimate_params_from_size(info.size_bytes),
115                size_bytes: info.size_bytes,
116                url: info.url.clone(),
117                dataset: info.dataset.clone(),
118                description: format!(
119                    "{} trained on {} (Top-1: {:.1}%)",
120                    name, info.dataset, info.accuracy
121                ),
122                tags: generate_vision_tags(&name, &info),
123            });
124        }
125    }
126
127    // Add LLM models
128    #[cfg(feature = "llm")]
129    {
130        for (name, info) in axonml_llm::hub::llm_registry() {
131            models.push(UnifiedModelInfo {
132                name: name.clone(),
133                category: ModelCategory::Language,
134                architecture: info.architecture.clone(),
135                num_parameters: info.num_parameters,
136                size_bytes: info.size_bytes,
137                url: info.url.clone(),
138                dataset: info.dataset.clone(),
139                description: format!(
140                    "{} ({} params, {} layers)",
141                    name,
142                    format_params(info.num_parameters),
143                    info.num_layers
144                ),
145                tags: generate_llm_tags(&name, &info),
146            });
147        }
148    }
149
150    models
151}
152
153/// Search models by name or tag.
154#[cfg(all(feature = "vision", feature = "llm"))]
155pub fn search_models(query: &str) -> Vec<UnifiedModelInfo> {
156    let query_lower = query.to_lowercase();
157    list_all_models()
158        .into_iter()
159        .filter(|m| {
160            m.name.to_lowercase().contains(&query_lower)
161                || m.architecture.to_lowercase().contains(&query_lower)
162                || m.tags
163                    .iter()
164                    .any(|t| t.to_lowercase().contains(&query_lower))
165        })
166        .collect()
167}
168
169/// Get models by category.
170#[cfg(all(feature = "vision", feature = "llm"))]
171pub fn models_by_category(category: ModelCategory) -> Vec<UnifiedModelInfo> {
172    list_all_models()
173        .into_iter()
174        .filter(|m| m.category == category)
175        .collect()
176}
177
178/// Get models within a size budget (in MB).
179#[cfg(all(feature = "vision", feature = "llm"))]
180pub fn models_by_max_size_mb(max_mb: f64) -> Vec<UnifiedModelInfo> {
181    let max_bytes = (max_mb * 1_000_000.0) as u64;
182    let mut models: Vec<_> = list_all_models()
183        .into_iter()
184        .filter(|m| m.size_bytes <= max_bytes)
185        .collect();
186    models.sort_by_key(|m| m.size_bytes);
187    models
188}
189
190/// Get models within a parameter budget.
191#[cfg(all(feature = "vision", feature = "llm"))]
192pub fn models_by_max_params(max_params: u64) -> Vec<UnifiedModelInfo> {
193    let mut models: Vec<_> = list_all_models()
194        .into_iter()
195        .filter(|m| m.num_parameters <= max_params)
196        .collect();
197    models.sort_by_key(|m| m.num_parameters);
198    models
199}
200
201/// Get recommended models for a task.
202#[cfg(all(feature = "vision", feature = "llm"))]
203pub fn recommended_models(task: &str) -> Vec<UnifiedModelInfo> {
204    let task_lower = task.to_lowercase();
205
206    if task_lower.contains("image")
207        || task_lower.contains("vision")
208        || task_lower.contains("classify")
209    {
210        // Image classification - recommend efficient models first
211        let mut models = models_by_category(ModelCategory::Vision);
212        models.sort_by(|a, b| {
213            // Prefer models with good accuracy/size ratio
214            let ratio_a = a.size_bytes as f64;
215            let ratio_b = b.size_bytes as f64;
216            ratio_a
217                .partial_cmp(&ratio_b)
218                .unwrap_or(std::cmp::Ordering::Equal)
219        });
220        models.truncate(5);
221        models
222    } else if task_lower.contains("text")
223        || task_lower.contains("nlp")
224        || task_lower.contains("language")
225    {
226        // NLP tasks - recommend smaller models first
227        let mut models = models_by_category(ModelCategory::Language);
228        models.sort_by_key(|m| m.num_parameters);
229        models.truncate(5);
230        models
231    } else if task_lower.contains("chat")
232        || task_lower.contains("instruct")
233        || task_lower.contains("generate")
234    {
235        // Text generation - recommend instruction-tuned models
236        search_models("instruct")
237    } else {
238        // Default - return smallest models from each category
239        let mut result = Vec::new();
240        for category in [ModelCategory::Vision, ModelCategory::Language] {
241            let mut cat_models = models_by_category(category);
242            cat_models.sort_by_key(|m| m.size_bytes);
243            if let Some(m) = cat_models.into_iter().next() {
244                result.push(m);
245            }
246        }
247        result
248    }
249}
250
251// =============================================================================
252// Helper Functions (used by vision-gated callers below; appears dead to clippy
253// when compiled without --features vision, but is reachable at runtime)
254// =============================================================================
255
256#[allow(dead_code)]
257fn extract_architecture(name: &str) -> String {
258    if name.starts_with("resnet") {
259        "ResNet".to_string()
260    } else if name.starts_with("vgg") {
261        "VGG".to_string()
262    } else if name.starts_with("mobilenet") {
263        "MobileNet".to_string()
264    } else if name.starts_with("efficientnet") {
265        "EfficientNet".to_string()
266    } else if name.starts_with("densenet") {
267        "DenseNet".to_string()
268    } else if name.starts_with("vit") {
269        "ViT".to_string()
270    } else if name.starts_with("swin") {
271        "Swin".to_string()
272    } else if name.starts_with("convnext") {
273        "ConvNeXt".to_string()
274    } else {
275        "Unknown".to_string()
276    }
277}
278
279#[allow(dead_code)]
280fn estimate_params_from_size(size_bytes: u64) -> u64 {
281    // Rough estimate: 4 bytes per float32 parameter
282    size_bytes / 4
283}
284
285#[allow(dead_code)]
286fn format_params(params: u64) -> String {
287    if params >= 1_000_000_000 {
288        format!("{:.1}B", params as f64 / 1_000_000_000.0)
289    } else if params >= 1_000_000 {
290        format!("{:.1}M", params as f64 / 1_000_000.0)
291    } else {
292        format!("{:.1}K", params as f64 / 1_000.0)
293    }
294}
295
296#[cfg(feature = "vision")]
297fn generate_vision_tags(name: &str, _info: &axonml_vision::hub::PretrainedModel) -> Vec<String> {
298    let mut tags = vec![
299        "vision".to_string(),
300        "image".to_string(),
301        "classification".to_string(),
302    ];
303
304    if name.contains("mobile") {
305        tags.push("mobile".to_string());
306        tags.push("efficient".to_string());
307    }
308    if name.contains("efficient") {
309        tags.push("efficient".to_string());
310    }
311    if name.contains("vit") || name.contains("swin") {
312        tags.push("transformer".to_string());
313    }
314
315    tags
316}
317
318#[cfg(feature = "llm")]
319fn generate_llm_tags(name: &str, info: &axonml_llm::hub::PretrainedLLM) -> Vec<String> {
320    let mut tags = vec![
321        "language".to_string(),
322        "nlp".to_string(),
323        "text".to_string(),
324    ];
325
326    tags.push(info.architecture.to_lowercase());
327
328    if name.contains("instruct") || name.contains("chat") {
329        tags.push("instruct".to_string());
330        tags.push("chat".to_string());
331    }
332    if info.num_parameters < 1_000_000_000 {
333        tags.push("small".to_string());
334    } else if info.num_parameters < 10_000_000_000 {
335        tags.push("medium".to_string());
336    } else {
337        tags.push("large".to_string());
338    }
339
340    tags
341}
342
343// =============================================================================
344// Model Benchmark Utilities
345// =============================================================================
346
347/// Results from model benchmarking.
348#[derive(Debug, Clone)]
349pub struct BenchmarkResult {
350    /// Model name
351    pub model_name: String,
352    /// Average inference time in milliseconds
353    pub avg_latency_ms: f64,
354    /// 95th percentile latency
355    pub p95_latency_ms: f64,
356    /// Throughput (samples/second)
357    pub throughput: f64,
358    /// Peak memory usage in bytes
359    pub peak_memory_bytes: u64,
360    /// Number of iterations run
361    pub iterations: usize,
362}
363
364impl BenchmarkResult {
365    /// Create a new benchmark result.
366    pub fn new(model_name: &str, latencies_ms: &[f64], peak_memory_bytes: u64) -> Self {
367        let iterations = latencies_ms.len();
368        let avg_latency_ms = if iterations > 0 {
369            latencies_ms.iter().sum::<f64>() / iterations as f64
370        } else {
371            0.0
372        };
373
374        // Calculate p95
375        let mut sorted = latencies_ms.to_vec();
376        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
377        let p95_idx = (iterations as f64 * 0.95) as usize;
378        let p95_latency_ms = sorted
379            .get(p95_idx.min(iterations.saturating_sub(1)))
380            .copied()
381            .unwrap_or(0.0);
382
383        let throughput = if avg_latency_ms > 0.0 {
384            1000.0 / avg_latency_ms
385        } else {
386            0.0
387        };
388
389        Self {
390            model_name: model_name.to_string(),
391            avg_latency_ms,
392            p95_latency_ms,
393            throughput,
394            peak_memory_bytes,
395            iterations,
396        }
397    }
398
399    /// Print a formatted summary.
400    pub fn print_summary(&self) {
401        println!("Benchmark: {}", self.model_name);
402        println!("  Iterations: {}", self.iterations);
403        println!("  Avg latency: {:.2} ms", self.avg_latency_ms);
404        println!("  P95 latency: {:.2} ms", self.p95_latency_ms);
405        println!("  Throughput: {:.1} samples/sec", self.throughput);
406        println!(
407            "  Peak memory: {:.1} MB",
408            self.peak_memory_bytes as f64 / 1_000_000.0
409        );
410    }
411}
412
413/// Compare multiple benchmark results.
414pub fn compare_benchmarks(results: &[BenchmarkResult]) {
415    if results.is_empty() {
416        println!("No benchmark results to compare.");
417        return;
418    }
419
420    println!(
421        "\n{:<25} {:>12} {:>12} {:>14} {:>12}",
422        "Model", "Avg (ms)", "P95 (ms)", "Throughput", "Memory (MB)"
423    );
424    println!("{}", "-".repeat(80));
425
426    for result in results {
427        println!(
428            "{:<25} {:>12.2} {:>12.2} {:>12.1}/s {:>12.1}",
429            result.model_name,
430            result.avg_latency_ms,
431            result.p95_latency_ms,
432            result.throughput,
433            result.peak_memory_bytes as f64 / 1_000_000.0
434        );
435    }
436}
437
438// =============================================================================
439// Tests
440// =============================================================================
441
442#[cfg(test)]
443mod tests {
444    use super::*;
445
446    #[test]
447    fn test_model_category_display() {
448        assert_eq!(format!("{}", ModelCategory::Vision), "Vision");
449        assert_eq!(format!("{}", ModelCategory::Language), "Language");
450    }
451
452    #[test]
453    fn test_unified_model_info_size() {
454        let info = UnifiedModelInfo {
455            name: "test".to_string(),
456            category: ModelCategory::Vision,
457            architecture: "Test".to_string(),
458            num_parameters: 1_500_000_000,
459            size_bytes: 6_000_000_000,
460            url: "https://example.com".to_string(),
461            dataset: "Test".to_string(),
462            description: "Test model".to_string(),
463            tags: vec!["test".to_string()],
464        };
465
466        assert!((info.size_mb() - 6000.0).abs() < 0.1);
467        assert_eq!(info.params_str(), "1.5B");
468    }
469
470    #[test]
471    fn test_benchmark_result() {
472        let latencies = vec![10.0, 12.0, 11.0, 15.0, 10.5];
473        let result = BenchmarkResult::new("test_model", &latencies, 100_000_000);
474
475        assert_eq!(result.iterations, 5);
476        assert!(result.avg_latency_ms > 0.0);
477        assert!(result.throughput > 0.0);
478    }
479
480    #[test]
481    fn test_extract_architecture() {
482        assert_eq!(extract_architecture("resnet50"), "ResNet");
483        assert_eq!(extract_architecture("vgg16"), "VGG");
484        assert_eq!(extract_architecture("mobilenet_v2"), "MobileNet");
485        assert_eq!(extract_architecture("vit_b_16"), "ViT");
486    }
487
488    #[test]
489    fn test_format_params() {
490        assert_eq!(format_params(1_500_000_000), "1.5B");
491        assert_eq!(format_params(110_000_000), "110.0M");
492        assert_eq!(format_params(50_000), "50.0K");
493    }
494
495    #[cfg(all(feature = "vision", feature = "llm"))]
496    #[test]
497    fn test_list_all_models() {
498        let models = list_all_models();
499        assert!(!models.is_empty());
500    }
501
502    #[cfg(all(feature = "vision", feature = "llm"))]
503    #[test]
504    fn test_search_models() {
505        let results = search_models("resnet");
506        for model in &results {
507            assert!(
508                model.name.to_lowercase().contains("resnet")
509                    || model.architecture.to_lowercase().contains("resnet")
510            );
511        }
512    }
513}