use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use globset::{Glob, GlobSet, GlobSetBuilder};
use crate::error::FetchError;
use crate::progress::ProgressEvent;
pub(crate) type ProgressCallback = Arc<dyn Fn(&ProgressEvent) + Send + Sync>;
pub struct FetchConfig {
pub(crate) revision: Option<String>,
pub(crate) token: Option<String>,
pub(crate) include: Option<GlobSet>,
pub(crate) exclude: Option<GlobSet>,
pub(crate) concurrency: usize,
pub(crate) output_dir: Option<PathBuf>,
pub(crate) timeout_per_file: Option<Duration>,
pub(crate) timeout_total: Option<Duration>,
pub(crate) max_retries: u32,
pub(crate) verify_checksums: bool,
pub(crate) chunk_threshold: u64,
pub(crate) connections_per_file: usize,
pub(crate) on_progress: Option<ProgressCallback>,
}
impl std::fmt::Debug for FetchConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FetchConfig")
.field("revision", &self.revision)
.field("token", &self.token.as_ref().map(|_| "***"))
.field("include", &self.include)
.field("exclude", &self.exclude)
.field("concurrency", &self.concurrency)
.field("output_dir", &self.output_dir)
.field("timeout_per_file", &self.timeout_per_file)
.field("timeout_total", &self.timeout_total)
.field("max_retries", &self.max_retries)
.field("verify_checksums", &self.verify_checksums)
.field("chunk_threshold", &self.chunk_threshold)
.field("connections_per_file", &self.connections_per_file)
.field(
"on_progress",
if self.on_progress.is_some() {
&"Some(<fn>)"
} else {
&"None"
},
)
.finish()
}
}
impl FetchConfig {
#[must_use]
pub fn builder() -> FetchConfigBuilder {
FetchConfigBuilder::default()
}
}
#[derive(Default)]
pub struct FetchConfigBuilder {
revision: Option<String>,
token: Option<String>,
include_patterns: Vec<String>,
exclude_patterns: Vec<String>,
concurrency: Option<usize>,
output_dir: Option<PathBuf>,
timeout_per_file: Option<Duration>,
timeout_total: Option<Duration>,
max_retries: Option<u32>,
verify_checksums: Option<bool>,
chunk_threshold: Option<u64>,
connections_per_file: Option<usize>,
on_progress: Option<ProgressCallback>,
}
impl FetchConfigBuilder {
#[must_use]
pub fn revision(mut self, revision: &str) -> Self {
self.revision = Some(revision.to_owned());
self
}
#[must_use]
pub fn token(mut self, token: &str) -> Self {
self.token = Some(token.to_owned());
self
}
#[must_use]
pub fn token_from_env(mut self) -> Self {
self.token = std::env::var("HF_TOKEN").ok();
self
}
#[must_use]
pub fn filter(mut self, pattern: &str) -> Self {
self.include_patterns.push(pattern.to_owned());
self
}
#[must_use]
pub fn exclude(mut self, pattern: &str) -> Self {
self.exclude_patterns.push(pattern.to_owned());
self
}
#[must_use]
pub fn concurrency(mut self, concurrency: usize) -> Self {
self.concurrency = Some(concurrency);
self
}
#[must_use]
pub fn output_dir(mut self, dir: PathBuf) -> Self {
self.output_dir = Some(dir);
self
}
#[must_use]
pub fn timeout_per_file(mut self, duration: Duration) -> Self {
self.timeout_per_file = Some(duration);
self
}
#[must_use]
pub fn timeout_total(mut self, duration: Duration) -> Self {
self.timeout_total = Some(duration);
self
}
#[must_use]
pub fn max_retries(mut self, retries: u32) -> Self {
self.max_retries = Some(retries);
self
}
#[must_use]
pub fn verify_checksums(mut self, verify: bool) -> Self {
self.verify_checksums = Some(verify);
self
}
#[must_use]
pub fn chunk_threshold(mut self, bytes: u64) -> Self {
self.chunk_threshold = Some(bytes);
self
}
#[must_use]
pub fn connections_per_file(mut self, connections: usize) -> Self {
self.connections_per_file = Some(connections);
self
}
#[must_use]
pub fn on_progress<F>(mut self, callback: F) -> Self
where
F: Fn(&ProgressEvent) + Send + Sync + 'static,
{
self.on_progress = Some(Arc::new(callback));
self
}
pub fn build(self) -> Result<FetchConfig, FetchError> {
let include = build_globset(&self.include_patterns)?;
let exclude = build_globset(&self.exclude_patterns)?;
Ok(FetchConfig {
revision: self.revision,
token: self.token,
include,
exclude,
concurrency: self.concurrency.unwrap_or(4),
output_dir: self.output_dir,
timeout_per_file: self.timeout_per_file,
timeout_total: self.timeout_total,
max_retries: self.max_retries.unwrap_or(3),
verify_checksums: self.verify_checksums.unwrap_or(true),
chunk_threshold: self.chunk_threshold.unwrap_or(104_857_600),
connections_per_file: self.connections_per_file.unwrap_or(8).max(1),
on_progress: self.on_progress,
})
}
}
#[non_exhaustive]
pub struct Filter;
impl Filter {
#[must_use]
pub fn safetensors() -> FetchConfigBuilder {
FetchConfigBuilder::default()
.filter("*.safetensors")
.filter("*.json")
.filter("*.txt")
}
#[must_use]
pub fn gguf() -> FetchConfigBuilder {
FetchConfigBuilder::default()
.filter("*.gguf")
.filter("*.json")
.filter("*.txt")
}
#[must_use]
pub fn config_only() -> FetchConfigBuilder {
FetchConfigBuilder::default()
.filter("*.json")
.filter("*.txt")
.filter("*.md")
}
}
#[must_use]
pub(crate) fn file_matches(
filename: &str,
include: Option<&GlobSet>,
exclude: Option<&GlobSet>,
) -> bool {
if let Some(exc) = exclude {
if exc.is_match(filename) {
return false;
}
}
if let Some(inc) = include {
return inc.is_match(filename);
}
true
}
fn build_globset(patterns: &[String]) -> Result<Option<GlobSet>, FetchError> {
if patterns.is_empty() {
return Ok(None);
}
let mut builder = GlobSetBuilder::new();
for pattern in patterns {
let glob = Glob::new(pattern.as_str()).map_err(|e| FetchError::InvalidPattern {
pattern: pattern.clone(),
reason: e.to_string(),
})?;
builder.add(glob);
}
let set = builder.build().map_err(|e| FetchError::InvalidPattern {
pattern: patterns.join(", "),
reason: e.to_string(),
})?;
Ok(Some(set))
}
#[cfg(test)]
mod tests {
#![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
use super::*;
#[test]
fn test_file_matches_no_filters() {
assert!(file_matches("model.safetensors", None, None));
}
#[test]
fn test_file_matches_include() {
let include = build_globset(&["*.safetensors".to_owned()]).unwrap();
assert!(file_matches("model.safetensors", include.as_ref(), None));
assert!(!file_matches("model.bin", include.as_ref(), None));
}
#[test]
fn test_file_matches_exclude() {
let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
assert!(file_matches("model.safetensors", None, exclude.as_ref()));
assert!(!file_matches("model.bin", None, exclude.as_ref()));
}
#[test]
fn test_exclude_overrides_include() {
let include = build_globset(&["*.safetensors".to_owned(), "*.bin".to_owned()]).unwrap();
let exclude = build_globset(&["*.bin".to_owned()]).unwrap();
assert!(file_matches(
"model.safetensors",
include.as_ref(),
exclude.as_ref()
));
assert!(!file_matches(
"model.bin",
include.as_ref(),
exclude.as_ref()
));
}
}