Skip to main content

hanzo_engine/pipeline/
hf.rs

1use std::{
2    env, fs,
3    io::Read,
4    path::{Path, PathBuf},
5};
6
7use anyhow::{anyhow, Result};
8use hf_hub::{
9    api::sync::{ApiError, ApiRepo},
10    Cache, Repo, RepoType,
11};
12use tracing::{trace, warn};
13
14use super::FileListCache;
15
16/// Env variable that, when set to a truthy value, disables all network calls
17/// to the Hugging Face Hub. Only cached files are used.
18pub const HF_HUB_OFFLINE_ENV: &str = "HF_HUB_OFFLINE";
19
20/// Returns true when the user has requested fully-offline operation via
21/// `HF_HUB_OFFLINE`. Accepted truthy values: `1`, `true`, `yes`, `on`
22/// (case-insensitive). Anything else, or unset, is treated as online.
23pub fn is_hf_hub_offline() -> bool {
24    matches!(
25        env::var(HF_HUB_OFFLINE_ENV)
26            .ok()
27            .map(|v| v.trim().to_ascii_lowercase()),
28        Some(ref v) if matches!(v.as_str(), "1" | "true" | "yes" | "on")
29    )
30}
31
32fn offline_repo(model_id: &Path, revision: &str) -> Repo {
33    Repo::with_revision(
34        model_id.display().to_string(),
35        RepoType::Model,
36        revision.to_string(),
37    )
38}
39
40pub(crate) fn offline_cache_repo(model_id: &Path, revision: &str) -> hf_hub::CacheRepo {
41    let cache = hf_hub_cache_dir().map(Cache::new).unwrap_or_default();
42    cache.repo(offline_repo(model_id, revision))
43}
44
45pub(crate) fn offline_missing_file_error(
46    model_id: &Path,
47    file: &str,
48    revision: &str,
49) -> anyhow::Error {
50    anyhow!(
51        "`{HF_HUB_OFFLINE_ENV}` is set but `{file}` for `{}` (revision `{revision}`) was not found in the local Hugging Face cache. \
52         Unset `{HF_HUB_OFFLINE_ENV}` or pre-download the file (e.g. via `huggingface-cli download`).",
53        model_id.display()
54    )
55}
56
57fn offline_snapshot_files(model_id: &Path, revision: &str) -> Vec<String> {
58    fn walk(root: &Path, dir: &Path, out: &mut Vec<String>) -> std::io::Result<()> {
59        for entry in fs::read_dir(dir)? {
60            let entry = entry?;
61            let path = entry.path();
62            if path.is_dir() {
63                walk(root, &path, out)?;
64            } else if let Ok(rel) = path.strip_prefix(root) {
65                out.push(rel.to_string_lossy().replace('\\', "/"));
66            }
67        }
68        Ok(())
69    }
70
71    let Some(cache_dir) = hf_hub_cache_dir() else {
72        return Vec::new();
73    };
74    let repo = offline_repo(model_id, revision);
75    let folder = repo.folder_name();
76    let ref_path = cache_dir.join(&folder).join("refs").join(revision);
77    // Refs file is typically a branch/tag name; fall back to treating revision as a literal commit.
78    let commit = fs::read_to_string(&ref_path)
79        .map(|s| s.trim().to_string())
80        .unwrap_or_else(|_| revision.to_string());
81    let snapshot_dir = cache_dir.join(&folder).join("snapshots").join(&commit);
82    if !snapshot_dir.is_dir() {
83        return Vec::new();
84    }
85    let mut files = Vec::new();
86    let _ = walk(&snapshot_dir, &snapshot_dir, &mut files);
87    files
88}
89
90#[derive(Clone, Debug)]
91pub(crate) struct RemoteAccessIssue {
92    pub status_code: Option<u16>,
93    pub message: String,
94}
95
96/// Resolve the Hugging Face home directory.
97///
98/// Precedence:
99/// 1. HF_HOME
100/// 2. ~/.cache/huggingface
101pub fn hf_home_dir() -> Option<PathBuf> {
102    let dir = env::var("HF_HOME")
103        .ok()
104        .map(PathBuf::from)
105        .or_else(|| dirs::home_dir().map(|home| home.join(".cache").join("huggingface")));
106
107    if let Some(ref dir) = dir {
108        if let Err(err) = fs::create_dir_all(dir) {
109            warn!(
110                "Could not create Hugging Face home directory `{}`: {err}",
111                dir.display()
112            );
113        }
114    }
115
116    dir
117}
118
119/// Resolve the Hugging Face Hub cache directory.
120///
121/// Precedence:
122/// 1. HF_HUB_CACHE
123/// 2. HF_HOME/hub
124/// 3. ~/.cache/huggingface/hub
125pub fn hf_hub_cache_dir() -> Option<PathBuf> {
126    let dir = env::var("HF_HUB_CACHE")
127        .ok()
128        .map(PathBuf::from)
129        .or_else(|| hf_home_dir().map(|home| home.join("hub")));
130
131    if let Some(ref dir) = dir {
132        if let Err(err) = fs::create_dir_all(dir) {
133            warn!(
134                "Could not create Hugging Face hub cache directory `{}`: {err}",
135                dir.display()
136            );
137        }
138    }
139
140    dir
141}
142
143/// Resolve the Hugging Face token file path.
144pub fn hf_token_path() -> Option<PathBuf> {
145    hf_home_dir().map(|home| home.join("token"))
146}
147
148fn cache_dir() -> PathBuf {
149    hf_hub_cache_dir().unwrap_or_else(|| PathBuf::from("./"))
150}
151
152fn cache_file_for_model(model_id: &Path) -> PathBuf {
153    let sanitized_id = model_id.display().to_string().replace('/', "-");
154    cache_dir().join(format!("{sanitized_id}_repo_list.json"))
155}
156
157fn read_cached_repo_files(cache_file: &Path) -> Option<Vec<String>> {
158    if !cache_file.exists() {
159        return None;
160    }
161
162    let mut file = match fs::File::open(cache_file) {
163        Ok(file) => file,
164        Err(err) => {
165            warn!(
166                "Could not open Hugging Face repo cache file `{}`: {err}",
167                cache_file.display()
168            );
169            return None;
170        }
171    };
172
173    let mut contents = String::new();
174    if let Err(err) = file.read_to_string(&mut contents) {
175        warn!(
176            "Could not read Hugging Face repo cache file `{}`: {err}",
177            cache_file.display()
178        );
179        return None;
180    }
181
182    match serde_json::from_str::<FileListCache>(&contents) {
183        Ok(cache) => {
184            trace!("Read from cache file `{}`", cache_file.display());
185            Some(cache.files)
186        }
187        Err(err) => {
188            warn!(
189                "Could not parse Hugging Face repo cache file `{}`: {err}",
190                cache_file.display()
191            );
192            None
193        }
194    }
195}
196
197fn write_cached_repo_files(cache_file: &Path, files: &[String]) {
198    let cache = FileListCache {
199        files: files.to_vec(),
200    };
201    match serde_json::to_string_pretty(&cache) {
202        Ok(json) => {
203            if let Err(err) = fs::write(cache_file, json) {
204                warn!(
205                    "Could not write Hugging Face repo cache file `{}`: {err}",
206                    cache_file.display()
207                );
208            } else {
209                trace!("Write to cache file `{}`", cache_file.display());
210            }
211        }
212        Err(err) => warn!(
213            "Could not serialize Hugging Face repo cache for `{}`: {err}",
214            cache_file.display()
215        ),
216    }
217}
218
219pub(crate) fn parse_status_code(message: &str) -> Option<u16> {
220    let marker = "status code ";
221    let (_, tail) = message.split_once(marker)?;
222    let digits = tail
223        .chars()
224        .take_while(|c| c.is_ascii_digit())
225        .collect::<String>();
226    digits.parse().ok()
227}
228
229pub(crate) fn api_error_status_code(err: &ApiError) -> Option<u16> {
230    match err {
231        ApiError::TooManyRetries(inner) => api_error_status_code(inner),
232        _ => parse_status_code(&err.to_string()),
233    }
234}
235
236pub(crate) fn should_propagate_api_error(err: &ApiError) -> bool {
237    matches!(api_error_status_code(err), Some(401 | 403 | 404))
238}
239
240pub(crate) fn remote_issue_from_api_error(
241    model_id: &Path,
242    file: Option<&str>,
243    err: &ApiError,
244) -> RemoteAccessIssue {
245    let target = match file {
246        Some(file) => format!("`{file}` for `{}`", model_id.display()),
247        None => format!("`{}`", model_id.display()),
248    };
249    RemoteAccessIssue {
250        status_code: api_error_status_code(err),
251        message: format!("Failed to access {target}: {err}"),
252    }
253}
254
255pub(crate) fn hf_access_error(model_id: &Path, issue: &RemoteAccessIssue) -> anyhow::Error {
256    match issue.status_code {
257        Some(code @ (401 | 403)) => anyhow!(
258            "Could not access `{}` on Hugging Face (HTTP {code}). You may need to run `hanzo login` or set HF_TOKEN.",
259            model_id.display()
260        ),
261        Some(404) => anyhow!(
262            "Model `{}` was not found or is not accessible on Hugging Face (HTTP 404). Check the model ID and your access token.",
263            model_id.display()
264        ),
265        Some(code) => anyhow!(
266            "Failed to access `{}` on Hugging Face (HTTP {code}): {}",
267            model_id.display(),
268            issue.message
269        ),
270        None => anyhow!(
271            "Failed to access `{}` on Hugging Face: {}",
272            model_id.display(),
273            issue.message
274        ),
275    }
276}
277
278pub(crate) fn hf_api_error(model_id: &Path, file: Option<&str>, err: &ApiError) -> anyhow::Error {
279    let status_code = api_error_status_code(err);
280    let file_context = file
281        .map(|f| format!(" while fetching `{f}`"))
282        .unwrap_or_default();
283    match status_code {
284        Some(code @ (401 | 403)) => anyhow!(
285            "Could not access `{}` on Hugging Face (HTTP {code}){file_context}. You may need to run `hanzo login` or set HF_TOKEN.",
286            model_id.display()
287        ),
288        Some(404) => anyhow!(
289            "Model `{}` was not found or is not accessible on Hugging Face (HTTP 404){file_context}. Check the model ID and your access token.",
290            model_id.display()
291        ),
292        Some(code) => anyhow!(
293            "Failed to access `{}` on Hugging Face (HTTP {code}){file_context}: {err}",
294            model_id.display()
295        ),
296        None => anyhow!(
297            "Failed to access `{}` on Hugging Face{file_context}: {err}",
298            model_id.display()
299        ),
300    }
301}
302
303pub(crate) fn local_file_missing_error(model_id: &Path, file: &str) -> anyhow::Error {
304    anyhow!(
305        "File `{file}` was not found at local model path `{}`.",
306        model_id.display()
307    )
308}
309
310pub(crate) fn list_repo_files(
311    api: &ApiRepo,
312    model_id: &Path,
313    should_error: bool,
314    revision: &str,
315) -> Result<Vec<String>> {
316    if model_id.exists() {
317        let listing = fs::read_dir(model_id).map_err(|err| {
318            anyhow!(
319                "Cannot list local model directory `{}`: {err}",
320                model_id.display()
321            )
322        })?;
323        let files = listing
324            .filter_map(|entry| entry.ok())
325            .filter_map(|entry| {
326                entry
327                    .path()
328                    .file_name()
329                    .and_then(|name| name.to_str())
330                    .map(std::string::ToString::to_string)
331            })
332            .collect::<Vec<_>>();
333        return Ok(files);
334    }
335
336    let cache_file = cache_file_for_model(model_id);
337    if let Some(files) = read_cached_repo_files(&cache_file) {
338        return Ok(files);
339    }
340
341    if is_hf_hub_offline() {
342        let files = offline_snapshot_files(model_id, revision);
343        if !files.is_empty() {
344            write_cached_repo_files(&cache_file, &files);
345            return Ok(files);
346        }
347        if should_error {
348            return Err(anyhow!(
349                "`{HF_HUB_OFFLINE_ENV}` is set but no cached file list or snapshot was found for `{}` (revision `{revision}`).",
350                model_id.display()
351            ));
352        }
353        warn!(
354            "`{HF_HUB_OFFLINE_ENV}` is set and no local Hugging Face cache was found for `{}` (revision `{revision}`)",
355            model_id.display()
356        );
357        return Ok(Vec::new());
358    }
359
360    match api.info() {
361        Ok(repo) => {
362            let files = repo
363                .siblings
364                .iter()
365                .map(|x| x.rfilename.clone())
366                .collect::<Vec<_>>();
367            write_cached_repo_files(&cache_file, &files);
368            Ok(files)
369        }
370        Err(err) => {
371            if should_error || should_propagate_api_error(&err) {
372                Err(hf_api_error(model_id, None, &err))
373            } else {
374                warn!(
375                    "Could not get directory listing from Hugging Face for `{}`: {err}",
376                    model_id.display()
377                );
378                Ok(Vec::new())
379            }
380        }
381    }
382}
383
384pub(crate) fn get_file(
385    api: &ApiRepo,
386    model_id: &Path,
387    file: &str,
388    revision: &str,
389) -> Result<PathBuf> {
390    if model_id.exists() {
391        let path = model_id.join(file);
392        if !path.exists() {
393            return Err(local_file_missing_error(model_id, file));
394        }
395        trace!("Loading `{file}` locally at `{}`", path.display());
396        return Ok(path);
397    }
398
399    if is_hf_hub_offline() {
400        if let Some(path) = offline_cache_repo(model_id, revision).get(file) {
401            trace!(
402                "Loading `{file}` from local Hugging Face cache at `{}` (offline mode)",
403                path.display()
404            );
405            return Ok(path);
406        }
407        return Err(offline_missing_file_error(model_id, file, revision));
408    }
409
410    api.get(file)
411        .map_err(|err| hf_api_error(model_id, Some(file), &err))
412}
413
414/// Like [`get_file`] but returns `Ok(None)` (instead of an error) when the file is genuinely missing, and used with `HF_HUB_OFFLINE`.
415pub(crate) fn try_get_file(
416    api: &ApiRepo,
417    model_id: &Path,
418    file: &str,
419    revision: &str,
420) -> std::result::Result<Option<PathBuf>, ApiError> {
421    if model_id.exists() {
422        let path = model_id.join(file);
423        if path.exists() {
424            trace!("Loading `{file}` locally at `{}`", path.display());
425            return Ok(Some(path));
426        }
427        return Ok(None);
428    }
429
430    if is_hf_hub_offline() {
431        return Ok(offline_cache_repo(model_id, revision).get(file));
432    }
433
434    match api.get(file) {
435        Ok(p) => Ok(Some(p)),
436        Err(err) => match api_error_status_code(&err) {
437            Some(404) => Ok(None),
438            _ => Err(err),
439        },
440    }
441}
442
443/// Best-effort file listing for a HF repo. Returns `None` on 404, API failure,
444/// or offline-without-cache. Quiet by design: callers choose what to log.
445pub fn probe_hf_repo_files(
446    model_id: &str,
447    revision: &str,
448    token_source: &crate::pipeline::TokenSource,
449) -> Option<Vec<String>> {
450    use hf_hub::api::sync::ApiBuilder;
451
452    if is_hf_hub_offline() {
453        let files = offline_snapshot_files(Path::new(model_id), revision);
454        return (!files.is_empty()).then_some(files);
455    }
456
457    let token = crate::utils::tokens::get_token(token_source).ok().flatten();
458    let cache = hf_hub_cache_dir()
459        .map(Cache::new)
460        .unwrap_or_else(Cache::from_env);
461    let mut api = ApiBuilder::from_cache(cache)
462        .with_progress(false)
463        .with_token(token);
464    if let Some(cache_dir) = hf_hub_cache_dir() {
465        api = api.with_cache_dir(cache_dir);
466    }
467    let repo = api.build().ok()?.repo(Repo::with_revision(
468        model_id.to_string(),
469        RepoType::Model,
470        revision.to_string(),
471    ));
472    repo.info()
473        .ok()
474        .map(|info| info.siblings.into_iter().map(|s| s.rfilename).collect())
475}