bb_downloader/
lib.rs

1//! A simple downloader library designed to be used in Applications with support to cache
2//! downloaded assets.
3//!
4//! # Features
5//!
6//! - Async
7//! - Cache downloaded file in a directory in filesystem.
8//! - Check if a file is available in cache.
9//! - Uses SHA256 for verifying cached files.
10//! - Optional support to download files without caching.
11//!
12//! # Sample Usage
13//!
14//! ```no_run
15//! #[tokio::main]
16//! async fn main() {
17//!     let downloader = bb_downloader::Downloader::new("/tmp").unwrap();
18//!
19//!     let sha = [0u8; 32];
20//!     let url = "https://example.com/img.jpg";
21//!
22//!     // Download with just URL
23//!     downloader.download(url, None).await.unwrap();
24//!
25//!     // Check if the file is in cache
26//!     assert!(downloader.check_cache_from_url(url).is_some());
27//!
28//!     // Will fetch directly from cache instead of re-downloading
29//!     downloader.download(url, None).await.unwrap();
30//!
31//!     // Will fetch directly from cache instead of re-downloading
32//!     assert!(!downloader.check_cache_from_sha(sha).await.is_some());
33//!
34//!     // Will re-download the file
35//!     downloader.download_with_sha(url, sha, None).await.unwrap();
36//!
37//!     assert!(downloader.check_cache_from_sha(sha).await.is_some());
38//! }
39//! ```
40
41use futures::{Stream, StreamExt, channel::mpsc};
42#[cfg(feature = "json")]
43use serde::de::DeserializeOwned;
44use sha2::{Digest as _, Sha256};
45use std::{
46    io,
47    path::{Path, PathBuf},
48    time::Duration,
49};
50use tokio::io::{AsyncReadExt, AsyncSeekExt, AsyncWriteExt};
51
52pub use reqwest::IntoUrl;
53
54/// Simple downloader that caches files in the provided directory. Uses SHA256 to determine if the
55/// file is already downloaded.
56///
57/// Either SHA256 or URL can be used for caching files. However, both are not interchangable. If
58/// SHA256 cannot be used to check files that were downloaded with just URL, and vice versa.
59///
60/// # Invalidate Cache
61///
62/// Using SHA256 should be prefered when it is known in advance since it allows performing SHA256
63/// verficiation on the downloaded file. Additionally, it also adds capability to invalidate cached
64/// file.
65///
66/// Files downloaded with just URL cannot be invalidated without changing the URL, or deleting the
67/// file manually.
68///
69/// # Thread Safety
70///
71/// You do not have to wrap the Client in an Rc or Arc to reuse it, because it already uses an Arc
72/// internally.
73#[derive(Debug, Clone)]
74pub struct Downloader {
75    client: reqwest::Client,
76    cache_dir: PathBuf,
77}
78
79impl Downloader {
80    /// Create a new downloader that uses a directory for storing cached files.
81    pub fn new<P: Into<PathBuf>>(cache_dir: P) -> io::Result<Self> {
82        let cache_dir = cache_dir.into();
83
84        if !cache_dir.exists() {
85            let _ = std::fs::create_dir_all(&cache_dir);
86        }
87
88        if cache_dir.exists() && !cache_dir.is_dir() {
89            return Err(io::Error::new(
90                io::ErrorKind::NotADirectory,
91                "cache_dir should be a directory",
92            ));
93        }
94
95        let client = reqwest::Client::builder()
96            .connect_timeout(Duration::from_secs(10))
97            .build()
98            .expect("Unsupported OS");
99
100        Ok(Self { client, cache_dir })
101    }
102
103    /// Check if a downloaded file with a particular SHA256 is already in cache.
104    pub async fn check_cache_from_sha(&self, sha256: [u8; 32]) -> Option<PathBuf> {
105        let file_path = self.path_from_sha(sha256);
106
107        if file_path.exists() {
108            if let Ok(hash) = sha256_from_path(&file_path).await {
109                if hash == sha256 {
110                    return Some(file_path);
111                }
112            }
113
114            // Delete old file
115            let _ = tokio::fs::remove_file(&file_path).await;
116        }
117
118        None
119    }
120
121    /// Check if a downloaded file is already in cache.
122    ///
123    /// [`check_cache_from_sha`](Self::check_cache_from_sha) should be prefered in cases when SHA256
124    /// of the file to download is already known.
125    pub fn check_cache_from_url<U: reqwest::IntoUrl>(&self, url: U) -> Option<PathBuf> {
126        // Use hash of url for file name
127        let file_path = self.path_from_url(url.as_str());
128        if file_path.exists() {
129            Some(file_path)
130        } else {
131            None
132        }
133    }
134
135    /// Download a JSON file without caching the contents. Should be used when there is no point in
136    /// caching the file.
137    #[cfg(feature = "json")]
138    pub async fn download_json_no_cache<T, U>(&self, url: U) -> io::Result<T>
139    where
140        T: DeserializeOwned,
141        U: reqwest::IntoUrl,
142    {
143        self.client
144            .get(url)
145            .send()
146            .await
147            .map_err(io::Error::other)?
148            .json()
149            .await
150            .map_err(io::Error::other)
151    }
152
153    /// Checks if the file is present in cache. If the file is present, returns path to it. Else
154    /// downloads the file.
155    ///
156    /// [`download_with_sha`](Self::download_with_sha) should be prefered when the SHA256 of the
157    /// file is known in advance.
158    ///
159    /// # Progress
160    ///
161    /// Download progress can be optionally tracked using a [`futures::channel::mpsc`].
162    pub async fn download<U: reqwest::IntoUrl>(
163        &self,
164        url: U,
165        chan: Option<mpsc::Sender<f32>>,
166    ) -> io::Result<PathBuf> {
167        let url = url.into_url().map_err(io::Error::other)?;
168
169        // Check cache
170        if let Some(p) = self.check_cache_from_url(url.clone()) {
171            return Ok(p);
172        }
173
174        self.download_no_cache(url, chan).await
175    }
176
177    /// Downloads the file without checking cache.
178    ///
179    /// [`download_with_sha`](Self::download_with_sha) should be prefered when the SHA256 of the
180    /// file is known in advance.
181    ///
182    /// # Progress
183    ///
184    /// Download progress can be optionally tracked using a [`futures::channel::mpsc`].
185    ///
186    /// # Differences from [Self::download]
187    ///
188    /// This function does not check if the file is present in cache, and will ovewrite the old
189    /// cached file. The file is still cached in the end.
190    pub async fn download_no_cache<U: reqwest::IntoUrl>(
191        &self,
192        url: U,
193        mut chan: Option<mpsc::Sender<f32>>,
194    ) -> io::Result<PathBuf> {
195        let url = url.into_url().map_err(io::Error::other)?;
196
197        let file_path = self.path_from_url(url.as_str());
198        chan_send(chan.as_mut(), 0.0);
199
200        let mut cur_pos = 0;
201        let mut file = AsyncTempFile::new()?;
202        {
203            let mut file = tokio::io::BufWriter::new(&mut file.0);
204
205            let response = self
206                .client
207                .get(url)
208                .send()
209                .await
210                .map_err(io::Error::other)?;
211            let response_size = response.content_length();
212            let mut response_stream = response.bytes_stream();
213
214            let response_size = match response_size {
215                Some(x) => x as usize,
216                None => response_stream.size_hint().0,
217            };
218
219            while let Some(x) = response_stream.next().await {
220                let mut data = x.map_err(io::Error::other)?;
221                cur_pos += data.len();
222                file.write_all_buf(&mut data).await?;
223                chan_send(chan.as_mut(), (cur_pos as f32) / (response_size as f32));
224            }
225
226            file.flush().await?
227        }
228
229        file.persist(&file_path).await?;
230        Ok(file_path)
231    }
232
233    /// Checks if the file is present in cache. If the file is present, returns path to it. Else
234    /// downloads the file.
235    ///
236    /// Uses SHA256 to verify that the file in cache is valid.
237    ///
238    /// # Progress
239    ///
240    /// Download progress can be optionally tracked using a [`futures::channel::mpsc`].
241    pub async fn download_with_sha<U: reqwest::IntoUrl>(
242        &self,
243        url: U,
244        sha256: [u8; 32],
245        mut chan: Option<mpsc::Sender<f32>>,
246    ) -> io::Result<PathBuf> {
247        let url = url.into_url().map_err(io::Error::other)?;
248        tracing::info!(
249            "Download {:?} with sha256: {:?}",
250            url,
251            const_hex::encode(sha256)
252        );
253
254        if let Some(p) = self.check_cache_from_sha(sha256).await {
255            return Ok(p);
256        }
257
258        let file_path = self.path_from_sha(sha256);
259        chan_send(chan.as_mut(), 0.0);
260
261        let mut file = AsyncTempFile::new()?;
262        {
263            let mut file = tokio::io::BufWriter::new(&mut file.0);
264
265            let response = self
266                .client
267                .get(url)
268                .send()
269                .await
270                .map_err(io::Error::other)?;
271
272            let mut cur_pos = 0;
273            let response_size = response.content_length();
274
275            let mut response_stream = response.bytes_stream();
276
277            let response_size = match response_size {
278                Some(x) => x as usize,
279                None => response_stream.size_hint().0,
280            };
281
282            let mut hasher = Sha256::new();
283
284            while let Some(x) = response_stream.next().await {
285                let mut data = x.map_err(io::Error::other)?;
286                cur_pos += data.len();
287                hasher.update(&data);
288                file.write_all_buf(&mut data).await?;
289
290                chan_send(chan.as_mut(), (cur_pos as f32) / (response_size as f32));
291            }
292
293            let hash: [u8; 32] = hasher
294                .finalize()
295                .as_slice()
296                .try_into()
297                .expect("SHA-256 is 32 bytes");
298
299            if hash != sha256 {
300                tracing::warn!("{hash:?} != {sha256:?}");
301                return Err(io::Error::new(
302                    io::ErrorKind::InvalidInput,
303                    "Invalid SHA256",
304                ));
305            }
306            file.flush().await?;
307        }
308
309        file.persist(&file_path).await?;
310        Ok(file_path)
311    }
312
313    fn path_from_url(&self, url: &str) -> PathBuf {
314        let file_name: [u8; 32] = Sha256::new()
315            .chain_update(url)
316            .finalize()
317            .as_slice()
318            .try_into()
319            .expect("SHA-256 is 32 bytes");
320        self.path_from_sha(file_name)
321    }
322
323    fn path_from_sha(&self, sha256: [u8; 32]) -> PathBuf {
324        let file_name = const_hex::encode(sha256);
325        self.cache_dir.join(file_name)
326    }
327}
328
329async fn sha256_from_path(p: &Path) -> io::Result<[u8; 32]> {
330    let file = tokio::fs::File::open(p).await?;
331    let mut reader = tokio::io::BufReader::new(file);
332    let mut hasher = Sha256::new();
333    let mut buffer = [0; 512];
334
335    loop {
336        let count = reader.read(&mut buffer).await?;
337        if count == 0 {
338            break;
339        }
340
341        hasher.update(&buffer[..count]);
342    }
343
344    let hash = hasher
345        .finalize()
346        .as_slice()
347        .try_into()
348        .expect("SHA-256 is 32 bytes");
349
350    Ok(hash)
351}
352
353fn chan_send(chan: Option<&mut mpsc::Sender<f32>>, msg: f32) {
354    if let Some(c) = chan {
355        let _ = c.try_send(msg);
356    }
357}
358
359struct AsyncTempFile(tokio::fs::File);
360
361impl AsyncTempFile {
362    fn new() -> io::Result<Self> {
363        let f = tempfile::tempfile()?;
364        Ok(Self(tokio::fs::File::from_std(f)))
365    }
366
367    async fn persist(&mut self, path: &Path) -> io::Result<()> {
368        let mut f = tokio::fs::File::create(path).await?;
369        self.0.seek(io::SeekFrom::Start(0)).await?;
370        tokio::io::copy(&mut self.0, &mut f).await?;
371        Ok(())
372    }
373}