Skip to main content

minecraft_java_rs_core/net/
downloader.rs

1use std::collections::VecDeque;
2use std::path::PathBuf;
3use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use futures::StreamExt;
8use sha1::Digest;
9use tokio::io::AsyncWriteExt;
10use tokio::sync::Semaphore;
11use tokio::task::JoinSet;
12
13use crate::error::{DownloadError, LaunchError};
14use crate::launcher::events::LaunchEvent;
15
16const DOWNLOAD_MAX_RETRIES: u32 = 3;
17const DOWNLOAD_INITIAL_BACKOFF_MS: u64 = 500;
18
19// ── DownloadItem ──────────────────────────────────────────────────────────────
20
21#[derive(Debug, Clone)]
22pub struct DownloadItem {
23    /// Full URL to fetch.
24    pub url: String,
25    /// Absolute path to write the file to.
26    pub path: PathBuf,
27    /// Parent directory; created with `create_dir_all` before writing.
28    /// When empty the parent of `path` is used instead.
29    pub folder: PathBuf,
30    /// Human-readable name used in error messages and progress events.
31    pub name: String,
32    /// Expected file size in bytes (used for progress totals; 0 = unknown).
33    pub size: u64,
34    /// Category label emitted with `LaunchEvent::Progress` (e.g. "assets").
35    #[allow(clippy::pub_with_shorthand)]
36    pub r#type: Option<String>,
37    /// Expected SHA-1 hex digest.  When `Some`, the file is verified after
38    /// download; `DownloadError::ChecksumMismatch` is returned on mismatch.
39    pub sha1: Option<String>,
40}
41
42// ── Downloader ────────────────────────────────────────────────────────────────
43
44pub struct Downloader {
45    client: reqwest::Client,
46    /// Effective concurrency after applying the adaptive cap.
47    concurrency: usize,
48}
49
50impl Downloader {
51    pub fn new(timeout_secs: u64, concurrency: u32) -> Self {
52        let client = reqwest::Client::builder()
53            .timeout(Duration::from_secs(timeout_secs))
54            .build()
55            .expect("failed to build reqwest client");
56        Self {
57            client,
58            concurrency: adaptive_concurrency(concurrency),
59        }
60    }
61
62    /// Download a single file.  No progress events are emitted.
63    pub async fn download_file(&self, item: &DownloadItem) -> Result<(), LaunchError> {
64        let counter = Arc::new(AtomicU64::new(0));
65        fetch_one(self.client.clone(), item, &counter)
66            .await
67            .map_err(LaunchError::Download)
68    }
69
70    /// Download many files concurrently, emitting `LaunchEvent` progress
71    /// notifications on `event_tx`.
72    ///
73    /// Events emitted:
74    /// - `Progress { downloaded, total, kind }` — file-count progress after
75    ///   each file completes, where `downloaded` = files done, `total` = total
76    ///   files.
77    /// - `Speed(bytes_per_sec)` — rolling 5-second average.
78    /// - `Estimated(secs)` — ETA in seconds at the current speed.
79    pub async fn download_multiple(
80        &self,
81        items: Vec<DownloadItem>,
82        event_tx: tokio::sync::mpsc::Sender<LaunchEvent>,
83    ) -> Result<(), LaunchError> {
84        if items.is_empty() {
85            return Ok(());
86        }
87
88        let total_bytes: u64 = items.iter().map(|i| i.size).sum();
89        let total_count = items.len() as u64;
90        let downloaded = Arc::new(AtomicU64::new(0));
91        let completed = Arc::new(AtomicUsize::new(0));
92
93        let semaphore = Arc::new(Semaphore::new(self.concurrency));
94        let mut join_set: JoinSet<Result<(), LaunchError>> = JoinSet::new();
95
96        for item in items {
97            let sem = Arc::clone(&semaphore);
98            let dl = Arc::clone(&downloaded);
99            let comp = Arc::clone(&completed);
100            let client = self.client.clone();
101            let tx = event_tx.clone();
102
103            join_set.spawn(async move {
104                let _permit = sem
105                    .acquire_owned()
106                    .await
107                    .map_err(|e| LaunchError::Archive(e.to_string()))?;
108
109                fetch_one(client, &item, &dl)
110                    .await
111                    .map_err(LaunchError::Download)?;
112
113                let done = comp.fetch_add(1, Ordering::Relaxed) as u64 + 1;
114                tx.send(LaunchEvent::Progress {
115                    downloaded: done,
116                    total: total_count,
117                    kind: item.r#type.clone().unwrap_or_default(),
118                })
119                .await
120                .ok();
121
122                Ok(())
123            });
124        }
125
126        // Sliding-window speed tracker (pure coordinator state — no sharing needed).
127        let mut speed_window: VecDeque<(Instant, u64)> = VecDeque::new();
128
129        while let Some(result) = join_set.join_next().await {
130            result.map_err(|e| LaunchError::Archive(e.to_string()))??;
131
132            let now = Instant::now();
133            let dl = downloaded.load(Ordering::Relaxed);
134            speed_window.push_back((now, dl));
135
136            // Evict samples older than 5 seconds.
137            while speed_window
138                .front()
139                .map_or(false, |(t, _)| now.duration_since(*t).as_secs_f64() > 5.0)
140            {
141                speed_window.pop_front();
142            }
143
144            if let Some((t0, b0)) = speed_window.front() {
145                let dt = now.duration_since(*t0).as_secs_f64();
146                if dt > 0.1 {
147                    let speed = dl.saturating_sub(*b0) as f64 / dt;
148                    event_tx.send(LaunchEvent::Speed(speed)).await.ok();
149                    if speed > 0.0 && total_bytes > 0 {
150                        let remaining = total_bytes.saturating_sub(dl) as f64 / speed;
151                        event_tx.send(LaunchEvent::Estimated(remaining)).await.ok();
152                    }
153                }
154            }
155        }
156
157        Ok(())
158    }
159
160    /// Returns `true` if a HEAD request to `url` succeeds with a 2xx status.
161    pub async fn check_url(&self, url: &str) -> bool {
162        self.client
163            .head(url)
164            .send()
165            .await
166            .map(|r| r.status().is_success())
167            .unwrap_or(false)
168    }
169
170    /// Iterate `mirrors` in order, appending `path`, and return the first URL
171    /// that responds successfully to a HEAD request.  Returns `None` if all
172    /// mirrors are unreachable.
173    pub async fn check_mirror(&self, mirrors: &[&str], path: &str) -> Option<String> {
174        let path = path.trim_start_matches('/');
175        for mirror in mirrors {
176            let url = format!("{}/{}", mirror.trim_end_matches('/'), path);
177            if self.check_url(&url).await {
178                return Some(url);
179            }
180        }
181        None
182    }
183}
184
185// ── Internal helpers ──────────────────────────────────────────────────────────
186
187/// Clamp user-requested concurrency to a system-aware upper bound.
188///
189/// Each active download holds roughly one TCP connection, a ~64 KB read buffer,
190/// and a Tokio task. High values (e.g. 400) exhaust file descriptors and network
191/// stack memory without any throughput gain. The cap is:
192///
193///   min(requested, cpu_cores × 8, 64).max(1)
194///
195/// This allows a 4-core machine to run up to 32 simultaneous downloads and an
196/// 8-core machine up to 64, while still honouring smaller values the caller sets.
197fn adaptive_concurrency(requested: u32) -> usize {
198    let cpu_count = std::thread::available_parallelism()
199        .map(|n| n.get())
200        .unwrap_or(4);
201    let cap = (cpu_count * 8).min(64).max(4);
202    (requested as usize).clamp(1, cap)
203}
204
205/// Returns true for HTTP status codes worth retrying.
206/// 4xx client errors are not retried since they won't change.
207fn is_retryable_status(status: reqwest::StatusCode) -> bool {
208    status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
209}
210
211/// Download `item` to disk, updating `dl_counter` with each received chunk.
212///
213/// Uses a temporary file (`<path>.tmp`) and an atomic rename so a failed or
214/// interrupted download never leaves a corrupt file at the final path.
215///
216/// Retries up to `DOWNLOAD_MAX_RETRIES` times on network errors, 5xx, and 429,
217/// with exponential backoff starting at `DOWNLOAD_INITIAL_BACKOFF_MS`.
218/// Checksum mismatches and I/O errors are not retried.
219async fn fetch_one(
220    client: reqwest::Client,
221    item: &DownloadItem,
222    dl_counter: &Arc<AtomicU64>,
223) -> Result<(), DownloadError> {
224    let dir = if item.folder.as_os_str().is_empty() {
225        item.path
226            .parent()
227            .map(|p| p.to_path_buf())
228            .unwrap_or_else(|| PathBuf::from("."))
229    } else {
230        item.folder.clone()
231    };
232    tokio::fs::create_dir_all(&dir).await?;
233
234    // Temporary path: `foo.jar` → `foo.jar.tmp`
235    let tmp_path = {
236        let mut s = item.path.as_os_str().to_owned();
237        s.push(".tmp");
238        PathBuf::from(s)
239    };
240
241    let mut last_err: Option<DownloadError> = None;
242    let mut backoff = DOWNLOAD_INITIAL_BACKOFF_MS;
243
244    for attempt in 0..=DOWNLOAD_MAX_RETRIES {
245        if attempt > 0 {
246            let _ = tokio::fs::remove_file(&tmp_path).await;
247            tokio::time::sleep(Duration::from_millis(backoff)).await;
248            backoff = (backoff * 2).min(8_000);
249        }
250
251        // ── Send request ──────────────────────────────────────────────────────
252        let response = match client.get(&item.url).send().await {
253            Ok(r) => r,
254            Err(e) => {
255                last_err = Some(DownloadError::Http(e));
256                continue;
257            }
258        };
259
260        let status = response.status();
261        if is_retryable_status(status) {
262            last_err = Some(DownloadError::Http(
263                response.error_for_status().unwrap_err(),
264            ));
265            continue;
266        }
267        if !status.is_success() {
268            // 4xx — don't retry
269            return Err(DownloadError::Http(
270                response.error_for_status().unwrap_err(),
271            ));
272        }
273
274        // ── Stream body to temp file ──────────────────────────────────────────
275        let mut file = match tokio::fs::File::create(&tmp_path).await {
276            Ok(f) => f,
277            Err(e) => return Err(DownloadError::Io(e)),
278        };
279
280        let mut stream = response.bytes_stream();
281        let mut hasher = sha1::Sha1::new();
282        let verify = item.sha1.is_some();
283        let mut stream_err: Option<DownloadError> = None;
284
285        while let Some(chunk_result) = stream.next().await {
286            match chunk_result {
287                Ok(chunk) => {
288                    if let Err(e) = file.write_all(&chunk).await {
289                        return Err(DownloadError::Io(e));
290                    }
291                    if verify {
292                        hasher.update(&chunk);
293                    }
294                    dl_counter.fetch_add(chunk.len() as u64, Ordering::Relaxed);
295                }
296                Err(e) => {
297                    stream_err = Some(DownloadError::Http(e));
298                    break;
299                }
300            }
301        }
302
303        if let Some(e) = stream_err {
304            last_err = Some(e);
305            continue;
306        }
307
308        if let Err(e) = file.flush().await {
309            return Err(DownloadError::Io(e));
310        }
311
312        // ── Checksum ──────────────────────────────────────────────────────────
313        if let Some(expected) = &item.sha1 {
314            let actual: String = hasher
315                .finalize()
316                .iter()
317                .map(|b| format!("{b:02x}"))
318                .collect();
319            if actual != *expected {
320                let _ = tokio::fs::remove_file(&tmp_path).await;
321                return Err(DownloadError::ChecksumMismatch {
322                    file: item.name.clone(),
323                    expected: expected.clone(),
324                    actual,
325                });
326            }
327        }
328
329        // ── Atomic rename ─────────────────────────────────────────────────────
330        if let Err(e) = tokio::fs::rename(&tmp_path, &item.path).await {
331            return Err(DownloadError::Io(e));
332        }
333
334        return Ok(());
335    }
336
337    let _ = tokio::fs::remove_file(&tmp_path).await;
338    Err(last_err.unwrap_or(DownloadError::Timeout))
339}
340
341// ── Tests ─────────────────────────────────────────────────────────────────────
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346    use tempfile::TempDir;
347    use tokio::sync::mpsc;
348
349    fn make_downloader() -> Downloader {
350        Downloader::new(5, 4)
351    }
352
353    #[test]
354    fn adaptive_concurrency_clamps_high_value() {
355        // Whatever the CPU count is, 400 must be reduced.
356        assert!(adaptive_concurrency(400) <= 64);
357    }
358
359    #[test]
360    fn adaptive_concurrency_preserves_low_value() {
361        assert_eq!(adaptive_concurrency(2), 2);
362        assert_eq!(adaptive_concurrency(1), 1);
363    }
364
365    #[test]
366    fn adaptive_concurrency_floors_at_one() {
367        assert_eq!(adaptive_concurrency(0), 1);
368    }
369
370    #[tokio::test]
371    async fn download_multiple_empty_list() {
372        let d = make_downloader();
373        let (tx, _rx) = mpsc::channel(16);
374        d.download_multiple(vec![], tx).await.unwrap();
375    }
376
377    #[tokio::test]
378    async fn download_file_bad_url_returns_error() {
379        let dir = TempDir::new().unwrap();
380        let item = DownloadItem {
381            url: "http://127.0.0.1:1/nonexistent".into(),
382            path: dir.path().join("out.bin"),
383            folder: dir.path().to_path_buf(),
384            name: "out.bin".into(),
385            size: 0,
386            r#type: None,
387            sha1: None,
388        };
389        let d = Downloader::new(1, 1); // 1-second timeout
390        let result = d.download_file(&item).await;
391        assert!(result.is_err());
392    }
393
394    #[tokio::test]
395    async fn check_url_unreachable_returns_false() {
396        let d = Downloader::new(1, 1);
397        assert!(!d.check_url("http://127.0.0.1:1/test").await);
398    }
399
400    #[tokio::test]
401    async fn check_mirror_all_bad_returns_none() {
402        let d = Downloader::new(1, 1);
403        let result = d
404            .check_mirror(&["http://127.0.0.1:1"], "/some/path.jar")
405            .await;
406        assert!(result.is_none());
407    }
408
409    #[tokio::test]
410    async fn download_multiple_bad_url_propagates_error() {
411        let dir = TempDir::new().unwrap();
412        let item = DownloadItem {
413            url: "http://127.0.0.1:1/nonexistent".into(),
414            path: dir.path().join("out.bin"),
415            folder: dir.path().to_path_buf(),
416            name: "out.bin".into(),
417            size: 0,
418            r#type: Some("test".into()),
419            sha1: None,
420        };
421        let d = Downloader::new(1, 1);
422        let (tx, _rx) = mpsc::channel(16);
423        let result = d.download_multiple(vec![item], tx).await;
424        assert!(result.is_err());
425    }
426
427    #[tokio::test]
428    async fn no_tmp_file_left_after_failed_download() {
429        let dir = TempDir::new().unwrap();
430        let path = dir.path().join("out.bin");
431        let item = DownloadItem {
432            url: "http://127.0.0.1:1/nonexistent".into(),
433            path: path.clone(),
434            folder: dir.path().to_path_buf(),
435            name: "out.bin".into(),
436            size: 0,
437            r#type: None,
438            sha1: None,
439        };
440        let d = Downloader::new(1, 1);
441        let _ = d.download_file(&item).await;
442
443        let tmp = {
444            let mut s = path.as_os_str().to_owned();
445            s.push(".tmp");
446            PathBuf::from(s)
447        };
448        assert!(!tmp.exists(), ".tmp file should be cleaned up after failure");
449    }
450}