oma_fetch/
lib.rs

1use std::{borrow::Cow, cmp::Ordering, fmt::Debug, path::PathBuf, time::Duration};
2
3use bon::{Builder, builder};
4use checksum::Checksum;
5use download::{BuilderError, SingleDownloader, SuccessSummary};
6use futures::StreamExt;
7
8use reqwest::{Client, Method, RequestBuilder, Response};
9use tracing::debug;
10
11pub mod checksum;
12pub mod download;
13pub use crate::download::SingleDownloadError;
14
15pub use reqwest;
16
17#[derive(Clone, Default, Builder)]
18pub struct DownloadEntry {
19    pub source: Vec<DownloadSource>,
20    pub filename: String,
21    dir: PathBuf,
22    hash: Option<Checksum>,
23    allow_resume: bool,
24    msg: Option<Cow<'static, str>>,
25    #[builder(default)]
26    file_type: CompressFile,
27}
28
29impl Debug for DownloadEntry {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        f.debug_struct("DownloadEntry")
32            .field("source", &self.source)
33            .field("filename", &self.filename)
34            .field("dir", &self.dir)
35            .field("hash", &self.hash.as_ref().map(|c| c.to_string()))
36            .field("allow_resume", &self.allow_resume)
37            .field("msg", &self.msg)
38            .field("file_type", &self.file_type)
39            .finish()
40    }
41}
42
43#[derive(Debug, Clone, Default, PartialEq, Eq, Copy)]
44pub enum CompressFile {
45    Bz2,
46    Gzip,
47    Xz,
48    Zstd,
49    Lzma,
50    Lz4,
51    #[default]
52    Nothing,
53}
54
55impl From<&str> for CompressFile {
56    fn from(s: &str) -> Self {
57        match s {
58            "xz" => CompressFile::Xz,
59            "gz" => CompressFile::Gzip,
60            "bz2" => CompressFile::Bz2,
61            "zst" => CompressFile::Zstd,
62            _ => CompressFile::Nothing,
63        }
64    }
65}
66
67#[derive(Debug, Clone)]
68pub struct DownloadSource {
69    pub url: String,
70    pub source_type: DownloadSourceType,
71}
72
73#[derive(Debug, PartialEq, Eq, Clone)]
74pub enum DownloadSourceType {
75    Http { auth: Option<(String, String)> },
76    Local(bool),
77}
78
79impl PartialOrd for DownloadSourceType {
80    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
81        Some(self.cmp(other))
82    }
83}
84
85impl Ord for DownloadSourceType {
86    fn cmp(&self, other: &Self) -> Ordering {
87        match self {
88            DownloadSourceType::Http { .. } => match other {
89                DownloadSourceType::Http { .. } => Ordering::Equal,
90                DownloadSourceType::Local { .. } => Ordering::Less,
91            },
92            DownloadSourceType::Local { .. } => match other {
93                DownloadSourceType::Http { .. } => Ordering::Greater,
94                DownloadSourceType::Local { .. } => Ordering::Equal,
95            },
96        }
97    }
98}
99
100#[derive(Debug)]
101pub enum Event {
102    ChecksumMismatch {
103        index: usize,
104        filename: String,
105        times: usize,
106    },
107    GlobalProgressAdd(u64),
108    GlobalProgressSub(u64),
109    ProgressDone(usize),
110    NewProgressSpinner {
111        index: usize,
112        total: usize,
113        msg: String,
114    },
115    NewProgressBar {
116        index: usize,
117        total: usize,
118        msg: String,
119        size: u64,
120    },
121    ProgressInc {
122        index: usize,
123        size: u64,
124    },
125    NextUrl {
126        index: usize,
127        file_name: String,
128        err: SingleDownloadError,
129    },
130    DownloadDone {
131        index: usize,
132        msg: Box<str>,
133    },
134    Failed {
135        file_name: String,
136        error: SingleDownloadError,
137    },
138    AllDone,
139    NewGlobalProgressBar(u64),
140}
141
142#[derive(Builder)]
143pub struct DownloadManager<'a> {
144    client: &'a Client,
145    download_list: &'a [DownloadEntry],
146    #[builder(default = 4)]
147    threads: usize,
148    #[builder(default = 3)]
149    retry_times: usize,
150    #[builder(default)]
151    total_size: u64,
152    #[builder(default = Duration::from_secs(15))]
153    timeout: Duration,
154}
155
156#[derive(Debug)]
157pub struct Summary {
158    pub success: Vec<SuccessSummary>,
159    pub failed: Vec<String>,
160}
161
162impl Summary {
163    pub fn is_download_success(&self) -> bool {
164        self.failed.is_empty()
165    }
166
167    pub fn has_wrote(&self) -> bool {
168        self.success.iter().any(|x| x.wrote)
169    }
170}
171
172impl DownloadManager<'_> {
173    /// Start download
174    pub async fn start_download(
175        &self,
176        callback: impl AsyncFn(Event),
177    ) -> Result<Summary, BuilderError> {
178        if self.threads == 0 || self.threads > 255 {
179            return Err(BuilderError::IllegalDownloadThread {
180                count: self.threads,
181            });
182        }
183
184        let mut tasks = Vec::new();
185        let mut list = vec![];
186        for (i, c) in self.download_list.iter().enumerate() {
187            let msg = c.msg.clone();
188            let single = SingleDownloader::builder()
189                .client(self.client)
190                .maybe_msg(msg)
191                .download_list_index(i)
192                .entry(c)
193                .total(self.download_list.len())
194                .retry_times(self.retry_times)
195                .file_type(c.file_type)
196                .timeout(self.timeout)
197                .build()?;
198
199            list.push(single);
200        }
201
202        for single in list {
203            tasks.push(single.try_download(&callback));
204        }
205
206        if self.total_size != 0 {
207            callback(Event::NewGlobalProgressBar(self.total_size)).await;
208        }
209
210        let stream = futures::stream::iter(tasks).buffer_unordered(self.threads);
211        let res = stream.collect::<Vec<_>>().await;
212        callback(Event::AllDone).await;
213
214        let (mut success, mut failed) = (vec![], vec![]);
215
216        for i in res {
217            match i {
218                download::DownloadResult::Success(success_summary) => {
219                    success.push(success_summary);
220                }
221                download::DownloadResult::Failed { file_name } => {
222                    failed.push(file_name);
223                }
224            }
225        }
226
227        Ok(Summary { success, failed })
228    }
229}
230
231pub fn build_request_with_basic_auth(
232    client: &Client,
233    method: Method,
234    auth: &Option<(String, String)>,
235    url: &str,
236) -> RequestBuilder {
237    let mut req = client.request(method, url);
238
239    if let Some((user, password)) = auth {
240        debug!("Authenticating as user: {} ...", user);
241        req = req.basic_auth(user, Some(password));
242    }
243
244    req
245}
246
247pub async fn send_request(url: &str, request: RequestBuilder) -> Result<Response, reqwest::Error> {
248    let resp = request.send().await?;
249    let headers = resp.headers();
250
251    debug!(
252        "\nDownload URL: {url}\nStatus: {}\nHeaders: {headers:#?}",
253        resp.status()
254    );
255
256    let resp = resp.error_for_status()?;
257
258    Ok(resp)
259}