Skip to main content

hf_fetch_model/
config.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Configuration for model downloads.
4//!
5//! [`FetchConfig`] controls revision, authentication, file filtering,
6//! concurrency, timeouts, retry behavior, and progress reporting.
7
8use std::path::PathBuf;
9use std::sync::Arc;
10use std::time::Duration;
11
12use globset::{Glob, GlobSet, GlobSetBuilder};
13
14use crate::error::FetchError;
15use crate::progress::ProgressEvent;
16
17// TRAIT_OBJECT: heterogeneous progress handlers from different callers
18pub(crate) type ProgressCallback = Arc<dyn Fn(&ProgressEvent) + Send + Sync>;
19
20/// Configuration for downloading a model repository.
21///
22/// Use [`FetchConfig::builder()`] to construct.
23///
24/// # Example
25///
26/// ```rust
27/// # fn example() -> Result<(), hf_fetch_model::FetchError> {
28/// use hf_fetch_model::FetchConfig;
29///
30/// let config = FetchConfig::builder()
31///     .revision("main")
32///     .filter("*.safetensors")
33///     .concurrency(4)
34///     .build()?;
35/// # Ok(())
36/// # }
37/// ```
38pub struct FetchConfig {
39    /// Git revision (branch, tag, or commit SHA). `None` means `"main"`.
40    pub(crate) revision: Option<String>,
41    /// Authentication token for gated/private repositories.
42    pub(crate) token: Option<String>,
43    /// Compiled include glob patterns. Only matching files are downloaded.
44    pub(crate) include: Option<GlobSet>,
45    /// Compiled exclude glob patterns. Matching files are skipped.
46    pub(crate) exclude: Option<GlobSet>,
47    /// Number of files to download in parallel.
48    pub(crate) concurrency: usize,
49    /// Custom cache directory (overrides the default HF cache).
50    pub(crate) output_dir: Option<PathBuf>,
51    /// Maximum time allowed for a single file download.
52    pub(crate) timeout_per_file: Option<Duration>,
53    /// Maximum total time for the entire download operation.
54    pub(crate) timeout_total: Option<Duration>,
55    /// Maximum retry attempts per file (exponential backoff with jitter).
56    pub(crate) max_retries: u32,
57    /// Whether to verify SHA256 checksums against HF LFS metadata.
58    pub(crate) verify_checksums: bool,
59    /// Minimum file size (bytes) for multi-connection chunked download.
60    pub(crate) chunk_threshold: u64,
61    /// Number of parallel HTTP Range connections per large file.
62    pub(crate) connections_per_file: usize,
63    // TRAIT_OBJECT: heterogeneous progress handlers from different callers
64    /// Progress callback invoked for each download event.
65    pub(crate) on_progress: Option<ProgressCallback>,
66    /// Tracks which performance fields the user explicitly set via the builder.
67    pub(crate) explicit: ExplicitSettings,
68}
69
70/// Tracks which performance fields the user explicitly set via the builder.
71///
72/// Used by the implicit plan retrofit: when a field was not explicitly set,
73/// the plan-based optimizer may override it with a recommended value.
74#[derive(Debug, Clone, Default)]
75pub(crate) struct ExplicitSettings {
76    /// Whether `concurrency` was explicitly set by the caller.
77    pub(crate) concurrency: bool,
78    /// Whether `connections_per_file` was explicitly set by the caller.
79    pub(crate) connections_per_file: bool,
80    /// Whether `chunk_threshold` was explicitly set by the caller.
81    pub(crate) chunk_threshold: bool,
82}
83
84impl std::fmt::Debug for FetchConfig {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        f.debug_struct("FetchConfig")
87            .field("revision", &self.revision)
88            .field("token", &self.token.as_ref().map(|_| "***"))
89            .field("include", &self.include)
90            .field("exclude", &self.exclude)
91            .field("concurrency", &self.concurrency)
92            .field("output_dir", &self.output_dir)
93            .field("timeout_per_file", &self.timeout_per_file)
94            .field("timeout_total", &self.timeout_total)
95            .field("max_retries", &self.max_retries)
96            .field("verify_checksums", &self.verify_checksums)
97            .field("chunk_threshold", &self.chunk_threshold)
98            .field("connections_per_file", &self.connections_per_file)
99            .field(
100                "on_progress",
101                if self.on_progress.is_some() {
102                    &"Some(<fn>)"
103                } else {
104                    &"None"
105                },
106            )
107            .field("explicit", &self.explicit)
108            .finish()
109    }
110}
111
112impl FetchConfig {
113    /// Creates a new [`FetchConfigBuilder`].
114    #[must_use]
115    pub fn builder() -> FetchConfigBuilder {
116        FetchConfigBuilder::default()
117    }
118
119    /// Returns the configured concurrency level (parallel file downloads).
120    #[must_use]
121    pub const fn concurrency(&self) -> usize {
122        self.concurrency
123    }
124
125    /// Returns the configured number of parallel HTTP connections per file.
126    #[must_use]
127    pub const fn connections_per_file(&self) -> usize {
128        self.connections_per_file
129    }
130
131    /// Returns the chunk threshold in bytes (minimum file size for
132    /// multi-connection chunked downloads).
133    #[must_use]
134    pub const fn chunk_threshold(&self) -> u64 {
135        self.chunk_threshold
136    }
137}
138
139/// Builder for [`FetchConfig`].
140#[derive(Default)]
141pub struct FetchConfigBuilder {
142    revision: Option<String>,
143    token: Option<String>,
144    include_patterns: Vec<String>,
145    exclude_patterns: Vec<String>,
146    concurrency: Option<usize>,
147    output_dir: Option<PathBuf>,
148    timeout_per_file: Option<Duration>,
149    timeout_total: Option<Duration>,
150    max_retries: Option<u32>,
151    verify_checksums: Option<bool>,
152    chunk_threshold: Option<u64>,
153    connections_per_file: Option<usize>,
154    on_progress: Option<ProgressCallback>,
155}
156
157impl FetchConfigBuilder {
158    /// Sets the git revision (branch, tag, or commit SHA) to download.
159    ///
160    /// Defaults to `"main"` if not set.
161    #[must_use]
162    pub fn revision(mut self, revision: &str) -> Self {
163        self.revision = Some(revision.to_owned());
164        self
165    }
166
167    /// Sets the authentication token.
168    #[must_use]
169    pub fn token(mut self, token: &str) -> Self {
170        self.token = Some(token.to_owned());
171        self
172    }
173
174    /// Reads the authentication token from the `HF_TOKEN` environment variable.
175    #[must_use]
176    pub fn token_from_env(mut self) -> Self {
177        self.token = std::env::var("HF_TOKEN").ok();
178        self
179    }
180
181    /// Adds an include glob pattern. Only files matching at least one
182    /// include pattern will be downloaded.
183    ///
184    /// Can be called multiple times to add multiple patterns.
185    #[must_use]
186    pub fn filter(mut self, pattern: &str) -> Self {
187        self.include_patterns.push(pattern.to_owned());
188        self
189    }
190
191    /// Adds an exclude glob pattern. Files matching any exclude pattern
192    /// will be skipped, even if they match an include pattern.
193    ///
194    /// Can be called multiple times to add multiple patterns.
195    #[must_use]
196    pub fn exclude(mut self, pattern: &str) -> Self {
197        self.exclude_patterns.push(pattern.to_owned());
198        self
199    }
200
201    /// Sets the number of files to download concurrently.
202    ///
203    /// When omitted, the download plan optimizer auto-tunes this value
204    /// based on file count and size distribution. Falls back to 4 if
205    /// no plan recommendation is available.
206    #[must_use]
207    pub fn concurrency(mut self, concurrency: usize) -> Self {
208        self.concurrency = Some(concurrency);
209        self
210    }
211
212    /// Sets a custom output directory for downloaded files.
213    ///
214    /// By default, files are stored in the standard `HuggingFace` cache directory
215    /// (`~/.cache/huggingface/hub/`). When set, the `HuggingFace` cache hierarchy
216    /// is created inside this directory instead.
217    #[must_use]
218    pub fn output_dir(mut self, dir: PathBuf) -> Self {
219        self.output_dir = Some(dir);
220        self
221    }
222
223    /// Sets the maximum time allowed per file download.
224    ///
225    /// If a single file download exceeds this duration, it is aborted
226    /// and may be retried according to the retry policy.
227    #[must_use]
228    pub fn timeout_per_file(mut self, duration: Duration) -> Self {
229        self.timeout_per_file = Some(duration);
230        self
231    }
232
233    /// Sets the maximum total time for the entire download operation.
234    ///
235    /// If the total download time exceeds this duration, remaining files
236    /// are skipped and a [`FetchError::Timeout`] is returned.
237    #[must_use]
238    pub fn timeout_total(mut self, duration: Duration) -> Self {
239        self.timeout_total = Some(duration);
240        self
241    }
242
243    /// Sets the maximum number of retry attempts per file.
244    ///
245    /// Defaults to 3. Set to 0 to disable retries.
246    /// Uses exponential backoff with jitter (base 300ms, cap 10s).
247    #[must_use]
248    pub fn max_retries(mut self, retries: u32) -> Self {
249        self.max_retries = Some(retries);
250        self
251    }
252
253    /// Enables or disables SHA256 checksum verification after download.
254    ///
255    /// When enabled, downloaded files are verified against the SHA256 hash
256    /// from `HuggingFace` LFS metadata. Files without LFS metadata (small
257    /// config files stored directly in git) are skipped.
258    ///
259    /// Defaults to `true`.
260    #[must_use]
261    pub fn verify_checksums(mut self, verify: bool) -> Self {
262        self.verify_checksums = Some(verify);
263        self
264    }
265
266    /// Sets the minimum file size (in bytes) for chunked parallel download.
267    ///
268    /// Files at or above this threshold are downloaded using multiple HTTP
269    /// Range connections in parallel. Files below use the standard single
270    /// connection. Set to `u64::MAX` to disable chunked downloads entirely.
271    ///
272    /// When omitted, the download plan optimizer auto-tunes this value
273    /// based on file size distribution. Falls back to 100 MiB
274    /// (104\_857\_600 bytes) if no plan recommendation is available.
275    #[must_use]
276    pub fn chunk_threshold(mut self, bytes: u64) -> Self {
277        self.chunk_threshold = Some(bytes);
278        self
279    }
280
281    /// Sets the number of parallel HTTP connections per large file.
282    ///
283    /// Only applies to files at or above `chunk_threshold`. When omitted,
284    /// the download plan optimizer auto-tunes this value based on file size
285    /// distribution. Falls back to 8 if no plan recommendation is available.
286    #[must_use]
287    pub fn connections_per_file(mut self, connections: usize) -> Self {
288        self.connections_per_file = Some(connections);
289        self
290    }
291
292    /// Sets a progress callback invoked for each progress event.
293    #[must_use]
294    pub fn on_progress<F>(mut self, callback: F) -> Self
295    where
296        F: Fn(&ProgressEvent) + Send + Sync + 'static,
297    {
298        self.on_progress = Some(Arc::new(callback));
299        self
300    }
301
302    /// Builds the [`FetchConfig`].
303    ///
304    /// # Errors
305    ///
306    /// Returns [`FetchError::InvalidPattern`] if any glob pattern is invalid.
307    pub fn build(self) -> Result<FetchConfig, FetchError> {
308        let include = build_globset(&self.include_patterns)?;
309        let exclude = build_globset(&self.exclude_patterns)?;
310
311        Ok(FetchConfig {
312            revision: self.revision,
313            token: self.token,
314            include,
315            exclude,
316            concurrency: self.concurrency.unwrap_or(4),
317            output_dir: self.output_dir,
318            timeout_per_file: self.timeout_per_file,
319            timeout_total: self.timeout_total,
320            max_retries: self.max_retries.unwrap_or(3),
321            verify_checksums: self.verify_checksums.unwrap_or(true),
322            chunk_threshold: self.chunk_threshold.unwrap_or(104_857_600),
323            connections_per_file: self.connections_per_file.unwrap_or(8).max(1),
324            on_progress: self.on_progress,
325            explicit: ExplicitSettings {
326                concurrency: self.concurrency.is_some(),
327                connections_per_file: self.connections_per_file.is_some(),
328                chunk_threshold: self.chunk_threshold.is_some(),
329            },
330        })
331    }
332}
333
334/// Common filter presets for typical download patterns.
335#[non_exhaustive]
336pub struct Filter;
337
338impl Filter {
339    /// Returns a builder pre-configured to download only `*.safetensors` files
340    /// plus common config files.
341    #[must_use]
342    pub fn safetensors() -> FetchConfigBuilder {
343        FetchConfigBuilder::default()
344            .filter("*.safetensors")
345            .filter("*.json")
346            .filter("*.txt")
347    }
348
349    /// Returns a builder pre-configured to download only GGUF files
350    /// plus common config files.
351    #[must_use]
352    pub fn gguf() -> FetchConfigBuilder {
353        FetchConfigBuilder::default()
354            .filter("*.gguf")
355            .filter("*.json")
356            .filter("*.txt")
357    }
358
359    /// Returns a builder pre-configured to download only `pytorch_model*.bin`
360    /// files plus common config files.
361    #[must_use]
362    pub fn pth() -> FetchConfigBuilder {
363        FetchConfigBuilder::default()
364            .filter("pytorch_model*.bin")
365            .filter("*.json")
366            .filter("*.txt")
367    }
368
369    /// Returns a builder pre-configured to download only config files
370    /// (no model weights).
371    #[must_use]
372    pub fn config_only() -> FetchConfigBuilder {
373        FetchConfigBuilder::default()
374            .filter("*.json")
375            .filter("*.txt")
376            .filter("*.md")
377    }
378}
379
380/// Returns whether `filename` passes the given include/exclude glob filters.
381///
382/// A file matches when it is not excluded by any `exclude` pattern **and**
383/// either there are no `include` patterns or it matches at least one.
384#[must_use]
385pub fn file_matches(filename: &str, include: Option<&GlobSet>, exclude: Option<&GlobSet>) -> bool {
386    if let Some(exc) = exclude {
387        if exc.is_match(filename) {
388            return false;
389        }
390    }
391    if let Some(inc) = include {
392        return inc.is_match(filename);
393    }
394    true
395}
396
397fn build_globset(patterns: &[String]) -> Result<Option<GlobSet>, FetchError> {
398    if patterns.is_empty() {
399        return Ok(None);
400    }
401    let mut builder = GlobSetBuilder::new();
402    for pattern in patterns {
403        // BORROW: explicit .as_str() instead of Deref coercion
404        let glob = Glob::new(pattern.as_str()).map_err(|e| FetchError::InvalidPattern {
405            pattern: pattern.clone(),
406            reason: e.to_string(),
407        })?;
408        builder.add(glob);
409    }
410    let set = builder.build().map_err(|e| FetchError::InvalidPattern {
411        pattern: patterns.join(", "),
412        reason: e.to_string(),
413    })?;
414    Ok(Some(set))
415}
416
417/// Builds a compiled [`GlobSet`] from a list of pattern strings.
418///
419/// Returns `None` if the pattern list is empty. This is useful for callers
420/// that need glob filtering outside the download pipeline (e.g., the
421/// `list-files` subcommand).
422///
423/// # Errors
424///
425/// Returns [`FetchError::InvalidPattern`] if any pattern fails to compile.
426pub fn compile_glob_patterns(patterns: &[String]) -> Result<Option<GlobSet>, FetchError> {
427    build_globset(patterns)
428}
429
430/// Returns `true` if `s` contains glob metacharacters (`*`, `?`, `[`, `{`).
431///
432/// Useful for detecting whether a user-supplied filename should be treated
433/// as a glob pattern or an exact match.
434#[must_use]
435pub fn has_glob_chars(s: &str) -> bool {
436    s.bytes().any(|b| matches!(b, b'*' | b'?' | b'[' | b'{'))
437}
438
439#[cfg(test)]
440mod tests {
441    #![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
442
443    use super::*;
444
445    #[test]
446    fn test_file_matches_no_filters() {
447        assert!(file_matches("model.safetensors", None, None));
448    }
449
450    #[test]
451    fn test_file_matches_include() {
452        let include = build_globset(&["*.safetensors".to_owned()]).unwrap();
453        assert!(file_matches("model.safetensors", include.as_ref(), None));
454        assert!(!file_matches("model.bin", include.as_ref(), None));
455    }
456
457    #[test]
458    fn test_file_matches_exclude() {
459        let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
460        assert!(file_matches("model.safetensors", None, exclude.as_ref()));
461        assert!(!file_matches("model.bin", None, exclude.as_ref()));
462    }
463
464    #[test]
465    fn test_exclude_overrides_include() {
466        let include = build_globset(&["*.safetensors".to_owned(), "*.bin".to_owned()]).unwrap();
467        let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
468        assert!(file_matches(
469            "model.safetensors",
470            include.as_ref(),
471            exclude.as_ref()
472        ));
473        assert!(!file_matches(
474            "model.bin",
475            include.as_ref(),
476            exclude.as_ref()
477        ));
478    }
479
480    #[test]
481    fn test_has_glob_chars() {
482        assert!(has_glob_chars("*.safetensors"));
483        assert!(has_glob_chars("model-[0-9].bin"));
484        assert!(has_glob_chars("model?.bin"));
485        assert!(has_glob_chars("{a,b}.bin"));
486        assert!(!has_glob_chars("model.safetensors"));
487        assert!(!has_glob_chars("config.json"));
488        assert!(!has_glob_chars("pytorch_model.bin"));
489    }
490}