Skip to main content

minecraft_java_rs_core/net/
downloader.rs

1use std::collections::VecDeque;
2use std::path::PathBuf;
3use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use futures::StreamExt;
8use sha1::Digest;
9use tokio::io::AsyncWriteExt;
10use tokio::sync::Semaphore;
11use tokio::task::JoinSet;
12
13use crate::error::{DownloadError, LaunchError};
14use crate::launcher::events::LaunchEvent;
15
16const DOWNLOAD_MAX_RETRIES: u32 = 3;
17const DOWNLOAD_INITIAL_BACKOFF_MS: u64 = 500;
18
19// ── DownloadItem ──────────────────────────────────────────────────────────────
20
21#[derive(Debug, Clone)]
22pub struct DownloadItem {
23    /// Full URL to fetch.
24    pub url: String,
25    /// Absolute path to write the file to.
26    pub path: PathBuf,
27    /// Parent directory; created with `create_dir_all` before writing.
28    /// When empty the parent of `path` is used instead.
29    pub folder: PathBuf,
30    /// Human-readable name used in error messages and progress events.
31    pub name: String,
32    /// Expected file size in bytes (used for progress totals; 0 = unknown).
33    pub size: u64,
34    /// Category label emitted with `LaunchEvent::Progress` (e.g. "assets").
35    #[allow(clippy::pub_with_shorthand)]
36    pub r#type: Option<String>,
37    /// Expected SHA-1 hex digest.  When `Some`, the file is verified after
38    /// download; `DownloadError::ChecksumMismatch` is returned on mismatch.
39    pub sha1: Option<String>,
40}
41
42// ── Downloader ────────────────────────────────────────────────────────────────
43
44pub struct Downloader {
45    client: reqwest::Client,
46    /// Effective concurrency after applying the adaptive cap.
47    concurrency: usize,
48}
49
50impl Downloader {
51    pub fn new(timeout_secs: u64, concurrency: u32, force_ipv4: bool) -> Self {
52        let client = crate::net::client::build_client(timeout_secs, force_ipv4)
53            .expect("failed to build reqwest client");
54        Self {
55            client,
56            concurrency: adaptive_concurrency(concurrency),
57        }
58    }
59
60    /// Download a single file.  No progress events are emitted.
61    pub async fn download_file(&self, item: &DownloadItem) -> Result<(), LaunchError> {
62        let counter = Arc::new(AtomicU64::new(0));
63        fetch_one(self.client.clone(), item, &counter)
64            .await
65            .map_err(LaunchError::Download)
66    }
67
68    /// Download many files concurrently, emitting `LaunchEvent` progress
69    /// notifications on `event_tx`.
70    ///
71    /// Events emitted:
72    /// - `Progress { downloaded, total, kind }` — file-count progress after
73    ///   each file completes, where `downloaded` = files done, `total` = total
74    ///   files.
75    /// - `Speed(bytes_per_sec)` — rolling 5-second average.
76    /// - `Estimated(secs)` — ETA in seconds at the current speed.
77    pub async fn download_multiple(
78        &self,
79        items: Vec<DownloadItem>,
80        event_tx: tokio::sync::mpsc::Sender<LaunchEvent>,
81    ) -> Result<(), LaunchError> {
82        if items.is_empty() {
83            return Ok(());
84        }
85
86        let total_bytes: u64 = items.iter().map(|i| i.size).sum();
87        let total_count = items.len() as u64;
88        let downloaded = Arc::new(AtomicU64::new(0));
89        let completed = Arc::new(AtomicUsize::new(0));
90
91        let semaphore = Arc::new(Semaphore::new(self.concurrency));
92        let mut join_set: JoinSet<Result<(), LaunchError>> = JoinSet::new();
93
94        for item in items {
95            let sem = Arc::clone(&semaphore);
96            let dl = Arc::clone(&downloaded);
97            let comp = Arc::clone(&completed);
98            let client = self.client.clone();
99            let tx = event_tx.clone();
100
101            join_set.spawn(async move {
102                let _permit = sem
103                    .acquire_owned()
104                    .await
105                    .map_err(|e| LaunchError::Archive(e.to_string()))?;
106
107                fetch_one(client, &item, &dl)
108                    .await
109                    .map_err(LaunchError::Download)?;
110
111                let done = comp.fetch_add(1, Ordering::Relaxed) as u64 + 1;
112                tx.send(LaunchEvent::Progress {
113                    downloaded: done,
114                    total: total_count,
115                    kind: item.r#type.clone().unwrap_or_default(),
116                })
117                .await
118                .ok();
119
120                Ok(())
121            });
122        }
123
124        // Sliding-window speed tracker (pure coordinator state — no sharing needed).
125        let mut speed_window: VecDeque<(Instant, u64)> = VecDeque::new();
126
127        while let Some(result) = join_set.join_next().await {
128            result.map_err(|e| LaunchError::Archive(e.to_string()))??;
129
130            let now = Instant::now();
131            let dl = downloaded.load(Ordering::Relaxed);
132            speed_window.push_back((now, dl));
133
134            // Evict samples older than 5 seconds.
135            while speed_window
136                .front()
137                .map_or(false, |(t, _)| now.duration_since(*t).as_secs_f64() > 5.0)
138            {
139                speed_window.pop_front();
140            }
141
142            if let Some((t0, b0)) = speed_window.front() {
143                let dt = now.duration_since(*t0).as_secs_f64();
144                if dt > 0.1 {
145                    let speed = dl.saturating_sub(*b0) as f64 / dt;
146                    event_tx.send(LaunchEvent::Speed(speed)).await.ok();
147                    if speed > 0.0 && total_bytes > 0 {
148                        let remaining = total_bytes.saturating_sub(dl) as f64 / speed;
149                        event_tx.send(LaunchEvent::Estimated(remaining)).await.ok();
150                    }
151                }
152            }
153        }
154
155        Ok(())
156    }
157
158    /// Returns `true` if a HEAD request to `url` succeeds with a 2xx status.
159    pub async fn check_url(&self, url: &str) -> bool {
160        self.client
161            .head(url)
162            .send()
163            .await
164            .map(|r| r.status().is_success())
165            .unwrap_or(false)
166    }
167
168    /// Iterate `mirrors` in order, appending `path`, and return the first URL
169    /// that responds successfully to a HEAD request.  Returns `None` if all
170    /// mirrors are unreachable.
171    pub async fn check_mirror(&self, mirrors: &[&str], path: &str) -> Option<String> {
172        let path = path.trim_start_matches('/');
173        for mirror in mirrors {
174            let url = format!("{}/{}", mirror.trim_end_matches('/'), path);
175            if self.check_url(&url).await {
176                return Some(url);
177            }
178        }
179        None
180    }
181}
182
183// ── Internal helpers ──────────────────────────────────────────────────────────
184
185/// Clamp user-requested concurrency to a system-aware upper bound.
186///
187/// Each active download holds roughly one TCP connection, a ~64 KB read buffer,
188/// and a Tokio task. High values (e.g. 400) exhaust file descriptors and network
189/// stack memory without any throughput gain. The cap is:
190///
191///   min(requested, cpu_cores × 8, 64).max(1)
192///
193/// This allows a 4-core machine to run up to 32 simultaneous downloads and an
194/// 8-core machine up to 64, while still honouring smaller values the caller sets.
195fn adaptive_concurrency(requested: u32) -> usize {
196    let cpu_count = std::thread::available_parallelism()
197        .map(|n| n.get())
198        .unwrap_or(4);
199    let cap = (cpu_count * 8).min(64).max(4);
200    (requested as usize).clamp(1, cap)
201}
202
203/// Returns true for HTTP status codes worth retrying.
204/// 4xx client errors are not retried since they won't change.
205fn is_retryable_status(status: reqwest::StatusCode) -> bool {
206    status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
207}
208
209/// Map a `reqwest::Error` to a `DownloadError`, distinguishing transport-level
210/// failures (no HTTP status — DNS, unreachable, reset, timeout) from
211/// server-side ones. Transport failures get a human-readable cause via
212/// [`describe_reqwest_error`](crate::net::client::describe_reqwest_error) so the
213/// user sees the real reason instead of "error sending request for url".
214fn classify_error(url: &str, e: reqwest::Error) -> DownloadError {
215    if e.status().is_some() {
216        DownloadError::Http(e)
217    } else {
218        DownloadError::Connection {
219            url: url.to_owned(),
220            detail: crate::net::client::describe_reqwest_error(&e),
221        }
222    }
223}
224
225/// Download `item` to disk, updating `dl_counter` with each received chunk.
226///
227/// Uses a temporary file (`<path>.tmp`) and an atomic rename so a failed or
228/// interrupted download never leaves a corrupt file at the final path.
229///
230/// Retries up to `DOWNLOAD_MAX_RETRIES` times on network errors, 5xx, and 429,
231/// with exponential backoff starting at `DOWNLOAD_INITIAL_BACKOFF_MS`.
232/// Checksum mismatches and I/O errors are not retried.
233async fn fetch_one(
234    client: reqwest::Client,
235    item: &DownloadItem,
236    dl_counter: &Arc<AtomicU64>,
237) -> Result<(), DownloadError> {
238    let dir = if item.folder.as_os_str().is_empty() {
239        item.path
240            .parent()
241            .map(|p| p.to_path_buf())
242            .unwrap_or_else(|| PathBuf::from("."))
243    } else {
244        item.folder.clone()
245    };
246    tokio::fs::create_dir_all(&dir).await?;
247
248    // Temporary path: `foo.jar` → `foo.jar.tmp`
249    let tmp_path = {
250        let mut s = item.path.as_os_str().to_owned();
251        s.push(".tmp");
252        PathBuf::from(s)
253    };
254
255    let mut last_err: Option<DownloadError> = None;
256    let mut backoff = DOWNLOAD_INITIAL_BACKOFF_MS;
257
258    for attempt in 0..=DOWNLOAD_MAX_RETRIES {
259        if attempt > 0 {
260            let _ = tokio::fs::remove_file(&tmp_path).await;
261            tokio::time::sleep(Duration::from_millis(backoff)).await;
262            backoff = (backoff * 2).min(8_000);
263        }
264
265        // ── Send request ──────────────────────────────────────────────────────
266        let response = match client.get(&item.url).send().await {
267            Ok(r) => r,
268            Err(e) => {
269                last_err = Some(classify_error(&item.url, e));
270                continue;
271            }
272        };
273
274        let status = response.status();
275        if is_retryable_status(status) {
276            last_err = Some(DownloadError::Http(
277                response.error_for_status().unwrap_err(),
278            ));
279            continue;
280        }
281        if !status.is_success() {
282            // 4xx — don't retry
283            return Err(DownloadError::Http(
284                response.error_for_status().unwrap_err(),
285            ));
286        }
287
288        // ── Stream body to temp file ──────────────────────────────────────────
289        let mut file = match tokio::fs::File::create(&tmp_path).await {
290            Ok(f) => f,
291            Err(e) => return Err(DownloadError::Io(e)),
292        };
293
294        let mut stream = response.bytes_stream();
295        let mut hasher = sha1::Sha1::new();
296        let verify = item.sha1.is_some();
297        let mut stream_err: Option<DownloadError> = None;
298
299        while let Some(chunk_result) = stream.next().await {
300            match chunk_result {
301                Ok(chunk) => {
302                    if let Err(e) = file.write_all(&chunk).await {
303                        return Err(DownloadError::Io(e));
304                    }
305                    if verify {
306                        hasher.update(&chunk);
307                    }
308                    dl_counter.fetch_add(chunk.len() as u64, Ordering::Relaxed);
309                }
310                Err(e) => {
311                    stream_err = Some(classify_error(&item.url, e));
312                    break;
313                }
314            }
315        }
316
317        if let Some(e) = stream_err {
318            last_err = Some(e);
319            continue;
320        }
321
322        if let Err(e) = file.flush().await {
323            return Err(DownloadError::Io(e));
324        }
325
326        // ── Checksum ──────────────────────────────────────────────────────────
327        if let Some(expected) = &item.sha1 {
328            let actual: String = hasher
329                .finalize()
330                .iter()
331                .map(|b| format!("{b:02x}"))
332                .collect();
333            if actual != *expected {
334                let _ = tokio::fs::remove_file(&tmp_path).await;
335                return Err(DownloadError::ChecksumMismatch {
336                    file: item.name.clone(),
337                    expected: expected.clone(),
338                    actual,
339                });
340            }
341        }
342
343        // ── Atomic rename ─────────────────────────────────────────────────────
344        if let Err(e) = tokio::fs::rename(&tmp_path, &item.path).await {
345            return Err(DownloadError::Io(e));
346        }
347
348        return Ok(());
349    }
350
351    let _ = tokio::fs::remove_file(&tmp_path).await;
352    Err(last_err.unwrap_or(DownloadError::Timeout))
353}
354
355// ── Tests ─────────────────────────────────────────────────────────────────────
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use tempfile::TempDir;
361    use tokio::sync::mpsc;
362
363    fn make_downloader() -> Downloader {
364        Downloader::new(5, 4, false)
365    }
366
367    #[test]
368    fn adaptive_concurrency_clamps_high_value() {
369        // Whatever the CPU count is, 400 must be reduced.
370        assert!(adaptive_concurrency(400) <= 64);
371    }
372
373    #[test]
374    fn adaptive_concurrency_preserves_low_value() {
375        assert_eq!(adaptive_concurrency(2), 2);
376        assert_eq!(adaptive_concurrency(1), 1);
377    }
378
379    #[test]
380    fn adaptive_concurrency_floors_at_one() {
381        assert_eq!(adaptive_concurrency(0), 1);
382    }
383
384    #[tokio::test]
385    async fn download_multiple_empty_list() {
386        let d = make_downloader();
387        let (tx, _rx) = mpsc::channel(16);
388        d.download_multiple(vec![], tx).await.unwrap();
389    }
390
391    #[tokio::test]
392    async fn download_file_bad_url_returns_error() {
393        let dir = TempDir::new().unwrap();
394        let item = DownloadItem {
395            url: "http://127.0.0.1:1/nonexistent".into(),
396            path: dir.path().join("out.bin"),
397            folder: dir.path().to_path_buf(),
398            name: "out.bin".into(),
399            size: 0,
400            r#type: None,
401            sha1: None,
402        };
403        let d = Downloader::new(1, 1, false); // 1-second timeout
404        let result = d.download_file(&item).await;
405        assert!(result.is_err());
406    }
407
408    #[tokio::test]
409    async fn check_url_unreachable_returns_false() {
410        let d = Downloader::new(1, 1, false);
411        assert!(!d.check_url("http://127.0.0.1:1/test").await);
412    }
413
414    #[tokio::test]
415    async fn check_mirror_all_bad_returns_none() {
416        let d = Downloader::new(1, 1, false);
417        let result = d
418            .check_mirror(&["http://127.0.0.1:1"], "/some/path.jar")
419            .await;
420        assert!(result.is_none());
421    }
422
423    #[tokio::test]
424    async fn download_multiple_bad_url_propagates_error() {
425        let dir = TempDir::new().unwrap();
426        let item = DownloadItem {
427            url: "http://127.0.0.1:1/nonexistent".into(),
428            path: dir.path().join("out.bin"),
429            folder: dir.path().to_path_buf(),
430            name: "out.bin".into(),
431            size: 0,
432            r#type: Some("test".into()),
433            sha1: None,
434        };
435        let d = Downloader::new(1, 1, false);
436        let (tx, _rx) = mpsc::channel(16);
437        let result = d.download_multiple(vec![item], tx).await;
438        assert!(result.is_err());
439    }
440
441    #[tokio::test]
442    async fn no_tmp_file_left_after_failed_download() {
443        let dir = TempDir::new().unwrap();
444        let path = dir.path().join("out.bin");
445        let item = DownloadItem {
446            url: "http://127.0.0.1:1/nonexistent".into(),
447            path: path.clone(),
448            folder: dir.path().to_path_buf(),
449            name: "out.bin".into(),
450            size: 0,
451            r#type: None,
452            sha1: None,
453        };
454        let d = Downloader::new(1, 1, false);
455        let _ = d.download_file(&item).await;
456
457        let tmp = {
458            let mut s = path.as_os_str().to_owned();
459            s.push(".tmp");
460            PathBuf::from(s)
461        };
462        assert!(
463            !tmp.exists(),
464            ".tmp file should be cleaned up after failure"
465        );
466    }
467}