grm/provider/
mod.rs

1pub mod github;
2pub mod gitlab;
3
4use std::{borrow::Cow, collections::HashMap, fmt};
5
6pub use github::Github;
7pub use gitlab::Gitlab;
8use thiserror::Error;
9
10use super::{RemoteName, RemoteUrl, auth, config, repo};
11
12pub struct Url(Cow<'static, str>);
13
14impl Url {
15    pub fn new(from: String) -> Self {
16        Self(Cow::Owned(from))
17    }
18
19    pub const fn new_static(from: &'static str) -> Self {
20        Self(Cow::Borrowed(from))
21    }
22
23    pub fn as_str(&self) -> &str {
24        &self.0
25    }
26}
27#[derive(Clone)]
28pub struct User(String);
29
30impl User {
31    pub fn new(name: String) -> Self {
32        Self(name)
33    }
34}
35
36impl From<super::config::User> for User {
37    fn from(value: super::config::User) -> Self {
38        Self(value.into_username())
39    }
40}
41
42impl fmt::Display for User {
43    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44        write!(f, "{}", self.0)
45    }
46}
47
48#[derive(Clone)]
49pub struct Group(String);
50
51impl Group {
52    pub fn new(name: String) -> Self {
53        Self(name)
54    }
55}
56
57impl From<super::config::Group> for Group {
58    fn from(value: super::config::Group) -> Self {
59        Self(value.into_groupname())
60    }
61}
62
63impl fmt::Display for Group {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        write!(f, "{}", self.0)
66    }
67}
68
69const DEFAULT_REMOTE_NAME: RemoteName = RemoteName::new_static("origin");
70
71#[derive(Debug, Error)]
72pub enum Error {
73    #[error("response error: {0}")]
74    Response(String),
75    #[error("provider error: {0}")]
76    Provider(String),
77}
78
79#[derive(Debug, clap::ValueEnum, Clone)]
80pub enum RemoteProvider {
81    Github,
82    Gitlab,
83}
84
85impl From<config::RemoteProvider> for RemoteProvider {
86    fn from(other: config::RemoteProvider) -> Self {
87        match other {
88            config::RemoteProvider::Github => Self::Github,
89            config::RemoteProvider::Gitlab => Self::Gitlab,
90        }
91    }
92}
93
94pub fn escape(s: &str) -> String {
95    url_escape::encode_component(s).to_string()
96}
97
98#[derive(PartialEq, Eq)]
99pub struct ProjectName(String);
100
101impl ProjectName {
102    pub fn new(from: String) -> Self {
103        Self(from)
104    }
105
106    pub fn into_string(self) -> String {
107        self.0
108    }
109}
110
111impl From<repo::ProjectName> for ProjectName {
112    fn from(other: repo::ProjectName) -> Self {
113        Self(other.into_string())
114    }
115}
116
117impl From<ProjectName> for repo::ProjectName {
118    fn from(other: ProjectName) -> Self {
119        Self::new(other.into_string())
120    }
121}
122
123#[derive(PartialEq, Eq, Hash)]
124pub struct ProjectNamespace(String);
125
126impl ProjectNamespace {
127    pub fn new(from: String) -> Self {
128        Self(from)
129    }
130
131    pub fn into_string(self) -> String {
132        self.0
133    }
134
135    pub fn as_str(&self) -> &str {
136        &self.0
137    }
138}
139
140impl From<repo::ProjectNamespace> for ProjectNamespace {
141    fn from(other: repo::ProjectNamespace) -> Self {
142        Self(other.into_string())
143    }
144}
145
146impl From<ProjectNamespace> for repo::ProjectNamespace {
147    fn from(other: ProjectNamespace) -> Self {
148        Self::new(other.into_string())
149    }
150}
151
152pub trait Project {
153    fn into_repo_config(
154        self,
155        remote_name: &RemoteName,
156        worktree_setup: bool,
157        force_ssh: bool,
158    ) -> repo::Repo
159    where
160        Self: Sized,
161    {
162        repo::Repo {
163            name: self.name().into(),
164            namespace: self.namespace().map(Into::into),
165            worktree_setup,
166            remotes: vec![repo::Remote {
167                name: remote_name.clone(),
168                url: if force_ssh || self.private() {
169                    self.ssh_url()
170                } else {
171                    self.http_url()
172                },
173                remote_type: if force_ssh || self.private() {
174                    repo::RemoteType::Ssh
175                } else {
176                    repo::RemoteType::Https
177                },
178            }],
179        }
180    }
181
182    fn name(&self) -> ProjectName;
183    fn namespace(&self) -> Option<ProjectNamespace>;
184    fn ssh_url(&self) -> RemoteUrl;
185    fn http_url(&self) -> RemoteUrl;
186    fn private(&self) -> bool;
187}
188
189#[derive(Clone)]
190pub struct Filter {
191    users: Vec<User>,
192    groups: Vec<Group>,
193    owner: bool,
194    access: bool,
195}
196
197impl Filter {
198    pub fn new(users: Vec<User>, groups: Vec<Group>, owner: bool, access: bool) -> Self {
199        Self {
200            users,
201            groups,
202            owner,
203            access,
204        }
205    }
206
207    pub fn empty(&self) -> bool {
208        self.users.is_empty() && self.groups.is_empty() && !self.owner && !self.access
209    }
210}
211
212#[derive(Debug, Error)]
213pub enum ApiError<T>
214where
215    T: JsonError,
216{
217    Json(T),
218    String(String),
219}
220
221impl<T> From<String> for ApiError<T>
222where
223    T: JsonError,
224{
225    fn from(s: String) -> Self {
226        Self::String(s)
227    }
228}
229
230impl<T> From<ureq::http::header::ToStrError> for ApiError<T>
231where
232    T: JsonError,
233{
234    fn from(s: ureq::http::header::ToStrError) -> Self {
235        Self::String(s.to_string())
236    }
237}
238
239pub trait JsonError {
240    fn to_string(self) -> String;
241}
242
243pub trait Provider {
244    type Project: serde::de::DeserializeOwned + Project;
245    type Error: serde::de::DeserializeOwned + JsonError;
246
247    fn new(
248        filter: Filter,
249        secret_token: auth::AuthToken,
250        api_url_override: Option<Url>,
251    ) -> Result<Self, Error>
252    where
253        Self: Sized;
254
255    fn filter(&self) -> &Filter;
256    fn secret_token(&self) -> &auth::AuthToken;
257    fn auth_header_key() -> &'static str;
258
259    fn get_user_projects(&self, user: &User) -> Result<Vec<Self::Project>, ApiError<Self::Error>>;
260
261    fn get_group_projects(
262        &self,
263        group: &Group,
264    ) -> Result<Vec<Self::Project>, ApiError<Self::Error>>;
265
266    fn get_own_projects(&self) -> Result<Vec<Self::Project>, ApiError<Self::Error>> {
267        self.get_user_projects(&self.get_current_user()?)
268    }
269
270    fn get_accessible_projects(&self) -> Result<Vec<Self::Project>, ApiError<Self::Error>>;
271
272    fn get_current_user(&self) -> Result<User, ApiError<Self::Error>>;
273
274    ///
275    /// Calls the API at specific uri and expects a successful response of
276    /// `Vec<T>` back, or an error response U
277    ///
278    /// Handles paging with "link" HTTP headers properly and reads all pages to
279    /// the end.
280    fn call_list(
281        &self,
282        uri: &Url,
283        accept_header: Option<&str>,
284    ) -> Result<Vec<Self::Project>, ApiError<Self::Error>> {
285        let mut results = vec![];
286
287        match ureq::get(uri.as_str())
288            .config()
289            .http_status_as_error(false)
290            .build()
291            .header("accept", accept_header.unwrap_or("application/json"))
292            .header(
293                "authorization",
294                &format!(
295                    "{} {}",
296                    Self::auth_header_key(),
297                    &self.secret_token().access()
298                ),
299            )
300            .call()
301        {
302            Err(ureq::Error::Http(error)) => return Err(format!("http error: {error}").into()),
303            Err(e) => return Err(format!("unknown error: {e}").into()),
304            Ok(mut response) => {
305                if !response.status().is_success() {
306                    let result: Self::Error = response
307                        .body_mut()
308                        .read_json()
309                        .map_err(|error| format!("Failed deserializing error response: {error}"))?;
310                    return Err(ApiError::Json(result));
311                } else {
312                    if let Some(link_header) = response.headers().get("link") {
313                        let link_header = parse_link_header::parse(link_header.to_str()?)
314                            .map_err(|error| error.to_string())?;
315
316                        let next_page = link_header.get(&Some(String::from("next")));
317
318                        if let Some(page) = next_page {
319                            let following_repos =
320                                self.call_list(&Url::new(page.raw_uri.clone()), accept_header)?;
321                            results.extend(following_repos);
322                        }
323                    }
324
325                    let result: Vec<Self::Project> = response
326                        .body_mut()
327                        .read_json()
328                        .map_err(|error| format!("Failed deserializing response: {error}"))?;
329
330                    results.extend(result);
331                }
332            }
333        }
334
335        Ok(results)
336    }
337
338    fn get_repos(
339        &self,
340        worktree_setup: bool,
341        force_ssh: bool,
342        remote_name: Option<RemoteName>,
343    ) -> Result<HashMap<Option<ProjectNamespace>, Vec<repo::Repo>>, Error> {
344        let mut repos = vec![];
345
346        if self.filter().owner {
347            repos.extend(self.get_own_projects().map_err(|error| {
348                Error::Response(match error {
349                    ApiError::Json(x) => x.to_string(),
350                    ApiError::String(s) => s,
351                })
352            })?);
353        }
354
355        if self.filter().access {
356            let accessible_projects = self.get_accessible_projects().map_err(|error| {
357                Error::Response(match error {
358                    ApiError::Json(x) => x.to_string(),
359                    ApiError::String(s) => s,
360                })
361            })?;
362
363            for accessible_project in accessible_projects {
364                let mut already_present = false;
365                for repo in &repos {
366                    if repo.name() == accessible_project.name()
367                        && repo.namespace() == accessible_project.namespace()
368                    {
369                        already_present = true;
370                    }
371                }
372                if !already_present {
373                    repos.push(accessible_project);
374                }
375            }
376        }
377
378        for user in &self.filter().users {
379            let user_projects = self.get_user_projects(user).map_err(|error| {
380                Error::Response(match error {
381                    ApiError::Json(x) => x.to_string(),
382                    ApiError::String(s) => s,
383                })
384            })?;
385
386            for user_project in user_projects {
387                let mut already_present = false;
388                for repo in &repos {
389                    if repo.name() == user_project.name()
390                        && repo.namespace() == user_project.namespace()
391                    {
392                        already_present = true;
393                    }
394                }
395                if !already_present {
396                    repos.push(user_project);
397                }
398            }
399        }
400
401        for group in &self.filter().groups {
402            let group_projects = self.get_group_projects(group).map_err(|error| {
403                Error::Response(format!(
404                    "group \"{}\": {}",
405                    group,
406                    match error {
407                        ApiError::Json(x) => x.to_string(),
408                        ApiError::String(s) => s,
409                    }
410                ))
411            })?;
412            for group_project in group_projects {
413                let mut already_present = false;
414                for repo in &repos {
415                    if repo.name() == group_project.name()
416                        && repo.namespace() == group_project.namespace()
417                    {
418                        already_present = true;
419                    }
420                }
421
422                if !already_present {
423                    repos.push(group_project);
424                }
425            }
426        }
427
428        let mut ret: HashMap<Option<ProjectNamespace>, Vec<repo::Repo>> = HashMap::new();
429
430        let remote_name = remote_name.unwrap_or(DEFAULT_REMOTE_NAME);
431
432        for repo in repos {
433            let namespace = repo.namespace();
434
435            let mut repo = repo.into_repo_config(&remote_name, worktree_setup, force_ssh);
436
437            // Namespace is already part of the hashmap key. I'm not too happy
438            // about the data exchange format here.
439            repo.remove_namespace();
440
441            ret.entry(namespace).or_default().push(repo);
442        }
443
444        Ok(ret)
445    }
446}
447
448fn call<T, U>(
449    uri: &str,
450    auth_header_key: &str,
451    secret_token: &auth::AuthToken,
452    accept_header: Option<&str>,
453) -> Result<T, ApiError<U>>
454where
455    T: serde::de::DeserializeOwned,
456    U: serde::de::DeserializeOwned + JsonError,
457{
458    let response = match ureq::get(uri)
459        .header("accept", accept_header.unwrap_or("application/json"))
460        .header(
461            "authorization",
462            &format!("{} {}", &auth_header_key, &secret_token.access()),
463        )
464        .call()
465    {
466        Err(ureq::Error::Http(error)) => return Err(format!("http error: {error}").into()),
467        Err(e) => return Err(format!("unknown error: {e}").into()),
468        Ok(mut response) => {
469            if !response.status().is_success() {
470                let result: U = response
471                    .body_mut()
472                    .read_json()
473                    .map_err(|error| format!("Failed deserializing error response: {error}"))?;
474                return Err(ApiError::Json(result));
475            } else {
476                response
477                    .body_mut()
478                    .read_json()
479                    .map_err(|error| format!("Failed deserializing response: {error}"))?
480            }
481        }
482    };
483
484    Ok(response)
485}