git_repos/provider/
mod.rs

1use serde::{Deserialize, Serialize};
2
3pub mod github;
4pub mod gitlab;
5
6use std::collections::HashMap;
7
8pub use github::Github;
9pub use gitlab::Gitlab;
10
11use super::{repo, token};
12
13const DEFAULT_REMOTE_NAME: &str = "origin";
14
15#[derive(Debug, Deserialize, Serialize, clap::ValueEnum, Clone)]
16pub enum RemoteProvider {
17    #[serde(alias = "github", alias = "GitHub")]
18    Github,
19    #[serde(alias = "gitlab", alias = "GitLab")]
20    Gitlab,
21}
22
23#[derive(Deserialize)]
24#[serde(untagged)]
25enum ProjectResponse<F> {
26    Success,
27    Failure(F),
28}
29
30pub fn escape(s: &str) -> String {
31    url_escape::encode_component(s).to_string()
32}
33
34pub trait Project {
35    fn into_repo_config(
36        self,
37        provider_name: &str,
38        worktree_setup: bool,
39        force_ssh: bool,
40    ) -> repo::Repo
41    where
42        Self: Sized,
43    {
44        repo::Repo {
45            name: self.name(),
46            namespace: self.namespace(),
47            worktree_setup,
48            remotes: Some(vec![repo::Remote {
49                name: String::from(provider_name),
50                url: if force_ssh || self.private() { self.ssh_url() } else { self.http_url() },
51                remote_type: if force_ssh || self.private() {
52                    repo::RemoteType::SSH
53                } else {
54                    repo::RemoteType::HTTPS
55                },
56            }]),
57        }
58    }
59
60    fn name(&self) -> String;
61    fn namespace(&self) -> Option<String>;
62    fn ssh_url(&self) -> String;
63    fn http_url(&self) -> String;
64    fn private(&self) -> bool;
65}
66
67#[derive(Clone)]
68pub struct Filter {
69    users: Vec<String>,
70    groups: Vec<String>,
71    owner: bool,
72    access: bool,
73}
74
75impl Filter {
76    pub fn new(users: Vec<String>, groups: Vec<String>, owner: bool, access: bool) -> Self {
77        Filter { users, groups, owner, access }
78    }
79
80    pub fn empty(&self) -> bool {
81        self.users.is_empty() && self.groups.is_empty() && !self.owner && !self.access
82    }
83}
84
85pub enum ApiErrorResponse<T>
86where
87    T: JsonError,
88{
89    Json(T),
90    String(String),
91}
92
93impl<T> From<String> for ApiErrorResponse<T>
94where
95    T: JsonError,
96{
97    fn from(s: String) -> ApiErrorResponse<T> {
98        ApiErrorResponse::String(s)
99    }
100}
101
102pub trait JsonError {
103    fn to_string(self) -> String;
104}
105
106pub trait Provider {
107    type Project: serde::de::DeserializeOwned + Project;
108    type Error: serde::de::DeserializeOwned + JsonError;
109
110    fn new(
111        filter: Filter,
112        secret_token: token::AuthToken,
113        api_url_override: Option<String>,
114    ) -> Result<Self, String>
115    where
116        Self: Sized;
117
118    fn filter(&self) -> &Filter;
119    fn secret_token(&self) -> &token::AuthToken;
120    fn auth_header_key() -> &'static str;
121
122    fn get_user_projects(
123        &self,
124        user: &str,
125    ) -> Result<Vec<Self::Project>, ApiErrorResponse<Self::Error>>;
126
127    fn get_group_projects(
128        &self,
129        group: &str,
130    ) -> Result<Vec<Self::Project>, ApiErrorResponse<Self::Error>>;
131
132    fn get_own_projects(&self) -> Result<Vec<Self::Project>, ApiErrorResponse<Self::Error>> {
133        self.get_user_projects(&self.get_current_user()?)
134    }
135
136    fn get_accessible_projects(&self) -> Result<Vec<Self::Project>, ApiErrorResponse<Self::Error>>;
137
138    fn get_current_user(&self) -> Result<String, ApiErrorResponse<Self::Error>>;
139
140    /// Calls the API at specific uri and expects a successful response of Vec<T> back, or an error
141    /// response U
142    ///
143    /// Handles paging with "link" HTTP headers properly and reads all pages to
144    /// the end.
145    fn call_list(
146        &self,
147        uri: &str,
148        accept_header: Option<&str>,
149    ) -> Result<Vec<Self::Project>, ApiErrorResponse<Self::Error>> {
150        let mut results = vec![];
151
152        match ureq::get(uri)
153            .set("accept", accept_header.unwrap_or("application/json"))
154            .set(
155                "authorization",
156                &format!("{} {}", Self::auth_header_key(), &self.secret_token().access()),
157            )
158            .call()
159        {
160            Err(ureq::Error::Transport(error)) => return Err(error.to_string())?,
161            Err(ureq::Error::Status(_code, response)) => {
162                let r: Self::Error = response
163                    .into_json()
164                    .map_err(|error| format!("Failed deserializing error response: {}", error))?;
165                return Err(ApiErrorResponse::Json(r));
166            },
167            Ok(response) => {
168                if let Some(link_header) = response.header("link") {
169                    let link_header =
170                        parse_link_header::parse(link_header).map_err(|error| error.to_string())?;
171
172                    let next_page = link_header.get(&Some(String::from("next")));
173
174                    if let Some(page) = next_page {
175                        let following_repos = self.call_list(&page.raw_uri, accept_header)?;
176                        results.extend(following_repos);
177                    }
178                }
179
180                let result: Vec<Self::Project> = response
181                    .into_json()
182                    .map_err(|error| format!("Failed deserializing response: {}", error))?;
183
184                results.extend(result);
185            },
186        }
187
188        Ok(results)
189    }
190
191    fn get_repos(
192        &self,
193        worktree_setup: bool,
194        force_ssh: bool,
195        remote_name: Option<String>,
196    ) -> Result<HashMap<Option<String>, Vec<repo::Repo>>, String> {
197        let mut repos = vec![];
198
199        if self.filter().owner {
200            repos.extend(self.get_own_projects().map_err(|error| {
201                match error {
202                    ApiErrorResponse::Json(x) => x.to_string(),
203                    ApiErrorResponse::String(s) => s,
204                }
205            })?);
206        }
207
208        if self.filter().access {
209            let accessible_projects = self.get_accessible_projects().map_err(|error| {
210                match error {
211                    ApiErrorResponse::Json(x) => x.to_string(),
212                    ApiErrorResponse::String(s) => s,
213                }
214            })?;
215
216            for accessible_project in accessible_projects {
217                let mut already_present = false;
218                for repo in &repos {
219                    if repo.name() == accessible_project.name()
220                        && repo.namespace() == accessible_project.namespace()
221                    {
222                        already_present = true;
223                    }
224                }
225                if !already_present {
226                    repos.push(accessible_project);
227                }
228            }
229        }
230
231        for user in &self.filter().users {
232            let user_projects = self.get_user_projects(user).map_err(|error| {
233                match error {
234                    ApiErrorResponse::Json(x) => x.to_string(),
235                    ApiErrorResponse::String(s) => s,
236                }
237            })?;
238
239            for user_project in user_projects {
240                let mut already_present = false;
241                for repo in &repos {
242                    if repo.name() == user_project.name()
243                        && repo.namespace() == user_project.namespace()
244                    {
245                        already_present = true;
246                    }
247                }
248                if !already_present {
249                    repos.push(user_project);
250                }
251            }
252        }
253
254        for group in &self.filter().groups {
255            let group_projects = self.get_group_projects(group).map_err(|error| {
256                format!("group \"{}\": {}", group, match error {
257                    ApiErrorResponse::Json(x) => x.to_string(),
258                    ApiErrorResponse::String(s) => s,
259                })
260            })?;
261            for group_project in group_projects {
262                let mut already_present = false;
263                for repo in &repos {
264                    if repo.name() == group_project.name()
265                        && repo.namespace() == group_project.namespace()
266                    {
267                        already_present = true;
268                    }
269                }
270
271                if !already_present {
272                    repos.push(group_project);
273                }
274            }
275        }
276
277        let mut ret: HashMap<Option<String>, Vec<repo::Repo>> = HashMap::new();
278
279        let remote_name = remote_name.unwrap_or_else(|| DEFAULT_REMOTE_NAME.to_string());
280
281        for repo in repos {
282            let namespace = repo.namespace();
283
284            let mut repo = repo.into_repo_config(&remote_name, worktree_setup, force_ssh);
285
286            // Namespace is already part of the hashmap key. I'm not too happy
287            // about the data exchange format here.
288            repo.remove_namespace();
289
290            ret.entry(namespace).or_default().push(repo);
291        }
292
293        Ok(ret)
294    }
295}
296
297fn call<T, U>(
298    uri: &str,
299    auth_header_key: &str,
300    secret_token: &token::AuthToken,
301    accept_header: Option<&str>,
302) -> Result<T, ApiErrorResponse<U>>
303where
304    T: serde::de::DeserializeOwned,
305    U: serde::de::DeserializeOwned + JsonError,
306{
307    let response = match ureq::get(uri)
308        .set("accept", accept_header.unwrap_or("application/json"))
309        .set("authorization", &format!("{} {}", &auth_header_key, &secret_token.access()))
310        .call()
311    {
312        Err(ureq::Error::Transport(error)) => return Err(error.to_string())?,
313        Err(ureq::Error::Status(_code, response)) => {
314            let response: U = response
315                .into_json()
316                .map_err(|error| format!("Failed deserializing error response: {}", error))?;
317            return Err(ApiErrorResponse::Json(response));
318        },
319        Ok(response) => {
320            response
321                .into_json()
322                .map_err(|error| format!("Failed deserializing response: {}", error))?
323        },
324    };
325
326    Ok(response)
327}