git_workspace/providers/
github.rs1use crate::providers::{
2 create_exclude_regex_set, create_include_regex_set, Provider, APP_USER_AGENT,
3};
4use crate::repository::Repository;
5use anyhow::{bail, Context};
6use console::style;
7use graphql_client::{GraphQLQuery, Response};
8use serde::{Deserialize, Serialize};
9use serde_json::json;
10use std::env;
11use std::fmt;
12
13type GitSSHRemote = String;
15#[allow(clippy::upper_case_acronyms)]
16type URI = String;
17
18#[derive(GraphQLQuery)]
19#[graphql(
20 schema_path = "src/providers/graphql/github/schema.graphql",
21 query_path = "src/providers/graphql/github/projects.graphql",
22 response_derives = "Debug"
23)]
24pub struct Repositories;
25
26fn default_env_var() -> String {
27 String::from("GITHUB_TOKEN")
28}
29
30static DEFAULT_GITHUB_URL: &str = "https://api.github.com/graphql";
31
32fn public_github_url() -> String {
33 DEFAULT_GITHUB_URL.to_string()
34}
35
36#[derive(Deserialize, Serialize, Default, Debug, Eq, Ord, PartialEq, PartialOrd, clap::Parser)]
37#[serde(rename_all = "lowercase")]
38#[command(about = "Add a Github user or organization by name")]
39pub struct GithubProvider {
40 pub name: String,
42 #[arg(long = "path", default_value = "github")]
43 path: String,
45 #[arg(long = "env-name", short = 'e', default_value = "GITHUB_TOKEN")]
46 #[serde(default = "default_env_var")]
47 env_var: String,
49
50 #[arg(long = "skip-forks")]
51 #[serde(default)]
52 skip_forks: bool,
54
55 #[arg(long = "include")]
56 #[serde(default)]
57 include: Vec<String>,
60
61 #[arg(long = "auth-http")]
62 #[serde(default)]
63 auth_http: bool,
65
66 #[arg(long = "exclude")]
67 #[serde(default)]
68 exclude: Vec<String>,
71
72 #[serde(default = "public_github_url")]
73 #[arg(long = "url", default_value = DEFAULT_GITHUB_URL)]
74 pub url: String,
77}
78
79impl fmt::Display for GithubProvider {
80 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
81 write!(
82 f,
83 "Github user/org {} in directory {}, using the token stored in {}",
84 style(&self.name.to_lowercase()).green(),
85 style(&self.path.to_lowercase()).green(),
86 style(&self.env_var).green(),
87 )
88 }
89}
90
91impl GithubProvider {
92 fn parse_repo(
93 &self,
94 path: &str,
95 repo: &repositories::RepositoriesRepositoryOwnerRepositoriesNodes,
96 ) -> Repository {
97 let default_branch = repo
98 .default_branch_ref
99 .as_ref()
100 .map(|branch| branch.name.clone());
101 let upstream = repo.parent.as_ref().map(|parent| parent.ssh_url.clone());
102
103 Repository::new(
104 format!("{}/{}", path, repo.name_with_owner.clone()),
105 if self.auth_http {
106 repo.url.clone()
107 } else {
108 repo.ssh_url.clone()
109 },
110 default_branch,
111 upstream,
112 )
113 }
114}
115
116impl Provider for GithubProvider {
117 fn correctly_configured(&self) -> bool {
118 let token = env::var(&self.env_var);
119 if token.is_err() {
120 println!(
121 "{}",
122 style(format!(
123 "Error: {} environment variable is not defined",
124 self.env_var
125 ))
126 .red()
127 );
128 if self.url == public_github_url() {
129 println!(
130 "Create a personal access token here: {}",
131 style("https://github.com/settings/tokens").green()
132 );
133 } else {
134 println!(
135 "Create a personal access token in your {}.",
136 style("Github Enterprise server").green()
137 );
138 }
139
140 println!(
141 "Then set a {} environment variable with the value",
142 style(&self.env_var).green()
143 );
144 return false;
145 }
146 if self.name.ends_with('/') {
147 println!(
148 "{}",
149 style("Error: Ensure that names do not end in forward slashes").red()
150 );
151 println!("You specified: {}", self.name);
152 return false;
153 }
154 true
155 }
156
157 fn fetch_repositories(&self) -> anyhow::Result<Vec<Repository>> {
158 let github_token = env::var(&self.env_var)
159 .with_context(|| format!("Missing {} environment variable", self.env_var))?;
160
161 let auth_header = match github_token.as_str() {
162 "none" => "none".to_string(),
163 token => {
164 format!("Bearer {}", token)
165 }
166 };
167
168 let mut repositories = vec![];
169
170 let mut after = None;
171
172 let include_regex_set = create_include_regex_set(&self.include)?;
173 let exclude_regex_set = create_exclude_regex_set(&self.exclude)?;
174
175 let include_forks: Option<bool> = if self.skip_forks { Some(false) } else { None };
178
179 let agent = ureq::AgentBuilder::new()
180 .https_only(true)
181 .user_agent(APP_USER_AGENT)
182 .build();
183
184 loop {
185 let q = Repositories::build_query(repositories::Variables {
186 login: self.name.to_lowercase(),
187 include_forks,
188 after,
189 });
190 let res = {
191 let max_retries = 3;
192 let mut last_err = None;
193 let mut response = None;
194 for attempt in 0..max_retries {
195 let result = agent
196 .post(&self.url)
197 .set("Authorization", &auth_header)
198 .send_json(json!(&q));
199 match result {
200 Ok(resp) => {
201 response = Some(resp);
202 break;
203 }
204 Err(e) => {
205 last_err = Some(e);
206 if attempt < max_retries - 1 {
207 std::thread::sleep(std::time::Duration::from_secs(1));
208 }
209 }
210 }
211 }
212 match response {
213 Some(resp) => resp,
214 None => {
215 let err = last_err.unwrap();
216 match err {
217 ureq::Error::Status(status, response) => match response.into_string() {
218 Ok(resp) => {
219 bail!("Got status code {status}. Body: {resp}")
220 }
221 Err(e) => {
222 bail!("Got status code {status}. Error reading body: {e}")
223 }
224 },
225 e => return Err(e.into()),
226 }
227 }
228 }
229 };
230
231 let body = res.into_string()?;
232 let response_data: Response<repositories::ResponseData> = serde_json::from_str(&body)?;
233
234 if let Some(errors) = response_data.errors {
235 let total_errors = errors.len();
236 let combined_errors: Vec<_> = errors
237 .into_iter()
238 .map(|e| {
239 let mut message_str = e.message;
240 if let Some(path) = e.path {
241 let path_strings: Vec<String> =
242 path.iter().map(|p| p.to_string()).collect();
243 message_str.push_str(format!(" ({})", path_strings.join(".")).as_str());
244 }
245 message_str
246 })
247 .collect();
248 let combined_message = combined_errors.join("\n");
249 bail!(
250 "Received {} errors. Errors:\n{}",
251 total_errors,
252 combined_message
253 );
254 }
255
256 let response_repositories = response_data
257 .data
258 .with_context(|| format!("Invalid response from GitHub: {}", body))?
259 .repository_owner
260 .with_context(|| format!("Invalid response from GitHub: {}", body))?
261 .repositories;
262
263 repositories.extend(
264 response_repositories
265 .nodes
266 .unwrap()
267 .iter()
268 .map(|r| r.as_ref().unwrap())
269 .filter(|r| !r.is_archived)
270 .filter(|r| include_regex_set.is_match(&r.name_with_owner))
271 .filter(|r| !exclude_regex_set.is_match(&r.name_with_owner))
272 .map(|repo| self.parse_repo(&self.path, repo)),
273 );
274
275 if !response_repositories.page_info.has_next_page {
276 break;
277 }
278 after = response_repositories.page_info.end_cursor;
279 }
280
281 Ok(repositories)
282 }
283}