Skip to main content

hf_fetch_model/
config.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Configuration for model downloads.
4//!
5//! [`FetchConfig`] controls revision, authentication, file filtering,
6//! concurrency, timeouts, retry behavior, and progress reporting.
7
8use std::path::PathBuf;
9use std::sync::Arc;
10use std::time::Duration;
11
12use globset::{Glob, GlobSet, GlobSetBuilder};
13
14use crate::error::FetchError;
15use crate::progress::ProgressEvent;
16
17// TRAIT_OBJECT: heterogeneous progress handlers from different callers
18pub(crate) type ProgressCallback = Arc<dyn Fn(&ProgressEvent) + Send + Sync>;
19
20/// Configuration for downloading a model repository.
21///
22/// Use [`FetchConfig::builder()`] to construct.
23///
24/// # Example
25///
26/// ```rust
27/// # fn example() -> Result<(), hf_fetch_model::FetchError> {
28/// use hf_fetch_model::FetchConfig;
29///
30/// let config = FetchConfig::builder()
31///     .revision("main")
32///     .filter("*.safetensors")
33///     .concurrency(4)
34///     .build()?;
35/// # Ok(())
36/// # }
37/// ```
38pub struct FetchConfig {
39    pub(crate) revision: Option<String>,
40    pub(crate) token: Option<String>,
41    pub(crate) include: Option<GlobSet>,
42    pub(crate) exclude: Option<GlobSet>,
43    pub(crate) concurrency: usize,
44    pub(crate) output_dir: Option<PathBuf>,
45    pub(crate) timeout_per_file: Option<Duration>,
46    pub(crate) timeout_total: Option<Duration>,
47    pub(crate) max_retries: u32,
48    pub(crate) verify_checksums: bool,
49    pub(crate) chunk_threshold: u64,
50    pub(crate) connections_per_file: usize,
51    // TRAIT_OBJECT: heterogeneous progress handlers from different callers
52    pub(crate) on_progress: Option<ProgressCallback>,
53}
54
55impl std::fmt::Debug for FetchConfig {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("FetchConfig")
58            .field("revision", &self.revision)
59            .field("token", &self.token.as_ref().map(|_| "***"))
60            .field("include", &self.include)
61            .field("exclude", &self.exclude)
62            .field("concurrency", &self.concurrency)
63            .field("output_dir", &self.output_dir)
64            .field("timeout_per_file", &self.timeout_per_file)
65            .field("timeout_total", &self.timeout_total)
66            .field("max_retries", &self.max_retries)
67            .field("verify_checksums", &self.verify_checksums)
68            .field("chunk_threshold", &self.chunk_threshold)
69            .field("connections_per_file", &self.connections_per_file)
70            .field(
71                "on_progress",
72                if self.on_progress.is_some() {
73                    &"Some(<fn>)"
74                } else {
75                    &"None"
76                },
77            )
78            .finish()
79    }
80}
81
82impl FetchConfig {
83    /// Creates a new [`FetchConfigBuilder`].
84    #[must_use]
85    pub fn builder() -> FetchConfigBuilder {
86        FetchConfigBuilder::default()
87    }
88}
89
90/// Builder for [`FetchConfig`].
91#[derive(Default)]
92pub struct FetchConfigBuilder {
93    revision: Option<String>,
94    token: Option<String>,
95    include_patterns: Vec<String>,
96    exclude_patterns: Vec<String>,
97    concurrency: Option<usize>,
98    output_dir: Option<PathBuf>,
99    timeout_per_file: Option<Duration>,
100    timeout_total: Option<Duration>,
101    max_retries: Option<u32>,
102    verify_checksums: Option<bool>,
103    chunk_threshold: Option<u64>,
104    connections_per_file: Option<usize>,
105    on_progress: Option<ProgressCallback>,
106}
107
108impl FetchConfigBuilder {
109    /// Sets the git revision (branch, tag, or commit SHA) to download.
110    ///
111    /// Defaults to `"main"` if not set.
112    #[must_use]
113    pub fn revision(mut self, revision: &str) -> Self {
114        self.revision = Some(revision.to_owned());
115        self
116    }
117
118    /// Sets the authentication token.
119    #[must_use]
120    pub fn token(mut self, token: &str) -> Self {
121        self.token = Some(token.to_owned());
122        self
123    }
124
125    /// Reads the authentication token from the `HF_TOKEN` environment variable.
126    #[must_use]
127    pub fn token_from_env(mut self) -> Self {
128        self.token = std::env::var("HF_TOKEN").ok();
129        self
130    }
131
132    /// Adds an include glob pattern. Only files matching at least one
133    /// include pattern will be downloaded.
134    ///
135    /// Can be called multiple times to add multiple patterns.
136    #[must_use]
137    pub fn filter(mut self, pattern: &str) -> Self {
138        self.include_patterns.push(pattern.to_owned());
139        self
140    }
141
142    /// Adds an exclude glob pattern. Files matching any exclude pattern
143    /// will be skipped, even if they match an include pattern.
144    ///
145    /// Can be called multiple times to add multiple patterns.
146    #[must_use]
147    pub fn exclude(mut self, pattern: &str) -> Self {
148        self.exclude_patterns.push(pattern.to_owned());
149        self
150    }
151
152    /// Sets the number of files to download concurrently.
153    ///
154    /// Defaults to 4.
155    #[must_use]
156    pub fn concurrency(mut self, concurrency: usize) -> Self {
157        self.concurrency = Some(concurrency);
158        self
159    }
160
161    /// Sets a custom output directory for downloaded files.
162    ///
163    /// By default, files are stored in the standard `HuggingFace` cache directory
164    /// (`~/.cache/huggingface/hub/`). When set, the `HuggingFace` cache hierarchy
165    /// is created inside this directory instead.
166    #[must_use]
167    pub fn output_dir(mut self, dir: PathBuf) -> Self {
168        self.output_dir = Some(dir);
169        self
170    }
171
172    /// Sets the maximum time allowed per file download.
173    ///
174    /// If a single file download exceeds this duration, it is aborted
175    /// and may be retried according to the retry policy.
176    #[must_use]
177    pub fn timeout_per_file(mut self, duration: Duration) -> Self {
178        self.timeout_per_file = Some(duration);
179        self
180    }
181
182    /// Sets the maximum total time for the entire download operation.
183    ///
184    /// If the total download time exceeds this duration, remaining files
185    /// are skipped and a [`FetchError::Timeout`] is returned.
186    #[must_use]
187    pub fn timeout_total(mut self, duration: Duration) -> Self {
188        self.timeout_total = Some(duration);
189        self
190    }
191
192    /// Sets the maximum number of retry attempts per file.
193    ///
194    /// Defaults to 3. Set to 0 to disable retries.
195    /// Uses exponential backoff with jitter (base 300ms, cap 10s).
196    #[must_use]
197    pub fn max_retries(mut self, retries: u32) -> Self {
198        self.max_retries = Some(retries);
199        self
200    }
201
202    /// Enables or disables SHA256 checksum verification after download.
203    ///
204    /// When enabled, downloaded files are verified against the SHA256 hash
205    /// from `HuggingFace` LFS metadata. Files without LFS metadata (small
206    /// config files stored directly in git) are skipped.
207    ///
208    /// Defaults to `true`.
209    #[must_use]
210    pub fn verify_checksums(mut self, verify: bool) -> Self {
211        self.verify_checksums = Some(verify);
212        self
213    }
214
215    /// Sets the minimum file size (in bytes) for chunked parallel download.
216    ///
217    /// Files at or above this threshold are downloaded using multiple HTTP
218    /// Range connections in parallel. Files below use the standard single
219    /// connection. Defaults to 100 MiB (104\_857\_600 bytes). Set to
220    /// `u64::MAX` to disable chunked downloads entirely.
221    #[must_use]
222    pub fn chunk_threshold(mut self, bytes: u64) -> Self {
223        self.chunk_threshold = Some(bytes);
224        self
225    }
226
227    /// Sets the number of parallel HTTP connections per large file.
228    ///
229    /// Only applies to files at or above `chunk_threshold`. Defaults to 8.
230    #[must_use]
231    pub fn connections_per_file(mut self, connections: usize) -> Self {
232        self.connections_per_file = Some(connections);
233        self
234    }
235
236    /// Sets a progress callback invoked for each progress event.
237    #[must_use]
238    pub fn on_progress<F>(mut self, callback: F) -> Self
239    where
240        F: Fn(&ProgressEvent) + Send + Sync + 'static,
241    {
242        self.on_progress = Some(Arc::new(callback));
243        self
244    }
245
246    /// Builds the [`FetchConfig`].
247    ///
248    /// # Errors
249    ///
250    /// Returns [`FetchError::InvalidPattern`] if any glob pattern is invalid.
251    pub fn build(self) -> Result<FetchConfig, FetchError> {
252        let include = build_globset(&self.include_patterns)?;
253        let exclude = build_globset(&self.exclude_patterns)?;
254
255        Ok(FetchConfig {
256            revision: self.revision,
257            token: self.token,
258            include,
259            exclude,
260            concurrency: self.concurrency.unwrap_or(4),
261            output_dir: self.output_dir,
262            timeout_per_file: self.timeout_per_file,
263            timeout_total: self.timeout_total,
264            max_retries: self.max_retries.unwrap_or(3),
265            verify_checksums: self.verify_checksums.unwrap_or(true),
266            chunk_threshold: self.chunk_threshold.unwrap_or(104_857_600),
267            connections_per_file: self.connections_per_file.unwrap_or(8).max(1),
268            on_progress: self.on_progress,
269        })
270    }
271}
272
273/// Common filter presets for typical download patterns.
274#[non_exhaustive]
275pub struct Filter;
276
277impl Filter {
278    /// Returns a builder pre-configured to download only `*.safetensors` files
279    /// plus common config files.
280    #[must_use]
281    pub fn safetensors() -> FetchConfigBuilder {
282        FetchConfigBuilder::default()
283            .filter("*.safetensors")
284            .filter("*.json")
285            .filter("*.txt")
286    }
287
288    /// Returns a builder pre-configured to download only GGUF files
289    /// plus common config files.
290    #[must_use]
291    pub fn gguf() -> FetchConfigBuilder {
292        FetchConfigBuilder::default()
293            .filter("*.gguf")
294            .filter("*.json")
295            .filter("*.txt")
296    }
297
298    /// Returns a builder pre-configured to download only config files
299    /// (no model weights).
300    #[must_use]
301    pub fn config_only() -> FetchConfigBuilder {
302        FetchConfigBuilder::default()
303            .filter("*.json")
304            .filter("*.txt")
305            .filter("*.md")
306    }
307}
308
309/// Returns whether a filename passes the include/exclude filters.
310#[must_use]
311pub(crate) fn file_matches(
312    filename: &str,
313    include: Option<&GlobSet>,
314    exclude: Option<&GlobSet>,
315) -> bool {
316    if let Some(exc) = exclude {
317        if exc.is_match(filename) {
318            return false;
319        }
320    }
321    if let Some(inc) = include {
322        return inc.is_match(filename);
323    }
324    true
325}
326
327fn build_globset(patterns: &[String]) -> Result<Option<GlobSet>, FetchError> {
328    if patterns.is_empty() {
329        return Ok(None);
330    }
331    let mut builder = GlobSetBuilder::new();
332    for pattern in patterns {
333        // BORROW: explicit .as_str() instead of Deref coercion
334        let glob = Glob::new(pattern.as_str()).map_err(|e| FetchError::InvalidPattern {
335            pattern: pattern.clone(),
336            reason: e.to_string(),
337        })?;
338        builder.add(glob);
339    }
340    let set = builder.build().map_err(|e| FetchError::InvalidPattern {
341        pattern: patterns.join(", "),
342        reason: e.to_string(),
343    })?;
344    Ok(Some(set))
345}
346
347#[cfg(test)]
348mod tests {
349    #![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
350
351    use super::*;
352
353    #[test]
354    fn test_file_matches_no_filters() {
355        assert!(file_matches("model.safetensors", None, None));
356    }
357
358    #[test]
359    fn test_file_matches_include() {
360        let include = build_globset(&["*.safetensors".to_owned()]).unwrap();
361        assert!(file_matches("model.safetensors", include.as_ref(), None));
362        assert!(!file_matches("model.bin", include.as_ref(), None));
363    }
364
365    #[test]
366    fn test_file_matches_exclude() {
367        let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
368        assert!(file_matches("model.safetensors", None, exclude.as_ref()));
369        assert!(!file_matches("model.bin", None, exclude.as_ref()));
370    }
371
372    #[test]
373    fn test_exclude_overrides_include() {
374        let include = build_globset(&["*.safetensors".to_owned(), "*.bin".to_owned()]).unwrap();
375        let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
376        assert!(file_matches(
377            "model.safetensors",
378            include.as_ref(),
379            exclude.as_ref()
380        ));
381        assert!(!file_matches(
382            "model.bin",
383            include.as_ref(),
384            exclude.as_ref()
385        ));
386    }
387}