gitload/
lib.rs

1use std::{
2    fs::{create_dir_all, File},
3    io::Write,
4    path::PathBuf,
5    sync::{Arc, Mutex},
6};
7
8use anyhow::{anyhow, Result};
9use tokio::sync::broadcast::{channel, Sender};
10
11#[derive(serde::Deserialize, Debug)]
12struct Node {
13    path: String,
14    size: Option<isize>,
15}
16
17#[derive(serde::Deserialize, Debug)]
18struct FileTree {
19    tree: Vec<Node>,
20}
21
22#[derive(Clone, Copy, Debug)]
23pub struct Process {
24    pub current: usize,
25    pub all: usize,
26}
27
28impl Process {
29    fn new(n: usize) -> Arc<Mutex<Self>> {
30        let this = Self { current: 0, all: n };
31        Arc::new(Mutex::new(this))
32    }
33
34    fn deep_clone(&self) -> Self {
35        Self {
36            current: self.current,
37            all: self.all,
38        }
39    }
40
41    fn done(&mut self) {
42        self.current += 1;
43    }
44
45    pub fn percent(&self) -> f64 {
46        self.current as f64 / self.all as f64
47    }
48
49    pub fn is_over(&self) -> bool {
50        self.current == self.all
51    }
52}
53
54macro_rules! send_if_err {
55    ($tx: expr,$result: expr) => {
56        if let Err(err) = $result {
57            $tx.send(Err(err.to_string())).unwrap();
58            return;
59        }
60        $result.unwrap()
61    };
62}
63
64impl FileTree {
65    async fn download(&self, downloader: Arc<Downloader>, tx: Sender<Result<Process, String>>) {
66        let tasks: Vec<_> = self
67            .tree
68            .iter()
69            .filter(|node| node.size.is_some())
70            .map(|node| Arc::new(PathBuf::from(&node.path)))
71            .filter(|path| {
72                let src = PathBuf::from(&downloader.remote_path);
73                path.starts_with(src)
74            })
75            .collect();
76        let process = Process::new(tasks.len());
77        tasks.iter().for_each(|path| {
78            let src = PathBuf::from(&downloader.remote_path);
79            let dst = PathBuf::from(&downloader.local_path);
80            let path = path.clone();
81            let tx = tx.clone();
82            let downloader = downloader.clone();
83            let process = process.clone();
84            tokio::spawn(async move {
85                // src is remote path, such as nvim/init.lua
86                // dst is local path such as src
87                // path is the exact remote path, on the situation of single file, path equals with src
88                // the final dst is the download path, such as src/init.lua
89                let dst = dst.join(path.strip_prefix(&src).unwrap());
90                let dst_dir = dst.parent().unwrap();
91                create_dir_all(dst_dir).unwrap();
92                send_if_err!(
93                    tx,
94                    downloader
95                        .download_single(
96                            path.as_os_str().to_str().unwrap(),
97                            dst.to_str().unwrap().trim_end_matches("/")
98                        )
99                        .await
100                );
101                let mut lock = process.lock().unwrap();
102                lock.done();
103                let process = lock.deep_clone();
104                tx.send(Ok(process)).unwrap();
105            });
106        });
107    }
108}
109
110pub struct Downloader {
111    user: String,
112    repo: String,
113    branch: String,
114    remote_path: String,
115    local_path: String,
116    process_handler: fn(Process),
117}
118
119impl Downloader {
120    const USER_AGENT:&'static str="Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/109.0.5410.0 Safari/537.36";
121
122    async fn download_single(&self, path: &str, dst: &str) -> Result<()> {
123        let url = format!(
124            "https://raw.githubusercontent.com/{}/{}/{}/{path}",
125            &self.user, &self.repo, &self.branch
126        );
127        let client = reqwest::ClientBuilder::new()
128            .user_agent(Self::USER_AGENT)
129            .build()?;
130        let res = client.get(url).send().await?.text().await?;
131        let mut file = File::create(dst)?;
132        file.write_all(res.as_bytes())?;
133        Ok(())
134    }
135
136    pub async fn download(self) -> Result<()> {
137        let url = format!(
138            "https://api.github.com/repos/{}/{}/git/trees/{}?recursive=1",
139            &self.user, &self.repo, &self.branch
140        );
141        let client = reqwest::ClientBuilder::new()
142            .user_agent(Self::USER_AGENT)
143            .build()?;
144        let res = client.get(url).send().await.unwrap().text().await.unwrap();
145        let file_tree: FileTree = serde_json::from_str(&res)
146            .map_err(|_| anyhow!("Are you sure the repo really exists?"))?;
147
148        let (tx, mut rx) = channel::<Result<Process, String>>(5);
149
150        let me = Arc::new(self);
151
152        file_tree.download(me.clone(), tx).await;
153
154        loop {
155            let process = rx
156                .recv()
157                .await
158                .map_err(|_| anyhow!("Are you sure the target name is right?"))?
159                .unwrap();
160            (me.process_handler)(process);
161            if process.is_over() {
162                return Ok(());
163            }
164        }
165    }
166}
167
168#[derive(Default)]
169pub struct DownloaderBuilder {
170    user: String,
171    repo: String,
172    branch: Option<String>,
173    remote_path: String,
174    local_path: Option<String>,
175    process_handler: Option<fn(Process)>,
176}
177
178impl DownloaderBuilder {
179    pub fn new(user: &str, repo: &str, remote: &str) -> Self {
180        Self {
181            user: user.into(),
182            repo: repo.into(),
183            remote_path: remote.into(),
184            ..Default::default()
185        }
186    }
187
188    pub fn branch(mut self, branch: &str) -> Self {
189        self.branch = Some(branch.into());
190        self
191    }
192
193    pub fn local_path(mut self, local: &str) -> Self {
194        let remote = PathBuf::from(&self.remote_path);
195        let name = remote.file_name().unwrap().to_str().unwrap().to_string();
196        let local = PathBuf::from(local);
197        self.local_path = Some(local.join(name).to_str().unwrap().to_string());
198        self
199    }
200
201    pub fn on_process(mut self, f: fn(Process)) -> Self {
202        self.process_handler = Some(f);
203        self
204    }
205
206    pub fn build(self) -> Downloader {
207        let path = PathBuf::from(&self.remote_path);
208        let name = path.file_name().unwrap().to_str().unwrap().to_string();
209        Downloader {
210            user: self.user,
211            repo: self.repo,
212            remote_path: self.remote_path,
213            branch: self.branch.unwrap_or("main".into()),
214            local_path: self.local_path.unwrap_or(name),
215            process_handler: self.process_handler.unwrap_or(|_| {}),
216        }
217    }
218}