1use std::collections::HashMap;
9use std::fs::{self, File};
10use std::io::{Read, Write};
11use std::path::PathBuf;
12
13use axonml_tensor::Tensor;
14
15#[derive(Debug)]
21pub enum HubError {
22 NetworkError(String),
24 IoError(std::io::Error),
26 ModelNotFound(String),
28 InvalidFormat(String),
30 ChecksumMismatch {
32 expected: String,
34 actual: String,
36 },
37}
38
39impl std::fmt::Display for HubError {
40 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41 match self {
42 HubError::NetworkError(e) => write!(f, "Network error: {}", e),
43 HubError::IoError(e) => write!(f, "IO error: {}", e),
44 HubError::ModelNotFound(name) => write!(f, "Model not found: {}", name),
45 HubError::InvalidFormat(msg) => write!(f, "Invalid format: {}", msg),
46 HubError::ChecksumMismatch { expected, actual } => {
47 write!(f, "Checksum mismatch: expected {}, got {}", expected, actual)
48 }
49 }
50 }
51}
52
53impl std::error::Error for HubError {}
54
55impl From<std::io::Error> for HubError {
56 fn from(e: std::io::Error) -> Self {
57 HubError::IoError(e)
58 }
59}
60
61pub type HubResult<T> = Result<T, HubError>;
63
64#[derive(Debug, Clone)]
70pub struct PretrainedModel {
71 pub name: String,
73 pub url: String,
75 pub checksum: Option<String>,
77 pub size_bytes: u64,
79 pub num_classes: usize,
81 pub input_size: (usize, usize),
83 pub dataset: String,
85 pub accuracy: f32,
87}
88
89pub fn cache_dir() -> PathBuf {
91 let base = dirs::cache_dir()
92 .or_else(dirs::home_dir)
93 .unwrap_or_else(|| PathBuf::from("."));
94 base.join("axonml").join("hub").join("weights")
95}
96
97pub fn model_registry() -> HashMap<String, PretrainedModel> {
99 let mut registry = HashMap::new();
100
101 registry.insert(
103 "resnet18".to_string(),
104 PretrainedModel {
105 name: "resnet18".to_string(),
106 url: "https://huggingface.co/axonml-ml/resnet18-imagenet/resolve/main/resnet18.safetensors".to_string(),
107 checksum: None,
108 size_bytes: 44_700_000,
109 num_classes: 1000,
110 input_size: (224, 224),
111 dataset: "ImageNet-1K".to_string(),
112 accuracy: 69.76,
113 },
114 );
115
116 registry.insert(
117 "resnet34".to_string(),
118 PretrainedModel {
119 name: "resnet34".to_string(),
120 url: "https://huggingface.co/axonml-ml/resnet34-imagenet/resolve/main/resnet34.safetensors".to_string(),
121 checksum: None,
122 size_bytes: 83_300_000,
123 num_classes: 1000,
124 input_size: (224, 224),
125 dataset: "ImageNet-1K".to_string(),
126 accuracy: 73.31,
127 },
128 );
129
130 registry.insert(
131 "resnet50".to_string(),
132 PretrainedModel {
133 name: "resnet50".to_string(),
134 url: "https://huggingface.co/axonml-ml/resnet50-imagenet/resolve/main/resnet50.safetensors".to_string(),
135 checksum: None,
136 size_bytes: 97_800_000,
137 num_classes: 1000,
138 input_size: (224, 224),
139 dataset: "ImageNet-1K".to_string(),
140 accuracy: 76.13,
141 },
142 );
143
144 registry.insert(
146 "vgg16".to_string(),
147 PretrainedModel {
148 name: "vgg16".to_string(),
149 url: "https://huggingface.co/axonml-ml/vgg16-imagenet/resolve/main/vgg16.safetensors".to_string(),
150 checksum: None,
151 size_bytes: 528_000_000,
152 num_classes: 1000,
153 input_size: (224, 224),
154 dataset: "ImageNet-1K".to_string(),
155 accuracy: 71.59,
156 },
157 );
158
159 registry.insert(
160 "vgg19".to_string(),
161 PretrainedModel {
162 name: "vgg19".to_string(),
163 url: "https://huggingface.co/axonml-ml/vgg19-imagenet/resolve/main/vgg19.safetensors".to_string(),
164 checksum: None,
165 size_bytes: 548_000_000,
166 num_classes: 1000,
167 input_size: (224, 224),
168 dataset: "ImageNet-1K".to_string(),
169 accuracy: 72.38,
170 },
171 );
172
173 registry.insert(
174 "vgg16_bn".to_string(),
175 PretrainedModel {
176 name: "vgg16_bn".to_string(),
177 url: "https://huggingface.co/axonml-ml/vgg16bn-imagenet/resolve/main/vgg16_bn.safetensors".to_string(),
178 checksum: None,
179 size_bytes: 528_000_000,
180 num_classes: 1000,
181 input_size: (224, 224),
182 dataset: "ImageNet-1K".to_string(),
183 accuracy: 73.36,
184 },
185 );
186
187 registry.insert(
189 "resnet101".to_string(),
190 PretrainedModel {
191 name: "resnet101".to_string(),
192 url: "https://huggingface.co/axonml-ml/resnet101-imagenet/resolve/main/resnet101.safetensors".to_string(),
193 checksum: None,
194 size_bytes: 170_500_000,
195 num_classes: 1000,
196 input_size: (224, 224),
197 dataset: "ImageNet-1K".to_string(),
198 accuracy: 77.37,
199 },
200 );
201
202 registry.insert(
203 "resnet152".to_string(),
204 PretrainedModel {
205 name: "resnet152".to_string(),
206 url: "https://huggingface.co/axonml-ml/resnet152-imagenet/resolve/main/resnet152.safetensors".to_string(),
207 checksum: None,
208 size_bytes: 230_400_000,
209 num_classes: 1000,
210 input_size: (224, 224),
211 dataset: "ImageNet-1K".to_string(),
212 accuracy: 78.31,
213 },
214 );
215
216 registry.insert(
218 "mobilenet_v2".to_string(),
219 PretrainedModel {
220 name: "mobilenet_v2".to_string(),
221 url: "https://huggingface.co/axonml-ml/mobilenetv2-imagenet/resolve/main/mobilenet_v2.safetensors".to_string(),
222 checksum: None,
223 size_bytes: 13_600_000,
224 num_classes: 1000,
225 input_size: (224, 224),
226 dataset: "ImageNet-1K".to_string(),
227 accuracy: 71.88,
228 },
229 );
230
231 registry.insert(
232 "mobilenet_v3_small".to_string(),
233 PretrainedModel {
234 name: "mobilenet_v3_small".to_string(),
235 url: "https://huggingface.co/axonml-ml/mobilenetv3-small-imagenet/resolve/main/mobilenet_v3_small.safetensors".to_string(),
236 checksum: None,
237 size_bytes: 9_800_000,
238 num_classes: 1000,
239 input_size: (224, 224),
240 dataset: "ImageNet-1K".to_string(),
241 accuracy: 67.67,
242 },
243 );
244
245 registry.insert(
246 "mobilenet_v3_large".to_string(),
247 PretrainedModel {
248 name: "mobilenet_v3_large".to_string(),
249 url: "https://huggingface.co/axonml-ml/mobilenetv3-large-imagenet/resolve/main/mobilenet_v3_large.safetensors".to_string(),
250 checksum: None,
251 size_bytes: 21_100_000,
252 num_classes: 1000,
253 input_size: (224, 224),
254 dataset: "ImageNet-1K".to_string(),
255 accuracy: 74.04,
256 },
257 );
258
259 registry.insert(
261 "efficientnet_b0".to_string(),
262 PretrainedModel {
263 name: "efficientnet_b0".to_string(),
264 url: "https://huggingface.co/axonml-ml/efficientnet-b0-imagenet/resolve/main/efficientnet_b0.safetensors".to_string(),
265 checksum: None,
266 size_bytes: 20_300_000,
267 num_classes: 1000,
268 input_size: (224, 224),
269 dataset: "ImageNet-1K".to_string(),
270 accuracy: 77.10,
271 },
272 );
273
274 registry.insert(
275 "efficientnet_b1".to_string(),
276 PretrainedModel {
277 name: "efficientnet_b1".to_string(),
278 url: "https://huggingface.co/axonml-ml/efficientnet-b1-imagenet/resolve/main/efficientnet_b1.safetensors".to_string(),
279 checksum: None,
280 size_bytes: 30_100_000,
281 num_classes: 1000,
282 input_size: (240, 240),
283 dataset: "ImageNet-1K".to_string(),
284 accuracy: 78.80,
285 },
286 );
287
288 registry.insert(
289 "efficientnet_b2".to_string(),
290 PretrainedModel {
291 name: "efficientnet_b2".to_string(),
292 url: "https://huggingface.co/axonml-ml/efficientnet-b2-imagenet/resolve/main/efficientnet_b2.safetensors".to_string(),
293 checksum: None,
294 size_bytes: 35_200_000,
295 num_classes: 1000,
296 input_size: (260, 260),
297 dataset: "ImageNet-1K".to_string(),
298 accuracy: 79.80,
299 },
300 );
301
302 registry.insert(
304 "densenet121".to_string(),
305 PretrainedModel {
306 name: "densenet121".to_string(),
307 url: "https://huggingface.co/axonml-ml/densenet121-imagenet/resolve/main/densenet121.safetensors".to_string(),
308 checksum: None,
309 size_bytes: 30_800_000,
310 num_classes: 1000,
311 input_size: (224, 224),
312 dataset: "ImageNet-1K".to_string(),
313 accuracy: 74.43,
314 },
315 );
316
317 registry.insert(
318 "densenet169".to_string(),
319 PretrainedModel {
320 name: "densenet169".to_string(),
321 url: "https://huggingface.co/axonml-ml/densenet169-imagenet/resolve/main/densenet169.safetensors".to_string(),
322 checksum: None,
323 size_bytes: 54_700_000,
324 num_classes: 1000,
325 input_size: (224, 224),
326 dataset: "ImageNet-1K".to_string(),
327 accuracy: 75.60,
328 },
329 );
330
331 registry.insert(
333 "vit_b_16".to_string(),
334 PretrainedModel {
335 name: "vit_b_16".to_string(),
336 url: "https://huggingface.co/axonml-ml/vit-b16-imagenet/resolve/main/vit_b_16.safetensors".to_string(),
337 checksum: None,
338 size_bytes: 330_200_000,
339 num_classes: 1000,
340 input_size: (224, 224),
341 dataset: "ImageNet-1K".to_string(),
342 accuracy: 81.07,
343 },
344 );
345
346 registry.insert(
347 "vit_b_32".to_string(),
348 PretrainedModel {
349 name: "vit_b_32".to_string(),
350 url: "https://huggingface.co/axonml-ml/vit-b32-imagenet/resolve/main/vit_b_32.safetensors".to_string(),
351 checksum: None,
352 size_bytes: 337_500_000,
353 num_classes: 1000,
354 input_size: (224, 224),
355 dataset: "ImageNet-1K".to_string(),
356 accuracy: 75.91,
357 },
358 );
359
360 registry.insert(
362 "swin_t".to_string(),
363 PretrainedModel {
364 name: "swin_t".to_string(),
365 url: "https://huggingface.co/axonml-ml/swin-tiny-imagenet/resolve/main/swin_t.safetensors".to_string(),
366 checksum: None,
367 size_bytes: 110_700_000,
368 num_classes: 1000,
369 input_size: (224, 224),
370 dataset: "ImageNet-1K".to_string(),
371 accuracy: 81.30,
372 },
373 );
374
375 registry.insert(
376 "swin_s".to_string(),
377 PretrainedModel {
378 name: "swin_s".to_string(),
379 url: "https://huggingface.co/axonml-ml/swin-small-imagenet/resolve/main/swin_s.safetensors".to_string(),
380 checksum: None,
381 size_bytes: 193_500_000,
382 num_classes: 1000,
383 input_size: (224, 224),
384 dataset: "ImageNet-1K".to_string(),
385 accuracy: 83.20,
386 },
387 );
388
389 registry.insert(
391 "convnext_tiny".to_string(),
392 PretrainedModel {
393 name: "convnext_tiny".to_string(),
394 url: "https://huggingface.co/axonml-ml/convnext-tiny-imagenet/resolve/main/convnext_tiny.safetensors".to_string(),
395 checksum: None,
396 size_bytes: 109_100_000,
397 num_classes: 1000,
398 input_size: (224, 224),
399 dataset: "ImageNet-1K".to_string(),
400 accuracy: 82.10,
401 },
402 );
403
404 registry.insert(
405 "convnext_small".to_string(),
406 PretrainedModel {
407 name: "convnext_small".to_string(),
408 url: "https://huggingface.co/axonml-ml/convnext-small-imagenet/resolve/main/convnext_small.safetensors".to_string(),
409 checksum: None,
410 size_bytes: 195_600_000,
411 num_classes: 1000,
412 input_size: (224, 224),
413 dataset: "ImageNet-1K".to_string(),
414 accuracy: 83.10,
415 },
416 );
417
418 registry
419}
420
421pub type StateDict = HashMap<String, Tensor<f32>>;
427
428pub fn is_cached(model_name: &str) -> bool {
430 let path = cache_dir().join(format!("{}.safetensors", model_name));
431 path.exists()
432}
433
434pub fn cached_path(model_name: &str) -> PathBuf {
436 cache_dir().join(format!("{}.safetensors", model_name))
437}
438
439pub fn download_weights(model_name: &str, force: bool) -> HubResult<PathBuf> {
448 let registry = model_registry();
449 let model_info = registry
450 .get(model_name)
451 .ok_or_else(|| HubError::ModelNotFound(model_name.to_string()))?;
452
453 let cache_path = cached_path(model_name);
454
455 if cache_path.exists() && !force {
457 return Ok(cache_path);
458 }
459
460 if let Some(parent) = cache_path.parent() {
462 fs::create_dir_all(parent)?;
463 }
464
465 println!("Downloading {} weights ({:.1} MB)...", model_name, model_info.size_bytes as f64 / 1_000_000.0);
467
468 let response = reqwest::blocking::get(&model_info.url)
469 .map_err(|e| HubError::NetworkError(e.to_string()))?;
470
471 if !response.status().is_success() {
472 return Err(HubError::NetworkError(format!(
473 "HTTP {}: {}",
474 response.status(),
475 model_info.url
476 )));
477 }
478
479 let bytes = response.bytes()
480 .map_err(|e| HubError::NetworkError(e.to_string()))?;
481
482 let mut file = File::create(&cache_path)?;
483 file.write_all(&bytes)?;
484
485 println!("Downloaded to {:?}", cache_path);
486
487 Ok(cache_path)
488}
489
490pub fn save_state_dict(state: &StateDict, path: &PathBuf) -> HubResult<()> {
504 use std::io::BufWriter;
505
506 let file = File::create(path)?;
507 let mut writer = BufWriter::new(file);
508
509 let num_tensors = state.len() as u32;
511 writer.write_all(&num_tensors.to_le_bytes())?;
512
513 for (name, tensor) in state {
514 let name_bytes = name.as_bytes();
516 let name_len = name_bytes.len() as u32;
517 writer.write_all(&name_len.to_le_bytes())?;
518 writer.write_all(name_bytes)?;
519
520 let shape = tensor.shape();
522 let ndim = shape.len() as u32;
523 writer.write_all(&ndim.to_le_bytes())?;
524 for &dim in shape {
525 writer.write_all(&(dim as u64).to_le_bytes())?;
526 }
527
528 let data = tensor.to_vec();
530 for val in data {
531 writer.write_all(&val.to_le_bytes())?;
532 }
533 }
534
535 Ok(())
536}
537
538pub fn load_state_dict(path: &PathBuf) -> HubResult<StateDict> {
540 use std::io::BufReader;
541
542 let file = File::open(path)?;
543 let mut reader = BufReader::new(file);
544
545 let mut buf4 = [0u8; 4];
547 reader.read_exact(&mut buf4)?;
548 let num_tensors = u32::from_le_bytes(buf4);
549
550 let mut state = HashMap::new();
551
552 for _ in 0..num_tensors {
553 reader.read_exact(&mut buf4)?;
555 let name_len = u32::from_le_bytes(buf4) as usize;
556 let mut name_bytes = vec![0u8; name_len];
557 reader.read_exact(&mut name_bytes)?;
558 let name = String::from_utf8_lossy(&name_bytes).to_string();
559
560 reader.read_exact(&mut buf4)?;
562 let ndim = u32::from_le_bytes(buf4) as usize;
563 let mut shape = Vec::with_capacity(ndim);
564 let mut buf8 = [0u8; 8];
565 for _ in 0..ndim {
566 reader.read_exact(&mut buf8)?;
567 shape.push(u64::from_le_bytes(buf8) as usize);
568 }
569
570 let numel: usize = shape.iter().product();
572 let mut data = Vec::with_capacity(numel);
573 for _ in 0..numel {
574 reader.read_exact(&mut buf4)?;
575 data.push(f32::from_le_bytes(buf4));
576 }
577
578 let tensor = Tensor::from_vec(data, &shape)
579 .map_err(|e| HubError::InvalidFormat(format!("{:?}", e)))?;
580 state.insert(name, tensor);
581 }
582
583 Ok(state)
584}
585
586pub fn list_models() -> Vec<PretrainedModel> {
588 model_registry().into_values().collect()
589}
590
591pub fn model_info(name: &str) -> Option<PretrainedModel> {
593 model_registry().get(name).cloned()
594}
595
596#[cfg(test)]
601mod tests {
602 use super::*;
603
604 #[test]
605 fn test_model_registry() {
606 let registry = model_registry();
607 assert!(registry.contains_key("resnet18"));
608 assert!(registry.contains_key("vgg16"));
609 }
610
611 #[test]
612 fn test_cache_dir() {
613 let dir = cache_dir();
614 assert!(dir.to_string_lossy().contains("axonml"));
615 }
616
617 #[test]
618 fn test_list_models() {
619 let models = list_models();
620 assert!(!models.is_empty());
621 }
622
623 #[test]
624 fn test_model_info() {
625 let info = model_info("resnet18");
626 assert!(info.is_some());
627 let info = info.unwrap();
628 assert_eq!(info.num_classes, 1000);
629 assert_eq!(info.input_size, (224, 224));
630 }
631
632 #[test]
633 fn test_model_urls() {
634 let registry = model_registry();
635 for (name, model) in ®istry {
636 assert!(!model.url.is_empty(), "Model {} has empty URL", name);
637 assert!(model.url.starts_with("https://"), "Model {} URL should be HTTPS", name);
638 assert!(model.size_bytes > 0, "Model {} has zero size", name);
639 }
640 }
641
642 #[test]
643 fn test_cached_path() {
644 let path = cached_path("resnet18");
645 assert!(path.to_string_lossy().contains("resnet18"));
646 assert!(path.to_string_lossy().ends_with(".safetensors"));
647 }
648
649 #[test]
650 fn test_save_load_state_dict() {
651 let mut state = StateDict::new();
653 state.insert("layer.weight".to_string(), Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap());
654 state.insert("layer.bias".to_string(), Tensor::from_vec(vec![0.1, 0.2], &[2]).unwrap());
655
656 let temp_path = std::env::temp_dir().join("test_weights.bin");
657 save_state_dict(&state, &temp_path).unwrap();
658
659 let loaded = load_state_dict(&temp_path).unwrap();
660 assert_eq!(state.len(), loaded.len());
661
662 let weight = loaded.get("layer.weight").unwrap();
664 assert_eq!(weight.shape(), &[2, 2]);
665
666 let bias = loaded.get("layer.bias").unwrap();
667 assert_eq!(bias.shape(), &[2]);
668
669 let _ = std::fs::remove_file(&temp_path);
671 }
672}