rung_github/
client.rs

1//! GitHub API client.
2
3use reqwest::Client;
4use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
5use serde::de::DeserializeOwned;
6
7use crate::auth::Auth;
8use crate::error::{Error, Result};
9use crate::types::{
10    CheckRun, CreatePullRequest, MergePullRequest, MergeResult, PullRequest, PullRequestState,
11    UpdatePullRequest,
12};
13
14// === Internal API response types (shared across methods) ===
15
16/// Internal representation of a PR from the GitHub API.
17#[derive(serde::Deserialize)]
18struct ApiPullRequest {
19    number: u64,
20    title: String,
21    body: Option<String>,
22    state: String,
23    /// Whether the PR was merged (GitHub returns state="closed" + merged=true for merged PRs).
24    #[serde(default)]
25    merged: bool,
26    draft: bool,
27    html_url: String,
28    head: ApiBranch,
29    base: ApiBranch,
30}
31
32/// Internal representation of a branch ref from the GitHub API.
33#[derive(serde::Deserialize)]
34struct ApiBranch {
35    #[serde(rename = "ref")]
36    ref_name: String,
37}
38
39impl ApiPullRequest {
40    /// Convert API response to domain type, parsing state string.
41    fn into_pull_request(self) -> PullRequest {
42        // GitHub API returns state="closed" + merged=true for merged PRs
43        let state = if self.merged {
44            PullRequestState::Merged
45        } else {
46            match self.state.as_str() {
47                "open" => PullRequestState::Open,
48                _ => PullRequestState::Closed,
49            }
50        };
51
52        PullRequest {
53            number: self.number,
54            title: self.title,
55            body: self.body,
56            state,
57            draft: self.draft,
58            head_branch: self.head.ref_name,
59            base_branch: self.base.ref_name,
60            html_url: self.html_url,
61        }
62    }
63
64    /// Convert API response to domain type with a known state.
65    fn into_pull_request_with_state(self, state: PullRequestState) -> PullRequest {
66        PullRequest {
67            number: self.number,
68            title: self.title,
69            body: self.body,
70            state,
71            draft: self.draft,
72            head_branch: self.head.ref_name,
73            base_branch: self.base.ref_name,
74            html_url: self.html_url,
75        }
76    }
77}
78
79/// GitHub API client.
80pub struct GitHubClient {
81    client: Client,
82    base_url: String,
83    token: String,
84}
85
86impl GitHubClient {
87    /// Default GitHub API URL.
88    pub const DEFAULT_API_URL: &'static str = "https://api.github.com";
89
90    /// Create a new GitHub client.
91    ///
92    /// # Errors
93    /// Returns error if authentication fails.
94    pub fn new(auth: &Auth) -> Result<Self> {
95        Self::with_base_url(auth, Self::DEFAULT_API_URL)
96    }
97
98    /// Create a new GitHub client with a custom API URL (for GitHub Enterprise).
99    ///
100    /// # Errors
101    /// Returns error if authentication fails.
102    pub fn with_base_url(auth: &Auth, base_url: impl Into<String>) -> Result<Self> {
103        let token = auth.resolve()?;
104
105        let mut headers = HeaderMap::new();
106        headers.insert(
107            ACCEPT,
108            HeaderValue::from_static("application/vnd.github+json"),
109        );
110        headers.insert(USER_AGENT, HeaderValue::from_static("rung-cli"));
111        headers.insert(
112            "X-GitHub-Api-Version",
113            HeaderValue::from_static("2022-11-28"),
114        );
115
116        let client = Client::builder().default_headers(headers).build()?;
117
118        Ok(Self {
119            client,
120            base_url: base_url.into(),
121            token,
122        })
123    }
124
125    /// Make a GET request.
126    async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
127        let url = format!("{}{}", self.base_url, path);
128        let response = self
129            .client
130            .get(&url)
131            .header(AUTHORIZATION, format!("Bearer {}", self.token))
132            .send()
133            .await?;
134
135        self.handle_response(response).await
136    }
137
138    /// Make a POST request.
139    async fn post<T: DeserializeOwned, B: serde::Serialize + Sync>(
140        &self,
141        path: &str,
142        body: &B,
143    ) -> Result<T> {
144        let url = format!("{}{}", self.base_url, path);
145        let response = self
146            .client
147            .post(&url)
148            .header(AUTHORIZATION, format!("Bearer {}", self.token))
149            .json(body)
150            .send()
151            .await?;
152
153        self.handle_response(response).await
154    }
155
156    /// Make a PATCH request.
157    async fn patch<T: DeserializeOwned, B: serde::Serialize + Sync>(
158        &self,
159        path: &str,
160        body: &B,
161    ) -> Result<T> {
162        let url = format!("{}{}", self.base_url, path);
163        let response = self
164            .client
165            .patch(&url)
166            .header(AUTHORIZATION, format!("Bearer {}", self.token))
167            .json(body)
168            .send()
169            .await?;
170
171        self.handle_response(response).await
172    }
173
174    /// Make a PUT request.
175    async fn put<T: DeserializeOwned, B: serde::Serialize + Sync>(
176        &self,
177        path: &str,
178        body: &B,
179    ) -> Result<T> {
180        let url = format!("{}{}", self.base_url, path);
181        let response = self
182            .client
183            .put(&url)
184            .header(AUTHORIZATION, format!("Bearer {}", self.token))
185            .json(body)
186            .send()
187            .await?;
188
189        self.handle_response(response).await
190    }
191
192    /// Make a DELETE request.
193    async fn delete(&self, path: &str) -> Result<()> {
194        let url = format!("{}{}", self.base_url, path);
195        let response = self
196            .client
197            .delete(&url)
198            .header(AUTHORIZATION, format!("Bearer {}", self.token))
199            .send()
200            .await?;
201
202        let status = response.status();
203        if status.is_success() || status.as_u16() == 204 {
204            return Ok(());
205        }
206
207        let status_code = status.as_u16();
208        match status_code {
209            401 => Err(Error::AuthenticationFailed),
210            403 if response
211                .headers()
212                .get("x-ratelimit-remaining")
213                .is_some_and(|v| v == "0") =>
214            {
215                Err(Error::RateLimited)
216            }
217            _ => {
218                let text = response.text().await.unwrap_or_default();
219                Err(Error::ApiError {
220                    status: status_code,
221                    message: text,
222                })
223            }
224        }
225    }
226
227    /// Handle API response.
228    async fn handle_response<T: DeserializeOwned>(&self, response: reqwest::Response) -> Result<T> {
229        let status = response.status();
230
231        if status.is_success() {
232            let body = response.json().await?;
233            return Ok(body);
234        }
235
236        // Handle error responses
237        let status_code = status.as_u16();
238
239        match status_code {
240            401 => Err(Error::AuthenticationFailed),
241            403 if response
242                .headers()
243                .get("x-ratelimit-remaining")
244                .is_some_and(|v| v == "0") =>
245            {
246                Err(Error::RateLimited)
247            }
248            _ => {
249                let text = response.text().await.unwrap_or_default();
250                Err(Error::ApiError {
251                    status: status_code,
252                    message: text,
253                })
254            }
255        }
256    }
257
258    // === PR Operations ===
259
260    /// Get a pull request by number.
261    ///
262    /// # Errors
263    /// Returns error if PR not found or API call fails.
264    pub async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
265        let api_pr: ApiPullRequest = self
266            .get(&format!("/repos/{owner}/{repo}/pulls/{number}"))
267            .await?;
268
269        Ok(api_pr.into_pull_request())
270    }
271
272    /// Find a PR for a branch.
273    ///
274    /// # Errors
275    /// Returns error if API call fails.
276    pub async fn find_pr_for_branch(
277        &self,
278        owner: &str,
279        repo: &str,
280        branch: &str,
281    ) -> Result<Option<PullRequest>> {
282        // We only query open PRs, so state is always Open
283        let prs: Vec<ApiPullRequest> = self
284            .get(&format!(
285                "/repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open"
286            ))
287            .await?;
288
289        Ok(prs
290            .into_iter()
291            .next()
292            .map(|api_pr| api_pr.into_pull_request_with_state(PullRequestState::Open)))
293    }
294
295    /// Create a pull request.
296    ///
297    /// # Errors
298    /// Returns error if PR creation fails.
299    pub async fn create_pr(
300        &self,
301        owner: &str,
302        repo: &str,
303        pr: CreatePullRequest,
304    ) -> Result<PullRequest> {
305        // Newly created PRs are always open
306        let api_pr: ApiPullRequest = self
307            .post(&format!("/repos/{owner}/{repo}/pulls"), &pr)
308            .await?;
309
310        Ok(api_pr.into_pull_request_with_state(PullRequestState::Open))
311    }
312
313    /// Update a pull request.
314    ///
315    /// # Errors
316    /// Returns error if PR update fails.
317    pub async fn update_pr(
318        &self,
319        owner: &str,
320        repo: &str,
321        number: u64,
322        update: UpdatePullRequest,
323    ) -> Result<PullRequest> {
324        let api_pr: ApiPullRequest = self
325            .patch(&format!("/repos/{owner}/{repo}/pulls/{number}"), &update)
326            .await?;
327
328        Ok(api_pr.into_pull_request())
329    }
330
331    // === Check Runs ===
332
333    /// Get check runs for a commit.
334    ///
335    /// # Errors
336    /// Returns error if API call fails.
337    pub async fn get_check_runs(
338        &self,
339        owner: &str,
340        repo: &str,
341        commit_sha: &str,
342    ) -> Result<Vec<CheckRun>> {
343        #[derive(serde::Deserialize)]
344        struct Response {
345            check_runs: Vec<ApiCheckRun>,
346        }
347
348        #[derive(serde::Deserialize)]
349        struct ApiCheckRun {
350            name: String,
351            status: String,
352            conclusion: Option<String>,
353            details_url: Option<String>,
354        }
355
356        let response: Response = self
357            .get(&format!(
358                "/repos/{owner}/{repo}/commits/{commit_sha}/check-runs"
359            ))
360            .await?;
361
362        Ok(response
363            .check_runs
364            .into_iter()
365            .map(|cr| CheckRun {
366                name: cr.name,
367                status: match (cr.status.as_str(), cr.conclusion.as_deref()) {
368                    ("queued", _) => crate::types::CheckStatus::Queued,
369                    ("in_progress", _) => crate::types::CheckStatus::InProgress,
370                    ("completed", Some("success")) => crate::types::CheckStatus::Success,
371                    ("completed", Some("skipped")) => crate::types::CheckStatus::Skipped,
372                    ("completed", Some("cancelled")) => crate::types::CheckStatus::Cancelled,
373                    // Any other status (failure, timed_out, action_required, etc.) treated as failure
374                    _ => crate::types::CheckStatus::Failure,
375                },
376                details_url: cr.details_url,
377            })
378            .collect())
379    }
380
381    // === Merge Operations ===
382
383    /// Merge a pull request.
384    ///
385    /// # Errors
386    /// Returns error if merge fails.
387    pub async fn merge_pr(
388        &self,
389        owner: &str,
390        repo: &str,
391        number: u64,
392        merge: MergePullRequest,
393    ) -> Result<MergeResult> {
394        self.put(
395            &format!("/repos/{owner}/{repo}/pulls/{number}/merge"),
396            &merge,
397        )
398        .await
399    }
400
401    // === Ref Operations ===
402
403    /// Delete a git reference (branch).
404    ///
405    /// # Errors
406    /// Returns error if deletion fails.
407    pub async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
408        self.delete(&format!("/repos/{owner}/{repo}/git/refs/heads/{ref_name}"))
409            .await
410    }
411
412    // === Comment Operations ===
413
414    /// List comments on a pull request.
415    ///
416    /// # Errors
417    /// Returns error if request fails.
418    pub async fn list_pr_comments(
419        &self,
420        owner: &str,
421        repo: &str,
422        pr_number: u64,
423    ) -> Result<Vec<crate::types::IssueComment>> {
424        self.get(&format!(
425            "/repos/{owner}/{repo}/issues/{pr_number}/comments"
426        ))
427        .await
428    }
429
430    /// Create a comment on a pull request.
431    ///
432    /// # Errors
433    /// Returns error if request fails.
434    pub async fn create_pr_comment(
435        &self,
436        owner: &str,
437        repo: &str,
438        pr_number: u64,
439        comment: crate::types::CreateComment,
440    ) -> Result<crate::types::IssueComment> {
441        self.post(
442            &format!("/repos/{owner}/{repo}/issues/{pr_number}/comments"),
443            &comment,
444        )
445        .await
446    }
447
448    /// Update a comment on a pull request.
449    ///
450    /// # Errors
451    /// Returns error if request fails.
452    pub async fn update_pr_comment(
453        &self,
454        owner: &str,
455        repo: &str,
456        comment_id: u64,
457        comment: crate::types::UpdateComment,
458    ) -> Result<crate::types::IssueComment> {
459        self.patch(
460            &format!("/repos/{owner}/{repo}/issues/comments/{comment_id}"),
461            &comment,
462        )
463        .await
464    }
465}
466
467impl std::fmt::Debug for GitHubClient {
468    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469        f.debug_struct("GitHubClient")
470            .field("base_url", &self.base_url)
471            .field("token", &"[redacted]")
472            .finish_non_exhaustive()
473    }
474}