Skip to main content

hf_fetch_model/
cache.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! `HuggingFace` cache directory resolution and local model family scanning.
4//!
5//! [`hf_cache_dir()`] locates the local HF cache. [`list_cached_families()`]
6//! scans downloaded models and groups them by `model_type`.
7
8use std::collections::BTreeMap;
9use std::path::{Path, PathBuf};
10
11use crate::error::FetchError;
12
13/// Returns the `HuggingFace` Hub cache directory.
14///
15/// Resolution order:
16/// 1. `HF_HOME` environment variable + `/hub`
17/// 2. `~/.cache/huggingface/hub/` (via [`dirs::home_dir()`])
18///
19/// # Errors
20///
21/// Returns [`FetchError::Io`] if the home directory cannot be determined.
22pub fn hf_cache_dir() -> Result<PathBuf, FetchError> {
23    if let Ok(home) = std::env::var("HF_HOME") {
24        let mut path = PathBuf::from(home);
25        path.push("hub");
26        return Ok(path);
27    }
28
29    let home = dirs::home_dir().ok_or_else(|| FetchError::Io {
30        path: PathBuf::from("~"),
31        source: std::io::Error::new(std::io::ErrorKind::NotFound, "home directory not found"),
32    })?;
33
34    let mut path = home;
35    path.push(".cache");
36    path.push("huggingface");
37    path.push("hub");
38    Ok(path)
39}
40
41/// Scans the local HF cache for downloaded models and groups them by `model_type`.
42///
43/// Looks for `config.json` files inside model snapshot directories:
44/// `<cache>/models--<org>--<name>/snapshots/*/config.json`
45///
46/// Returns a map from `model_type` (e.g., `"llama"`) to a sorted list of
47/// repository identifiers (e.g., `["meta-llama/Llama-3.2-1B"]`).
48///
49/// Models without a `model_type` field in their `config.json` are skipped.
50///
51/// # Errors
52///
53/// Returns [`FetchError::Io`] if the cache directory cannot be read.
54pub fn list_cached_families() -> Result<BTreeMap<String, Vec<String>>, FetchError> {
55    let cache_dir = hf_cache_dir()?;
56
57    if !cache_dir.exists() {
58        return Ok(BTreeMap::new());
59    }
60
61    let entries = std::fs::read_dir(&cache_dir).map_err(|e| FetchError::Io {
62        path: cache_dir.clone(),
63        source: e,
64    })?;
65
66    let mut families: BTreeMap<String, Vec<String>> = BTreeMap::new();
67
68    for entry in entries {
69        let Ok(entry) = entry else { continue };
70
71        let dir_name = entry.file_name();
72        // BORROW: explicit .to_string_lossy() for OsString → str conversion
73        let dir_str = dir_name.to_string_lossy();
74
75        // Only process model directories (models--org--name)
76        let Some(repo_part) = dir_str.strip_prefix("models--") else {
77            continue;
78        };
79
80        // Reconstruct repo_id: replace first "--" with "/"
81        let repo_id = match repo_part.find("--") {
82            Some(pos) => {
83                let (org, name_with_sep) = repo_part.split_at(pos);
84                let name = name_with_sep.get(2..).unwrap_or_default();
85                format!("{org}/{name}")
86            }
87            None => repo_part.to_string(),
88        };
89
90        // Find the newest snapshot's config.json
91        let snapshots_dir = entry.path().join("snapshots");
92        if !snapshots_dir.exists() {
93            continue;
94        }
95
96        if let Some(model_type) = find_model_type_in_snapshots(&snapshots_dir) {
97            families.entry(model_type).or_default().push(repo_id);
98        }
99    }
100
101    // Sort repo lists within each family for stable output
102    for repos in families.values_mut() {
103        repos.sort();
104    }
105
106    Ok(families)
107}
108
109/// Searches snapshot directories for a `config.json` containing `model_type`.
110///
111/// Returns the first `model_type` value found, or `None`.
112fn find_model_type_in_snapshots(snapshots_dir: &std::path::Path) -> Option<String> {
113    let snapshots = std::fs::read_dir(snapshots_dir).ok()?;
114
115    for snap_entry in snapshots {
116        let Ok(snap_entry) = snap_entry else { continue };
117        let config_path = snap_entry.path().join("config.json");
118
119        if !config_path.exists() {
120            continue;
121        }
122
123        if let Some(model_type) = extract_model_type(&config_path) {
124            return Some(model_type);
125        }
126    }
127
128    None
129}
130
131/// Reads a `config.json` file and extracts the `model_type` field.
132fn extract_model_type(config_path: &std::path::Path) -> Option<String> {
133    let contents = std::fs::read_to_string(config_path).ok()?;
134    let value: serde_json::Value = serde_json::from_str(contents.as_str()).ok()?;
135    // BORROW: explicit .as_str() on serde_json Value
136    value.get("model_type")?.as_str().map(String::from)
137}
138
139/// Status of a single file in the cache.
140#[derive(Debug, Clone)]
141#[non_exhaustive]
142pub enum FileStatus {
143    /// File is fully downloaded (local size matches expected size, or no expected size known).
144    Complete {
145        /// Local file size in bytes.
146        local_size: u64,
147    },
148    /// File exists but is smaller than expected (interrupted download),
149    /// or a `.chunked.part` temp file was found in the blobs directory
150    /// (repo-level heuristic — may not correspond to this specific file).
151    Partial {
152        /// Local file size in bytes.
153        local_size: u64,
154        /// Expected file size in bytes.
155        expected_size: u64,
156    },
157    /// File is not present in the cache.
158    Missing {
159        /// Expected file size in bytes (0 if unknown).
160        expected_size: u64,
161    },
162}
163
164/// Cache status report for a repository.
165#[derive(Debug, Clone)]
166pub struct RepoStatus {
167    /// The repository identifier.
168    pub repo_id: String,
169    /// The resolved commit hash (if available).
170    pub commit_hash: Option<String>,
171    /// The cache directory for this repo.
172    pub cache_path: PathBuf,
173    /// Per-file status, sorted by filename.
174    pub files: Vec<(String, FileStatus)>,
175}
176
177impl RepoStatus {
178    /// Number of fully downloaded files.
179    #[must_use]
180    pub fn complete_count(&self) -> usize {
181        self.files
182            .iter()
183            .filter(|(_, s)| matches!(s, FileStatus::Complete { .. }))
184            .count()
185    }
186
187    /// Number of partially downloaded files.
188    #[must_use]
189    pub fn partial_count(&self) -> usize {
190        self.files
191            .iter()
192            .filter(|(_, s)| matches!(s, FileStatus::Partial { .. }))
193            .count()
194    }
195
196    /// Number of missing files.
197    #[must_use]
198    pub fn missing_count(&self) -> usize {
199        self.files
200            .iter()
201            .filter(|(_, s)| matches!(s, FileStatus::Missing { .. }))
202            .count()
203    }
204}
205
206/// Inspects the local cache for a repository and compares against the remote file list.
207///
208/// # Arguments
209///
210/// * `repo_id` — The repository identifier (e.g., `"RWKV/RWKV7-Goose-World3-1.5B-HF"`).
211/// * `token` — Optional authentication token.
212/// * `revision` — Optional revision (defaults to `"main"`).
213///
214/// # Errors
215///
216/// Returns [`FetchError::Http`] if the API request fails.
217/// Returns [`FetchError::Io`] if the cache directory cannot be read.
218pub async fn repo_status(
219    repo_id: &str,
220    token: Option<&str>,
221    revision: Option<&str>,
222) -> Result<RepoStatus, FetchError> {
223    let revision = revision.unwrap_or("main");
224    let cache_dir = hf_cache_dir()?;
225    let repo_folder = format!("models--{}", repo_id.replace('/', "--"));
226    // BORROW: explicit .as_str() for path construction
227    let repo_dir = cache_dir.join(repo_folder.as_str());
228
229    // Read commit hash from refs file if available.
230    let commit_hash = read_ref(&repo_dir, revision);
231
232    // Fetch remote file list with sizes.
233    let remote_files =
234        crate::repo::list_repo_files_with_metadata(repo_id, token, Some(revision)).await?;
235
236    // Determine snapshot directory.
237    // BORROW: explicit .as_deref() for Option<String> → Option<&str>
238    let snapshot_dir = commit_hash
239        .as_deref()
240        .map(|hash| repo_dir.join("snapshots").join(hash));
241
242    // Also check for .chunked.part files in blobs directory.
243    let blobs_dir = repo_dir.join("blobs");
244
245    // Cross-reference remote files against local state.
246    let mut files: Vec<(String, FileStatus)> = Vec::with_capacity(remote_files.len());
247
248    for remote in &remote_files {
249        let expected_size = remote.size.unwrap_or(0);
250
251        let local_path = snapshot_dir
252            .as_ref()
253            // BORROW: explicit .as_str() for path construction
254            .map(|dir| dir.join(remote.filename.as_str()));
255
256        let status = if let Some(ref path) = local_path {
257            if path.exists() {
258                let local_size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
259
260                if expected_size > 0 && local_size < expected_size {
261                    FileStatus::Partial {
262                        local_size,
263                        expected_size,
264                    }
265                } else {
266                    FileStatus::Complete { local_size }
267                }
268            } else if has_partial_blob(&blobs_dir) {
269                // Check blobs for .chunked.part temp files
270                let part_size = find_partial_blob_size(&blobs_dir);
271                FileStatus::Partial {
272                    local_size: part_size,
273                    expected_size,
274                }
275            } else {
276                FileStatus::Missing { expected_size }
277            }
278        } else {
279            FileStatus::Missing { expected_size }
280        };
281
282        // BORROW: explicit .clone() for owned String
283        files.push((remote.filename.clone(), status));
284    }
285
286    files.sort_by(|(a, _), (b, _)| a.cmp(b));
287
288    Ok(RepoStatus {
289        repo_id: repo_id.to_owned(),
290        commit_hash,
291        cache_path: repo_dir,
292        files,
293    })
294}
295
296/// Summary of a single cached model (local-only, no API calls).
297#[derive(Debug, Clone)]
298pub struct CachedModelSummary {
299    /// The repository identifier (e.g., `"RWKV/RWKV7-Goose-World3-1.5B-HF"`).
300    pub repo_id: String,
301    /// Number of files in the snapshot directory.
302    pub file_count: usize,
303    /// Total size on disk in bytes.
304    pub total_size: u64,
305    /// Whether there are incomplete `.chunked.part` temp files.
306    pub has_partial: bool,
307}
308
309/// Scans the entire HF cache and returns a summary for each cached model.
310///
311/// This is a local-only operation (no API calls). It lists all `models--*`
312/// directories and counts files + sizes in each snapshot.
313///
314/// # Errors
315///
316/// Returns [`FetchError::Io`] if the cache directory cannot be read.
317pub fn cache_summary() -> Result<Vec<CachedModelSummary>, FetchError> {
318    let cache_dir = hf_cache_dir()?;
319
320    if !cache_dir.exists() {
321        return Ok(Vec::new());
322    }
323
324    let entries = std::fs::read_dir(&cache_dir).map_err(|e| FetchError::Io {
325        path: cache_dir.clone(),
326        source: e,
327    })?;
328
329    let mut summaries: Vec<CachedModelSummary> = Vec::new();
330
331    for entry in entries {
332        let Ok(entry) = entry else { continue };
333        let dir_name = entry.file_name();
334        // BORROW: explicit .to_string_lossy() for OsString → str conversion
335        let dir_str = dir_name.to_string_lossy();
336
337        let Some(repo_part) = dir_str.strip_prefix("models--") else {
338            continue;
339        };
340
341        // Reconstruct repo_id: replace first "--" with "/"
342        let repo_id = match repo_part.find("--") {
343            Some(pos) => {
344                let (org, name_with_sep) = repo_part.split_at(pos);
345                let name = name_with_sep.get(2..).unwrap_or_default();
346                format!("{org}/{name}")
347            }
348            None => repo_part.to_string(),
349        };
350
351        let repo_dir = entry.path();
352
353        // Count files and total size in snapshots.
354        let (file_count, total_size) = count_snapshot_files(&repo_dir);
355
356        // Check for partial downloads.
357        let has_partial = find_partial_blob_size(&repo_dir.join("blobs")) > 0;
358
359        summaries.push(CachedModelSummary {
360            repo_id,
361            file_count,
362            total_size,
363            has_partial,
364        });
365    }
366
367    summaries.sort_by(|a, b| a.repo_id.cmp(&b.repo_id));
368
369    Ok(summaries)
370}
371
372/// Counts files and total size across all snapshot directories for a repo.
373fn count_snapshot_files(repo_dir: &Path) -> (usize, u64) {
374    let snapshots_dir = repo_dir.join("snapshots");
375    let Ok(snapshots) = std::fs::read_dir(snapshots_dir) else {
376        return (0, 0);
377    };
378
379    let mut file_count: usize = 0;
380    let mut total_size: u64 = 0;
381
382    for snap_entry in snapshots {
383        let Ok(snap_entry) = snap_entry else { continue };
384        let snap_path = snap_entry.path();
385        if !snap_path.is_dir() {
386            continue;
387        }
388        count_files_recursive(&snap_path, &mut file_count, &mut total_size);
389    }
390
391    (file_count, total_size)
392}
393
394/// Recursively counts files and accumulates sizes in a directory.
395fn count_files_recursive(dir: &Path, count: &mut usize, total: &mut u64) {
396    let Ok(entries) = std::fs::read_dir(dir) else {
397        return;
398    };
399
400    for entry in entries {
401        let Ok(entry) = entry else { continue };
402        let path = entry.path();
403        if path.is_dir() {
404            count_files_recursive(&path, count, total);
405        } else {
406            *count += 1;
407            *total += entry.metadata().map(|m| m.len()).unwrap_or(0);
408        }
409    }
410}
411
412/// Reads the commit hash from a refs file, if it exists.
413pub(crate) fn read_ref(repo_dir: &Path, revision: &str) -> Option<String> {
414    let ref_path = repo_dir.join("refs").join(revision);
415    std::fs::read_to_string(ref_path)
416        .ok()
417        .map(|s| s.trim().to_owned())
418        .filter(|s| !s.is_empty())
419}
420
421/// Checks whether any `.chunked.part` temp file exists in the blobs directory.
422///
423/// This is a repo-level heuristic: it cannot map a specific filename to its
424/// blob without full LFS metadata, so it checks for any `.chunked.part` file.
425/// A `true` result means *some* file in the repo has a partial download.
426fn has_partial_blob(blobs_dir: &Path) -> bool {
427    find_partial_blob_size(blobs_dir) > 0
428}
429
430/// Returns the size of the first `.chunked.part` file found in the blobs directory.
431fn find_partial_blob_size(blobs_dir: &Path) -> u64 {
432    let Ok(entries) = std::fs::read_dir(blobs_dir) else {
433        return 0;
434    };
435
436    for entry in entries {
437        let Ok(entry) = entry else { continue };
438        let name = entry.file_name();
439        // BORROW: explicit .to_string_lossy() for OsString → str conversion
440        if name.to_string_lossy().ends_with(".chunked.part") {
441            return entry.metadata().map(|m| m.len()).unwrap_or(0);
442        }
443    }
444
445    0
446}