1use futures::{Stream, StreamExt, channel::mpsc};
42#[cfg(feature = "json")]
43use serde::de::DeserializeOwned;
44use sha2::{Digest as _, Sha256};
45use std::{
46 io,
47 path::{Path, PathBuf},
48 time::Duration,
49};
50use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
51
52pub use reqwest::IntoUrl;
53
54#[derive(Debug, Clone)]
74pub struct Downloader {
75 client: reqwest::Client,
76 cache_dir: PathBuf,
77}
78
79impl Downloader {
80 pub fn new<P: Into<PathBuf>>(cache_dir: P) -> io::Result<Self> {
82 let cache_dir = cache_dir.into();
83
84 if !cache_dir.exists() {
85 let _ = std::fs::create_dir_all(&cache_dir);
86 }
87
88 if cache_dir.exists() && !cache_dir.is_dir() {
89 return Err(io::Error::new(
90 io::ErrorKind::NotADirectory,
91 "cache_dir should be a directory",
92 ));
93 }
94
95 let client = reqwest::Client::builder()
96 .connect_timeout(Duration::from_secs(10))
97 .build()
98 .expect("Unsupported OS");
99
100 Ok(Self { client, cache_dir })
101 }
102
103 pub async fn check_cache_from_sha(&self, sha256: [u8; 32]) -> Option<PathBuf> {
105 let file_path = self.path_from_sha(sha256);
106
107 if file_path.exists() {
108 if let Ok(hash) = sha256_from_path(&file_path).await {
109 if hash == sha256 {
110 return Some(file_path);
111 }
112 }
113
114 let _ = tokio::fs::remove_file(&file_path).await;
116 }
117
118 None
119 }
120
121 pub fn check_cache_from_url<U: reqwest::IntoUrl>(&self, url: U) -> Option<PathBuf> {
126 let file_path = self.path_from_url(url.as_str());
128 if file_path.exists() {
129 Some(file_path)
130 } else {
131 None
132 }
133 }
134
135 #[cfg(feature = "json")]
138 pub async fn download_json_no_cache<T, U>(&self, url: U) -> io::Result<T>
139 where
140 T: DeserializeOwned,
141 U: reqwest::IntoUrl,
142 {
143 self.client
144 .get(url)
145 .send()
146 .await
147 .map_err(io::Error::other)?
148 .json()
149 .await
150 .map_err(io::Error::other)
151 }
152
153 pub async fn download<U: reqwest::IntoUrl>(
163 &self,
164 url: U,
165 chan: Option<mpsc::Sender<f32>>,
166 ) -> io::Result<PathBuf> {
167 let url = url.into_url().map_err(io::Error::other)?;
168
169 if let Some(p) = self.check_cache_from_url(url.clone()) {
171 return Ok(p);
172 }
173
174 self.download_no_cache(url, chan).await
175 }
176
177 pub async fn download_no_cache<U: reqwest::IntoUrl>(
191 &self,
192 url: U,
193 mut chan: Option<mpsc::Sender<f32>>,
194 ) -> io::Result<PathBuf> {
195 let url = url.into_url().map_err(io::Error::other)?;
196
197 let file_path = self.path_from_url(url.as_str());
198 chan_send(chan.as_mut(), 0.0);
199
200 let mut cur_pos = 0;
201 let mut file = AsyncTempFile::new()?;
202 {
203 let mut file = tokio::io::BufWriter::new(&mut file.0);
204
205 let response = self
206 .client
207 .get(url)
208 .send()
209 .await
210 .map_err(io::Error::other)?;
211 let response_size = response.content_length();
212 let mut response_stream = response.bytes_stream();
213
214 let response_size = match response_size {
215 Some(x) => x as usize,
216 None => response_stream.size_hint().0,
217 };
218
219 while let Some(x) = response_stream.next().await {
220 let mut data = x.map_err(io::Error::other)?;
221 cur_pos += data.len();
222 file.write_all_buf(&mut data).await?;
223 chan_send(chan.as_mut(), (cur_pos as f32) / (response_size as f32));
224 }
225
226 file.flush().await?
227 }
228
229 file.persist(&file_path).await?;
230 Ok(file_path)
231 }
232
233 pub async fn download_with_sha<U: reqwest::IntoUrl>(
242 &self,
243 url: U,
244 sha256: [u8; 32],
245 mut chan: Option<mpsc::Sender<f32>>,
246 ) -> io::Result<PathBuf> {
247 let url = url.into_url().map_err(io::Error::other)?;
248 tracing::info!(
249 "Download {:?} with sha256: {:?}",
250 url,
251 const_hex::encode(sha256)
252 );
253
254 if let Some(p) = self.check_cache_from_sha(sha256).await {
255 return Ok(p);
256 }
257
258 let file_path = self.path_from_sha(sha256);
259 chan_send(chan.as_mut(), 0.0);
260
261 let mut file = AsyncTempFile::new()?;
262 {
263 let mut file = tokio::io::BufWriter::new(&mut file.0);
264
265 let response = self
266 .client
267 .get(url)
268 .send()
269 .await
270 .map_err(io::Error::other)?;
271
272 let mut cur_pos = 0;
273 let response_size = response.content_length();
274
275 let mut response_stream = response.bytes_stream();
276
277 let response_size = match response_size {
278 Some(x) => x as usize,
279 None => response_stream.size_hint().0,
280 };
281
282 let mut hasher = Sha256::new();
283
284 while let Some(x) = response_stream.next().await {
285 let mut data = x.map_err(io::Error::other)?;
286 cur_pos += data.len();
287 hasher.update(&data);
288 file.write_all_buf(&mut data).await?;
289
290 chan_send(chan.as_mut(), (cur_pos as f32) / (response_size as f32));
291 }
292
293 let hash: [u8; 32] = hasher
294 .finalize()
295 .as_slice()
296 .try_into()
297 .expect("SHA-256 is 32 bytes");
298
299 if hash != sha256 {
300 tracing::warn!("{hash:?} != {sha256:?}");
301 return Err(io::Error::new(
302 io::ErrorKind::InvalidInput,
303 "Invalid SHA256",
304 ));
305 }
306 file.flush().await?;
307 }
308
309 file.persist(&file_path).await?;
310 Ok(file_path)
311 }
312
313 fn path_from_url(&self, url: &str) -> PathBuf {
314 let file_name: [u8; 32] = Sha256::new()
315 .chain_update(url)
316 .finalize()
317 .as_slice()
318 .try_into()
319 .expect("SHA-256 is 32 bytes");
320 self.path_from_sha(file_name)
321 }
322
323 fn path_from_sha(&self, sha256: [u8; 32]) -> PathBuf {
324 let file_name = const_hex::encode(sha256);
325 self.cache_dir.join(file_name)
326 }
327}
328
329async fn sha256_from_path(p: &Path) -> io::Result<[u8; 32]> {
330 let file = tokio::fs::File::open(p).await?;
331 let mut reader = tokio::io::BufReader::new(file);
332 let mut hasher = Sha256::new();
333 let mut buffer = [0; 512];
334
335 loop {
336 let count = reader.read(&mut buffer).await?;
337 if count == 0 {
338 break;
339 }
340
341 hasher.update(&buffer[..count]);
342 }
343
344 let hash = hasher
345 .finalize()
346 .as_slice()
347 .try_into()
348 .expect("SHA-256 is 32 bytes");
349
350 Ok(hash)
351}
352
353fn chan_send(chan: Option<&mut mpsc::Sender<f32>>, msg: f32) {
354 if let Some(c) = chan {
355 let _ = c.try_send(msg);
356 }
357}
358
359struct AsyncTempFile(tokio::fs::File);
360
361impl AsyncTempFile {
362 fn new() -> io::Result<Self> {
363 let f = tempfile::tempfile()?;
364 Ok(Self(tokio::fs::File::from_std(f)))
365 }
366
367 async fn persist(&mut self, path: &Path) -> io::Result<()> {
368 let mut f = tokio::fs::File::create(path).await?;
369 self.0.seek(io::SeekFrom::Start(0)).await?;
370 tokio::io::copy(&mut self.0, &mut f).await?;
371 Ok(())
372 }
373}