Skip to main content

minecraft_java_rs_core/net/
downloader.rs

1use std::collections::VecDeque;
2use std::path::PathBuf;
3use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use futures::StreamExt;
8use sha1::Digest;
9use tokio::io::AsyncWriteExt;
10use tokio::sync::Semaphore;
11use tokio::task::JoinSet;
12
13use crate::error::{DownloadError, LaunchError};
14use crate::launcher::events::LaunchEvent;
15
16const DOWNLOAD_MAX_RETRIES: u32 = 3;
17const DOWNLOAD_INITIAL_BACKOFF_MS: u64 = 500;
18
19// ── DownloadItem ──────────────────────────────────────────────────────────────
20
21#[derive(Debug, Clone)]
22pub struct DownloadItem {
23    /// Full URL to fetch.
24    pub url: String,
25    /// Absolute path to write the file to.
26    pub path: PathBuf,
27    /// Parent directory; created with `create_dir_all` before writing.
28    /// When empty the parent of `path` is used instead.
29    pub folder: PathBuf,
30    /// Human-readable name used in error messages and progress events.
31    pub name: String,
32    /// Expected file size in bytes (used for progress totals; 0 = unknown).
33    pub size: u64,
34    /// Category label emitted with `LaunchEvent::Progress` (e.g. "assets").
35    #[allow(clippy::pub_with_shorthand)]
36    pub r#type: Option<String>,
37    /// Expected SHA-1 hex digest.  When `Some`, the file is verified after
38    /// download; `DownloadError::ChecksumMismatch` is returned on mismatch.
39    pub sha1: Option<String>,
40}
41
42// ── Downloader ────────────────────────────────────────────────────────────────
43
44pub struct Downloader {
45    client: reqwest::Client,
46    /// Effective concurrency after applying the adaptive cap.
47    concurrency: usize,
48}
49
50impl Downloader {
51    pub fn new(
52        timeout_secs: u64,
53        concurrency: u32,
54        force_ipv4: bool,
55        dns: Option<std::net::IpAddr>,
56    ) -> Self {
57        let client = crate::net::client::build_client(timeout_secs, force_ipv4, dns)
58            .expect("failed to build reqwest client");
59        Self {
60            client,
61            concurrency: adaptive_concurrency(concurrency),
62        }
63    }
64
65    /// Download a single file.  No progress events are emitted.
66    pub async fn download_file(&self, item: &DownloadItem) -> Result<(), LaunchError> {
67        let counter = Arc::new(AtomicU64::new(0));
68        fetch_one(self.client.clone(), item, &counter)
69            .await
70            .map_err(LaunchError::Download)
71    }
72
73    /// Download many files concurrently, emitting `LaunchEvent` progress
74    /// notifications on `event_tx`.
75    ///
76    /// Events emitted:
77    /// - `Progress { downloaded, total, kind }` — file-count progress after
78    ///   each file completes, where `downloaded` = files done, `total` = total
79    ///   files.
80    /// - `Speed(bytes_per_sec)` — rolling 5-second average.
81    /// - `Estimated(secs)` — ETA in seconds at the current speed.
82    pub async fn download_multiple(
83        &self,
84        items: Vec<DownloadItem>,
85        event_tx: tokio::sync::mpsc::Sender<LaunchEvent>,
86    ) -> Result<(), LaunchError> {
87        if items.is_empty() {
88            return Ok(());
89        }
90
91        let total_bytes: u64 = items.iter().map(|i| i.size).sum();
92        let total_count = items.len() as u64;
93        let downloaded = Arc::new(AtomicU64::new(0));
94        let completed = Arc::new(AtomicUsize::new(0));
95
96        let semaphore = Arc::new(Semaphore::new(self.concurrency));
97        let mut join_set: JoinSet<Result<(), LaunchError>> = JoinSet::new();
98
99        for item in items {
100            let sem = Arc::clone(&semaphore);
101            let dl = Arc::clone(&downloaded);
102            let comp = Arc::clone(&completed);
103            let client = self.client.clone();
104            let tx = event_tx.clone();
105
106            join_set.spawn(async move {
107                let _permit = sem
108                    .acquire_owned()
109                    .await
110                    .map_err(|e| LaunchError::Archive(e.to_string()))?;
111
112                fetch_one(client, &item, &dl)
113                    .await
114                    .map_err(LaunchError::Download)?;
115
116                let done = comp.fetch_add(1, Ordering::Relaxed) as u64 + 1;
117                tx.send(LaunchEvent::Progress {
118                    downloaded: done,
119                    total: total_count,
120                    kind: item.r#type.clone().unwrap_or_default(),
121                })
122                .await
123                .ok();
124
125                Ok(())
126            });
127        }
128
129        // Sliding-window speed tracker (pure coordinator state — no sharing needed).
130        let mut speed_window: VecDeque<(Instant, u64)> = VecDeque::new();
131
132        while let Some(result) = join_set.join_next().await {
133            result.map_err(|e| LaunchError::Archive(e.to_string()))??;
134
135            let now = Instant::now();
136            let dl = downloaded.load(Ordering::Relaxed);
137            speed_window.push_back((now, dl));
138
139            // Evict samples older than 5 seconds.
140            while speed_window
141                .front()
142                .map_or(false, |(t, _)| now.duration_since(*t).as_secs_f64() > 5.0)
143            {
144                speed_window.pop_front();
145            }
146
147            if let Some((t0, b0)) = speed_window.front() {
148                let dt = now.duration_since(*t0).as_secs_f64();
149                if dt > 0.1 {
150                    let speed = dl.saturating_sub(*b0) as f64 / dt;
151                    event_tx.send(LaunchEvent::Speed(speed)).await.ok();
152                    if speed > 0.0 && total_bytes > 0 {
153                        let remaining = total_bytes.saturating_sub(dl) as f64 / speed;
154                        event_tx.send(LaunchEvent::Estimated(remaining)).await.ok();
155                    }
156                }
157            }
158        }
159
160        Ok(())
161    }
162
163    /// Returns `true` if a HEAD request to `url` succeeds with a 2xx status.
164    pub async fn check_url(&self, url: &str) -> bool {
165        self.client
166            .head(url)
167            .send()
168            .await
169            .map(|r| r.status().is_success())
170            .unwrap_or(false)
171    }
172
173    /// Iterate `mirrors` in order, appending `path`, and return the first URL
174    /// that responds successfully to a HEAD request.  Returns `None` if all
175    /// mirrors are unreachable.
176    pub async fn check_mirror(&self, mirrors: &[&str], path: &str) -> Option<String> {
177        let path = path.trim_start_matches('/');
178        for mirror in mirrors {
179            let url = format!("{}/{}", mirror.trim_end_matches('/'), path);
180            if self.check_url(&url).await {
181                return Some(url);
182            }
183        }
184        None
185    }
186}
187
188// ── Internal helpers ──────────────────────────────────────────────────────────
189
190/// Clamp user-requested concurrency to a system-aware upper bound.
191///
192/// Each active download holds roughly one TCP connection, a ~64 KB read buffer,
193/// and a Tokio task. High values (e.g. 400) exhaust file descriptors and network
194/// stack memory without any throughput gain. The cap is:
195///
196///   min(requested, cpu_cores × 8, 64).max(1)
197///
198/// This allows a 4-core machine to run up to 32 simultaneous downloads and an
199/// 8-core machine up to 64, while still honouring smaller values the caller sets.
200fn adaptive_concurrency(requested: u32) -> usize {
201    let cpu_count = std::thread::available_parallelism()
202        .map(|n| n.get())
203        .unwrap_or(4);
204    let cap = (cpu_count * 8).min(64).max(4);
205    (requested as usize).clamp(1, cap)
206}
207
208/// Returns true for HTTP status codes worth retrying.
209/// 4xx client errors are not retried since they won't change.
210fn is_retryable_status(status: reqwest::StatusCode) -> bool {
211    status.is_server_error() || status == reqwest::StatusCode::TOO_MANY_REQUESTS
212}
213
214/// Map a `reqwest::Error` to a `DownloadError`, distinguishing transport-level
215/// failures (no HTTP status — DNS, unreachable, reset, timeout) from
216/// server-side ones. Transport failures get a human-readable cause via
217/// [`describe_reqwest_error`](crate::net::client::describe_reqwest_error) so the
218/// user sees the real reason instead of "error sending request for url".
219fn classify_error(url: &str, e: reqwest::Error) -> DownloadError {
220    if e.status().is_some() {
221        DownloadError::Http(e)
222    } else {
223        DownloadError::Connection {
224            url: url.to_owned(),
225            detail: crate::net::client::describe_reqwest_error(&e),
226        }
227    }
228}
229
230/// Download `item` to disk, updating `dl_counter` with each received chunk.
231///
232/// Uses a temporary file (`<path>.tmp`) and an atomic rename so a failed or
233/// interrupted download never leaves a corrupt file at the final path.
234///
235/// Retries up to `DOWNLOAD_MAX_RETRIES` times on network errors, 5xx, and 429,
236/// with exponential backoff starting at `DOWNLOAD_INITIAL_BACKOFF_MS`.
237/// Checksum mismatches and I/O errors are not retried.
238async fn fetch_one(
239    client: reqwest::Client,
240    item: &DownloadItem,
241    dl_counter: &Arc<AtomicU64>,
242) -> Result<(), DownloadError> {
243    let dir = if item.folder.as_os_str().is_empty() {
244        item.path
245            .parent()
246            .map(|p| p.to_path_buf())
247            .unwrap_or_else(|| PathBuf::from("."))
248    } else {
249        item.folder.clone()
250    };
251    tokio::fs::create_dir_all(&dir).await?;
252
253    // Temporary path: `foo.jar` → `foo.jar.tmp`
254    let tmp_path = {
255        let mut s = item.path.as_os_str().to_owned();
256        s.push(".tmp");
257        PathBuf::from(s)
258    };
259
260    let mut last_err: Option<DownloadError> = None;
261    let mut backoff = DOWNLOAD_INITIAL_BACKOFF_MS;
262
263    for attempt in 0..=DOWNLOAD_MAX_RETRIES {
264        if attempt > 0 {
265            let _ = tokio::fs::remove_file(&tmp_path).await;
266            tokio::time::sleep(Duration::from_millis(backoff)).await;
267            backoff = (backoff * 2).min(8_000);
268        }
269
270        // ── Send request ──────────────────────────────────────────────────────
271        let response = match client.get(&item.url).send().await {
272            Ok(r) => r,
273            Err(e) => {
274                last_err = Some(classify_error(&item.url, e));
275                continue;
276            }
277        };
278
279        let status = response.status();
280        if is_retryable_status(status) {
281            last_err = Some(DownloadError::Http(
282                response.error_for_status().unwrap_err(),
283            ));
284            continue;
285        }
286        if !status.is_success() {
287            // 4xx — don't retry
288            return Err(DownloadError::Http(
289                response.error_for_status().unwrap_err(),
290            ));
291        }
292
293        // ── Stream body to temp file ──────────────────────────────────────────
294        let mut file = match tokio::fs::File::create(&tmp_path).await {
295            Ok(f) => f,
296            Err(e) => return Err(DownloadError::Io(e)),
297        };
298
299        let mut stream = response.bytes_stream();
300        let mut hasher = sha1::Sha1::new();
301        let verify = item.sha1.is_some();
302        let mut stream_err: Option<DownloadError> = None;
303
304        while let Some(chunk_result) = stream.next().await {
305            match chunk_result {
306                Ok(chunk) => {
307                    if let Err(e) = file.write_all(&chunk).await {
308                        return Err(DownloadError::Io(e));
309                    }
310                    if verify {
311                        hasher.update(&chunk);
312                    }
313                    dl_counter.fetch_add(chunk.len() as u64, Ordering::Relaxed);
314                }
315                Err(e) => {
316                    stream_err = Some(classify_error(&item.url, e));
317                    break;
318                }
319            }
320        }
321
322        if let Some(e) = stream_err {
323            last_err = Some(e);
324            continue;
325        }
326
327        if let Err(e) = file.flush().await {
328            return Err(DownloadError::Io(e));
329        }
330
331        // ── Checksum ──────────────────────────────────────────────────────────
332        if let Some(expected) = &item.sha1 {
333            let actual: String = hasher
334                .finalize()
335                .iter()
336                .map(|b| format!("{b:02x}"))
337                .collect();
338            if actual != *expected {
339                let _ = tokio::fs::remove_file(&tmp_path).await;
340                return Err(DownloadError::ChecksumMismatch {
341                    file: item.name.clone(),
342                    expected: expected.clone(),
343                    actual,
344                });
345            }
346        }
347
348        // ── Atomic rename ─────────────────────────────────────────────────────
349        if let Err(e) = tokio::fs::rename(&tmp_path, &item.path).await {
350            return Err(DownloadError::Io(e));
351        }
352
353        return Ok(());
354    }
355
356    let _ = tokio::fs::remove_file(&tmp_path).await;
357    Err(last_err.unwrap_or(DownloadError::Timeout))
358}
359
360// ── Tests ─────────────────────────────────────────────────────────────────────
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use tempfile::TempDir;
366    use tokio::sync::mpsc;
367
368    fn make_downloader() -> Downloader {
369        Downloader::new(5, 4, false, None)
370    }
371
372    #[test]
373    fn adaptive_concurrency_clamps_high_value() {
374        // Whatever the CPU count is, 400 must be reduced.
375        assert!(adaptive_concurrency(400) <= 64);
376    }
377
378    #[test]
379    fn adaptive_concurrency_preserves_low_value() {
380        assert_eq!(adaptive_concurrency(2), 2);
381        assert_eq!(adaptive_concurrency(1), 1);
382    }
383
384    #[test]
385    fn adaptive_concurrency_floors_at_one() {
386        assert_eq!(adaptive_concurrency(0), 1);
387    }
388
389    #[tokio::test]
390    async fn download_multiple_empty_list() {
391        let d = make_downloader();
392        let (tx, _rx) = mpsc::channel(16);
393        d.download_multiple(vec![], tx).await.unwrap();
394    }
395
396    #[tokio::test]
397    async fn download_file_bad_url_returns_error() {
398        let dir = TempDir::new().unwrap();
399        let item = DownloadItem {
400            url: "http://127.0.0.1:1/nonexistent".into(),
401            path: dir.path().join("out.bin"),
402            folder: dir.path().to_path_buf(),
403            name: "out.bin".into(),
404            size: 0,
405            r#type: None,
406            sha1: None,
407        };
408        let d = Downloader::new(1, 1, false, None); // 1-second timeout
409        let result = d.download_file(&item).await;
410        assert!(result.is_err());
411    }
412
413    #[tokio::test]
414    async fn check_url_unreachable_returns_false() {
415        let d = Downloader::new(1, 1, false, None);
416        assert!(!d.check_url("http://127.0.0.1:1/test").await);
417    }
418
419    #[tokio::test]
420    async fn check_mirror_all_bad_returns_none() {
421        let d = Downloader::new(1, 1, false, None);
422        let result = d
423            .check_mirror(&["http://127.0.0.1:1"], "/some/path.jar")
424            .await;
425        assert!(result.is_none());
426    }
427
428    #[tokio::test]
429    async fn download_multiple_bad_url_propagates_error() {
430        let dir = TempDir::new().unwrap();
431        let item = DownloadItem {
432            url: "http://127.0.0.1:1/nonexistent".into(),
433            path: dir.path().join("out.bin"),
434            folder: dir.path().to_path_buf(),
435            name: "out.bin".into(),
436            size: 0,
437            r#type: Some("test".into()),
438            sha1: None,
439        };
440        let d = Downloader::new(1, 1, false, None);
441        let (tx, _rx) = mpsc::channel(16);
442        let result = d.download_multiple(vec![item], tx).await;
443        assert!(result.is_err());
444    }
445
446    #[tokio::test]
447    async fn no_tmp_file_left_after_failed_download() {
448        let dir = TempDir::new().unwrap();
449        let path = dir.path().join("out.bin");
450        let item = DownloadItem {
451            url: "http://127.0.0.1:1/nonexistent".into(),
452            path: path.clone(),
453            folder: dir.path().to_path_buf(),
454            name: "out.bin".into(),
455            size: 0,
456            r#type: None,
457            sha1: None,
458        };
459        let d = Downloader::new(1, 1, false, None);
460        let _ = d.download_file(&item).await;
461
462        let tmp = {
463            let mut s = path.as_os_str().to_owned();
464            s.push(".tmp");
465            PathBuf::from(s)
466        };
467        assert!(
468            !tmp.exists(),
469            ".tmp file should be cleaned up after failure"
470        );
471    }
472}