use std::path::PathBuf;
use globset::{Glob, GlobMatcher};
pub use hf_hub::{
Repo, RepoType,
api::{
RepoInfo,
sync::{Api, ApiBuilder, ApiError, ApiRepo},
},
};
#[derive(thiserror::Error, Debug)]
pub enum HuggingfaceError {
#[error("fail to download: {0}")]
ApiError(#[from] hf_hub::api::sync::ApiError),
}
pub fn compile_glob_pattern(patterns: &[&str]) -> Result<Vec<GlobMatcher>, globset::Error> {
let compiled_patterns = patterns
.iter()
.map(|s| Glob::new(s).map(|g| g.compile_matcher()))
.collect::<Result<Vec<GlobMatcher>, globset::Error>>()?;
Ok(compiled_patterns)
}
#[derive(Debug, Clone, Default)]
pub struct Params {
pub allow_patterns: Option<Vec<GlobMatcher>>,
pub ignore_patterns: Option<Vec<GlobMatcher>>,
}
impl Params {
pub fn is_matched(&self, filename: &str) -> bool {
if let Some(patterns) = &self.allow_patterns {
if !patterns.iter().any(|glob| glob.is_match(filename)) {
return false;
}
}
if let Some(patterns) = &self.ignore_patterns {
if patterns.iter().any(|glob| glob.is_match(filename)) {
return false;
}
}
true
}
}
pub fn snapshot_download(repo: Repo, options: Option<Params>) -> Result<PathBuf, ApiError> {
let api: Api = if let Ok(token) = std::env::var("HF_TOKEN") {
ApiBuilder::new().with_token(Some(token)).build()?
} else {
ApiBuilder::from_env().build()?
};
let api_repo: ApiRepo = api.repo(repo.clone());
let repo_info: RepoInfo = api_repo.info()?;
for sibling in repo_info.siblings {
if let Some(options) = &options {
if !options.is_matched(&sibling.rfilename) {
continue;
}
}
api_repo.get(&sibling.rfilename)?;
}
let config_json_path: PathBuf = api_repo.get("config.json")?;
config_json_path
.parent()
.ok_or(ApiError::IoError(std::io::Error::new(
std::io::ErrorKind::NotFound,
"Parent directory not found", )))
.map(PathBuf::from)
}