hf_fetch_model/lib.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! # hf-fetch-model
4//!
5//! Fast `HuggingFace` model downloads for Rust.
6//!
7//! An embeddable library for downloading `HuggingFace` model repositories
8//! with maximum throughput. Wraps [`hf_hub`] and adds repo-level orchestration.
9//!
10//! ## Quick Start
11//!
12//! ```rust,no_run
13//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
14//! let outcome = hf_fetch_model::download("julien-c/dummy-unknown".to_owned()).await?;
15//! println!("Model at: {}", outcome.inner().display());
16//! # Ok(())
17//! # }
18//! ```
19//!
20//! ## Configured Download
21//!
22//! ```rust,no_run
23//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
24//! use hf_fetch_model::FetchConfig;
25//!
26//! let config = FetchConfig::builder()
27//! .filter("*.safetensors")
28//! .filter("*.json")
29//! .on_progress(|e| {
30//! println!("{}: {:.1}%", e.filename, e.percent);
31//! })
32//! .build()?;
33//!
34//! let outcome = hf_fetch_model::download_with_config(
35//! "google/gemma-2-2b".to_owned(),
36//! &config,
37//! ).await?;
38//! // outcome.is_cached() tells you if it came from local cache
39//! let path = outcome.into_inner();
40//! # Ok(())
41//! # }
42//! ```
43//!
44//! ## Inspect Before Downloading
45//!
46//! Read tensor metadata from `.safetensors` headers via HTTP Range requests —
47//! no weight data downloaded. Sharded repos (those with
48//! `model.safetensors.index.json`) work transparently —
49//! [`inspect::inspect_repo_safetensors`] reads every shard's header in parallel
50//! and returns a flat per-file result list. See
51//! [`examples/candle_inspect.rs`](https://github.com/PCfVW/hf-fetch-model/blob/main/examples/candle_inspect.rs)
52//! for a runnable example, or the
53//! [Inspect tutorial](https://github.com/PCfVW/hf-fetch-model/blob/main/docs/tutorials/inspect-before-downloading.md)
54//! for a narrative walkthrough.
55//!
56//! ```rust,no_run
57//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
58//! let results = hf_fetch_model::inspect::inspect_repo_safetensors(
59//! "EleutherAI/pythia-1.4b", None, None,
60//! ).await?;
61//!
62//! for (filename, header, _source) in &results {
63//! println!("{filename}: {} tensors", header.tensors.len());
64//! }
65//! # Ok(())
66//! # }
67//! ```
68//!
69//! The CLI also exposes `hf-fm inspect <repo> [FILE] --check-gpu [N]` (v0.10.1)
70//! to print a one-line GPU-fit verdict against device `N` (default 0) using
71//! the `hypomnesis` crate (NVML on Linux/Windows, DXGI on Windows). The
72//! verdict is a binary-only feature today; no library equivalent is exposed
73//! — depend on `hypomnesis` directly if you need the device-info numbers
74//! from library code.
75//!
76//! ## `HuggingFace` Cache
77//!
78//! Downloaded files are stored in the standard `HuggingFace` cache directory
79//! (`~/.cache/huggingface/hub/`), ensuring compatibility with Python tooling.
80//!
81//! ## Cache Management
82//!
83//! v0.10.0 adds library APIs for inspecting, verifying, and pruning the local
84//! cache. [`cache::cache_summary`] enumerates every cached repo with size and
85//! file counts; [`cache::repo_status`] gives a per-file `Complete` / `Partial` /
86//! `Missing` breakdown for one repo; [`cache::verify_cache`] re-checks `SHA256`
87//! digests of cached files against `HuggingFace` LFS metadata; and
88//! [`cache::find_partial_files`] locates `.chunked.part` orphans from
89//! interrupted downloads.
90//!
91//! For long verifications (multi-GiB safetensors files), drive
92//! [`cache::verify_cache_with_progress`] with an [`Fn`] callback that receives
93//! [`cache::VerifyEvent`]s so a CLI or GUI can render a spinner or progress
94//! bar without polling.
95//!
96//! ```rust,no_run
97//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
98//! use hf_fetch_model::cache::{self, VerifyStatus};
99//!
100//! let results = cache::verify_cache("google/gemma-2-2b-it", None, None).await?;
101//! let ok = results
102//! .iter()
103//! .filter(|r| matches!(r.status, VerifyStatus::Ok))
104//! .count();
105//! let mismatch = results
106//! .iter()
107//! .filter(|r| matches!(r.status, VerifyStatus::Mismatch { .. }))
108//! .count();
109//! println!("{}/{} files verified, {} mismatches", ok, results.len(), mismatch);
110//! # Ok(())
111//! # }
112//! ```
113//!
114//! ## Download Durability
115//!
116//! Multi-connection downloads survive interruption. When a download is
117//! aborted by [`FetchConfigBuilder::timeout_per_file`] (default 300 s),
118//! Ctrl-C, panic, or a transient chunk error, the partial `.chunked.part`
119//! file plus a small per-chunk progress sidecar are kept on disk. The next
120//! call to [`download_with_config`] for the same file picks up where it
121//! stopped — each parallel chunk sends a fresh `Range` request that skips
122//! the bytes it already has — provided the upstream etag still matches.
123//! On etag change, schema-version mismatch, or a different
124//! [`FetchConfigBuilder::connections_per_file`] count, the partial is
125//! discarded and a fresh download starts.
126//!
127//! For slow connections on multi-GiB files, raise the per-file budget to
128//! match real throughput:
129//!
130//! ```rust,no_run
131//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
132//! use std::time::Duration;
133//! use hf_fetch_model::FetchConfig;
134//!
135//! let config = FetchConfig::builder()
136//! .timeout_per_file(Duration::from_secs(1800))
137//! .build()?;
138//! # let _ = hf_fetch_model::download_with_config(
139//! # "google/gemma-4-E2B-it".to_owned(),
140//! # &config,
141//! # ).await?;
142//! # Ok(())
143//! # }
144//! ```
145//!
146//! ## Authentication
147//!
148//! Set the `HF_TOKEN` environment variable to access private or gated models,
149//! or use [`FetchConfig::builder().token()`](FetchConfigBuilder::token).
150
151pub mod cache;
152pub mod cache_layout;
153pub mod checksum;
154mod chunked;
155mod chunked_state;
156pub mod config;
157pub mod discover;
158pub mod download;
159pub mod error;
160pub mod inspect;
161pub mod plan;
162pub mod progress;
163pub mod repo;
164mod retry;
165
166pub use chunked::build_client;
167pub use config::{
168 compile_glob_patterns, file_matches, has_glob_chars, FetchConfig, FetchConfigBuilder, Filter,
169};
170pub use discover::{DiscoveredFamily, GateStatus, ModelCardMetadata, SearchResult};
171pub use download::DownloadOutcome;
172pub use error::{FetchError, FileFailure};
173pub use inspect::AdapterConfig;
174pub use plan::{download_plan, DownloadPlan, FilePlan};
175pub use progress::{ProgressEvent, ProgressReceiver};
176
177use std::collections::HashMap;
178use std::path::PathBuf;
179
180use hf_hub::{Repo, RepoType};
181
182/// Pre-flight check for gated model access.
183///
184/// Two cases:
185/// - **No token**: checks the model metadata (unauthenticated) for gating
186/// status and rejects with a clear message if gated.
187/// - **Token present**: if the model is gated, makes one authenticated
188/// metadata request to verify the token actually grants access. Catches
189/// invalid tokens and unaccepted licenses before the download starts.
190///
191/// If the metadata request itself fails (network error, private repo),
192/// the check is silently skipped so that normal download error handling
193/// can take over.
194async fn preflight_gated_check(repo_id: &str, config: &FetchConfig) -> Result<(), FetchError> {
195 // Best-effort: if the metadata call fails, let the download proceed.
196 let Ok(metadata) = discover::fetch_model_card(repo_id).await else {
197 return Ok(());
198 };
199
200 if !metadata.gated.is_gated() {
201 return Ok(());
202 }
203
204 // Model is gated — check auth.
205 if config.token.is_none() {
206 return Err(FetchError::Auth {
207 reason: format!(
208 "{repo_id} is a gated model — accept the license at \
209 https://huggingface.co/{repo_id} and set HF_TOKEN or pass --token"
210 ),
211 });
212 }
213
214 // Token is present — verify it grants access with a lightweight probe.
215 let probe_client = chunked::build_client(config.token.as_deref())?;
216 let probe = repo::list_repo_files_with_metadata(
217 repo_id,
218 config.token.as_deref(),
219 config.revision.as_deref(),
220 &probe_client,
221 )
222 .await;
223
224 if let Err(ref e) = probe {
225 // BORROW: explicit .to_string() for error Display formatting
226 let msg = e.to_string();
227 if msg.contains("401") || msg.contains("403") {
228 return Err(FetchError::Auth {
229 reason: format!(
230 "{repo_id} is a gated model and your token was rejected — \
231 accept the license at https://huggingface.co/{repo_id} \
232 and check that your token is valid"
233 ),
234 });
235 }
236 }
237
238 Ok(())
239}
240
241/// Downloads all files from a `HuggingFace` model repository.
242///
243/// Uses high-throughput mode for maximum download speed, including
244/// auto-tuned concurrency, chunked multi-connection downloads for large
245/// files, and plan-optimized settings based on file size distribution.
246/// Files are stored in the standard `HuggingFace` cache layout
247/// (`~/.cache/huggingface/hub/`).
248///
249/// Authentication is handled via the `HF_TOKEN` environment variable when set.
250///
251/// For filtering, progress, and other options, use [`download_with_config()`].
252///
253/// # Arguments
254///
255/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
256///
257/// # Returns
258///
259/// The path to the snapshot directory containing all downloaded files.
260///
261/// # Errors
262///
263/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
264/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
265/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
266/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
267pub async fn download(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
268 let config = FetchConfig::builder().build()?;
269 download_with_config(repo_id, &config).await
270}
271
272/// Downloads files from a `HuggingFace` model repository using the given configuration.
273///
274/// Supports filtering, progress reporting, custom revision, authentication,
275/// and concurrency settings via [`FetchConfig`].
276///
277/// # Arguments
278///
279/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
280/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
281///
282/// # Returns
283///
284/// The path to the snapshot directory containing all downloaded files.
285///
286/// # Errors
287///
288/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
289/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
290/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
291pub async fn download_with_config(
292 repo_id: String,
293 config: &FetchConfig,
294) -> Result<DownloadOutcome<PathBuf>, FetchError> {
295 // BORROW: explicit .as_str() instead of Deref coercion
296 preflight_gated_check(repo_id.as_str(), config).await?;
297
298 let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
299
300 if let Some(ref token) = config.token {
301 // BORROW: explicit .clone() to pass owned String
302 builder = builder.with_token(Some(token.clone()));
303 }
304
305 if let Some(ref dir) = config.output_dir {
306 // BORROW: explicit .clone() for owned PathBuf
307 builder = builder.with_cache_dir(dir.clone());
308 }
309
310 let api = builder.build().map_err(FetchError::Api)?;
311
312 let hf_repo = match config.revision {
313 Some(ref rev) => {
314 // BORROW: explicit .clone() for owned String arguments
315 Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
316 }
317 None => Repo::new(repo_id.clone(), RepoType::Model),
318 };
319
320 let repo = api.repo(hf_repo);
321 download::download_all_files(repo, repo_id, Some(config)).await
322}
323
324/// Blocking version of [`download()`] for non-async callers.
325///
326/// Creates a Tokio runtime internally. Do not call from within
327/// an existing async context (use [`download()`] instead).
328///
329/// # Errors
330///
331/// Same as [`download()`].
332pub fn download_blocking(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
333 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
334 path: PathBuf::from("<runtime>"),
335 source: e,
336 })?;
337 rt.block_on(download(repo_id))
338}
339
340/// Blocking version of [`download_with_config()`] for non-async callers.
341///
342/// Creates a Tokio runtime internally. Do not call from within
343/// an existing async context (use [`download_with_config()`] instead).
344///
345/// # Errors
346///
347/// Same as [`download_with_config()`].
348pub fn download_with_config_blocking(
349 repo_id: String,
350 config: &FetchConfig,
351) -> Result<DownloadOutcome<PathBuf>, FetchError> {
352 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
353 path: PathBuf::from("<runtime>"),
354 source: e,
355 })?;
356 rt.block_on(download_with_config(repo_id, config))
357}
358
359/// Downloads all files from a `HuggingFace` model repository and returns
360/// a filename → path map.
361///
362/// Each key is the relative filename within the repository (e.g.,
363/// `"config.json"`, `"model.safetensors"`), and each value is the
364/// absolute local path to the downloaded file.
365///
366/// Uses the same high-throughput defaults as [`download()`]: auto-tuned
367/// concurrency and chunked multi-connection downloads for large files.
368///
369/// For filtering, progress, and other options, use
370/// [`download_files_with_config()`].
371///
372/// # Arguments
373///
374/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
375///
376/// # Errors
377///
378/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
379/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
380/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
381pub async fn download_files(
382 repo_id: String,
383) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
384 let config = FetchConfig::builder().build()?;
385 download_files_with_config(repo_id, &config).await
386}
387
388/// Downloads files from a `HuggingFace` model repository using the given
389/// configuration and returns a filename → path map.
390///
391/// Each key is the relative filename within the repository (e.g.,
392/// `"config.json"`, `"model.safetensors"`), and each value is the
393/// absolute local path to the downloaded file.
394///
395/// # Arguments
396///
397/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
398/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
399///
400/// # Errors
401///
402/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
403/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
404/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
405pub async fn download_files_with_config(
406 repo_id: String,
407 config: &FetchConfig,
408) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
409 // BORROW: explicit .as_str() instead of Deref coercion
410 preflight_gated_check(repo_id.as_str(), config).await?;
411
412 let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
413
414 if let Some(ref token) = config.token {
415 // BORROW: explicit .clone() to pass owned String
416 builder = builder.with_token(Some(token.clone()));
417 }
418
419 if let Some(ref dir) = config.output_dir {
420 // BORROW: explicit .clone() for owned PathBuf
421 builder = builder.with_cache_dir(dir.clone());
422 }
423
424 let api = builder.build().map_err(FetchError::Api)?;
425
426 let hf_repo = match config.revision {
427 Some(ref rev) => {
428 // BORROW: explicit .clone() for owned String arguments
429 Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
430 }
431 None => Repo::new(repo_id.clone(), RepoType::Model),
432 };
433
434 let repo = api.repo(hf_repo);
435 download::download_all_files_map(repo, repo_id, Some(config)).await
436}
437
438/// Blocking version of [`download_files()`] for non-async callers.
439///
440/// Creates a Tokio runtime internally. Do not call from within
441/// an existing async context (use [`download_files()`] instead).
442///
443/// # Errors
444///
445/// Same as [`download_files()`].
446pub fn download_files_blocking(
447 repo_id: String,
448) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
449 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
450 path: PathBuf::from("<runtime>"),
451 source: e,
452 })?;
453 rt.block_on(download_files(repo_id))
454}
455
456/// Downloads a single file from a `HuggingFace` model repository.
457///
458/// Returns the local cache path. If the file is already cached (and
459/// checksums match when `verify_checksums` is enabled), the download
460/// is skipped and the cached path is returned immediately.
461///
462/// Files at or above [`FetchConfig`]'s `chunk_threshold` (auto-tuned by
463/// the download plan optimizer, or 100 MiB fallback) are downloaded using
464/// multiple parallel HTTP Range connections (`connections_per_file`,
465/// auto-tuned or 8 fallback). Smaller files use a single connection.
466///
467/// # Arguments
468///
469/// * `repo_id` — Repository identifier (e.g., `"mntss/clt-gemma-2-2b-426k"`).
470/// * `filename` — Exact filename within the repository (e.g., `"W_enc_5.safetensors"`).
471/// * `config` — Shared configuration for auth, progress, checksums, retries, and chunking.
472///
473/// # Errors
474///
475/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
476/// * [`FetchError::Http`] — if the file does not exist in the repository.
477/// * [`FetchError::Api`] — on download failure (after retries).
478/// * [`FetchError::Checksum`] — if verification is enabled and fails.
479pub async fn download_file(
480 repo_id: String,
481 filename: &str,
482 config: &FetchConfig,
483) -> Result<DownloadOutcome<PathBuf>, FetchError> {
484 // BORROW: explicit .as_str() instead of Deref coercion
485 preflight_gated_check(repo_id.as_str(), config).await?;
486
487 let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
488
489 if let Some(ref token) = config.token {
490 // BORROW: explicit .clone() to pass owned String
491 builder = builder.with_token(Some(token.clone()));
492 }
493
494 if let Some(ref dir) = config.output_dir {
495 // BORROW: explicit .clone() for owned PathBuf
496 builder = builder.with_cache_dir(dir.clone());
497 }
498
499 let api = builder.build().map_err(FetchError::Api)?;
500
501 let hf_repo = match config.revision {
502 Some(ref rev) => {
503 // BORROW: explicit .clone() for owned String arguments
504 Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
505 }
506 None => Repo::new(repo_id.clone(), RepoType::Model),
507 };
508
509 let repo = api.repo(hf_repo);
510 download::download_file_by_name(repo, repo_id, filename, config).await
511}
512
513/// Blocking version of [`download_file()`] for non-async callers.
514///
515/// Creates a Tokio runtime internally. Do not call from within
516/// an existing async context (use [`download_file()`] instead).
517///
518/// # Errors
519///
520/// Same as [`download_file()`].
521pub fn download_file_blocking(
522 repo_id: String,
523 filename: &str,
524 config: &FetchConfig,
525) -> Result<DownloadOutcome<PathBuf>, FetchError> {
526 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
527 path: PathBuf::from("<runtime>"),
528 source: e,
529 })?;
530 rt.block_on(download_file(repo_id, filename, config))
531}
532
533/// Blocking version of [`download_files_with_config()`] for non-async callers.
534///
535/// Creates a Tokio runtime internally. Do not call from within
536/// an existing async context (use [`download_files_with_config()`] instead).
537///
538/// # Errors
539///
540/// Same as [`download_files_with_config()`].
541pub fn download_files_with_config_blocking(
542 repo_id: String,
543 config: &FetchConfig,
544) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
545 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
546 path: PathBuf::from("<runtime>"),
547 source: e,
548 })?;
549 rt.block_on(download_files_with_config(repo_id, config))
550}
551
552/// Downloads files according to an existing [`DownloadPlan`].
553///
554/// Only uncached files in the plan are downloaded. The `config` controls
555/// authentication, progress, timeouts, and performance settings.
556/// Use [`DownloadPlan::recommended_config()`] to compute an optimized config,
557/// or override specific fields via [`DownloadPlan::recommended_config_builder()`].
558///
559/// # Errors
560///
561/// Returns [`FetchError::Io`] if the cache directory cannot be resolved.
562/// Same error conditions as [`download_with_config()`] for the download itself.
563pub async fn download_with_plan(
564 plan: &DownloadPlan,
565 config: &FetchConfig,
566) -> Result<DownloadOutcome<PathBuf>, FetchError> {
567 if plan.fully_cached() {
568 // Resolve snapshot path from cache and return immediately.
569 let cache_dir = config
570 .output_dir
571 .clone()
572 .map_or_else(cache::hf_cache_dir, Ok)?;
573 let repo_dir = cache_layout::repo_dir(&cache_dir, plan.repo_id.as_str());
574 let snapshot_dir = cache_layout::snapshot_dir(&repo_dir, plan.revision.as_str());
575 return Ok(DownloadOutcome::Cached(snapshot_dir));
576 }
577
578 // Delegate to the standard download path which will re-check cache
579 // internally. The plan's value is the dry-run preview and the
580 // recommended config computed by the caller.
581 // BORROW: explicit .clone() for owned String argument
582 download_with_config(plan.repo_id.clone(), config).await
583}
584
585/// Blocking version of [`download_with_plan()`] for non-async callers.
586///
587/// Creates a Tokio runtime internally. Do not call from within
588/// an existing async context (use [`download_with_plan()`] instead).
589///
590/// # Errors
591///
592/// Same as [`download_with_plan()`].
593pub fn download_with_plan_blocking(
594 plan: &DownloadPlan,
595 config: &FetchConfig,
596) -> Result<DownloadOutcome<PathBuf>, FetchError> {
597 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
598 path: PathBuf::from("<runtime>"),
599 source: e,
600 })?;
601 rt.block_on(download_with_plan(plan, config))
602}