use crate::cache;
use crate::cache_layout;
use crate::chunked;
use crate::config::{self, FetchConfig, FetchConfigBuilder};
use crate::error::FetchError;
use crate::repo;
const LARGE_FILE_THRESHOLD: u64 = 1_073_741_824;
const VERY_LARGE_FILE_THRESHOLD: u64 = 5_368_709_120;
const SMALL_FILE_THRESHOLD: u64 = 10_485_760;
const DEFAULT_CHUNK_THRESHOLD: u64 = 104_857_600;
#[derive(Debug, Clone)]
pub struct DownloadPlan {
pub repo_id: String,
pub revision: String,
pub files: Vec<FilePlan>,
pub total_bytes: u64,
pub cached_bytes: u64,
pub download_bytes: u64,
}
#[derive(Debug, Clone)]
pub struct FilePlan {
pub filename: String,
pub size: u64,
pub cached: bool,
}
impl DownloadPlan {
#[must_use]
pub fn files_to_download(&self) -> usize {
self.files.iter().filter(|f| !f.cached).count()
}
#[must_use]
pub const fn fully_cached(&self) -> bool {
self.download_bytes == 0
}
pub fn recommended_config(&self) -> Result<FetchConfig, FetchError> {
self.recommended_config_builder().build()
}
#[must_use]
pub fn recommended_config_builder(&self) -> FetchConfigBuilder {
let uncached: Vec<u64> = self
.files
.iter()
.filter(|f| !f.cached)
.map(|f| f.size)
.collect();
let builder = FetchConfig::builder();
if uncached.is_empty() {
return builder.concurrency(1);
}
let count = uncached.len();
let large_count = uncached
.iter()
.filter(|&&s| s >= LARGE_FILE_THRESHOLD)
.count();
let very_large = uncached.iter().any(|&s| s >= VERY_LARGE_FILE_THRESHOLD);
let small_count = uncached
.iter()
.filter(|&&s| s < SMALL_FILE_THRESHOLD)
.count();
if count <= 2 && large_count > 0 {
let connections = if very_large { 16 } else { 8 };
return builder
.concurrency(count.max(1))
.connections_per_file(connections)
.chunk_threshold(DEFAULT_CHUNK_THRESHOLD);
}
if small_count > count / 2 && large_count == 0 {
return builder
.concurrency(8.min(count))
.connections_per_file(1)
.chunk_threshold(u64::MAX);
}
builder
.concurrency(4)
.connections_per_file(8)
.chunk_threshold(DEFAULT_CHUNK_THRESHOLD)
}
}
pub async fn download_plan(
repo_id: &str,
config: &FetchConfig,
) -> Result<DownloadPlan, FetchError> {
let revision_str = config.revision.as_deref().unwrap_or("main");
let token = config.token.as_deref();
let client = chunked::build_client(token)?;
let remote_files =
repo::list_repo_files_with_metadata(repo_id, token, Some(revision_str), &client).await?;
let filtered: Vec<_> = remote_files
.into_iter()
.filter(|f| {
config::file_matches(
f.filename.as_str(),
config.include.as_ref(),
config.exclude.as_ref(),
)
})
.collect();
let cache_dir = config
.output_dir
.clone()
.map_or_else(cache::hf_cache_dir, Ok)?;
let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
let commit_hash = cache::read_ref(&repo_dir, revision_str);
let snapshot_dir = commit_hash
.as_deref()
.map(|hash| cache_layout::snapshot_dir(&repo_dir, hash));
let mut total_bytes: u64 = 0;
let mut cached_bytes: u64 = 0;
let mut files = Vec::with_capacity(filtered.len());
for rf in &filtered {
let size = rf.size.unwrap_or(0);
total_bytes = total_bytes.saturating_add(size);
let cached = snapshot_dir
.as_ref()
.is_some_and(|dir| dir.join(rf.filename.as_str()).exists());
if cached {
cached_bytes = cached_bytes.saturating_add(size);
}
files.push(FilePlan {
filename: rf.filename.clone(),
size,
cached,
});
}
let download_bytes = total_bytes.saturating_sub(cached_bytes);
Ok(DownloadPlan {
repo_id: repo_id.to_owned(),
revision: commit_hash.unwrap_or_else(|| revision_str.to_owned()),
files,
total_bytes,
cached_bytes,
download_bytes,
})
}
#[cfg(test)]
mod tests {
#![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
use super::*;
fn make_plan(file_specs: &[(u64, bool)]) -> DownloadPlan {
let mut total_bytes: u64 = 0;
let mut cached_bytes: u64 = 0;
let files: Vec<FilePlan> = file_specs
.iter()
.enumerate()
.map(|(i, &(size, cached))| {
total_bytes = total_bytes.saturating_add(size);
if cached {
cached_bytes = cached_bytes.saturating_add(size);
}
FilePlan {
filename: format!("file_{i}.bin"),
size,
cached,
}
})
.collect();
DownloadPlan {
repo_id: "test/repo".to_owned(),
revision: "main".to_owned(),
files,
total_bytes,
cached_bytes,
download_bytes: total_bytes.saturating_sub(cached_bytes),
}
}
#[test]
fn all_cached_returns_concurrency_one() {
let plan = make_plan(&[(1_000_000, true), (2_000_000, true)]);
assert!(plan.fully_cached());
assert_eq!(plan.files_to_download(), 0);
let config = plan.recommended_config().unwrap();
assert_eq!(config.concurrency(), 1);
}
#[test]
fn single_very_large_file_gets_sixteen_connections() {
let plan = make_plan(&[(6_442_450_944, false)]);
assert_eq!(plan.files_to_download(), 1);
let config = plan.recommended_config().unwrap();
assert_eq!(config.concurrency(), 1);
assert_eq!(config.connections_per_file(), 16);
}
#[test]
fn two_large_files_get_eight_connections() {
let plan = make_plan(&[(2_147_483_648, false), (2_147_483_648, false)]);
assert_eq!(plan.files_to_download(), 2);
let config = plan.recommended_config().unwrap();
assert_eq!(config.concurrency(), 2);
assert_eq!(config.connections_per_file(), 8);
}
#[test]
fn many_small_files_get_high_concurrency_single_connection() {
let specs: Vec<(u64, bool)> = (0..20).map(|_| (1_048_576, false)).collect();
let plan = make_plan(&specs);
assert_eq!(plan.files_to_download(), 20);
let config = plan.recommended_config().unwrap();
assert_eq!(config.concurrency(), 8);
assert_eq!(config.connections_per_file(), 1);
assert_eq!(config.chunk_threshold(), u64::MAX);
}
#[test]
fn mixed_sizes_get_balanced_defaults() {
let plan = make_plan(&[
(2_147_483_648, false), (104_857_600, false), (52_428_800, false), (1_073_741_824, false), (20_971_520, false), ]);
assert_eq!(plan.files_to_download(), 5);
let config = plan.recommended_config().unwrap();
assert_eq!(config.concurrency(), 4);
assert_eq!(config.connections_per_file(), 8);
assert_eq!(config.chunk_threshold(), DEFAULT_CHUNK_THRESHOLD);
}
#[test]
fn mostly_small_with_large_files_uses_mixed_strategy() {
let plan = make_plan(&[
(4_672_561_152, false), (4_672_561_152, false), (2_355, false), (1_946, false), (131, false), (1_229, false), (976, false), (16_756_736, false), (17_081_344, false), (21_197, false), ]);
assert_eq!(plan.files_to_download(), 10);
let config = plan.recommended_config().unwrap();
assert_eq!(config.concurrency(), 4);
assert_eq!(config.connections_per_file(), 8);
assert_eq!(config.chunk_threshold(), DEFAULT_CHUNK_THRESHOLD);
}
}