humble_cli/
download.rs

1use futures_util::StreamExt;
2use indicatif::{ProgressBar, ProgressStyle};
3use reqwest::Client;
4use std::cmp::min;
5use std::fs::File;
6use std::io::{Seek, Write};
7use std::time::Duration;
8
9#[derive(Debug, thiserror::Error)]
10pub enum DownloadError {
11    #[error(transparent)]
12    Network(#[from] reqwest::Error),
13
14    #[error(transparent)]
15    IO(#[from] std::io::Error),
16
17    #[error("{0}")]
18    Generic(String),
19}
20
21impl DownloadError {
22    fn from_string(s: String) -> Self {
23        DownloadError::Generic(s)
24    }
25}
26
27pub async fn download_file(
28    client: &Client,
29    url: &str,
30    path: &str,
31    title: &str,
32) -> Result<(), DownloadError> {
33    const RETRY_SECONDS: u64 = 5;
34    let mut retries = 3;
35
36    loop {
37        let res = _download_file(client, url, path, title).await;
38
39        retries -= 1;
40        if retries < 0 {
41            return res;
42        }
43
44        match res {
45            Err(DownloadError::Network(ref net_err))
46                if net_err.is_connect() || net_err.is_timeout() =>
47            {
48                println!("  Will retry in {} seconds...", RETRY_SECONDS);
49                tokio::time::sleep(Duration::from_secs(RETRY_SECONDS)).await;
50                continue;
51            }
52            _ => return res,
53        };
54    }
55}
56
57async fn _download_file(
58    client: &Client,
59    url: &str,
60    path: &str,
61    title: &str,
62) -> Result<(), DownloadError> {
63    let (mut file, mut downloaded) = open_file_for_write(path)?;
64    let total_size = get_content_length(client, url).await?;
65
66    if downloaded >= total_size {
67        println!("  Nothing to do. File already exists.");
68        return Ok(());
69    }
70
71    // Start the download
72    let res = client
73        .get(url)
74        .header("Range", format!("bytes={}-", downloaded))
75        .send()
76        .await?;
77
78    let mut stream = res.bytes_stream();
79
80    let pb = get_progress_bar(total_size);
81    pb.set_message(format!("Downloading {}", title));
82
83    while let Some(chunk) = stream.next().await {
84        let chunk = chunk?;
85        let _ = file.write(&chunk)?;
86
87        downloaded = min(downloaded + (chunk.len() as u64), total_size);
88        pb.set_position(downloaded);
89    }
90
91    pb.finish_and_clear();
92    println!("  Downloaded {}", title);
93    Ok(())
94}
95
96fn open_file_for_write(path: &str) -> Result<(File, u64), std::io::Error> {
97    if std::path::Path::new(path).exists() {
98        let mut file = std::fs::OpenOptions::new()
99            .read(true)
100            .append(true)
101            .open(path)?;
102
103        let file_size = std::fs::metadata(path)?.len();
104        file.seek(std::io::SeekFrom::Start(file_size))?;
105        Ok((file, file_size))
106    } else {
107        let file = File::create(path)?;
108        Ok((file, 0))
109    }
110}
111
112async fn get_content_length(client: &Client, url: &str) -> Result<u64, DownloadError> {
113    let res = client.get(url).send().await?;
114    res.content_length().ok_or_else(|| {
115        DownloadError::from_string(format!("Failed to get content length from '{}'", &url))
116    })
117}
118
119fn get_progress_bar(total_size: u64) -> ProgressBar {
120    let pb = ProgressBar::new(total_size);
121    let pb_template =
122        "  {msg}\n  {spinner:.green} [{elapsed}] [{bar}] {bytes} / {total_bytes} ({bytes_per_sec})";
123
124    pb.set_style(
125        ProgressStyle::default_bar()
126            .template(pb_template)
127            .expect("failed to parse progressbar template")
128            .progress_chars("=> "),
129    );
130    pb
131}