Skip to main content

axonml_vision/
hub.rs

1//! Model Hub - Pretrained Weights Management
2//!
3//! Download, cache, and load pretrained model weights.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::{Read, Write};
11use std::path::PathBuf;
12
13use axonml_tensor::Tensor;
14
15// =============================================================================
16// Error Type
17// =============================================================================
18
19/// Hub errors.
20#[derive(Debug)]
21pub enum HubError {
22    /// Network error during download.
23    NetworkError(String),
24    /// IO error.
25    IoError(std::io::Error),
26    /// Model not found.
27    ModelNotFound(String),
28    /// Invalid weight format.
29    InvalidFormat(String),
30    /// Checksum mismatch between expected and actual hash.
31    ChecksumMismatch {
32        /// Expected checksum value.
33        expected: String,
34        /// Actual computed checksum.
35        actual: String,
36    },
37}
38
39impl std::fmt::Display for HubError {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            HubError::NetworkError(e) => write!(f, "Network error: {}", e),
43            HubError::IoError(e) => write!(f, "IO error: {}", e),
44            HubError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
45            HubError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg),
46            HubError::ChecksumMismatch { expected, actual } => {
47                write!(f, "Checksum mismatch: expected {}, got {}", expected, actual)
48            }
49        }
50    }
51}
52
53impl std::error::Error for HubError {}
54
55impl From<std::io::Error> for HubError {
56    fn from(e: std::io::Error) -> Self {
57        HubError::IoError(e)
58    }
59}
60
61/// Result type for hub operations.
62pub type HubResult<T> = Result<T, HubError>;
63
64// =============================================================================
65// Pretrained Model Registry
66// =============================================================================
67
68/// Information about a pretrained model.
69#[derive(Debug, Clone)]
70pub struct PretrainedModel {
71    /// Model name (e.g., "resnet18").
72    pub name: String,
73    /// URL to download weights.
74    pub url: String,
75    /// Expected SHA256 checksum (optional).
76    pub checksum: Option<String>,
77    /// File size in bytes.
78    pub size_bytes: u64,
79    /// Number of classes the model was trained on.
80    pub num_classes: usize,
81    /// Input image size (height, width).
82    pub input_size: (usize, usize),
83    /// Dataset trained on.
84    pub dataset: String,
85    /// Top-1 accuracy on validation set.
86    pub accuracy: f32,
87}
88
89/// Get the cache directory for pretrained weights.
90pub fn cache_dir() -> PathBuf {
91    let base = dirs::cache_dir()
92        .or_else(dirs::home_dir)
93        .unwrap_or_else(|| PathBuf::from("."));
94    base.join("axonml").join("hub").join("weights")
95}
96
97/// Get registry of available pretrained models.
98pub fn model_registry() -> HashMap<String, PretrainedModel> {
99    let mut registry = HashMap::new();
100
101    // ResNet models (ImageNet pretrained)
102    registry.insert(
103        "resnet18".to_string(),
104        PretrainedModel {
105            name: "resnet18".to_string(),
106            url: "https://huggingface.co/axonml-ml/resnet18-imagenet/resolve/main/resnet18.safetensors".to_string(),
107            checksum: None,
108            size_bytes: 44_700_000,
109            num_classes: 1000,
110            input_size: (224, 224),
111            dataset: "ImageNet-1K".to_string(),
112            accuracy: 69.76,
113        },
114    );
115
116    registry.insert(
117        "resnet34".to_string(),
118        PretrainedModel {
119            name: "resnet34".to_string(),
120            url: "https://huggingface.co/axonml-ml/resnet34-imagenet/resolve/main/resnet34.safetensors".to_string(),
121            checksum: None,
122            size_bytes: 83_300_000,
123            num_classes: 1000,
124            input_size: (224, 224),
125            dataset: "ImageNet-1K".to_string(),
126            accuracy: 73.31,
127        },
128    );
129
130    registry.insert(
131        "resnet50".to_string(),
132        PretrainedModel {
133            name: "resnet50".to_string(),
134            url: "https://huggingface.co/axonml-ml/resnet50-imagenet/resolve/main/resnet50.safetensors".to_string(),
135            checksum: None,
136            size_bytes: 97_800_000,
137            num_classes: 1000,
138            input_size: (224, 224),
139            dataset: "ImageNet-1K".to_string(),
140            accuracy: 76.13,
141        },
142    );
143
144    // VGG models (ImageNet pretrained)
145    registry.insert(
146        "vgg16".to_string(),
147        PretrainedModel {
148            name: "vgg16".to_string(),
149            url: "https://huggingface.co/axonml-ml/vgg16-imagenet/resolve/main/vgg16.safetensors".to_string(),
150            checksum: None,
151            size_bytes: 528_000_000,
152            num_classes: 1000,
153            input_size: (224, 224),
154            dataset: "ImageNet-1K".to_string(),
155            accuracy: 71.59,
156        },
157    );
158
159    registry.insert(
160        "vgg19".to_string(),
161        PretrainedModel {
162            name: "vgg19".to_string(),
163            url: "https://huggingface.co/axonml-ml/vgg19-imagenet/resolve/main/vgg19.safetensors".to_string(),
164            checksum: None,
165            size_bytes: 548_000_000,
166            num_classes: 1000,
167            input_size: (224, 224),
168            dataset: "ImageNet-1K".to_string(),
169            accuracy: 72.38,
170        },
171    );
172
173    registry.insert(
174        "vgg16_bn".to_string(),
175        PretrainedModel {
176            name: "vgg16_bn".to_string(),
177            url: "https://huggingface.co/axonml-ml/vgg16bn-imagenet/resolve/main/vgg16_bn.safetensors".to_string(),
178            checksum: None,
179            size_bytes: 528_000_000,
180            num_classes: 1000,
181            input_size: (224, 224),
182            dataset: "ImageNet-1K".to_string(),
183            accuracy: 73.36,
184        },
185    );
186
187    // Larger ResNet variants
188    registry.insert(
189        "resnet101".to_string(),
190        PretrainedModel {
191            name: "resnet101".to_string(),
192            url: "https://huggingface.co/axonml-ml/resnet101-imagenet/resolve/main/resnet101.safetensors".to_string(),
193            checksum: None,
194            size_bytes: 170_500_000,
195            num_classes: 1000,
196            input_size: (224, 224),
197            dataset: "ImageNet-1K".to_string(),
198            accuracy: 77.37,
199        },
200    );
201
202    registry.insert(
203        "resnet152".to_string(),
204        PretrainedModel {
205            name: "resnet152".to_string(),
206            url: "https://huggingface.co/axonml-ml/resnet152-imagenet/resolve/main/resnet152.safetensors".to_string(),
207            checksum: None,
208            size_bytes: 230_400_000,
209            num_classes: 1000,
210            input_size: (224, 224),
211            dataset: "ImageNet-1K".to_string(),
212            accuracy: 78.31,
213        },
214    );
215
216    // Mobile-optimized models
217    registry.insert(
218        "mobilenet_v2".to_string(),
219        PretrainedModel {
220            name: "mobilenet_v2".to_string(),
221            url: "https://huggingface.co/axonml-ml/mobilenetv2-imagenet/resolve/main/mobilenet_v2.safetensors".to_string(),
222            checksum: None,
223            size_bytes: 13_600_000,
224            num_classes: 1000,
225            input_size: (224, 224),
226            dataset: "ImageNet-1K".to_string(),
227            accuracy: 71.88,
228        },
229    );
230
231    registry.insert(
232        "mobilenet_v3_small".to_string(),
233        PretrainedModel {
234            name: "mobilenet_v3_small".to_string(),
235            url: "https://huggingface.co/axonml-ml/mobilenetv3-small-imagenet/resolve/main/mobilenet_v3_small.safetensors".to_string(),
236            checksum: None,
237            size_bytes: 9_800_000,
238            num_classes: 1000,
239            input_size: (224, 224),
240            dataset: "ImageNet-1K".to_string(),
241            accuracy: 67.67,
242        },
243    );
244
245    registry.insert(
246        "mobilenet_v3_large".to_string(),
247        PretrainedModel {
248            name: "mobilenet_v3_large".to_string(),
249            url: "https://huggingface.co/axonml-ml/mobilenetv3-large-imagenet/resolve/main/mobilenet_v3_large.safetensors".to_string(),
250            checksum: None,
251            size_bytes: 21_100_000,
252            num_classes: 1000,
253            input_size: (224, 224),
254            dataset: "ImageNet-1K".to_string(),
255            accuracy: 74.04,
256        },
257    );
258
259    // EfficientNet family
260    registry.insert(
261        "efficientnet_b0".to_string(),
262        PretrainedModel {
263            name: "efficientnet_b0".to_string(),
264            url: "https://huggingface.co/axonml-ml/efficientnet-b0-imagenet/resolve/main/efficientnet_b0.safetensors".to_string(),
265            checksum: None,
266            size_bytes: 20_300_000,
267            num_classes: 1000,
268            input_size: (224, 224),
269            dataset: "ImageNet-1K".to_string(),
270            accuracy: 77.10,
271        },
272    );
273
274    registry.insert(
275        "efficientnet_b1".to_string(),
276        PretrainedModel {
277            name: "efficientnet_b1".to_string(),
278            url: "https://huggingface.co/axonml-ml/efficientnet-b1-imagenet/resolve/main/efficientnet_b1.safetensors".to_string(),
279            checksum: None,
280            size_bytes: 30_100_000,
281            num_classes: 1000,
282            input_size: (240, 240),
283            dataset: "ImageNet-1K".to_string(),
284            accuracy: 78.80,
285        },
286    );
287
288    registry.insert(
289        "efficientnet_b2".to_string(),
290        PretrainedModel {
291            name: "efficientnet_b2".to_string(),
292            url: "https://huggingface.co/axonml-ml/efficientnet-b2-imagenet/resolve/main/efficientnet_b2.safetensors".to_string(),
293            checksum: None,
294            size_bytes: 35_200_000,
295            num_classes: 1000,
296            input_size: (260, 260),
297            dataset: "ImageNet-1K".to_string(),
298            accuracy: 79.80,
299        },
300    );
301
302    // DenseNet family
303    registry.insert(
304        "densenet121".to_string(),
305        PretrainedModel {
306            name: "densenet121".to_string(),
307            url: "https://huggingface.co/axonml-ml/densenet121-imagenet/resolve/main/densenet121.safetensors".to_string(),
308            checksum: None,
309            size_bytes: 30_800_000,
310            num_classes: 1000,
311            input_size: (224, 224),
312            dataset: "ImageNet-1K".to_string(),
313            accuracy: 74.43,
314        },
315    );
316
317    registry.insert(
318        "densenet169".to_string(),
319        PretrainedModel {
320            name: "densenet169".to_string(),
321            url: "https://huggingface.co/axonml-ml/densenet169-imagenet/resolve/main/densenet169.safetensors".to_string(),
322            checksum: None,
323            size_bytes: 54_700_000,
324            num_classes: 1000,
325            input_size: (224, 224),
326            dataset: "ImageNet-1K".to_string(),
327            accuracy: 75.60,
328        },
329    );
330
331    // Vision Transformer (ViT)
332    registry.insert(
333        "vit_b_16".to_string(),
334        PretrainedModel {
335            name: "vit_b_16".to_string(),
336            url: "https://huggingface.co/axonml-ml/vit-b16-imagenet/resolve/main/vit_b_16.safetensors".to_string(),
337            checksum: None,
338            size_bytes: 330_200_000,
339            num_classes: 1000,
340            input_size: (224, 224),
341            dataset: "ImageNet-1K".to_string(),
342            accuracy: 81.07,
343        },
344    );
345
346    registry.insert(
347        "vit_b_32".to_string(),
348        PretrainedModel {
349            name: "vit_b_32".to_string(),
350            url: "https://huggingface.co/axonml-ml/vit-b32-imagenet/resolve/main/vit_b_32.safetensors".to_string(),
351            checksum: None,
352            size_bytes: 337_500_000,
353            num_classes: 1000,
354            input_size: (224, 224),
355            dataset: "ImageNet-1K".to_string(),
356            accuracy: 75.91,
357        },
358    );
359
360    // Swin Transformer
361    registry.insert(
362        "swin_t".to_string(),
363        PretrainedModel {
364            name: "swin_t".to_string(),
365            url: "https://huggingface.co/axonml-ml/swin-tiny-imagenet/resolve/main/swin_t.safetensors".to_string(),
366            checksum: None,
367            size_bytes: 110_700_000,
368            num_classes: 1000,
369            input_size: (224, 224),
370            dataset: "ImageNet-1K".to_string(),
371            accuracy: 81.30,
372        },
373    );
374
375    registry.insert(
376        "swin_s".to_string(),
377        PretrainedModel {
378            name: "swin_s".to_string(),
379            url: "https://huggingface.co/axonml-ml/swin-small-imagenet/resolve/main/swin_s.safetensors".to_string(),
380            checksum: None,
381            size_bytes: 193_500_000,
382            num_classes: 1000,
383            input_size: (224, 224),
384            dataset: "ImageNet-1K".to_string(),
385            accuracy: 83.20,
386        },
387    );
388
389    // ConvNeXt
390    registry.insert(
391        "convnext_tiny".to_string(),
392        PretrainedModel {
393            name: "convnext_tiny".to_string(),
394            url: "https://huggingface.co/axonml-ml/convnext-tiny-imagenet/resolve/main/convnext_tiny.safetensors".to_string(),
395            checksum: None,
396            size_bytes: 109_100_000,
397            num_classes: 1000,
398            input_size: (224, 224),
399            dataset: "ImageNet-1K".to_string(),
400            accuracy: 82.10,
401        },
402    );
403
404    registry.insert(
405        "convnext_small".to_string(),
406        PretrainedModel {
407            name: "convnext_small".to_string(),
408            url: "https://huggingface.co/axonml-ml/convnext-small-imagenet/resolve/main/convnext_small.safetensors".to_string(),
409            checksum: None,
410            size_bytes: 195_600_000,
411            num_classes: 1000,
412            input_size: (224, 224),
413            dataset: "ImageNet-1K".to_string(),
414            accuracy: 83.10,
415        },
416    );
417
418    registry
419}
420
421// =============================================================================
422// Weight Loading
423// =============================================================================
424
425/// State dictionary - named tensor mapping.
426pub type StateDict = HashMap<String, Tensor<f32>>;
427
428/// Check if pretrained weights are cached.
429pub fn is_cached(model_name: &str) -> bool {
430    let path = cache_dir().join(format!("{}.safetensors", model_name));
431    path.exists()
432}
433
434/// Get cached weight path.
435pub fn cached_path(model_name: &str) -> PathBuf {
436    cache_dir().join(format!("{}.safetensors", model_name))
437}
438
439/// Download pretrained weights if not cached.
440///
441/// # Arguments
442/// * `model_name` - Name of the model (e.g., "resnet18")
443/// * `force` - Force re-download even if cached
444///
445/// # Returns
446/// Path to the downloaded weights file
447pub fn download_weights(model_name: &str, force: bool) -> HubResult<PathBuf> {
448    let registry = model_registry();
449    let model_info = registry
450        .get(model_name)
451        .ok_or_else(|| HubError::ModelNotFound(model_name.to_string()))?;
452
453    let cache_path = cached_path(model_name);
454
455    // Return cached path if exists and not forcing
456    if cache_path.exists() && !force {
457        return Ok(cache_path);
458    }
459
460    // Ensure cache directory exists
461    if let Some(parent) = cache_path.parent() {
462        fs::create_dir_all(parent)?;
463    }
464
465    // Download weights from pretrained model hub
466    println!("Downloading {} weights ({:.1} MB)...", model_name, model_info.size_bytes as f64 / 1_000_000.0);
467
468    let response = reqwest::blocking::get(&model_info.url)
469        .map_err(|e| HubError::NetworkError(e.to_string()))?;
470
471    if !response.status().is_success() {
472        return Err(HubError::NetworkError(format!(
473            "HTTP {}: {}",
474            response.status(),
475            model_info.url
476        )));
477    }
478
479    let bytes = response.bytes()
480        .map_err(|e| HubError::NetworkError(e.to_string()))?;
481
482    let mut file = File::create(&cache_path)?;
483    file.write_all(&bytes)?;
484
485    println!("Downloaded to {:?}", cache_path);
486
487    Ok(cache_path)
488}
489
490/// Save state dict to file (simple binary format).
491///
492/// # Arguments
493/// * `state` - The state dictionary to save
494/// * `path` - Path where the file will be saved
495///
496/// # Example
497/// ```ignore
498/// use axonml_vision::hub::{save_state_dict, StateDict};
499/// let mut state = StateDict::new();
500/// // ... populate state dict ...
501/// save_state_dict(&state, &PathBuf::from("model.bin")).unwrap();
502/// ```
503pub fn save_state_dict(state: &StateDict, path: &PathBuf) -> HubResult<()> {
504    use std::io::BufWriter;
505
506    let file = File::create(path)?;
507    let mut writer = BufWriter::new(file);
508
509    // Write number of tensors
510    let num_tensors = state.len() as u32;
511    writer.write_all(&num_tensors.to_le_bytes())?;
512
513    for (name, tensor) in state {
514        // Write name length and name
515        let name_bytes = name.as_bytes();
516        let name_len = name_bytes.len() as u32;
517        writer.write_all(&name_len.to_le_bytes())?;
518        writer.write_all(name_bytes)?;
519
520        // Write shape
521        let shape = tensor.shape();
522        let ndim = shape.len() as u32;
523        writer.write_all(&ndim.to_le_bytes())?;
524        for &dim in shape {
525            writer.write_all(&(dim as u64).to_le_bytes())?;
526        }
527
528        // Write data
529        let data = tensor.to_vec();
530        for val in data {
531            writer.write_all(&val.to_le_bytes())?;
532        }
533    }
534
535    Ok(())
536}
537
538/// Load state dict from file.
539pub fn load_state_dict(path: &PathBuf) -> HubResult<StateDict> {
540    use std::io::BufReader;
541
542    let file = File::open(path)?;
543    let mut reader = BufReader::new(file);
544
545    // Read number of tensors
546    let mut buf4 = [0u8; 4];
547    reader.read_exact(&mut buf4)?;
548    let num_tensors = u32::from_le_bytes(buf4);
549
550    let mut state = HashMap::new();
551
552    for _ in 0..num_tensors {
553        // Read name
554        reader.read_exact(&mut buf4)?;
555        let name_len = u32::from_le_bytes(buf4) as usize;
556        let mut name_bytes = vec![0u8; name_len];
557        reader.read_exact(&mut name_bytes)?;
558        let name = String::from_utf8_lossy(&name_bytes).to_string();
559
560        // Read shape
561        reader.read_exact(&mut buf4)?;
562        let ndim = u32::from_le_bytes(buf4) as usize;
563        let mut shape = Vec::with_capacity(ndim);
564        let mut buf8 = [0u8; 8];
565        for _ in 0..ndim {
566            reader.read_exact(&mut buf8)?;
567            shape.push(u64::from_le_bytes(buf8) as usize);
568        }
569
570        // Read data
571        let numel: usize = shape.iter().product();
572        let mut data = Vec::with_capacity(numel);
573        for _ in 0..numel {
574            reader.read_exact(&mut buf4)?;
575            data.push(f32::from_le_bytes(buf4));
576        }
577
578        let tensor = Tensor::from_vec(data, &shape)
579            .map_err(|e| HubError::InvalidFormat(format!("{:?}", e)))?;
580        state.insert(name, tensor);
581    }
582
583    Ok(state)
584}
585
586/// List available pretrained models.
587pub fn list_models() -> Vec<PretrainedModel> {
588    model_registry().into_values().collect()
589}
590
591/// Get info for a specific model.
592pub fn model_info(name: &str) -> Option<PretrainedModel> {
593    model_registry().get(name).cloned()
594}
595
596// =============================================================================
597// Tests
598// =============================================================================
599
600#[cfg(test)]
601mod tests {
602    use super::*;
603
604    #[test]
605    fn test_model_registry() {
606        let registry = model_registry();
607        assert!(registry.contains_key("resnet18"));
608        assert!(registry.contains_key("vgg16"));
609    }
610
611    #[test]
612    fn test_cache_dir() {
613        let dir = cache_dir();
614        assert!(dir.to_string_lossy().contains("axonml"));
615    }
616
617    #[test]
618    fn test_list_models() {
619        let models = list_models();
620        assert!(!models.is_empty());
621    }
622
623    #[test]
624    fn test_model_info() {
625        let info = model_info("resnet18");
626        assert!(info.is_some());
627        let info = info.unwrap();
628        assert_eq!(info.num_classes, 1000);
629        assert_eq!(info.input_size, (224, 224));
630    }
631
632    #[test]
633    fn test_model_urls() {
634        let registry = model_registry();
635        for (name, model) in &registry {
636            assert!(!model.url.is_empty(), "Model {} has empty URL", name);
637            assert!(model.url.starts_with("https://"), "Model {} URL should be HTTPS", name);
638            assert!(model.size_bytes > 0, "Model {} has zero size", name);
639        }
640    }
641
642    #[test]
643    fn test_cached_path() {
644        let path = cached_path("resnet18");
645        assert!(path.to_string_lossy().contains("resnet18"));
646        assert!(path.to_string_lossy().ends_with(".safetensors"));
647    }
648
649    #[test]
650    fn test_save_load_state_dict() {
651        // Create a simple state dict for testing
652        let mut state = StateDict::new();
653        state.insert("layer.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap());
654        state.insert("layer.bias".to_string(), Tensor::from_vec(vec![0.1, 0.2], &[2]).unwrap());
655
656        let temp_path = std::env::temp_dir().join("test_weights.bin");
657        save_state_dict(&state, &temp_path).unwrap();
658
659        let loaded = load_state_dict(&temp_path).unwrap();
660        assert_eq!(state.len(), loaded.len());
661
662        // Verify tensor shapes
663        let weight = loaded.get("layer.weight").unwrap();
664        assert_eq!(weight.shape(), &[2, 2]);
665
666        let bias = loaded.get("layer.bias").unwrap();
667        assert_eq!(bias.shape(), &[2]);
668
669        // Clean up
670        let _ = std::fs::remove_file(&temp_path);
671    }
672}