Skip to main content

hf_fetch_model/
lib.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! # hf-fetch-model
4//!
5//! Fast `HuggingFace` model downloads for Rust.
6//!
7//! An embeddable library for downloading `HuggingFace` model repositories
8//! with maximum throughput. Wraps [`hf_hub`] and adds repo-level orchestration.
9//!
10//! ## Quick Start
11//!
12//! ```rust,no_run
13//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
14//! let path = hf_fetch_model::download("julien-c/dummy-unknown".to_owned()).await?;
15//! println!("Model downloaded to: {}", path.display());
16//! # Ok(())
17//! # }
18//! ```
19//!
20//! ## Configured Download
21//!
22//! ```rust,no_run
23//! # async fn example() -> Result<(), hf_fetch_model::FetchError> {
24//! use hf_fetch_model::FetchConfig;
25//!
26//! let config = FetchConfig::builder()
27//!     .filter("*.safetensors")
28//!     .filter("*.json")
29//!     .on_progress(|e| {
30//!         println!("{}: {:.1}%", e.filename, e.percent);
31//!     })
32//!     .build()?;
33//!
34//! let path = hf_fetch_model::download_with_config(
35//!     "google/gemma-2-2b".to_owned(),
36//!     &config,
37//! ).await?;
38//! # Ok(())
39//! # }
40//! ```
41//!
42//! ## `HuggingFace` Cache
43//!
44//! Downloaded files are stored in the standard `HuggingFace` cache directory
45//! (`~/.cache/huggingface/hub/`), ensuring compatibility with Python tooling.
46//!
47//! ## Authentication
48//!
49//! Set the `HF_TOKEN` environment variable to access private or gated models,
50//! or use [`FetchConfig::builder().token()`](FetchConfigBuilder::token).
51
52pub mod cache;
53pub mod checksum;
54pub mod config;
55pub mod discover;
56pub mod download;
57pub mod error;
58pub mod progress;
59pub mod repo;
60mod retry;
61
62pub use config::{FetchConfig, FetchConfigBuilder, Filter};
63pub use error::{FetchError, FileFailure};
64pub use progress::ProgressEvent;
65
66use std::collections::HashMap;
67use std::path::PathBuf;
68
69use hf_hub::{Repo, RepoType};
70
71/// Downloads all files from a `HuggingFace` model repository.
72///
73/// Uses high-throughput mode for maximum download speed. Files are stored
74/// in the standard `HuggingFace` cache layout (`~/.cache/huggingface/hub/`).
75///
76/// Authentication is handled via the `HF_TOKEN` environment variable when set.
77///
78/// For filtering, progress, and other options, use [`download_with_config()`].
79///
80/// # Arguments
81///
82/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
83///
84/// # Returns
85///
86/// The path to the snapshot directory containing all downloaded files.
87///
88/// # Errors
89///
90/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails.
91/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
92/// * [`FetchError::Auth`] — if authentication is required but fails.
93pub async fn download(repo_id: String) -> Result<PathBuf, FetchError> {
94    let api = hf_hub::api::tokio::ApiBuilder::new()
95        .high()
96        .build()
97        .map_err(FetchError::Api)?;
98
99    let repo = api.model(repo_id.clone());
100    download::download_all_files(repo, repo_id, None).await
101}
102
103/// Downloads files from a `HuggingFace` model repository using the given configuration.
104///
105/// Supports filtering, progress reporting, custom revision, authentication,
106/// and concurrency settings via [`FetchConfig`].
107///
108/// # Arguments
109///
110/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
111/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
112///
113/// # Returns
114///
115/// The path to the snapshot directory containing all downloaded files.
116///
117/// # Errors
118///
119/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails.
120/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
121/// * [`FetchError::Auth`] — if authentication is required but fails.
122pub async fn download_with_config(
123    repo_id: String,
124    config: &FetchConfig,
125) -> Result<PathBuf, FetchError> {
126    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
127
128    if let Some(ref token) = config.token {
129        // BORROW: explicit .clone() to pass owned String
130        builder = builder.with_token(Some(token.clone()));
131    }
132
133    if let Some(ref dir) = config.output_dir {
134        // BORROW: explicit .clone() for owned PathBuf
135        builder = builder.with_cache_dir(dir.clone());
136    }
137
138    let api = builder.build().map_err(FetchError::Api)?;
139
140    let hf_repo = match config.revision {
141        Some(ref rev) => {
142            // BORROW: explicit .clone() for owned String arguments
143            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
144        }
145        None => Repo::new(repo_id.clone(), RepoType::Model),
146    };
147
148    let repo = api.repo(hf_repo);
149    download::download_all_files(repo, repo_id, Some(config)).await
150}
151
152/// Blocking version of [`download()`] for non-async callers.
153///
154/// Creates a Tokio runtime internally. Do not call from within
155/// an existing async context (use [`download()`] instead).
156///
157/// # Errors
158///
159/// Same as [`download()`].
160pub fn download_blocking(repo_id: String) -> Result<PathBuf, FetchError> {
161    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
162        path: PathBuf::from("<runtime>"),
163        source: e,
164    })?;
165    rt.block_on(download(repo_id))
166}
167
168/// Blocking version of [`download_with_config()`] for non-async callers.
169///
170/// Creates a Tokio runtime internally. Do not call from within
171/// an existing async context (use [`download_with_config()`] instead).
172///
173/// # Errors
174///
175/// Same as [`download_with_config()`].
176pub fn download_with_config_blocking(
177    repo_id: String,
178    config: &FetchConfig,
179) -> Result<PathBuf, FetchError> {
180    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
181        path: PathBuf::from("<runtime>"),
182        source: e,
183    })?;
184    rt.block_on(download_with_config(repo_id, config))
185}
186
187/// Downloads all files from a `HuggingFace` model repository and returns
188/// a filename → path map.
189///
190/// Each key is the relative filename within the repository (e.g.,
191/// `"config.json"`, `"model.safetensors"`), and each value is the
192/// absolute local path to the downloaded file.
193///
194/// For filtering, progress, and other options, use
195/// [`download_files_with_config()`].
196///
197/// # Arguments
198///
199/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
200///
201/// # Errors
202///
203/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails.
204/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
205/// * [`FetchError::Auth`] — if authentication is required but fails.
206pub async fn download_files(repo_id: String) -> Result<HashMap<String, PathBuf>, FetchError> {
207    let api = hf_hub::api::tokio::ApiBuilder::new()
208        .high()
209        .build()
210        .map_err(FetchError::Api)?;
211
212    let repo = api.model(repo_id.clone());
213    download::download_all_files_map(repo, repo_id, None).await
214}
215
216/// Downloads files from a `HuggingFace` model repository using the given
217/// configuration and returns a filename → path map.
218///
219/// Each key is the relative filename within the repository (e.g.,
220/// `"config.json"`, `"model.safetensors"`), and each value is the
221/// absolute local path to the downloaded file.
222///
223/// # Arguments
224///
225/// * `repo_id` — The repository identifier (e.g., `"google/gemma-2-2b-it"`).
226/// * `config` — Download configuration (see [`FetchConfig::builder()`]).
227///
228/// # Errors
229///
230/// * [`FetchError::Api`] — if the `HuggingFace` API or download fails.
231/// * [`FetchError::RepoNotFound`] — if the repository does not exist.
232/// * [`FetchError::Auth`] — if authentication is required but fails.
233pub async fn download_files_with_config(
234    repo_id: String,
235    config: &FetchConfig,
236) -> Result<HashMap<String, PathBuf>, FetchError> {
237    let mut builder = hf_hub::api::tokio::ApiBuilder::new().high();
238
239    if let Some(ref token) = config.token {
240        // BORROW: explicit .clone() to pass owned String
241        builder = builder.with_token(Some(token.clone()));
242    }
243
244    if let Some(ref dir) = config.output_dir {
245        // BORROW: explicit .clone() for owned PathBuf
246        builder = builder.with_cache_dir(dir.clone());
247    }
248
249    let api = builder.build().map_err(FetchError::Api)?;
250
251    let hf_repo = match config.revision {
252        Some(ref rev) => {
253            // BORROW: explicit .clone() for owned String arguments
254            Repo::with_revision(repo_id.clone(), RepoType::Model, rev.clone())
255        }
256        None => Repo::new(repo_id.clone(), RepoType::Model),
257    };
258
259    let repo = api.repo(hf_repo);
260    download::download_all_files_map(repo, repo_id, Some(config)).await
261}
262
263/// Blocking version of [`download_files()`] for non-async callers.
264///
265/// Creates a Tokio runtime internally. Do not call from within
266/// an existing async context (use [`download_files()`] instead).
267///
268/// # Errors
269///
270/// Same as [`download_files()`].
271pub fn download_files_blocking(repo_id: String) -> Result<HashMap<String, PathBuf>, FetchError> {
272    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
273        path: PathBuf::from("<runtime>"),
274        source: e,
275    })?;
276    rt.block_on(download_files(repo_id))
277}
278
279/// Blocking version of [`download_files_with_config()`] for non-async callers.
280///
281/// Creates a Tokio runtime internally. Do not call from within
282/// an existing async context (use [`download_files_with_config()`] instead).
283///
284/// # Errors
285///
286/// Same as [`download_files_with_config()`].
287pub fn download_files_with_config_blocking(
288    repo_id: String,
289    config: &FetchConfig,
290) -> Result<HashMap<String, PathBuf>, FetchError> {
291    let rt = tokio::runtime::Runtime::new().map_err(|e| FetchError::Io {
292        path: PathBuf::from("<runtime>"),
293        source: e,
294    })?;
295    rt.block_on(download_files_with_config(repo_id, config))
296}