1use std::path::PathBuf;
9use std::sync::Arc;
10use std::time::Duration;
11
12use globset::{Glob, GlobSet, GlobSetBuilder};
13
14use crate::error::FetchError;
15use crate::progress::ProgressEvent;
16
17pub(crate) type ProgressCallback = Arc<dyn Fn(&ProgressEvent) + Send + Sync>;
19
20pub struct FetchConfig {
39 pub(crate) revision: Option<String>,
41 pub(crate) token: Option<String>,
43 pub(crate) include: Option<GlobSet>,
45 pub(crate) exclude: Option<GlobSet>,
47 pub(crate) concurrency: usize,
49 pub(crate) output_dir: Option<PathBuf>,
51 pub(crate) timeout_per_file: Option<Duration>,
53 pub(crate) timeout_total: Option<Duration>,
55 pub(crate) max_retries: u32,
57 pub(crate) verify_checksums: bool,
59 pub(crate) chunk_threshold: u64,
61 pub(crate) connections_per_file: usize,
63 pub(crate) on_progress: Option<ProgressCallback>,
66 pub(crate) explicit: ExplicitSettings,
68}
69
70#[derive(Debug, Clone, Default)]
75pub(crate) struct ExplicitSettings {
76 pub(crate) concurrency: bool,
78 pub(crate) connections_per_file: bool,
80 pub(crate) chunk_threshold: bool,
82}
83
84impl std::fmt::Debug for FetchConfig {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 f.debug_struct("FetchConfig")
87 .field("revision", &self.revision)
88 .field("token", &self.token.as_ref().map(|_| "***"))
89 .field("include", &self.include)
90 .field("exclude", &self.exclude)
91 .field("concurrency", &self.concurrency)
92 .field("output_dir", &self.output_dir)
93 .field("timeout_per_file", &self.timeout_per_file)
94 .field("timeout_total", &self.timeout_total)
95 .field("max_retries", &self.max_retries)
96 .field("verify_checksums", &self.verify_checksums)
97 .field("chunk_threshold", &self.chunk_threshold)
98 .field("connections_per_file", &self.connections_per_file)
99 .field(
100 "on_progress",
101 if self.on_progress.is_some() {
102 &"Some(<fn>)"
103 } else {
104 &"None"
105 },
106 )
107 .field("explicit", &self.explicit)
108 .finish()
109 }
110}
111
112impl FetchConfig {
113 #[must_use]
115 pub fn builder() -> FetchConfigBuilder {
116 FetchConfigBuilder::default()
117 }
118
119 #[must_use]
121 pub const fn concurrency(&self) -> usize {
122 self.concurrency
123 }
124
125 #[must_use]
127 pub const fn connections_per_file(&self) -> usize {
128 self.connections_per_file
129 }
130
131 #[must_use]
134 pub const fn chunk_threshold(&self) -> u64 {
135 self.chunk_threshold
136 }
137}
138
139#[derive(Default)]
141pub struct FetchConfigBuilder {
142 revision: Option<String>,
143 token: Option<String>,
144 include_patterns: Vec<String>,
145 exclude_patterns: Vec<String>,
146 concurrency: Option<usize>,
147 output_dir: Option<PathBuf>,
148 timeout_per_file: Option<Duration>,
149 timeout_total: Option<Duration>,
150 max_retries: Option<u32>,
151 verify_checksums: Option<bool>,
152 chunk_threshold: Option<u64>,
153 connections_per_file: Option<usize>,
154 on_progress: Option<ProgressCallback>,
155}
156
157impl FetchConfigBuilder {
158 #[must_use]
162 pub fn revision(mut self, revision: &str) -> Self {
163 self.revision = Some(revision.to_owned());
164 self
165 }
166
167 #[must_use]
169 pub fn token(mut self, token: &str) -> Self {
170 self.token = Some(token.to_owned());
171 self
172 }
173
174 #[must_use]
176 pub fn token_from_env(mut self) -> Self {
177 self.token = std::env::var("HF_TOKEN").ok();
178 self
179 }
180
181 #[must_use]
186 pub fn filter(mut self, pattern: &str) -> Self {
187 self.include_patterns.push(pattern.to_owned());
188 self
189 }
190
191 #[must_use]
196 pub fn exclude(mut self, pattern: &str) -> Self {
197 self.exclude_patterns.push(pattern.to_owned());
198 self
199 }
200
201 #[must_use]
207 pub fn concurrency(mut self, concurrency: usize) -> Self {
208 self.concurrency = Some(concurrency);
209 self
210 }
211
212 #[must_use]
218 pub fn output_dir(mut self, dir: PathBuf) -> Self {
219 self.output_dir = Some(dir);
220 self
221 }
222
223 #[must_use]
228 pub fn timeout_per_file(mut self, duration: Duration) -> Self {
229 self.timeout_per_file = Some(duration);
230 self
231 }
232
233 #[must_use]
238 pub fn timeout_total(mut self, duration: Duration) -> Self {
239 self.timeout_total = Some(duration);
240 self
241 }
242
243 #[must_use]
248 pub fn max_retries(mut self, retries: u32) -> Self {
249 self.max_retries = Some(retries);
250 self
251 }
252
253 #[must_use]
261 pub fn verify_checksums(mut self, verify: bool) -> Self {
262 self.verify_checksums = Some(verify);
263 self
264 }
265
266 #[must_use]
276 pub fn chunk_threshold(mut self, bytes: u64) -> Self {
277 self.chunk_threshold = Some(bytes);
278 self
279 }
280
281 #[must_use]
287 pub fn connections_per_file(mut self, connections: usize) -> Self {
288 self.connections_per_file = Some(connections);
289 self
290 }
291
292 #[must_use]
294 pub fn on_progress<F>(mut self, callback: F) -> Self
295 where
296 F: Fn(&ProgressEvent) + Send + Sync + 'static,
297 {
298 self.on_progress = Some(Arc::new(callback));
299 self
300 }
301
302 pub fn build(self) -> Result<FetchConfig, FetchError> {
308 let include = build_globset(&self.include_patterns)?;
309 let exclude = build_globset(&self.exclude_patterns)?;
310
311 Ok(FetchConfig {
312 revision: self.revision,
313 token: self.token,
314 include,
315 exclude,
316 concurrency: self.concurrency.unwrap_or(4),
317 output_dir: self.output_dir,
318 timeout_per_file: self.timeout_per_file,
319 timeout_total: self.timeout_total,
320 max_retries: self.max_retries.unwrap_or(3),
321 verify_checksums: self.verify_checksums.unwrap_or(true),
322 chunk_threshold: self.chunk_threshold.unwrap_or(104_857_600),
323 connections_per_file: self.connections_per_file.unwrap_or(8).max(1),
324 on_progress: self.on_progress,
325 explicit: ExplicitSettings {
326 concurrency: self.concurrency.is_some(),
327 connections_per_file: self.connections_per_file.is_some(),
328 chunk_threshold: self.chunk_threshold.is_some(),
329 },
330 })
331 }
332}
333
334#[non_exhaustive]
336pub struct Filter;
337
338impl Filter {
339 #[must_use]
342 pub fn safetensors() -> FetchConfigBuilder {
343 FetchConfigBuilder::default()
344 .filter("*.safetensors")
345 .filter("*.json")
346 .filter("*.txt")
347 }
348
349 #[must_use]
352 pub fn gguf() -> FetchConfigBuilder {
353 FetchConfigBuilder::default()
354 .filter("*.gguf")
355 .filter("*.json")
356 .filter("*.txt")
357 }
358
359 #[must_use]
362 pub fn pth() -> FetchConfigBuilder {
363 FetchConfigBuilder::default()
364 .filter("pytorch_model*.bin")
365 .filter("*.json")
366 .filter("*.txt")
367 }
368
369 #[must_use]
372 pub fn config_only() -> FetchConfigBuilder {
373 FetchConfigBuilder::default()
374 .filter("*.json")
375 .filter("*.txt")
376 .filter("*.md")
377 }
378}
379
380#[must_use]
385pub fn file_matches(filename: &str, include: Option<&GlobSet>, exclude: Option<&GlobSet>) -> bool {
386 if let Some(exc) = exclude {
387 if exc.is_match(filename) {
388 return false;
389 }
390 }
391 if let Some(inc) = include {
392 return inc.is_match(filename);
393 }
394 true
395}
396
397fn build_globset(patterns: &[String]) -> Result<Option<GlobSet>, FetchError> {
398 if patterns.is_empty() {
399 return Ok(None);
400 }
401 let mut builder = GlobSetBuilder::new();
402 for pattern in patterns {
403 let glob = Glob::new(pattern.as_str()).map_err(|e| FetchError::InvalidPattern {
405 pattern: pattern.clone(),
406 reason: e.to_string(),
407 })?;
408 builder.add(glob);
409 }
410 let set = builder.build().map_err(|e| FetchError::InvalidPattern {
411 pattern: patterns.join(", "),
412 reason: e.to_string(),
413 })?;
414 Ok(Some(set))
415}
416
417pub fn compile_glob_patterns(patterns: &[String]) -> Result<Option<GlobSet>, FetchError> {
427 build_globset(patterns)
428}
429
430#[must_use]
435pub fn has_glob_chars(s: &str) -> bool {
436 s.bytes().any(|b| matches!(b, b'*' | b'?' | b'[' | b'{'))
437}
438
439#[cfg(test)]
440mod tests {
441 #![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
442
443 use super::*;
444
445 #[test]
446 fn test_file_matches_no_filters() {
447 assert!(file_matches("model.safetensors", None, None));
448 }
449
450 #[test]
451 fn test_file_matches_include() {
452 let include = build_globset(&["*.safetensors".to_owned()]).unwrap();
453 assert!(file_matches("model.safetensors", include.as_ref(), None));
454 assert!(!file_matches("model.bin", include.as_ref(), None));
455 }
456
457 #[test]
458 fn test_file_matches_exclude() {
459 let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
460 assert!(file_matches("model.safetensors", None, exclude.as_ref()));
461 assert!(!file_matches("model.bin", None, exclude.as_ref()));
462 }
463
464 #[test]
465 fn test_exclude_overrides_include() {
466 let include = build_globset(&["*.safetensors".to_owned(), "*.bin".to_owned()]).unwrap();
467 let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
468 assert!(file_matches(
469 "model.safetensors",
470 include.as_ref(),
471 exclude.as_ref()
472 ));
473 assert!(!file_matches(
474 "model.bin",
475 include.as_ref(),
476 exclude.as_ref()
477 ));
478 }
479
480 #[test]
481 fn test_has_glob_chars() {
482 assert!(has_glob_chars("*.safetensors"));
483 assert!(has_glob_chars("model-[0-9].bin"));
484 assert!(has_glob_chars("model?.bin"));
485 assert!(has_glob_chars("{a,b}.bin"));
486 assert!(!has_glob_chars("model.safetensors"));
487 assert!(!has_glob_chars("config.json"));
488 assert!(!has_glob_chars("pytorch_model.bin"));
489 }
490}