Skip to main content

hf_fetch_model/
lib.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! # hf-fetch-model
4//!
5//! Fast `HuggingFace` model downloads for Rust.
6//!
7//! An embeddable library for downloading `HuggingFace` model repositories
8//! with maximum throughput. Wraps [`hf_hub`] and adds repo-level orchestration.
9//!
10//! ## Quick Start
11//!
12//! ```rust,no_run
13//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
14//! let outcome = hf_fetch_model::download("julien-c/dummy-unknown".to_owned()).await?;
15//! println!("Model at: {}", outcome.inner().display());
16//! # Ok(())
17//! # }
18//! ```
19//!
20//! ## Configured Download
21//!
22//! ```rust,no_run
23//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
24//! use hf_fetch_model::FetchConfig;
25//!
26//! let config = FetchConfig::builder()
27//!     .filter("*.safetensors")
28//!     .filter("*.json")
29//!     .on_progress(|e| {
30//!         println!("{}: {:.1}%", e.filename, e.percent);
31//!     })
32//!     .build()?;
33//!
34//! let outcome = hf_fetch_model::download_with_config(
35//!     "google/gemma-2-2b".to_owned(),
36//!     &config,
37//! ).await?;
38//! // outcome.is_cached() tells you if it came from local cache
39//! let path = outcome.into_inner();
40//! # Ok(())
41//! # }
42//! ```
43//!
44//! ## Inspect Before Downloading
45//!
46//! Read tensor metadata from `.safetensors` headers via HTTP Range requests —
47//! no weight data downloaded. Sharded repos (those with
48//! `model.safetensors.index.json`) work transparently —
49//! [`inspect::inspect_repo_safetensors`] reads every shard's header in parallel
50//! and returns a flat per-file result list. See
51//! [`examples/candle_inspect.rs`](https://github.com/PCfVW/hf-fetch-model/blob/main/examples/candle_inspect.rs)
52//! for a runnable example, or the
53//! [Inspect tutorial](https://github.com/PCfVW/hf-fetch-model/blob/main/docs/tutorials/inspect-before-downloading.md)
54//! for a narrative walkthrough.
55//!
56//! ```rust,no_run
57//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
58//! let results = hf_fetch_model::inspect::inspect_repo_safetensors(
59//!     "EleutherAI/pythia-1.4b", None, None,
60//! ).await?;
61//!
62//! for (filename, header, _source) in &results {
63//!     println!("{filename}: {} tensors", header.tensors.len());
64//! }
65//! # Ok(())
66//! # }
67//! ```
68//!
69//! The CLI also exposes `hf-fm inspect <repo> [FILE] --check-gpu [N]` (v0.10.1)
70//! to print a one-line GPU-fit verdict against device `N` (default 0) using
71//! the `hypomnesis` crate (NVML on Linux/Windows, DXGI on Windows). The
72//! verdict is a binary-only feature today; no library equivalent is exposed
73//! — depend on `hypomnesis` directly if you need the device-info numbers
74//! from library code.
75//!
76//! ## `HuggingFace` Cache
77//!
78//! Downloaded files are stored in the standard `HuggingFace` cache directory
79//! (`~/.cache/huggingface/hub/`), ensuring compatibility with Python tooling.
80//!
81//! ## Cache Management
82//!
83//! v0.10.0 adds library APIs for inspecting, verifying, and pruning the local
84//! cache. [`cache::cache_summary`] enumerates every cached repo with size and
85//! file counts; [`cache::repo_status`] gives a per-file `Complete` / `Partial` /
86//! `Missing` breakdown for one repo; [`cache::verify_cache`] re-checks `SHA256`
87//! digests of cached files against `HuggingFace` LFS metadata; and
88//! [`cache::find_partial_files`] locates `.chunked.part` orphans from
89//! interrupted downloads.
90//!
91//! For long verifications (multi-GiB safetensors files), drive
92//! [`cache::verify_cache_with_progress`] with an [`Fn`] callback that receives
93//! [`cache::VerifyEvent`]s so a CLI or GUI can render a spinner or progress
94//! bar without polling.
95//!
96//! ```rust,no_run
97//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
98//! use hf_fetch_model::cache::{self, VerifyStatus};
99//!
100//! let results = cache::verify_cache("google/gemma-2-2b-it", None, None).await?;
101//! let ok = results
102//!     .iter()
103//!     .filter(|r| matches!(r.status, VerifyStatus::Ok))
104//!     .count();
105//! let mismatch = results
106//!     .iter()
107//!     .filter(|r| matches!(r.status, VerifyStatus::Mismatch { .. }))
108//!     .count();
109//! println!("{}/{} files verified, {} mismatches", ok, results.len(), mismatch);
110//! # Ok(())
111//! # }
112//! ```
113//!
114//! ## Download Durability
115//!
116//! Multi-connection downloads survive interruption. When a download is
117//! aborted by [`FetchConfigBuilder::timeout_per_file`] (default 300 s),
118//! Ctrl-C, panic, or a transient chunk error, the partial `.chunked.part`
119//! file plus a small per-chunk progress sidecar are kept on disk. The next
120//! call to [`download_with_config`] for the same file picks up where it
121//! stopped — each parallel chunk sends a fresh `Range` request that skips
122//! the bytes it already has — provided the upstream etag still matches.
123//! On etag change, schema-version mismatch, or a different
124//! [`FetchConfigBuilder::connections_per_file`] count, the partial is
125//! discarded and a fresh download starts.
126//!
127//! For slow connections on multi-GiB files, raise the per-file budget to
128//! match real throughput:
129//!
130//! ```rust,no_run
131//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
132//! use std::time::Duration;
133//! use hf_fetch_model::FetchConfig;
134//!
135//! let config = FetchConfig::builder()
136//!     .timeout_per_file(Duration::from_secs(1800))
137//!     .build()?;
138//! # let _ = hf_fetch_model::download_with_config(
139//! #     "google/gemma-4-E2B-it".to_owned(),
140//! #     &config,
141//! # ).await?;
142//! # Ok(())
143//! # }
144//! ```
145//!
146//! ## Authentication
147//!
148//! Set the `HF_TOKEN` environment variable to access private or gated models,
149//! or use [`FetchConfig::builder().token()`](FetchConfigBuilder::token).
150
151pub mod cache;
152pub mod cache_layout;
153pub mod checksum;
154mod chunked;
155mod chunked_state;
156pub mod config;
157pub mod discover;
158pub mod download;
159pub mod error;
160pub mod inspect;
161pub mod plan;
162pub mod progress;
163pub mod repo;
164mod retry;
165
166pub use chunked::build_client;
167pub use config::{
168    compile_glob_patterns, file_matches, has_glob_chars, FetchConfig, FetchConfigBuilder, Filter,
169};
170pub use discover::{DiscoveredFamily, GateStatus, ModelCardMetadata, SearchResult};
171pub use download::DownloadOutcome;
172pub use error::{FetchError, FileFailure};
173pub use inspect::AdapterConfig;
174pub use plan::{download_plan, DownloadPlan, FilePlan};
175pub use progress::{ProgressEvent, ProgressReceiver};
176
177use std::collections::HashMap;
178use std::path::PathBuf;
179
180use hf_hub::{Repo, RepoType};
181
182/// Pre-flight check for gated model access.
183///
184/// Two cases:
185/// - **No token**: checks the model metadata (unauthenticated) for gating
186///   status and rejects with a clear message if gated.
187/// - **Token present**: if the model is gated, makes one authenticated
188///   metadata request to verify the token actually grants access. Catches
189///   invalid tokens and unaccepted licenses before the download starts.
190///
191/// If the metadata request itself fails (network error, private repo),
192/// the check is silently skipped so that normal download error handling
193/// can take over.
194async fn preflight_gated_check(repo_id: &str, config: &FetchConfig) -> Result<(), FetchError> {
195    // Best-effort: if the metadata call fails, let the download proceed.
196    let Ok(metadata) = discover::fetch_model_card(repo_id).await else {
197        return Ok(());
198    };
199
200    if !metadata.gated.is_gated() {
201        return Ok(());
202    }
203
204    // Model is gated — check auth.
205    if config.token.is_none() {
206        return Err(FetchError::Auth {
207            reason: format!(
208                "{repo_id} is a gated model — accept the license at \
209                 https://huggingface.co/{repo_id} and set HF_TOKEN or pass --token"
210            ),
211        });
212    }
213
214    // Token is present — verify it grants access with a lightweight probe.
215    let probe_client = chunked::build_client(config.token.as_deref())?;
216    let probe = repo::list_repo_files_with_metadata(
217        repo_id,
218        config.token.as_deref(),
219        config.revision.as_deref(),
220        &probe_client,
221    )
222    .await;
223
224    if let Err(ref e) = probe {
225        // BORROW: explicit .to_string() for error Display formatting
226        let msg = e.to_string();
227        if msg.contains("401") || msg.contains("403") {
228            return Err(FetchError::Auth {
229                reason: format!(
230                    "{repo_id} is a gated model and your token was rejected — \
231                     accept the license at https://huggingface.co/{repo_id} \
232                     and check that your token is valid"
233                ),
234            });
235        }
236    }
237
238    Ok(())
239}
240
241/// Downloads all files from a `HuggingFace` model repository.
242///
243/// Uses high-throughput mode for maximum download speed, including
244/// auto-tuned concurrency, chunked multi-connection downloads for large
245/// files, and plan-optimized settings based on file size distribution.
246/// Files are stored in the standard `HuggingFace` cache layout
247/// (`~/.cache/huggingface/hub/`).
248///
249/// Authentication is handled via the `HF_TOKEN` environment variable when set.
250///
251/// For filtering, progress, and other options, use [`download_with_config()`].
252///
253/// # Arguments
254///
255/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
256///
257/// # Returns
258///
259/// The path to the snapshot directory containing all downloaded files.
260///
261/// # Errors
262///
263/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
264/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
265/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
266/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
267pub async fn download(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
268    let config = FetchConfig::builder().build()?;
269    download_with_config(repo_id, &config).await
270}
271
272/// Downloads files from a `HuggingFace` model repository using the given configuration.
273///
274/// Supports filtering, progress reporting, custom revision, authentication,
275/// and concurrency settings via [`FetchConfig`].
276///
277/// # Arguments
278///
279/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
280/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
281///
282/// # Returns
283///
284/// The path to the snapshot directory containing all downloaded files.
285///
286/// # Errors
287///
288/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
289/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
290/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
291pub async fn download_with_config(
292    repo_id: String,
293    config: &FetchConfig,
294) -> Result<DownloadOutcome<PathBuf>, FetchError> {
295    // BORROW: explicit .as_str() instead of Deref coercion
296    preflight_gated_check(repo_id.as_str(), config).await?;
297
298    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
299
300    if let Some(ref token) = config.token {
301        // BORROW: explicit .clone() to pass owned String
302        builder = builder.with_token(Some(token.clone()));
303    }
304
305    if let Some(ref dir) = config.output_dir {
306        // BORROW: explicit .clone() for owned PathBuf
307        builder = builder.with_cache_dir(dir.clone());
308    }
309
310    let api = builder.build().map_err(FetchError::Api)?;
311
312    let hf_repo = match config.revision {
313        Some(ref rev) => {
314            // BORROW: explicit .clone() for owned String arguments
315            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
316        }
317        None => Repo::new(repo_id.clone(), RepoType::Model),
318    };
319
320    let repo = api.repo(hf_repo);
321    download::download_all_files(repo, repo_id, Some(config)).await
322}
323
324/// Blocking version of [`download()`] for non-async callers.
325///
326/// Creates a Tokio runtime internally. Do not call from within
327/// an existing async context (use [`download()`] instead).
328///
329/// # Errors
330///
331/// Same as [`download()`].
332pub fn download_blocking(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
333    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
334        path: PathBuf::from("<runtime>"),
335        source: e,
336    })?;
337    rt.block_on(download(repo_id))
338}
339
340/// Blocking version of [`download_with_config()`] for non-async callers.
341///
342/// Creates a Tokio runtime internally. Do not call from within
343/// an existing async context (use [`download_with_config()`] instead).
344///
345/// # Errors
346///
347/// Same as [`download_with_config()`].
348pub fn download_with_config_blocking(
349    repo_id: String,
350    config: &FetchConfig,
351) -> Result<DownloadOutcome<PathBuf>, FetchError> {
352    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
353        path: PathBuf::from("<runtime>"),
354        source: e,
355    })?;
356    rt.block_on(download_with_config(repo_id, config))
357}
358
359/// Downloads all files from a `HuggingFace` model repository and returns
360/// a filename → path map.
361///
362/// Each key is the relative filename within the repository (e.g.,
363/// `"config.json"`, `"model.safetensors"`), and each value is the
364/// absolute local path to the downloaded file.
365///
366/// Uses the same high-throughput defaults as [`download()`]: auto-tuned
367/// concurrency and chunked multi-connection downloads for large files.
368///
369/// For filtering, progress, and other options, use
370/// [`download_files_with_config()`].
371///
372/// # Arguments
373///
374/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
375///
376/// # Errors
377///
378/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
379/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
380/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
381pub async fn download_files(
382    repo_id: String,
383) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
384    let config = FetchConfig::builder().build()?;
385    download_files_with_config(repo_id, &config).await
386}
387
388/// Downloads files from a `HuggingFace` model repository using the given
389/// configuration and returns a filename → path map.
390///
391/// Each key is the relative filename within the repository (e.g.,
392/// `"config.json"`, `"model.safetensors"`), and each value is the
393/// absolute local path to the downloaded file.
394///
395/// # Arguments
396///
397/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
398/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
399///
400/// # Errors
401///
402/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
403/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
404/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
405pub async fn download_files_with_config(
406    repo_id: String,
407    config: &FetchConfig,
408) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
409    // BORROW: explicit .as_str() instead of Deref coercion
410    preflight_gated_check(repo_id.as_str(), config).await?;
411
412    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
413
414    if let Some(ref token) = config.token {
415        // BORROW: explicit .clone() to pass owned String
416        builder = builder.with_token(Some(token.clone()));
417    }
418
419    if let Some(ref dir) = config.output_dir {
420        // BORROW: explicit .clone() for owned PathBuf
421        builder = builder.with_cache_dir(dir.clone());
422    }
423
424    let api = builder.build().map_err(FetchError::Api)?;
425
426    let hf_repo = match config.revision {
427        Some(ref rev) => {
428            // BORROW: explicit .clone() for owned String arguments
429            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
430        }
431        None => Repo::new(repo_id.clone(), RepoType::Model),
432    };
433
434    let repo = api.repo(hf_repo);
435    download::download_all_files_map(repo, repo_id, Some(config)).await
436}
437
438/// Blocking version of [`download_files()`] for non-async callers.
439///
440/// Creates a Tokio runtime internally. Do not call from within
441/// an existing async context (use [`download_files()`] instead).
442///
443/// # Errors
444///
445/// Same as [`download_files()`].
446pub fn download_files_blocking(
447    repo_id: String,
448) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
449    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
450        path: PathBuf::from("<runtime>"),
451        source: e,
452    })?;
453    rt.block_on(download_files(repo_id))
454}
455
456/// Downloads a single file from a `HuggingFace` model repository.
457///
458/// Returns the local cache path. If the file is already cached (and
459/// checksums match when `verify_checksums` is enabled), the download
460/// is skipped and the cached path is returned immediately.
461///
462/// Files at or above [`FetchConfig`]'s `chunk_threshold` (auto-tuned by
463/// the download plan optimizer, or 100 MiB fallback) are downloaded using
464/// multiple parallel HTTP Range connections (`connections_per_file`,
465/// auto-tuned or 8 fallback). Smaller files use a single connection.
466///
467/// # Arguments
468///
469/// * `repo_id` — Repository identifier (e.g., `"mntss/clt-gemma-2-2b-426k"`).
470/// * `filename` — Exact filename within the repository (e.g., `"W_enc_5.safetensors"`).
471/// * `config` — Shared configuration for auth, progress, checksums, retries, and chunking.
472///
473/// # Errors
474///
475/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
476/// * [`FetchError::Http`] — if the file does not exist in the repository.
477/// * [`FetchError::Api`] — on download failure (after retries).
478/// * [`FetchError::Checksum`] — if verification is enabled and fails.
479pub async fn download_file(
480    repo_id: String,
481    filename: &str,
482    config: &FetchConfig,
483) -> Result<DownloadOutcome<PathBuf>, FetchError> {
484    // BORROW: explicit .as_str() instead of Deref coercion
485    preflight_gated_check(repo_id.as_str(), config).await?;
486
487    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
488
489    if let Some(ref token) = config.token {
490        // BORROW: explicit .clone() to pass owned String
491        builder = builder.with_token(Some(token.clone()));
492    }
493
494    if let Some(ref dir) = config.output_dir {
495        // BORROW: explicit .clone() for owned PathBuf
496        builder = builder.with_cache_dir(dir.clone());
497    }
498
499    let api = builder.build().map_err(FetchError::Api)?;
500
501    let hf_repo = match config.revision {
502        Some(ref rev) => {
503            // BORROW: explicit .clone() for owned String arguments
504            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
505        }
506        None => Repo::new(repo_id.clone(), RepoType::Model),
507    };
508
509    let repo = api.repo(hf_repo);
510    download::download_file_by_name(repo, repo_id, filename, config).await
511}
512
513/// Blocking version of [`download_file()`] for non-async callers.
514///
515/// Creates a Tokio runtime internally. Do not call from within
516/// an existing async context (use [`download_file()`] instead).
517///
518/// # Errors
519///
520/// Same as [`download_file()`].
521pub fn download_file_blocking(
522    repo_id: String,
523    filename: &str,
524    config: &FetchConfig,
525) -> Result<DownloadOutcome<PathBuf>, FetchError> {
526    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
527        path: PathBuf::from("<runtime>"),
528        source: e,
529    })?;
530    rt.block_on(download_file(repo_id, filename, config))
531}
532
533/// Blocking version of [`download_files_with_config()`] for non-async callers.
534///
535/// Creates a Tokio runtime internally. Do not call from within
536/// an existing async context (use [`download_files_with_config()`] instead).
537///
538/// # Errors
539///
540/// Same as [`download_files_with_config()`].
541pub fn download_files_with_config_blocking(
542    repo_id: String,
543    config: &FetchConfig,
544) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
545    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
546        path: PathBuf::from("<runtime>"),
547        source: e,
548    })?;
549    rt.block_on(download_files_with_config(repo_id, config))
550}
551
552/// Downloads files according to an existing [`DownloadPlan`].
553///
554/// Only uncached files in the plan are downloaded. The `config` controls
555/// authentication, progress, timeouts, and performance settings.
556/// Use [`DownloadPlan::recommended_config()`] to compute an optimized config,
557/// or override specific fields via [`DownloadPlan::recommended_config_builder()`].
558///
559/// # Errors
560///
561/// Returns [`FetchError::Io`] if the cache directory cannot be resolved.
562/// Same error conditions as [`download_with_config()`] for the download itself.
563pub async fn download_with_plan(
564    plan: &DownloadPlan,
565    config: &FetchConfig,
566) -> Result<DownloadOutcome<PathBuf>, FetchError> {
567    if plan.fully_cached() {
568        // Resolve snapshot path from cache and return immediately.
569        let cache_dir = config
570            .output_dir
571            .clone()
572            .map_or_else(cache::hf_cache_dir, Ok)?;
573        let repo_dir = cache_layout::repo_dir(&cache_dir, plan.repo_id.as_str());
574        let snapshot_dir = cache_layout::snapshot_dir(&repo_dir, plan.revision.as_str());
575        return Ok(DownloadOutcome::Cached(snapshot_dir));
576    }
577
578    // Delegate to the standard download path which will re-check cache
579    // internally. The plan's value is the dry-run preview and the
580    // recommended config computed by the caller.
581    // BORROW: explicit .clone() for owned String argument
582    download_with_config(plan.repo_id.clone(), config).await
583}
584
585/// Blocking version of [`download_with_plan()`] for non-async callers.
586///
587/// Creates a Tokio runtime internally. Do not call from within
588/// an existing async context (use [`download_with_plan()`] instead).
589///
590/// # Errors
591///
592/// Same as [`download_with_plan()`].
593pub fn download_with_plan_blocking(
594    plan: &DownloadPlan,
595    config: &FetchConfig,
596) -> Result<DownloadOutcome<PathBuf>, FetchError> {
597    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
598        path: PathBuf::from("<runtime>"),
599        source: e,
600    })?;
601    rt.block_on(download_with_plan(plan, config))
602}