Skip to main content

hf_fetch_model/
cache.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! `HuggingFace` cache directory resolution, model family scanning, and disk usage.
4//!
5//! [`hf_cache_dir()`] locates the local HF cache. [`list_cached_families()`]
6//! scans downloaded models and groups them by `model_type`.
7//! [`cache_summary()`] provides per-repo size totals, and
8//! [`cache_repo_usage()`] returns per-file disk usage for a single repo.
9
10use std::collections::BTreeMap;
11use std::path::{Path, PathBuf};
12
13use crate::error::FetchError;
14
15/// Reconstructs a repo ID from a `models--org--name` directory name.
16///
17/// Returns `None` if the directory name does not start with `models--`.
18fn repo_id_from_folder_name(dir_name: &str) -> Option<String> {
19    let repo_part = dir_name.strip_prefix("models--")?;
20
21    // Reconstruct repo_id: replace first "--" with "/".
22    let repo_id = match repo_part.find("--") {
23        Some(pos) => {
24            let (org, name_with_sep) = repo_part.split_at(pos);
25            let name = name_with_sep.get(2..).unwrap_or_default();
26            format!("{org}/{name}")
27        }
28        None => repo_part.to_string(),
29    };
30
31    Some(repo_id)
32}
33
34/// Returns the `HuggingFace` Hub cache directory.
35///
36/// Resolution order:
37/// 1. `HF_HOME` environment variable + `/hub`
38/// 2. `~/.cache/huggingface/hub/` (via [`dirs::home_dir()`])
39///
40/// # Errors
41///
42/// Returns [`FetchError::Io`] if the home directory cannot be determined.
43pub fn hf_cache_dir() -> Result<PathBuf, FetchError> {
44    if let Ok(home) = std::env::var("HF_HOME") {
45        let mut path = PathBuf::from(home);
46        path.push("hub");
47        return Ok(path);
48    }
49
50    let home = dirs::home_dir().ok_or_else(|| FetchError::Io {
51        path: PathBuf::from("~"),
52        source: std::io::Error::new(std::io::ErrorKind::NotFound, "home directory not found"),
53    })?;
54
55    let mut path = home;
56    path.push(".cache");
57    path.push("huggingface");
58    path.push("hub");
59    Ok(path)
60}
61
62/// Scans the local HF cache for downloaded models and groups them by `model_type`.
63///
64/// Looks for `config.json` files inside model snapshot directories:
65/// `<cache>/models--<org>--<name>/snapshots/*/config.json`
66///
67/// Returns a map from `model_type` (e.g., `"llama"`) to a sorted list of
68/// repository identifiers (e.g., `["meta-llama/Llama-3.2-1B"]`).
69///
70/// Models without a `model_type` field in their `config.json` are skipped.
71///
72/// # Errors
73///
74/// Returns [`FetchError::Io`] if the cache directory cannot be read.
75pub fn list_cached_families() -> Result<BTreeMap<String, Vec<String>>, FetchError> {
76    let cache_dir = hf_cache_dir()?;
77
78    if !cache_dir.exists() {
79        return Ok(BTreeMap::new());
80    }
81
82    let entries = std::fs::read_dir(&cache_dir).map_err(|e| FetchError::Io {
83        path: cache_dir.clone(),
84        source: e,
85    })?;
86
87    let mut families: BTreeMap<String, Vec<String>> = BTreeMap::new();
88
89    for entry in entries {
90        let Ok(entry) = entry else { continue };
91
92        let dir_name = entry.file_name();
93        // BORROW: explicit .to_string_lossy() for OsString → str conversion
94        let dir_str = dir_name.to_string_lossy();
95
96        let Some(repo_id) = repo_id_from_folder_name(&dir_str) else {
97            continue;
98        };
99
100        // Find the newest snapshot's config.json
101        let snapshots_dir = crate::cache_layout::snapshots_dir(&entry.path());
102        if !snapshots_dir.exists() {
103            continue;
104        }
105
106        if let Some(model_type) = find_model_type_in_snapshots(&snapshots_dir) {
107            families.entry(model_type).or_default().push(repo_id);
108        }
109    }
110
111    // Sort repo lists within each family for stable output
112    for repos in families.values_mut() {
113        repos.sort();
114    }
115
116    Ok(families)
117}
118
119/// Searches snapshot directories for a `config.json` containing `model_type`.
120///
121/// Returns the first `model_type` value found, or `None`.
122fn find_model_type_in_snapshots(snapshots_dir: &std::path::Path) -> Option<String> {
123    let snapshots = std::fs::read_dir(snapshots_dir).ok()?;
124
125    for snap_entry in snapshots {
126        let Ok(snap_entry) = snap_entry else { continue };
127        let config_path = snap_entry.path().join("config.json");
128
129        if !config_path.exists() {
130            continue;
131        }
132
133        if let Some(model_type) = extract_model_type(&config_path) {
134            return Some(model_type);
135        }
136    }
137
138    None
139}
140
141/// Reads a `config.json` file and extracts the `model_type` field.
142fn extract_model_type(config_path: &std::path::Path) -> Option<String> {
143    let contents = std::fs::read_to_string(config_path).ok()?;
144    // BORROW: explicit .as_str() instead of Deref coercion
145    let value: serde_json::Value = serde_json::from_str(contents.as_str()).ok()?;
146    // BORROW: explicit .as_str() on serde_json Value
147    value.get("model_type")?.as_str().map(String::from)
148}
149
150/// Status of a single file in the cache.
151#[derive(Debug, Clone)]
152#[non_exhaustive]
153pub enum FileStatus {
154    /// File is fully downloaded (local size matches expected size, or no expected size known).
155    Complete {
156        /// Local file size in bytes.
157        local_size: u64,
158    },
159    /// File exists but is smaller than expected (interrupted download),
160    /// or a `.chunked.part` temp file was found in the blobs directory
161    /// (repo-level heuristic — may not correspond to this specific file).
162    Partial {
163        /// Local file size in bytes.
164        local_size: u64,
165        /// Expected file size in bytes.
166        expected_size: u64,
167    },
168    /// File is not present in the cache.
169    Missing {
170        /// Expected file size in bytes (0 if unknown).
171        expected_size: u64,
172    },
173}
174
175/// Cache status report for a repository.
176#[derive(Debug, Clone)]
177pub struct RepoStatus {
178    /// The repository identifier.
179    pub repo_id: String,
180    /// The resolved commit hash (if available).
181    pub commit_hash: Option<String>,
182    /// The cache directory for this repo.
183    pub cache_path: PathBuf,
184    /// Per-file status, sorted by filename.
185    pub files: Vec<(String, FileStatus)>,
186}
187
188impl RepoStatus {
189    /// Number of fully downloaded files.
190    #[must_use]
191    pub fn complete_count(&self) -> usize {
192        self.files
193            .iter()
194            .filter(|(_, s)| matches!(s, FileStatus::Complete { .. }))
195            .count()
196    }
197
198    /// Number of partially downloaded files.
199    #[must_use]
200    pub fn partial_count(&self) -> usize {
201        self.files
202            .iter()
203            .filter(|(_, s)| matches!(s, FileStatus::Partial { .. }))
204            .count()
205    }
206
207    /// Number of missing files.
208    #[must_use]
209    pub fn missing_count(&self) -> usize {
210        self.files
211            .iter()
212            .filter(|(_, s)| matches!(s, FileStatus::Missing { .. }))
213            .count()
214    }
215}
216
217/// Inspects the local cache for a repository and compares against the remote file list.
218///
219/// # Arguments
220///
221/// * `repo_id` — The repository identifier (e.g., `"RWKV/RWKV7-Goose-World3-1.5B-HF"`).
222/// * `token` — Optional authentication token.
223/// * `revision` — Optional revision (defaults to `"main"`).
224///
225/// # Notes
226///
227/// Partial download detection is a repo-level heuristic: if any
228/// `.chunked.part` file exists in the repo's `blobs/` directory, all
229/// missing files are reported as [`FileStatus::Partial`] with the partial
230/// file's size. This may overcount partials when multiple files are
231/// missing but only one has an incomplete blob. Exact blob-to-file
232/// mapping would require LFS metadata.
233///
234/// # Errors
235///
236/// Returns [`FetchError::Http`] if the API request fails.
237/// Returns [`FetchError::Io`] if the cache directory cannot be read.
238pub async fn repo_status(
239    repo_id: &str,
240    token: Option<&str>,
241    revision: Option<&str>,
242) -> Result<RepoStatus, FetchError> {
243    let revision = revision.unwrap_or("main");
244    let cache_dir = hf_cache_dir()?;
245    let repo_dir = crate::cache_layout::repo_dir(&cache_dir, repo_id);
246
247    // Read commit hash from refs file if available.
248    let commit_hash = read_ref(&repo_dir, revision);
249
250    // Fetch remote file list with sizes.
251    let client = crate::chunked::build_client(token)?;
252    let remote_files =
253        crate::repo::list_repo_files_with_metadata(repo_id, token, Some(revision), &client).await?;
254
255    // Determine snapshot directory.
256    // BORROW: explicit .as_deref() for Option<String> → Option<&str>
257    let snapshot_dir = commit_hash
258        .as_deref()
259        .map(|hash| crate::cache_layout::snapshot_dir(&repo_dir, hash));
260
261    // Pre-check for .chunked.part files in blobs directory (avoids re-scanning
262    // the blobs directory for every missing file in the loop below).
263    let blobs_dir = crate::cache_layout::blobs_dir(&repo_dir);
264    let has_any_partial = has_partial_blob(&blobs_dir);
265
266    // Cross-reference remote files against local state.
267    let mut files: Vec<(String, FileStatus)> = Vec::with_capacity(remote_files.len());
268
269    for remote in &remote_files {
270        let expected_size = remote.size.unwrap_or(0);
271
272        let local_path = snapshot_dir
273            .as_ref()
274            // BORROW: explicit .as_str() for path construction
275            .map(|dir| dir.join(remote.filename.as_str()));
276
277        let status = if let Some(ref path) = local_path {
278            if path.exists() {
279                let local_size = std::fs::metadata(path).map_or(0, |m| m.len());
280
281                if expected_size > 0 && local_size < expected_size {
282                    FileStatus::Partial {
283                        local_size,
284                        expected_size,
285                    }
286                } else {
287                    FileStatus::Complete { local_size }
288                }
289            } else if has_any_partial {
290                // Blobs directory has .chunked.part temp files
291                let part_size = find_partial_blob_size(&blobs_dir);
292                FileStatus::Partial {
293                    local_size: part_size,
294                    expected_size,
295                }
296            } else {
297                FileStatus::Missing { expected_size }
298            }
299        } else {
300            FileStatus::Missing { expected_size }
301        };
302
303        // BORROW: explicit .clone() for owned String
304        files.push((remote.filename.clone(), status));
305    }
306
307    files.sort_by(|(a, _), (b, _)| a.cmp(b));
308
309    // BORROW: explicit .to_owned() for &str → owned String field
310    Ok(RepoStatus {
311        repo_id: repo_id.to_owned(),
312        commit_hash,
313        cache_path: repo_dir,
314        files,
315    })
316}
317
318/// Summary of a single cached model (local-only, no API calls).
319#[derive(Debug, Clone)]
320pub struct CachedModelSummary {
321    /// The repository identifier (e.g., `"RWKV/RWKV7-Goose-World3-1.5B-HF"`).
322    pub repo_id: String,
323    /// Number of files in the snapshot directory.
324    pub file_count: usize,
325    /// Total size on disk in bytes.
326    pub total_size: u64,
327    /// Whether there are incomplete `.chunked.part` temp files.
328    pub has_partial: bool,
329    /// Most recent modification time among files in the snapshot directory.
330    ///
331    /// `None` if no files were found or all metadata reads failed.
332    pub last_modified: Option<std::time::SystemTime>,
333}
334
335/// Scans the entire HF cache and returns a summary for each cached model.
336///
337/// This is a local-only operation (no API calls). It lists all `models--*`
338/// directories and counts files + sizes in each snapshot.
339///
340/// # Errors
341///
342/// Returns [`FetchError::Io`] if the cache directory cannot be read.
343pub fn cache_summary() -> Result<Vec<CachedModelSummary>, FetchError> {
344    let cache_dir = hf_cache_dir()?;
345
346    if !cache_dir.exists() {
347        return Ok(Vec::new());
348    }
349
350    let entries = std::fs::read_dir(&cache_dir).map_err(|e| FetchError::Io {
351        path: cache_dir.clone(),
352        source: e,
353    })?;
354
355    let mut summaries: Vec<CachedModelSummary> = Vec::new();
356
357    for entry in entries {
358        let Ok(entry) = entry else { continue };
359        let dir_name = entry.file_name();
360        // BORROW: explicit .to_string_lossy() for OsString → str conversion
361        let dir_str = dir_name.to_string_lossy();
362
363        let Some(repo_id) = repo_id_from_folder_name(&dir_str) else {
364            continue;
365        };
366
367        let repo_dir = entry.path();
368
369        // Count files and total size in snapshots.
370        let (file_count, total_size, last_modified) = count_snapshot_files(&repo_dir);
371
372        // Check for partial downloads.
373        let has_partial = find_partial_blob_size(&crate::cache_layout::blobs_dir(&repo_dir)) > 0;
374
375        summaries.push(CachedModelSummary {
376            repo_id,
377            file_count,
378            total_size,
379            has_partial,
380            last_modified,
381        });
382    }
383
384    summaries.sort_by(|a, b| a.repo_id.cmp(&b.repo_id));
385
386    Ok(summaries)
387}
388
389/// Returns the file count and total size for a single cached repo.
390///
391/// Avoids scanning the entire cache when only one repo's metrics are needed
392/// (e.g., for the `cache delete` preview).
393///
394/// # Errors
395///
396/// Returns [`FetchError::Io`] if the cache directory cannot be determined.
397pub fn repo_disk_usage(repo_id: &str) -> Result<(usize, u64), FetchError> {
398    let cache_dir = hf_cache_dir()?;
399    let repo_dir = crate::cache_layout::repo_dir(&cache_dir, repo_id);
400    let (file_count, total_size, _) = count_snapshot_files(&repo_dir);
401    Ok((file_count, total_size))
402}
403
404/// Checks whether a single cached repo has `.chunked.part` temp files.
405///
406/// Avoids scanning the entire cache when only one repo's partial status
407/// is needed (e.g., for the `du <REPO>` partial-download hint).
408///
409/// # Errors
410///
411/// Returns [`FetchError::Io`] if the cache directory cannot be determined.
412pub fn repo_has_partial(repo_id: &str) -> Result<bool, FetchError> {
413    let cache_dir = hf_cache_dir()?;
414    let repo_dir = crate::cache_layout::repo_dir(&cache_dir, repo_id);
415    let blobs_dir = crate::cache_layout::blobs_dir(&repo_dir);
416    Ok(find_partial_blob_size(&blobs_dir) > 0)
417}
418
419/// Counts files, total size, and most recent modification time across all
420/// snapshot directories for a repo.
421fn count_snapshot_files(repo_dir: &Path) -> (usize, u64, Option<std::time::SystemTime>) {
422    let snapshots_dir = crate::cache_layout::snapshots_dir(repo_dir);
423    let Ok(snapshots) = std::fs::read_dir(snapshots_dir) else {
424        return (0, 0, None);
425    };
426
427    let mut file_count: usize = 0;
428    let mut total_size: u64 = 0;
429    let mut latest: Option<std::time::SystemTime> = None;
430
431    for snap_entry in snapshots {
432        let Ok(snap_entry) = snap_entry else { continue };
433        let snap_path = snap_entry.path();
434        if !snap_path.is_dir() {
435            continue;
436        }
437        count_files_recursive(&snap_path, &mut file_count, &mut total_size, &mut latest);
438    }
439
440    (file_count, total_size, latest)
441}
442
443/// Recursively counts files, accumulates sizes, and tracks the most recent
444/// modification time in a directory.
445fn count_files_recursive(
446    dir: &Path,
447    count: &mut usize,
448    total: &mut u64,
449    latest: &mut Option<std::time::SystemTime>,
450) {
451    let Ok(entries) = std::fs::read_dir(dir) else {
452        return;
453    };
454
455    for entry in entries {
456        let Ok(entry) = entry else { continue };
457        let path = entry.path();
458        if path.is_dir() {
459            count_files_recursive(&path, count, total, latest);
460        } else if let Ok(meta) = entry.metadata() {
461            *count += 1;
462            *total += meta.len();
463            if let Ok(modified) = meta.modified() {
464                match *latest {
465                    Some(current) if modified <= current => {} // EXPLICIT: current mtime is more recent, keep it
466                    _ => *latest = Some(modified),
467                }
468            }
469        } else {
470            *count += 1;
471        }
472    }
473}
474
475/// Reads the commit hash from a refs file, if it exists.
476///
477/// Looks for `<repo_dir>/refs/<revision>` and returns the trimmed contents
478/// (a commit hash) or `None` if the file does not exist or is empty.
479#[must_use]
480pub fn read_ref(repo_dir: &Path, revision: &str) -> Option<String> {
481    let ref_path = crate::cache_layout::ref_path(repo_dir, revision);
482    std::fs::read_to_string(ref_path)
483        .ok()
484        // BORROW: explicit .to_owned() to convert trimmed &str → owned String
485        .map(|s| s.trim().to_owned())
486        .filter(|s| !s.is_empty())
487}
488
489/// Checks whether any `.chunked.part` temp file exists in the blobs directory.
490///
491/// This is a repo-level heuristic: it cannot map a specific filename to its
492/// blob without full LFS metadata, so it checks for any `.chunked.part` file.
493/// A `true` result means *some* file in the repo has a partial download.
494fn has_partial_blob(blobs_dir: &Path) -> bool {
495    find_partial_blob_size(blobs_dir) > 0
496}
497
498/// Returns the size of the first `.chunked.part` file found in the blobs directory.
499fn find_partial_blob_size(blobs_dir: &Path) -> u64 {
500    let Ok(entries) = std::fs::read_dir(blobs_dir) else {
501        return 0;
502    };
503
504    for entry in entries {
505        let Ok(entry) = entry else { continue };
506        let name = entry.file_name();
507        // BORROW: explicit .to_string_lossy() for OsString → str conversion
508        if name.to_string_lossy().ends_with(".chunked.part") {
509            return entry.metadata().map_or(0, |m| m.len());
510        }
511    }
512
513    0
514}
515
516/// A `.chunked.part` temp file left by an interrupted chunked download.
517#[derive(Debug, Clone)]
518pub struct PartialFile {
519    /// The repository identifier (e.g., `"meta-llama/Llama-3.2-1B"`).
520    pub repo_id: String,
521    /// The `.chunked.part` filename (e.g., `"abc123def456.chunked.part"`).
522    pub filename: String,
523    /// Absolute path to the `.chunked.part` file.
524    pub path: PathBuf,
525    /// Size of the partial file in bytes.
526    pub size: u64,
527}
528
529/// Finds all `.chunked.part` temp files in the `HuggingFace` cache.
530///
531/// Walks `models--*/blobs/` directories and collects partial files.
532/// When `repo_filter` is `Some`, only the matching repo is scanned.
533///
534/// Returns an empty `Vec` if the cache directory does not exist.
535///
536/// # Errors
537///
538/// Returns [`FetchError::Io`] if the cache directory cannot be read.
539pub fn find_partial_files(repo_filter: Option<&str>) -> Result<Vec<PartialFile>, FetchError> {
540    let cache_dir = hf_cache_dir()?;
541
542    if !cache_dir.exists() {
543        return Ok(Vec::new());
544    }
545
546    let entries = std::fs::read_dir(&cache_dir).map_err(|e| FetchError::Io {
547        // BORROW: explicit .clone() for owned PathBuf
548        path: cache_dir.clone(),
549        source: e,
550    })?;
551
552    let mut partials: Vec<PartialFile> = Vec::new();
553
554    for entry in entries {
555        let Ok(entry) = entry else { continue };
556        let dir_name = entry.file_name();
557        // BORROW: explicit .to_string_lossy() for OsString → str conversion
558        let dir_str = dir_name.to_string_lossy();
559
560        let Some(repo_id) = repo_id_from_folder_name(&dir_str) else {
561            continue;
562        };
563
564        // Skip repos that don't match the filter.
565        // BORROW: explicit .as_str() instead of Deref coercion
566        if let Some(filter) = repo_filter {
567            if repo_id.as_str() != filter {
568                continue;
569            }
570        }
571
572        let blobs_dir = crate::cache_layout::blobs_dir(&entry.path());
573        let Ok(blob_entries) = std::fs::read_dir(&blobs_dir) else {
574            continue;
575        };
576
577        for blob_entry in blob_entries {
578            let Ok(blob_entry) = blob_entry else { continue };
579            let name = blob_entry.file_name();
580            // BORROW: explicit .to_string_lossy() for OsString → str conversion
581            let name_str = name.to_string_lossy();
582            if name_str.ends_with(".chunked.part") {
583                let size = blob_entry.metadata().map_or(0, |m| m.len());
584                partials.push(PartialFile {
585                    // BORROW: explicit .clone() for owned String
586                    repo_id: repo_id.clone(),
587                    // BORROW: explicit .to_string() for Cow<str> → owned String
588                    filename: name_str.to_string(),
589                    path: blob_entry.path(),
590                    size,
591                });
592            }
593        }
594    }
595
596    Ok(partials)
597}
598
599/// Per-file disk usage entry within a cached repository.
600#[derive(Debug, Clone)]
601pub struct CacheFileUsage {
602    /// Filename relative to the snapshot directory.
603    pub filename: String,
604    /// File size in bytes.
605    pub size: u64,
606}
607
608/// Returns per-file disk usage for a specific cached repository.
609///
610/// Walks the snapshot directories under
611/// `<cache_dir>/models--<org>--<name>/snapshots/` and collects each file's
612/// relative path and size. Results are sorted by size descending.
613///
614/// Returns an empty `Vec` if the repository is not cached.
615///
616/// # Errors
617///
618/// Returns [`FetchError::Io`] if the cache directory cannot be determined.
619pub fn cache_repo_usage(repo_id: &str) -> Result<Vec<CacheFileUsage>, FetchError> {
620    let cache_dir = hf_cache_dir()?;
621    let repo_dir = crate::cache_layout::repo_dir(&cache_dir, repo_id);
622
623    if !repo_dir.exists() {
624        return Ok(Vec::new());
625    }
626
627    let snapshots_dir = crate::cache_layout::snapshots_dir(&repo_dir);
628    let Ok(snapshots) = std::fs::read_dir(&snapshots_dir) else {
629        return Ok(Vec::new());
630    };
631
632    let mut files: Vec<CacheFileUsage> = Vec::new();
633
634    for snap_entry in snapshots {
635        let Ok(snap_entry) = snap_entry else { continue };
636        let snap_path = snap_entry.path();
637        if !snap_path.is_dir() {
638            continue;
639        }
640        collect_snapshot_files(&snap_path, "", &mut files);
641    }
642
643    files.sort_by_key(|f| std::cmp::Reverse(f.size));
644
645    Ok(files)
646}
647
648/// Recursively collects files from a snapshot directory into `CacheFileUsage` entries.
649///
650/// The `prefix` parameter tracks the relative path from the snapshot root,
651/// so that files in subdirectories get paths like `"tokenizer/vocab.json"`.
652fn collect_snapshot_files(dir: &Path, prefix: &str, files: &mut Vec<CacheFileUsage>) {
653    let Ok(entries) = std::fs::read_dir(dir) else {
654        return;
655    };
656
657    for entry in entries {
658        let Ok(entry) = entry else { continue };
659        let path = entry.path();
660        // BORROW: explicit .to_string_lossy() for OsString → str conversion
661        let name = entry.file_name().to_string_lossy().to_string();
662
663        if path.is_dir() {
664            let child_prefix = if prefix.is_empty() {
665                name
666            } else {
667                format!("{prefix}/{name}")
668            };
669            collect_snapshot_files(&path, &child_prefix, files);
670        } else {
671            let filename = if prefix.is_empty() {
672                name
673            } else {
674                format!("{prefix}/{name}")
675            };
676            let size = entry.metadata().map_or(0, |m| m.len());
677            files.push(CacheFileUsage { filename, size });
678        }
679    }
680}