1use std::path::PathBuf;
9use std::sync::Arc;
10use std::time::Duration;
11
12use globset::{Glob, GlobSet, GlobSetBuilder};
13
14use crate::error::FetchError;
15use crate::progress::ProgressEvent;
16
17type ProgressCallback = Arc<dyn Fn(&ProgressEvent) + Send + Sync>;
19
20pub struct FetchConfig {
39 pub(crate) revision: Option<String>,
40 pub(crate) token: Option<String>,
41 pub(crate) include: Option<GlobSet>,
42 pub(crate) exclude: Option<GlobSet>,
43 pub(crate) concurrency: usize,
44 pub(crate) output_dir: Option<PathBuf>,
45 pub(crate) timeout_per_file: Option<Duration>,
46 pub(crate) timeout_total: Option<Duration>,
47 pub(crate) max_retries: u32,
48 pub(crate) verify_checksums: bool,
49 pub(crate) on_progress: Option<ProgressCallback>,
51}
52
53impl std::fmt::Debug for FetchConfig {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 f.debug_struct("FetchConfig")
56 .field("revision", &self.revision)
57 .field("token", &self.token.as_ref().map(|_| "***"))
58 .field("include", &self.include)
59 .field("exclude", &self.exclude)
60 .field("concurrency", &self.concurrency)
61 .field("output_dir", &self.output_dir)
62 .field("timeout_per_file", &self.timeout_per_file)
63 .field("timeout_total", &self.timeout_total)
64 .field("max_retries", &self.max_retries)
65 .field("verify_checksums", &self.verify_checksums)
66 .field(
67 "on_progress",
68 if self.on_progress.is_some() {
69 &"Some(<fn>)"
70 } else {
71 &"None"
72 },
73 )
74 .finish()
75 }
76}
77
78impl FetchConfig {
79 #[must_use]
81 pub fn builder() -> FetchConfigBuilder {
82 FetchConfigBuilder::default()
83 }
84}
85
86#[derive(Default)]
88pub struct FetchConfigBuilder {
89 revision: Option<String>,
90 token: Option<String>,
91 include_patterns: Vec<String>,
92 exclude_patterns: Vec<String>,
93 concurrency: Option<usize>,
94 output_dir: Option<PathBuf>,
95 timeout_per_file: Option<Duration>,
96 timeout_total: Option<Duration>,
97 max_retries: Option<u32>,
98 verify_checksums: Option<bool>,
99 on_progress: Option<ProgressCallback>,
100}
101
102impl FetchConfigBuilder {
103 #[must_use]
107 pub fn revision(mut self, revision: &str) -> Self {
108 self.revision = Some(revision.to_owned());
109 self
110 }
111
112 #[must_use]
114 pub fn token(mut self, token: &str) -> Self {
115 self.token = Some(token.to_owned());
116 self
117 }
118
119 #[must_use]
121 pub fn token_from_env(mut self) -> Self {
122 self.token = std::env::var("HF_TOKEN").ok();
123 self
124 }
125
126 #[must_use]
131 pub fn filter(mut self, pattern: &str) -> Self {
132 self.include_patterns.push(pattern.to_owned());
133 self
134 }
135
136 #[must_use]
141 pub fn exclude(mut self, pattern: &str) -> Self {
142 self.exclude_patterns.push(pattern.to_owned());
143 self
144 }
145
146 #[must_use]
150 pub fn concurrency(mut self, concurrency: usize) -> Self {
151 self.concurrency = Some(concurrency);
152 self
153 }
154
155 #[must_use]
161 pub fn output_dir(mut self, dir: PathBuf) -> Self {
162 self.output_dir = Some(dir);
163 self
164 }
165
166 #[must_use]
171 pub fn timeout_per_file(mut self, duration: Duration) -> Self {
172 self.timeout_per_file = Some(duration);
173 self
174 }
175
176 #[must_use]
181 pub fn timeout_total(mut self, duration: Duration) -> Self {
182 self.timeout_total = Some(duration);
183 self
184 }
185
186 #[must_use]
191 pub fn max_retries(mut self, retries: u32) -> Self {
192 self.max_retries = Some(retries);
193 self
194 }
195
196 #[must_use]
204 pub fn verify_checksums(mut self, verify: bool) -> Self {
205 self.verify_checksums = Some(verify);
206 self
207 }
208
209 #[must_use]
211 pub fn on_progress<F>(mut self, callback: F) -> Self
212 where
213 F: Fn(&ProgressEvent) + Send + Sync + 'static,
214 {
215 self.on_progress = Some(Arc::new(callback));
216 self
217 }
218
219 pub fn build(self) -> Result<FetchConfig, FetchError> {
225 let include = build_globset(&self.include_patterns)?;
226 let exclude = build_globset(&self.exclude_patterns)?;
227
228 Ok(FetchConfig {
229 revision: self.revision,
230 token: self.token,
231 include,
232 exclude,
233 concurrency: self.concurrency.unwrap_or(4),
234 output_dir: self.output_dir,
235 timeout_per_file: self.timeout_per_file,
236 timeout_total: self.timeout_total,
237 max_retries: self.max_retries.unwrap_or(3),
238 verify_checksums: self.verify_checksums.unwrap_or(true),
239 on_progress: self.on_progress,
240 })
241 }
242}
243
244#[non_exhaustive]
246pub struct Filter;
247
248impl Filter {
249 #[must_use]
252 pub fn safetensors() -> FetchConfigBuilder {
253 FetchConfigBuilder::default()
254 .filter("*.safetensors")
255 .filter("*.json")
256 .filter("*.txt")
257 }
258
259 #[must_use]
262 pub fn gguf() -> FetchConfigBuilder {
263 FetchConfigBuilder::default()
264 .filter("*.gguf")
265 .filter("*.json")
266 .filter("*.txt")
267 }
268
269 #[must_use]
272 pub fn config_only() -> FetchConfigBuilder {
273 FetchConfigBuilder::default()
274 .filter("*.json")
275 .filter("*.txt")
276 .filter("*.md")
277 }
278}
279
280#[must_use]
282pub(crate) fn file_matches(
283 filename: &str,
284 include: Option<&GlobSet>,
285 exclude: Option<&GlobSet>,
286) -> bool {
287 if let Some(exc) = exclude {
288 if exc.is_match(filename) {
289 return false;
290 }
291 }
292 if let Some(inc) = include {
293 return inc.is_match(filename);
294 }
295 true
296}
297
298fn build_globset(patterns: &[String]) -> Result<Option<GlobSet>, FetchError> {
299 if patterns.is_empty() {
300 return Ok(None);
301 }
302 let mut builder = GlobSetBuilder::new();
303 for pattern in patterns {
304 let glob = Glob::new(pattern.as_str()).map_err(|e| FetchError::InvalidPattern {
306 pattern: pattern.clone(),
307 reason: e.to_string(),
308 })?;
309 builder.add(glob);
310 }
311 let set = builder.build().map_err(|e| FetchError::InvalidPattern {
312 pattern: patterns.join(", "),
313 reason: e.to_string(),
314 })?;
315 Ok(Some(set))
316}
317
318#[cfg(test)]
319mod tests {
320 #![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
321
322 use super::*;
323
324 #[test]
325 fn test_file_matches_no_filters() {
326 assert!(file_matches("model.safetensors", None, None));
327 }
328
329 #[test]
330 fn test_file_matches_include() {
331 let include = build_globset(&["*.safetensors".to_owned()]).unwrap();
332 assert!(file_matches("model.safetensors", include.as_ref(), None));
333 assert!(!file_matches("model.bin", include.as_ref(), None));
334 }
335
336 #[test]
337 fn test_file_matches_exclude() {
338 let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
339 assert!(file_matches("model.safetensors", None, exclude.as_ref()));
340 assert!(!file_matches("model.bin", None, exclude.as_ref()));
341 }
342
343 #[test]
344 fn test_exclude_overrides_include() {
345 let include = build_globset(&["*.safetensors".to_owned(), "*.bin".to_owned()]).unwrap();
346 let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
347 assert!(file_matches(
348 "model.safetensors",
349 include.as_ref(),
350 exclude.as_ref()
351 ));
352 assert!(!file_matches(
353 "model.bin",
354 include.as_ref(),
355 exclude.as_ref()
356 ));
357 }
358}