Skip to main content

hf_fetch_model/
lib.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! # hf-fetch-model
4//!
5//! Fast `HuggingFace` model downloads for Rust.
6//!
7//! An embeddable library for downloading `HuggingFace` model repositories
8//! with maximum throughput. Wraps [`hf_hub`] and adds repo-level orchestration.
9//!
10//! ## Quick Start
11//!
12//! ```rust,no_run
13//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
14//! let outcome = hf_fetch_model::download("julien-c/dummy-unknown".to_owned()).await?;
15//! println!("Model at: {}", outcome.inner().display());
16//! # Ok(())
17//! # }
18//! ```
19//!
20//! ## Configured Download
21//!
22//! ```rust,no_run
23//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
24//! use hf_fetch_model::FetchConfig;
25//!
26//! let config = FetchConfig::builder()
27//!     .filter("*.safetensors")
28//!     .filter("*.json")
29//!     .on_progress(|e| {
30//!         println!("{}: {:.1}%", e.filename, e.percent);
31//!     })
32//!     .build()?;
33//!
34//! let outcome = hf_fetch_model::download_with_config(
35//!     "google/gemma-2-2b".to_owned(),
36//!     &config,
37//! ).await?;
38//! // outcome.is_cached() tells you if it came from local cache
39//! let path = outcome.into_inner();
40//! # Ok(())
41//! # }
42//! ```
43//!
44//! ## `HuggingFace` Cache
45//!
46//! Downloaded files are stored in the standard `HuggingFace` cache directory
47//! (`~/.cache/huggingface/hub/`), ensuring compatibility with Python tooling.
48//!
49//! ## Authentication
50//!
51//! Set the `HF_TOKEN` environment variable to access private or gated models,
52//! or use [`FetchConfig::builder().token()`](FetchConfigBuilder::token).
53
54pub mod cache;
55pub mod checksum;
56mod chunked;
57pub mod config;
58pub mod discover;
59pub mod download;
60pub mod error;
61pub mod progress;
62pub mod repo;
63mod retry;
64
65pub use config::{FetchConfig, FetchConfigBuilder, Filter};
66pub use download::DownloadOutcome;
67pub use error::{FetchError, FileFailure};
68pub use progress::ProgressEvent;
69
70use std::collections::HashMap;
71use std::path::PathBuf;
72
73use hf_hub::{Repo, RepoType};
74
75/// Downloads all files from a `HuggingFace` model repository.
76///
77/// Uses high-throughput mode for maximum download speed, including
78/// multi-connection chunked downloads for large files (≥100 MiB by default,
79/// 8 parallel connections per file). Files are stored in the standard
80/// `HuggingFace` cache layout (`~/.cache/huggingface/hub/`).
81///
82/// Authentication is handled via the `HF_TOKEN` environment variable when set.
83///
84/// For filtering, progress, and other options, use [`download_with_config()`].
85///
86/// # Arguments
87///
88/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
89///
90/// # Returns
91///
92/// The path to the snapshot directory containing all downloaded files.
93///
94/// # Errors
95///
96/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
97/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
98/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
99pub async fn download(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
100    let config = FetchConfig::builder().build()?;
101    download_with_config(repo_id, &config).await
102}
103
104/// Downloads files from a `HuggingFace` model repository using the given configuration.
105///
106/// Supports filtering, progress reporting, custom revision, authentication,
107/// and concurrency settings via [`FetchConfig`].
108///
109/// # Arguments
110///
111/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
112/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
113///
114/// # Returns
115///
116/// The path to the snapshot directory containing all downloaded files.
117///
118/// # Errors
119///
120/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
121/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
122pub async fn download_with_config(
123    repo_id: String,
124    config: &FetchConfig,
125) -> Result<DownloadOutcome<PathBuf>, FetchError> {
126    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
127
128    if let Some(ref token) = config.token {
129        // BORROW: explicit .clone() to pass owned String
130        builder = builder.with_token(Some(token.clone()));
131    }
132
133    if let Some(ref dir) = config.output_dir {
134        // BORROW: explicit .clone() for owned PathBuf
135        builder = builder.with_cache_dir(dir.clone());
136    }
137
138    let api = builder.build().map_err(FetchError::Api)?;
139
140    let hf_repo = match config.revision {
141        Some(ref rev) => {
142            // BORROW: explicit .clone() for owned String arguments
143            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
144        }
145        None => Repo::new(repo_id.clone(), RepoType::Model),
146    };
147
148    let repo = api.repo(hf_repo);
149    download::download_all_files(repo, repo_id, Some(config)).await
150}
151
152/// Blocking version of [`download()`] for non-async callers.
153///
154/// Creates a Tokio runtime internally. Do not call from within
155/// an existing async context (use [`download()`] instead).
156///
157/// # Errors
158///
159/// Same as [`download()`].
160pub fn download_blocking(repo_id: String) -> Result<DownloadOutcome<PathBuf>, FetchError> {
161    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
162        path: PathBuf::from("<runtime>"),
163        source: e,
164    })?;
165    rt.block_on(download(repo_id))
166}
167
168/// Blocking version of [`download_with_config()`] for non-async callers.
169///
170/// Creates a Tokio runtime internally. Do not call from within
171/// an existing async context (use [`download_with_config()`] instead).
172///
173/// # Errors
174///
175/// Same as [`download_with_config()`].
176pub fn download_with_config_blocking(
177    repo_id: String,
178    config: &FetchConfig,
179) -> Result<DownloadOutcome<PathBuf>, FetchError> {
180    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
181        path: PathBuf::from("<runtime>"),
182        source: e,
183    })?;
184    rt.block_on(download_with_config(repo_id, config))
185}
186
187/// Downloads all files from a `HuggingFace` model repository and returns
188/// a filename → path map.
189///
190/// Each key is the relative filename within the repository (e.g.,
191/// `"config.json"`, `"model.safetensors"`), and each value is the
192/// absolute local path to the downloaded file.
193///
194/// Uses the same high-throughput defaults as [`download()`]: multi-connection
195/// chunked downloads for large files (≥100 MiB, 8 parallel connections).
196///
197/// For filtering, progress, and other options, use
198/// [`download_files_with_config()`].
199///
200/// # Arguments
201///
202/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
203///
204/// # Errors
205///
206/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
207/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
208/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
209pub async fn download_files(
210    repo_id: String,
211) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
212    let config = FetchConfig::builder().build()?;
213    download_files_with_config(repo_id, &config).await
214}
215
216/// Downloads files from a `HuggingFace` model repository using the given
217/// configuration and returns a filename → path map.
218///
219/// Each key is the relative filename within the repository (e.g.,
220/// `"config.json"`, `"model.safetensors"`), and each value is the
221/// absolute local path to the downloaded file.
222///
223/// # Arguments
224///
225/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
226/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
227///
228/// # Errors
229///
230/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
231/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
232pub async fn download_files_with_config(
233    repo_id: String,
234    config: &FetchConfig,
235) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
236    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
237
238    if let Some(ref token) = config.token {
239        // BORROW: explicit .clone() to pass owned String
240        builder = builder.with_token(Some(token.clone()));
241    }
242
243    if let Some(ref dir) = config.output_dir {
244        // BORROW: explicit .clone() for owned PathBuf
245        builder = builder.with_cache_dir(dir.clone());
246    }
247
248    let api = builder.build().map_err(FetchError::Api)?;
249
250    let hf_repo = match config.revision {
251        Some(ref rev) => {
252            // BORROW: explicit .clone() for owned String arguments
253            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
254        }
255        None => Repo::new(repo_id.clone(), RepoType::Model),
256    };
257
258    let repo = api.repo(hf_repo);
259    download::download_all_files_map(repo, repo_id, Some(config)).await
260}
261
262/// Blocking version of [`download_files()`] for non-async callers.
263///
264/// Creates a Tokio runtime internally. Do not call from within
265/// an existing async context (use [`download_files()`] instead).
266///
267/// # Errors
268///
269/// Same as [`download_files()`].
270pub fn download_files_blocking(
271    repo_id: String,
272) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
273    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
274        path: PathBuf::from("<runtime>"),
275        source: e,
276    })?;
277    rt.block_on(download_files(repo_id))
278}
279
280/// Downloads a single file from a `HuggingFace` model repository.
281///
282/// Returns the local cache path. If the file is already cached (and
283/// checksums match when `verify_checksums` is enabled), the download
284/// is skipped and the cached path is returned immediately.
285///
286/// Files at or above [`FetchConfig`]'s `chunk_threshold` (default 100 MiB)
287/// are downloaded using multiple parallel HTTP Range connections
288/// (`connections_per_file`, default 8). Smaller files use a single
289/// connection.
290///
291/// # Arguments
292///
293/// * `repo_id` — Repository identifier (e.g., `"mntss/clt-gemma-2-2b-426k"`).
294/// * `filename` — Exact filename within the repository (e.g., `"W_enc_5.safetensors"`).
295/// * `config` — Shared configuration for auth, progress, checksums, retries, and chunking.
296///
297/// # Errors
298///
299/// * [`FetchError::Http`] — if the file does not exist in the repository.
300/// * [`FetchError::Api`] — on download failure (after retries).
301/// * [`FetchError::Checksum`] — if verification is enabled and fails.
302pub async fn download_file(
303    repo_id: String,
304    filename: &str,
305    config: &FetchConfig,
306) -> Result<DownloadOutcome<PathBuf>, FetchError> {
307    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
308
309    if let Some(ref token) = config.token {
310        // BORROW: explicit .clone() to pass owned String
311        builder = builder.with_token(Some(token.clone()));
312    }
313
314    if let Some(ref dir) = config.output_dir {
315        // BORROW: explicit .clone() for owned PathBuf
316        builder = builder.with_cache_dir(dir.clone());
317    }
318
319    let api = builder.build().map_err(FetchError::Api)?;
320
321    let hf_repo = match config.revision {
322        Some(ref rev) => {
323            // BORROW: explicit .clone() for owned String arguments
324            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
325        }
326        None => Repo::new(repo_id.clone(), RepoType::Model),
327    };
328
329    let repo = api.repo(hf_repo);
330    download::download_file_by_name(repo, repo_id, filename, config).await
331}
332
333/// Blocking version of [`download_file()`] for non-async callers.
334///
335/// Creates a Tokio runtime internally. Do not call from within
336/// an existing async context (use [`download_file()`] instead).
337///
338/// # Errors
339///
340/// Same as [`download_file()`].
341pub fn download_file_blocking(
342    repo_id: String,
343    filename: &str,
344    config: &FetchConfig,
345) -> Result<DownloadOutcome<PathBuf>, FetchError> {
346    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
347        path: PathBuf::from("<runtime>"),
348        source: e,
349    })?;
350    rt.block_on(download_file(repo_id, filename, config))
351}
352
353/// Blocking version of [`download_files_with_config()`] for non-async callers.
354///
355/// Creates a Tokio runtime internally. Do not call from within
356/// an existing async context (use [`download_files_with_config()`] instead).
357///
358/// # Errors
359///
360/// Same as [`download_files_with_config()`].
361pub fn download_files_with_config_blocking(
362    repo_id: String,
363    config: &FetchConfig,
364) -> Result<DownloadOutcome<HashMap<String, PathBuf>>, FetchError> {
365    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
366        path: PathBuf::from("<runtime>"),
367        source: e,
368    })?;
369    rt.block_on(download_files_with_config(repo_id, config))
370}