1use std::collections::BTreeMap;
9use std::path::{Path, PathBuf};
10
11use crate::error::FetchError;
12
13pub fn hf_cache_dir() -> Result<PathBuf, FetchError> {
23 if let Ok(home) = std::env::var("HF_HOME") {
24 let mut path = PathBuf::from(home);
25 path.push("hub");
26 return Ok(path);
27 }
28
29 let home = dirs::home_dir().ok_or_else(|| FetchError::Io {
30 path: PathBuf::from("~"),
31 source: std::io::Error::new(std::io::ErrorKind::NotFound, "home directory not found"),
32 })?;
33
34 let mut path = home;
35 path.push(".cache");
36 path.push("huggingface");
37 path.push("hub");
38 Ok(path)
39}
40
41pub fn list_cached_families() -> Result<BTreeMap<String, Vec<String>>, FetchError> {
55 let cache_dir = hf_cache_dir()?;
56
57 if !cache_dir.exists() {
58 return Ok(BTreeMap::new());
59 }
60
61 let entries = std::fs::read_dir(&cache_dir).map_err(|e| FetchError::Io {
62 path: cache_dir.clone(),
63 source: e,
64 })?;
65
66 let mut families: BTreeMap<String, Vec<String>> = BTreeMap::new();
67
68 for entry in entries {
69 let Ok(entry) = entry else { continue };
70
71 let dir_name = entry.file_name();
72 let dir_str = dir_name.to_string_lossy();
74
75 let Some(repo_part) = dir_str.strip_prefix("models--") else {
77 continue;
78 };
79
80 let repo_id = match repo_part.find("--") {
82 Some(pos) => {
83 let (org, name_with_sep) = repo_part.split_at(pos);
84 let name = name_with_sep.get(2..).unwrap_or_default();
85 format!("{org}/{name}")
86 }
87 None => repo_part.to_string(),
88 };
89
90 let snapshots_dir = entry.path().join("snapshots");
92 if !snapshots_dir.exists() {
93 continue;
94 }
95
96 if let Some(model_type) = find_model_type_in_snapshots(&snapshots_dir) {
97 families.entry(model_type).or_default().push(repo_id);
98 }
99 }
100
101 for repos in families.values_mut() {
103 repos.sort();
104 }
105
106 Ok(families)
107}
108
109fn find_model_type_in_snapshots(snapshots_dir: &std::path::Path) -> Option<String> {
113 let snapshots = std::fs::read_dir(snapshots_dir).ok()?;
114
115 for snap_entry in snapshots {
116 let Ok(snap_entry) = snap_entry else { continue };
117 let config_path = snap_entry.path().join("config.json");
118
119 if !config_path.exists() {
120 continue;
121 }
122
123 if let Some(model_type) = extract_model_type(&config_path) {
124 return Some(model_type);
125 }
126 }
127
128 None
129}
130
131fn extract_model_type(config_path: &std::path::Path) -> Option<String> {
133 let contents = std::fs::read_to_string(config_path).ok()?;
134 let value: serde_json::Value = serde_json::from_str(contents.as_str()).ok()?;
135 value.get("model_type")?.as_str().map(String::from)
137}
138
139#[derive(Debug, Clone)]
141#[non_exhaustive]
142pub enum FileStatus {
143 Complete {
145 local_size: u64,
147 },
148 Partial {
152 local_size: u64,
154 expected_size: u64,
156 },
157 Missing {
159 expected_size: u64,
161 },
162}
163
164#[derive(Debug, Clone)]
166pub struct RepoStatus {
167 pub repo_id: String,
169 pub commit_hash: Option<String>,
171 pub cache_path: PathBuf,
173 pub files: Vec<(String, FileStatus)>,
175}
176
177impl RepoStatus {
178 #[must_use]
180 pub fn complete_count(&self) -> usize {
181 self.files
182 .iter()
183 .filter(|(_, s)| matches!(s, FileStatus::Complete { .. }))
184 .count()
185 }
186
187 #[must_use]
189 pub fn partial_count(&self) -> usize {
190 self.files
191 .iter()
192 .filter(|(_, s)| matches!(s, FileStatus::Partial { .. }))
193 .count()
194 }
195
196 #[must_use]
198 pub fn missing_count(&self) -> usize {
199 self.files
200 .iter()
201 .filter(|(_, s)| matches!(s, FileStatus::Missing { .. }))
202 .count()
203 }
204}
205
206pub async fn repo_status(
219 repo_id: &str,
220 token: Option<&str>,
221 revision: Option<&str>,
222) -> Result<RepoStatus, FetchError> {
223 let revision = revision.unwrap_or("main");
224 let cache_dir = hf_cache_dir()?;
225 let repo_folder = format!("models--{}", repo_id.replace('/', "--"));
226 let repo_dir = cache_dir.join(repo_folder.as_str());
228
229 let commit_hash = read_ref(&repo_dir, revision);
231
232 let remote_files =
234 crate::repo::list_repo_files_with_metadata(repo_id, token, Some(revision)).await?;
235
236 let snapshot_dir = commit_hash
239 .as_deref()
240 .map(|hash| repo_dir.join("snapshots").join(hash));
241
242 let blobs_dir = repo_dir.join("blobs");
244
245 let mut files: Vec<(String, FileStatus)> = Vec::with_capacity(remote_files.len());
247
248 for remote in &remote_files {
249 let expected_size = remote.size.unwrap_or(0);
250
251 let local_path = snapshot_dir
252 .as_ref()
253 .map(|dir| dir.join(remote.filename.as_str()));
255
256 let status = if let Some(ref path) = local_path {
257 if path.exists() {
258 let local_size = std::fs::metadata(path).map(|m| m.len()).unwrap_or(0);
259
260 if expected_size > 0 && local_size < expected_size {
261 FileStatus::Partial {
262 local_size,
263 expected_size,
264 }
265 } else {
266 FileStatus::Complete { local_size }
267 }
268 } else if has_partial_blob(&blobs_dir) {
269 let part_size = find_partial_blob_size(&blobs_dir);
271 FileStatus::Partial {
272 local_size: part_size,
273 expected_size,
274 }
275 } else {
276 FileStatus::Missing { expected_size }
277 }
278 } else {
279 FileStatus::Missing { expected_size }
280 };
281
282 files.push((remote.filename.clone(), status));
284 }
285
286 files.sort_by(|(a, _), (b, _)| a.cmp(b));
287
288 Ok(RepoStatus {
289 repo_id: repo_id.to_owned(),
290 commit_hash,
291 cache_path: repo_dir,
292 files,
293 })
294}
295
296#[derive(Debug, Clone)]
298pub struct CachedModelSummary {
299 pub repo_id: String,
301 pub file_count: usize,
303 pub total_size: u64,
305 pub has_partial: bool,
307}
308
309pub fn cache_summary() -> Result<Vec<CachedModelSummary>, FetchError> {
318 let cache_dir = hf_cache_dir()?;
319
320 if !cache_dir.exists() {
321 return Ok(Vec::new());
322 }
323
324 let entries = std::fs::read_dir(&cache_dir).map_err(|e| FetchError::Io {
325 path: cache_dir.clone(),
326 source: e,
327 })?;
328
329 let mut summaries: Vec<CachedModelSummary> = Vec::new();
330
331 for entry in entries {
332 let Ok(entry) = entry else { continue };
333 let dir_name = entry.file_name();
334 let dir_str = dir_name.to_string_lossy();
336
337 let Some(repo_part) = dir_str.strip_prefix("models--") else {
338 continue;
339 };
340
341 let repo_id = match repo_part.find("--") {
343 Some(pos) => {
344 let (org, name_with_sep) = repo_part.split_at(pos);
345 let name = name_with_sep.get(2..).unwrap_or_default();
346 format!("{org}/{name}")
347 }
348 None => repo_part.to_string(),
349 };
350
351 let repo_dir = entry.path();
352
353 let (file_count, total_size) = count_snapshot_files(&repo_dir);
355
356 let has_partial = find_partial_blob_size(&repo_dir.join("blobs")) > 0;
358
359 summaries.push(CachedModelSummary {
360 repo_id,
361 file_count,
362 total_size,
363 has_partial,
364 });
365 }
366
367 summaries.sort_by(|a, b| a.repo_id.cmp(&b.repo_id));
368
369 Ok(summaries)
370}
371
372fn count_snapshot_files(repo_dir: &Path) -> (usize, u64) {
374 let snapshots_dir = repo_dir.join("snapshots");
375 let Ok(snapshots) = std::fs::read_dir(snapshots_dir) else {
376 return (0, 0);
377 };
378
379 let mut file_count: usize = 0;
380 let mut total_size: u64 = 0;
381
382 for snap_entry in snapshots {
383 let Ok(snap_entry) = snap_entry else { continue };
384 let snap_path = snap_entry.path();
385 if !snap_path.is_dir() {
386 continue;
387 }
388 count_files_recursive(&snap_path, &mut file_count, &mut total_size);
389 }
390
391 (file_count, total_size)
392}
393
394fn count_files_recursive(dir: &Path, count: &mut usize, total: &mut u64) {
396 let Ok(entries) = std::fs::read_dir(dir) else {
397 return;
398 };
399
400 for entry in entries {
401 let Ok(entry) = entry else { continue };
402 let path = entry.path();
403 if path.is_dir() {
404 count_files_recursive(&path, count, total);
405 } else {
406 *count += 1;
407 *total += entry.metadata().map(|m| m.len()).unwrap_or(0);
408 }
409 }
410}
411
412pub(crate) fn read_ref(repo_dir: &Path, revision: &str) -> Option<String> {
414 let ref_path = repo_dir.join("refs").join(revision);
415 std::fs::read_to_string(ref_path)
416 .ok()
417 .map(|s| s.trim().to_owned())
418 .filter(|s| !s.is_empty())
419}
420
421fn has_partial_blob(blobs_dir: &Path) -> bool {
427 find_partial_blob_size(blobs_dir) > 0
428}
429
430fn find_partial_blob_size(blobs_dir: &Path) -> u64 {
432 let Ok(entries) = std::fs::read_dir(blobs_dir) else {
433 return 0;
434 };
435
436 for entry in entries {
437 let Ok(entry) = entry else { continue };
438 let name = entry.file_name();
439 if name.to_string_lossy().ends_with(".chunked.part") {
441 return entry.metadata().map(|m| m.len()).unwrap_or(0);
442 }
443 }
444
445 0
446}