git_same/provider/github/
client.rs1use async_trait::async_trait;
4use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION, USER_AGENT};
5use reqwest::Client;
6use tracing::{debug, trace};
7
8use super::pagination::fetch_all_pages;
9use super::GITHUB_API_URL;
10use crate::errors::ProviderError;
11use crate::provider::traits::*;
12use crate::types::{Org, OwnedRepo, ProviderKind, Repo};
13
14const DEFAULT_TIMEOUT_SECS: u64 = 60;
16
17pub struct GitHubProvider {
21 client: Client,
23 credentials: Credentials,
25 display_name: String,
27}
28
29impl GitHubProvider {
30 pub fn new(
32 credentials: Credentials,
33 display_name: impl Into<String>,
34 ) -> Result<Self, ProviderError> {
35 Self::with_timeout(credentials, display_name, DEFAULT_TIMEOUT_SECS)
36 }
37
38 pub fn with_timeout(
40 credentials: Credentials,
41 display_name: impl Into<String>,
42 timeout_secs: u64,
43 ) -> Result<Self, ProviderError> {
44 let mut headers = HeaderMap::new();
45 headers.insert(USER_AGENT, HeaderValue::from_static("gisa-cli/0.1.0"));
46 headers.insert(
47 ACCEPT,
48 HeaderValue::from_static("application/vnd.github+json"),
49 );
50 headers.insert(
51 "X-GitHub-Api-Version",
52 HeaderValue::from_static("2022-11-28"),
53 );
54
55 let client = Client::builder()
56 .default_headers(headers)
57 .timeout(std::time::Duration::from_secs(timeout_secs))
58 .build()
59 .map_err(|e| ProviderError::Configuration(e.to_string()))?;
60
61 Ok(Self {
62 client,
63 credentials,
64 display_name: display_name.into(),
65 })
66 }
67
68 fn api_url(&self, path: &str) -> String {
70 format!("{}{}", self.credentials.api_base_url, path)
71 }
72
73 async fn get<T: serde::de::DeserializeOwned>(&self, url: &str) -> Result<T, ProviderError> {
75 trace!(url, "Making authenticated GET request");
76
77 let response = self
78 .client
79 .get(url)
80 .header(AUTHORIZATION, format!("Bearer {}", self.credentials.token))
81 .send()
82 .await
83 .map_err(|e| ProviderError::Network(e.to_string()))?;
84
85 let status = response.status();
86 trace!(url, status = %status, "Received response");
87
88 if !status.is_success() {
89 let body = response.text().await.unwrap_or_default();
90 debug!(url, status = %status, "API request failed");
91 return Err(ProviderError::from_status(status.as_u16(), body));
92 }
93
94 response
95 .json()
96 .await
97 .map_err(|e| ProviderError::Parse(e.to_string()))
98 }
99
100 async fn get_paginated<T: serde::de::DeserializeOwned>(
102 &self,
103 url: &str,
104 ) -> Result<Vec<T>, ProviderError> {
105 fetch_all_pages(&self.client, &self.credentials.token, url).await
106 }
107
108 fn is_github_com(&self) -> bool {
110 self.credentials.api_base_url == GITHUB_API_URL
111 }
112}
113
114#[async_trait]
115impl Provider for GitHubProvider {
116 fn kind(&self) -> ProviderKind {
117 if self.is_github_com() {
118 ProviderKind::GitHub
119 } else {
120 ProviderKind::GitHubEnterprise
121 }
122 }
123
124 fn display_name(&self) -> &str {
125 &self.display_name
126 }
127
128 async fn validate_credentials(&self) -> Result<(), ProviderError> {
129 self.get_username().await?;
131 Ok(())
132 }
133
134 async fn get_username(&self) -> Result<String, ProviderError> {
135 #[derive(serde::Deserialize)]
136 struct User {
137 login: String,
138 }
139
140 let url = self.api_url("/user");
141 let user: User = self.get(&url).await?;
142 Ok(user.login)
143 }
144
145 async fn get_organizations(&self) -> Result<Vec<Org>, ProviderError> {
146 let url = self.api_url("/user/orgs");
147 self.get_paginated(&url).await
148 }
149
150 async fn get_org_repos(&self, org: &str) -> Result<Vec<Repo>, ProviderError> {
151 let url = self.api_url(&format!("/orgs/{}/repos", org));
152 self.get_paginated(&url).await
153 }
154
155 async fn get_user_repos(&self) -> Result<Vec<Repo>, ProviderError> {
156 let url = self.api_url("/user/repos?affiliation=owner");
157 self.get_paginated(&url).await
158 }
159
160 async fn get_rate_limit(&self) -> Result<RateLimitInfo, ProviderError> {
161 #[derive(serde::Deserialize)]
162 struct RateLimitResponse {
163 rate: RateInfo,
164 }
165
166 #[derive(serde::Deserialize)]
167 struct RateInfo {
168 limit: u32,
169 remaining: u32,
170 reset: i64,
171 }
172
173 let url = self.api_url("/rate_limit");
174 let response: RateLimitResponse = self.get(&url).await?;
175
176 Ok(RateLimitInfo {
177 limit: response.rate.limit,
178 remaining: response.rate.remaining,
179 reset_at: Some(response.rate.reset),
180 })
181 }
182
183 async fn discover_repos(
184 &self,
185 options: &DiscoveryOptions,
186 progress: &dyn DiscoveryProgress,
187 ) -> Result<Vec<OwnedRepo>, ProviderError> {
188 debug!(provider = %self.display_name, "Starting repository discovery");
189
190 let username = self.get_username().await?;
191 debug!(username, "Authenticated user");
192
193 let mut all_repos = Vec::new();
194
195 let orgs = self.get_organizations().await?;
197 let orgs_count = orgs.len();
198 let filtered_orgs: Vec<_> = orgs
199 .into_iter()
200 .filter(|o| options.should_include_org(&o.login))
201 .collect();
202
203 debug!(
204 total_orgs = orgs_count,
205 filtered_orgs = filtered_orgs.len(),
206 "Discovered organizations"
207 );
208 progress.on_orgs_discovered(filtered_orgs.len());
209
210 for org in &filtered_orgs {
212 progress.on_org_started(&org.login);
213
214 match self.get_org_repos(&org.login).await {
215 Ok(repos) => {
216 let filtered: Vec<_> = repos
217 .into_iter()
218 .filter(|r| options.should_include(r))
219 .collect();
220
221 let count = filtered.len();
222 for repo in filtered {
223 all_repos.push(OwnedRepo::new(&org.login, repo));
224 }
225
226 progress.on_org_complete(&org.login, count);
227 }
228 Err(e) => {
229 progress.on_error(&format!("Error fetching repos for {}: {}", org.login, e));
230 progress.on_org_complete(&org.login, 0);
231 }
232 }
233 }
234
235 progress.on_personal_repos_started();
237
238 match self.get_user_repos().await {
239 Ok(repos) => {
240 let filtered: Vec<_> = repos
241 .into_iter()
242 .filter(|r| !all_repos.iter().any(|or| or.repo.id == r.id))
244 .filter(|r| options.should_include(r))
245 .collect();
246
247 let count = filtered.len();
248 for repo in filtered {
249 all_repos.push(OwnedRepo::new(&username, repo));
250 }
251
252 progress.on_personal_repos_complete(count);
253 }
254 Err(e) => {
255 progress.on_error(&format!("Error fetching personal repos: {}", e));
256 progress.on_personal_repos_complete(0);
257 }
258 }
259
260 Ok(all_repos)
261 }
262
263 fn get_clone_url(&self, repo: &Repo, prefer_ssh: bool) -> String {
264 if prefer_ssh {
265 repo.ssh_url.clone()
266 } else {
267 repo.clone_url.clone()
268 }
269 }
270}
271
272#[cfg(test)]
273#[path = "client_tests.rs"]
274mod tests;