Skip to main content

hf_fetch_model/
lib.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! # hf-fetch-model
4//!
5//! Fast `HuggingFace` model downloads for Rust.
6//!
7//! An embeddable library for downloading `HuggingFace` model repositories
8//! with maximum throughput. Wraps [`hf_hub`] and adds repo-level orchestration.
9//!
10//! ## Quick Start
11//!
12//! ```rust,no_run
13//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
14//! let path = hf_fetch_model::download("julien-c/dummy-unknown".to_owned()).await?;
15//! println!("Model downloaded to: {}", path.display());
16//! # Ok(())
17//! # }
18//! ```
19//!
20//! ## Configured Download
21//!
22//! ```rust,no_run
23//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
24//! use hf_fetch_model::FetchConfig;
25//!
26//! let config = FetchConfig::builder()
27//!     .filter("*.safetensors")
28//!     .filter("*.json")
29//!     .on_progress(|e| {
30//!         println!("{}: {:.1}%", e.filename, e.percent);
31//!     })
32//!     .build()?;
33//!
34//! let path = hf_fetch_model::download_with_config(
35//!     "google/gemma-2-2b".to_owned(),
36//!     &config,
37//! ).await?;
38//! # Ok(())
39//! # }
40//! ```
41//!
42//! ## `HuggingFace` Cache
43//!
44//! Downloaded files are stored in the standard `HuggingFace` cache directory
45//! (`~/.cache/huggingface/hub/`), ensuring compatibility with Python tooling.
46//!
47//! ## Authentication
48//!
49//! Set the `HF_TOKEN` environment variable to access private or gated models,
50//! or use [`FetchConfig::builder().token()`](FetchConfigBuilder::token).
51
52pub mod cache;
53pub mod checksum;
54mod chunked;
55pub mod config;
56pub mod discover;
57pub mod download;
58pub mod error;
59pub mod progress;
60pub mod repo;
61mod retry;
62
63pub use config::{FetchConfig, FetchConfigBuilder, Filter};
64pub use error::{FetchError, FileFailure};
65pub use progress::ProgressEvent;
66
67use std::collections::HashMap;
68use std::path::PathBuf;
69
70use hf_hub::{Repo, RepoType};
71
72/// Downloads all files from a `HuggingFace` model repository.
73///
74/// Uses high-throughput mode for maximum download speed. Files are stored
75/// in the standard `HuggingFace` cache layout (`~/.cache/huggingface/hub/`).
76///
77/// Authentication is handled via the `HF_TOKEN` environment variable when set.
78///
79/// For filtering, progress, and other options, use [`download_with_config()`].
80///
81/// # Arguments
82///
83/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
84///
85/// # Returns
86///
87/// The path to the snapshot directory containing all downloaded files.
88///
89/// # Errors
90///
91/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails.
92/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
93/// * [`FetchError::Auth`] — if authentication is required but fails.
94pub async fn download(repo_id: String) -> Result<PathBuf, FetchError> {
95    let api = hf_hub::api::tokio::ApiBuilder::new()
96        .high()
97        .build()
98        .map_err(FetchError::Api)?;
99
100    let repo = api.model(repo_id.clone());
101    download::download_all_files(repo, repo_id, None).await
102}
103
104/// Downloads files from a `HuggingFace` model repository using the given configuration.
105///
106/// Supports filtering, progress reporting, custom revision, authentication,
107/// and concurrency settings via [`FetchConfig`].
108///
109/// # Arguments
110///
111/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
112/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
113///
114/// # Returns
115///
116/// The path to the snapshot directory containing all downloaded files.
117///
118/// # Errors
119///
120/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails.
121/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
122/// * [`FetchError::Auth`] — if authentication is required but fails.
123pub async fn download_with_config(
124    repo_id: String,
125    config: &FetchConfig,
126) -> Result<PathBuf, FetchError> {
127    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
128
129    if let Some(ref token) = config.token {
130        // BORROW: explicit .clone() to pass owned String
131        builder = builder.with_token(Some(token.clone()));
132    }
133
134    if let Some(ref dir) = config.output_dir {
135        // BORROW: explicit .clone() for owned PathBuf
136        builder = builder.with_cache_dir(dir.clone());
137    }
138
139    let api = builder.build().map_err(FetchError::Api)?;
140
141    let hf_repo = match config.revision {
142        Some(ref rev) => {
143            // BORROW: explicit .clone() for owned String arguments
144            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
145        }
146        None => Repo::new(repo_id.clone(), RepoType::Model),
147    };
148
149    let repo = api.repo(hf_repo);
150    download::download_all_files(repo, repo_id, Some(config)).await
151}
152
153/// Blocking version of [`download()`] for non-async callers.
154///
155/// Creates a Tokio runtime internally. Do not call from within
156/// an existing async context (use [`download()`] instead).
157///
158/// # Errors
159///
160/// Same as [`download()`].
161pub fn download_blocking(repo_id: String) -> Result<PathBuf, FetchError> {
162    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
163        path: PathBuf::from("<runtime>"),
164        source: e,
165    })?;
166    rt.block_on(download(repo_id))
167}
168
169/// Blocking version of [`download_with_config()`] for non-async callers.
170///
171/// Creates a Tokio runtime internally. Do not call from within
172/// an existing async context (use [`download_with_config()`] instead).
173///
174/// # Errors
175///
176/// Same as [`download_with_config()`].
177pub fn download_with_config_blocking(
178    repo_id: String,
179    config: &FetchConfig,
180) -> Result<PathBuf, FetchError> {
181    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
182        path: PathBuf::from("<runtime>"),
183        source: e,
184    })?;
185    rt.block_on(download_with_config(repo_id, config))
186}
187
188/// Downloads all files from a `HuggingFace` model repository and returns
189/// a filename → path map.
190///
191/// Each key is the relative filename within the repository (e.g.,
192/// `"config.json"`, `"model.safetensors"`), and each value is the
193/// absolute local path to the downloaded file.
194///
195/// For filtering, progress, and other options, use
196/// [`download_files_with_config()`].
197///
198/// # Arguments
199///
200/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
201///
202/// # Errors
203///
204/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails.
205/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
206/// * [`FetchError::Auth`] — if authentication is required but fails.
207pub async fn download_files(repo_id: String) -> Result<HashMap<String, PathBuf>, FetchError> {
208    let api = hf_hub::api::tokio::ApiBuilder::new()
209        .high()
210        .build()
211        .map_err(FetchError::Api)?;
212
213    let repo = api.model(repo_id.clone());
214    download::download_all_files_map(repo, repo_id, None).await
215}
216
217/// Downloads files from a `HuggingFace` model repository using the given
218/// configuration and returns a filename → path map.
219///
220/// Each key is the relative filename within the repository (e.g.,
221/// `"config.json"`, `"model.safetensors"`), and each value is the
222/// absolute local path to the downloaded file.
223///
224/// # Arguments
225///
226/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
227/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
228///
229/// # Errors
230///
231/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails.
232/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
233/// * [`FetchError::Auth`] — if authentication is required but fails.
234pub async fn download_files_with_config(
235    repo_id: String,
236    config: &FetchConfig,
237) -> Result<HashMap<String, PathBuf>, FetchError> {
238    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
239
240    if let Some(ref token) = config.token {
241        // BORROW: explicit .clone() to pass owned String
242        builder = builder.with_token(Some(token.clone()));
243    }
244
245    if let Some(ref dir) = config.output_dir {
246        // BORROW: explicit .clone() for owned PathBuf
247        builder = builder.with_cache_dir(dir.clone());
248    }
249
250    let api = builder.build().map_err(FetchError::Api)?;
251
252    let hf_repo = match config.revision {
253        Some(ref rev) => {
254            // BORROW: explicit .clone() for owned String arguments
255            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
256        }
257        None => Repo::new(repo_id.clone(), RepoType::Model),
258    };
259
260    let repo = api.repo(hf_repo);
261    download::download_all_files_map(repo, repo_id, Some(config)).await
262}
263
264/// Blocking version of [`download_files()`] for non-async callers.
265///
266/// Creates a Tokio runtime internally. Do not call from within
267/// an existing async context (use [`download_files()`] instead).
268///
269/// # Errors
270///
271/// Same as [`download_files()`].
272pub fn download_files_blocking(repo_id: String) -> Result<HashMap<String, PathBuf>, FetchError> {
273    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
274        path: PathBuf::from("<runtime>"),
275        source: e,
276    })?;
277    rt.block_on(download_files(repo_id))
278}
279
280/// Downloads a single file from a `HuggingFace` model repository.
281///
282/// Returns the local cache path. If the file is already cached (and
283/// checksums match when `verify_checksums` is enabled), the download
284/// is skipped and the cached path is returned immediately.
285///
286/// Files at or above [`FetchConfig`]'s `chunk_threshold` (default 100 MiB)
287/// are downloaded using multiple parallel HTTP Range connections
288/// (`connections_per_file`, default 8). Smaller files use a single
289/// connection.
290///
291/// # Arguments
292///
293/// * `repo_id` — Repository identifier (e.g., `"mntss/clt-gemma-2-2b-426k"`).
294/// * `filename` — Exact filename within the repository (e.g., `"W_enc_5.safetensors"`).
295/// * `config` — Shared configuration for auth, progress, checksums, retries, and chunking.
296///
297/// # Errors
298///
299/// * [`FetchError::Http`] — if the file does not exist in the repository.
300/// * [`FetchError::Api`] — on download failure (after retries).
301/// * [`FetchError::Checksum`] — if verification is enabled and fails.
302pub async fn download_file(
303    repo_id: String,
304    filename: &str,
305    config: &FetchConfig,
306) -> Result<PathBuf, FetchError> {
307    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
308
309    if let Some(ref token) = config.token {
310        // BORROW: explicit .clone() to pass owned String
311        builder = builder.with_token(Some(token.clone()));
312    }
313
314    if let Some(ref dir) = config.output_dir {
315        // BORROW: explicit .clone() for owned PathBuf
316        builder = builder.with_cache_dir(dir.clone());
317    }
318
319    let api = builder.build().map_err(FetchError::Api)?;
320
321    let hf_repo = match config.revision {
322        Some(ref rev) => {
323            // BORROW: explicit .clone() for owned String arguments
324            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
325        }
326        None => Repo::new(repo_id.clone(), RepoType::Model),
327    };
328
329    let repo = api.repo(hf_repo);
330    download::download_file_by_name(repo, repo_id, filename, config).await
331}
332
333/// Blocking version of [`download_file()`] for non-async callers.
334///
335/// Creates a Tokio runtime internally. Do not call from within
336/// an existing async context (use [`download_file()`] instead).
337///
338/// # Errors
339///
340/// Same as [`download_file()`].
341pub fn download_file_blocking(
342    repo_id: String,
343    filename: &str,
344    config: &FetchConfig,
345) -> Result<PathBuf, FetchError> {
346    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
347        path: PathBuf::from("<runtime>"),
348        source: e,
349    })?;
350    rt.block_on(download_file(repo_id, filename, config))
351}
352
353/// Blocking version of [`download_files_with_config()`] for non-async callers.
354///
355/// Creates a Tokio runtime internally. Do not call from within
356/// an existing async context (use [`download_files_with_config()`] instead).
357///
358/// # Errors
359///
360/// Same as [`download_files_with_config()`].
361pub fn download_files_with_config_blocking(
362    repo_id: String,
363    config: &FetchConfig,
364) -> Result<HashMap<String, PathBuf>, FetchError> {
365    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
366        path: PathBuf::from("<runtime>"),
367        source: e,
368    })?;
369    rt.block_on(download_files_with_config(repo_id, config))
370}