Skip to main content

hf_fetch_model/
lib.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! # hf-fetch-model
4//!
5//! Fast `HuggingFace` model downloads for Rust.
6//!
7//! An embeddable library for downloading `HuggingFace` model repositories
8//! with maximum throughput. Wraps [`hf_hub`] and adds repo-level orchestration.
9//!
10//! ## Quick Start
11//!
12//! ```rust,no_run
13//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
14//! let path = hf_fetch_model::download("julien-c/dummy-unknown".to_owned()).await?;
15//! println!("Model downloaded to: {}", path.display());
16//! # Ok(())
17//! # }
18//! ```
19//!
20//! ## Configured Download
21//!
22//! ```rust,no_run
23//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
24//! use hf_fetch_model::FetchConfig;
25//!
26//! let config = FetchConfig::builder()
27//!     .filter("*.safetensors")
28//!     .filter("*.json")
29//!     .on_progress(|e| {
30//!         println!("{}: {:.1}%", e.filename, e.percent);
31//!     })
32//!     .build()?;
33//!
34//! let path = hf_fetch_model::download_with_config(
35//!     "google/gemma-2-2b".to_owned(),
36//!     &config,
37//! ).await?;
38//! # Ok(())
39//! # }
40//! ```
41//!
42//! ## `HuggingFace` Cache
43//!
44//! Downloaded files are stored in the standard `HuggingFace` cache directory
45//! (`~/.cache/huggingface/hub/`), ensuring compatibility with Python tooling.
46//!
47//! ## Authentication
48//!
49//! Set the `HF_TOKEN` environment variable to access private or gated models,
50//! or use [`FetchConfig::builder().token()`](FetchConfigBuilder::token).
51
52pub mod cache;
53pub mod checksum;
54mod chunked;
55pub mod config;
56pub mod discover;
57pub mod download;
58pub mod error;
59pub mod progress;
60pub mod repo;
61mod retry;
62
63pub use config::{FetchConfig, FetchConfigBuilder, Filter};
64pub use error::{FetchError, FileFailure};
65pub use progress::ProgressEvent;
66
67use std::collections::HashMap;
68use std::path::PathBuf;
69
70use hf_hub::{Repo, RepoType};
71
72/// Downloads all files from a `HuggingFace` model repository.
73///
74/// Uses high-throughput mode for maximum download speed, including
75/// multi-connection chunked downloads for large files (≥100 MiB by default,
76/// 8 parallel connections per file). Files are stored in the standard
77/// `HuggingFace` cache layout (`~/.cache/huggingface/hub/`).
78///
79/// Authentication is handled via the `HF_TOKEN` environment variable when set.
80///
81/// For filtering, progress, and other options, use [`download_with_config()`].
82///
83/// # Arguments
84///
85/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
86///
87/// # Returns
88///
89/// The path to the snapshot directory containing all downloaded files.
90///
91/// # Errors
92///
93/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
94/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
95/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
96pub async fn download(repo_id: String) -> Result<PathBuf, FetchError> {
97    let config = FetchConfig::builder().build()?;
98    download_with_config(repo_id, &config).await
99}
100
101/// Downloads files from a `HuggingFace` model repository using the given configuration.
102///
103/// Supports filtering, progress reporting, custom revision, authentication,
104/// and concurrency settings via [`FetchConfig`].
105///
106/// # Arguments
107///
108/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
109/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
110///
111/// # Returns
112///
113/// The path to the snapshot directory containing all downloaded files.
114///
115/// # Errors
116///
117/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
118/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
119pub async fn download_with_config(
120    repo_id: String,
121    config: &FetchConfig,
122) -> Result<PathBuf, FetchError> {
123    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
124
125    if let Some(ref token) = config.token {
126        // BORROW: explicit .clone() to pass owned String
127        builder = builder.with_token(Some(token.clone()));
128    }
129
130    if let Some(ref dir) = config.output_dir {
131        // BORROW: explicit .clone() for owned PathBuf
132        builder = builder.with_cache_dir(dir.clone());
133    }
134
135    let api = builder.build().map_err(FetchError::Api)?;
136
137    let hf_repo = match config.revision {
138        Some(ref rev) => {
139            // BORROW: explicit .clone() for owned String arguments
140            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
141        }
142        None => Repo::new(repo_id.clone(), RepoType::Model),
143    };
144
145    let repo = api.repo(hf_repo);
146    download::download_all_files(repo, repo_id, Some(config)).await
147}
148
149/// Blocking version of [`download()`] for non-async callers.
150///
151/// Creates a Tokio runtime internally. Do not call from within
152/// an existing async context (use [`download()`] instead).
153///
154/// # Errors
155///
156/// Same as [`download()`].
157pub fn download_blocking(repo_id: String) -> Result<PathBuf, FetchError> {
158    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
159        path: PathBuf::from("<runtime>"),
160        source: e,
161    })?;
162    rt.block_on(download(repo_id))
163}
164
165/// Blocking version of [`download_with_config()`] for non-async callers.
166///
167/// Creates a Tokio runtime internally. Do not call from within
168/// an existing async context (use [`download_with_config()`] instead).
169///
170/// # Errors
171///
172/// Same as [`download_with_config()`].
173pub fn download_with_config_blocking(
174    repo_id: String,
175    config: &FetchConfig,
176) -> Result<PathBuf, FetchError> {
177    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
178        path: PathBuf::from("<runtime>"),
179        source: e,
180    })?;
181    rt.block_on(download_with_config(repo_id, config))
182}
183
184/// Downloads all files from a `HuggingFace` model repository and returns
185/// a filename → path map.
186///
187/// Each key is the relative filename within the repository (e.g.,
188/// `"config.json"`, `"model.safetensors"`), and each value is the
189/// absolute local path to the downloaded file.
190///
191/// Uses the same high-throughput defaults as [`download()`]: multi-connection
192/// chunked downloads for large files (≥100 MiB, 8 parallel connections).
193///
194/// For filtering, progress, and other options, use
195/// [`download_files_with_config()`].
196///
197/// # Arguments
198///
199/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
200///
201/// # Errors
202///
203/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
204/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
205/// * [`FetchError::InvalidPattern`] — if the default config fails to build (should not happen).
206pub async fn download_files(repo_id: String) -> Result<HashMap<String, PathBuf>, FetchError> {
207    let config = FetchConfig::builder().build()?;
208    download_files_with_config(repo_id, &config).await
209}
210
211/// Downloads files from a `HuggingFace` model repository using the given
212/// configuration and returns a filename → path map.
213///
214/// Each key is the relative filename within the repository (e.g.,
215/// `"config.json"`, `"model.safetensors"`), and each value is the
216/// absolute local path to the downloaded file.
217///
218/// # Arguments
219///
220/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
221/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
222///
223/// # Errors
224///
225/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails (includes auth failures).
226/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
227pub async fn download_files_with_config(
228    repo_id: String,
229    config: &FetchConfig,
230) -> Result<HashMap<String, PathBuf>, FetchError> {
231    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
232
233    if let Some(ref token) = config.token {
234        // BORROW: explicit .clone() to pass owned String
235        builder = builder.with_token(Some(token.clone()));
236    }
237
238    if let Some(ref dir) = config.output_dir {
239        // BORROW: explicit .clone() for owned PathBuf
240        builder = builder.with_cache_dir(dir.clone());
241    }
242
243    let api = builder.build().map_err(FetchError::Api)?;
244
245    let hf_repo = match config.revision {
246        Some(ref rev) => {
247            // BORROW: explicit .clone() for owned String arguments
248            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
249        }
250        None => Repo::new(repo_id.clone(), RepoType::Model),
251    };
252
253    let repo = api.repo(hf_repo);
254    download::download_all_files_map(repo, repo_id, Some(config)).await
255}
256
257/// Blocking version of [`download_files()`] for non-async callers.
258///
259/// Creates a Tokio runtime internally. Do not call from within
260/// an existing async context (use [`download_files()`] instead).
261///
262/// # Errors
263///
264/// Same as [`download_files()`].
265pub fn download_files_blocking(repo_id: String) -> Result<HashMap<String, PathBuf>, FetchError> {
266    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
267        path: PathBuf::from("<runtime>"),
268        source: e,
269    })?;
270    rt.block_on(download_files(repo_id))
271}
272
273/// Downloads a single file from a `HuggingFace` model repository.
274///
275/// Returns the local cache path. If the file is already cached (and
276/// checksums match when `verify_checksums` is enabled), the download
277/// is skipped and the cached path is returned immediately.
278///
279/// Files at or above [`FetchConfig`]'s `chunk_threshold` (default 100 MiB)
280/// are downloaded using multiple parallel HTTP Range connections
281/// (`connections_per_file`, default 8). Smaller files use a single
282/// connection.
283///
284/// # Arguments
285///
286/// * `repo_id` — Repository identifier (e.g., `"mntss/clt-gemma-2-2b-426k"`).
287/// * `filename` — Exact filename within the repository (e.g., `"W_enc_5.safetensors"`).
288/// * `config` — Shared configuration for auth, progress, checksums, retries, and chunking.
289///
290/// # Errors
291///
292/// * [`FetchError::Http`] — if the file does not exist in the repository.
293/// * [`FetchError::Api`] — on download failure (after retries).
294/// * [`FetchError::Checksum`] — if verification is enabled and fails.
295pub async fn download_file(
296    repo_id: String,
297    filename: &str,
298    config: &FetchConfig,
299) -> Result<PathBuf, FetchError> {
300    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
301
302    if let Some(ref token) = config.token {
303        // BORROW: explicit .clone() to pass owned String
304        builder = builder.with_token(Some(token.clone()));
305    }
306
307    if let Some(ref dir) = config.output_dir {
308        // BORROW: explicit .clone() for owned PathBuf
309        builder = builder.with_cache_dir(dir.clone());
310    }
311
312    let api = builder.build().map_err(FetchError::Api)?;
313
314    let hf_repo = match config.revision {
315        Some(ref rev) => {
316            // BORROW: explicit .clone() for owned String arguments
317            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
318        }
319        None => Repo::new(repo_id.clone(), RepoType::Model),
320    };
321
322    let repo = api.repo(hf_repo);
323    download::download_file_by_name(repo, repo_id, filename, config).await
324}
325
326/// Blocking version of [`download_file()`] for non-async callers.
327///
328/// Creates a Tokio runtime internally. Do not call from within
329/// an existing async context (use [`download_file()`] instead).
330///
331/// # Errors
332///
333/// Same as [`download_file()`].
334pub fn download_file_blocking(
335    repo_id: String,
336    filename: &str,
337    config: &FetchConfig,
338) -> Result<PathBuf, FetchError> {
339    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
340        path: PathBuf::from("<runtime>"),
341        source: e,
342    })?;
343    rt.block_on(download_file(repo_id, filename, config))
344}
345
346/// Blocking version of [`download_files_with_config()`] for non-async callers.
347///
348/// Creates a Tokio runtime internally. Do not call from within
349/// an existing async context (use [`download_files_with_config()`] instead).
350///
351/// # Errors
352///
353/// Same as [`download_files_with_config()`].
354pub fn download_files_with_config_blocking(
355    repo_id: String,
356    config: &FetchConfig,
357) -> Result<HashMap<String, PathBuf>, FetchError> {
358    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
359        path: PathBuf::from("<runtime>"),
360        source: e,
361    })?;
362    rt.block_on(download_files_with_config(repo_id, config))
363}