hf_fetch_model/lib.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! # hf-fetch-model
4//!
5//! Fast `HuggingFace` model downloads for Rust.
6//!
7//! An embeddable library for downloading `HuggingFace` model repositories
8//! with maximum throughput. Wraps [`hf_hub`] and adds repo-level orchestration.
9//!
10//! ## Quick Start
11//!
12//! ```rust,no_run
13//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
14//! let outcome = hf_fetch_model::download("julien-c/dummy-unknown".to_owned()).await?;
15//! println!("Model at: {}", outcome.inner().display());
16//! # Ok(())
17//! # }
18//! ```
19//!
20//! ## Configured Download
21//!
22//! ```rust,no_run
23//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
24//! use hf_fetch_model::FetchConfig;
25//!
26//! let config = FetchConfig::builder()
27//! .filter("*.safetensors")
28//! .filter("*.json")
29//! .on_progress(|e| {
30//! println!("{}: {:.1}%", e.filename, e.percent);
31//! })
32//! .build()?;
33//!
34//! let outcome = hf_fetch_model::download_with_config(
35//! "google/gemma-2-2b".to_owned(),
36//! &config,
37//! ).await?;
38//! // outcome.is_cached() tells you if it came from local cache
39//! let path = outcome.into_inner();
40//! # Ok(())
41//! # }
42//! ```
43//!
44//! ## Inspect Before Downloading
45//!
46//! Read tensor metadata from `.safetensors` headers via HTTP Range requests —
47//! no weight data downloaded. See [`examples/candle_inspect.rs`](https://github.com/PCfVW/hf-fetch-model/blob/main/examples/candle_inspect.rs)
48//! for a runnable example.
49//!
50//! ```rust,no_run
51//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
52//! let results = hf_fetch_model::inspect::inspect_repo_safetensors(
53//! "EleutherAI/pythia-1.4b", None, None,
54//! ).await?;
55//!
56//! for (filename, header, _source) in &results {
57//! println!("{filename}: {} tensors", header.tensors.len());
58//! }
59//! # Ok(())
60//! # }
61//! ```
62//!
63//! ## `HuggingFace` Cache
64//!
65//! Downloaded files are stored in the standard `HuggingFace` cache directory
66//! (`~/.cache/huggingface/hub/`), ensuring compatibility with Python tooling.
67//!
68//! ## Authentication
69//!
70//! Set the `HF_TOKEN` environment variable to access private or gated models,
71//! or use [`FetchConfig::builder().token()`](FetchConfigBuilder::token).
72
73pub mod cache;
74pub mod cache_layout;
75pub mod checksum;
76mod chunked;
77pub mod config;
78pub mod discover;
79pub mod download;
80pub mod error;
81pub mod inspect;
82pub mod plan;
83pub mod progress;
84pub mod repo;
85mod retry;
86
87pub use chunked::build_client;
88pub use config::{
89 compile_glob_patterns, file_matches, has_glob_chars, FetchConfig, FetchConfigBuilder, Filter,
90};
91pub use discover::{DiscoveredFamily, GateStatus, ModelCardMetadata, SearchResult};
92pub use download::DownloadOutcome;
93pub use error::{FetchError, FileFailure};
94pub use inspect::AdapterConfig;
95pub use plan::{download_plan, DownloadPlan, FilePlan};
96pub use progress::{ProgressEvent, ProgressReceiver};
97
98use std::collections::HashMap;
99use std::path::PathBuf;
100
101use hf_hub::{Repo, RepoType};
102
103/// Pre-flight check for gated model access.
104///
105/// Two cases:
106/// - **No token**: checks the model metadata (unauthenticated) for gating
107/// status and rejects with a clear message if gated.
108/// - **Token present**: if the model is gated, makes one authenticated
109/// metadata request to verify the token actually grants access. Catches
110/// invalid tokens and unaccepted licenses before the download starts.
111///
112/// If the metadata request itself fails (network error, private repo),
113/// the check is silently skipped so that normal download error handling
114/// can take over.
115async fn preflight_gated_check(repo_id: &str, config: &FetchConfig) -> Result<(), FetchError> {
116 // Best-effort: if the metadata call fails, let the download proceed.
117 let Ok(metadata) = discover::fetch_model_card(repo_id).await else {
118 return Ok(());
119 };
120
121 if !metadata.gated.is_gated() {
122 return Ok(());
123 }
124
125 // Model is gated — check auth.
126 if config.token.is_none() {
127 return Err(FetchError::Auth {
128 reason: format!(
129 "{repo_id} is a gated model — accept the license at \
130 https://huggingface.co/{repo_id} and set HF_TOKEN or pass --token"
131 ),
132 });
133 }
134
135 // Token is present — verify it grants access with a lightweight probe.
136 let probe_client = chunked::build_client(config.token.as_deref())?;
137 let probe = repo::list_repo_files_with_metadata(
138 repo_id,
139 config.token.as_deref(),
140 config.revision.as_deref(),
141 &probe_client,
142 )
143 .await;
144
145 if let Err(ref e) = probe {
146 // BORROW: explicit .to_string() for error Display formatting
147 let msg = e.to_string();
148 if msg.contains("401") || msg.contains("403") {
149 return Err(FetchError::Auth {
150 reason: format!(
151 "{repo_id} is a gated model and your token was rejected — \
152 accept the license at https://huggingface.co/{repo_id} \
153 and check that your token is valid"
154 ),
155 });
156 }
157 }
158
159 Ok(())
160}
161
162/// Downloads all files from a `HuggingFace` model repository.
163///
164/// Uses high-throughput mode for maximum download speed, including
165/// auto-tuned concurrency, chunked multi-connection downloads for large
166/// files, and plan-optimized settings based on file size distribution.
167/// Files are stored in the standard `HuggingFace` cache layout
168/// (`~/.cache/huggingface/hub/`).
169///
170/// Authentication is handled via the `HF_TOKEN` environment variable when set.
171///
172/// For filtering, progress, and other options, use [`download_with_config()`].
173///
174/// # Arguments
175///
176/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
177///
178/// # Returns
179///
180/// The path to the snapshot directory containing all downloaded files.
181///
182/// # Errors
183///
184/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
185/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
186/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
187/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
188pub async fn download(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
189 let config = FetchConfig::builder().build()?;
190 download_with_config(repo_id, &config).await
191}
192
193/// Downloads files from a `HuggingFace` model repository using the given configuration.
194///
195/// Supports filtering, progress reporting, custom revision, authentication,
196/// and concurrency settings via [`FetchConfig`].
197///
198/// # Arguments
199///
200/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
201/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
202///
203/// # Returns
204///
205/// The path to the snapshot directory containing all downloaded files.
206///
207/// # Errors
208///
209/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
210/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
211/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
212pub async fn download_with_config(
213 repo_id: String,
214 config: &FetchConfig,
215) -> Result<DownloadOutcome<PathBuf>, FetchError> {
216 // BORROW: explicit .as_str() instead of Deref coercion
217 preflight_gated_check(repo_id.as_str(), config).await?;
218
219 let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
220
221 if let Some(ref token) = config.token {
222 // BORROW: explicit .clone() to pass owned String
223 builder = builder.with_token(Some(token.clone()));
224 }
225
226 if let Some(ref dir) = config.output_dir {
227 // BORROW: explicit .clone() for owned PathBuf
228 builder = builder.with_cache_dir(dir.clone());
229 }
230
231 let api = builder.build().map_err(FetchError::Api)?;
232
233 let hf_repo = match config.revision {
234 Some(ref rev) => {
235 // BORROW: explicit .clone() for owned String arguments
236 Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
237 }
238 None => Repo::new(repo_id.clone(), RepoType::Model),
239 };
240
241 let repo = api.repo(hf_repo);
242 download::download_all_files(repo, repo_id, Some(config)).await
243}
244
245/// Blocking version of [`download()`] for non-async callers.
246///
247/// Creates a Tokio runtime internally. Do not call from within
248/// an existing async context (use [`download()`] instead).
249///
250/// # Errors
251///
252/// Same as [`download()`].
253pub fn download_blocking(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
254 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
255 path: PathBuf::from("<runtime>"),
256 source: e,
257 })?;
258 rt.block_on(download(repo_id))
259}
260
261/// Blocking version of [`download_with_config()`] for non-async callers.
262///
263/// Creates a Tokio runtime internally. Do not call from within
264/// an existing async context (use [`download_with_config()`] instead).
265///
266/// # Errors
267///
268/// Same as [`download_with_config()`].
269pub fn download_with_config_blocking(
270 repo_id: String,
271 config: &FetchConfig,
272) -> Result<DownloadOutcome<PathBuf>, FetchError> {
273 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
274 path: PathBuf::from("<runtime>"),
275 source: e,
276 })?;
277 rt.block_on(download_with_config(repo_id, config))
278}
279
280/// Downloads all files from a `HuggingFace` model repository and returns
281/// a filename → path map.
282///
283/// Each key is the relative filename within the repository (e.g.,
284/// `"config.json"`, `"model.safetensors"`), and each value is the
285/// absolute local path to the downloaded file.
286///
287/// Uses the same high-throughput defaults as [`download()`]: auto-tuned
288/// concurrency and chunked multi-connection downloads for large files.
289///
290/// For filtering, progress, and other options, use
291/// [`download_files_with_config()`].
292///
293/// # Arguments
294///
295/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
296///
297/// # Errors
298///
299/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
300/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
301/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
302pub async fn download_files(
303 repo_id: String,
304) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
305 let config = FetchConfig::builder().build()?;
306 download_files_with_config(repo_id, &config).await
307}
308
309/// Downloads files from a `HuggingFace` model repository using the given
310/// configuration and returns a filename → path map.
311///
312/// Each key is the relative filename within the repository (e.g.,
313/// `"config.json"`, `"model.safetensors"`), and each value is the
314/// absolute local path to the downloaded file.
315///
316/// # Arguments
317///
318/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
319/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
320///
321/// # Errors
322///
323/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
324/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
325/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
326pub async fn download_files_with_config(
327 repo_id: String,
328 config: &FetchConfig,
329) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
330 // BORROW: explicit .as_str() instead of Deref coercion
331 preflight_gated_check(repo_id.as_str(), config).await?;
332
333 let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
334
335 if let Some(ref token) = config.token {
336 // BORROW: explicit .clone() to pass owned String
337 builder = builder.with_token(Some(token.clone()));
338 }
339
340 if let Some(ref dir) = config.output_dir {
341 // BORROW: explicit .clone() for owned PathBuf
342 builder = builder.with_cache_dir(dir.clone());
343 }
344
345 let api = builder.build().map_err(FetchError::Api)?;
346
347 let hf_repo = match config.revision {
348 Some(ref rev) => {
349 // BORROW: explicit .clone() for owned String arguments
350 Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
351 }
352 None => Repo::new(repo_id.clone(), RepoType::Model),
353 };
354
355 let repo = api.repo(hf_repo);
356 download::download_all_files_map(repo, repo_id, Some(config)).await
357}
358
359/// Blocking version of [`download_files()`] for non-async callers.
360///
361/// Creates a Tokio runtime internally. Do not call from within
362/// an existing async context (use [`download_files()`] instead).
363///
364/// # Errors
365///
366/// Same as [`download_files()`].
367pub fn download_files_blocking(
368 repo_id: String,
369) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
370 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
371 path: PathBuf::from("<runtime>"),
372 source: e,
373 })?;
374 rt.block_on(download_files(repo_id))
375}
376
377/// Downloads a single file from a `HuggingFace` model repository.
378///
379/// Returns the local cache path. If the file is already cached (and
380/// checksums match when `verify_checksums` is enabled), the download
381/// is skipped and the cached path is returned immediately.
382///
383/// Files at or above [`FetchConfig`]'s `chunk_threshold` (auto-tuned by
384/// the download plan optimizer, or 100 MiB fallback) are downloaded using
385/// multiple parallel HTTP Range connections (`connections_per_file`,
386/// auto-tuned or 8 fallback). Smaller files use a single connection.
387///
388/// # Arguments
389///
390/// * `repo_id` — Repository identifier (e.g., `"mntss/clt-gemma-2-2b-426k"`).
391/// * `filename` — Exact filename within the repository (e.g., `"W_enc_5.safetensors"`).
392/// * `config` — Shared configuration for auth, progress, checksums, retries, and chunking.
393///
394/// # Errors
395///
396/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
397/// * [`FetchError::Http`] — if the file does not exist in the repository.
398/// * [`FetchError::Api`] — on download failure (after retries).
399/// * [`FetchError::Checksum`] — if verification is enabled and fails.
400pub async fn download_file(
401 repo_id: String,
402 filename: &str,
403 config: &FetchConfig,
404) -> Result<DownloadOutcome<PathBuf>, FetchError> {
405 // BORROW: explicit .as_str() instead of Deref coercion
406 preflight_gated_check(repo_id.as_str(), config).await?;
407
408 let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
409
410 if let Some(ref token) = config.token {
411 // BORROW: explicit .clone() to pass owned String
412 builder = builder.with_token(Some(token.clone()));
413 }
414
415 if let Some(ref dir) = config.output_dir {
416 // BORROW: explicit .clone() for owned PathBuf
417 builder = builder.with_cache_dir(dir.clone());
418 }
419
420 let api = builder.build().map_err(FetchError::Api)?;
421
422 let hf_repo = match config.revision {
423 Some(ref rev) => {
424 // BORROW: explicit .clone() for owned String arguments
425 Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
426 }
427 None => Repo::new(repo_id.clone(), RepoType::Model),
428 };
429
430 let repo = api.repo(hf_repo);
431 download::download_file_by_name(repo, repo_id, filename, config).await
432}
433
434/// Blocking version of [`download_file()`] for non-async callers.
435///
436/// Creates a Tokio runtime internally. Do not call from within
437/// an existing async context (use [`download_file()`] instead).
438///
439/// # Errors
440///
441/// Same as [`download_file()`].
442pub fn download_file_blocking(
443 repo_id: String,
444 filename: &str,
445 config: &FetchConfig,
446) -> Result<DownloadOutcome<PathBuf>, FetchError> {
447 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
448 path: PathBuf::from("<runtime>"),
449 source: e,
450 })?;
451 rt.block_on(download_file(repo_id, filename, config))
452}
453
454/// Blocking version of [`download_files_with_config()`] for non-async callers.
455///
456/// Creates a Tokio runtime internally. Do not call from within
457/// an existing async context (use [`download_files_with_config()`] instead).
458///
459/// # Errors
460///
461/// Same as [`download_files_with_config()`].
462pub fn download_files_with_config_blocking(
463 repo_id: String,
464 config: &FetchConfig,
465) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
466 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
467 path: PathBuf::from("<runtime>"),
468 source: e,
469 })?;
470 rt.block_on(download_files_with_config(repo_id, config))
471}
472
473/// Downloads files according to an existing [`DownloadPlan`].
474///
475/// Only uncached files in the plan are downloaded. The `config` controls
476/// authentication, progress, timeouts, and performance settings.
477/// Use [`DownloadPlan::recommended_config()`] to compute an optimized config,
478/// or override specific fields via [`DownloadPlan::recommended_config_builder()`].
479///
480/// # Errors
481///
482/// Returns [`FetchError::Io`] if the cache directory cannot be resolved.
483/// Same error conditions as [`download_with_config()`] for the download itself.
484pub async fn download_with_plan(
485 plan: &DownloadPlan,
486 config: &FetchConfig,
487) -> Result<DownloadOutcome<PathBuf>, FetchError> {
488 if plan.fully_cached() {
489 // Resolve snapshot path from cache and return immediately.
490 let cache_dir = config
491 .output_dir
492 .clone()
493 .map_or_else(cache::hf_cache_dir, Ok)?;
494 let repo_dir = cache_layout::repo_dir(&cache_dir, plan.repo_id.as_str());
495 let snapshot_dir = cache_layout::snapshot_dir(&repo_dir, plan.revision.as_str());
496 return Ok(DownloadOutcome::Cached(snapshot_dir));
497 }
498
499 // Delegate to the standard download path which will re-check cache
500 // internally. The plan's value is the dry-run preview and the
501 // recommended config computed by the caller.
502 // BORROW: explicit .clone() for owned String argument
503 download_with_config(plan.repo_id.clone(), config).await
504}
505
506/// Blocking version of [`download_with_plan()`] for non-async callers.
507///
508/// Creates a Tokio runtime internally. Do not call from within
509/// an existing async context (use [`download_with_plan()`] instead).
510///
511/// # Errors
512///
513/// Same as [`download_with_plan()`].
514pub fn download_with_plan_blocking(
515 plan: &DownloadPlan,
516 config: &FetchConfig,
517) -> Result<DownloadOutcome<PathBuf>, FetchError> {
518 let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
519 path: PathBuf::from("<runtime>"),
520 source: e,
521 })?;
522 rt.block_on(download_with_plan(plan, config))
523}