Skip to main content

yscv_model/
hub.rs

1//! Pretrained model hub: download and cache weights from remote sources.
2//!
3//! Uses `curl` via `std::process::Command` to avoid adding heavy HTTP
4//! dependencies.  Downloaded `.safetensors` files are cached under
5//! `$RUSTCV_CACHE_DIR` (or `~/.yscv/models/` by default) and validated
6//! by expected file size.
7
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10use std::process::Command;
11
12use crate::ModelError;
13use crate::load_state_dict;
14use yscv_tensor::Tensor;
15
16// ---------------------------------------------------------------------------
17// Public types
18// ---------------------------------------------------------------------------
19
20/// Registry entry for a pretrained model.
21#[derive(Debug, Clone)]
22pub struct HubEntry {
23    /// URL to download the `.safetensors` file.
24    pub url: String,
25    /// Expected file size in bytes (used for validation after download).
26    pub expected_size: u64,
27    /// Local filename inside the cache directory.
28    pub filename: String,
29}
30
31/// Model hub for downloading and caching pretrained weights.
32pub struct ModelHub {
33    cache_dir: PathBuf,
34    registry: HashMap<String, HubEntry>,
35}
36
37// ---------------------------------------------------------------------------
38// Helpers
39// ---------------------------------------------------------------------------
40
41/// Returns the default cache directory for downloaded model weights.
42///
43/// Uses `$RUSTCV_CACHE_DIR` if set, otherwise `~/.yscv/models/`.
44pub fn default_cache_dir() -> PathBuf {
45    if let Ok(dir) = std::env::var("RUSTCV_CACHE_DIR") {
46        return PathBuf::from(dir);
47    }
48    // Fall back to ~/.yscv/models/
49    let home = std::env::var("HOME")
50        .map(PathBuf::from)
51        .unwrap_or_else(|_| PathBuf::from("."));
52    home.join(".yscv").join("models")
53}
54
55// ---------------------------------------------------------------------------
56// Registry population
57// ---------------------------------------------------------------------------
58
59fn build_registry() -> HashMap<String, HubEntry> {
60    let mut m = HashMap::new();
61
62    m.insert(
63        "resnet18".into(),
64        HubEntry {
65            url: "https://huggingface.co/timm/resnet18.a1_in1k/resolve/main/model.safetensors"
66                .into(),
67            expected_size: 46_830_408,
68            filename: "resnet18.safetensors".into(),
69        },
70    );
71    m.insert(
72        "resnet34".into(),
73        HubEntry {
74            url: "https://huggingface.co/timm/resnet34.a1_in1k/resolve/main/model.safetensors"
75                .into(),
76            expected_size: 87_338_584,
77            filename: "resnet34.safetensors".into(),
78        },
79    );
80    m.insert(
81        "resnet50".into(),
82        HubEntry {
83            url: "https://huggingface.co/timm/resnet50.a1_in1k/resolve/main/model.safetensors"
84                .into(),
85            expected_size: 102_170_688,
86            filename: "resnet50.safetensors".into(),
87        },
88    );
89    m.insert(
90        "resnet101".into(),
91        HubEntry {
92            url: "https://huggingface.co/timm/resnet101.a1_in1k/resolve/main/model.safetensors"
93                .into(),
94            expected_size: 178_834_240,
95            filename: "resnet101.safetensors".into(),
96        },
97    );
98    m.insert(
99        "vgg16".into(),
100        HubEntry {
101            url: "https://huggingface.co/timm/vgg16.tv_in1k/resolve/main/model.safetensors".into(),
102            expected_size: 553_507_904,
103            filename: "vgg16.safetensors".into(),
104        },
105    );
106    m.insert(
107        "vgg19".into(),
108        HubEntry {
109            url: "https://huggingface.co/timm/vgg19.tv_in1k/resolve/main/model.safetensors".into(),
110            expected_size: 574_879_552,
111            filename: "vgg19.safetensors".into(),
112        },
113    );
114    m.insert(
115        "mobilenet_v2".into(),
116        HubEntry {
117            url:
118                "https://huggingface.co/timm/mobilenetv2_100.ra_in1k/resolve/main/model.safetensors"
119                    .into(),
120            expected_size: 14_214_848,
121            filename: "mobilenet_v2.safetensors".into(),
122        },
123    );
124    m.insert(
125        "efficientnet_b0".into(),
126        HubEntry {
127            url:
128                "https://huggingface.co/timm/efficientnet_b0.ra_in1k/resolve/main/model.safetensors"
129                    .into(),
130            expected_size: 21_388_928,
131            filename: "efficientnet_b0.safetensors".into(),
132        },
133    );
134    m.insert(
135        "alexnet".into(),
136        HubEntry {
137            url: "https://huggingface.co/pytorch/alexnet/resolve/main/model.safetensors".into(),
138            expected_size: 244_408_336,
139            filename: "alexnet.safetensors".into(),
140        },
141    );
142
143    m
144}
145
146// ---------------------------------------------------------------------------
147// ModelHub implementation
148// ---------------------------------------------------------------------------
149
150impl ModelHub {
151    /// Creates a new hub with the default cache directory and built-in
152    /// registry of known pretrained models.
153    pub fn new() -> Self {
154        Self {
155            cache_dir: default_cache_dir(),
156            registry: build_registry(),
157        }
158    }
159
160    /// Returns the cache directory path.
161    pub fn cache_dir(&self) -> &Path {
162        &self.cache_dir
163    }
164
165    /// Returns a reference to the internal registry.
166    pub fn registry(&self) -> &HashMap<String, HubEntry> {
167        &self.registry
168    }
169
170    /// Ensures the weight file for `name` is present in the local cache,
171    /// downloading it via `curl` if necessary.
172    ///
173    /// Returns the path to the cached file on success.
174    pub fn download_if_missing(&self, name: &str) -> Result<PathBuf, ModelError> {
175        let entry = self
176            .registry
177            .get(name)
178            .ok_or_else(|| ModelError::DownloadFailed {
179                url: name.to_string(),
180                reason: format!("model '{name}' is not in the hub registry"),
181            })?;
182
183        let dest = self.cache_dir.join(&entry.filename);
184
185        // Already cached — validate size and return.
186        if dest.is_file() {
187            validate_file_size(&dest, entry.expected_size)?;
188            return Ok(dest);
189        }
190
191        // Ensure cache directory exists.
192        std::fs::create_dir_all(&self.cache_dir).map_err(|e| ModelError::DownloadFailed {
193            url: entry.url.clone(),
194            reason: format!(
195                "failed to create cache dir {}: {e}",
196                self.cache_dir.display()
197            ),
198        })?;
199
200        // Download with curl.
201        let output = Command::new("curl")
202            .args(["-fSL", "-o"])
203            .arg(&dest)
204            .arg(&entry.url)
205            .output()
206            .map_err(|e| ModelError::DownloadFailed {
207                url: entry.url.clone(),
208                reason: format!("failed to run curl: {e}"),
209            })?;
210
211        if !output.status.success() {
212            // Clean up partial file.
213            let _ = std::fs::remove_file(&dest);
214            let stderr = String::from_utf8_lossy(&output.stderr);
215            return Err(ModelError::DownloadFailed {
216                url: entry.url.clone(),
217                reason: format!("curl exited with {}: {stderr}", output.status),
218            });
219        }
220
221        validate_file_size(&dest, entry.expected_size)?;
222        Ok(dest)
223    }
224
225    /// Downloads (if needed) and loads all tensors from the safetensors
226    /// weight file for the given model name.
227    pub fn load_weights(&self, name: &str) -> Result<HashMap<String, Tensor>, ModelError> {
228        let path = self.download_if_missing(name)?;
229        load_state_dict(&path)
230    }
231}
232
233impl Default for ModelHub {
234    fn default() -> Self {
235        Self::new()
236    }
237}
238
239// ---------------------------------------------------------------------------
240// Internal helpers
241// ---------------------------------------------------------------------------
242
243fn validate_file_size(path: &Path, expected: u64) -> Result<(), ModelError> {
244    let meta = std::fs::metadata(path).map_err(|e| ModelError::DownloadFailed {
245        url: path.display().to_string(),
246        reason: format!("cannot stat downloaded file: {e}"),
247    })?;
248    let actual = meta.len();
249    if actual != expected {
250        return Err(ModelError::DownloadFailed {
251            url: path.display().to_string(),
252            reason: format!("file size mismatch: expected {expected} bytes, got {actual} bytes"),
253        });
254    }
255    Ok(())
256}