1use std::{
2 env, fs,
3 io::Read,
4 path::{Path, PathBuf},
5};
6
7use anyhow::{anyhow, Result};
8use hf_hub::{
9 api::sync::{ApiError, ApiRepo},
10 Cache, Repo, RepoType,
11};
12use tracing::{trace, warn};
13
14use super::FileListCache;
15
16pub const HF_HUB_OFFLINE_ENV: &str = "HF_HUB_OFFLINE";
19
20pub fn is_hf_hub_offline() -> bool {
24 matches!(
25 env::var(HF_HUB_OFFLINE_ENV)
26 .ok()
27 .map(|v| v.trim().to_ascii_lowercase()),
28 Some(ref v) if matches!(v.as_str(), "1" | "true" | "yes" | "on")
29 )
30}
31
32fn offline_repo(model_id: &Path, revision: &str) -> Repo {
33 Repo::with_revision(
34 model_id.display().to_string(),
35 RepoType::Model,
36 revision.to_string(),
37 )
38}
39
40pub(crate) fn offline_cache_repo(model_id: &Path, revision: &str) -> hf_hub::CacheRepo {
41 let cache = hf_hub_cache_dir().map(Cache::new).unwrap_or_default();
42 cache.repo(offline_repo(model_id, revision))
43}
44
45pub(crate) fn offline_missing_file_error(
46 model_id: &Path,
47 file: &str,
48 revision: &str,
49) -> anyhow::Error {
50 anyhow!(
51 "`{HF_HUB_OFFLINE_ENV}` is set but `{file}` for `{}` (revision `{revision}`) was not found in the local Hugging Face cache. \
52 Unset `{HF_HUB_OFFLINE_ENV}` or pre-download the file (e.g. via `huggingface-cli download`).",
53 model_id.display()
54 )
55}
56
57fn offline_snapshot_files(model_id: &Path, revision: &str) -> Vec<String> {
58 fn walk(root: &Path, dir: &Path, out: &mut Vec<String>) -> std::io::Result<()> {
59 for entry in fs::read_dir(dir)? {
60 let entry = entry?;
61 let path = entry.path();
62 if path.is_dir() {
63 walk(root, &path, out)?;
64 } else if let Ok(rel) = path.strip_prefix(root) {
65 out.push(rel.to_string_lossy().replace('\\', "/"));
66 }
67 }
68 Ok(())
69 }
70
71 let Some(cache_dir) = hf_hub_cache_dir() else {
72 return Vec::new();
73 };
74 let repo = offline_repo(model_id, revision);
75 let folder = repo.folder_name();
76 let ref_path = cache_dir.join(&folder).join("refs").join(revision);
77 let commit = fs::read_to_string(&ref_path)
79 .map(|s| s.trim().to_string())
80 .unwrap_or_else(|_| revision.to_string());
81 let snapshot_dir = cache_dir.join(&folder).join("snapshots").join(&commit);
82 if !snapshot_dir.is_dir() {
83 return Vec::new();
84 }
85 let mut files = Vec::new();
86 let _ = walk(&snapshot_dir, &snapshot_dir, &mut files);
87 files
88}
89
90#[derive(Clone, Debug)]
91pub(crate) struct RemoteAccessIssue {
92 pub status_code: Option<u16>,
93 pub message: String,
94}
95
96pub fn hf_home_dir() -> Option<PathBuf> {
102 let dir = env::var("HF_HOME")
103 .ok()
104 .map(PathBuf::from)
105 .or_else(|| dirs::home_dir().map(|home| home.join(".cache").join("huggingface")));
106
107 if let Some(ref dir) = dir {
108 if let Err(err) = fs::create_dir_all(dir) {
109 warn!(
110 "Could not create Hugging Face home directory `{}`: {err}",
111 dir.display()
112 );
113 }
114 }
115
116 dir
117}
118
119pub fn hf_hub_cache_dir() -> Option<PathBuf> {
126 let dir = env::var("HF_HUB_CACHE")
127 .ok()
128 .map(PathBuf::from)
129 .or_else(|| hf_home_dir().map(|home| home.join("hub")));
130
131 if let Some(ref dir) = dir {
132 if let Err(err) = fs::create_dir_all(dir) {
133 warn!(
134 "Could not create Hugging Face hub cache directory `{}`: {err}",
135 dir.display()
136 );
137 }
138 }
139
140 dir
141}
142
143pub fn hf_token_path() -> Option<PathBuf> {
145 hf_home_dir().map(|home| home.join("token"))
146}
147
148fn cache_dir() -> PathBuf {
149 hf_hub_cache_dir().unwrap_or_else(|| PathBuf::from("./"))
150}
151
152fn cache_file_for_model(model_id: &Path) -> PathBuf {
153 let sanitized_id = model_id.display().to_string().replace('/', "-");
154 cache_dir().join(format!("{sanitized_id}_repo_list.json"))
155}
156
157fn read_cached_repo_files(cache_file: &Path) -> Option<Vec<String>> {
158 if !cache_file.exists() {
159 return None;
160 }
161
162 let mut file = match fs::File::open(cache_file) {
163 Ok(file) => file,
164 Err(err) => {
165 warn!(
166 "Could not open Hugging Face repo cache file `{}`: {err}",
167 cache_file.display()
168 );
169 return None;
170 }
171 };
172
173 let mut contents = String::new();
174 if let Err(err) = file.read_to_string(&mut contents) {
175 warn!(
176 "Could not read Hugging Face repo cache file `{}`: {err}",
177 cache_file.display()
178 );
179 return None;
180 }
181
182 match serde_json::from_str::<FileListCache>(&contents) {
183 Ok(cache) => {
184 trace!("Read from cache file `{}`", cache_file.display());
185 Some(cache.files)
186 }
187 Err(err) => {
188 warn!(
189 "Could not parse Hugging Face repo cache file `{}`: {err}",
190 cache_file.display()
191 );
192 None
193 }
194 }
195}
196
197fn write_cached_repo_files(cache_file: &Path, files: &[String]) {
198 let cache = FileListCache {
199 files: files.to_vec(),
200 };
201 match serde_json::to_string_pretty(&cache) {
202 Ok(json) => {
203 if let Err(err) = fs::write(cache_file, json) {
204 warn!(
205 "Could not write Hugging Face repo cache file `{}`: {err}",
206 cache_file.display()
207 );
208 } else {
209 trace!("Write to cache file `{}`", cache_file.display());
210 }
211 }
212 Err(err) => warn!(
213 "Could not serialize Hugging Face repo cache for `{}`: {err}",
214 cache_file.display()
215 ),
216 }
217}
218
219pub(crate) fn parse_status_code(message: &str) -> Option<u16> {
220 let marker = "status code ";
221 let (_, tail) = message.split_once(marker)?;
222 let digits = tail
223 .chars()
224 .take_while(|c| c.is_ascii_digit())
225 .collect::<String>();
226 digits.parse().ok()
227}
228
229pub(crate) fn api_error_status_code(err: &ApiError) -> Option<u16> {
230 match err {
231 ApiError::TooManyRetries(inner) => api_error_status_code(inner),
232 _ => parse_status_code(&err.to_string()),
233 }
234}
235
236pub(crate) fn should_propagate_api_error(err: &ApiError) -> bool {
237 matches!(api_error_status_code(err), Some(401 | 403 | 404))
238}
239
240pub(crate) fn remote_issue_from_api_error(
241 model_id: &Path,
242 file: Option<&str>,
243 err: &ApiError,
244) -> RemoteAccessIssue {
245 let target = match file {
246 Some(file) => format!("`{file}` for `{}`", model_id.display()),
247 None => format!("`{}`", model_id.display()),
248 };
249 RemoteAccessIssue {
250 status_code: api_error_status_code(err),
251 message: format!("Failed to access {target}: {err}"),
252 }
253}
254
255pub(crate) fn hf_access_error(model_id: &Path, issue: &RemoteAccessIssue) -> anyhow::Error {
256 match issue.status_code {
257 Some(code @ (401 | 403)) => anyhow!(
258 "Could not access `{}` on Hugging Face (HTTP {code}). You may need to run `hanzo login` or set HF_TOKEN.",
259 model_id.display()
260 ),
261 Some(404) => anyhow!(
262 "Model `{}` was not found or is not accessible on Hugging Face (HTTP 404). Check the model ID and your access token.",
263 model_id.display()
264 ),
265 Some(code) => anyhow!(
266 "Failed to access `{}` on Hugging Face (HTTP {code}): {}",
267 model_id.display(),
268 issue.message
269 ),
270 None => anyhow!(
271 "Failed to access `{}` on Hugging Face: {}",
272 model_id.display(),
273 issue.message
274 ),
275 }
276}
277
278pub(crate) fn hf_api_error(model_id: &Path, file: Option<&str>, err: &ApiError) -> anyhow::Error {
279 let status_code = api_error_status_code(err);
280 let file_context = file
281 .map(|f| format!(" while fetching `{f}`"))
282 .unwrap_or_default();
283 match status_code {
284 Some(code @ (401 | 403)) => anyhow!(
285 "Could not access `{}` on Hugging Face (HTTP {code}){file_context}. You may need to run `hanzo login` or set HF_TOKEN.",
286 model_id.display()
287 ),
288 Some(404) => anyhow!(
289 "Model `{}` was not found or is not accessible on Hugging Face (HTTP 404){file_context}. Check the model ID and your access token.",
290 model_id.display()
291 ),
292 Some(code) => anyhow!(
293 "Failed to access `{}` on Hugging Face (HTTP {code}){file_context}: {err}",
294 model_id.display()
295 ),
296 None => anyhow!(
297 "Failed to access `{}` on Hugging Face{file_context}: {err}",
298 model_id.display()
299 ),
300 }
301}
302
303pub(crate) fn local_file_missing_error(model_id: &Path, file: &str) -> anyhow::Error {
304 anyhow!(
305 "File `{file}` was not found at local model path `{}`.",
306 model_id.display()
307 )
308}
309
310pub(crate) fn list_repo_files(
311 api: &ApiRepo,
312 model_id: &Path,
313 should_error: bool,
314 revision: &str,
315) -> Result<Vec<String>> {
316 if model_id.exists() {
317 let listing = fs::read_dir(model_id).map_err(|err| {
318 anyhow!(
319 "Cannot list local model directory `{}`: {err}",
320 model_id.display()
321 )
322 })?;
323 let files = listing
324 .filter_map(|entry| entry.ok())
325 .filter_map(|entry| {
326 entry
327 .path()
328 .file_name()
329 .and_then(|name| name.to_str())
330 .map(std::string::ToString::to_string)
331 })
332 .collect::<Vec<_>>();
333 return Ok(files);
334 }
335
336 let cache_file = cache_file_for_model(model_id);
337 if let Some(files) = read_cached_repo_files(&cache_file) {
338 return Ok(files);
339 }
340
341 if is_hf_hub_offline() {
342 let files = offline_snapshot_files(model_id, revision);
343 if !files.is_empty() {
344 write_cached_repo_files(&cache_file, &files);
345 return Ok(files);
346 }
347 if should_error {
348 return Err(anyhow!(
349 "`{HF_HUB_OFFLINE_ENV}` is set but no cached file list or snapshot was found for `{}` (revision `{revision}`).",
350 model_id.display()
351 ));
352 }
353 warn!(
354 "`{HF_HUB_OFFLINE_ENV}` is set and no local Hugging Face cache was found for `{}` (revision `{revision}`)",
355 model_id.display()
356 );
357 return Ok(Vec::new());
358 }
359
360 match api.info() {
361 Ok(repo) => {
362 let files = repo
363 .siblings
364 .iter()
365 .map(|x| x.rfilename.clone())
366 .collect::<Vec<_>>();
367 write_cached_repo_files(&cache_file, &files);
368 Ok(files)
369 }
370 Err(err) => {
371 if should_error || should_propagate_api_error(&err) {
372 Err(hf_api_error(model_id, None, &err))
373 } else {
374 warn!(
375 "Could not get directory listing from Hugging Face for `{}`: {err}",
376 model_id.display()
377 );
378 Ok(Vec::new())
379 }
380 }
381 }
382}
383
384pub(crate) fn get_file(
385 api: &ApiRepo,
386 model_id: &Path,
387 file: &str,
388 revision: &str,
389) -> Result<PathBuf> {
390 if model_id.exists() {
391 let path = model_id.join(file);
392 if !path.exists() {
393 return Err(local_file_missing_error(model_id, file));
394 }
395 trace!("Loading `{file}` locally at `{}`", path.display());
396 return Ok(path);
397 }
398
399 if is_hf_hub_offline() {
400 if let Some(path) = offline_cache_repo(model_id, revision).get(file) {
401 trace!(
402 "Loading `{file}` from local Hugging Face cache at `{}` (offline mode)",
403 path.display()
404 );
405 return Ok(path);
406 }
407 return Err(offline_missing_file_error(model_id, file, revision));
408 }
409
410 api.get(file)
411 .map_err(|err| hf_api_error(model_id, Some(file), &err))
412}
413
414pub(crate) fn try_get_file(
416 api: &ApiRepo,
417 model_id: &Path,
418 file: &str,
419 revision: &str,
420) -> std::result::Result<Option<PathBuf>, ApiError> {
421 if model_id.exists() {
422 let path = model_id.join(file);
423 if path.exists() {
424 trace!("Loading `{file}` locally at `{}`", path.display());
425 return Ok(Some(path));
426 }
427 return Ok(None);
428 }
429
430 if is_hf_hub_offline() {
431 return Ok(offline_cache_repo(model_id, revision).get(file));
432 }
433
434 match api.get(file) {
435 Ok(p) => Ok(Some(p)),
436 Err(err) => match api_error_status_code(&err) {
437 Some(404) => Ok(None),
438 _ => Err(err),
439 },
440 }
441}
442
443pub fn probe_hf_repo_files(
446 model_id: &str,
447 revision: &str,
448 token_source: &crate::pipeline::TokenSource,
449) -> Option<Vec<String>> {
450 use hf_hub::api::sync::ApiBuilder;
451
452 if is_hf_hub_offline() {
453 let files = offline_snapshot_files(Path::new(model_id), revision);
454 return (!files.is_empty()).then_some(files);
455 }
456
457 let token = crate::utils::tokens::get_token(token_source).ok().flatten();
458 let cache = hf_hub_cache_dir()
459 .map(Cache::new)
460 .unwrap_or_else(Cache::from_env);
461 let mut api = ApiBuilder::from_cache(cache)
462 .with_progress(false)
463 .with_token(token);
464 if let Some(cache_dir) = hf_hub_cache_dir() {
465 api = api.with_cache_dir(cache_dir);
466 }
467 let repo = api.build().ok()?.repo(Repo::with_revision(
468 model_id.to_string(),
469 RepoType::Model,
470 revision.to_string(),
471 ));
472 repo.info()
473 .ok()
474 .map(|info| info.siblings.into_iter().map(|s| s.rfilename).collect())
475}