1use crate::cache;
12use crate::cache_layout;
13use crate::chunked;
14use crate::config::{self, FetchConfig, FetchConfigBuilder};
15use crate::error::FetchError;
16use crate::repo;
17
18const LARGE_FILE_THRESHOLD: u64 = 1_073_741_824;
20
21const VERY_LARGE_FILE_THRESHOLD: u64 = 5_368_709_120;
23
24const SMALL_FILE_THRESHOLD: u64 = 10_485_760;
26
27const DEFAULT_CHUNK_THRESHOLD: u64 = 104_857_600;
29
30#[derive(Debug, Clone)]
36pub struct DownloadPlan {
37 pub repo_id: String,
39 pub revision: String,
41 pub files: Vec<FilePlan>,
43 pub total_bytes: u64,
45 pub cached_bytes: u64,
47 pub download_bytes: u64,
49}
50
51#[derive(Debug, Clone)]
53pub struct FilePlan {
54 pub filename: String,
56 pub size: u64,
58 pub cached: bool,
60}
61
62impl DownloadPlan {
63 #[must_use]
65 pub fn files_to_download(&self) -> usize {
66 self.files.iter().filter(|f| !f.cached).count()
67 }
68
69 #[must_use]
71 pub const fn fully_cached(&self) -> bool {
72 self.download_bytes == 0
73 }
74
75 pub fn recommended_config(&self) -> Result<FetchConfig, FetchError> {
88 self.recommended_config_builder().build()
89 }
90
91 #[must_use]
94 pub fn recommended_config_builder(&self) -> FetchConfigBuilder {
95 let uncached: Vec<u64> = self
96 .files
97 .iter()
98 .filter(|f| !f.cached)
99 .map(|f| f.size)
100 .collect();
101
102 let builder = FetchConfig::builder();
103
104 if uncached.is_empty() {
105 return builder.concurrency(1);
107 }
108
109 let count = uncached.len();
110 let large_count = uncached
111 .iter()
112 .filter(|&&s| s >= LARGE_FILE_THRESHOLD)
113 .count();
114 let very_large = uncached.iter().any(|&s| s >= VERY_LARGE_FILE_THRESHOLD);
115 let small_count = uncached
116 .iter()
117 .filter(|&&s| s < SMALL_FILE_THRESHOLD)
118 .count();
119
120 if count <= 2 && large_count > 0 {
122 let connections = if very_large { 16 } else { 8 };
123 return builder
124 .concurrency(count.max(1))
125 .connections_per_file(connections)
126 .chunk_threshold(DEFAULT_CHUNK_THRESHOLD);
127 }
128
129 if small_count > count / 2 && large_count == 0 {
133 return builder
134 .concurrency(8.min(count))
135 .connections_per_file(1)
136 .chunk_threshold(u64::MAX);
137 }
138
139 builder
141 .concurrency(4)
142 .connections_per_file(8)
143 .chunk_threshold(DEFAULT_CHUNK_THRESHOLD)
144 }
145}
146
147pub async fn download_plan(
159 repo_id: &str,
160 config: &FetchConfig,
161) -> Result<DownloadPlan, FetchError> {
162 let revision_str = config.revision.as_deref().unwrap_or("main");
163 let token = config.token.as_deref();
164
165 let client = chunked::build_client(token)?;
167 let remote_files =
168 repo::list_repo_files_with_metadata(repo_id, token, Some(revision_str), &client).await?;
169
170 let filtered: Vec<_> = remote_files
172 .into_iter()
173 .filter(|f| {
174 config::file_matches(
176 f.filename.as_str(),
177 config.include.as_ref(),
178 config.exclude.as_ref(),
179 )
180 })
181 .collect();
182
183 let cache_dir = config
185 .output_dir
186 .clone()
187 .map_or_else(cache::hf_cache_dir, Ok)?;
188 let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
189 let commit_hash = cache::read_ref(&repo_dir, revision_str);
190 let snapshot_dir = commit_hash
191 .as_deref()
192 .map(|hash| cache_layout::snapshot_dir(&repo_dir, hash));
193
194 let mut total_bytes: u64 = 0;
195 let mut cached_bytes: u64 = 0;
196 let mut files = Vec::with_capacity(filtered.len());
197
198 for rf in &filtered {
199 let size = rf.size.unwrap_or(0);
200 total_bytes = total_bytes.saturating_add(size);
201
202 let cached = snapshot_dir
203 .as_ref()
204 .is_some_and(|dir| dir.join(rf.filename.as_str()).exists());
206
207 if cached {
208 cached_bytes = cached_bytes.saturating_add(size);
209 }
210
211 files.push(FilePlan {
212 filename: rf.filename.clone(),
214 size,
215 cached,
216 });
217 }
218
219 let download_bytes = total_bytes.saturating_sub(cached_bytes);
220
221 Ok(DownloadPlan {
222 repo_id: repo_id.to_owned(),
224 revision: commit_hash.unwrap_or_else(|| revision_str.to_owned()),
226 files,
227 total_bytes,
228 cached_bytes,
229 download_bytes,
230 })
231}
232
233#[cfg(test)]
234mod tests {
235 #![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
236
237 use super::*;
238
239 fn make_plan(file_specs: &[(u64, bool)]) -> DownloadPlan {
241 let mut total_bytes: u64 = 0;
242 let mut cached_bytes: u64 = 0;
243 let files: Vec<FilePlan> = file_specs
244 .iter()
245 .enumerate()
246 .map(|(i, &(size, cached))| {
247 total_bytes = total_bytes.saturating_add(size);
248 if cached {
249 cached_bytes = cached_bytes.saturating_add(size);
250 }
251 FilePlan {
252 filename: format!("file_{i}.bin"),
253 size,
254 cached,
255 }
256 })
257 .collect();
258
259 DownloadPlan {
260 repo_id: "test/repo".to_owned(),
261 revision: "main".to_owned(),
262 files,
263 total_bytes,
264 cached_bytes,
265 download_bytes: total_bytes.saturating_sub(cached_bytes),
266 }
267 }
268
269 #[test]
270 fn all_cached_returns_concurrency_one() {
271 let plan = make_plan(&[(1_000_000, true), (2_000_000, true)]);
272 assert!(plan.fully_cached());
273 assert_eq!(plan.files_to_download(), 0);
274 let config = plan.recommended_config().unwrap();
275 assert_eq!(config.concurrency(), 1);
276 }
277
278 #[test]
279 fn single_very_large_file_gets_sixteen_connections() {
280 let plan = make_plan(&[(6_442_450_944, false)]);
282 assert_eq!(plan.files_to_download(), 1);
283 let config = plan.recommended_config().unwrap();
284 assert_eq!(config.concurrency(), 1);
285 assert_eq!(config.connections_per_file(), 16);
286 }
287
288 #[test]
289 fn two_large_files_get_eight_connections() {
290 let plan = make_plan(&[(2_147_483_648, false), (2_147_483_648, false)]);
292 assert_eq!(plan.files_to_download(), 2);
293 let config = plan.recommended_config().unwrap();
294 assert_eq!(config.concurrency(), 2);
295 assert_eq!(config.connections_per_file(), 8);
296 }
297
298 #[test]
299 fn many_small_files_get_high_concurrency_single_connection() {
300 let specs: Vec<(u64, bool)> = (0..20).map(|_| (1_048_576, false)).collect();
302 let plan = make_plan(&specs);
303 assert_eq!(plan.files_to_download(), 20);
304 let config = plan.recommended_config().unwrap();
305 assert_eq!(config.concurrency(), 8);
306 assert_eq!(config.connections_per_file(), 1);
307 assert_eq!(config.chunk_threshold(), u64::MAX);
308 }
309
310 #[test]
311 fn mixed_sizes_get_balanced_defaults() {
312 let plan = make_plan(&[
314 (2_147_483_648, false), (104_857_600, false), (52_428_800, false), (1_073_741_824, false), (20_971_520, false), ]);
320 assert_eq!(plan.files_to_download(), 5);
321 let config = plan.recommended_config().unwrap();
322 assert_eq!(config.concurrency(), 4);
323 assert_eq!(config.connections_per_file(), 8);
324 assert_eq!(config.chunk_threshold(), DEFAULT_CHUNK_THRESHOLD);
325 }
326
327 #[test]
328 fn mostly_small_with_large_files_uses_mixed_strategy() {
329 let plan = make_plan(&[
333 (4_672_561_152, false), (4_672_561_152, false), (2_355, false), (1_946, false), (131, false), (1_229, false), (976, false), (16_756_736, false), (17_081_344, false), (21_197, false), ]);
344 assert_eq!(plan.files_to_download(), 10);
345 let config = plan.recommended_config().unwrap();
346 assert_eq!(config.concurrency(), 4);
348 assert_eq!(config.connections_per_file(), 8);
349 assert_eq!(config.chunk_threshold(), DEFAULT_CHUNK_THRESHOLD);
350 }
351}