Skip to main content

hf_fetch_model/
cache.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! `HuggingFace` cache directory resolution, model family scanning, and disk usage.
4//!
5//! [`hf_cache_dir()`] locates the local HF cache. [`list_cached_families()`]
6//! scans downloaded models and groups them by `model_type`.
7//! [`cache_summary()`] provides per-repo size totals, and
8//! [`cache_repo_usage()`] returns per-file disk usage for a single repo.
9
10use std::collections::BTreeMap;
11use std::path::{Path, PathBuf};
12
13use crate::error::FetchError;
14
15/// Returns the `HuggingFace` Hub cache directory.
16///
17/// Resolution order:
18/// 1. `HF_HOME` environment variable + `/hub`
19/// 2. `~/.cache/huggingface/hub/` (via [`dirs::home_dir()`])
20///
21/// # Errors
22///
23/// Returns [`FetchError::Io`] if the home directory cannot be determined.
24pub fn hf_cache_dir() -> Result<PathBuf, FetchError> {
25    if let Ok(home) = std::env::var("HF_HOME") {
26        let mut path = PathBuf::from(home);
27        path.push("hub");
28        return Ok(path);
29    }
30
31    let home = dirs::home_dir().ok_or_else(|| FetchError::Io {
32        path: PathBuf::from("~"),
33        source: std::io::Error::new(std::io::ErrorKind::NotFound, "home directory not found"),
34    })?;
35
36    let mut path = home;
37    path.push(".cache");
38    path.push("huggingface");
39    path.push("hub");
40    Ok(path)
41}
42
43/// Scans the local HF cache for downloaded models and groups them by `model_type`.
44///
45/// Looks for `config.json` files inside model snapshot directories:
46/// `<cache>/models--<org>--<name>/snapshots/*/config.json`
47///
48/// Returns a map from `model_type` (e.g., `"llama"`) to a sorted list of
49/// repository identifiers (e.g., `["meta-llama/Llama-3.2-1B"]`).
50///
51/// Models without a `model_type` field in their `config.json` are skipped.
52///
53/// # Errors
54///
55/// Returns [`FetchError::Io`] if the cache directory cannot be read.
56pub fn list_cached_families() -> Result<BTreeMap<String, Vec<String>>, FetchError> {
57    let cache_dir = hf_cache_dir()?;
58
59    if !cache_dir.exists() {
60        return Ok(BTreeMap::new());
61    }
62
63    let entries = std::fs::read_dir(&cache_dir).map_err(|e| FetchError::Io {
64        path: cache_dir.clone(),
65        source: e,
66    })?;
67
68    let mut families: BTreeMap<String, Vec<String>> = BTreeMap::new();
69
70    for entry in entries {
71        let Ok(entry) = entry else { continue };
72
73        let dir_name = entry.file_name();
74        // BORROW: explicit .to_string_lossy() for OsString → str conversion
75        let dir_str = dir_name.to_string_lossy();
76
77        // Only process model directories (models--org--name)
78        let Some(repo_part) = dir_str.strip_prefix("models--") else {
79            continue;
80        };
81
82        // Reconstruct repo_id: replace first "--" with "/"
83        let repo_id = match repo_part.find("--") {
84            Some(pos) => {
85                let (org, name_with_sep) = repo_part.split_at(pos);
86                let name = name_with_sep.get(2..).unwrap_or_default();
87                format!("{org}/{name}")
88            }
89            None => repo_part.to_string(),
90        };
91
92        // Find the newest snapshot's config.json
93        let snapshots_dir = entry.path().join("snapshots");
94        if !snapshots_dir.exists() {
95            continue;
96        }
97
98        if let Some(model_type) = find_model_type_in_snapshots(&snapshots_dir) {
99            families.entry(model_type).or_default().push(repo_id);
100        }
101    }
102
103    // Sort repo lists within each family for stable output
104    for repos in families.values_mut() {
105        repos.sort();
106    }
107
108    Ok(families)
109}
110
111/// Searches snapshot directories for a `config.json` containing `model_type`.
112///
113/// Returns the first `model_type` value found, or `None`.
114fn find_model_type_in_snapshots(snapshots_dir: &std::path::Path) -> Option<String> {
115    let snapshots = std::fs::read_dir(snapshots_dir).ok()?;
116
117    for snap_entry in snapshots {
118        let Ok(snap_entry) = snap_entry else { continue };
119        let config_path = snap_entry.path().join("config.json");
120
121        if !config_path.exists() {
122            continue;
123        }
124
125        if let Some(model_type) = extract_model_type(&config_path) {
126            return Some(model_type);
127        }
128    }
129
130    None
131}
132
133/// Reads a `config.json` file and extracts the `model_type` field.
134fn extract_model_type(config_path: &std::path::Path) -> Option<String> {
135    let contents = std::fs::read_to_string(config_path).ok()?;
136    let value: serde_json::Value = serde_json::from_str(contents.as_str()).ok()?;
137    // BORROW: explicit .as_str() on serde_json Value
138    value.get("model_type")?.as_str().map(String::from)
139}
140
141/// Status of a single file in the cache.
142#[derive(Debug, Clone)]
143#[non_exhaustive]
144pub enum FileStatus {
145    /// File is fully downloaded (local size matches expected size, or no expected size known).
146    Complete {
147        /// Local file size in bytes.
148        local_size: u64,
149    },
150    /// File exists but is smaller than expected (interrupted download),
151    /// or a `.chunked.part` temp file was found in the blobs directory
152    /// (repo-level heuristic — may not correspond to this specific file).
153    Partial {
154        /// Local file size in bytes.
155        local_size: u64,
156        /// Expected file size in bytes.
157        expected_size: u64,
158    },
159    /// File is not present in the cache.
160    Missing {
161        /// Expected file size in bytes (0 if unknown).
162        expected_size: u64,
163    },
164}
165
166/// Cache status report for a repository.
167#[derive(Debug, Clone)]
168pub struct RepoStatus {
169    /// The repository identifier.
170    pub repo_id: String,
171    /// The resolved commit hash (if available).
172    pub commit_hash: Option<String>,
173    /// The cache directory for this repo.
174    pub cache_path: PathBuf,
175    /// Per-file status, sorted by filename.
176    pub files: Vec<(String, FileStatus)>,
177}
178
179impl RepoStatus {
180    /// Number of fully downloaded files.
181    #[must_use]
182    pub fn complete_count(&self) -> usize {
183        self.files
184            .iter()
185            .filter(|(_, s)| matches!(s, FileStatus::Complete { .. }))
186            .count()
187    }
188
189    /// Number of partially downloaded files.
190    #[must_use]
191    pub fn partial_count(&self) -> usize {
192        self.files
193            .iter()
194            .filter(|(_, s)| matches!(s, FileStatus::Partial { .. }))
195            .count()
196    }
197
198    /// Number of missing files.
199    #[must_use]
200    pub fn missing_count(&self) -> usize {
201        self.files
202            .iter()
203            .filter(|(_, s)| matches!(s, FileStatus::Missing { .. }))
204            .count()
205    }
206}
207
208/// Inspects the local cache for a repository and compares against the remote file list.
209///
210/// # Arguments
211///
212/// * `repo_id` — The repository identifier (e.g., `"RWKV/RWKV7-Goose-World3-1.5B-HF"`).
213/// * `token` — Optional authentication token.
214/// * `revision` — Optional revision (defaults to `"main"`).
215///
216/// # Errors
217///
218/// Returns [`FetchError::Http`] if the API request fails.
219/// Returns [`FetchError::Io`] if the cache directory cannot be read.
220pub async fn repo_status(
221    repo_id: &str,
222    token: Option<&str>,
223    revision: Option<&str>,
224) -> Result<RepoStatus, FetchError> {
225    let revision = revision.unwrap_or("main");
226    let cache_dir = hf_cache_dir()?;
227    let repo_folder = format!("models--{}", repo_id.replace('/', "--"));
228    // BORROW: explicit .as_str() for path construction
229    let repo_dir = cache_dir.join(repo_folder.as_str());
230
231    // Read commit hash from refs file if available.
232    let commit_hash = read_ref(&repo_dir, revision);
233
234    // Fetch remote file list with sizes.
235    let remote_files =
236        crate::repo::list_repo_files_with_metadata(repo_id, token, Some(revision)).await?;
237
238    // Determine snapshot directory.
239    // BORROW: explicit .as_deref() for Option<String> → Option<&str>
240    let snapshot_dir = commit_hash
241        .as_deref()
242        .map(|hash| repo_dir.join("snapshots").join(hash));
243
244    // Also check for .chunked.part files in blobs directory.
245    let blobs_dir = repo_dir.join("blobs");
246
247    // Cross-reference remote files against local state.
248    let mut files: Vec<(String, FileStatus)> = Vec::with_capacity(remote_files.len());
249
250    for remote in &remote_files {
251        let expected_size = remote.size.unwrap_or(0);
252
253        let local_path = snapshot_dir
254            .as_ref()
255            // BORROW: explicit .as_str() for path construction
256            .map(|dir| dir.join(remote.filename.as_str()));
257
258        let status = if let Some(ref path) = local_path {
259            if path.exists() {
260                let local_size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
261
262                if expected_size > 0 && local_size < expected_size {
263                    FileStatus::Partial {
264                        local_size,
265                        expected_size,
266                    }
267                } else {
268                    FileStatus::Complete { local_size }
269                }
270            } else if has_partial_blob(&blobs_dir) {
271                // Check blobs for .chunked.part temp files
272                let part_size = find_partial_blob_size(&blobs_dir);
273                FileStatus::Partial {
274                    local_size: part_size,
275                    expected_size,
276                }
277            } else {
278                FileStatus::Missing { expected_size }
279            }
280        } else {
281            FileStatus::Missing { expected_size }
282        };
283
284        // BORROW: explicit .clone() for owned String
285        files.push((remote.filename.clone(), status));
286    }
287
288    files.sort_by(|(a, _), (b, _)| a.cmp(b));
289
290    Ok(RepoStatus {
291        repo_id: repo_id.to_owned(),
292        commit_hash,
293        cache_path: repo_dir,
294        files,
295    })
296}
297
298/// Summary of a single cached model (local-only, no API calls).
299#[derive(Debug, Clone)]
300pub struct CachedModelSummary {
301    /// The repository identifier (e.g., `"RWKV/RWKV7-Goose-World3-1.5B-HF"`).
302    pub repo_id: String,
303    /// Number of files in the snapshot directory.
304    pub file_count: usize,
305    /// Total size on disk in bytes.
306    pub total_size: u64,
307    /// Whether there are incomplete `.chunked.part` temp files.
308    pub has_partial: bool,
309}
310
311/// Scans the entire HF cache and returns a summary for each cached model.
312///
313/// This is a local-only operation (no API calls). It lists all `models--*`
314/// directories and counts files + sizes in each snapshot.
315///
316/// # Errors
317///
318/// Returns [`FetchError::Io`] if the cache directory cannot be read.
319pub fn cache_summary() -> Result<Vec<CachedModelSummary>, FetchError> {
320    let cache_dir = hf_cache_dir()?;
321
322    if !cache_dir.exists() {
323        return Ok(Vec::new());
324    }
325
326    let entries = std::fs::read_dir(&cache_dir).map_err(|e| FetchError::Io {
327        path: cache_dir.clone(),
328        source: e,
329    })?;
330
331    let mut summaries: Vec<CachedModelSummary> = Vec::new();
332
333    for entry in entries {
334        let Ok(entry) = entry else { continue };
335        let dir_name = entry.file_name();
336        // BORROW: explicit .to_string_lossy() for OsString → str conversion
337        let dir_str = dir_name.to_string_lossy();
338
339        let Some(repo_part) = dir_str.strip_prefix("models--") else {
340            continue;
341        };
342
343        // Reconstruct repo_id: replace first "--" with "/"
344        let repo_id = match repo_part.find("--") {
345            Some(pos) => {
346                let (org, name_with_sep) = repo_part.split_at(pos);
347                let name = name_with_sep.get(2..).unwrap_or_default();
348                format!("{org}/{name}")
349            }
350            None => repo_part.to_string(),
351        };
352
353        let repo_dir = entry.path();
354
355        // Count files and total size in snapshots.
356        let (file_count, total_size) = count_snapshot_files(&repo_dir);
357
358        // Check for partial downloads.
359        let has_partial = find_partial_blob_size(&repo_dir.join("blobs")) > 0;
360
361        summaries.push(CachedModelSummary {
362            repo_id,
363            file_count,
364            total_size,
365            has_partial,
366        });
367    }
368
369    summaries.sort_by(|a, b| a.repo_id.cmp(&b.repo_id));
370
371    Ok(summaries)
372}
373
374/// Counts files and total size across all snapshot directories for a repo.
375fn count_snapshot_files(repo_dir: &Path) -> (usize, u64) {
376    let snapshots_dir = repo_dir.join("snapshots");
377    let Ok(snapshots) = std::fs::read_dir(snapshots_dir) else {
378        return (0, 0);
379    };
380
381    let mut file_count: usize = 0;
382    let mut total_size: u64 = 0;
383
384    for snap_entry in snapshots {
385        let Ok(snap_entry) = snap_entry else { continue };
386        let snap_path = snap_entry.path();
387        if !snap_path.is_dir() {
388            continue;
389        }
390        count_files_recursive(&snap_path, &mut file_count, &mut total_size);
391    }
392
393    (file_count, total_size)
394}
395
396/// Recursively counts files and accumulates sizes in a directory.
397fn count_files_recursive(dir: &Path, count: &mut usize, total: &mut u64) {
398    let Ok(entries) = std::fs::read_dir(dir) else {
399        return;
400    };
401
402    for entry in entries {
403        let Ok(entry) = entry else { continue };
404        let path = entry.path();
405        if path.is_dir() {
406            count_files_recursive(&path, count, total);
407        } else {
408            *count += 1;
409            *total += entry.metadata().map(|m| m.len()).unwrap_or(0);
410        }
411    }
412}
413
414/// Reads the commit hash from a refs file, if it exists.
415///
416/// Looks for `<repo_dir>/refs/<revision>` and returns the trimmed contents
417/// (a commit hash) or `None` if the file does not exist or is empty.
418#[must_use]
419pub fn read_ref(repo_dir: &Path, revision: &str) -> Option<String> {
420    let ref_path = repo_dir.join("refs").join(revision);
421    std::fs::read_to_string(ref_path)
422        .ok()
423        .map(|s| s.trim().to_owned())
424        .filter(|s| !s.is_empty())
425}
426
427/// Checks whether any `.chunked.part` temp file exists in the blobs directory.
428///
429/// This is a repo-level heuristic: it cannot map a specific filename to its
430/// blob without full LFS metadata, so it checks for any `.chunked.part` file.
431/// A `true` result means *some* file in the repo has a partial download.
432fn has_partial_blob(blobs_dir: &Path) -> bool {
433    find_partial_blob_size(blobs_dir) > 0
434}
435
436/// Returns the size of the first `.chunked.part` file found in the blobs directory.
437fn find_partial_blob_size(blobs_dir: &Path) -> u64 {
438    let Ok(entries) = std::fs::read_dir(blobs_dir) else {
439        return 0;
440    };
441
442    for entry in entries {
443        let Ok(entry) = entry else { continue };
444        let name = entry.file_name();
445        // BORROW: explicit .to_string_lossy() for OsString → str conversion
446        if name.to_string_lossy().ends_with(".chunked.part") {
447            return entry.metadata().map(|m| m.len()).unwrap_or(0);
448        }
449    }
450
451    0
452}
453
454/// Per-file disk usage entry within a cached repository.
455#[derive(Debug, Clone)]
456pub struct CacheFileUsage {
457    /// Filename relative to the snapshot directory.
458    pub filename: String,
459    /// File size in bytes.
460    pub size: u64,
461}
462
463/// Returns per-file disk usage for a specific cached repository.
464///
465/// Walks the snapshot directories under
466/// `<cache_dir>/models--<org>--<name>/snapshots/` and collects each file's
467/// relative path and size. Results are sorted by size descending.
468///
469/// Returns an empty `Vec` if the repository is not cached.
470///
471/// # Errors
472///
473/// Returns [`FetchError::Io`] if the cache directory cannot be determined.
474pub fn cache_repo_usage(repo_id: &str) -> Result<Vec<CacheFileUsage>, FetchError> {
475    let cache_dir = hf_cache_dir()?;
476    let repo_folder = format!("models--{}", repo_id.replace('/', "--"));
477    let repo_dir = cache_dir.join(&repo_folder);
478
479    if !repo_dir.exists() {
480        return Ok(Vec::new());
481    }
482
483    let snapshots_dir = repo_dir.join("snapshots");
484    let Ok(snapshots) = std::fs::read_dir(&snapshots_dir) else {
485        return Ok(Vec::new());
486    };
487
488    let mut files: Vec<CacheFileUsage> = Vec::new();
489
490    for snap_entry in snapshots {
491        let Ok(snap_entry) = snap_entry else { continue };
492        let snap_path = snap_entry.path();
493        if !snap_path.is_dir() {
494            continue;
495        }
496        collect_snapshot_files(&snap_path, "", &mut files);
497    }
498
499    files.sort_by(|a, b| b.size.cmp(&a.size));
500
501    Ok(files)
502}
503
504/// Recursively collects files from a snapshot directory into `CacheFileUsage` entries.
505///
506/// The `prefix` parameter tracks the relative path from the snapshot root,
507/// so that files in subdirectories get paths like `"tokenizer/vocab.json"`.
508fn collect_snapshot_files(dir: &Path, prefix: &str, files: &mut Vec<CacheFileUsage>) {
509    let Ok(entries) = std::fs::read_dir(dir) else {
510        return;
511    };
512
513    for entry in entries {
514        let Ok(entry) = entry else { continue };
515        let path = entry.path();
516        // BORROW: explicit .to_string_lossy() for OsString → str conversion
517        let name = entry.file_name().to_string_lossy().to_string();
518
519        if path.is_dir() {
520            let child_prefix = if prefix.is_empty() {
521                name
522            } else {
523                format!("{prefix}/{name}")
524            };
525            collect_snapshot_files(&path, &child_prefix, files);
526        } else {
527            let filename = if prefix.is_empty() {
528                name
529            } else {
530                format!("{prefix}/{name}")
531            };
532            let size = entry.metadata().map(|m| m.len()).unwrap_or(0);
533            files.push(CacheFileUsage { filename, size });
534        }
535    }
536}