Skip to main content

hf_fetch_model/
plan.rs

1// SPDX-License-Identifier: MIT OR Apache-2.0
2
3//! Download plan: metadata-only analysis of what needs downloading.
4//!
5//! A [`DownloadPlan`] describes which files in a remote `HuggingFace`
6//! repository need downloading and which are already cached locally.
7//! Use [`download_plan()`] to compute a plan, then inspect it or pass
8//! it to [`recommended_config()`](DownloadPlan::recommended_config) for
9//! optimized download settings.
10
11use crate::cache;
12use crate::chunked;
13use crate::config::{self, FetchConfig, FetchConfigBuilder};
14use crate::error::FetchError;
15use crate::repo;
16
17/// Size threshold for "large" files (1 GiB).
18const LARGE_FILE_THRESHOLD: u64 = 1_073_741_824;
19
20/// Size threshold for "very large" files (5 GiB).
21const VERY_LARGE_FILE_THRESHOLD: u64 = 5_368_709_120;
22
23/// Size threshold for "small" files (10 MiB).
24const SMALL_FILE_THRESHOLD: u64 = 10_485_760;
25
26/// Default chunk threshold (100 MiB).
27const DEFAULT_CHUNK_THRESHOLD: u64 = 104_857_600;
28
29/// A download plan describing which files need downloading and which are cached.
30///
31/// Created by [`download_plan()`]. Contains per-file metadata and aggregate
32/// byte counts. Use [`recommended_config()`](Self::recommended_config) to
33/// compute an optimized [`FetchConfig`] based on the file size distribution.
34#[derive(Debug, Clone)]
35pub struct DownloadPlan {
36    /// Repository identifier (e.g., `"google/gemma-2-2b-it"`).
37    pub repo_id: String,
38    /// Resolved revision (commit hash or branch name).
39    pub revision: String,
40    /// Per-file plan entries.
41    pub files: Vec<FilePlan>,
42    /// Total bytes across all files (cached + uncached).
43    pub total_bytes: u64,
44    /// Bytes already present in local cache.
45    pub cached_bytes: u64,
46    /// Bytes that need downloading.
47    pub download_bytes: u64,
48}
49
50/// Per-file entry within a [`DownloadPlan`].
51#[derive(Debug, Clone)]
52pub struct FilePlan {
53    /// Filename within the repository.
54    pub filename: String,
55    /// File size in bytes (0 if unknown).
56    pub size: u64,
57    /// Whether the file is already cached locally.
58    pub cached: bool,
59}
60
61impl DownloadPlan {
62    /// Number of files that still need downloading.
63    #[must_use]
64    pub fn files_to_download(&self) -> usize {
65        self.files.iter().filter(|f| !f.cached).count()
66    }
67
68    /// Whether all files are already cached (download would be a no-op).
69    #[must_use]
70    pub const fn fully_cached(&self) -> bool {
71        self.download_bytes == 0
72    }
73
74    /// Computes an optimized [`FetchConfig`] based on the size distribution
75    /// of uncached files.
76    ///
77    /// The returned config has no `token`, `revision`, `on_progress`, or
78    /// glob filters set — only the performance-tuning fields (`concurrency`,
79    /// `connections_per_file`, `chunk_threshold`). Merge with user config
80    /// before use.
81    ///
82    /// # Errors
83    ///
84    /// Returns [`FetchError::InvalidPattern`] if the internal builder fails
85    /// (should not happen since no patterns are set).
86    pub fn recommended_config(&self) -> Result<FetchConfig, FetchError> {
87        self.recommended_config_builder().build()
88    }
89
90    /// Like [`recommended_config()`](Self::recommended_config) but returns a
91    /// [`FetchConfigBuilder`] so the caller can override specific fields.
92    #[must_use]
93    pub fn recommended_config_builder(&self) -> FetchConfigBuilder {
94        let uncached: Vec<u64> = self
95            .files
96            .iter()
97            .filter(|f| !f.cached)
98            .map(|f| f.size)
99            .collect();
100
101        let builder = FetchConfig::builder();
102
103        if uncached.is_empty() {
104            // All cached — defaults are fine, download will be a no-op.
105            return builder.concurrency(1);
106        }
107
108        let count = uncached.len();
109        let large_count = uncached
110            .iter()
111            .filter(|&&s| s >= LARGE_FILE_THRESHOLD)
112            .count();
113        let very_large = uncached.iter().any(|&s| s >= VERY_LARGE_FILE_THRESHOLD);
114        let small_count = uncached
115            .iter()
116            .filter(|&&s| s < SMALL_FILE_THRESHOLD)
117            .count();
118
119        // Strategy: few large files — maximize per-file parallelism.
120        if count <= 2 && large_count > 0 {
121            let connections = if very_large { 16 } else { 8 };
122            return builder
123                .concurrency(count.max(1))
124                .connections_per_file(connections)
125                .chunk_threshold(DEFAULT_CHUNK_THRESHOLD);
126        }
127
128        // Strategy: many small files — parallelize at file level.
129        // Only applies when there are NO large files; otherwise fall through
130        // to the mixed strategy which handles both small and large files.
131        if small_count > count / 2 && large_count == 0 {
132            return builder
133                .concurrency(8.min(count))
134                .connections_per_file(1)
135                .chunk_threshold(u64::MAX);
136        }
137
138        // Strategy: mixed — balanced defaults.
139        builder
140            .concurrency(4)
141            .connections_per_file(8)
142            .chunk_threshold(DEFAULT_CHUNK_THRESHOLD)
143    }
144}
145
146/// Computes a download plan for a repository without downloading anything.
147///
148/// Fetches remote file metadata and compares against the local cache to
149/// classify each file as cached or pending download. Glob filters from
150/// `config` are applied.
151///
152/// # Errors
153///
154/// Returns [`FetchError::Http`] if the `HuggingFace` API request fails.
155/// Returns [`FetchError::RepoNotFound`] if the repository does not exist.
156/// Returns [`FetchError::Io`] if the cache directory cannot be resolved.
157pub async fn download_plan(
158    repo_id: &str,
159    config: &FetchConfig,
160) -> Result<DownloadPlan, FetchError> {
161    let revision_str = config.revision.as_deref().unwrap_or("main");
162    let token = config.token.as_deref();
163
164    // Fetch remote file list with metadata.
165    let remote_files =
166        repo::list_repo_files_with_metadata(repo_id, token, Some(revision_str)).await?;
167
168    // Apply glob filters.
169    let filtered: Vec<_> = remote_files
170        .into_iter()
171        .filter(|f| {
172            // BORROW: explicit .as_str() instead of Deref coercion
173            config::file_matches(
174                f.filename.as_str(),
175                config.include.as_ref(),
176                config.exclude.as_ref(),
177            )
178        })
179        .collect();
180
181    // Resolve cache state.
182    let cache_dir = config
183        .output_dir
184        .clone()
185        .map_or_else(cache::hf_cache_dir, Ok)?;
186    let repo_folder = chunked::repo_folder_name(repo_id);
187    // BORROW: explicit .as_str() instead of Deref coercion
188    let repo_dir = cache_dir.join(repo_folder.as_str());
189    let commit_hash = cache::read_ref(&repo_dir, revision_str);
190    let snapshot_dir = commit_hash
191        .as_deref()
192        .map(|hash| repo_dir.join("snapshots").join(hash));
193
194    let mut total_bytes: u64 = 0;
195    let mut cached_bytes: u64 = 0;
196    let mut files = Vec::with_capacity(filtered.len());
197
198    for rf in &filtered {
199        let size = rf.size.unwrap_or(0);
200        total_bytes = total_bytes.saturating_add(size);
201
202        let cached = snapshot_dir
203            .as_ref()
204            // BORROW: explicit .as_str() instead of Deref coercion
205            .is_some_and(|dir| dir.join(rf.filename.as_str()).exists());
206
207        if cached {
208            cached_bytes = cached_bytes.saturating_add(size);
209        }
210
211        files.push(FilePlan {
212            // BORROW: explicit .clone() for owned String field
213            filename: rf.filename.clone(),
214            size,
215            cached,
216        });
217    }
218
219    let download_bytes = total_bytes.saturating_sub(cached_bytes);
220
221    Ok(DownloadPlan {
222        // BORROW: explicit .to_owned() for &str → owned String
223        repo_id: repo_id.to_owned(),
224        // BORROW: explicit .to_owned() for &str → owned String
225        revision: commit_hash.unwrap_or_else(|| revision_str.to_owned()),
226        files,
227        total_bytes,
228        cached_bytes,
229        download_bytes,
230    })
231}
232
233#[cfg(test)]
234mod tests {
235    #![allow(clippy::panic, clippy::unwrap_used, clippy::expect_used)]
236
237    use super::*;
238
239    /// Builds a `DownloadPlan` from a list of `(size, cached)` pairs.
240    fn make_plan(file_specs: &[(u64, bool)]) -> DownloadPlan {
241        let mut total_bytes: u64 = 0;
242        let mut cached_bytes: u64 = 0;
243        let files: Vec<FilePlan> = file_specs
244            .iter()
245            .enumerate()
246            .map(|(i, &(size, cached))| {
247                total_bytes = total_bytes.saturating_add(size);
248                if cached {
249                    cached_bytes = cached_bytes.saturating_add(size);
250                }
251                FilePlan {
252                    filename: format!("file_{i}.bin"),
253                    size,
254                    cached,
255                }
256            })
257            .collect();
258
259        DownloadPlan {
260            repo_id: "test/repo".to_owned(),
261            revision: "main".to_owned(),
262            files,
263            total_bytes,
264            cached_bytes,
265            download_bytes: total_bytes.saturating_sub(cached_bytes),
266        }
267    }
268
269    #[test]
270    fn all_cached_returns_concurrency_one() {
271        let plan = make_plan(&[(1_000_000, true), (2_000_000, true)]);
272        assert!(plan.fully_cached());
273        assert_eq!(plan.files_to_download(), 0);
274        let config = plan.recommended_config().unwrap();
275        assert_eq!(config.concurrency(), 1);
276    }
277
278    #[test]
279    fn single_very_large_file_gets_sixteen_connections() {
280        // 6 GiB file, uncached.
281        let plan = make_plan(&[(6_442_450_944, false)]);
282        assert_eq!(plan.files_to_download(), 1);
283        let config = plan.recommended_config().unwrap();
284        assert_eq!(config.concurrency(), 1);
285        assert_eq!(config.connections_per_file(), 16);
286    }
287
288    #[test]
289    fn two_large_files_get_eight_connections() {
290        // Two 2 GiB files, uncached.
291        let plan = make_plan(&[(2_147_483_648, false), (2_147_483_648, false)]);
292        assert_eq!(plan.files_to_download(), 2);
293        let config = plan.recommended_config().unwrap();
294        assert_eq!(config.concurrency(), 2);
295        assert_eq!(config.connections_per_file(), 8);
296    }
297
298    #[test]
299    fn many_small_files_get_high_concurrency_single_connection() {
300        // 20 small files (1 MiB each), uncached.
301        let specs: Vec<(u64, bool)> = (0..20).map(|_| (1_048_576, false)).collect();
302        let plan = make_plan(&specs);
303        assert_eq!(plan.files_to_download(), 20);
304        let config = plan.recommended_config().unwrap();
305        assert_eq!(config.concurrency(), 8);
306        assert_eq!(config.connections_per_file(), 1);
307        assert_eq!(config.chunk_threshold(), u64::MAX);
308    }
309
310    #[test]
311    fn mixed_sizes_get_balanced_defaults() {
312        // Mix of large and medium files — fewer than half are small.
313        let plan = make_plan(&[
314            (2_147_483_648, false), // 2 GiB
315            (104_857_600, false),   // 100 MiB
316            (52_428_800, false),    // 50 MiB
317            (1_073_741_824, false), // 1 GiB
318            (20_971_520, false),    // 20 MiB
319        ]);
320        assert_eq!(plan.files_to_download(), 5);
321        let config = plan.recommended_config().unwrap();
322        assert_eq!(config.concurrency(), 4);
323        assert_eq!(config.connections_per_file(), 8);
324        assert_eq!(config.chunk_threshold(), DEFAULT_CHUNK_THRESHOLD);
325    }
326
327    #[test]
328    fn mostly_small_with_large_files_uses_mixed_strategy() {
329        // Mirrors the Ministral-3-3B case: 2 large files + 8 small files.
330        // Should NOT pick the "many small files" strategy because large
331        // files are present — falls through to "mixed" instead.
332        let plan = make_plan(&[
333            (4_672_561_152, false), // 4.35 GiB
334            (4_672_561_152, false), // 4.35 GiB
335            (2_355, false),         // 2.3 KiB
336            (1_946, false),         // 1.9 KiB
337            (131, false),           // 131 B
338            (1_229, false),         // 1.2 KiB
339            (976, false),           // 976 B
340            (16_756_736, false),    // 16 MiB
341            (17_081_344, false),    // 16.3 MiB
342            (21_197, false),        // 20.7 KiB
343        ]);
344        assert_eq!(plan.files_to_download(), 10);
345        let config = plan.recommended_config().unwrap();
346        // Mixed strategy: balanced concurrency with chunked downloads enabled.
347        assert_eq!(config.concurrency(), 4);
348        assert_eq!(config.connections_per_file(), 8);
349        assert_eq!(config.chunk_threshold(), DEFAULT_CHUNK_THRESHOLD);
350    }
351}