Skip to main content

hf_fetch_model/
config.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Configuration for model downloads.
4//!
5//! [`FetchConfig`] controls revision, authentication, file filtering,
6//! concurrency, timeouts, retry behavior, and progress reporting.
7
8use std::path::PathBuf;
9use std::sync::Arc;
10use std::time::Duration;
11
12use globset::{Glob, GlobSet, GlobSetBuilder};
13
14use crate::error::FetchError;
15use crate::progress::ProgressEvent;
16
17// TRAIT_OBJECT: heterogeneous progress handlers from different callers
18type ProgressCallback = Arc<dyn Fn(&ProgressEvent) + Send + Sync>;
19
20/// Configuration for downloading a model repository.
21///
22/// Use [`FetchConfig::builder()`] to construct.
23///
24/// # Example
25///
26/// ```rust
27/// # fn example() -> Result<(), hf_fetch_model::FetchError> {
28/// use hf_fetch_model::FetchConfig;
29///
30/// let config = FetchConfig::builder()
31///     .revision("main")
32///     .filter("*.safetensors")
33///     .concurrency(4)
34///     .build()?;
35/// # Ok(())
36/// # }
37/// ```
38pub struct FetchConfig {
39    pub(crate) revision: Option<String>,
40    pub(crate) token: Option<String>,
41    pub(crate) include: Option<GlobSet>,
42    pub(crate) exclude: Option<GlobSet>,
43    pub(crate) concurrency: usize,
44    pub(crate) output_dir: Option<PathBuf>,
45    pub(crate) timeout_per_file: Option<Duration>,
46    pub(crate) timeout_total: Option<Duration>,
47    pub(crate) max_retries: u32,
48    pub(crate) verify_checksums: bool,
49    // TRAIT_OBJECT: heterogeneous progress handlers from different callers
50    pub(crate) on_progress: Option<ProgressCallback>,
51}
52
53impl std::fmt::Debug for FetchConfig {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        f.debug_struct("FetchConfig")
56            .field("revision", &self.revision)
57            .field("token", &self.token.as_ref().map(|_| "***"))
58            .field("include", &self.include)
59            .field("exclude", &self.exclude)
60            .field("concurrency", &self.concurrency)
61            .field("output_dir", &self.output_dir)
62            .field("timeout_per_file", &self.timeout_per_file)
63            .field("timeout_total", &self.timeout_total)
64            .field("max_retries", &self.max_retries)
65            .field("verify_checksums", &self.verify_checksums)
66            .field(
67                "on_progress",
68                if self.on_progress.is_some() {
69                    &"Some(<fn>)"
70                } else {
71                    &"None"
72                },
73            )
74            .finish()
75    }
76}
77
78impl FetchConfig {
79    /// Creates a new [`FetchConfigBuilder`].
80    #[must_use]
81    pub fn builder() -> FetchConfigBuilder {
82        FetchConfigBuilder::default()
83    }
84}
85
86/// Builder for [`FetchConfig`].
87#[derive(Default)]
88pub struct FetchConfigBuilder {
89    revision: Option<String>,
90    token: Option<String>,
91    include_patterns: Vec<String>,
92    exclude_patterns: Vec<String>,
93    concurrency: Option<usize>,
94    output_dir: Option<PathBuf>,
95    timeout_per_file: Option<Duration>,
96    timeout_total: Option<Duration>,
97    max_retries: Option<u32>,
98    verify_checksums: Option<bool>,
99    on_progress: Option<ProgressCallback>,
100}
101
102impl FetchConfigBuilder {
103    /// Sets the git revision (branch, tag, or commit SHA) to download.
104    ///
105    /// Defaults to `"main"` if not set.
106    #[must_use]
107    pub fn revision(mut self, revision: &str) -> Self {
108        self.revision = Some(revision.to_owned());
109        self
110    }
111
112    /// Sets the authentication token.
113    #[must_use]
114    pub fn token(mut self, token: &str) -> Self {
115        self.token = Some(token.to_owned());
116        self
117    }
118
119    /// Reads the authentication token from the `HF_TOKEN` environment variable.
120    #[must_use]
121    pub fn token_from_env(mut self) -> Self {
122        self.token = std::env::var("HF_TOKEN").ok();
123        self
124    }
125
126    /// Adds an include glob pattern. Only files matching at least one
127    /// include pattern will be downloaded.
128    ///
129    /// Can be called multiple times to add multiple patterns.
130    #[must_use]
131    pub fn filter(mut self, pattern: &str) -> Self {
132        self.include_patterns.push(pattern.to_owned());
133        self
134    }
135
136    /// Adds an exclude glob pattern. Files matching any exclude pattern
137    /// will be skipped, even if they match an include pattern.
138    ///
139    /// Can be called multiple times to add multiple patterns.
140    #[must_use]
141    pub fn exclude(mut self, pattern: &str) -> Self {
142        self.exclude_patterns.push(pattern.to_owned());
143        self
144    }
145
146    /// Sets the number of files to download concurrently.
147    ///
148    /// Defaults to 4.
149    #[must_use]
150    pub fn concurrency(mut self, concurrency: usize) -> Self {
151        self.concurrency = Some(concurrency);
152        self
153    }
154
155    /// Sets a custom output directory for downloaded files.
156    ///
157    /// By default, files are stored in the standard `HuggingFace` cache directory
158    /// (`~/.cache/huggingface/hub/`). When set, the `HuggingFace` cache hierarchy
159    /// is created inside this directory instead.
160    #[must_use]
161    pub fn output_dir(mut self, dir: PathBuf) -> Self {
162        self.output_dir = Some(dir);
163        self
164    }
165
166    /// Sets the maximum time allowed per file download.
167    ///
168    /// If a single file download exceeds this duration, it is aborted
169    /// and may be retried according to the retry policy.
170    #[must_use]
171    pub fn timeout_per_file(mut self, duration: Duration) -> Self {
172        self.timeout_per_file = Some(duration);
173        self
174    }
175
176    /// Sets the maximum total time for the entire download operation.
177    ///
178    /// If the total download time exceeds this duration, remaining files
179    /// are skipped and a [`FetchError::Timeout`] is returned.
180    #[must_use]
181    pub fn timeout_total(mut self, duration: Duration) -> Self {
182        self.timeout_total = Some(duration);
183        self
184    }
185
186    /// Sets the maximum number of retry attempts per file.
187    ///
188    /// Defaults to 3. Set to 0 to disable retries.
189    /// Uses exponential backoff with jitter (base 300ms, cap 10s).
190    #[must_use]
191    pub fn max_retries(mut self, retries: u32) -> Self {
192        self.max_retries = Some(retries);
193        self
194    }
195
196    /// Enables or disables SHA256 checksum verification after download.
197    ///
198    /// When enabled, downloaded files are verified against the SHA256 hash
199    /// from `HuggingFace` LFS metadata. Files without LFS metadata (small
200    /// config files stored directly in git) are skipped.
201    ///
202    /// Defaults to `true`.
203    #[must_use]
204    pub fn verify_checksums(mut self, verify: bool) -> Self {
205        self.verify_checksums = Some(verify);
206        self
207    }
208
209    /// Sets a progress callback invoked for each progress event.
210    #[must_use]
211    pub fn on_progress<F>(mut self, callback: F) -> Self
212    where
213        F: Fn(&ProgressEvent) + Send + Sync + 'static,
214    {
215        self.on_progress = Some(Arc::new(callback));
216        self
217    }
218
219    /// Builds the [`FetchConfig`].
220    ///
221    /// # Errors
222    ///
223    /// Returns [`FetchError::InvalidPattern`] if any glob pattern is invalid.
224    pub fn build(self) -> Result<FetchConfig, FetchError> {
225        let include = build_globset(&self.include_patterns)?;
226        let exclude = build_globset(&self.exclude_patterns)?;
227
228        Ok(FetchConfig {
229            revision: self.revision,
230            token: self.token,
231            include,
232            exclude,
233            concurrency: self.concurrency.unwrap_or(4),
234            output_dir: self.output_dir,
235            timeout_per_file: self.timeout_per_file,
236            timeout_total: self.timeout_total,
237            max_retries: self.max_retries.unwrap_or(3),
238            verify_checksums: self.verify_checksums.unwrap_or(true),
239            on_progress: self.on_progress,
240        })
241    }
242}
243
244/// Common filter presets for typical download patterns.
245#[non_exhaustive]
246pub struct Filter;
247
248impl Filter {
249    /// Returns a builder pre-configured to download only `*.safetensors` files
250    /// plus common config files.
251    #[must_use]
252    pub fn safetensors() -> FetchConfigBuilder {
253        FetchConfigBuilder::default()
254            .filter("*.safetensors")
255            .filter("*.json")
256            .filter("*.txt")
257    }
258
259    /// Returns a builder pre-configured to download only GGUF files
260    /// plus common config files.
261    #[must_use]
262    pub fn gguf() -> FetchConfigBuilder {
263        FetchConfigBuilder::default()
264            .filter("*.gguf")
265            .filter("*.json")
266            .filter("*.txt")
267    }
268
269    /// Returns a builder pre-configured to download only config files
270    /// (no model weights).
271    #[must_use]
272    pub fn config_only() -> FetchConfigBuilder {
273        FetchConfigBuilder::default()
274            .filter("*.json")
275            .filter("*.txt")
276            .filter("*.md")
277    }
278}
279
280/// Returns whether a filename passes the include/exclude filters.
281#[must_use]
282pub(crate) fn file_matches(
283    filename: &str,
284    include: Option<&GlobSet>,
285    exclude: Option<&GlobSet>,
286) -> bool {
287    if let Some(exc) = exclude {
288        if exc.is_match(filename) {
289            return false;
290        }
291    }
292    if let Some(inc) = include {
293        return inc.is_match(filename);
294    }
295    true
296}
297
298fn build_globset(patterns: &[String]) -> Result<Option<GlobSet>, FetchError> {
299    if patterns.is_empty() {
300        return Ok(None);
301    }
302    let mut builder = GlobSetBuilder::new();
303    for pattern in patterns {
304        // BORROW: explicit .as_str() instead of Deref coercion
305        let glob = Glob::new(pattern.as_str()).map_err(|e| FetchError::InvalidPattern {
306            pattern: pattern.clone(),
307            reason: e.to_string(),
308        })?;
309        builder.add(glob);
310    }
311    let set = builder.build().map_err(|e| FetchError::InvalidPattern {
312        pattern: patterns.join(", "),
313        reason: e.to_string(),
314    })?;
315    Ok(Some(set))
316}
317
318#[cfg(test)]
319mod tests {
320    #![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
321
322    use super::*;
323
324    #[test]
325    fn test_file_matches_no_filters() {
326        assert!(file_matches("model.safetensors", None, None));
327    }
328
329    #[test]
330    fn test_file_matches_include() {
331        let include = build_globset(&["*.safetensors".to_owned()]).unwrap();
332        assert!(file_matches("model.safetensors", include.as_ref(), None));
333        assert!(!file_matches("model.bin", include.as_ref(), None));
334    }
335
336    #[test]
337    fn test_file_matches_exclude() {
338        let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
339        assert!(file_matches("model.safetensors", None, exclude.as_ref()));
340        assert!(!file_matches("model.bin", None, exclude.as_ref()));
341    }
342
343    #[test]
344    fn test_exclude_overrides_include() {
345        let include = build_globset(&["*.safetensors".to_owned(), "*.bin".to_owned()]).unwrap();
346        let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
347        assert!(file_matches(
348            "model.safetensors",
349            include.as_ref(),
350            exclude.as_ref()
351        ));
352        assert!(!file_matches(
353            "model.bin",
354            include.as_ref(),
355            exclude.as_ref()
356        ));
357    }
358}