1use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11
12use crate::ModelError;
13use crate::load_state_dict;
14use yscv_tensor::Tensor;
15
16#[derive(Debug, Clone)]
22pub struct HubEntry {
23 pub url: String,
25 pub expected_size: u64,
27 pub filename: String,
29}
30
31pub struct ModelHub {
33 cache_dir: PathBuf,
34 registry: HashMap<String, HubEntry>,
35}
36
37pub fn default_cache_dir() -> PathBuf {
45 if let Ok(dir) = std::env::var("RUSTCV_CACHE_DIR") {
46 return PathBuf::from(dir);
47 }
48 let home = std::env::var("HOME")
50 .map(PathBuf::from)
51 .unwrap_or_else(|_| PathBuf::from("."));
52 home.join(".yscv").join("models")
53}
54
55fn build_registry() -> HashMap<String, HubEntry> {
60 let mut m = HashMap::new();
61
62 m.insert(
63 "resnet18".into(),
64 HubEntry {
65 url: "https://huggingface.co/timm/resnet18.a1_in1k/resolve/main/model.safetensors"
66 .into(),
67 expected_size: 46_830_408,
68 filename: "resnet18.safetensors".into(),
69 },
70 );
71 m.insert(
72 "resnet34".into(),
73 HubEntry {
74 url: "https://huggingface.co/timm/resnet34.a1_in1k/resolve/main/model.safetensors"
75 .into(),
76 expected_size: 87_338_584,
77 filename: "resnet34.safetensors".into(),
78 },
79 );
80 m.insert(
81 "resnet50".into(),
82 HubEntry {
83 url: "https://huggingface.co/timm/resnet50.a1_in1k/resolve/main/model.safetensors"
84 .into(),
85 expected_size: 102_170_688,
86 filename: "resnet50.safetensors".into(),
87 },
88 );
89 m.insert(
90 "resnet101".into(),
91 HubEntry {
92 url: "https://huggingface.co/timm/resnet101.a1_in1k/resolve/main/model.safetensors"
93 .into(),
94 expected_size: 178_834_240,
95 filename: "resnet101.safetensors".into(),
96 },
97 );
98 m.insert(
99 "vgg16".into(),
100 HubEntry {
101 url: "https://huggingface.co/timm/vgg16.tv_in1k/resolve/main/model.safetensors".into(),
102 expected_size: 553_507_904,
103 filename: "vgg16.safetensors".into(),
104 },
105 );
106 m.insert(
107 "vgg19".into(),
108 HubEntry {
109 url: "https://huggingface.co/timm/vgg19.tv_in1k/resolve/main/model.safetensors".into(),
110 expected_size: 574_879_552,
111 filename: "vgg19.safetensors".into(),
112 },
113 );
114 m.insert(
115 "mobilenet_v2".into(),
116 HubEntry {
117 url:
118 "https://huggingface.co/timm/mobilenetv2_100.ra_in1k/resolve/main/model.safetensors"
119 .into(),
120 expected_size: 14_214_848,
121 filename: "mobilenet_v2.safetensors".into(),
122 },
123 );
124 m.insert(
125 "efficientnet_b0".into(),
126 HubEntry {
127 url:
128 "https://huggingface.co/timm/efficientnet_b0.ra_in1k/resolve/main/model.safetensors"
129 .into(),
130 expected_size: 21_388_928,
131 filename: "efficientnet_b0.safetensors".into(),
132 },
133 );
134 m.insert(
135 "alexnet".into(),
136 HubEntry {
137 url: "https://huggingface.co/pytorch/alexnet/resolve/main/model.safetensors".into(),
138 expected_size: 244_408_336,
139 filename: "alexnet.safetensors".into(),
140 },
141 );
142
143 m
144}
145
146impl ModelHub {
151 pub fn new() -> Self {
154 Self {
155 cache_dir: default_cache_dir(),
156 registry: build_registry(),
157 }
158 }
159
160 pub fn cache_dir(&self) -> &Path {
162 &self.cache_dir
163 }
164
165 pub fn registry(&self) -> &HashMap<String, HubEntry> {
167 &self.registry
168 }
169
170 pub fn download_if_missing(&self, name: &str) -> Result<PathBuf, ModelError> {
175 let entry = self
176 .registry
177 .get(name)
178 .ok_or_else(|| ModelError::DownloadFailed {
179 url: name.to_string(),
180 reason: format!("model '{name}' is not in the hub registry"),
181 })?;
182
183 let dest = self.cache_dir.join(&entry.filename);
184
185 if dest.is_file() {
187 validate_file_size(&dest, entry.expected_size)?;
188 return Ok(dest);
189 }
190
191 std::fs::create_dir_all(&self.cache_dir).map_err(|e| ModelError::DownloadFailed {
193 url: entry.url.clone(),
194 reason: format!(
195 "failed to create cache dir {}: {e}",
196 self.cache_dir.display()
197 ),
198 })?;
199
200 let output = Command::new("curl")
202 .args(["-fSL", "-o"])
203 .arg(&dest)
204 .arg(&entry.url)
205 .output()
206 .map_err(|e| ModelError::DownloadFailed {
207 url: entry.url.clone(),
208 reason: format!("failed to run curl: {e}"),
209 })?;
210
211 if !output.status.success() {
212 let _ = std::fs::remove_file(&dest);
214 let stderr = String::from_utf8_lossy(&output.stderr);
215 return Err(ModelError::DownloadFailed {
216 url: entry.url.clone(),
217 reason: format!("curl exited with {}: {stderr}", output.status),
218 });
219 }
220
221 validate_file_size(&dest, entry.expected_size)?;
222 Ok(dest)
223 }
224
225 pub fn load_weights(&self, name: &str) -> Result<HashMap<String, Tensor>, ModelError> {
228 let path = self.download_if_missing(name)?;
229 load_state_dict(&path)
230 }
231}
232
233impl Default for ModelHub {
234 fn default() -> Self {
235 Self::new()
236 }
237}
238
239fn validate_file_size(path: &Path, expected: u64) -> Result<(), ModelError> {
244 let meta = std::fs::metadata(path).map_err(|e| ModelError::DownloadFailed {
245 url: path.display().to_string(),
246 reason: format!("cannot stat downloaded file: {e}"),
247 })?;
248 let actual = meta.len();
249 if actual != expected {
250 return Err(ModelError::DownloadFailed {
251 url: path.display().to_string(),
252 reason: format!("file size mismatch: expected {expected} bytes, got {actual} bytes"),
253 });
254 }
255 Ok(())
256}