1use std::path::PathBuf;
9use std::sync::Arc;
10use std::time::Duration;
11
12use globset::{Glob, GlobSet, GlobSetBuilder};
13
14use crate::error::FetchError;
15use crate::progress::ProgressEvent;
16
17pub(crate) type ProgressCallback = Arc<dyn Fn(&ProgressEvent) + Send + Sync>;
19
20pub struct FetchConfig {
39 pub(crate) revision: Option<String>,
40 pub(crate) token: Option<String>,
41 pub(crate) include: Option<GlobSet>,
42 pub(crate) exclude: Option<GlobSet>,
43 pub(crate) concurrency: usize,
44 pub(crate) output_dir: Option<PathBuf>,
45 pub(crate) timeout_per_file: Option<Duration>,
46 pub(crate) timeout_total: Option<Duration>,
47 pub(crate) max_retries: u32,
48 pub(crate) verify_checksums: bool,
49 pub(crate) chunk_threshold: u64,
50 pub(crate) connections_per_file: usize,
51 pub(crate) on_progress: Option<ProgressCallback>,
53}
54
55impl std::fmt::Debug for FetchConfig {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 f.debug_struct("FetchConfig")
58 .field("revision", &self.revision)
59 .field("token", &self.token.as_ref().map(|_| "***"))
60 .field("include", &self.include)
61 .field("exclude", &self.exclude)
62 .field("concurrency", &self.concurrency)
63 .field("output_dir", &self.output_dir)
64 .field("timeout_per_file", &self.timeout_per_file)
65 .field("timeout_total", &self.timeout_total)
66 .field("max_retries", &self.max_retries)
67 .field("verify_checksums", &self.verify_checksums)
68 .field("chunk_threshold", &self.chunk_threshold)
69 .field("connections_per_file", &self.connections_per_file)
70 .field(
71 "on_progress",
72 if self.on_progress.is_some() {
73 &"Some(<fn>)"
74 } else {
75 &"None"
76 },
77 )
78 .finish()
79 }
80}
81
82impl FetchConfig {
83 #[must_use]
85 pub fn builder() -> FetchConfigBuilder {
86 FetchConfigBuilder::default()
87 }
88}
89
90#[derive(Default)]
92pub struct FetchConfigBuilder {
93 revision: Option<String>,
94 token: Option<String>,
95 include_patterns: Vec<String>,
96 exclude_patterns: Vec<String>,
97 concurrency: Option<usize>,
98 output_dir: Option<PathBuf>,
99 timeout_per_file: Option<Duration>,
100 timeout_total: Option<Duration>,
101 max_retries: Option<u32>,
102 verify_checksums: Option<bool>,
103 chunk_threshold: Option<u64>,
104 connections_per_file: Option<usize>,
105 on_progress: Option<ProgressCallback>,
106}
107
108impl FetchConfigBuilder {
109 #[must_use]
113 pub fn revision(mut self, revision: &str) -> Self {
114 self.revision = Some(revision.to_owned());
115 self
116 }
117
118 #[must_use]
120 pub fn token(mut self, token: &str) -> Self {
121 self.token = Some(token.to_owned());
122 self
123 }
124
125 #[must_use]
127 pub fn token_from_env(mut self) -> Self {
128 self.token = std::env::var("HF_TOKEN").ok();
129 self
130 }
131
132 #[must_use]
137 pub fn filter(mut self, pattern: &str) -> Self {
138 self.include_patterns.push(pattern.to_owned());
139 self
140 }
141
142 #[must_use]
147 pub fn exclude(mut self, pattern: &str) -> Self {
148 self.exclude_patterns.push(pattern.to_owned());
149 self
150 }
151
152 #[must_use]
156 pub fn concurrency(mut self, concurrency: usize) -> Self {
157 self.concurrency = Some(concurrency);
158 self
159 }
160
161 #[must_use]
167 pub fn output_dir(mut self, dir: PathBuf) -> Self {
168 self.output_dir = Some(dir);
169 self
170 }
171
172 #[must_use]
177 pub fn timeout_per_file(mut self, duration: Duration) -> Self {
178 self.timeout_per_file = Some(duration);
179 self
180 }
181
182 #[must_use]
187 pub fn timeout_total(mut self, duration: Duration) -> Self {
188 self.timeout_total = Some(duration);
189 self
190 }
191
192 #[must_use]
197 pub fn max_retries(mut self, retries: u32) -> Self {
198 self.max_retries = Some(retries);
199 self
200 }
201
202 #[must_use]
210 pub fn verify_checksums(mut self, verify: bool) -> Self {
211 self.verify_checksums = Some(verify);
212 self
213 }
214
215 #[must_use]
222 pub fn chunk_threshold(mut self, bytes: u64) -> Self {
223 self.chunk_threshold = Some(bytes);
224 self
225 }
226
227 #[must_use]
231 pub fn connections_per_file(mut self, connections: usize) -> Self {
232 self.connections_per_file = Some(connections);
233 self
234 }
235
236 #[must_use]
238 pub fn on_progress<F>(mut self, callback: F) -> Self
239 where
240 F: Fn(&ProgressEvent) + Send + Sync + 'static,
241 {
242 self.on_progress = Some(Arc::new(callback));
243 self
244 }
245
246 pub fn build(self) -> Result<FetchConfig, FetchError> {
252 let include = build_globset(&self.include_patterns)?;
253 let exclude = build_globset(&self.exclude_patterns)?;
254
255 Ok(FetchConfig {
256 revision: self.revision,
257 token: self.token,
258 include,
259 exclude,
260 concurrency: self.concurrency.unwrap_or(4),
261 output_dir: self.output_dir,
262 timeout_per_file: self.timeout_per_file,
263 timeout_total: self.timeout_total,
264 max_retries: self.max_retries.unwrap_or(3),
265 verify_checksums: self.verify_checksums.unwrap_or(true),
266 chunk_threshold: self.chunk_threshold.unwrap_or(104_857_600),
267 connections_per_file: self.connections_per_file.unwrap_or(8).max(1),
268 on_progress: self.on_progress,
269 })
270 }
271}
272
273#[non_exhaustive]
275pub struct Filter;
276
277impl Filter {
278 #[must_use]
281 pub fn safetensors() -> FetchConfigBuilder {
282 FetchConfigBuilder::default()
283 .filter("*.safetensors")
284 .filter("*.json")
285 .filter("*.txt")
286 }
287
288 #[must_use]
291 pub fn gguf() -> FetchConfigBuilder {
292 FetchConfigBuilder::default()
293 .filter("*.gguf")
294 .filter("*.json")
295 .filter("*.txt")
296 }
297
298 #[must_use]
301 pub fn config_only() -> FetchConfigBuilder {
302 FetchConfigBuilder::default()
303 .filter("*.json")
304 .filter("*.txt")
305 .filter("*.md")
306 }
307}
308
309#[must_use]
311pub(crate) fn file_matches(
312 filename: &str,
313 include: Option<&GlobSet>,
314 exclude: Option<&GlobSet>,
315) -> bool {
316 if let Some(exc) = exclude {
317 if exc.is_match(filename) {
318 return false;
319 }
320 }
321 if let Some(inc) = include {
322 return inc.is_match(filename);
323 }
324 true
325}
326
327fn build_globset(patterns: &[String]) -> Result<Option<GlobSet>, FetchError> {
328 if patterns.is_empty() {
329 return Ok(None);
330 }
331 let mut builder = GlobSetBuilder::new();
332 for pattern in patterns {
333 let glob = Glob::new(pattern.as_str()).map_err(|e| FetchError::InvalidPattern {
335 pattern: pattern.clone(),
336 reason: e.to_string(),
337 })?;
338 builder.add(glob);
339 }
340 let set = builder.build().map_err(|e| FetchError::InvalidPattern {
341 pattern: patterns.join(", "),
342 reason: e.to_string(),
343 })?;
344 Ok(Some(set))
345}
346
347#[cfg(test)]
348mod tests {
349 #![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
350
351 use super::*;
352
353 #[test]
354 fn test_file_matches_no_filters() {
355 assert!(file_matches("model.safetensors", None, None));
356 }
357
358 #[test]
359 fn test_file_matches_include() {
360 let include = build_globset(&["*.safetensors".to_owned()]).unwrap();
361 assert!(file_matches("model.safetensors", include.as_ref(), None));
362 assert!(!file_matches("model.bin", include.as_ref(), None));
363 }
364
365 #[test]
366 fn test_file_matches_exclude() {
367 let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
368 assert!(file_matches("model.safetensors", None, exclude.as_ref()));
369 assert!(!file_matches("model.bin", None, exclude.as_ref()));
370 }
371
372 #[test]
373 fn test_exclude_overrides_include() {
374 let include = build_globset(&["*.safetensors".to_owned(), "*.bin".to_owned()]).unwrap();
375 let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
376 assert!(file_matches(
377 "model.safetensors",
378 include.as_ref(),
379 exclude.as_ref()
380 ));
381 assert!(!file_matches(
382 "model.bin",
383 include.as_ref(),
384 exclude.as_ref()
385 ));
386 }
387}