Skip to main content

git_same/provider/github/
client.rs

1//! GitHub API client implementation.
2
3use async_trait::async_trait;
4use reqwest::header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION, USER_AGENT};
5use reqwest::Client;
6use tracing::{debug, trace};
7
8use super::pagination::fetch_all_pages;
9use super::GITHUB_API_URL;
10use crate::errors::ProviderError;
11use crate::provider::traits::*;
12use crate::types::{Org, OwnedRepo, ProviderKind, Repo};
13
14/// Default timeout for API requests in seconds.
15const DEFAULT_TIMEOUT_SECS: u64 = 60;
16
17/// GitHub provider implementation.
18///
19/// Supports both github.com and GitHub Enterprise Server.
20pub struct GitHubProvider {
21    /// HTTP client
22    client: Client,
23    /// Authentication credentials
24    credentials: Credentials,
25    /// Display name for this provider instance
26    display_name: String,
27}
28
29impl GitHubProvider {
30    /// Creates a new GitHub provider with default timeout.
31    pub fn new(
32        credentials: Credentials,
33        display_name: impl Into<String>,
34    ) -> Result<Self, ProviderError> {
35        Self::with_timeout(credentials, display_name, DEFAULT_TIMEOUT_SECS)
36    }
37
38    /// Creates a new GitHub provider with custom timeout.
39    pub fn with_timeout(
40        credentials: Credentials,
41        display_name: impl Into<String>,
42        timeout_secs: u64,
43    ) -> Result<Self, ProviderError> {
44        let mut headers = HeaderMap::new();
45        headers.insert(USER_AGENT, HeaderValue::from_static("gisa-cli/0.1.0"));
46        headers.insert(
47            ACCEPT,
48            HeaderValue::from_static("application/vnd.github+json"),
49        );
50        headers.insert(
51            "X-GitHub-Api-Version",
52            HeaderValue::from_static("2022-11-28"),
53        );
54
55        let client = Client::builder()
56            .default_headers(headers)
57            .timeout(std::time::Duration::from_secs(timeout_secs))
58            .build()
59            .map_err(|e| ProviderError::Configuration(e.to_string()))?;
60
61        Ok(Self {
62            client,
63            credentials,
64            display_name: display_name.into(),
65        })
66    }
67
68    /// Constructs a full API URL from a path.
69    fn api_url(&self, path: &str) -> String {
70        format!("{}{}", self.credentials.api_base_url, path)
71    }
72
73    /// Makes an authenticated GET request.
74    async fn get<T: serde::de::DeserializeOwned>(&self, url: &str) -> Result<T, ProviderError> {
75        trace!(url, "Making authenticated GET request");
76
77        let response = self
78            .client
79            .get(url)
80            .header(AUTHORIZATION, format!("Bearer {}", self.credentials.token))
81            .send()
82            .await
83            .map_err(|e| ProviderError::Network(e.to_string()))?;
84
85        let status = response.status();
86        trace!(url, status = %status, "Received response");
87
88        if !status.is_success() {
89            let body = response.text().await.unwrap_or_default();
90            debug!(url, status = %status, "API request failed");
91            return Err(ProviderError::from_status(status.as_u16(), body));
92        }
93
94        response
95            .json()
96            .await
97            .map_err(|e| ProviderError::Parse(e.to_string()))
98    }
99
100    /// Fetches all pages from an endpoint.
101    async fn get_paginated<T: serde::de::DeserializeOwned>(
102        &self,
103        url: &str,
104    ) -> Result<Vec<T>, ProviderError> {
105        fetch_all_pages(&self.client, &self.credentials.token, url).await
106    }
107
108    /// Determines if this is GitHub.com or GitHub Enterprise.
109    fn is_github_com(&self) -> bool {
110        self.credentials.api_base_url == GITHUB_API_URL
111    }
112}
113
114#[async_trait]
115impl Provider for GitHubProvider {
116    fn kind(&self) -> ProviderKind {
117        if self.is_github_com() {
118            ProviderKind::GitHub
119        } else {
120            ProviderKind::GitHubEnterprise
121        }
122    }
123
124    fn display_name(&self) -> &str {
125        &self.display_name
126    }
127
128    async fn validate_credentials(&self) -> Result<(), ProviderError> {
129        // Make a simple API call to verify the token works
130        self.get_username().await?;
131        Ok(())
132    }
133
134    async fn get_username(&self) -> Result<String, ProviderError> {
135        #[derive(serde::Deserialize)]
136        struct User {
137            login: String,
138        }
139
140        let url = self.api_url("/user");
141        let user: User = self.get(&url).await?;
142        Ok(user.login)
143    }
144
145    async fn get_organizations(&self) -> Result<Vec<Org>, ProviderError> {
146        let url = self.api_url("/user/orgs");
147        self.get_paginated(&url).await
148    }
149
150    async fn get_org_repos(&self, org: &str) -> Result<Vec<Repo>, ProviderError> {
151        let url = self.api_url(&format!("/orgs/{}/repos", org));
152        self.get_paginated(&url).await
153    }
154
155    async fn get_user_repos(&self) -> Result<Vec<Repo>, ProviderError> {
156        let url = self.api_url("/user/repos?affiliation=owner");
157        self.get_paginated(&url).await
158    }
159
160    async fn get_rate_limit(&self) -> Result<RateLimitInfo, ProviderError> {
161        #[derive(serde::Deserialize)]
162        struct RateLimitResponse {
163            rate: RateInfo,
164        }
165
166        #[derive(serde::Deserialize)]
167        struct RateInfo {
168            limit: u32,
169            remaining: u32,
170            reset: i64,
171        }
172
173        let url = self.api_url("/rate_limit");
174        let response: RateLimitResponse = self.get(&url).await?;
175
176        Ok(RateLimitInfo {
177            limit: response.rate.limit,
178            remaining: response.rate.remaining,
179            reset_at: Some(response.rate.reset),
180        })
181    }
182
183    async fn discover_repos(
184        &self,
185        options: &DiscoveryOptions,
186        progress: &dyn DiscoveryProgress,
187    ) -> Result<Vec<OwnedRepo>, ProviderError> {
188        debug!(provider = %self.display_name, "Starting repository discovery");
189
190        let username = self.get_username().await?;
191        debug!(username, "Authenticated user");
192
193        let mut all_repos = Vec::new();
194
195        // Get organizations
196        let orgs = self.get_organizations().await?;
197        let orgs_count = orgs.len();
198        let filtered_orgs: Vec<_> = orgs
199            .into_iter()
200            .filter(|o| options.should_include_org(&o.login))
201            .collect();
202
203        debug!(
204            total_orgs = orgs_count,
205            filtered_orgs = filtered_orgs.len(),
206            "Discovered organizations"
207        );
208        progress.on_orgs_discovered(filtered_orgs.len());
209
210        // Fetch repos for each org
211        for org in &filtered_orgs {
212            progress.on_org_started(&org.login);
213
214            match self.get_org_repos(&org.login).await {
215                Ok(repos) => {
216                    let filtered: Vec<_> = repos
217                        .into_iter()
218                        .filter(|r| options.should_include(r))
219                        .collect();
220
221                    let count = filtered.len();
222                    for repo in filtered {
223                        all_repos.push(OwnedRepo::new(&org.login, repo));
224                    }
225
226                    progress.on_org_complete(&org.login, count);
227                }
228                Err(e) => {
229                    progress.on_error(&format!("Error fetching repos for {}: {}", org.login, e));
230                    progress.on_org_complete(&org.login, 0);
231                }
232            }
233        }
234
235        // Fetch personal repos
236        progress.on_personal_repos_started();
237
238        match self.get_user_repos().await {
239            Ok(repos) => {
240                let filtered: Vec<_> = repos
241                    .into_iter()
242                    // Skip repos already added via org
243                    .filter(|r| !all_repos.iter().any(|or| or.repo.id == r.id))
244                    .filter(|r| options.should_include(r))
245                    .collect();
246
247                let count = filtered.len();
248                for repo in filtered {
249                    all_repos.push(OwnedRepo::new(&username, repo));
250                }
251
252                progress.on_personal_repos_complete(count);
253            }
254            Err(e) => {
255                progress.on_error(&format!("Error fetching personal repos: {}", e));
256                progress.on_personal_repos_complete(0);
257            }
258        }
259
260        Ok(all_repos)
261    }
262
263    fn get_clone_url(&self, repo: &Repo, prefer_ssh: bool) -> String {
264        if prefer_ssh {
265            repo.ssh_url.clone()
266        } else {
267            repo.clone_url.clone()
268        }
269    }
270}
271
272#[cfg(test)]
273#[path = "client_tests.rs"]
274mod tests;