Skip to main content

hf_fetch_model/
plan.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Download plan: metadata-only analysis of what needs downloading.
4//!
5//! A [`DownloadPlan`] describes which files in a remote `HuggingFace`
6//! repository need downloading and which are already cached locally.
7//! Use [`download_plan()`] to compute a plan, then inspect it or pass
8//! it to [`recommended_config()`](DownloadPlan::recommended_config) for
9//! optimized download settings.
10
11use crate::cache;
12use crate::cache_layout;
13use crate::chunked;
14use crate::config::{self, FetchConfig, FetchConfigBuilder};
15use crate::error::FetchError;
16use crate::repo;
17
18/// Size threshold for "large" files (1 GiB).
19const LARGE_FILE_THRESHOLD: u64 = 1_073_741_824;
20
21/// Size threshold for "very large" files (5 GiB).
22const VERY_LARGE_FILE_THRESHOLD: u64 = 5_368_709_120;
23
24/// Size threshold for "small" files (10 MiB).
25const SMALL_FILE_THRESHOLD: u64 = 10_485_760;
26
27/// Default chunk threshold (100 MiB).
28const DEFAULT_CHUNK_THRESHOLD: u64 = 104_857_600;
29
30/// A download plan describing which files need downloading and which are cached.
31///
32/// Created by [`download_plan()`]. Contains per-file metadata and aggregate
33/// byte counts. Use [`recommended_config()`](Self::recommended_config) to
34/// compute an optimized [`FetchConfig`] based on the file size distribution.
35#[derive(Debug, Clone)]
36pub struct DownloadPlan {
37    /// Repository identifier (e.g., `"google/gemma-2-2b-it"`).
38    pub repo_id: String,
39    /// Resolved revision (commit hash or branch name).
40    pub revision: String,
41    /// Per-file plan entries.
42    pub files: Vec<FilePlan>,
43    /// Total bytes across all files (cached + uncached).
44    pub total_bytes: u64,
45    /// Bytes already present in local cache.
46    pub cached_bytes: u64,
47    /// Bytes that need downloading.
48    pub download_bytes: u64,
49}
50
51/// Per-file entry within a [`DownloadPlan`].
52#[derive(Debug, Clone)]
53pub struct FilePlan {
54    /// Filename within the repository.
55    pub filename: String,
56    /// File size in bytes (0 if unknown).
57    pub size: u64,
58    /// Whether the file is already cached locally.
59    pub cached: bool,
60}
61
62impl DownloadPlan {
63    /// Number of files that still need downloading.
64    #[must_use]
65    pub fn files_to_download(&self) -> usize {
66        self.files.iter().filter(|f| !f.cached).count()
67    }
68
69    /// Whether all files are already cached (download would be a no-op).
70    #[must_use]
71    pub const fn fully_cached(&self) -> bool {
72        self.download_bytes == 0
73    }
74
75    /// Computes an optimized [`FetchConfig`] based on the size distribution
76    /// of uncached files.
77    ///
78    /// The returned config has no `token`, `revision`, `on_progress`, or
79    /// glob filters set — only the performance-tuning fields (`concurrency`,
80    /// `connections_per_file`, `chunk_threshold`). Merge with user config
81    /// before use.
82    ///
83    /// # Errors
84    ///
85    /// Returns [`FetchError::InvalidPattern`] if the internal builder fails
86    /// (should not happen since no patterns are set).
87    pub fn recommended_config(&self) -> Result<FetchConfig, FetchError> {
88        self.recommended_config_builder().build()
89    }
90
91    /// Like [`recommended_config()`](Self::recommended_config) but returns a
92    /// [`FetchConfigBuilder`] so the caller can override specific fields.
93    #[must_use]
94    pub fn recommended_config_builder(&self) -> FetchConfigBuilder {
95        let uncached: Vec<u64> = self
96            .files
97            .iter()
98            .filter(|f| !f.cached)
99            .map(|f| f.size)
100            .collect();
101
102        let builder = FetchConfig::builder();
103
104        if uncached.is_empty() {
105            // All cached — defaults are fine, download will be a no-op.
106            return builder.concurrency(1);
107        }
108
109        let count = uncached.len();
110        let large_count = uncached
111            .iter()
112            .filter(|&&s| s >= LARGE_FILE_THRESHOLD)
113            .count();
114        let very_large = uncached.iter().any(|&s| s >= VERY_LARGE_FILE_THRESHOLD);
115        let small_count = uncached
116            .iter()
117            .filter(|&&s| s < SMALL_FILE_THRESHOLD)
118            .count();
119
120        // Strategy: few large files — maximize per-file parallelism.
121        if count <= 2 && large_count > 0 {
122            let connections = if very_large { 16 } else { 8 };
123            return builder
124                .concurrency(count.max(1))
125                .connections_per_file(connections)
126                .chunk_threshold(DEFAULT_CHUNK_THRESHOLD);
127        }
128
129        // Strategy: many small files — parallelize at file level.
130        // Only applies when there are NO large files; otherwise fall through
131        // to the mixed strategy which handles both small and large files.
132        if small_count > count / 2 && large_count == 0 {
133            return builder
134                .concurrency(8.min(count))
135                .connections_per_file(1)
136                .chunk_threshold(u64::MAX);
137        }
138
139        // Strategy: mixed — balanced defaults.
140        builder
141            .concurrency(4)
142            .connections_per_file(8)
143            .chunk_threshold(DEFAULT_CHUNK_THRESHOLD)
144    }
145}
146
147/// Computes a download plan for a repository without downloading anything.
148///
149/// Fetches remote file metadata and compares against the local cache to
150/// classify each file as cached or pending download. Glob filters from
151/// `config` are applied.
152///
153/// # Errors
154///
155/// Returns [`FetchError::Http`] if the `HuggingFace` API request fails.
156/// Returns [`FetchError::RepoNotFound`] if the repository does not exist.
157/// Returns [`FetchError::Io`] if the cache directory cannot be resolved.
158pub async fn download_plan(
159    repo_id: &str,
160    config: &FetchConfig,
161) -> Result<DownloadPlan, FetchError> {
162    let revision_str = config.revision.as_deref().unwrap_or("main");
163    let token = config.token.as_deref();
164
165    // Fetch remote file list with metadata.
166    let client = chunked::build_client(token)?;
167    let remote_files =
168        repo::list_repo_files_with_metadata(repo_id, token, Some(revision_str), &client).await?;
169
170    // Apply glob filters.
171    let filtered: Vec<_> = remote_files
172        .into_iter()
173        .filter(|f| {
174            // BORROW: explicit .as_str() instead of Deref coercion
175            config::file_matches(
176                f.filename.as_str(),
177                config.include.as_ref(),
178                config.exclude.as_ref(),
179            )
180        })
181        .collect();
182
183    // Resolve cache state.
184    let cache_dir = config
185        .output_dir
186        .clone()
187        .map_or_else(cache::hf_cache_dir, Ok)?;
188    let repo_dir = cache_layout::repo_dir(&cache_dir, repo_id);
189    let commit_hash = cache::read_ref(&repo_dir, revision_str);
190    let snapshot_dir = commit_hash
191        .as_deref()
192        .map(|hash| cache_layout::snapshot_dir(&repo_dir, hash));
193
194    let mut total_bytes: u64 = 0;
195    let mut cached_bytes: u64 = 0;
196    let mut files = Vec::with_capacity(filtered.len());
197
198    for rf in &filtered {
199        let size = rf.size.unwrap_or(0);
200        total_bytes = total_bytes.saturating_add(size);
201
202        let cached = snapshot_dir
203            .as_ref()
204            // BORROW: explicit .as_str() instead of Deref coercion
205            .is_some_and(|dir| dir.join(rf.filename.as_str()).exists());
206
207        if cached {
208            cached_bytes = cached_bytes.saturating_add(size);
209        }
210
211        files.push(FilePlan {
212            // BORROW: explicit .clone() for owned String field
213            filename: rf.filename.clone(),
214            size,
215            cached,
216        });
217    }
218
219    let download_bytes = total_bytes.saturating_sub(cached_bytes);
220
221    Ok(DownloadPlan {
222        // BORROW: explicit .to_owned() for &str → owned String
223        repo_id: repo_id.to_owned(),
224        // BORROW: explicit .to_owned() for &str → owned String
225        revision: commit_hash.unwrap_or_else(|| revision_str.to_owned()),
226        files,
227        total_bytes,
228        cached_bytes,
229        download_bytes,
230    })
231}
232
233#[cfg(test)]
234mod tests {
235    #![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
236
237    use super::*;
238
239    /// Builds a `DownloadPlan` from a list of `(size, cached)` pairs.
240    fn make_plan(file_specs: &[(u64, bool)]) -> DownloadPlan {
241        let mut total_bytes: u64 = 0;
242        let mut cached_bytes: u64 = 0;
243        let files: Vec<FilePlan> = file_specs
244            .iter()
245            .enumerate()
246            .map(|(i, &(size, cached))| {
247                total_bytes = total_bytes.saturating_add(size);
248                if cached {
249                    cached_bytes = cached_bytes.saturating_add(size);
250                }
251                FilePlan {
252                    filename: format!("file_{i}.bin"),
253                    size,
254                    cached,
255                }
256            })
257            .collect();
258
259        DownloadPlan {
260            repo_id: "test/repo".to_owned(),
261            revision: "main".to_owned(),
262            files,
263            total_bytes,
264            cached_bytes,
265            download_bytes: total_bytes.saturating_sub(cached_bytes),
266        }
267    }
268
269    #[test]
270    fn all_cached_returns_concurrency_one() {
271        let plan = make_plan(&[(1_000_000, true), (2_000_000, true)]);
272        assert!(plan.fully_cached());
273        assert_eq!(plan.files_to_download(), 0);
274        let config = plan.recommended_config().unwrap();
275        assert_eq!(config.concurrency(), 1);
276    }
277
278    #[test]
279    fn single_very_large_file_gets_sixteen_connections() {
280        // 6 GiB file, uncached.
281        let plan = make_plan(&[(6_442_450_944, false)]);
282        assert_eq!(plan.files_to_download(), 1);
283        let config = plan.recommended_config().unwrap();
284        assert_eq!(config.concurrency(), 1);
285        assert_eq!(config.connections_per_file(), 16);
286    }
287
288    #[test]
289    fn two_large_files_get_eight_connections() {
290        // Two 2 GiB files, uncached.
291        let plan = make_plan(&[(2_147_483_648, false), (2_147_483_648, false)]);
292        assert_eq!(plan.files_to_download(), 2);
293        let config = plan.recommended_config().unwrap();
294        assert_eq!(config.concurrency(), 2);
295        assert_eq!(config.connections_per_file(), 8);
296    }
297
298    #[test]
299    fn many_small_files_get_high_concurrency_single_connection() {
300        // 20 small files (1 MiB each), uncached.
301        let specs: Vec<(u64, bool)> = (0..20).map(|_| (1_048_576, false)).collect();
302        let plan = make_plan(&specs);
303        assert_eq!(plan.files_to_download(), 20);
304        let config = plan.recommended_config().unwrap();
305        assert_eq!(config.concurrency(), 8);
306        assert_eq!(config.connections_per_file(), 1);
307        assert_eq!(config.chunk_threshold(), u64::MAX);
308    }
309
310    #[test]
311    fn mixed_sizes_get_balanced_defaults() {
312        // Mix of large and medium files — fewer than half are small.
313        let plan = make_plan(&[
314            (2_147_483_648, false), // 2 GiB
315            (104_857_600, false),   // 100 MiB
316            (52_428_800, false),    // 50 MiB
317            (1_073_741_824, false), // 1 GiB
318            (20_971_520, false),    // 20 MiB
319        ]);
320        assert_eq!(plan.files_to_download(), 5);
321        let config = plan.recommended_config().unwrap();
322        assert_eq!(config.concurrency(), 4);
323        assert_eq!(config.connections_per_file(), 8);
324        assert_eq!(config.chunk_threshold(), DEFAULT_CHUNK_THRESHOLD);
325    }
326
327    #[test]
328    fn mostly_small_with_large_files_uses_mixed_strategy() {
329        // Mirrors the Ministral-3-3B case: 2 large files + 8 small files.
330        // Should NOT pick the "many small files" strategy because large
331        // files are present — falls through to "mixed" instead.
332        let plan = make_plan(&[
333            (4_672_561_152, false), // 4.35 GiB
334            (4_672_561_152, false), // 4.35 GiB
335            (2_355, false),         // 2.3 KiB
336            (1_946, false),         // 1.9 KiB
337            (131, false),           // 131 B
338            (1_229, false),         // 1.2 KiB
339            (976, false),           // 976 B
340            (16_756_736, false),    // 16 MiB
341            (17_081_344, false),    // 16.3 MiB
342            (21_197, false),        // 20.7 KiB
343        ]);
344        assert_eq!(plan.files_to_download(), 10);
345        let config = plan.recommended_config().unwrap();
346        // Mixed strategy: balanced concurrency with chunked downloads enabled.
347        assert_eq!(config.concurrency(), 4);
348        assert_eq!(config.connections_per_file(), 8);
349        assert_eq!(config.chunk_threshold(), DEFAULT_CHUNK_THRESHOLD);
350    }
351}