Skip to main content

hf_fetch_model/
lib.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! # hf-fetch-model
4//!
5//! Fast `HuggingFace` model downloads for Rust.
6//!
7//! An embeddable library for downloading `HuggingFace` model repositories
8//! with maximum throughput. Wraps [`hf_hub`] and adds repo-level orchestration.
9//!
10//! ## Quick Start
11//!
12//! ```rust,no_run
13//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
14//! let outcome = hf_fetch_model::download("julien-c/dummy-unknown".to_owned()).await?;
15//! println!("Model at: {}", outcome.inner().display());
16//! # Ok(())
17//! # }
18//! ```
19//!
20//! ## Configured Download
21//!
22//! ```rust,no_run
23//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
24//! use hf_fetch_model::FetchConfig;
25//!
26//! let config = FetchConfig::builder()
27//!     .filter("*.safetensors")
28//!     .filter("*.json")
29//!     .on_progress(|e| {
30//!         println!("{}: {:.1}%", e.filename, e.percent);
31//!     })
32//!     .build()?;
33//!
34//! let outcome = hf_fetch_model::download_with_config(
35//!     "google/gemma-2-2b".to_owned(),
36//!     &config,
37//! ).await?;
38//! // outcome.is_cached() tells you if it came from local cache
39//! let path = outcome.into_inner();
40//! # Ok(())
41//! # }
42//! ```
43//!
44//! ## Inspect Before Downloading
45//!
46//! Read tensor metadata from `.safetensors` headers via HTTP Range requests —
47//! no weight data downloaded. See [`examples/candle_inspect.rs`](https://github.com/PCfVW/hf-fetch-model/blob/main/examples/candle_inspect.rs)
48//! for a runnable example.
49//!
50//! ```rust,no_run
51//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
52//! let results = hf_fetch_model::inspect::inspect_repo_safetensors(
53//!     "EleutherAI/pythia-1.4b", None, None,
54//! ).await?;
55//!
56//! for (filename, header, _source) in &results {
57//!     println!("{filename}: {} tensors", header.tensors.len());
58//! }
59//! # Ok(())
60//! # }
61//! ```
62//!
63//! ## `HuggingFace` Cache
64//!
65//! Downloaded files are stored in the standard `HuggingFace` cache directory
66//! (`~/.cache/huggingface/hub/`), ensuring compatibility with Python tooling.
67//!
68//! ## Download Durability
69//!
70//! Multi-connection downloads survive interruption. When a download is
71//! aborted by [`FetchConfigBuilder::timeout_per_file`] (default 300 s),
72//! Ctrl-C, panic, or a transient chunk error, the partial `.chunked.part`
73//! file plus a small per-chunk progress sidecar are kept on disk. The next
74//! call to [`download_with_config`] for the same file picks up where it
75//! stopped — each parallel chunk sends a fresh `Range` request that skips
76//! the bytes it already has — provided the upstream etag still matches.
77//! On etag change, schema-version mismatch, or a different
78//! [`FetchConfigBuilder::connections_per_file`] count, the partial is
79//! discarded and a fresh download starts.
80//!
81//! For slow connections on multi-GiB files, raise the per-file budget to
82//! match real throughput:
83//!
84//! ```rust,no_run
85//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
86//! use std::time::Duration;
87//! use hf_fetch_model::FetchConfig;
88//!
89//! let config = FetchConfig::builder()
90//!     .timeout_per_file(Duration::from_secs(1800))
91//!     .build()?;
92//! # let _ = hf_fetch_model::download_with_config(
93//! #     "google/gemma-4-E2B-it".to_owned(),
94//! #     &config,
95//! # ).await?;
96//! # Ok(())
97//! # }
98//! ```
99//!
100//! ## Authentication
101//!
102//! Set the `HF_TOKEN` environment variable to access private or gated models,
103//! or use [`FetchConfig::builder().token()`](FetchConfigBuilder::token).
104
105pub mod cache;
106pub mod cache_layout;
107pub mod checksum;
108mod chunked;
109mod chunked_state;
110pub mod config;
111pub mod discover;
112pub mod download;
113pub mod error;
114pub mod inspect;
115pub mod plan;
116pub mod progress;
117pub mod repo;
118mod retry;
119
120pub use chunked::build_client;
121pub use config::{
122    compile_glob_patterns, file_matches, has_glob_chars, FetchConfig, FetchConfigBuilder, Filter,
123};
124pub use discover::{DiscoveredFamily, GateStatus, ModelCardMetadata, SearchResult};
125pub use download::DownloadOutcome;
126pub use error::{FetchError, FileFailure};
127pub use inspect::AdapterConfig;
128pub use plan::{download_plan, DownloadPlan, FilePlan};
129pub use progress::{ProgressEvent, ProgressReceiver};
130
131use std::collections::HashMap;
132use std::path::PathBuf;
133
134use hf_hub::{Repo, RepoType};
135
136/// Pre-flight check for gated model access.
137///
138/// Two cases:
139/// - **No token**: checks the model metadata (unauthenticated) for gating
140///   status and rejects with a clear message if gated.
141/// - **Token present**: if the model is gated, makes one authenticated
142///   metadata request to verify the token actually grants access. Catches
143///   invalid tokens and unaccepted licenses before the download starts.
144///
145/// If the metadata request itself fails (network error, private repo),
146/// the check is silently skipped so that normal download error handling
147/// can take over.
148async fn preflight_gated_check(repo_id: &str, config: &FetchConfig) -> Result<(), FetchError> {
149    // Best-effort: if the metadata call fails, let the download proceed.
150    let Ok(metadata) = discover::fetch_model_card(repo_id).await else {
151        return Ok(());
152    };
153
154    if !metadata.gated.is_gated() {
155        return Ok(());
156    }
157
158    // Model is gated — check auth.
159    if config.token.is_none() {
160        return Err(FetchError::Auth {
161            reason: format!(
162                "{repo_id} is a gated model — accept the license at \
163                 https://huggingface.co/{repo_id} and set HF_TOKEN or pass --token"
164            ),
165        });
166    }
167
168    // Token is present — verify it grants access with a lightweight probe.
169    let probe_client = chunked::build_client(config.token.as_deref())?;
170    let probe = repo::list_repo_files_with_metadata(
171        repo_id,
172        config.token.as_deref(),
173        config.revision.as_deref(),
174        &probe_client,
175    )
176    .await;
177
178    if let Err(ref e) = probe {
179        // BORROW: explicit .to_string() for error Display formatting
180        let msg = e.to_string();
181        if msg.contains("401") || msg.contains("403") {
182            return Err(FetchError::Auth {
183                reason: format!(
184                    "{repo_id} is a gated model and your token was rejected — \
185                     accept the license at https://huggingface.co/{repo_id} \
186                     and check that your token is valid"
187                ),
188            });
189        }
190    }
191
192    Ok(())
193}
194
195/// Downloads all files from a `HuggingFace` model repository.
196///
197/// Uses high-throughput mode for maximum download speed, including
198/// auto-tuned concurrency, chunked multi-connection downloads for large
199/// files, and plan-optimized settings based on file size distribution.
200/// Files are stored in the standard `HuggingFace` cache layout
201/// (`~/.cache/huggingface/hub/`).
202///
203/// Authentication is handled via the `HF_TOKEN` environment variable when set.
204///
205/// For filtering, progress, and other options, use [`download_with_config()`].
206///
207/// # Arguments
208///
209/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
210///
211/// # Returns
212///
213/// The path to the snapshot directory containing all downloaded files.
214///
215/// # Errors
216///
217/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
218/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
219/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
220/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
221pub async fn download(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
222    let config = FetchConfig::builder().build()?;
223    download_with_config(repo_id, &config).await
224}
225
226/// Downloads files from a `HuggingFace` model repository using the given configuration.
227///
228/// Supports filtering, progress reporting, custom revision, authentication,
229/// and concurrency settings via [`FetchConfig`].
230///
231/// # Arguments
232///
233/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
234/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
235///
236/// # Returns
237///
238/// The path to the snapshot directory containing all downloaded files.
239///
240/// # Errors
241///
242/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
243/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
244/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
245pub async fn download_with_config(
246    repo_id: String,
247    config: &FetchConfig,
248) -> Result<DownloadOutcome<PathBuf>, FetchError> {
249    // BORROW: explicit .as_str() instead of Deref coercion
250    preflight_gated_check(repo_id.as_str(), config).await?;
251
252    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
253
254    if let Some(ref token) = config.token {
255        // BORROW: explicit .clone() to pass owned String
256        builder = builder.with_token(Some(token.clone()));
257    }
258
259    if let Some(ref dir) = config.output_dir {
260        // BORROW: explicit .clone() for owned PathBuf
261        builder = builder.with_cache_dir(dir.clone());
262    }
263
264    let api = builder.build().map_err(FetchError::Api)?;
265
266    let hf_repo = match config.revision {
267        Some(ref rev) => {
268            // BORROW: explicit .clone() for owned String arguments
269            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
270        }
271        None => Repo::new(repo_id.clone(), RepoType::Model),
272    };
273
274    let repo = api.repo(hf_repo);
275    download::download_all_files(repo, repo_id, Some(config)).await
276}
277
278/// Blocking version of [`download()`] for non-async callers.
279///
280/// Creates a Tokio runtime internally. Do not call from within
281/// an existing async context (use [`download()`] instead).
282///
283/// # Errors
284///
285/// Same as [`download()`].
286pub fn download_blocking(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
287    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
288        path: PathBuf::from("<runtime>"),
289        source: e,
290    })?;
291    rt.block_on(download(repo_id))
292}
293
294/// Blocking version of [`download_with_config()`] for non-async callers.
295///
296/// Creates a Tokio runtime internally. Do not call from within
297/// an existing async context (use [`download_with_config()`] instead).
298///
299/// # Errors
300///
301/// Same as [`download_with_config()`].
302pub fn download_with_config_blocking(
303    repo_id: String,
304    config: &FetchConfig,
305) -> Result<DownloadOutcome<PathBuf>, FetchError> {
306    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
307        path: PathBuf::from("<runtime>"),
308        source: e,
309    })?;
310    rt.block_on(download_with_config(repo_id, config))
311}
312
313/// Downloads all files from a `HuggingFace` model repository and returns
314/// a filename → path map.
315///
316/// Each key is the relative filename within the repository (e.g.,
317/// `"config.json"`, `"model.safetensors"`), and each value is the
318/// absolute local path to the downloaded file.
319///
320/// Uses the same high-throughput defaults as [`download()`]: auto-tuned
321/// concurrency and chunked multi-connection downloads for large files.
322///
323/// For filtering, progress, and other options, use
324/// [`download_files_with_config()`].
325///
326/// # Arguments
327///
328/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
329///
330/// # Errors
331///
332/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
333/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
334/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
335pub async fn download_files(
336    repo_id: String,
337) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
338    let config = FetchConfig::builder().build()?;
339    download_files_with_config(repo_id, &config).await
340}
341
342/// Downloads files from a `HuggingFace` model repository using the given
343/// configuration and returns a filename → path map.
344///
345/// Each key is the relative filename within the repository (e.g.,
346/// `"config.json"`, `"model.safetensors"`), and each value is the
347/// absolute local path to the downloaded file.
348///
349/// # Arguments
350///
351/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
352/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
353///
354/// # Errors
355///
356/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
357/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
358/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
359pub async fn download_files_with_config(
360    repo_id: String,
361    config: &FetchConfig,
362) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
363    // BORROW: explicit .as_str() instead of Deref coercion
364    preflight_gated_check(repo_id.as_str(), config).await?;
365
366    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
367
368    if let Some(ref token) = config.token {
369        // BORROW: explicit .clone() to pass owned String
370        builder = builder.with_token(Some(token.clone()));
371    }
372
373    if let Some(ref dir) = config.output_dir {
374        // BORROW: explicit .clone() for owned PathBuf
375        builder = builder.with_cache_dir(dir.clone());
376    }
377
378    let api = builder.build().map_err(FetchError::Api)?;
379
380    let hf_repo = match config.revision {
381        Some(ref rev) => {
382            // BORROW: explicit .clone() for owned String arguments
383            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
384        }
385        None => Repo::new(repo_id.clone(), RepoType::Model),
386    };
387
388    let repo = api.repo(hf_repo);
389    download::download_all_files_map(repo, repo_id, Some(config)).await
390}
391
392/// Blocking version of [`download_files()`] for non-async callers.
393///
394/// Creates a Tokio runtime internally. Do not call from within
395/// an existing async context (use [`download_files()`] instead).
396///
397/// # Errors
398///
399/// Same as [`download_files()`].
400pub fn download_files_blocking(
401    repo_id: String,
402) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
403    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
404        path: PathBuf::from("<runtime>"),
405        source: e,
406    })?;
407    rt.block_on(download_files(repo_id))
408}
409
410/// Downloads a single file from a `HuggingFace` model repository.
411///
412/// Returns the local cache path. If the file is already cached (and
413/// checksums match when `verify_checksums` is enabled), the download
414/// is skipped and the cached path is returned immediately.
415///
416/// Files at or above [`FetchConfig`]'s `chunk_threshold` (auto-tuned by
417/// the download plan optimizer, or 100 MiB fallback) are downloaded using
418/// multiple parallel HTTP Range connections (`connections_per_file`,
419/// auto-tuned or 8 fallback). Smaller files use a single connection.
420///
421/// # Arguments
422///
423/// * `repo_id` — Repository identifier (e.g., `"mntss/clt-gemma-2-2b-426k"`).
424/// * `filename` — Exact filename within the repository (e.g., `"W_enc_5.safetensors"`).
425/// * `config` — Shared configuration for auth, progress, checksums, retries, and chunking.
426///
427/// # Errors
428///
429/// * [`FetchError::Auth`] — if the repository is gated and access is denied (no token, invalid token, or license not accepted).
430/// * [`FetchError::Http`] — if the file does not exist in the repository.
431/// * [`FetchError::Api`] — on download failure (after retries).
432/// * [`FetchError::Checksum`] — if verification is enabled and fails.
433pub async fn download_file(
434    repo_id: String,
435    filename: &str,
436    config: &FetchConfig,
437) -> Result<DownloadOutcome<PathBuf>, FetchError> {
438    // BORROW: explicit .as_str() instead of Deref coercion
439    preflight_gated_check(repo_id.as_str(), config).await?;
440
441    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
442
443    if let Some(ref token) = config.token {
444        // BORROW: explicit .clone() to pass owned String
445        builder = builder.with_token(Some(token.clone()));
446    }
447
448    if let Some(ref dir) = config.output_dir {
449        // BORROW: explicit .clone() for owned PathBuf
450        builder = builder.with_cache_dir(dir.clone());
451    }
452
453    let api = builder.build().map_err(FetchError::Api)?;
454
455    let hf_repo = match config.revision {
456        Some(ref rev) => {
457            // BORROW: explicit .clone() for owned String arguments
458            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
459        }
460        None => Repo::new(repo_id.clone(), RepoType::Model),
461    };
462
463    let repo = api.repo(hf_repo);
464    download::download_file_by_name(repo, repo_id, filename, config).await
465}
466
467/// Blocking version of [`download_file()`] for non-async callers.
468///
469/// Creates a Tokio runtime internally. Do not call from within
470/// an existing async context (use [`download_file()`] instead).
471///
472/// # Errors
473///
474/// Same as [`download_file()`].
475pub fn download_file_blocking(
476    repo_id: String,
477    filename: &str,
478    config: &FetchConfig,
479) -> Result<DownloadOutcome<PathBuf>, FetchError> {
480    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
481        path: PathBuf::from("<runtime>"),
482        source: e,
483    })?;
484    rt.block_on(download_file(repo_id, filename, config))
485}
486
487/// Blocking version of [`download_files_with_config()`] for non-async callers.
488///
489/// Creates a Tokio runtime internally. Do not call from within
490/// an existing async context (use [`download_files_with_config()`] instead).
491///
492/// # Errors
493///
494/// Same as [`download_files_with_config()`].
495pub fn download_files_with_config_blocking(
496    repo_id: String,
497    config: &FetchConfig,
498) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
499    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
500        path: PathBuf::from("<runtime>"),
501        source: e,
502    })?;
503    rt.block_on(download_files_with_config(repo_id, config))
504}
505
506/// Downloads files according to an existing [`DownloadPlan`].
507///
508/// Only uncached files in the plan are downloaded. The `config` controls
509/// authentication, progress, timeouts, and performance settings.
510/// Use [`DownloadPlan::recommended_config()`] to compute an optimized config,
511/// or override specific fields via [`DownloadPlan::recommended_config_builder()`].
512///
513/// # Errors
514///
515/// Returns [`FetchError::Io`] if the cache directory cannot be resolved.
516/// Same error conditions as [`download_with_config()`] for the download itself.
517pub async fn download_with_plan(
518    plan: &DownloadPlan,
519    config: &FetchConfig,
520) -> Result<DownloadOutcome<PathBuf>, FetchError> {
521    if plan.fully_cached() {
522        // Resolve snapshot path from cache and return immediately.
523        let cache_dir = config
524            .output_dir
525            .clone()
526            .map_or_else(cache::hf_cache_dir, Ok)?;
527        let repo_dir = cache_layout::repo_dir(&cache_dir, plan.repo_id.as_str());
528        let snapshot_dir = cache_layout::snapshot_dir(&repo_dir, plan.revision.as_str());
529        return Ok(DownloadOutcome::Cached(snapshot_dir));
530    }
531
532    // Delegate to the standard download path which will re-check cache
533    // internally. The plan's value is the dry-run preview and the
534    // recommended config computed by the caller.
535    // BORROW: explicit .clone() for owned String argument
536    download_with_config(plan.repo_id.clone(), config).await
537}
538
539/// Blocking version of [`download_with_plan()`] for non-async callers.
540///
541/// Creates a Tokio runtime internally. Do not call from within
542/// an existing async context (use [`download_with_plan()`] instead).
543///
544/// # Errors
545///
546/// Same as [`download_with_plan()`].
547pub fn download_with_plan_blocking(
548    plan: &DownloadPlan,
549    config: &FetchConfig,
550) -> Result<DownloadOutcome<PathBuf>, FetchError> {
551    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
552        path: PathBuf::from("<runtime>"),
553        source: e,
554    })?;
555    rt.block_on(download_with_plan(plan, config))
556}