download_extract_progress/
download.rs

1use std::path::Path;
2
3use anyhow::Context;
4use async_stream::stream;
5use futures_core::Stream;
6use futures_util::StreamExt;
7use sha2::{Digest, Sha256};
8use tokio::{fs::File, io::AsyncWriteExt};
9
10use crate::{error::DownloadError, github_releases};
11
12/// Downloads a file from the given URL to the specified path, reporting progress as a stream.
13///
14/// # Arguments
15/// * `display_name` - A name for display in progress messages.
16/// * `url` - The URL to download from.
17/// * `path` - The destination file path.
18/// * `expected_hash` - Optional SHA256 hash to verify the download.
19///
20/// # Returns
21/// A stream yielding progress (0.0-1.0) and status messages, or errors.
22pub async fn download<T: AsRef<Path>>(
23    display_name: &str,
24    url: &str,
25    path: T,
26    expected_hash: Option<String>,
27) -> impl Stream<Item = Result<(f32, String), DownloadError>> + use<T> {
28    let display_name = display_name.to_string();
29    let url = url.to_string();
30
31    let client = reqwest::Client::new();
32    stream! {
33        yield Ok((0.0, format!("Downloading {display_name}")));
34
35        let res = client.get(url).send().await;
36        if let Err(e) = res {
37            yield Err(DownloadError::RequestError(e));
38            return;
39        }
40
41        let res = res.unwrap();
42        let length = res.content_length().unwrap_or(0);
43
44        let mut bytes_stream = res.bytes_stream();
45
46        let tmp_file = File::create_new(&path)
47            .await
48            .map_err(|e| DownloadError::IoError(e));
49
50        if let Err(e) = tmp_file {
51            yield Err(e);
52            return;
53        }
54
55        let mut curr_len = 0;
56        let mut hasher = Sha256::new();
57
58        let mut tmp_file = tmp_file.unwrap();
59        while let Some(chunk) = bytes_stream.next().await {
60            if let Err(e) = chunk {
61                yield Err(DownloadError::RequestError(e));
62                return;
63            }
64
65            let chunk = chunk.unwrap();
66            hasher.update(&chunk);
67            let r = tmp_file.write_all(&chunk).await;
68            if let Err(e) = r {
69                yield Err(DownloadError::IoError(e));
70                return;
71            }
72
73            curr_len = std::cmp::min(curr_len + chunk.len() as u64, length);
74            yield Ok((curr_len as  f32 / length as f32, format!("Downloading {display_name}")));
75        }
76
77        if let Some(expected_hash) = expected_hash {
78            let remote_hash = hex::decode(expected_hash);
79            if let Err(e) = remote_hash {
80                yield Err(DownloadError::InvalidHash(e));
81                return;
82            }
83
84            let remote_hash = remote_hash.unwrap();
85
86            // Calculating local hash
87            let local_hash = hasher.finalize();
88            if local_hash.as_slice() != remote_hash {
89                let remote_hash = hex::encode(remote_hash);
90                let local_hash = hex::encode(local_hash);
91
92                yield Err(DownloadError::HashMismatch(remote_hash, local_hash));
93                return;
94            }
95
96            log::trace!("Hashes match");
97        }
98    }
99}
100
101/// Downloads the latest GitHub release asset matching a predicate, with optional hash verification.
102///
103/// # Arguments
104/// * `repo` - GitHub repo in `owner/name` format.
105/// * `is_valid_file` - Predicate to select the asset.
106/// * `path` - Destination file path.
107/// * `hash_url` - Optional URL to a hash file.
108///
109/// # Returns
110/// A stream yielding progress and status, or errors.
111pub async fn download_github<T: AsRef<Path>, K>(
112    repo: &str,
113    is_valid_file: K,
114    path: T,
115    hash_url: Option<String>,
116) -> anyhow::Result<impl Stream<Item = Result<(f32, String), DownloadError>>>
117where
118    K: Fn(&str) -> bool,
119{
120    let display_name = repo.split('/').last().unwrap_or("unknown");
121
122    let client = reqwest::Client::builder()
123        .user_agent("github-releases-downloader/1.0")
124        .build()
125        .context("Creating HTTP client")?;
126
127    let releases: github_releases::Root = client
128        .get(format!("https://api.github.com/repos/{repo}/releases"))
129        .send().await?
130        .json().await?;
131
132    let latest_version = releases
133        .iter()
134        .max_by_key(|r| &r.published_at)
135        .context("Finding latest version")?;
136
137    let archive_url = latest_version
138        .assets
139        .iter()
140        .find(|a| is_valid_file(&a.name))
141        .context("Finding zip asset")?
142        .browser_download_url
143        .clone();
144
145    let display_name = display_name.to_string();
146    let expected_hash = if let Some(url) = hash_url {
147        let hash: String = reqwest::get(url).await?.text().await?;
148        Some(hash.trim().to_string())
149    } else {
150        None
151    };
152
153    Ok(download(&display_name, &archive_url, path, expected_hash).await)
154}