hf_fetch_model/lib.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! # hf-fetch-model
4//!
5//! Fast `HuggingFace` model downloads for Rust.
6//!
7//! An embeddable library for downloading `HuggingFace` model repositories
8//! with maximum throughput. Wraps [`hf_hub`] and adds repo-level orchestration.
9//!
10//! ## Quick Start
11//!
12//! ```rust,no_run
13//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
14//! let outcome = hf_fetch_model::download("julien-c/dummy-unknown".to_owned()).await?;
15//! println!("Model at: {}", outcome.inner().display());
16//! # Ok(())
17//! # }
18//! ```
19//!
20//! ## Configured Download
21//!
22//! ```rust,no_run
23//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
24//! use hf_fetch_model::FetchConfig;
25//!
26//! let config = FetchConfig::builder()
27//! .filter("*.safetensors")
28//! .filter("*.json")
29//! .on_progress(|e| {
30//! println!("{}: {:.1}%", e.filename, e.percent);
31//! })
32//! .build()?;
33//!
34//! let outcome = hf_fetch_model::download_with_config(
35//! "google/gemma-2-2b".to_owned(),
36//! &config,
37//! ).await?;
38//! // outcome.is_cached() tells you if it came from local cache
39//! let path = outcome.into_inner();
40//! # Ok(())
41//! # }
42//! ```
43//!
44//! ## Inspect Before Downloading
45//!
46//! Read tensor metadata from `.safetensors` headers via HTTP Range requests —
47//! no weight data downloaded. See [`examples/candle_inspect.rs`](https://github.com/PCfVW/hf-fetch-model/blob/main/examples/candle_inspect.rs)
48//! for a runnable example.
49//!
50//! ```rust,no_run
51//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
52//! let results = hf_fetch_model::inspect::inspect_repo_safetensors(
53//! "EleutherAI/pythia-1.4b", None, None,
54//! ).await?;
55//!
56//! for (filename, header, _source) in &results {
57//! println!("{filename}: {} tensors", header.tensors.len());
58//! }
59//! # Ok(())
60//! # }
61//! ```
62//!
63//! ## `HuggingFace` Cache
64//!
65//! Downloaded files are stored in the standard `HuggingFace` cache directory
66//! (`~/.cache/huggingface/hub/`), ensuring compatibility with Python tooling.
67//!
68//! ## Authentication
69//!
70//! Set the `HF_TOKEN` environment variable to access private or gated models,
71//! or use [`FetchConfig::builder().token()`](FetchConfigBuilder::token).
72
73pub mod cache;
74pub mod checksum;
75mod chunked;
76pub mod config;
77pub mod discover;
78pub mod download;
79pub mod error;
80pub mod inspect;
81pub mod plan;
82pub mod progress;
83pub mod repo;
84mod retry;
85
86pub use config::{
87 compile_glob_patterns, file_matches, has_glob_chars, FetchConfig, FetchConfigBuilder, Filter,
88};
89pub use discover::{DiscoveredFamily, GateStatus, ModelCardMetadata, SearchResult};
90pub use download::DownloadOutcome;
91pub use error::{FetchError, FileFailure};
92pub use inspect::AdapterConfig;
93pub use plan::{download_plan, DownloadPlan, FilePlan};
94pub use progress::ProgressEvent;
95
96use std::collections::HashMap;
97use std::path::PathBuf;
98
99use hf_hub::{Repo, RepoType};
100
101/// Pre-flight check for gated model access.
102///
103/// Two cases:
104/// - **No token**: checks the model metadata (unauthenticated) for gating
105/// status and rejects with a clear message if gated.
106/// - **Token present**: if the model is gated, makes one authenticated
107/// metadata request to verify the token actually grants access. Catches
108/// invalid tokens and unaccepted licenses before the download starts.
109///
110/// If the metadata request itself fails (network error, private repo),
111/// the check is silently skipped so that normal download error handling
112/// can take over.
113async fn preflight_gated_check(repo_id: &str, config: &FetchConfig) -> Result<(), FetchError> {
114 // Best-effort: if the metadata call fails, let the download proceed.
115 let Ok(metadata) = discover::fetch_model_card(repo_id).await else {
116 return Ok(());
117 };
118
119 if !metadata.gated.is_gated() {
120 return Ok(());
121 }
122
123 // Model is gated — check auth.
124 if config.token.is_none() {
125 return Err(FetchError::Auth {
126 reason: format!(
127 "{repo_id} is a gated model — accept the license at \
128 https://huggingface.co/{repo_id} and set HF_TOKEN or pass --token"
129 ),
130 });
131 }
132
133 // Token is present — verify it grants access with a lightweight probe.
134 let probe = repo::list_repo_files_with_metadata(
135 repo_id,
136 config.token.as_deref(),
137 config.revision.as_deref(),
138 )
139 .await;
140
141 if let Err(ref e) = probe {
142 // BORROW: explicit .to_string() for error Display formatting
143 let msg = e.to_string();
144 if msg.contains("401") || msg.contains("403") {
145 return Err(FetchError::Auth {
146 reason: format!(
147 "{repo_id} is a gated model and your token was rejected — \
148 accept the license at https://huggingface.co/{repo_id} \
149 and check that your token is valid"
150 ),
151 });
152 }
153 }
154
155 Ok(())
156}
157
158/// Downloads all files from a `HuggingFace` model repository.
159///
160/// Uses high-throughput mode for maximum download speed, including
161/// auto-tuned concurrency, chunked multi-connection downloads for large
162/// files, and plan-optimized settings based on file size distribution.
163/// Files are stored in the standard `HuggingFace` cache layout
164/// (`~/.cache/huggingface/hub/`).
165///
166/// Authentication is handled via the `HF_TOKEN` environment variable when set.
167///
168/// For filtering, progress, and other options, use [`download_with_config()`].
169///
170/// # Arguments
171///
172/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
173///
174/// # Returns
175///
176/// The path to the snapshot directory containing all downloaded files.
177///
178/// # Errors
179///
180/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
181/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
182/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
183/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
184pub async fn download(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
185 let config = FetchConfig::builder().build()?;
186 download_with_config(repo_id, &config).await
187}
188
189/// Downloads files from a `HuggingFace` model repository using the given configuration.
190///
191/// Supports filtering, progress reporting, custom revision, authentication,
192/// and concurrency settings via [`FetchConfig`].
193///
194/// # Arguments
195///
196/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
197/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
198///
199/// # Returns
200///
201/// The path to the snapshot directory containing all downloaded files.
202///
203/// # Errors
204///
205/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
206/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
207/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
208pub async fn download_with_config(
209 repo_id: String,
210 config: &FetchConfig,
211) -> Result<DownloadOutcome<PathBuf>, FetchError> {
212 // BORROW: explicit .as_str() instead of Deref coercion
213 preflight_gated_check(repo_id.as_str(), config).await?;
214
215 let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
216
217 if let Some(ref token) = config.token {
218 // BORROW: explicit .clone() to pass owned String
219 builder = builder.with_token(Some(token.clone()));
220 }
221
222 if let Some(ref dir) = config.output_dir {
223 // BORROW: explicit .clone() for owned PathBuf
224 builder = builder.with_cache_dir(dir.clone());
225 }
226
227 let api = builder.build().map_err(FetchError::Api)?;
228
229 let hf_repo = match config.revision {
230 Some(ref rev) => {
231 // BORROW: explicit .clone() for owned String arguments
232 Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
233 }
234 None => Repo::new(repo_id.clone(), RepoType::Model),
235 };
236
237 let repo = api.repo(hf_repo);
238 download::download_all_files(repo, repo_id, Some(config)).await
239}
240
241/// Blocking version of [`download()`] for non-async callers.
242///
243/// Creates a Tokio runtime internally. Do not call from within
244/// an existing async context (use [`download()`] instead).
245///
246/// # Errors
247///
248/// Same as [`download()`].
249pub fn download_blocking(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
250 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
251 path: PathBuf::from("<runtime>"),
252 source: e,
253 })?;
254 rt.block_on(download(repo_id))
255}
256
257/// Blocking version of [`download_with_config()`] for non-async callers.
258///
259/// Creates a Tokio runtime internally. Do not call from within
260/// an existing async context (use [`download_with_config()`] instead).
261///
262/// # Errors
263///
264/// Same as [`download_with_config()`].
265pub fn download_with_config_blocking(
266 repo_id: String,
267 config: &FetchConfig,
268) -> Result<DownloadOutcome<PathBuf>, FetchError> {
269 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
270 path: PathBuf::from("<runtime>"),
271 source: e,
272 })?;
273 rt.block_on(download_with_config(repo_id, config))
274}
275
276/// Downloads all files from a `HuggingFace` model repository and returns
277/// a filename → path map.
278///
279/// Each key is the relative filename within the repository (e.g.,
280/// `"config.json"`, `"model.safetensors"`), and each value is the
281/// absolute local path to the downloaded file.
282///
283/// Uses the same high-throughput defaults as [`download()`]: auto-tuned
284/// concurrency and chunked multi-connection downloads for large files.
285///
286/// For filtering, progress, and other options, use
287/// [`download_files_with_config()`].
288///
289/// # Arguments
290///
291/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
292///
293/// # Errors
294///
295/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
296/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
297/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
298pub async fn download_files(
299 repo_id: String,
300) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
301 let config = FetchConfig::builder().build()?;
302 download_files_with_config(repo_id, &config).await
303}
304
305/// Downloads files from a `HuggingFace` model repository using the given
306/// configuration and returns a filename → path map.
307///
308/// Each key is the relative filename within the repository (e.g.,
309/// `"config.json"`, `"model.safetensors"`), and each value is the
310/// absolute local path to the downloaded file.
311///
312/// # Arguments
313///
314/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
315/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
316///
317/// # Errors
318///
319/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
320/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
321/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
322pub async fn download_files_with_config(
323 repo_id: String,
324 config: &FetchConfig,
325) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
326 // BORROW: explicit .as_str() instead of Deref coercion
327 preflight_gated_check(repo_id.as_str(), config).await?;
328
329 let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
330
331 if let Some(ref token) = config.token {
332 // BORROW: explicit .clone() to pass owned String
333 builder = builder.with_token(Some(token.clone()));
334 }
335
336 if let Some(ref dir) = config.output_dir {
337 // BORROW: explicit .clone() for owned PathBuf
338 builder = builder.with_cache_dir(dir.clone());
339 }
340
341 let api = builder.build().map_err(FetchError::Api)?;
342
343 let hf_repo = match config.revision {
344 Some(ref rev) => {
345 // BORROW: explicit .clone() for owned String arguments
346 Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
347 }
348 None => Repo::new(repo_id.clone(), RepoType::Model),
349 };
350
351 let repo = api.repo(hf_repo);
352 download::download_all_files_map(repo, repo_id, Some(config)).await
353}
354
355/// Blocking version of [`download_files()`] for non-async callers.
356///
357/// Creates a Tokio runtime internally. Do not call from within
358/// an existing async context (use [`download_files()`] instead).
359///
360/// # Errors
361///
362/// Same as [`download_files()`].
363pub fn download_files_blocking(
364 repo_id: String,
365) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
366 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
367 path: PathBuf::from("<runtime>"),
368 source: e,
369 })?;
370 rt.block_on(download_files(repo_id))
371}
372
373/// Downloads a single file from a `HuggingFace` model repository.
374///
375/// Returns the local cache path. If the file is already cached (and
376/// checksums match when `verify_checksums` is enabled), the download
377/// is skipped and the cached path is returned immediately.
378///
379/// Files at or above [`FetchConfig`]'s `chunk_threshold` (auto-tuned by
380/// the download plan optimizer, or 100 MiB fallback) are downloaded using
381/// multiple parallel HTTP Range connections (`connections_per_file`,
382/// auto-tuned or 8 fallback). Smaller files use a single connection.
383///
384/// # Arguments
385///
386/// * `repo_id` — Repository identifier (e.g., `"mntss/clt-gemma-2-2b-426k"`).
387/// * `filename` — Exact filename within the repository (e.g., `"W_enc_5.safetensors"`).
388/// * `config` — Shared configuration for auth, progress, checksums, retries, and chunking.
389///
390/// # Errors
391///
392/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
393/// * [`FetchError::Http`] — if the file does not exist in the repository.
394/// * [`FetchError::Api`] — on download failure (after retries).
395/// * [`FetchError::Checksum`] — if verification is enabled and fails.
396pub async fn download_file(
397 repo_id: String,
398 filename: &str,
399 config: &FetchConfig,
400) -> Result<DownloadOutcome<PathBuf>, FetchError> {
401 // BORROW: explicit .as_str() instead of Deref coercion
402 preflight_gated_check(repo_id.as_str(), config).await?;
403
404 let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
405
406 if let Some(ref token) = config.token {
407 // BORROW: explicit .clone() to pass owned String
408 builder = builder.with_token(Some(token.clone()));
409 }
410
411 if let Some(ref dir) = config.output_dir {
412 // BORROW: explicit .clone() for owned PathBuf
413 builder = builder.with_cache_dir(dir.clone());
414 }
415
416 let api = builder.build().map_err(FetchError::Api)?;
417
418 let hf_repo = match config.revision {
419 Some(ref rev) => {
420 // BORROW: explicit .clone() for owned String arguments
421 Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
422 }
423 None => Repo::new(repo_id.clone(), RepoType::Model),
424 };
425
426 let repo = api.repo(hf_repo);
427 download::download_file_by_name(repo, repo_id, filename, config).await
428}
429
430/// Blocking version of [`download_file()`] for non-async callers.
431///
432/// Creates a Tokio runtime internally. Do not call from within
433/// an existing async context (use [`download_file()`] instead).
434///
435/// # Errors
436///
437/// Same as [`download_file()`].
438pub fn download_file_blocking(
439 repo_id: String,
440 filename: &str,
441 config: &FetchConfig,
442) -> Result<DownloadOutcome<PathBuf>, FetchError> {
443 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
444 path: PathBuf::from("<runtime>"),
445 source: e,
446 })?;
447 rt.block_on(download_file(repo_id, filename, config))
448}
449
450/// Blocking version of [`download_files_with_config()`] for non-async callers.
451///
452/// Creates a Tokio runtime internally. Do not call from within
453/// an existing async context (use [`download_files_with_config()`] instead).
454///
455/// # Errors
456///
457/// Same as [`download_files_with_config()`].
458pub fn download_files_with_config_blocking(
459 repo_id: String,
460 config: &FetchConfig,
461) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
462 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
463 path: PathBuf::from("<runtime>"),
464 source: e,
465 })?;
466 rt.block_on(download_files_with_config(repo_id, config))
467}
468
469/// Downloads files according to an existing [`DownloadPlan`].
470///
471/// Only uncached files in the plan are downloaded. The `config` controls
472/// authentication, progress, timeouts, and performance settings.
473/// Use [`DownloadPlan::recommended_config()`] to compute an optimized config,
474/// or override specific fields via [`DownloadPlan::recommended_config_builder()`].
475///
476/// # Errors
477///
478/// Returns [`FetchError::Io`] if the cache directory cannot be resolved.
479/// Same error conditions as [`download_with_config()`] for the download itself.
480pub async fn download_with_plan(
481 plan: &DownloadPlan,
482 config: &FetchConfig,
483) -> Result<DownloadOutcome<PathBuf>, FetchError> {
484 if plan.fully_cached() {
485 // Resolve snapshot path from cache and return immediately.
486 let cache_dir = config
487 .output_dir
488 .clone()
489 .map_or_else(cache::hf_cache_dir, Ok)?;
490 let repo_folder = chunked::repo_folder_name(plan.repo_id.as_str());
491 // BORROW: explicit .as_str() instead of Deref coercion
492 let snapshot_dir = cache_dir
493 .join(repo_folder.as_str())
494 .join("snapshots")
495 .join(plan.revision.as_str());
496 return Ok(DownloadOutcome::Cached(snapshot_dir));
497 }
498
499 // Delegate to the standard download path which will re-check cache
500 // internally. The plan's value is the dry-run preview and the
501 // recommended config computed by the caller.
502 // BORROW: explicit .clone() for owned String argument
503 download_with_config(plan.repo_id.clone(), config).await
504}
505
506/// Blocking version of [`download_with_plan()`] for non-async callers.
507///
508/// Creates a Tokio runtime internally. Do not call from within
509/// an existing async context (use [`download_with_plan()`] instead).
510///
511/// # Errors
512///
513/// Same as [`download_with_plan()`].
514pub fn download_with_plan_blocking(
515 plan: &DownloadPlan,
516 config: &FetchConfig,
517) -> Result<DownloadOutcome<PathBuf>, FetchError> {
518 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
519 path: PathBuf::from("<runtime>"),
520 source: e,
521 })?;
522 rt.block_on(download_with_plan(plan, config))
523}