rung_github/
client.rs

1//! GitHub API client.
2
3use reqwest::Client;
4use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
5use serde::de::DeserializeOwned;
6
7use crate::auth::Auth;
8use crate::error::{Error, Result};
9use crate::types::{
10    CheckRun, CreatePullRequest, MergePullRequest, MergeResult, PullRequest, PullRequestState,
11    UpdatePullRequest,
12};
13
14// === Internal API response types (shared across methods) ===
15
16/// Internal representation of a PR from the GitHub API.
17#[derive(serde::Deserialize)]
18struct ApiPullRequest {
19    number: u64,
20    title: String,
21    body: Option<String>,
22    state: String,
23    draft: bool,
24    html_url: String,
25    head: ApiBranch,
26    base: ApiBranch,
27}
28
29/// Internal representation of a branch ref from the GitHub API.
30#[derive(serde::Deserialize)]
31struct ApiBranch {
32    #[serde(rename = "ref")]
33    ref_name: String,
34}
35
36impl ApiPullRequest {
37    /// Convert API response to domain type, parsing state string.
38    fn into_pull_request(self) -> PullRequest {
39        PullRequest {
40            number: self.number,
41            title: self.title,
42            body: self.body,
43            state: match self.state.as_str() {
44                "open" => PullRequestState::Open,
45                "merged" => PullRequestState::Merged,
46                _ => PullRequestState::Closed,
47            },
48            draft: self.draft,
49            head_branch: self.head.ref_name,
50            base_branch: self.base.ref_name,
51            html_url: self.html_url,
52        }
53    }
54
55    /// Convert API response to domain type with a known state.
56    fn into_pull_request_with_state(self, state: PullRequestState) -> PullRequest {
57        PullRequest {
58            number: self.number,
59            title: self.title,
60            body: self.body,
61            state,
62            draft: self.draft,
63            head_branch: self.head.ref_name,
64            base_branch: self.base.ref_name,
65            html_url: self.html_url,
66        }
67    }
68}
69
70/// GitHub API client.
71pub struct GitHubClient {
72    client: Client,
73    base_url: String,
74    token: String,
75}
76
77impl GitHubClient {
78    /// Default GitHub API URL.
79    pub const DEFAULT_API_URL: &'static str = "https://api.github.com";
80
81    /// Create a new GitHub client.
82    ///
83    /// # Errors
84    /// Returns error if authentication fails.
85    pub fn new(auth: &Auth) -> Result<Self> {
86        Self::with_base_url(auth, Self::DEFAULT_API_URL)
87    }
88
89    /// Create a new GitHub client with a custom API URL (for GitHub Enterprise).
90    ///
91    /// # Errors
92    /// Returns error if authentication fails.
93    pub fn with_base_url(auth: &Auth, base_url: impl Into<String>) -> Result<Self> {
94        let token = auth.resolve()?;
95
96        let mut headers = HeaderMap::new();
97        headers.insert(
98            ACCEPT,
99            HeaderValue::from_static("application/vnd.github+json"),
100        );
101        headers.insert(USER_AGENT, HeaderValue::from_static("rung-cli"));
102        headers.insert(
103            "X-GitHub-Api-Version",
104            HeaderValue::from_static("2022-11-28"),
105        );
106
107        let client = Client::builder().default_headers(headers).build()?;
108
109        Ok(Self {
110            client,
111            base_url: base_url.into(),
112            token,
113        })
114    }
115
116    /// Make a GET request.
117    async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
118        let url = format!("{}{}", self.base_url, path);
119        let response = self
120            .client
121            .get(&url)
122            .header(AUTHORIZATION, format!("Bearer {}", self.token))
123            .send()
124            .await?;
125
126        self.handle_response(response).await
127    }
128
129    /// Make a POST request.
130    async fn post<T: DeserializeOwned, B: serde::Serialize + Sync>(
131        &self,
132        path: &str,
133        body: &B,
134    ) -> Result<T> {
135        let url = format!("{}{}", self.base_url, path);
136        let response = self
137            .client
138            .post(&url)
139            .header(AUTHORIZATION, format!("Bearer {}", self.token))
140            .json(body)
141            .send()
142            .await?;
143
144        self.handle_response(response).await
145    }
146
147    /// Make a PATCH request.
148    async fn patch<T: DeserializeOwned, B: serde::Serialize + Sync>(
149        &self,
150        path: &str,
151        body: &B,
152    ) -> Result<T> {
153        let url = format!("{}{}", self.base_url, path);
154        let response = self
155            .client
156            .patch(&url)
157            .header(AUTHORIZATION, format!("Bearer {}", self.token))
158            .json(body)
159            .send()
160            .await?;
161
162        self.handle_response(response).await
163    }
164
165    /// Make a PUT request.
166    async fn put<T: DeserializeOwned, B: serde::Serialize + Sync>(
167        &self,
168        path: &str,
169        body: &B,
170    ) -> Result<T> {
171        let url = format!("{}{}", self.base_url, path);
172        let response = self
173            .client
174            .put(&url)
175            .header(AUTHORIZATION, format!("Bearer {}", self.token))
176            .json(body)
177            .send()
178            .await?;
179
180        self.handle_response(response).await
181    }
182
183    /// Make a DELETE request.
184    async fn delete(&self, path: &str) -> Result<()> {
185        let url = format!("{}{}", self.base_url, path);
186        let response = self
187            .client
188            .delete(&url)
189            .header(AUTHORIZATION, format!("Bearer {}", self.token))
190            .send()
191            .await?;
192
193        let status = response.status();
194        if status.is_success() || status.as_u16() == 204 {
195            return Ok(());
196        }
197
198        let status_code = status.as_u16();
199        match status_code {
200            401 => Err(Error::AuthenticationFailed),
201            403 if response
202                .headers()
203                .get("x-ratelimit-remaining")
204                .is_some_and(|v| v == "0") =>
205            {
206                Err(Error::RateLimited)
207            }
208            _ => {
209                let text = response.text().await.unwrap_or_default();
210                Err(Error::ApiError {
211                    status: status_code,
212                    message: text,
213                })
214            }
215        }
216    }
217
218    /// Handle API response.
219    async fn handle_response<T: DeserializeOwned>(&self, response: reqwest::Response) -> Result<T> {
220        let status = response.status();
221
222        if status.is_success() {
223            let body = response.json().await?;
224            return Ok(body);
225        }
226
227        // Handle error responses
228        let status_code = status.as_u16();
229
230        match status_code {
231            401 => Err(Error::AuthenticationFailed),
232            403 if response
233                .headers()
234                .get("x-ratelimit-remaining")
235                .is_some_and(|v| v == "0") =>
236            {
237                Err(Error::RateLimited)
238            }
239            _ => {
240                let text = response.text().await.unwrap_or_default();
241                Err(Error::ApiError {
242                    status: status_code,
243                    message: text,
244                })
245            }
246        }
247    }
248
249    // === PR Operations ===
250
251    /// Get a pull request by number.
252    ///
253    /// # Errors
254    /// Returns error if PR not found or API call fails.
255    pub async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
256        let api_pr: ApiPullRequest = self
257            .get(&format!("/repos/{owner}/{repo}/pulls/{number}"))
258            .await?;
259
260        Ok(api_pr.into_pull_request())
261    }
262
263    /// Find a PR for a branch.
264    ///
265    /// # Errors
266    /// Returns error if API call fails.
267    pub async fn find_pr_for_branch(
268        &self,
269        owner: &str,
270        repo: &str,
271        branch: &str,
272    ) -> Result<Option<PullRequest>> {
273        // We only query open PRs, so state is always Open
274        let prs: Vec<ApiPullRequest> = self
275            .get(&format!(
276                "/repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open"
277            ))
278            .await?;
279
280        Ok(prs
281            .into_iter()
282            .next()
283            .map(|api_pr| api_pr.into_pull_request_with_state(PullRequestState::Open)))
284    }
285
286    /// Create a pull request.
287    ///
288    /// # Errors
289    /// Returns error if PR creation fails.
290    pub async fn create_pr(
291        &self,
292        owner: &str,
293        repo: &str,
294        pr: CreatePullRequest,
295    ) -> Result<PullRequest> {
296        // Newly created PRs are always open
297        let api_pr: ApiPullRequest = self
298            .post(&format!("/repos/{owner}/{repo}/pulls"), &pr)
299            .await?;
300
301        Ok(api_pr.into_pull_request_with_state(PullRequestState::Open))
302    }
303
304    /// Update a pull request.
305    ///
306    /// # Errors
307    /// Returns error if PR update fails.
308    pub async fn update_pr(
309        &self,
310        owner: &str,
311        repo: &str,
312        number: u64,
313        update: UpdatePullRequest,
314    ) -> Result<PullRequest> {
315        let api_pr: ApiPullRequest = self
316            .patch(&format!("/repos/{owner}/{repo}/pulls/{number}"), &update)
317            .await?;
318
319        Ok(api_pr.into_pull_request())
320    }
321
322    // === Check Runs ===
323
324    /// Get check runs for a commit.
325    ///
326    /// # Errors
327    /// Returns error if API call fails.
328    pub async fn get_check_runs(
329        &self,
330        owner: &str,
331        repo: &str,
332        commit_sha: &str,
333    ) -> Result<Vec<CheckRun>> {
334        #[derive(serde::Deserialize)]
335        struct Response {
336            check_runs: Vec<ApiCheckRun>,
337        }
338
339        #[derive(serde::Deserialize)]
340        struct ApiCheckRun {
341            name: String,
342            status: String,
343            conclusion: Option<String>,
344            details_url: Option<String>,
345        }
346
347        let response: Response = self
348            .get(&format!(
349                "/repos/{owner}/{repo}/commits/{commit_sha}/check-runs"
350            ))
351            .await?;
352
353        Ok(response
354            .check_runs
355            .into_iter()
356            .map(|cr| CheckRun {
357                name: cr.name,
358                status: match (cr.status.as_str(), cr.conclusion.as_deref()) {
359                    ("queued", _) => crate::types::CheckStatus::Queued,
360                    ("in_progress", _) => crate::types::CheckStatus::InProgress,
361                    ("completed", Some("success")) => crate::types::CheckStatus::Success,
362                    ("completed", Some("skipped")) => crate::types::CheckStatus::Skipped,
363                    ("completed", Some("cancelled")) => crate::types::CheckStatus::Cancelled,
364                    // Any other status (failure, timed_out, action_required, etc.) treated as failure
365                    _ => crate::types::CheckStatus::Failure,
366                },
367                details_url: cr.details_url,
368            })
369            .collect())
370    }
371
372    // === Merge Operations ===
373
374    /// Merge a pull request.
375    ///
376    /// # Errors
377    /// Returns error if merge fails.
378    pub async fn merge_pr(
379        &self,
380        owner: &str,
381        repo: &str,
382        number: u64,
383        merge: MergePullRequest,
384    ) -> Result<MergeResult> {
385        self.put(
386            &format!("/repos/{owner}/{repo}/pulls/{number}/merge"),
387            &merge,
388        )
389        .await
390    }
391
392    // === Ref Operations ===
393
394    /// Delete a git reference (branch).
395    ///
396    /// # Errors
397    /// Returns error if deletion fails.
398    pub async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
399        self.delete(&format!("/repos/{owner}/{repo}/git/refs/heads/{ref_name}"))
400            .await
401    }
402
403    // === Comment Operations ===
404
405    /// List comments on a pull request.
406    ///
407    /// # Errors
408    /// Returns error if request fails.
409    pub async fn list_pr_comments(
410        &self,
411        owner: &str,
412        repo: &str,
413        pr_number: u64,
414    ) -> Result<Vec<crate::types::IssueComment>> {
415        self.get(&format!(
416            "/repos/{owner}/{repo}/issues/{pr_number}/comments"
417        ))
418        .await
419    }
420
421    /// Create a comment on a pull request.
422    ///
423    /// # Errors
424    /// Returns error if request fails.
425    pub async fn create_pr_comment(
426        &self,
427        owner: &str,
428        repo: &str,
429        pr_number: u64,
430        comment: crate::types::CreateComment,
431    ) -> Result<crate::types::IssueComment> {
432        self.post(
433            &format!("/repos/{owner}/{repo}/issues/{pr_number}/comments"),
434            &comment,
435        )
436        .await
437    }
438
439    /// Update a comment on a pull request.
440    ///
441    /// # Errors
442    /// Returns error if request fails.
443    pub async fn update_pr_comment(
444        &self,
445        owner: &str,
446        repo: &str,
447        comment_id: u64,
448        comment: crate::types::UpdateComment,
449    ) -> Result<crate::types::IssueComment> {
450        self.patch(
451            &format!("/repos/{owner}/{repo}/issues/comments/{comment_id}"),
452            &comment,
453        )
454        .await
455    }
456}
457
458impl std::fmt::Debug for GitHubClient {
459    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
460        f.debug_struct("GitHubClient")
461            .field("base_url", &self.base_url)
462            .field("token", &"[redacted]")
463            .finish_non_exhaustive()
464    }
465}