hf_fetch_model/config.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Configuration for model downloads.
4//!
5//! [`FetchConfig`] controls revision, authentication, file filtering,
6//! concurrency, timeouts, retry behavior, and progress reporting.
7
8use std::path::PathBuf;
9use std::sync::Arc;
10use std::time::Duration;
11
12use globset::{Glob, GlobSet, GlobSetBuilder};
13
14use crate::error::FetchError;
15use crate::progress::ProgressEvent;
16
17// TRAIT_OBJECT: heterogeneous progress handlers from different callers
18pub(crate) type ProgressCallback = Arc<dyn Fn(&ProgressEvent) + Send + Sync>;
19
20/// Configuration for downloading a model repository.
21///
22/// Use [`FetchConfig::builder()`] to construct.
23///
24/// # Example
25///
26/// ```rust
27/// # fn example() -> Result<(), hf_fetch_model::FetchError> {
28/// use hf_fetch_model::FetchConfig;
29///
30/// let config = FetchConfig::builder()
31/// .revision("main")
32/// .filter("*.safetensors")
33/// .concurrency(4)
34/// .build()?;
35/// # Ok(())
36/// # }
37/// ```
38pub struct FetchConfig {
39 /// Git revision (branch, tag, or commit SHA). `None` means `"main"`.
40 pub(crate) revision: Option<String>,
41 /// Authentication token for gated/private repositories.
42 pub(crate) token: Option<String>,
43 /// Compiled include glob patterns. Only matching files are downloaded.
44 pub(crate) include: Option<GlobSet>,
45 /// Compiled exclude glob patterns. Matching files are skipped.
46 pub(crate) exclude: Option<GlobSet>,
47 /// Number of files to download in parallel.
48 pub(crate) concurrency: usize,
49 /// Custom cache directory (overrides the default HF cache).
50 pub(crate) output_dir: Option<PathBuf>,
51 /// Maximum time allowed for a single file download.
52 pub(crate) timeout_per_file: Option<Duration>,
53 /// Maximum total time for the entire download operation.
54 pub(crate) timeout_total: Option<Duration>,
55 /// Maximum retry attempts per file (exponential backoff with jitter).
56 pub(crate) max_retries: u32,
57 /// Whether to verify SHA256 checksums against HF LFS metadata.
58 pub(crate) verify_checksums: bool,
59 /// Minimum file size (bytes) for multi-connection chunked download.
60 pub(crate) chunk_threshold: u64,
61 /// Number of parallel HTTP Range connections per large file.
62 pub(crate) connections_per_file: usize,
63 // TRAIT_OBJECT: heterogeneous progress handlers from different callers
64 /// Progress callback invoked for each download event.
65 pub(crate) on_progress: Option<ProgressCallback>,
66 /// Tracks which performance fields the user explicitly set via the builder.
67 pub(crate) explicit: ExplicitSettings,
68}
69
70/// Tracks which performance fields the user explicitly set via the builder.
71///
72/// Used by the implicit plan retrofit: when a field was not explicitly set,
73/// the plan-based optimizer may override it with a recommended value.
74#[derive(Debug, Clone, Default)]
75pub(crate) struct ExplicitSettings {
76 /// Whether `concurrency` was explicitly set by the caller.
77 pub(crate) concurrency: bool,
78 /// Whether `connections_per_file` was explicitly set by the caller.
79 pub(crate) connections_per_file: bool,
80 /// Whether `chunk_threshold` was explicitly set by the caller.
81 pub(crate) chunk_threshold: bool,
82}
83
84impl std::fmt::Debug for FetchConfig {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_struct("FetchConfig")
87 .field("revision", &self.revision)
88 .field("token", &self.token.as_ref().map(|_| "***"))
89 .field("include", &self.include)
90 .field("exclude", &self.exclude)
91 .field("concurrency", &self.concurrency)
92 .field("output_dir", &self.output_dir)
93 .field("timeout_per_file", &self.timeout_per_file)
94 .field("timeout_total", &self.timeout_total)
95 .field("max_retries", &self.max_retries)
96 .field("verify_checksums", &self.verify_checksums)
97 .field("chunk_threshold", &self.chunk_threshold)
98 .field("connections_per_file", &self.connections_per_file)
99 .field(
100 "on_progress",
101 if self.on_progress.is_some() {
102 &"Some(<fn>)"
103 } else {
104 &"None"
105 },
106 )
107 .field("explicit", &self.explicit)
108 .finish()
109 }
110}
111
112impl FetchConfig {
113 /// Creates a new [`FetchConfigBuilder`].
114 #[must_use]
115 pub fn builder() -> FetchConfigBuilder {
116 FetchConfigBuilder::default()
117 }
118
119 /// Returns the configured concurrency level (parallel file downloads).
120 #[must_use]
121 pub const fn concurrency(&self) -> usize {
122 self.concurrency
123 }
124
125 /// Returns the configured number of parallel HTTP connections per file.
126 #[must_use]
127 pub const fn connections_per_file(&self) -> usize {
128 self.connections_per_file
129 }
130
131 /// Returns the chunk threshold in bytes (minimum file size for
132 /// multi-connection chunked downloads).
133 #[must_use]
134 pub const fn chunk_threshold(&self) -> u64 {
135 self.chunk_threshold
136 }
137}
138
139/// Builder for [`FetchConfig`].
140#[derive(Default)]
141pub struct FetchConfigBuilder {
142 revision: Option<String>,
143 token: Option<String>,
144 include_patterns: Vec<String>,
145 exclude_patterns: Vec<String>,
146 concurrency: Option<usize>,
147 output_dir: Option<PathBuf>,
148 timeout_per_file: Option<Duration>,
149 timeout_total: Option<Duration>,
150 max_retries: Option<u32>,
151 verify_checksums: Option<bool>,
152 chunk_threshold: Option<u64>,
153 connections_per_file: Option<usize>,
154 on_progress: Option<ProgressCallback>,
155}
156
157impl FetchConfigBuilder {
158 /// Sets the git revision (branch, tag, or commit SHA) to download.
159 ///
160 /// Defaults to `"main"` if not set.
161 #[must_use]
162 pub fn revision(mut self, revision: &str) -> Self {
163 // BORROW: explicit .to_owned() for &str → owned String
164 self.revision = Some(revision.to_owned());
165 self
166 }
167
168 /// Sets the authentication token.
169 #[must_use]
170 pub fn token(mut self, token: &str) -> Self {
171 // BORROW: explicit .to_owned() for &str → owned String
172 self.token = Some(token.to_owned());
173 self
174 }
175
176 /// Reads the authentication token from the `HF_TOKEN` environment variable.
177 #[must_use]
178 pub fn token_from_env(mut self) -> Self {
179 self.token = std::env::var("HF_TOKEN").ok();
180 self
181 }
182
183 /// Adds an include glob pattern. Only files matching at least one
184 /// include pattern will be downloaded.
185 ///
186 /// Can be called multiple times to add multiple patterns.
187 #[must_use]
188 pub fn filter(mut self, pattern: &str) -> Self {
189 // BORROW: explicit .to_owned() for &str → owned String
190 self.include_patterns.push(pattern.to_owned());
191 self
192 }
193
194 /// Adds an exclude glob pattern. Files matching any exclude pattern
195 /// will be skipped, even if they match an include pattern.
196 ///
197 /// Can be called multiple times to add multiple patterns.
198 #[must_use]
199 pub fn exclude(mut self, pattern: &str) -> Self {
200 // BORROW: explicit .to_owned() for &str → owned String
201 self.exclude_patterns.push(pattern.to_owned());
202 self
203 }
204
205 /// Sets the number of files to download concurrently.
206 ///
207 /// When omitted, the download plan optimizer auto-tunes this value
208 /// based on file count and size distribution. Falls back to 4 if
209 /// no plan recommendation is available.
210 #[must_use]
211 pub fn concurrency(mut self, concurrency: usize) -> Self {
212 self.concurrency = Some(concurrency);
213 self
214 }
215
216 /// Sets a custom output directory for downloaded files.
217 ///
218 /// By default, files are stored in the standard `HuggingFace` cache directory
219 /// (`~/.cache/huggingface/hub/`). When set, the `HuggingFace` cache hierarchy
220 /// is created inside this directory instead.
221 #[must_use]
222 pub fn output_dir(mut self, dir: PathBuf) -> Self {
223 self.output_dir = Some(dir);
224 self
225 }
226
227 /// Sets the maximum time allowed per file download.
228 ///
229 /// Defaults to 300 seconds when not set. If a single file download
230 /// exceeds this duration, it is aborted and may be retried according
231 /// to the retry policy.
232 ///
233 /// For chunked (multi-connection) downloads, an abort here does **not**
234 /// discard work already on disk: the partial `.chunked.part` file and
235 /// its per-chunk progress sidecar are preserved, and a subsequent call
236 /// for the same file resumes from each chunk's last checkpoint via
237 /// `Range` requests — provided the upstream etag still matches. Raise
238 /// this when downloading large files on slow links (e.g.
239 /// `Duration::from_secs(1800)` for 5–15 GiB files at home-broadband
240 /// speeds).
241 #[must_use]
242 pub fn timeout_per_file(mut self, duration: Duration) -> Self {
243 self.timeout_per_file = Some(duration);
244 self
245 }
246
247 /// Sets the maximum total time for the entire download operation.
248 ///
249 /// Defaults to no limit when not set. If the total download time
250 /// exceeds this duration, remaining files are skipped and a
251 /// [`FetchError::Timeout`] is returned.
252 ///
253 /// Independent of [`timeout_per_file`](Self::timeout_per_file): the
254 /// per-file budget bounds any single transfer, while this bounds the
255 /// whole batch (including retries between transfers). Files that
256 /// completed before the total elapsed are kept; an interrupted
257 /// in-flight chunked transfer leaves its partial + sidecar on disk
258 /// for resume on a future call.
259 #[must_use]
260 pub fn timeout_total(mut self, duration: Duration) -> Self {
261 self.timeout_total = Some(duration);
262 self
263 }
264
265 /// Sets the maximum number of retry attempts per file.
266 ///
267 /// Defaults to 3. Set to 0 to disable retries.
268 /// Uses exponential backoff with jitter (base 300ms, cap 10s).
269 #[must_use]
270 pub fn max_retries(mut self, retries: u32) -> Self {
271 self.max_retries = Some(retries);
272 self
273 }
274
275 /// Enables or disables SHA256 checksum verification after download.
276 ///
277 /// When enabled, downloaded files are verified against the SHA256 hash
278 /// from `HuggingFace` LFS metadata. Files without LFS metadata (small
279 /// config files stored directly in git) are skipped.
280 ///
281 /// Defaults to `true`.
282 #[must_use]
283 pub fn verify_checksums(mut self, verify: bool) -> Self {
284 self.verify_checksums = Some(verify);
285 self
286 }
287
288 /// Sets the minimum file size (in bytes) for chunked parallel download.
289 ///
290 /// Files at or above this threshold are downloaded using multiple HTTP
291 /// Range connections in parallel. Files below use the standard single
292 /// connection. Set to `u64::MAX` to disable chunked downloads entirely.
293 ///
294 /// When omitted, the download plan optimizer auto-tunes this value
295 /// based on file size distribution. Falls back to 100 MiB
296 /// (104\_857\_600 bytes) if no plan recommendation is available.
297 #[must_use]
298 pub fn chunk_threshold(mut self, bytes: u64) -> Self {
299 self.chunk_threshold = Some(bytes);
300 self
301 }
302
303 /// Sets the number of parallel HTTP connections per large file.
304 ///
305 /// Only applies to files at or above `chunk_threshold`. When omitted,
306 /// the download plan optimizer auto-tunes this value based on file size
307 /// distribution. Falls back to 8 if no plan recommendation is available.
308 #[must_use]
309 pub fn connections_per_file(mut self, connections: usize) -> Self {
310 self.connections_per_file = Some(connections);
311 self
312 }
313
314 /// Sets a progress callback invoked for each progress event.
315 #[must_use]
316 pub fn on_progress<F>(mut self, callback: F) -> Self
317 where
318 F: Fn(&ProgressEvent) + Send + Sync + 'static,
319 {
320 self.on_progress = Some(Arc::new(callback));
321 self
322 }
323
324 /// Creates a `tokio::sync::watch` channel for async progress consumption.
325 ///
326 /// Returns `(self, receiver)`. The receiver yields the latest [`ProgressEvent`]
327 /// via `.changed().await` + `.borrow()`. Only the most recent event is retained.
328 ///
329 /// The channel is initialized with [`ProgressEvent::default()`] (all zeros,
330 /// empty filename). Use `.changed().await` rather than eager `.borrow()` to
331 /// avoid observing this sentinel value before the first real event.
332 ///
333 /// Composes with [`on_progress()`](Self::on_progress) — if a callback was
334 /// already set, both the callback and the watch channel fire for every event.
335 #[must_use]
336 pub fn progress_channel(mut self) -> (Self, crate::progress::ProgressReceiver) {
337 let (tx, rx) = tokio::sync::watch::channel(ProgressEvent::default());
338 let existing = self.on_progress.take();
339 // TRAIT_OBJECT: heterogeneous progress handlers composed with watch sender
340 self.on_progress = Some(Arc::new(move |event: &ProgressEvent| {
341 if let Some(ref cb) = existing {
342 cb(event);
343 }
344 // Skip clone + send if no receiver is listening.
345 if tx.receiver_count() > 0 {
346 // BORROW: explicit .clone() for owned ProgressEvent sent through watch channel
347 let _ = tx.send(event.clone());
348 }
349 }));
350 (self, rx)
351 }
352
353 /// Builds the [`FetchConfig`].
354 ///
355 /// # Errors
356 ///
357 /// Returns [`FetchError::InvalidPattern`] if any glob pattern is invalid.
358 pub fn build(self) -> Result<FetchConfig, FetchError> {
359 let include = build_globset(&self.include_patterns)?;
360 let exclude = build_globset(&self.exclude_patterns)?;
361
362 Ok(FetchConfig {
363 revision: self.revision,
364 token: self.token,
365 include,
366 exclude,
367 concurrency: self.concurrency.unwrap_or(4),
368 output_dir: self.output_dir,
369 timeout_per_file: self.timeout_per_file,
370 timeout_total: self.timeout_total,
371 max_retries: self.max_retries.unwrap_or(3),
372 verify_checksums: self.verify_checksums.unwrap_or(true),
373 chunk_threshold: self.chunk_threshold.unwrap_or(104_857_600),
374 connections_per_file: self.connections_per_file.unwrap_or(8).max(1),
375 on_progress: self.on_progress,
376 explicit: ExplicitSettings {
377 concurrency: self.concurrency.is_some(),
378 connections_per_file: self.connections_per_file.is_some(),
379 chunk_threshold: self.chunk_threshold.is_some(),
380 },
381 })
382 }
383}
384
385/// Common filter presets for typical download patterns.
386#[non_exhaustive]
387pub struct Filter;
388
389impl Filter {
390 /// Returns a builder pre-configured to download only `*.safetensors` files
391 /// plus common config files.
392 #[must_use]
393 pub fn safetensors() -> FetchConfigBuilder {
394 FetchConfigBuilder::default()
395 .filter("*.safetensors")
396 .filter("*.json")
397 .filter("*.txt")
398 }
399
400 /// Returns a builder pre-configured to download only GGUF files
401 /// plus common config files.
402 #[must_use]
403 pub fn gguf() -> FetchConfigBuilder {
404 FetchConfigBuilder::default()
405 .filter("*.gguf")
406 .filter("*.json")
407 .filter("*.txt")
408 }
409
410 /// Returns a builder pre-configured to download only `.npz` and `.npy`
411 /// files plus common config files. Matches NumPy-based weight repos
412 /// such as Google's `GemmaScope` transcoders (`config.yaml` + many `.npz`).
413 #[must_use]
414 pub fn npz() -> FetchConfigBuilder {
415 FetchConfigBuilder::default()
416 .filter("*.npz")
417 .filter("*.npy")
418 .filter("config.yaml")
419 .filter("*.json")
420 .filter("*.txt")
421 }
422
423 /// Returns a builder pre-configured to download only `pytorch_model*.bin`
424 /// files plus common config files.
425 #[must_use]
426 pub fn pth() -> FetchConfigBuilder {
427 FetchConfigBuilder::default()
428 .filter("pytorch_model*.bin")
429 .filter("*.json")
430 .filter("*.txt")
431 }
432
433 /// Returns a builder pre-configured to download only config files
434 /// (no model weights).
435 #[must_use]
436 pub fn config_only() -> FetchConfigBuilder {
437 FetchConfigBuilder::default()
438 .filter("*.json")
439 .filter("*.txt")
440 .filter("*.md")
441 }
442}
443
444/// Returns whether `filename` passes the given include/exclude glob filters.
445///
446/// A file matches when it is not excluded by any `exclude` pattern **and**
447/// either there are no `include` patterns or it matches at least one.
448#[must_use]
449pub fn file_matches(filename: &str, include: Option<&GlobSet>, exclude: Option<&GlobSet>) -> bool {
450 if let Some(exc) = exclude {
451 if exc.is_match(filename) {
452 return false;
453 }
454 }
455 if let Some(inc) = include {
456 return inc.is_match(filename);
457 }
458 true
459}
460
461fn build_globset(patterns: &[String]) -> Result<Option<GlobSet>, FetchError> {
462 if patterns.is_empty() {
463 return Ok(None);
464 }
465 let mut builder = GlobSetBuilder::new();
466 for pattern in patterns {
467 // BORROW: explicit .as_str() instead of Deref coercion
468 let glob = Glob::new(pattern.as_str()).map_err(|e| FetchError::InvalidPattern {
469 pattern: pattern.clone(),
470 reason: e.to_string(),
471 })?;
472 builder.add(glob);
473 }
474 let set = builder.build().map_err(|e| FetchError::InvalidPattern {
475 pattern: patterns.join(", "),
476 reason: e.to_string(),
477 })?;
478 Ok(Some(set))
479}
480
481/// Builds a compiled [`GlobSet`] from a list of pattern strings.
482///
483/// Returns `None` if the pattern list is empty. This is useful for callers
484/// that need glob filtering outside the download pipeline (e.g., the
485/// `list-files` subcommand).
486///
487/// # Errors
488///
489/// Returns [`FetchError::InvalidPattern`] if any pattern fails to compile.
490pub fn compile_glob_patterns(patterns: &[String]) -> Result<Option<GlobSet>, FetchError> {
491 build_globset(patterns)
492}
493
494/// Returns `true` if `s` contains glob metacharacters (`*`, `?`, `[`, `{`).
495///
496/// Useful for detecting whether a user-supplied filename should be treated
497/// as a glob pattern or an exact match.
498#[must_use]
499pub fn has_glob_chars(s: &str) -> bool {
500 s.bytes().any(|b| matches!(b, b'*' | b'?' | b'[' | b'{'))
501}
502
503#[cfg(test)]
504mod tests {
505 #![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
506
507 use super::*;
508
509 #[test]
510 fn test_file_matches_no_filters() {
511 assert!(file_matches("model.safetensors", None, None));
512 }
513
514 #[test]
515 fn test_file_matches_include() {
516 let include = build_globset(&["*.safetensors".to_owned()]).unwrap();
517 assert!(file_matches("model.safetensors", include.as_ref(), None));
518 assert!(!file_matches("model.bin", include.as_ref(), None));
519 }
520
521 #[test]
522 fn test_file_matches_exclude() {
523 let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
524 assert!(file_matches("model.safetensors", None, exclude.as_ref()));
525 assert!(!file_matches("model.bin", None, exclude.as_ref()));
526 }
527
528 #[test]
529 fn test_exclude_overrides_include() {
530 let include = build_globset(&["*.safetensors".to_owned(), "*.bin".to_owned()]).unwrap();
531 let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
532 assert!(file_matches(
533 "model.safetensors",
534 include.as_ref(),
535 exclude.as_ref()
536 ));
537 assert!(!file_matches(
538 "model.bin",
539 include.as_ref(),
540 exclude.as_ref()
541 ));
542 }
543
544 #[test]
545 fn test_has_glob_chars() {
546 assert!(has_glob_chars("*.safetensors"));
547 assert!(has_glob_chars("model-[0-9].bin"));
548 assert!(has_glob_chars("model?.bin"));
549 assert!(has_glob_chars("{a,b}.bin"));
550 assert!(!has_glob_chars("model.safetensors"));
551 assert!(!has_glob_chars("config.json"));
552 assert!(!has_glob_chars("pytorch_model.bin"));
553 }
554}