download_extract_progress/
download.rs1use std::path::Path;
2
3use anyhow::Context;
4use async_stream::stream;
5use futures_core::Stream;
6use futures_util::StreamExt;
7use sha2::{Digest, Sha256};
8use tokio::{fs::File, io::AsyncWriteExt};
9
10use crate::{error::DownloadError, github_releases};
11
12pub async fn download<T: AsRef<Path>>(
23 display_name: &str,
24 url: &str,
25 path: T,
26 expected_hash: Option<String>,
27) -> impl Stream<Item = Result<(f32, String), DownloadError>> + use<T> {
28 let display_name = display_name.to_string();
29 let url = url.to_string();
30
31 let client = reqwest::Client::new();
32 stream! {
33 yield Ok((0.0, format!("Downloading {display_name}")));
34
35 let res = client.get(url).send().await;
36 if let Err(e) = res {
37 yield Err(DownloadError::RequestError(e));
38 return;
39 }
40
41 let res = res.unwrap();
42 let length = res.content_length().unwrap_or(0);
43
44 let mut bytes_stream = res.bytes_stream();
45
46 let tmp_file = File::create_new(&path)
47 .await
48 .map_err(|e| DownloadError::IoError(e));
49
50 if let Err(e) = tmp_file {
51 yield Err(e);
52 return;
53 }
54
55 let mut curr_len = 0;
56 let mut hasher = Sha256::new();
57
58 let mut tmp_file = tmp_file.unwrap();
59 while let Some(chunk) = bytes_stream.next().await {
60 if let Err(e) = chunk {
61 yield Err(DownloadError::RequestError(e));
62 return;
63 }
64
65 let chunk = chunk.unwrap();
66 hasher.update(&chunk);
67 let r = tmp_file.write_all(&chunk).await;
68 if let Err(e) = r {
69 yield Err(DownloadError::IoError(e));
70 return;
71 }
72
73 curr_len = std::cmp::min(curr_len + chunk.len() as u64, length);
74 yield Ok((curr_len as f32 / length as f32, format!("Downloading {display_name}")));
75 }
76
77 if let Some(expected_hash) = expected_hash {
78 let remote_hash = hex::decode(expected_hash);
79 if let Err(e) = remote_hash {
80 yield Err(DownloadError::InvalidHash(e));
81 return;
82 }
83
84 let remote_hash = remote_hash.unwrap();
85
86 let local_hash = hasher.finalize();
88 if local_hash.as_slice() != remote_hash {
89 let remote_hash = hex::encode(remote_hash);
90 let local_hash = hex::encode(local_hash);
91
92 yield Err(DownloadError::HashMismatch(remote_hash, local_hash));
93 return;
94 }
95
96 log::trace!("Hashes match");
97 }
98 }
99}
100
101pub async fn download_github<T: AsRef<Path>, K>(
112 repo: &str,
113 is_valid_file: K,
114 path: T,
115 hash_url: Option<String>,
116) -> anyhow::Result<impl Stream<Item = Result<(f32, String), DownloadError>>>
117where
118 K: Fn(&str) -> bool,
119{
120 let display_name = repo.split('/').last().unwrap_or("unknown");
121
122 let client = reqwest::Client::builder()
123 .user_agent("github-releases-downloader/1.0")
124 .build()
125 .context("Creating HTTP client")?;
126
127 let releases: github_releases::Root = client
128 .get(format!("https://api.github.com/repos/{repo}/releases"))
129 .send().await?
130 .json().await?;
131
132 let latest_version = releases
133 .iter()
134 .max_by_key(|r| &r.published_at)
135 .context("Finding latest version")?;
136
137 let archive_url = latest_version
138 .assets
139 .iter()
140 .find(|a| is_valid_file(&a.name))
141 .context("Finding zip asset")?
142 .browser_download_url
143 .clone();
144
145 let display_name = display_name.to_string();
146 let expected_hash = if let Some(url) = hash_url {
147 let hash: String = reqwest::get(url).await?.text().await?;
148 Some(hash.trim().to_string())
149 } else {
150 None
151 };
152
153 Ok(download(&display_name, &archive_url, path, expected_hash).await)
154}