cmsis_pack/update/
download.rs

1use std::fs::{create_dir_all, rename, OpenOptions};
2use std::io::Write;
3use std::path::{Path, PathBuf};
4
5use anyhow::{anyhow, Error};
6use futures::prelude::*;
7use futures::stream::futures_unordered::FuturesUnordered;
8use reqwest::{redirect, Url};
9use reqwest::{Client, ClientBuilder, Response};
10use tokio::task::JoinHandle;
11use tokio::time::{sleep, Duration};
12
13use crate::pack_index::{PdscRef, Vidx};
14use crate::pdsc::Package;
15use crate::utils::parse::FromElem;
16use futures::StreamExt;
17use std::collections::HashMap;
18
19const CONCURRENCY: usize = 32;
20const HOST_LIMIT: usize = 6;
21const MAX_RETRIES: usize = 3;
22
23fn pdsc_url(pdsc: &mut PdscRef) -> String {
24    if pdsc.url.ends_with('/') {
25        format!("{}{}.{}.pdsc", pdsc.url, pdsc.vendor, pdsc.name)
26    } else {
27        format!("{}/{}.{}.pdsc", pdsc.url, pdsc.vendor, pdsc.name)
28    }
29}
30
31pub trait DownloadConfig {
32    fn pack_store(&self) -> PathBuf;
33}
34
35#[allow(clippy::wrong_self_convention)]
36pub trait IntoDownload {
37    fn into_uri(&self) -> Result<Url, Error>;
38    fn into_fd<D: DownloadConfig>(&self, _: &D) -> PathBuf;
39}
40
41impl IntoDownload for PdscRef {
42    fn into_uri(&self) -> Result<Url, Error> {
43        let PdscRef {
44            url, vendor, name, ..
45        } = self;
46        let uri = if url.ends_with('/') {
47            format!("{}{}.{}.pdsc", url, vendor, name)
48        } else {
49            format!("{}/{}.{}.pdsc", url, vendor, name)
50        }
51        .parse()?;
52        Ok(uri)
53    }
54
55    fn into_fd<D: DownloadConfig>(&self, config: &D) -> PathBuf {
56        let PdscRef {
57            vendor,
58            name,
59            version,
60            ..
61        } = self;
62        let mut filename = config.pack_store();
63        let pdscname = format!("{}.{}.{}.pdsc", vendor, name, version);
64        filename.push(pdscname);
65        filename
66    }
67}
68
69impl IntoDownload for &Package {
70    fn into_uri(&self) -> Result<Url, Error> {
71        let Package {
72            name,
73            vendor,
74            url,
75            releases,
76            ..
77        } = *self;
78        let version: &str = releases.latest_release().version.as_ref();
79        let uri = if url.ends_with('/') {
80            format!("{}{}.{}.{}.pack", url, vendor, name, version)
81        } else {
82            format!("{}/{}.{}.{}.pack", url, vendor, name, version)
83        }
84        .parse()?;
85        Ok(uri)
86    }
87
88    fn into_fd<D: DownloadConfig>(&self, config: &D) -> PathBuf {
89        let Package {
90            name,
91            vendor,
92            releases,
93            ..
94        } = *self;
95        let version: &str = releases.latest_release().version.as_ref();
96        let mut filename = config.pack_store();
97        filename.push(Path::new(vendor));
98        filename.push(Path::new(name));
99        filename.push(format!("{}.pack", version));
100        filename
101    }
102}
103
104async fn save_response(response: Response, dest: PathBuf) -> Result<(usize, PathBuf), Error> {
105    let temp = dest.with_extension("part");
106    let file = OpenOptions::new()
107        .write(true)
108        .create(true)
109        .truncate(true)
110        .open(&temp);
111
112    let mut file = match file {
113        Err(err) => return Err(anyhow!(err.to_string())),
114        Ok(f) => f,
115    };
116
117    let mut fsize: usize = 0;
118    let mut stream = response.bytes_stream();
119    while let Some(chunk) = stream.next().await {
120        match chunk {
121            Ok(bytes) => {
122                fsize += bytes.len();
123
124                if let Err(err) = file.write_all(bytes.as_ref()) {
125                    let _ = std::fs::remove_file(temp);
126                    return Err(anyhow!(err.to_string()));
127                }
128            }
129            Err(err) => {
130                let _ = std::fs::remove_file(temp);
131                return Err(anyhow!(err.to_string()));
132            }
133        }
134    }
135    if let Err(err) = rename(&temp, &dest) {
136        let _ = std::fs::remove_file(temp);
137        return Err(anyhow!(err.to_string()));
138    }
139    Ok((fsize, dest))
140}
141
142pub trait DownloadProgress: Send {
143    fn size(&self, files: usize);
144    fn progress(&self, bytes: usize);
145    fn complete(&self);
146    fn for_file(&self, file: &str) -> Self;
147}
148
149impl DownloadProgress for () {
150    fn size(&self, _: usize) {}
151    fn progress(&self, _: usize) {}
152    fn complete(&self) {}
153    fn for_file(&self, _: &str) -> Self {}
154}
155
156pub struct DownloadContext<'a, Conf, Prog>
157where
158    Conf: DownloadConfig,
159    Prog: DownloadProgress + 'a,
160{
161    config: &'a Conf,
162    prog: Prog,
163    client: Client,
164}
165
166impl<'a, Conf, Prog> DownloadContext<'a, Conf, Prog>
167where
168    Conf: DownloadConfig,
169    Prog: DownloadProgress + 'a,
170{
171    pub fn new(config: &'a Conf, prog: Prog) -> Result<Self, Error> {
172        let client = ClientBuilder::new()
173            .redirect(redirect::Policy::limited(5))
174            .build()?;
175
176        Ok(DownloadContext {
177            config,
178            prog,
179            client,
180        })
181    }
182
183    pub async fn download_iterator<I>(&'a self, iter: I) -> Vec<PathBuf>
184    where
185        I: IntoIterator + 'a,
186        <I as IntoIterator>::Item: IntoDownload,
187    {
188        let mut to_dl: Vec<(Url, String, PathBuf)> = iter
189            .into_iter()
190            .filter_map(|i| {
191                if let Ok(uri) = i.into_uri() {
192                    let c = uri.clone();
193                    c.host_str()
194                        .map(|host| (uri, host.to_string(), i.into_fd(self.config)))
195                } else {
196                    None
197                }
198            })
199            .collect();
200        self.prog.size(to_dl.len());
201
202        let mut hosts: HashMap<String, usize> = HashMap::new();
203        let mut results: Vec<PathBuf> = vec![];
204        let mut started: usize = 0;
205        let mut handles: Vec<JoinHandle<(String, usize, Option<PathBuf>)>> = vec![];
206
207        while !to_dl.is_empty() || !handles.is_empty() {
208            let mut wait_list: Vec<(Url, String, PathBuf)> = vec![];
209            let mut next: Vec<JoinHandle<(String, usize, Option<PathBuf>)>> = vec![];
210
211            while let Some(handle) = handles.pop() {
212                if handle.is_finished() {
213                    let r = handle.await.unwrap();
214                    *hosts.entry(r.0).or_insert(1) -= 1;
215                    started -= 1;
216                    self.prog.progress(r.1);
217                    self.prog.complete();
218                    if let Some(path) = r.2 {
219                        results.push(path);
220                    }
221                } else {
222                    next.push(handle);
223                }
224            }
225
226            while !to_dl.is_empty() && started < CONCURRENCY {
227                let from = to_dl.pop().unwrap();
228                let host = from.1.clone();
229                let entry = hosts.entry(host).or_insert(0);
230                if *entry >= HOST_LIMIT {
231                    wait_list.push(from);
232                } else {
233                    let source = from.0.clone();
234                    let host = from.1.clone();
235                    let dest = from.2.clone();
236                    if dest.exists() {
237                        results.push(dest);
238                    } else {
239                        let client = self.client.clone();
240                        let handle: JoinHandle<(String, usize, Option<PathBuf>)> =
241                            tokio::spawn(async move {
242                                dest.parent().map(create_dir_all);
243                                let res = client.get(source.clone()).send().await;
244                                let res: Result<(usize, PathBuf), Error> = match res {
245                                    Ok(r) => {
246                                        let rc = r.status().as_u16();
247                                        if rc >= 400 {
248                                            Err(anyhow!(format!(
249                                                "Response code in invalid range: {}",
250                                                rc
251                                            )
252                                            .to_string()))
253                                        } else {
254                                            save_response(r, dest).await
255                                        }
256                                    }
257                                    Err(err) => Err(anyhow!(err.to_string())),
258                                };
259                                match res {
260                                    Ok(r) => (host, r.0, Some(r.1)),
261                                    Err(err) => {
262                                        log::warn!(
263                                            "Download of {} failed: {}",
264                                            source.to_string(),
265                                            err
266                                        );
267                                        (host, 0, None)
268                                    }
269                                }
270                            });
271                        handles.push(handle);
272                        started += 1;
273                        *entry += 1;
274                    }
275                }
276            }
277
278            for w in wait_list {
279                to_dl.push(w);
280            }
281
282            for w in next {
283                handles.push(w);
284            }
285            sleep(Duration::from_millis(100)).await;
286        }
287
288        results
289    }
290
291    pub(crate) async fn update_vidx<I>(&'a self, list: I) -> Result<Vec<PathBuf>, Error>
292    where
293        I: IntoIterator + 'a,
294        <I as IntoIterator>::Item: Into<String>,
295    {
296        let mut downloaded: HashMap<String, bool> = HashMap::new();
297        let mut failures: HashMap<String, usize> = HashMap::new();
298        let mut urls: Vec<String> = list.into_iter().map(|x| x.into()).collect();
299        let mut vidxs: Vec<Vidx> = Vec::new();
300        loop {
301            // Remove from list all duplicate URLs and those already downloaded
302            urls.dedup();
303            urls.retain(|u| !*downloaded.get(u).unwrap_or(&false));
304
305            // TODO: Make this section asynchronous
306            let mut next: Vec<String> = Vec::new();
307            for url in urls {
308                match self.download_vidx(url.clone()).await {
309                    Ok(t) => {
310                        log::info!("Downloaded {}", url);
311                        downloaded.insert(url, true);
312                        for v in &t.vendor_index {
313                            let u = format!("{}{}.pidx", v.url, v.vendor);
314                            if !downloaded.contains_key(&u) {
315                                downloaded.insert(u.clone(), false);
316                                next.push(u);
317                            }
318                        }
319                        vidxs.push(t);
320                    }
321                    Err(_err) => {
322                        let tries = failures.entry(url.clone()).or_insert(0);
323                        *tries += 1;
324                        if *tries < MAX_RETRIES {
325                            next.push(url);
326                        }
327                    }
328                }
329            }
330            if next.is_empty() {
331                break;
332            }
333            urls = next;
334        }
335
336        let mut pdscs: Vec<PdscRef> = Vec::new();
337        for mut v in vidxs {
338            pdscs.append(&mut v.pdsc_index);
339        }
340
341        pdscs.dedup_by_key(pdsc_url);
342        log::info!("Found {} Pdsc entries", pdscs.len());
343
344        Ok(self.download_iterator(pdscs.into_iter()).await)
345    }
346
347    pub(crate) async fn download_vidx<I: Into<String>>(
348        &'a self,
349        vidx_ref: I,
350    ) -> Result<Vidx, Error> {
351        let vidx = vidx_ref.into();
352        let uri = vidx.parse::<Url>().unwrap();
353
354        let req: reqwest::Response = self.client.get(uri).send().await?;
355        Vidx::from_string(req.text().await?.as_str())
356    }
357
358    #[allow(dead_code)]
359    pub(crate) fn download_vidx_list<I>(&'a self, list: I) -> impl Stream<Item = Option<Vidx>> + 'a
360    where
361        I: IntoIterator + 'a,
362        <I as IntoIterator>::Item: Into<String>,
363    {
364        list.into_iter()
365            .map(|vidx_ref| {
366                let vidx = vidx_ref.into();
367                println!("{}", vidx);
368                self.download_vidx(vidx.clone()).then(|r| async move {
369                    match r {
370                        Ok(v) => {
371                            println!("{} success", vidx);
372                            Some(v)
373                        }
374                        Err(e) => {
375                            log::error!("{}", format!("{}", e).replace("uri", &vidx));
376                            None
377                        }
378                    }
379                })
380            })
381            .collect::<FuturesUnordered<_>>()
382    }
383}