1#[allow(unused_imports)]
20use std::collections::HashMap;
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
28pub enum ModelCategory {
29 Vision,
31 Language,
33 Audio,
35 Multimodal,
37}
38
39impl std::fmt::Display for ModelCategory {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 ModelCategory::Vision => write!(f, "Vision"),
43 ModelCategory::Language => write!(f, "Language"),
44 ModelCategory::Audio => write!(f, "Audio"),
45 ModelCategory::Multimodal => write!(f, "Multimodal"),
46 }
47 }
48}
49
50#[derive(Debug, Clone)]
56pub struct UnifiedModelInfo {
57 pub name: String,
59 pub category: ModelCategory,
61 pub architecture: String,
63 pub num_parameters: u64,
65 pub size_bytes: u64,
67 pub url: String,
69 pub dataset: String,
71 pub description: String,
73 pub tags: Vec<String>,
75}
76
77impl UnifiedModelInfo {
78 pub fn size_mb(&self) -> f64 {
80 self.size_bytes as f64 / 1_000_000.0
81 }
82
83 pub fn params_str(&self) -> String {
85 if self.num_parameters >= 1_000_000_000 {
86 format!("{:.1}B", self.num_parameters as f64 / 1_000_000_000.0)
87 } else if self.num_parameters >= 1_000_000 {
88 format!("{:.1}M", self.num_parameters as f64 / 1_000_000.0)
89 } else if self.num_parameters >= 1_000 {
90 format!("{:.1}K", self.num_parameters as f64 / 1_000.0)
91 } else {
92 format!("{}", self.num_parameters)
93 }
94 }
95}
96
97#[cfg(all(feature = "vision", feature = "llm"))]
103pub fn list_all_models() -> Vec<UnifiedModelInfo> {
104 let mut models = Vec::new();
105
106 #[cfg(feature = "vision")]
108 {
109 for (name, info) in axonml_vision::hub::model_registry() {
110 models.push(UnifiedModelInfo {
111 name: name.clone(),
112 category: ModelCategory::Vision,
113 architecture: extract_architecture(&name),
114 num_parameters: estimate_params_from_size(info.size_bytes),
115 size_bytes: info.size_bytes,
116 url: info.url.clone(),
117 dataset: info.dataset.clone(),
118 description: format!(
119 "{} trained on {} (Top-1: {:.1}%)",
120 name, info.dataset, info.accuracy
121 ),
122 tags: generate_vision_tags(&name, &info),
123 });
124 }
125 }
126
127 #[cfg(feature = "llm")]
129 {
130 for (name, info) in axonml_llm::hub::llm_registry() {
131 models.push(UnifiedModelInfo {
132 name: name.clone(),
133 category: ModelCategory::Language,
134 architecture: info.architecture.clone(),
135 num_parameters: info.num_parameters,
136 size_bytes: info.size_bytes,
137 url: info.url.clone(),
138 dataset: info.dataset.clone(),
139 description: format!(
140 "{} ({} params, {} layers)",
141 name,
142 format_params(info.num_parameters),
143 info.num_layers
144 ),
145 tags: generate_llm_tags(&name, &info),
146 });
147 }
148 }
149
150 models
151}
152
153#[cfg(all(feature = "vision", feature = "llm"))]
155pub fn search_models(query: &str) -> Vec<UnifiedModelInfo> {
156 let query_lower = query.to_lowercase();
157 list_all_models()
158 .into_iter()
159 .filter(|m| {
160 m.name.to_lowercase().contains(&query_lower)
161 || m.architecture.to_lowercase().contains(&query_lower)
162 || m.tags
163 .iter()
164 .any(|t| t.to_lowercase().contains(&query_lower))
165 })
166 .collect()
167}
168
169#[cfg(all(feature = "vision", feature = "llm"))]
171pub fn models_by_category(category: ModelCategory) -> Vec<UnifiedModelInfo> {
172 list_all_models()
173 .into_iter()
174 .filter(|m| m.category == category)
175 .collect()
176}
177
178#[cfg(all(feature = "vision", feature = "llm"))]
180pub fn models_by_max_size_mb(max_mb: f64) -> Vec<UnifiedModelInfo> {
181 let max_bytes = (max_mb * 1_000_000.0) as u64;
182 let mut models: Vec<_> = list_all_models()
183 .into_iter()
184 .filter(|m| m.size_bytes <= max_bytes)
185 .collect();
186 models.sort_by_key(|m| m.size_bytes);
187 models
188}
189
190#[cfg(all(feature = "vision", feature = "llm"))]
192pub fn models_by_max_params(max_params: u64) -> Vec<UnifiedModelInfo> {
193 let mut models: Vec<_> = list_all_models()
194 .into_iter()
195 .filter(|m| m.num_parameters <= max_params)
196 .collect();
197 models.sort_by_key(|m| m.num_parameters);
198 models
199}
200
201#[cfg(all(feature = "vision", feature = "llm"))]
203pub fn recommended_models(task: &str) -> Vec<UnifiedModelInfo> {
204 let task_lower = task.to_lowercase();
205
206 if task_lower.contains("image")
207 || task_lower.contains("vision")
208 || task_lower.contains("classify")
209 {
210 let mut models = models_by_category(ModelCategory::Vision);
212 models.sort_by(|a, b| {
213 let ratio_a = a.size_bytes as f64;
215 let ratio_b = b.size_bytes as f64;
216 ratio_a
217 .partial_cmp(&ratio_b)
218 .unwrap_or(std::cmp::Ordering::Equal)
219 });
220 models.truncate(5);
221 models
222 } else if task_lower.contains("text")
223 || task_lower.contains("nlp")
224 || task_lower.contains("language")
225 {
226 let mut models = models_by_category(ModelCategory::Language);
228 models.sort_by_key(|m| m.num_parameters);
229 models.truncate(5);
230 models
231 } else if task_lower.contains("chat")
232 || task_lower.contains("instruct")
233 || task_lower.contains("generate")
234 {
235 search_models("instruct")
237 } else {
238 let mut result = Vec::new();
240 for category in [ModelCategory::Vision, ModelCategory::Language] {
241 let mut cat_models = models_by_category(category);
242 cat_models.sort_by_key(|m| m.size_bytes);
243 if let Some(m) = cat_models.into_iter().next() {
244 result.push(m);
245 }
246 }
247 result
248 }
249}
250
251#[allow(dead_code)]
257fn extract_architecture(name: &str) -> String {
258 if name.starts_with("resnet") {
259 "ResNet".to_string()
260 } else if name.starts_with("vgg") {
261 "VGG".to_string()
262 } else if name.starts_with("mobilenet") {
263 "MobileNet".to_string()
264 } else if name.starts_with("efficientnet") {
265 "EfficientNet".to_string()
266 } else if name.starts_with("densenet") {
267 "DenseNet".to_string()
268 } else if name.starts_with("vit") {
269 "ViT".to_string()
270 } else if name.starts_with("swin") {
271 "Swin".to_string()
272 } else if name.starts_with("convnext") {
273 "ConvNeXt".to_string()
274 } else {
275 "Unknown".to_string()
276 }
277}
278
279#[allow(dead_code)]
280fn estimate_params_from_size(size_bytes: u64) -> u64 {
281 size_bytes / 4
283}
284
285#[allow(dead_code)]
286fn format_params(params: u64) -> String {
287 if params >= 1_000_000_000 {
288 format!("{:.1}B", params as f64 / 1_000_000_000.0)
289 } else if params >= 1_000_000 {
290 format!("{:.1}M", params as f64 / 1_000_000.0)
291 } else {
292 format!("{:.1}K", params as f64 / 1_000.0)
293 }
294}
295
296#[cfg(feature = "vision")]
297fn generate_vision_tags(name: &str, _info: &axonml_vision::hub::PretrainedModel) -> Vec<String> {
298 let mut tags = vec![
299 "vision".to_string(),
300 "image".to_string(),
301 "classification".to_string(),
302 ];
303
304 if name.contains("mobile") {
305 tags.push("mobile".to_string());
306 tags.push("efficient".to_string());
307 }
308 if name.contains("efficient") {
309 tags.push("efficient".to_string());
310 }
311 if name.contains("vit") || name.contains("swin") {
312 tags.push("transformer".to_string());
313 }
314
315 tags
316}
317
318#[cfg(feature = "llm")]
319fn generate_llm_tags(name: &str, info: &axonml_llm::hub::PretrainedLLM) -> Vec<String> {
320 let mut tags = vec![
321 "language".to_string(),
322 "nlp".to_string(),
323 "text".to_string(),
324 ];
325
326 tags.push(info.architecture.to_lowercase());
327
328 if name.contains("instruct") || name.contains("chat") {
329 tags.push("instruct".to_string());
330 tags.push("chat".to_string());
331 }
332 if info.num_parameters < 1_000_000_000 {
333 tags.push("small".to_string());
334 } else if info.num_parameters < 10_000_000_000 {
335 tags.push("medium".to_string());
336 } else {
337 tags.push("large".to_string());
338 }
339
340 tags
341}
342
343#[derive(Debug, Clone)]
349pub struct BenchmarkResult {
350 pub model_name: String,
352 pub avg_latency_ms: f64,
354 pub p95_latency_ms: f64,
356 pub throughput: f64,
358 pub peak_memory_bytes: u64,
360 pub iterations: usize,
362}
363
364impl BenchmarkResult {
365 pub fn new(model_name: &str, latencies_ms: &[f64], peak_memory_bytes: u64) -> Self {
367 let iterations = latencies_ms.len();
368 let avg_latency_ms = if iterations > 0 {
369 latencies_ms.iter().sum::<f64>() / iterations as f64
370 } else {
371 0.0
372 };
373
374 let mut sorted = latencies_ms.to_vec();
376 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
377 let p95_idx = (iterations as f64 * 0.95) as usize;
378 let p95_latency_ms = sorted
379 .get(p95_idx.min(iterations.saturating_sub(1)))
380 .copied()
381 .unwrap_or(0.0);
382
383 let throughput = if avg_latency_ms > 0.0 {
384 1000.0 / avg_latency_ms
385 } else {
386 0.0
387 };
388
389 Self {
390 model_name: model_name.to_string(),
391 avg_latency_ms,
392 p95_latency_ms,
393 throughput,
394 peak_memory_bytes,
395 iterations,
396 }
397 }
398
399 pub fn print_summary(&self) {
401 println!("Benchmark: {}", self.model_name);
402 println!(" Iterations: {}", self.iterations);
403 println!(" Avg latency: {:.2} ms", self.avg_latency_ms);
404 println!(" P95 latency: {:.2} ms", self.p95_latency_ms);
405 println!(" Throughput: {:.1} samples/sec", self.throughput);
406 println!(
407 " Peak memory: {:.1} MB",
408 self.peak_memory_bytes as f64 / 1_000_000.0
409 );
410 }
411}
412
413pub fn compare_benchmarks(results: &[BenchmarkResult]) {
415 if results.is_empty() {
416 println!("No benchmark results to compare.");
417 return;
418 }
419
420 println!(
421 "\n{:<25} {:>12} {:>12} {:>14} {:>12}",
422 "Model", "Avg (ms)", "P95 (ms)", "Throughput", "Memory (MB)"
423 );
424 println!("{}", "-".repeat(80));
425
426 for result in results {
427 println!(
428 "{:<25} {:>12.2} {:>12.2} {:>12.1}/s {:>12.1}",
429 result.model_name,
430 result.avg_latency_ms,
431 result.p95_latency_ms,
432 result.throughput,
433 result.peak_memory_bytes as f64 / 1_000_000.0
434 );
435 }
436}
437
438#[cfg(test)]
443mod tests {
444 use super::*;
445
446 #[test]
447 fn test_model_category_display() {
448 assert_eq!(format!("{}", ModelCategory::Vision), "Vision");
449 assert_eq!(format!("{}", ModelCategory::Language), "Language");
450 }
451
452 #[test]
453 fn test_unified_model_info_size() {
454 let info = UnifiedModelInfo {
455 name: "test".to_string(),
456 category: ModelCategory::Vision,
457 architecture: "Test".to_string(),
458 num_parameters: 1_500_000_000,
459 size_bytes: 6_000_000_000,
460 url: "https://example.com".to_string(),
461 dataset: "Test".to_string(),
462 description: "Test model".to_string(),
463 tags: vec!["test".to_string()],
464 };
465
466 assert!((info.size_mb() - 6000.0).abs() < 0.1);
467 assert_eq!(info.params_str(), "1.5B");
468 }
469
470 #[test]
471 fn test_benchmark_result() {
472 let latencies = vec![10.0, 12.0, 11.0, 15.0, 10.5];
473 let result = BenchmarkResult::new("test_model", &latencies, 100_000_000);
474
475 assert_eq!(result.iterations, 5);
476 assert!(result.avg_latency_ms > 0.0);
477 assert!(result.throughput > 0.0);
478 }
479
480 #[test]
481 fn test_extract_architecture() {
482 assert_eq!(extract_architecture("resnet50"), "ResNet");
483 assert_eq!(extract_architecture("vgg16"), "VGG");
484 assert_eq!(extract_architecture("mobilenet_v2"), "MobileNet");
485 assert_eq!(extract_architecture("vit_b_16"), "ViT");
486 }
487
488 #[test]
489 fn test_format_params() {
490 assert_eq!(format_params(1_500_000_000), "1.5B");
491 assert_eq!(format_params(110_000_000), "110.0M");
492 assert_eq!(format_params(50_000), "50.0K");
493 }
494
495 #[cfg(all(feature = "vision", feature = "llm"))]
496 #[test]
497 fn test_list_all_models() {
498 let models = list_all_models();
499 assert!(!models.is_empty());
500 }
501
502 #[cfg(all(feature = "vision", feature = "llm"))]
503 #[test]
504 fn test_search_models() {
505 let results = search_models("resnet");
506 for model in &results {
507 assert!(
508 model.name.to_lowercase().contains("resnet")
509 || model.architecture.to_lowercase().contains("resnet")
510 );
511 }
512 }
513}