1use futures_util::StreamExt;
2use indicatif::{ProgressBar, ProgressStyle};
3use reqwest::Client;
4use std::cmp::min;
5use std::fs::File;
6use std::io::{Seek, Write};
7use std::time::Duration;
8
9#[derive(Debug, thiserror::Error)]
10pub enum DownloadError {
11 #[error(transparent)]
12 Network(#[from] reqwest::Error),
13
14 #[error(transparent)]
15 IO(#[from] std::io::Error),
16
17 #[error("{0}")]
18 Generic(String),
19}
20
21impl DownloadError {
22 fn from_string(s: String) -> Self {
23 DownloadError::Generic(s)
24 }
25}
26
27pub async fn download_file(
28 client: &Client,
29 url: &str,
30 path: &str,
31 title: &str,
32) -> Result<(), DownloadError> {
33 const RETRY_SECONDS: u64 = 5;
34 let mut retries = 3;
35
36 loop {
37 let res = _download_file(client, url, path, title).await;
38
39 retries -= 1;
40 if retries < 0 {
41 return res;
42 }
43
44 match res {
45 Err(DownloadError::Network(ref net_err))
46 if net_err.is_connect() || net_err.is_timeout() =>
47 {
48 println!(" Will retry in {} seconds...", RETRY_SECONDS);
49 tokio::time::sleep(Duration::from_secs(RETRY_SECONDS)).await;
50 continue;
51 }
52 _ => return res,
53 };
54 }
55}
56
57async fn _download_file(
58 client: &Client,
59 url: &str,
60 path: &str,
61 title: &str,
62) -> Result<(), DownloadError> {
63 let (mut file, mut downloaded) = open_file_for_write(path)?;
64 let total_size = get_content_length(client, url).await?;
65
66 if downloaded >= total_size {
67 println!(" Nothing to do. File already exists.");
68 return Ok(());
69 }
70
71 let res = client
73 .get(url)
74 .header("Range", format!("bytes={}-", downloaded))
75 .send()
76 .await?;
77
78 let mut stream = res.bytes_stream();
79
80 let pb = get_progress_bar(total_size);
81 pb.set_message(format!("Downloading {}", title));
82
83 while let Some(chunk) = stream.next().await {
84 let chunk = chunk?;
85 let _ = file.write(&chunk)?;
86
87 downloaded = min(downloaded + (chunk.len() as u64), total_size);
88 pb.set_position(downloaded);
89 }
90
91 pb.finish_and_clear();
92 println!(" Downloaded {}", title);
93 Ok(())
94}
95
96fn open_file_for_write(path: &str) -> Result<(File, u64), std::io::Error> {
97 if std::path::Path::new(path).exists() {
98 let mut file = std::fs::OpenOptions::new()
99 .read(true)
100 .append(true)
101 .open(path)?;
102
103 let file_size = std::fs::metadata(path)?.len();
104 file.seek(std::io::SeekFrom::Start(file_size))?;
105 Ok((file, file_size))
106 } else {
107 let file = File::create(path)?;
108 Ok((file, 0))
109 }
110}
111
112async fn get_content_length(client: &Client, url: &str) -> Result<u64, DownloadError> {
113 let res = client.get(url).send().await?;
114 res.content_length().ok_or_else(|| {
115 DownloadError::from_string(format!("Failed to get content length from '{}'", &url))
116 })
117}
118
119fn get_progress_bar(total_size: u64) -> ProgressBar {
120 let pb = ProgressBar::new(total_size);
121 let pb_template =
122 " {msg}\n {spinner:.green} [{elapsed}] [{bar}] {bytes} / {total_bytes} ({bytes_per_sec})";
123
124 pb.set_style(
125 ProgressStyle::default_bar()
126 .template(pb_template)
127 .expect("failed to parse progressbar template")
128 .progress_chars("=> "),
129 );
130 pb
131}