use crate::huggingface::cache::HfCache;
use crate::huggingface::client::HfClient;
use anyhow::{Context, Result};
use std::path::PathBuf;
pub struct DownloadOptions {
pub revision: String,
pub include: Vec<String>,
pub exclude: Vec<String>,
}
impl Default for DownloadOptions {
fn default() -> Self {
Self {
revision: "main".to_string(),
include: Vec::new(),
exclude: Vec::new(),
}
}
}
pub struct DownloadResult {
pub model_id: String,
pub revision: String,
pub commit_sha: String,
pub snapshot_path: PathBuf,
pub files_downloaded: usize,
pub total_bytes: u64,
pub from_cache: bool,
}
pub async fn download_model(
client: &HfClient,
cache: &HfCache,
model_id: &str,
opts: &DownloadOptions,
) -> Result<DownloadResult> {
let commit_sha = client
.resolve_revision(model_id, &opts.revision)
.await
.with_context(|| {
format!(
"Failed to resolve revision '{}' for '{}'",
opts.revision, model_id
)
})?;
let snapshot_path = cache.snapshot_dir(model_id, &commit_sha);
if cache.is_cached(model_id, &commit_sha) {
return Ok(DownloadResult {
model_id: model_id.to_string(),
revision: opts.revision.clone(),
commit_sha,
snapshot_path,
files_downloaded: 0,
total_bytes: 0,
from_cache: true,
});
}
let entries = client
.list_files(model_id, &opts.revision)
.await
.with_context(|| format!("Failed to list files for '{}'", model_id))?;
let files: Vec<_> = entries
.iter()
.filter(|e| e.entry_type == "file")
.filter(|e| matches_filters(&e.path, &opts.include, &opts.exclude))
.collect();
let mut files_downloaded: usize = 0;
let mut total_bytes: u64 = 0;
for entry in &files {
let url = client.file_download_url(model_id, &opts.revision, &entry.path);
let data = client
.download_file(&url)
.await
.with_context(|| format!("Failed to download '{}'", entry.path))?;
let byte_count = data.len() as u64;
let blob_path = cache
.store_blob(model_id, &data)
.with_context(|| format!("Failed to store blob for '{}'", entry.path))?;
cache
.link_snapshot(model_id, &commit_sha, &entry.path, &blob_path)
.with_context(|| format!("Failed to link snapshot for '{}'", entry.path))?;
files_downloaded += 1;
total_bytes += byte_count;
}
cache
.write_ref(model_id, &opts.revision, &commit_sha)
.with_context(|| format!("Failed to write ref '{}' for '{}'", opts.revision, model_id))?;
Ok(DownloadResult {
model_id: model_id.to_string(),
revision: opts.revision.clone(),
commit_sha,
snapshot_path,
files_downloaded,
total_bytes,
from_cache: false,
})
}
fn matches_filters(path: &str, include: &[String], exclude: &[String]) -> bool {
if !include.is_empty() && !include.iter().any(|pat| glob_match(pat, path)) {
return false;
}
if exclude.iter().any(|pat| glob_match(pat, path)) {
return false;
}
true
}
fn glob_match(pattern: &str, text: &str) -> bool {
let pat: Vec<char> = pattern.chars().collect();
let txt: Vec<char> = text.chars().collect();
glob_match_inner(&pat, &txt)
}
fn glob_match_inner(pat: &[char], txt: &[char]) -> bool {
match (pat.first(), txt.first()) {
(None, None) => true,
(None, Some(_)) => false,
(Some('*'), _) => {
if glob_match_inner(&pat[1..], txt) {
return true;
}
if !txt.is_empty() {
return glob_match_inner(pat, &txt[1..]);
}
false
}
(Some('?'), Some(_)) => glob_match_inner(&pat[1..], &txt[1..]),
(Some('?'), None) => false,
(Some(&pc), Some(&tc)) => {
if pc == tc {
glob_match_inner(&pat[1..], &txt[1..])
} else {
false
}
}
(Some(_), None) => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn glob_exact_match() {
assert!(glob_match("config.json", "config.json"));
}
#[test]
fn glob_exact_mismatch() {
assert!(!glob_match("config.json", "model.bin"));
}
#[test]
fn glob_star_suffix() {
assert!(glob_match("*.json", "config.json"));
assert!(glob_match("*.json", "a.json"));
assert!(!glob_match("*.json", "config.txt"));
}
#[test]
fn glob_star_prefix() {
assert!(glob_match("config.*", "config.json"));
assert!(glob_match("config.*", "config.yaml"));
assert!(!glob_match("config.*", "other.json"));
}
#[test]
fn glob_star_middle() {
assert!(glob_match("model*.bin", "model.bin"));
assert!(glob_match("model*.bin", "model-00001-of-00002.bin"));
}
#[test]
fn glob_star_matches_empty() {
assert!(glob_match("model*", "model"));
}
#[test]
fn glob_question_mark() {
assert!(glob_match("?.json", "a.json"));
assert!(!glob_match("?.json", "ab.json"));
}
#[test]
fn glob_only_star() {
assert!(glob_match("*", "anything"));
assert!(glob_match("*", ""));
}
#[test]
fn glob_double_star() {
assert!(glob_match("**", "foo/bar/baz"));
}
#[test]
fn glob_star_slash_pattern() {
assert!(glob_match("*/*.json", "subdir/config.json"));
assert!(!glob_match("*/*.json", "config.json"));
}
#[test]
fn filters_empty_includes_everything() {
assert!(matches_filters("any/file.txt", &[], &[]));
}
#[test]
fn filters_include_only() {
let inc = vec!["*.json".to_string()];
assert!(matches_filters("config.json", &inc, &[]));
assert!(!matches_filters("model.bin", &inc, &[]));
}
#[test]
fn filters_exclude_only() {
let exc = vec!["*.bin".to_string()];
assert!(matches_filters("config.json", &[], &exc));
assert!(!matches_filters("model.bin", &[], &exc));
}
#[test]
fn filters_include_and_exclude() {
let inc = vec!["*.json".to_string(), "*.bin".to_string()];
let exc = vec!["*.bin".to_string()];
assert!(!matches_filters("model.bin", &inc, &exc));
assert!(matches_filters("config.json", &inc, &exc));
}
#[test]
fn filters_no_include_match() {
let inc = vec!["*.safetensors".to_string()];
assert!(!matches_filters("config.json", &inc, &[]));
}
}