grab_github/
download.rs

1use base64::{prelude::BASE64_STANDARD, Engine};
2use futures::{
3    future::{self, BoxFuture},
4    FutureExt,
5};
6use itertools::Itertools;
7use serde::Deserialize;
8use std::{
9    borrow::Cow,
10    env,
11    path::{Path, PathBuf},
12};
13
14use crate::{request::HttpRequest, Error, Filter, GithubBranchPath, SourceTree, TreeEntryType};
15
16/// An event involving a single download.
17#[derive(Debug)]
18pub enum DownloadEvent<'p> {
19    /// A file has begun downloading.
20    DownloadStarted {
21        /// The path of the file relative to the root of the repository.
22        path: &'p str,
23    },
24    /// A file has been downloaded successfully.
25    DownloadCompleted {
26        /// The path of the file relative to the root of the repository.
27        path: &'p str,
28    },
29    /// A file has encountered an error and has failed to download.
30    DownloadFailed {
31        /// The path of the file relative to the root of the repository.
32        path: &'p str,
33        /// The [Error] that was encountered while attempting to download the file.
34        error: Error,
35    },
36}
37
38/// Implement this trait to receive events on the status of each upload.
39pub trait DownloadReporter: Sync {
40    /// Called with events for each download's status.
41    fn on_event<'p>(&'p self, event: DownloadEvent<'p>) -> ();
42}
43
44/// An empty download reporter that does nothing.
45pub struct NullDownloadReporter {}
46
47impl DownloadReporter for NullDownloadReporter {
48    fn on_event<'p>(&'p self, _event: DownloadEvent<'p>) -> () {}
49}
50
51const DEFAULT_MAX_DOWNLOADS: usize = 5;
52
53/// Contains the configuration for a downloading operation.
54pub struct DownloadConfig<'download, Reporter>
55where
56    Reporter: DownloadReporter,
57{
58    /// The directory that the tree will be downloaded into.
59    pub output_path: &'download Path,
60    /// If provided, the reporter will receive events on the status of each download.
61    pub reporter: Option<&'download Reporter>,
62    /// The maximum number of simultaneous downloads allowed at once.
63    /// The default is 5.
64    pub max_simultaneous_downloads: usize,
65    /// Your GitHub personal access token, if you have one.
66    pub access_token: Option<Cow<'download, str>>,
67}
68
69impl<'download, Reporter> DownloadConfig<'download, Reporter>
70where
71    Reporter: DownloadReporter,
72{
73    /// Creates a new [DownloadConfig] with the given output path and default values.
74    ///
75    /// `access_token` will be read from the environment variable `GITHUB_ACCESS_TOKEN` if available.
76    pub fn new(output_path: &'download Path) -> DownloadConfig<'download, Reporter> {
77        let access_token = env::var("GITHUB_ACCESS_TOKEN")
78            .ok()
79            .and_then(|s| Some(Cow::from(s)));
80
81        DownloadConfig {
82            output_path,
83            reporter: None,
84            max_simultaneous_downloads: DEFAULT_MAX_DOWNLOADS,
85            access_token: access_token,
86        }
87    }
88
89    /// Creates a new [DownloadConfig] with the given output path, reporter, and default values.
90    ///
91    /// `access_token` will be read from the environment variable `GITHUB_ACCESS_TOKEN` if available.
92    pub fn new_with_reporter(
93        output_path: &'download Path,
94        reporter: &'download Reporter,
95    ) -> DownloadConfig<'download, Reporter> {
96        let access_token = env::var("GITHUB_ACCESS_TOKEN")
97            .ok()
98            .and_then(|s| Some(Cow::from(s)));
99
100        DownloadConfig {
101            output_path,
102            reporter: Some(reporter),
103            max_simultaneous_downloads: DEFAULT_MAX_DOWNLOADS,
104            access_token,
105        }
106    }
107}
108
109/// A convenience type for a download config with no reporter.
110pub type DownloadConfigNoReporting<'download> = DownloadConfig<'download, NullDownloadReporter>;
111
112/// Contains methods for downloading a [SourceTree] into a directory of files.
113pub struct Downloader {}
114
115impl<'p> Downloader {
116    /// Downloads an entire GitHub tree specified by `path`.
117    pub async fn download<Reporter: DownloadReporter>(
118        config: &'p DownloadConfig<'p, Reporter>,
119        path: &GithubBranchPath<'p>,
120        filter: &Filter<'p>,
121    ) -> Result<Vec<SourceTree>, Error> {
122        let tree = SourceTree::get(path).await?;
123        let files = Downloader::download_tree(config, &tree, filter).await?;
124        Ok(files.into_iter().map(|s| s.clone()).collect())
125    }
126
127    /// Downloads an entire [SourceTree] to a directory.
128    pub async fn download_tree<Reporter: DownloadReporter>(
129        config: &'p DownloadConfig<'p, Reporter>,
130        tree: &'p SourceTree,
131        filter: &Filter<'p>,
132    ) -> Result<Vec<&'p SourceTree>, Error> {
133        Ok(Downloader::download_tree_iter(config, tree.iter(), filter).await?)
134    }
135
136    /// Downloads an iterator of [SourceTree] nodes to a directory.
137    pub async fn download_tree_iter<Reporter, Iter>(
138        config: &'p DownloadConfig<'p, Reporter>,
139        iter: Iter,
140        filter: &Filter<'p>,
141    ) -> Result<Vec<&'p SourceTree>, Error>
142    where
143        Reporter: DownloadReporter,
144        Iter: IntoIterator<Item = &'p SourceTree>,
145    {
146        let output_path = config.output_path;
147        let access_token = &config.access_token;
148
149        let files: Vec<&SourceTree> = iter
150            .into_iter()
151            .filter(|n| {
152                n.entry_type == TreeEntryType::Blob && filter.check(n.path.to_str().unwrap_or(""))
153            })
154            .collect();
155
156        let mut active: Vec<BoxFuture<'p, Result<(), Error>>> = Vec::new();
157
158        for f in &files {
159            if active.len() > config.max_simultaneous_downloads {
160                // make sure some active downloads complete before starting new ones
161                let (result, index, _) = future::select_all(&mut active).await;
162                result?;
163                let _future = active.remove(index);
164            }
165
166            let next = Downloader::download_node_wrapper(
167                &config.reporter,
168                &access_token,
169                output_path.to_path_buf(),
170                f,
171            );
172            active.push(next.boxed());
173        }
174
175        for r in future::join_all(active).await {
176            if let Err(e) = r {
177                return Err(e);
178            }
179        }
180
181        Ok(files)
182    }
183
184    async fn download_node_wrapper<Reporter: DownloadReporter>(
185        reporter: &'p Option<&'p Reporter>,
186        access_token: &'p Option<Cow<'p, str>>,
187        output_path: PathBuf,
188        tree: &'p SourceTree,
189    ) -> Result<(), Error> {
190        let path = tree.path.to_str().unwrap();
191        if let Some(reporter) = *reporter {
192            reporter.on_event(DownloadEvent::DownloadStarted { path })
193        }
194
195        let result = Downloader::download_node(&access_token, &output_path, &tree).await;
196
197        if let Some(reporter) = *reporter {
198            match result {
199                Ok(_) => reporter.on_event(DownloadEvent::DownloadCompleted { path }),
200                Err(ref e) => reporter.on_event(DownloadEvent::DownloadFailed {
201                    path,
202                    error: e.clone(),
203                }),
204            }
205        };
206
207        result
208    }
209
210    async fn download_node(
211        access_token: &'p Option<Cow<'p, str>>,
212        output_path: &'p Path,
213        tree: &'p SourceTree,
214    ) -> Result<(), Error> {
215        let client = HttpRequest::client(access_token)?;
216        let request = client.get(&tree.url).build()?;
217        let str = request
218            .headers()
219            .iter()
220            .map(|(name, val)| format!("{} = {:?}", name, val))
221            .join(", ");
222        println!("{}", str);
223        let response = client.execute(request).await?;
224        let body = response.text().await?;
225
226        #[derive(Deserialize)]
227        #[serde(untagged)]
228        enum BlobOrError {
229            Blob { content: String },
230            Error { message: String },
231        }
232
233        let model: BlobOrError = serde_json::from_str(&body)?;
234        match model {
235            BlobOrError::Error { message } => Err(Error::GithubError(message)),
236            BlobOrError::Blob { content } => {
237                let base64_str: String = content.chars().filter(|c| *c != '\n').collect();
238                let bytes = BASE64_STANDARD.decode(base64_str.as_bytes())?;
239
240                let output_path = output_path.to_path_buf().join(&tree.path);
241
242                Downloader::write_file(&output_path, &bytes).await?;
243                Ok(())
244            }
245        }
246    }
247
248    async fn write_file(path: &Path, bytes: &[u8]) -> Result<(), Error> {
249        Downloader::ensure_dir_exists(path).await?;
250
251        tokio::fs::write(path, &bytes).await?;
252
253        Ok(())
254    }
255
256    async fn ensure_dir_exists(path: &Path) -> Result<(), Error> {
257        let dirname = path.parent();
258        if dirname.is_none() {
259            return Ok(());
260        }
261
262        let dirname = dirname.unwrap();
263        if dirname.exists() {
264            return Ok(());
265        }
266
267        tokio::fs::create_dir_all(dirname).await?;
268
269        Ok(())
270    }
271}