git_workspace/providers/
github.rs

1use crate::providers::{
2    create_exclude_regex_set, create_include_regex_set, Provider, APP_USER_AGENT,
3};
4use crate::repository::Repository;
5use anyhow::{bail, Context};
6use console::style;
7use graphql_client::{GraphQLQuery, Response};
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10use std::env;
11use std::fmt;
12
13// See https://github.com/graphql-rust/graphql-client/blob/master/graphql_client/tests/custom_scalars.rs#L6
14type GitSSHRemote = String;
15#[allow(clippy::upper_case_acronyms)]
16type URI = String;
17
18#[derive(GraphQLQuery)]
19#[graphql(
20    schema_path = "src/providers/graphql/github/schema.graphql",
21    query_path = "src/providers/graphql/github/projects.graphql",
22    response_derives = "Debug"
23)]
24pub struct Repositories;
25
26fn default_env_var() -> String {
27    String::from("GITHUB_TOKEN")
28}
29
30static DEFAULT_GITHUB_URL: &str = "https://api.github.com/graphql";
31
32fn public_github_url() -> String {
33    DEFAULT_GITHUB_URL.to_string()
34}
35
36#[derive(Deserialize, Serialize, Default, Debug, Eq, Ord, PartialEq, PartialOrd, clap::Parser)]
37#[serde(rename_all = "lowercase")]
38#[command(about = "Add a Github user or organization by name")]
39pub struct GithubProvider {
40    /// The name of the user or organisation to add.
41    pub name: String,
42    #[arg(long = "path", default_value = "github")]
43    /// Clone repositories to a specific base path
44    path: String,
45    #[arg(long = "env-name", short = 'e', default_value = "GITHUB_TOKEN")]
46    #[serde(default = "default_env_var")]
47    /// Environment variable containing the auth token
48    env_var: String,
49
50    #[arg(long = "skip-forks")]
51    #[serde(default)]
52    /// Don't clone forked repositories
53    skip_forks: bool,
54
55    #[arg(long = "include")]
56    #[serde(default)]
57    /// Only clone repositories that match these regular expressions. The repository name
58    /// includes the user or organisation name.
59    include: Vec<String>,
60
61    #[arg(long = "auth-http")]
62    #[serde(default)]
63    /// Use HTTP authentication instead of SSH
64    auth_http: bool,
65
66    #[arg(long = "exclude")]
67    #[serde(default)]
68    /// Don't clone repositories that match these regular expressions. The repository name
69    /// includes the user or organisation name.
70    exclude: Vec<String>,
71
72    #[serde(default = "public_github_url")]
73    #[arg(long = "url", default_value = DEFAULT_GITHUB_URL)]
74    /// Github instance URL, if using Github Enterprise this should be
75    /// http(s)://HOSTNAME/api/graphql
76    pub url: String,
77}
78
79impl fmt::Display for GithubProvider {
80    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81        write!(
82            f,
83            "Github user/org {} in directory {}, using the token stored in {}",
84            style(&self.name.to_lowercase()).green(),
85            style(&self.path.to_lowercase()).green(),
86            style(&self.env_var).green(),
87        )
88    }
89}
90
91impl GithubProvider {
92    fn parse_repo(
93        &self,
94        path: &str,
95        repo: &repositories::RepositoriesRepositoryOwnerRepositoriesNodes,
96    ) -> Repository {
97        let default_branch = repo
98            .default_branch_ref
99            .as_ref()
100            .map(|branch| branch.name.clone());
101        let upstream = repo.parent.as_ref().map(|parent| parent.ssh_url.clone());
102
103        Repository::new(
104            format!("{}/{}", path, repo.name_with_owner.clone()),
105            if self.auth_http {
106                repo.url.clone()
107            } else {
108                repo.ssh_url.clone()
109            },
110            default_branch,
111            upstream,
112        )
113    }
114}
115
116impl Provider for GithubProvider {
117    fn correctly_configured(&self) -> bool {
118        let token = env::var(&self.env_var);
119        if token.is_err() {
120            println!(
121                "{}",
122                style(format!(
123                    "Error: {} environment variable is not defined",
124                    self.env_var
125                ))
126                .red()
127            );
128            if self.url == public_github_url() {
129                println!(
130                    "Create a personal access token here: {}",
131                    style("https://github.com/settings/tokens").green()
132                );
133            } else {
134                println!(
135                    "Create a personal access token in your {}.",
136                    style("Github Enterprise server").green()
137                );
138            }
139
140            println!(
141                "Then set a {} environment variable with the value",
142                style(&self.env_var).green()
143            );
144            return false;
145        }
146        if self.name.ends_with('/') {
147            println!(
148                "{}",
149                style("Error: Ensure that names do not end in forward slashes").red()
150            );
151            println!("You specified: {}", self.name);
152            return false;
153        }
154        true
155    }
156
157    fn fetch_repositories(&self) -> anyhow::Result<Vec<Repository>> {
158        let github_token = env::var(&self.env_var)
159            .with_context(|| format!("Missing {} environment variable", self.env_var))?;
160
161        let auth_header = match github_token.as_str() {
162            "none" => "none".to_string(),
163            token => {
164                format!("Bearer {}", token)
165            }
166        };
167
168        let mut repositories = vec![];
169
170        let mut after = None;
171
172        let include_regex_set = create_include_regex_set(&self.include)?;
173        let exclude_regex_set = create_exclude_regex_set(&self.exclude)?;
174
175        // include_forks needs to be None instead of true, as the graphql parameter has three
176        // states: false - no forks, true - only forks, none - all repositories.
177        let include_forks: Option<bool> = if self.skip_forks { Some(false) } else { None };
178
179        let agent = ureq::AgentBuilder::new()
180            .https_only(true)
181            .user_agent(APP_USER_AGENT)
182            .build();
183
184        loop {
185            let q = Repositories::build_query(repositories::Variables {
186                login: self.name.to_lowercase(),
187                include_forks,
188                after,
189            });
190            let res = agent
191                .post(&self.url)
192                .set("Authorization", &auth_header)
193                .send_json(json!(&q));
194
195            let res = match res {
196                Ok(response) => response,
197                Err(ureq::Error::Status(status, response)) => match response.into_string() {
198                    Ok(resp) => {
199                        bail!("Got status code {status}. Body: {resp}")
200                    }
201                    Err(e) => {
202                        bail!("Got status code {status}. Error reading body: {e}")
203                    }
204                },
205                Err(e) => return Err(e.into()),
206            };
207
208            let body = res.into_string()?;
209            let response_data: Response<repositories::ResponseData> = serde_json::from_str(&body)?;
210
211            if let Some(errors) = response_data.errors {
212                let total_errors = errors.len();
213                let combined_errors: Vec<_> = errors.into_iter().map(|e| e.message).collect();
214                let combined_message = combined_errors.join("\n");
215                bail!(
216                    "Received {} errors. Errors:\n{}",
217                    total_errors,
218                    combined_message
219                );
220            }
221
222            let response_repositories = response_data
223                .data
224                .with_context(|| format!("Invalid response from GitHub: {}", body))?
225                .repository_owner
226                .with_context(|| format!("Invalid response from GitHub: {}", body))?
227                .repositories;
228
229            repositories.extend(
230                response_repositories
231                    .nodes
232                    .unwrap()
233                    .iter()
234                    .map(|r| r.as_ref().unwrap())
235                    .filter(|r| !r.is_archived)
236                    .filter(|r| include_regex_set.is_match(&r.name_with_owner))
237                    .filter(|r| !exclude_regex_set.is_match(&r.name_with_owner))
238                    .map(|repo| self.parse_repo(&self.path, repo)),
239            );
240
241            if !response_repositories.page_info.has_next_page {
242                break;
243            }
244            after = response_repositories.page_info.end_cursor;
245        }
246
247        Ok(repositories)
248    }
249}