Skip to main content

hf_fetch_model/
config.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Configuration for model downloads.
4//!
5//! [`FetchConfig`] controls revision, authentication, file filtering,
6//! concurrency, timeouts, retry behavior, and progress reporting.
7
8use std::path::PathBuf;
9use std::sync::Arc;
10use std::time::Duration;
11
12use globset::{Glob, GlobSet, GlobSetBuilder};
13
14use crate::error::FetchError;
15use crate::progress::ProgressEvent;
16
17// TRAIT_OBJECT: heterogeneous progress handlers from different callers
18pub(crate) type ProgressCallback = Arc<dyn Fn(&ProgressEvent) + Send + Sync>;
19
20/// Configuration for downloading a model repository.
21///
22/// Use [`FetchConfig::builder()`] to construct.
23///
24/// # Example
25///
26/// ```rust
27/// # fn example() -> Result<(), hf_fetch_model::FetchError> {
28/// use hf_fetch_model::FetchConfig;
29///
30/// let config = FetchConfig::builder()
31///     .revision("main")
32///     .filter("*.safetensors")
33///     .concurrency(4)
34///     .build()?;
35/// # Ok(())
36/// # }
37/// ```
38pub struct FetchConfig {
39    /// Git revision (branch, tag, or commit SHA). `None` means `"main"`.
40    pub(crate) revision: Option<String>,
41    /// Authentication token for gated/private repositories.
42    pub(crate) token: Option<String>,
43    /// Compiled include glob patterns. Only matching files are downloaded.
44    pub(crate) include: Option<GlobSet>,
45    /// Compiled exclude glob patterns. Matching files are skipped.
46    pub(crate) exclude: Option<GlobSet>,
47    /// Number of files to download in parallel.
48    pub(crate) concurrency: usize,
49    /// Custom cache directory (overrides the default HF cache).
50    pub(crate) output_dir: Option<PathBuf>,
51    /// Maximum time allowed for a single file download.
52    pub(crate) timeout_per_file: Option<Duration>,
53    /// Maximum total time for the entire download operation.
54    pub(crate) timeout_total: Option<Duration>,
55    /// Maximum retry attempts per file (exponential backoff with jitter).
56    pub(crate) max_retries: u32,
57    /// Whether to verify SHA256 checksums against HF LFS metadata.
58    pub(crate) verify_checksums: bool,
59    /// Minimum file size (bytes) for multi-connection chunked download.
60    pub(crate) chunk_threshold: u64,
61    /// Number of parallel HTTP Range connections per large file.
62    pub(crate) connections_per_file: usize,
63    // TRAIT_OBJECT: heterogeneous progress handlers from different callers
64    /// Progress callback invoked for each download event.
65    pub(crate) on_progress: Option<ProgressCallback>,
66}
67
68impl std::fmt::Debug for FetchConfig {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        f.debug_struct("FetchConfig")
71            .field("revision", &self.revision)
72            .field("token", &self.token.as_ref().map(|_| "***"))
73            .field("include", &self.include)
74            .field("exclude", &self.exclude)
75            .field("concurrency", &self.concurrency)
76            .field("output_dir", &self.output_dir)
77            .field("timeout_per_file", &self.timeout_per_file)
78            .field("timeout_total", &self.timeout_total)
79            .field("max_retries", &self.max_retries)
80            .field("verify_checksums", &self.verify_checksums)
81            .field("chunk_threshold", &self.chunk_threshold)
82            .field("connections_per_file", &self.connections_per_file)
83            .field(
84                "on_progress",
85                if self.on_progress.is_some() {
86                    &"Some(<fn>)"
87                } else {
88                    &"None"
89                },
90            )
91            .finish()
92    }
93}
94
95impl FetchConfig {
96    /// Creates a new [`FetchConfigBuilder`].
97    #[must_use]
98    pub fn builder() -> FetchConfigBuilder {
99        FetchConfigBuilder::default()
100    }
101}
102
103/// Builder for [`FetchConfig`].
104#[derive(Default)]
105pub struct FetchConfigBuilder {
106    revision: Option<String>,
107    token: Option<String>,
108    include_patterns: Vec<String>,
109    exclude_patterns: Vec<String>,
110    concurrency: Option<usize>,
111    output_dir: Option<PathBuf>,
112    timeout_per_file: Option<Duration>,
113    timeout_total: Option<Duration>,
114    max_retries: Option<u32>,
115    verify_checksums: Option<bool>,
116    chunk_threshold: Option<u64>,
117    connections_per_file: Option<usize>,
118    on_progress: Option<ProgressCallback>,
119}
120
121impl FetchConfigBuilder {
122    /// Sets the git revision (branch, tag, or commit SHA) to download.
123    ///
124    /// Defaults to `"main"` if not set.
125    #[must_use]
126    pub fn revision(mut self, revision: &str) -> Self {
127        self.revision = Some(revision.to_owned());
128        self
129    }
130
131    /// Sets the authentication token.
132    #[must_use]
133    pub fn token(mut self, token: &str) -> Self {
134        self.token = Some(token.to_owned());
135        self
136    }
137
138    /// Reads the authentication token from the `HF_TOKEN` environment variable.
139    #[must_use]
140    pub fn token_from_env(mut self) -> Self {
141        self.token = std::env::var("HF_TOKEN").ok();
142        self
143    }
144
145    /// Adds an include glob pattern. Only files matching at least one
146    /// include pattern will be downloaded.
147    ///
148    /// Can be called multiple times to add multiple patterns.
149    #[must_use]
150    pub fn filter(mut self, pattern: &str) -> Self {
151        self.include_patterns.push(pattern.to_owned());
152        self
153    }
154
155    /// Adds an exclude glob pattern. Files matching any exclude pattern
156    /// will be skipped, even if they match an include pattern.
157    ///
158    /// Can be called multiple times to add multiple patterns.
159    #[must_use]
160    pub fn exclude(mut self, pattern: &str) -> Self {
161        self.exclude_patterns.push(pattern.to_owned());
162        self
163    }
164
165    /// Sets the number of files to download concurrently.
166    ///
167    /// Defaults to 4.
168    #[must_use]
169    pub fn concurrency(mut self, concurrency: usize) -> Self {
170        self.concurrency = Some(concurrency);
171        self
172    }
173
174    /// Sets a custom output directory for downloaded files.
175    ///
176    /// By default, files are stored in the standard `HuggingFace` cache directory
177    /// (`~/.cache/huggingface/hub/`). When set, the `HuggingFace` cache hierarchy
178    /// is created inside this directory instead.
179    #[must_use]
180    pub fn output_dir(mut self, dir: PathBuf) -> Self {
181        self.output_dir = Some(dir);
182        self
183    }
184
185    /// Sets the maximum time allowed per file download.
186    ///
187    /// If a single file download exceeds this duration, it is aborted
188    /// and may be retried according to the retry policy.
189    #[must_use]
190    pub fn timeout_per_file(mut self, duration: Duration) -> Self {
191        self.timeout_per_file = Some(duration);
192        self
193    }
194
195    /// Sets the maximum total time for the entire download operation.
196    ///
197    /// If the total download time exceeds this duration, remaining files
198    /// are skipped and a [`FetchError::Timeout`] is returned.
199    #[must_use]
200    pub fn timeout_total(mut self, duration: Duration) -> Self {
201        self.timeout_total = Some(duration);
202        self
203    }
204
205    /// Sets the maximum number of retry attempts per file.
206    ///
207    /// Defaults to 3. Set to 0 to disable retries.
208    /// Uses exponential backoff with jitter (base 300ms, cap 10s).
209    #[must_use]
210    pub fn max_retries(mut self, retries: u32) -> Self {
211        self.max_retries = Some(retries);
212        self
213    }
214
215    /// Enables or disables SHA256 checksum verification after download.
216    ///
217    /// When enabled, downloaded files are verified against the SHA256 hash
218    /// from `HuggingFace` LFS metadata. Files without LFS metadata (small
219    /// config files stored directly in git) are skipped.
220    ///
221    /// Defaults to `true`.
222    #[must_use]
223    pub fn verify_checksums(mut self, verify: bool) -> Self {
224        self.verify_checksums = Some(verify);
225        self
226    }
227
228    /// Sets the minimum file size (in bytes) for chunked parallel download.
229    ///
230    /// Files at or above this threshold are downloaded using multiple HTTP
231    /// Range connections in parallel. Files below use the standard single
232    /// connection. Defaults to 100 MiB (104\_857\_600 bytes). Set to
233    /// `u64::MAX` to disable chunked downloads entirely.
234    #[must_use]
235    pub fn chunk_threshold(mut self, bytes: u64) -> Self {
236        self.chunk_threshold = Some(bytes);
237        self
238    }
239
240    /// Sets the number of parallel HTTP connections per large file.
241    ///
242    /// Only applies to files at or above `chunk_threshold`. Defaults to 8.
243    #[must_use]
244    pub fn connections_per_file(mut self, connections: usize) -> Self {
245        self.connections_per_file = Some(connections);
246        self
247    }
248
249    /// Sets a progress callback invoked for each progress event.
250    #[must_use]
251    pub fn on_progress<F>(mut self, callback: F) -> Self
252    where
253        F: Fn(&ProgressEvent) + Send + Sync + 'static,
254    {
255        self.on_progress = Some(Arc::new(callback));
256        self
257    }
258
259    /// Builds the [`FetchConfig`].
260    ///
261    /// # Errors
262    ///
263    /// Returns [`FetchError::InvalidPattern`] if any glob pattern is invalid.
264    pub fn build(self) -> Result<FetchConfig, FetchError> {
265        let include = build_globset(&self.include_patterns)?;
266        let exclude = build_globset(&self.exclude_patterns)?;
267
268        Ok(FetchConfig {
269            revision: self.revision,
270            token: self.token,
271            include,
272            exclude,
273            concurrency: self.concurrency.unwrap_or(4),
274            output_dir: self.output_dir,
275            timeout_per_file: self.timeout_per_file,
276            timeout_total: self.timeout_total,
277            max_retries: self.max_retries.unwrap_or(3),
278            verify_checksums: self.verify_checksums.unwrap_or(true),
279            chunk_threshold: self.chunk_threshold.unwrap_or(104_857_600),
280            connections_per_file: self.connections_per_file.unwrap_or(8).max(1),
281            on_progress: self.on_progress,
282        })
283    }
284}
285
286/// Common filter presets for typical download patterns.
287#[non_exhaustive]
288pub struct Filter;
289
290impl Filter {
291    /// Returns a builder pre-configured to download only `*.safetensors` files
292    /// plus common config files.
293    #[must_use]
294    pub fn safetensors() -> FetchConfigBuilder {
295        FetchConfigBuilder::default()
296            .filter("*.safetensors")
297            .filter("*.json")
298            .filter("*.txt")
299    }
300
301    /// Returns a builder pre-configured to download only GGUF files
302    /// plus common config files.
303    #[must_use]
304    pub fn gguf() -> FetchConfigBuilder {
305        FetchConfigBuilder::default()
306            .filter("*.gguf")
307            .filter("*.json")
308            .filter("*.txt")
309    }
310
311    /// Returns a builder pre-configured to download only config files
312    /// (no model weights).
313    #[must_use]
314    pub fn config_only() -> FetchConfigBuilder {
315        FetchConfigBuilder::default()
316            .filter("*.json")
317            .filter("*.txt")
318            .filter("*.md")
319    }
320}
321
322/// Returns whether a filename passes the include/exclude filters.
323#[must_use]
324pub(crate) fn file_matches(
325    filename: &str,
326    include: Option<&GlobSet>,
327    exclude: Option<&GlobSet>,
328) -> bool {
329    if let Some(exc) = exclude {
330        if exc.is_match(filename) {
331            return false;
332        }
333    }
334    if let Some(inc) = include {
335        return inc.is_match(filename);
336    }
337    true
338}
339
340fn build_globset(patterns: &[String]) -> Result<Option<GlobSet>, FetchError> {
341    if patterns.is_empty() {
342        return Ok(None);
343    }
344    let mut builder = GlobSetBuilder::new();
345    for pattern in patterns {
346        // BORROW: explicit .as_str() instead of Deref coercion
347        let glob = Glob::new(pattern.as_str()).map_err(|e| FetchError::InvalidPattern {
348            pattern: pattern.clone(),
349            reason: e.to_string(),
350        })?;
351        builder.add(glob);
352    }
353    let set = builder.build().map_err(|e| FetchError::InvalidPattern {
354        pattern: patterns.join(", "),
355        reason: e.to_string(),
356    })?;
357    Ok(Some(set))
358}
359
360#[cfg(test)]
361mod tests {
362    #![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
363
364    use super::*;
365
366    #[test]
367    fn test_file_matches_no_filters() {
368        assert!(file_matches("model.safetensors", None, None));
369    }
370
371    #[test]
372    fn test_file_matches_include() {
373        let include = build_globset(&["*.safetensors".to_owned()]).unwrap();
374        assert!(file_matches("model.safetensors", include.as_ref(), None));
375        assert!(!file_matches("model.bin", include.as_ref(), None));
376    }
377
378    #[test]
379    fn test_file_matches_exclude() {
380        let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
381        assert!(file_matches("model.safetensors", None, exclude.as_ref()));
382        assert!(!file_matches("model.bin", None, exclude.as_ref()));
383    }
384
385    #[test]
386    fn test_exclude_overrides_include() {
387        let include = build_globset(&["*.safetensors".to_owned(), "*.bin".to_owned()]).unwrap();
388        let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
389        assert!(file_matches(
390            "model.safetensors",
391            include.as_ref(),
392            exclude.as_ref()
393        ));
394        assert!(!file_matches(
395            "model.bin",
396            include.as_ref(),
397            exclude.as_ref()
398        ));
399    }
400}