cmsis_pack/update/
download.rs

1use std::fs::{create_dir_all, rename, OpenOptions};
2use std::io::Write;
3use std::path::{Path, PathBuf};
4
5use anyhow::{anyhow, Error};
6use futures::prelude::*;
7use futures::stream::futures_unordered::FuturesUnordered;
8use reqwest::{redirect, Url};
9use reqwest::{Client, ClientBuilder, Response};
10use tokio::task::JoinHandle;
11use tokio::time::{sleep, Duration};
12
13use crate::pack_index::{PdscRef, Vidx};
14use crate::pdsc::Package;
15use crate::utils::parse::FromElem;
16use futures::StreamExt;
17use std::collections::HashMap;
18
19const CONCURRENCY: usize = 32;
20const HOST_LIMIT: usize = 6;
21const MAX_RETRIES: usize = 3;
22const CONNECT_TIMEOUT: u64 = 15;
23const READ_TIMEOUT: u64 = 15;
24
25fn pdsc_url(pdsc: &mut PdscRef) -> String {
26    if pdsc.url.ends_with('/') {
27        format!("{}{}.{}.pdsc", pdsc.url, pdsc.vendor, pdsc.name)
28    } else {
29        format!("{}/{}.{}.pdsc", pdsc.url, pdsc.vendor, pdsc.name)
30    }
31}
32
33pub trait DownloadConfig {
34    fn pack_store(&self) -> PathBuf;
35}
36
37#[allow(clippy::wrong_self_convention)]
38pub trait IntoDownload {
39    fn into_uri(&self) -> Result<Url, Error>;
40    fn into_fd<D: DownloadConfig>(&self, _: &D) -> PathBuf;
41}
42
43impl IntoDownload for PdscRef {
44    fn into_uri(&self) -> Result<Url, Error> {
45        let PdscRef {
46            url, vendor, name, ..
47        } = self;
48        let uri = if url.ends_with('/') {
49            format!("{}{}.{}.pdsc", url, vendor, name)
50        } else {
51            format!("{}/{}.{}.pdsc", url, vendor, name)
52        }
53        .parse()?;
54        Ok(uri)
55    }
56
57    fn into_fd<D: DownloadConfig>(&self, config: &D) -> PathBuf {
58        let PdscRef {
59            vendor,
60            name,
61            version,
62            ..
63        } = self;
64        let mut filename = config.pack_store();
65        let pdscname = format!("{}.{}.{}.pdsc", vendor, name, version);
66        filename.push(pdscname);
67        filename
68    }
69}
70
71impl IntoDownload for &Package {
72    fn into_uri(&self) -> Result<Url, Error> {
73        let Package {
74            name,
75            vendor,
76            url,
77            releases,
78            ..
79        } = *self;
80        let version: &str = releases.latest_release().version.as_ref();
81        let uri = if url.ends_with('/') {
82            format!("{}{}.{}.{}.pack", url, vendor, name, version)
83        } else {
84            format!("{}/{}.{}.{}.pack", url, vendor, name, version)
85        }
86        .parse()?;
87        Ok(uri)
88    }
89
90    fn into_fd<D: DownloadConfig>(&self, config: &D) -> PathBuf {
91        let Package {
92            name,
93            vendor,
94            releases,
95            ..
96        } = *self;
97        let version: &str = releases.latest_release().version.as_ref();
98        let mut filename = config.pack_store();
99        filename.push(Path::new(vendor));
100        filename.push(Path::new(name));
101        filename.push(format!("{}.pack", version));
102        filename
103    }
104}
105
106async fn save_response(response: Response, dest: PathBuf) -> Result<(usize, PathBuf), Error> {
107    let temp = dest.with_extension("part");
108    let file = OpenOptions::new()
109        .write(true)
110        .create(true)
111        .truncate(true)
112        .open(&temp);
113
114    let mut file = match file {
115        Err(err) => return Err(anyhow!(err.to_string())),
116        Ok(f) => f,
117    };
118
119    let mut fsize: usize = 0;
120    let mut stream = response.bytes_stream();
121    while let Some(chunk) = stream.next().await {
122        match chunk {
123            Ok(bytes) => {
124                fsize += bytes.len();
125
126                if let Err(err) = file.write_all(bytes.as_ref()) {
127                    let _ = std::fs::remove_file(temp);
128                    return Err(anyhow!(err.to_string()));
129                }
130            }
131            Err(err) => {
132                let _ = std::fs::remove_file(temp);
133                return Err(anyhow!(err.to_string()));
134            }
135        }
136    }
137    if let Err(err) = rename(&temp, &dest) {
138        let _ = std::fs::remove_file(temp);
139        return Err(anyhow!(err.to_string()));
140    }
141    Ok((fsize, dest))
142}
143
144pub trait DownloadProgress: Send {
145    fn size(&self, files: usize);
146    fn progress(&self, bytes: usize);
147    fn complete(&self);
148    fn for_file(&self, file: &str) -> Self;
149}
150
151impl DownloadProgress for () {
152    fn size(&self, _: usize) {}
153    fn progress(&self, _: usize) {}
154    fn complete(&self) {}
155    fn for_file(&self, _: &str) -> Self {}
156}
157
158pub struct DownloadContext<'a, Conf, Prog>
159where
160    Conf: DownloadConfig,
161    Prog: DownloadProgress + 'a,
162{
163    config: &'a Conf,
164    prog: Prog,
165    client: Client,
166}
167
168impl<'a, Conf, Prog> DownloadContext<'a, Conf, Prog>
169where
170    Conf: DownloadConfig,
171    Prog: DownloadProgress + 'a,
172{
173    pub fn new(config: &'a Conf, prog: Prog) -> Result<Self, Error> {
174        let client = ClientBuilder::new()
175            .redirect(redirect::Policy::limited(5))
176            .connect_timeout(Duration::from_secs(CONNECT_TIMEOUT))
177            .read_timeout(Duration::from_secs(READ_TIMEOUT))
178            .build()?;
179
180        Ok(DownloadContext {
181            config,
182            prog,
183            client,
184        })
185    }
186
187    pub async fn download_iterator<I>(&'a self, iter: I) -> Vec<PathBuf>
188    where
189        I: IntoIterator + 'a,
190        <I as IntoIterator>::Item: IntoDownload,
191    {
192        let mut to_dl: Vec<(Url, String, PathBuf)> = iter
193            .into_iter()
194            .filter_map(|i| {
195                if let Ok(uri) = i.into_uri() {
196                    let c = uri.clone();
197                    c.host_str()
198                        .map(|host| (uri, host.to_string(), i.into_fd(self.config)))
199                } else {
200                    None
201                }
202            })
203            .collect();
204        self.prog.size(to_dl.len());
205
206        let mut hosts: HashMap<String, usize> = HashMap::new();
207        let mut results: Vec<PathBuf> = vec![];
208        let mut started: usize = 0;
209        let mut handles: Vec<JoinHandle<(String, usize, Option<PathBuf>)>> = vec![];
210
211        while !to_dl.is_empty() || !handles.is_empty() {
212            let mut wait_list: Vec<(Url, String, PathBuf)> = vec![];
213            let mut next: Vec<JoinHandle<(String, usize, Option<PathBuf>)>> = vec![];
214
215            while let Some(handle) = handles.pop() {
216                if handle.is_finished() {
217                    let r = handle.await.unwrap();
218                    *hosts.entry(r.0).or_insert(1) -= 1;
219                    started -= 1;
220                    self.prog.progress(r.1);
221                    self.prog.complete();
222                    if let Some(path) = r.2 {
223                        results.push(path);
224                    }
225                } else {
226                    next.push(handle);
227                }
228            }
229
230            while !to_dl.is_empty() && started < CONCURRENCY {
231                let from = to_dl.pop().unwrap();
232                let host = from.1.clone();
233                let entry = hosts.entry(host).or_insert(0);
234                if *entry >= HOST_LIMIT {
235                    wait_list.push(from);
236                } else {
237                    let source = from.0.clone();
238                    let host = from.1.clone();
239                    let dest = from.2.clone();
240                    if dest.exists() {
241                        self.prog.complete();
242                        results.push(dest);
243                    } else {
244                        let client = self.client.clone();
245                        let handle: JoinHandle<(String, usize, Option<PathBuf>)> =
246                            tokio::spawn(async move {
247                                dest.parent().map(create_dir_all);
248                                let res = client.get(source.clone()).send().await;
249                                let res: Result<(usize, PathBuf), Error> = match res {
250                                    Ok(r) => {
251                                        let rc = r.status().as_u16();
252                                        if rc >= 400 {
253                                            Err(anyhow!(format!(
254                                                "Response code in invalid range: {}",
255                                                rc
256                                            )
257                                            .to_string()))
258                                        } else {
259                                            save_response(r, dest).await
260                                        }
261                                    }
262                                    Err(err) => Err(anyhow!(err.to_string())),
263                                };
264                                match res {
265                                    Ok(r) => (host, r.0, Some(r.1)),
266                                    Err(err) => {
267                                        log::warn!(
268                                            "Download of {} failed: {}",
269                                            source.to_string(),
270                                            err
271                                        );
272                                        (host, 0, None)
273                                    }
274                                }
275                            });
276                        handles.push(handle);
277                        started += 1;
278                        *entry += 1;
279                    }
280                }
281            }
282
283            for w in wait_list {
284                to_dl.push(w);
285            }
286
287            for w in next {
288                handles.push(w);
289            }
290            sleep(Duration::from_millis(100)).await;
291        }
292
293        results
294    }
295
296    pub(crate) async fn update_vidx<I>(&'a self, list: I) -> Result<Vec<PathBuf>, Error>
297    where
298        I: IntoIterator + 'a,
299        <I as IntoIterator>::Item: Into<String>,
300    {
301        let mut downloaded: HashMap<String, bool> = HashMap::new();
302        let mut failures: HashMap<String, usize> = HashMap::new();
303        let mut urls: Vec<String> = list.into_iter().map(|x| x.into()).collect();
304        let mut vidxs: Vec<Vidx> = Vec::new();
305        loop {
306            // Remove from list all duplicate URLs and those already downloaded
307            urls.dedup();
308            urls.retain(|u| !*downloaded.get(u).unwrap_or(&false));
309
310            // TODO: Make this section asynchronous
311            let mut next: Vec<String> = Vec::new();
312            for url in urls {
313                match self.download_vidx(url.clone()).await {
314                    Ok(t) => {
315                        log::info!("Downloaded {}", url);
316                        downloaded.insert(url, true);
317                        for v in &t.vendor_index {
318                            let u = format!("{}{}.pidx", v.url, v.vendor);
319                            if !downloaded.contains_key(&u) {
320                                downloaded.insert(u.clone(), false);
321                                next.push(u);
322                            }
323                        }
324                        vidxs.push(t);
325                    }
326                    Err(_err) => {
327                        let tries = failures.entry(url.clone()).or_insert(0);
328                        *tries += 1;
329                        if *tries < MAX_RETRIES {
330                            next.push(url);
331                        }
332                    }
333                }
334            }
335            if next.is_empty() {
336                break;
337            }
338            urls = next;
339        }
340
341        let mut pdscs: Vec<PdscRef> = Vec::new();
342        for mut v in vidxs {
343            pdscs.append(&mut v.pdsc_index);
344        }
345
346        pdscs.dedup_by_key(pdsc_url);
347        log::info!("Found {} Pdsc entries", pdscs.len());
348
349        Ok(self.download_iterator(pdscs.into_iter()).await)
350    }
351
352    pub(crate) async fn download_vidx<I: Into<String>>(
353        &'a self,
354        vidx_ref: I,
355    ) -> Result<Vidx, Error> {
356        let vidx = vidx_ref.into();
357        let uri = vidx.parse::<Url>().unwrap();
358
359        let req: reqwest::Response = self.client.get(uri).send().await?;
360        Vidx::from_string(req.text().await?.as_str())
361    }
362
363    #[allow(dead_code)]
364    pub(crate) fn download_vidx_list<I>(&'a self, list: I) -> impl Stream<Item = Option<Vidx>> + 'a
365    where
366        I: IntoIterator + 'a,
367        <I as IntoIterator>::Item: Into<String>,
368    {
369        list.into_iter()
370            .map(|vidx_ref| {
371                let vidx = vidx_ref.into();
372                println!("{}", vidx);
373                self.download_vidx(vidx.clone()).then(|r| async move {
374                    match r {
375                        Ok(v) => {
376                            println!("{} success", vidx);
377                            Some(v)
378                        }
379                        Err(e) => {
380                            log::error!("{}", format!("{}", e).replace("uri", &vidx));
381                            None
382                        }
383                    }
384                })
385            })
386            .collect::<FuturesUnordered<_>>()
387    }
388}