1use std::path::PathBuf;
9use std::sync::Arc;
10use std::time::Duration;
11
12use globset::{Glob, GlobSet, GlobSetBuilder};
13
14use crate::error::FetchError;
15use crate::progress::ProgressEvent;
16
17pub(crate) type ProgressCallback = Arc<dyn Fn(&ProgressEvent) + Send + Sync>;
19
20pub struct FetchConfig {
39 pub(crate) revision: Option<String>,
41 pub(crate) token: Option<String>,
43 pub(crate) include: Option<GlobSet>,
45 pub(crate) exclude: Option<GlobSet>,
47 pub(crate) concurrency: usize,
49 pub(crate) output_dir: Option<PathBuf>,
51 pub(crate) timeout_per_file: Option<Duration>,
53 pub(crate) timeout_total: Option<Duration>,
55 pub(crate) max_retries: u32,
57 pub(crate) verify_checksums: bool,
59 pub(crate) chunk_threshold: u64,
61 pub(crate) connections_per_file: usize,
63 pub(crate) on_progress: Option<ProgressCallback>,
66}
67
68impl std::fmt::Debug for FetchConfig {
69 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70 f.debug_struct("FetchConfig")
71 .field("revision", &self.revision)
72 .field("token", &self.token.as_ref().map(|_| "***"))
73 .field("include", &self.include)
74 .field("exclude", &self.exclude)
75 .field("concurrency", &self.concurrency)
76 .field("output_dir", &self.output_dir)
77 .field("timeout_per_file", &self.timeout_per_file)
78 .field("timeout_total", &self.timeout_total)
79 .field("max_retries", &self.max_retries)
80 .field("verify_checksums", &self.verify_checksums)
81 .field("chunk_threshold", &self.chunk_threshold)
82 .field("connections_per_file", &self.connections_per_file)
83 .field(
84 "on_progress",
85 if self.on_progress.is_some() {
86 &"Some(<fn>)"
87 } else {
88 &"None"
89 },
90 )
91 .finish()
92 }
93}
94
95impl FetchConfig {
96 #[must_use]
98 pub fn builder() -> FetchConfigBuilder {
99 FetchConfigBuilder::default()
100 }
101}
102
103#[derive(Default)]
105pub struct FetchConfigBuilder {
106 revision: Option<String>,
107 token: Option<String>,
108 include_patterns: Vec<String>,
109 exclude_patterns: Vec<String>,
110 concurrency: Option<usize>,
111 output_dir: Option<PathBuf>,
112 timeout_per_file: Option<Duration>,
113 timeout_total: Option<Duration>,
114 max_retries: Option<u32>,
115 verify_checksums: Option<bool>,
116 chunk_threshold: Option<u64>,
117 connections_per_file: Option<usize>,
118 on_progress: Option<ProgressCallback>,
119}
120
121impl FetchConfigBuilder {
122 #[must_use]
126 pub fn revision(mut self, revision: &str) -> Self {
127 self.revision = Some(revision.to_owned());
128 self
129 }
130
131 #[must_use]
133 pub fn token(mut self, token: &str) -> Self {
134 self.token = Some(token.to_owned());
135 self
136 }
137
138 #[must_use]
140 pub fn token_from_env(mut self) -> Self {
141 self.token = std::env::var("HF_TOKEN").ok();
142 self
143 }
144
145 #[must_use]
150 pub fn filter(mut self, pattern: &str) -> Self {
151 self.include_patterns.push(pattern.to_owned());
152 self
153 }
154
155 #[must_use]
160 pub fn exclude(mut self, pattern: &str) -> Self {
161 self.exclude_patterns.push(pattern.to_owned());
162 self
163 }
164
165 #[must_use]
169 pub fn concurrency(mut self, concurrency: usize) -> Self {
170 self.concurrency = Some(concurrency);
171 self
172 }
173
174 #[must_use]
180 pub fn output_dir(mut self, dir: PathBuf) -> Self {
181 self.output_dir = Some(dir);
182 self
183 }
184
185 #[must_use]
190 pub fn timeout_per_file(mut self, duration: Duration) -> Self {
191 self.timeout_per_file = Some(duration);
192 self
193 }
194
195 #[must_use]
200 pub fn timeout_total(mut self, duration: Duration) -> Self {
201 self.timeout_total = Some(duration);
202 self
203 }
204
205 #[must_use]
210 pub fn max_retries(mut self, retries: u32) -> Self {
211 self.max_retries = Some(retries);
212 self
213 }
214
215 #[must_use]
223 pub fn verify_checksums(mut self, verify: bool) -> Self {
224 self.verify_checksums = Some(verify);
225 self
226 }
227
228 #[must_use]
235 pub fn chunk_threshold(mut self, bytes: u64) -> Self {
236 self.chunk_threshold = Some(bytes);
237 self
238 }
239
240 #[must_use]
244 pub fn connections_per_file(mut self, connections: usize) -> Self {
245 self.connections_per_file = Some(connections);
246 self
247 }
248
249 #[must_use]
251 pub fn on_progress<F>(mut self, callback: F) -> Self
252 where
253 F: Fn(&ProgressEvent) + Send + Sync + 'static,
254 {
255 self.on_progress = Some(Arc::new(callback));
256 self
257 }
258
259 pub fn build(self) -> Result<FetchConfig, FetchError> {
265 let include = build_globset(&self.include_patterns)?;
266 let exclude = build_globset(&self.exclude_patterns)?;
267
268 Ok(FetchConfig {
269 revision: self.revision,
270 token: self.token,
271 include,
272 exclude,
273 concurrency: self.concurrency.unwrap_or(4),
274 output_dir: self.output_dir,
275 timeout_per_file: self.timeout_per_file,
276 timeout_total: self.timeout_total,
277 max_retries: self.max_retries.unwrap_or(3),
278 verify_checksums: self.verify_checksums.unwrap_or(true),
279 chunk_threshold: self.chunk_threshold.unwrap_or(104_857_600),
280 connections_per_file: self.connections_per_file.unwrap_or(8).max(1),
281 on_progress: self.on_progress,
282 })
283 }
284}
285
286#[non_exhaustive]
288pub struct Filter;
289
290impl Filter {
291 #[must_use]
294 pub fn safetensors() -> FetchConfigBuilder {
295 FetchConfigBuilder::default()
296 .filter("*.safetensors")
297 .filter("*.json")
298 .filter("*.txt")
299 }
300
301 #[must_use]
304 pub fn gguf() -> FetchConfigBuilder {
305 FetchConfigBuilder::default()
306 .filter("*.gguf")
307 .filter("*.json")
308 .filter("*.txt")
309 }
310
311 #[must_use]
314 pub fn config_only() -> FetchConfigBuilder {
315 FetchConfigBuilder::default()
316 .filter("*.json")
317 .filter("*.txt")
318 .filter("*.md")
319 }
320}
321
322#[must_use]
324pub(crate) fn file_matches(
325 filename: &str,
326 include: Option<&GlobSet>,
327 exclude: Option<&GlobSet>,
328) -> bool {
329 if let Some(exc) = exclude {
330 if exc.is_match(filename) {
331 return false;
332 }
333 }
334 if let Some(inc) = include {
335 return inc.is_match(filename);
336 }
337 true
338}
339
340fn build_globset(patterns: &[String]) -> Result<Option<GlobSet>, FetchError> {
341 if patterns.is_empty() {
342 return Ok(None);
343 }
344 let mut builder = GlobSetBuilder::new();
345 for pattern in patterns {
346 let glob = Glob::new(pattern.as_str()).map_err(|e| FetchError::InvalidPattern {
348 pattern: pattern.clone(),
349 reason: e.to_string(),
350 })?;
351 builder.add(glob);
352 }
353 let set = builder.build().map_err(|e| FetchError::InvalidPattern {
354 pattern: patterns.join(", "),
355 reason: e.to_string(),
356 })?;
357 Ok(Some(set))
358}
359
360#[cfg(test)]
361mod tests {
362 #![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
363
364 use super::*;
365
366 #[test]
367 fn test_file_matches_no_filters() {
368 assert!(file_matches("model.safetensors", None, None));
369 }
370
371 #[test]
372 fn test_file_matches_include() {
373 let include = build_globset(&["*.safetensors".to_owned()]).unwrap();
374 assert!(file_matches("model.safetensors", include.as_ref(), None));
375 assert!(!file_matches("model.bin", include.as_ref(), None));
376 }
377
378 #[test]
379 fn test_file_matches_exclude() {
380 let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
381 assert!(file_matches("model.safetensors", None, exclude.as_ref()));
382 assert!(!file_matches("model.bin", None, exclude.as_ref()));
383 }
384
385 #[test]
386 fn test_exclude_overrides_include() {
387 let include = build_globset(&["*.safetensors".to_owned(), "*.bin".to_owned()]).unwrap();
388 let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
389 assert!(file_matches(
390 "model.safetensors",
391 include.as_ref(),
392 exclude.as_ref()
393 ));
394 assert!(!file_matches(
395 "model.bin",
396 include.as_ref(),
397 exclude.as_ref()
398 ));
399 }
400}