1use std::{
2 fs::{create_dir_all, File},
3 io::Write,
4 path::PathBuf,
5 sync::{Arc, Mutex},
6};
7
8use anyhow::{anyhow, Result};
9use tokio::sync::broadcast::{channel, Sender};
10
11#[derive(serde::Deserialize, Debug)]
12struct Node {
13 path: String,
14 size: Option<isize>,
15}
16
17#[derive(serde::Deserialize, Debug)]
18struct FileTree {
19 tree: Vec<Node>,
20}
21
22#[derive(Clone, Copy, Debug)]
23pub struct Process {
24 pub current: usize,
25 pub all: usize,
26}
27
28impl Process {
29 fn new(n: usize) -> Arc<Mutex<Self>> {
30 let this = Self { current: 0, all: n };
31 Arc::new(Mutex::new(this))
32 }
33
34 fn deep_clone(&self) -> Self {
35 Self {
36 current: self.current,
37 all: self.all,
38 }
39 }
40
41 fn done(&mut self) {
42 self.current += 1;
43 }
44
45 pub fn percent(&self) -> f64 {
46 self.current as f64 / self.all as f64
47 }
48
49 pub fn is_over(&self) -> bool {
50 self.current == self.all
51 }
52}
53
54macro_rules! send_if_err {
55 ($tx: expr,$result: expr) => {
56 if let Err(err) = $result {
57 $tx.send(Err(err.to_string())).unwrap();
58 return;
59 }
60 $result.unwrap()
61 };
62}
63
64impl FileTree {
65 async fn download(&self, downloader: Arc<Downloader>, tx: Sender<Result<Process, String>>) {
66 let tasks: Vec<_> = self
67 .tree
68 .iter()
69 .filter(|node| node.size.is_some())
70 .map(|node| Arc::new(PathBuf::from(&node.path)))
71 .filter(|path| {
72 let src = PathBuf::from(&downloader.remote_path);
73 path.starts_with(src)
74 })
75 .collect();
76 let process = Process::new(tasks.len());
77 tasks.iter().for_each(|path| {
78 let src = PathBuf::from(&downloader.remote_path);
79 let dst = PathBuf::from(&downloader.local_path);
80 let path = path.clone();
81 let tx = tx.clone();
82 let downloader = downloader.clone();
83 let process = process.clone();
84 tokio::spawn(async move {
85 let dst = dst.join(path.strip_prefix(&src).unwrap());
90 let dst_dir = dst.parent().unwrap();
91 create_dir_all(dst_dir).unwrap();
92 send_if_err!(
93 tx,
94 downloader
95 .download_single(
96 path.as_os_str().to_str().unwrap(),
97 dst.to_str().unwrap().trim_end_matches("/")
98 )
99 .await
100 );
101 let mut lock = process.lock().unwrap();
102 lock.done();
103 let process = lock.deep_clone();
104 tx.send(Ok(process)).unwrap();
105 });
106 });
107 }
108}
109
110pub struct Downloader {
111 user: String,
112 repo: String,
113 branch: String,
114 remote_path: String,
115 local_path: String,
116 process_handler: fn(Process),
117}
118
119impl Downloader {
120 const USER_AGENT:&'static str="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.5410.0 Safari/537.36";
121
122 async fn download_single(&self, path: &str, dst: &str) -> Result<()> {
123 let url = format!(
124 "https://raw.githubusercontent.com/{}/{}/{}/{path}",
125 &self.user, &self.repo, &self.branch
126 );
127 let client = reqwest::ClientBuilder::new()
128 .user_agent(Self::USER_AGENT)
129 .build()?;
130 let res = client.get(url).send().await?.text().await?;
131 let mut file = File::create(dst)?;
132 file.write_all(res.as_bytes())?;
133 Ok(())
134 }
135
136 pub async fn download(self) -> Result<()> {
137 let url = format!(
138 "https://api.github.com/repos/{}/{}/git/trees/{}?recursive=1",
139 &self.user, &self.repo, &self.branch
140 );
141 let client = reqwest::ClientBuilder::new()
142 .user_agent(Self::USER_AGENT)
143 .build()?;
144 let res = client.get(url).send().await.unwrap().text().await.unwrap();
145 let file_tree: FileTree = serde_json::from_str(&res)
146 .map_err(|_| anyhow!("Are you sure the repo really exists?"))?;
147
148 let (tx, mut rx) = channel::<Result<Process, String>>(5);
149
150 let me = Arc::new(self);
151
152 file_tree.download(me.clone(), tx).await;
153
154 loop {
155 let process = rx
156 .recv()
157 .await
158 .map_err(|_| anyhow!("Are you sure the target name is right?"))?
159 .unwrap();
160 (me.process_handler)(process);
161 if process.is_over() {
162 return Ok(());
163 }
164 }
165 }
166}
167
168#[derive(Default)]
169pub struct DownloaderBuilder {
170 user: String,
171 repo: String,
172 branch: Option<String>,
173 remote_path: String,
174 local_path: Option<String>,
175 process_handler: Option<fn(Process)>,
176}
177
178impl DownloaderBuilder {
179 pub fn new(user: &str, repo: &str, remote: &str) -> Self {
180 Self {
181 user: user.into(),
182 repo: repo.into(),
183 remote_path: remote.into(),
184 ..Default::default()
185 }
186 }
187
188 pub fn branch(mut self, branch: &str) -> Self {
189 self.branch = Some(branch.into());
190 self
191 }
192
193 pub fn local_path(mut self, local: &str) -> Self {
194 let remote = PathBuf::from(&self.remote_path);
195 let name = remote.file_name().unwrap().to_str().unwrap().to_string();
196 let local = PathBuf::from(local);
197 self.local_path = Some(local.join(name).to_str().unwrap().to_string());
198 self
199 }
200
201 pub fn on_process(mut self, f: fn(Process)) -> Self {
202 self.process_handler = Some(f);
203 self
204 }
205
206 pub fn build(self) -> Downloader {
207 let path = PathBuf::from(&self.remote_path);
208 let name = path.file_name().unwrap().to_str().unwrap().to_string();
209 Downloader {
210 user: self.user,
211 repo: self.repo,
212 remote_path: self.remote_path,
213 branch: self.branch.unwrap_or("main".into()),
214 local_path: self.local_path.unwrap_or(name),
215 process_handler: self.process_handler.unwrap_or(|_| {}),
216 }
217 }
218}