Skip to main content

hf_fetch_model/
lib.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! # hf-fetch-model
4//!
5//! Fast `HuggingFace` model downloads for Rust.
6//!
7//! An embeddable library for downloading `HuggingFace` model repositories
8//! with maximum throughput. Wraps [`hf_hub`] and adds repo-level orchestration.
9//!
10//! ## Quick Start
11//!
12//! ```rust,no_run
13//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
14//! let outcome = hf_fetch_model::download("julien-c/dummy-unknown".to_owned()).await?;
15//! println!("Model at: {}", outcome.inner().display());
16//! # Ok(())
17//! # }
18//! ```
19//!
20//! ## Configured Download
21//!
22//! ```rust,no_run
23//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
24//! use hf_fetch_model::FetchConfig;
25//!
26//! let config = FetchConfig::builder()
27//!     .filter("*.safetensors")
28//!     .filter("*.json")
29//!     .on_progress(|e| {
30//!         println!("{}: {:.1}%", e.filename, e.percent);
31//!     })
32//!     .build()?;
33//!
34//! let outcome = hf_fetch_model::download_with_config(
35//!     "google/gemma-2-2b".to_owned(),
36//!     &config,
37//! ).await?;
38//! // outcome.is_cached() tells you if it came from local cache
39//! let path = outcome.into_inner();
40//! # Ok(())
41//! # }
42//! ```
43//!
44//! ## `HuggingFace` Cache
45//!
46//! Downloaded files are stored in the standard `HuggingFace` cache directory
47//! (`~/.cache/huggingface/hub/`), ensuring compatibility with Python tooling.
48//!
49//! ## Authentication
50//!
51//! Set the `HF_TOKEN` environment variable to access private or gated models,
52//! or use [`FetchConfig::builder().token()`](FetchConfigBuilder::token).
53
54pub mod cache;
55pub mod checksum;
56mod chunked;
57pub mod config;
58pub mod discover;
59pub mod download;
60pub mod error;
61pub mod progress;
62pub mod repo;
63mod retry;
64
65pub use config::{FetchConfig, FetchConfigBuilder, Filter};
66pub use discover::{GateStatus, ModelCardMetadata, SearchResult};
67pub use download::DownloadOutcome;
68pub use error::{FetchError, FileFailure};
69pub use progress::ProgressEvent;
70
71use std::collections::HashMap;
72use std::path::PathBuf;
73
74use hf_hub::{Repo, RepoType};
75
76/// Downloads all files from a `HuggingFace` model repository.
77///
78/// Uses high-throughput mode for maximum download speed, including
79/// multi-connection chunked downloads for large files (≥100 MiB by default,
80/// 8 parallel connections per file). Files are stored in the standard
81/// `HuggingFace` cache layout (`~/.cache/huggingface/hub/`).
82///
83/// Authentication is handled via the `HF_TOKEN` environment variable when set.
84///
85/// For filtering, progress, and other options, use [`download_with_config()`].
86///
87/// # Arguments
88///
89/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
90///
91/// # Returns
92///
93/// The path to the snapshot directory containing all downloaded files.
94///
95/// # Errors
96///
97/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
98/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
99/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
100pub async fn download(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
101    let config = FetchConfig::builder().build()?;
102    download_with_config(repo_id, &config).await
103}
104
105/// Downloads files from a `HuggingFace` model repository using the given configuration.
106///
107/// Supports filtering, progress reporting, custom revision, authentication,
108/// and concurrency settings via [`FetchConfig`].
109///
110/// # Arguments
111///
112/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
113/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
114///
115/// # Returns
116///
117/// The path to the snapshot directory containing all downloaded files.
118///
119/// # Errors
120///
121/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
122/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
123pub async fn download_with_config(
124    repo_id: String,
125    config: &FetchConfig,
126) -> Result<DownloadOutcome<PathBuf>, FetchError> {
127    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
128
129    if let Some(ref token) = config.token {
130        // BORROW: explicit .clone() to pass owned String
131        builder = builder.with_token(Some(token.clone()));
132    }
133
134    if let Some(ref dir) = config.output_dir {
135        // BORROW: explicit .clone() for owned PathBuf
136        builder = builder.with_cache_dir(dir.clone());
137    }
138
139    let api = builder.build().map_err(FetchError::Api)?;
140
141    let hf_repo = match config.revision {
142        Some(ref rev) => {
143            // BORROW: explicit .clone() for owned String arguments
144            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
145        }
146        None => Repo::new(repo_id.clone(), RepoType::Model),
147    };
148
149    let repo = api.repo(hf_repo);
150    download::download_all_files(repo, repo_id, Some(config)).await
151}
152
153/// Blocking version of [`download()`] for non-async callers.
154///
155/// Creates a Tokio runtime internally. Do not call from within
156/// an existing async context (use [`download()`] instead).
157///
158/// # Errors
159///
160/// Same as [`download()`].
161pub fn download_blocking(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
162    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
163        path: PathBuf::from("<runtime>"),
164        source: e,
165    })?;
166    rt.block_on(download(repo_id))
167}
168
169/// Blocking version of [`download_with_config()`] for non-async callers.
170///
171/// Creates a Tokio runtime internally. Do not call from within
172/// an existing async context (use [`download_with_config()`] instead).
173///
174/// # Errors
175///
176/// Same as [`download_with_config()`].
177pub fn download_with_config_blocking(
178    repo_id: String,
179    config: &FetchConfig,
180) -> Result<DownloadOutcome<PathBuf>, FetchError> {
181    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
182        path: PathBuf::from("<runtime>"),
183        source: e,
184    })?;
185    rt.block_on(download_with_config(repo_id, config))
186}
187
188/// Downloads all files from a `HuggingFace` model repository and returns
189/// a filename → path map.
190///
191/// Each key is the relative filename within the repository (e.g.,
192/// `"config.json"`, `"model.safetensors"`), and each value is the
193/// absolute local path to the downloaded file.
194///
195/// Uses the same high-throughput defaults as [`download()`]: multi-connection
196/// chunked downloads for large files (≥100 MiB, 8 parallel connections).
197///
198/// For filtering, progress, and other options, use
199/// [`download_files_with_config()`].
200///
201/// # Arguments
202///
203/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
204///
205/// # Errors
206///
207/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
208/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
209/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
210pub async fn download_files(
211    repo_id: String,
212) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
213    let config = FetchConfig::builder().build()?;
214    download_files_with_config(repo_id, &config).await
215}
216
217/// Downloads files from a `HuggingFace` model repository using the given
218/// configuration and returns a filename → path map.
219///
220/// Each key is the relative filename within the repository (e.g.,
221/// `"config.json"`, `"model.safetensors"`), and each value is the
222/// absolute local path to the downloaded file.
223///
224/// # Arguments
225///
226/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
227/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
228///
229/// # Errors
230///
231/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
232/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
233pub async fn download_files_with_config(
234    repo_id: String,
235    config: &FetchConfig,
236) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
237    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
238
239    if let Some(ref token) = config.token {
240        // BORROW: explicit .clone() to pass owned String
241        builder = builder.with_token(Some(token.clone()));
242    }
243
244    if let Some(ref dir) = config.output_dir {
245        // BORROW: explicit .clone() for owned PathBuf
246        builder = builder.with_cache_dir(dir.clone());
247    }
248
249    let api = builder.build().map_err(FetchError::Api)?;
250
251    let hf_repo = match config.revision {
252        Some(ref rev) => {
253            // BORROW: explicit .clone() for owned String arguments
254            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
255        }
256        None => Repo::new(repo_id.clone(), RepoType::Model),
257    };
258
259    let repo = api.repo(hf_repo);
260    download::download_all_files_map(repo, repo_id, Some(config)).await
261}
262
263/// Blocking version of [`download_files()`] for non-async callers.
264///
265/// Creates a Tokio runtime internally. Do not call from within
266/// an existing async context (use [`download_files()`] instead).
267///
268/// # Errors
269///
270/// Same as [`download_files()`].
271pub fn download_files_blocking(
272    repo_id: String,
273) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
274    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
275        path: PathBuf::from("<runtime>"),
276        source: e,
277    })?;
278    rt.block_on(download_files(repo_id))
279}
280
281/// Downloads a single file from a `HuggingFace` model repository.
282///
283/// Returns the local cache path. If the file is already cached (and
284/// checksums match when `verify_checksums` is enabled), the download
285/// is skipped and the cached path is returned immediately.
286///
287/// Files at or above [`FetchConfig`]'s `chunk_threshold` (default 100 MiB)
288/// are downloaded using multiple parallel HTTP Range connections
289/// (`connections_per_file`, default 8). Smaller files use a single
290/// connection.
291///
292/// # Arguments
293///
294/// * `repo_id` — Repository identifier (e.g., `"mntss/clt-gemma-2-2b-426k"`).
295/// * `filename` — Exact filename within the repository (e.g., `"W_enc_5.safetensors"`).
296/// * `config` — Shared configuration for auth, progress, checksums, retries, and chunking.
297///
298/// # Errors
299///
300/// * [`FetchError::Http`] — if the file does not exist in the repository.
301/// * [`FetchError::Api`] — on download failure (after retries).
302/// * [`FetchError::Checksum`] — if verification is enabled and fails.
303pub async fn download_file(
304    repo_id: String,
305    filename: &str,
306    config: &FetchConfig,
307) -> Result<DownloadOutcome<PathBuf>, FetchError> {
308    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
309
310    if let Some(ref token) = config.token {
311        // BORROW: explicit .clone() to pass owned String
312        builder = builder.with_token(Some(token.clone()));
313    }
314
315    if let Some(ref dir) = config.output_dir {
316        // BORROW: explicit .clone() for owned PathBuf
317        builder = builder.with_cache_dir(dir.clone());
318    }
319
320    let api = builder.build().map_err(FetchError::Api)?;
321
322    let hf_repo = match config.revision {
323        Some(ref rev) => {
324            // BORROW: explicit .clone() for owned String arguments
325            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
326        }
327        None => Repo::new(repo_id.clone(), RepoType::Model),
328    };
329
330    let repo = api.repo(hf_repo);
331    download::download_file_by_name(repo, repo_id, filename, config).await
332}
333
334/// Blocking version of [`download_file()`] for non-async callers.
335///
336/// Creates a Tokio runtime internally. Do not call from within
337/// an existing async context (use [`download_file()`] instead).
338///
339/// # Errors
340///
341/// Same as [`download_file()`].
342pub fn download_file_blocking(
343    repo_id: String,
344    filename: &str,
345    config: &FetchConfig,
346) -> Result<DownloadOutcome<PathBuf>, FetchError> {
347    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
348        path: PathBuf::from("<runtime>"),
349        source: e,
350    })?;
351    rt.block_on(download_file(repo_id, filename, config))
352}
353
354/// Blocking version of [`download_files_with_config()`] for non-async callers.
355///
356/// Creates a Tokio runtime internally. Do not call from within
357/// an existing async context (use [`download_files_with_config()`] instead).
358///
359/// # Errors
360///
361/// Same as [`download_files_with_config()`].
362pub fn download_files_with_config_blocking(
363    repo_id: String,
364    config: &FetchConfig,
365) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
366    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
367        path: PathBuf::from("<runtime>"),
368        source: e,
369    })?;
370    rt.block_on(download_files_with_config(repo_id, config))
371}