1use base64::{prelude::BASE64_STANDARD, Engine};
2use futures::{
3 future::{self, BoxFuture},
4 FutureExt,
5};
6use itertools::Itertools;
7use serde::Deserialize;
8use std::{
9 borrow::Cow,
10 env,
11 path::{Path, PathBuf},
12};
13
14use crate::{request::HttpRequest, Error, Filter, GithubBranchPath, SourceTree, TreeEntryType};
15
16#[derive(Debug)]
18pub enum DownloadEvent<'p> {
19 DownloadStarted {
21 path: &'p str,
23 },
24 DownloadCompleted {
26 path: &'p str,
28 },
29 DownloadFailed {
31 path: &'p str,
33 error: Error,
35 },
36}
37
38pub trait DownloadReporter: Sync {
40 fn on_event<'p>(&'p self, event: DownloadEvent<'p>) -> ();
42}
43
44pub struct NullDownloadReporter {}
46
47impl DownloadReporter for NullDownloadReporter {
48 fn on_event<'p>(&'p self, _event: DownloadEvent<'p>) -> () {}
49}
50
51const DEFAULT_MAX_DOWNLOADS: usize = 5;
52
53pub struct DownloadConfig<'download, Reporter>
55where
56 Reporter: DownloadReporter,
57{
58 pub output_path: &'download Path,
60 pub reporter: Option<&'download Reporter>,
62 pub max_simultaneous_downloads: usize,
65 pub access_token: Option<Cow<'download, str>>,
67}
68
69impl<'download, Reporter> DownloadConfig<'download, Reporter>
70where
71 Reporter: DownloadReporter,
72{
73 pub fn new(output_path: &'download Path) -> DownloadConfig<'download, Reporter> {
77 let access_token = env::var("GITHUB_ACCESS_TOKEN")
78 .ok()
79 .and_then(|s| Some(Cow::from(s)));
80
81 DownloadConfig {
82 output_path,
83 reporter: None,
84 max_simultaneous_downloads: DEFAULT_MAX_DOWNLOADS,
85 access_token: access_token,
86 }
87 }
88
89 pub fn new_with_reporter(
93 output_path: &'download Path,
94 reporter: &'download Reporter,
95 ) -> DownloadConfig<'download, Reporter> {
96 let access_token = env::var("GITHUB_ACCESS_TOKEN")
97 .ok()
98 .and_then(|s| Some(Cow::from(s)));
99
100 DownloadConfig {
101 output_path,
102 reporter: Some(reporter),
103 max_simultaneous_downloads: DEFAULT_MAX_DOWNLOADS,
104 access_token,
105 }
106 }
107}
108
109pub type DownloadConfigNoReporting<'download> = DownloadConfig<'download, NullDownloadReporter>;
111
112pub struct Downloader {}
114
115impl<'p> Downloader {
116 pub async fn download<Reporter: DownloadReporter>(
118 config: &'p DownloadConfig<'p, Reporter>,
119 path: &GithubBranchPath<'p>,
120 filter: &Filter<'p>,
121 ) -> Result<Vec<SourceTree>, Error> {
122 let tree = SourceTree::get(path).await?;
123 let files = Downloader::download_tree(config, &tree, filter).await?;
124 Ok(files.into_iter().map(|s| s.clone()).collect())
125 }
126
127 pub async fn download_tree<Reporter: DownloadReporter>(
129 config: &'p DownloadConfig<'p, Reporter>,
130 tree: &'p SourceTree,
131 filter: &Filter<'p>,
132 ) -> Result<Vec<&'p SourceTree>, Error> {
133 Ok(Downloader::download_tree_iter(config, tree.iter(), filter).await?)
134 }
135
136 pub async fn download_tree_iter<Reporter, Iter>(
138 config: &'p DownloadConfig<'p, Reporter>,
139 iter: Iter,
140 filter: &Filter<'p>,
141 ) -> Result<Vec<&'p SourceTree>, Error>
142 where
143 Reporter: DownloadReporter,
144 Iter: IntoIterator<Item = &'p SourceTree>,
145 {
146 let output_path = config.output_path;
147 let access_token = &config.access_token;
148
149 let files: Vec<&SourceTree> = iter
150 .into_iter()
151 .filter(|n| {
152 n.entry_type == TreeEntryType::Blob && filter.check(n.path.to_str().unwrap_or(""))
153 })
154 .collect();
155
156 let mut active: Vec<BoxFuture<'p, Result<(), Error>>> = Vec::new();
157
158 for f in &files {
159 if active.len() > config.max_simultaneous_downloads {
160 let (result, index, _) = future::select_all(&mut active).await;
162 result?;
163 let _future = active.remove(index);
164 }
165
166 let next = Downloader::download_node_wrapper(
167 &config.reporter,
168 &access_token,
169 output_path.to_path_buf(),
170 f,
171 );
172 active.push(next.boxed());
173 }
174
175 for r in future::join_all(active).await {
176 if let Err(e) = r {
177 return Err(e);
178 }
179 }
180
181 Ok(files)
182 }
183
184 async fn download_node_wrapper<Reporter: DownloadReporter>(
185 reporter: &'p Option<&'p Reporter>,
186 access_token: &'p Option<Cow<'p, str>>,
187 output_path: PathBuf,
188 tree: &'p SourceTree,
189 ) -> Result<(), Error> {
190 let path = tree.path.to_str().unwrap();
191 if let Some(reporter) = *reporter {
192 reporter.on_event(DownloadEvent::DownloadStarted { path })
193 }
194
195 let result = Downloader::download_node(&access_token, &output_path, &tree).await;
196
197 if let Some(reporter) = *reporter {
198 match result {
199 Ok(_) => reporter.on_event(DownloadEvent::DownloadCompleted { path }),
200 Err(ref e) => reporter.on_event(DownloadEvent::DownloadFailed {
201 path,
202 error: e.clone(),
203 }),
204 }
205 };
206
207 result
208 }
209
210 async fn download_node(
211 access_token: &'p Option<Cow<'p, str>>,
212 output_path: &'p Path,
213 tree: &'p SourceTree,
214 ) -> Result<(), Error> {
215 let client = HttpRequest::client(access_token)?;
216 let request = client.get(&tree.url).build()?;
217 let str = request
218 .headers()
219 .iter()
220 .map(|(name, val)| format!("{} = {:?}", name, val))
221 .join(", ");
222 println!("{}", str);
223 let response = client.execute(request).await?;
224 let body = response.text().await?;
225
226 #[derive(Deserialize)]
227 #[serde(untagged)]
228 enum BlobOrError {
229 Blob { content: String },
230 Error { message: String },
231 }
232
233 let model: BlobOrError = serde_json::from_str(&body)?;
234 match model {
235 BlobOrError::Error { message } => Err(Error::GithubError(message)),
236 BlobOrError::Blob { content } => {
237 let base64_str: String = content.chars().filter(|c| *c != '\n').collect();
238 let bytes = BASE64_STANDARD.decode(base64_str.as_bytes())?;
239
240 let output_path = output_path.to_path_buf().join(&tree.path);
241
242 Downloader::write_file(&output_path, &bytes).await?;
243 Ok(())
244 }
245 }
246 }
247
248 async fn write_file(path: &Path, bytes: &[u8]) -> Result<(), Error> {
249 Downloader::ensure_dir_exists(path).await?;
250
251 tokio::fs::write(path, &bytes).await?;
252
253 Ok(())
254 }
255
256 async fn ensure_dir_exists(path: &Path) -> Result<(), Error> {
257 let dirname = path.parent();
258 if dirname.is_none() {
259 return Ok(());
260 }
261
262 let dirname = dirname.unwrap();
263 if dirname.exists() {
264 return Ok(());
265 }
266
267 tokio::fs::create_dir_all(dirname).await?;
268
269 Ok(())
270 }
271}