Skip to main content

rung_github/
client.rs

1//! GitHub API client.
2
3use reqwest::Client;
4use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
5use serde::de::DeserializeOwned;
6
7use crate::auth::Auth;
8use crate::error::{Error, Result};
9use crate::types::{
10    CheckRun, CreatePullRequest, MergePullRequest, MergeResult, PullRequest, UpdatePullRequest,
11};
12
13/// GitHub API client.
14pub struct GitHubClient {
15    client: Client,
16    base_url: String,
17    token: String,
18}
19
20impl GitHubClient {
21    /// Default GitHub API URL.
22    pub const DEFAULT_API_URL: &'static str = "https://api.github.com";
23
24    /// Create a new GitHub client.
25    ///
26    /// # Errors
27    /// Returns error if authentication fails.
28    pub fn new(auth: &Auth) -> Result<Self> {
29        Self::with_base_url(auth, Self::DEFAULT_API_URL)
30    }
31
32    /// Create a new GitHub client with a custom API URL (for GitHub Enterprise).
33    ///
34    /// # Errors
35    /// Returns error if authentication fails.
36    pub fn with_base_url(auth: &Auth, base_url: impl Into<String>) -> Result<Self> {
37        let token = auth.resolve()?;
38
39        let mut headers = HeaderMap::new();
40        headers.insert(
41            ACCEPT,
42            HeaderValue::from_static("application/vnd.github+json"),
43        );
44        headers.insert(USER_AGENT, HeaderValue::from_static("rung-cli"));
45        headers.insert(
46            "X-GitHub-Api-Version",
47            HeaderValue::from_static("2022-11-28"),
48        );
49
50        let client = Client::builder().default_headers(headers).build()?;
51
52        Ok(Self {
53            client,
54            base_url: base_url.into(),
55            token,
56        })
57    }
58
59    /// Make a GET request.
60    async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
61        let url = format!("{}{}", self.base_url, path);
62        let response = self
63            .client
64            .get(&url)
65            .header(AUTHORIZATION, format!("Bearer {}", self.token))
66            .send()
67            .await?;
68
69        self.handle_response(response).await
70    }
71
72    /// Make a POST request.
73    async fn post<T: DeserializeOwned, B: serde::Serialize + Sync>(
74        &self,
75        path: &str,
76        body: &B,
77    ) -> Result<T> {
78        let url = format!("{}{}", self.base_url, path);
79        let response = self
80            .client
81            .post(&url)
82            .header(AUTHORIZATION, format!("Bearer {}", self.token))
83            .json(body)
84            .send()
85            .await?;
86
87        self.handle_response(response).await
88    }
89
90    /// Make a PATCH request.
91    async fn patch<T: DeserializeOwned, B: serde::Serialize + Sync>(
92        &self,
93        path: &str,
94        body: &B,
95    ) -> Result<T> {
96        let url = format!("{}{}", self.base_url, path);
97        let response = self
98            .client
99            .patch(&url)
100            .header(AUTHORIZATION, format!("Bearer {}", self.token))
101            .json(body)
102            .send()
103            .await?;
104
105        self.handle_response(response).await
106    }
107
108    /// Make a PUT request.
109    async fn put<T: DeserializeOwned, B: serde::Serialize + Sync>(
110        &self,
111        path: &str,
112        body: &B,
113    ) -> Result<T> {
114        let url = format!("{}{}", self.base_url, path);
115        let response = self
116            .client
117            .put(&url)
118            .header(AUTHORIZATION, format!("Bearer {}", self.token))
119            .json(body)
120            .send()
121            .await?;
122
123        self.handle_response(response).await
124    }
125
126    /// Make a DELETE request.
127    async fn delete(&self, path: &str) -> Result<()> {
128        let url = format!("{}{}", self.base_url, path);
129        let response = self
130            .client
131            .delete(&url)
132            .header(AUTHORIZATION, format!("Bearer {}", self.token))
133            .send()
134            .await?;
135
136        let status = response.status();
137        if status.is_success() || status.as_u16() == 204 {
138            return Ok(());
139        }
140
141        let status_code = status.as_u16();
142        match status_code {
143            401 => Err(Error::AuthenticationFailed),
144            403 if response
145                .headers()
146                .get("x-ratelimit-remaining")
147                .is_some_and(|v| v == "0") =>
148            {
149                Err(Error::RateLimited)
150            }
151            _ => {
152                let text = response.text().await.unwrap_or_default();
153                Err(Error::ApiError {
154                    status: status_code,
155                    message: text,
156                })
157            }
158        }
159    }
160
161    /// Handle API response.
162    async fn handle_response<T: DeserializeOwned>(&self, response: reqwest::Response) -> Result<T> {
163        let status = response.status();
164
165        if status.is_success() {
166            let body = response.json().await?;
167            return Ok(body);
168        }
169
170        // Handle error responses
171        let status_code = status.as_u16();
172
173        match status_code {
174            401 => Err(Error::AuthenticationFailed),
175            403 if response
176                .headers()
177                .get("x-ratelimit-remaining")
178                .is_some_and(|v| v == "0") =>
179            {
180                Err(Error::RateLimited)
181            }
182            _ => {
183                let text = response.text().await.unwrap_or_default();
184                Err(Error::ApiError {
185                    status: status_code,
186                    message: text,
187                })
188            }
189        }
190    }
191
192    // === PR Operations ===
193
194    /// Get a pull request by number.
195    ///
196    /// # Errors
197    /// Returns error if PR not found or API call fails.
198    pub async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
199        #[derive(serde::Deserialize)]
200        struct ApiPr {
201            number: u64,
202            title: String,
203            body: Option<String>,
204            state: String,
205            draft: bool,
206            html_url: String,
207            head: Branch,
208            base: Branch,
209        }
210
211        #[derive(serde::Deserialize)]
212        struct Branch {
213            #[serde(rename = "ref")]
214            ref_name: String,
215        }
216
217        let api_pr: ApiPr = self
218            .get(&format!("/repos/{owner}/{repo}/pulls/{number}"))
219            .await?;
220
221        Ok(PullRequest {
222            number: api_pr.number,
223            title: api_pr.title,
224            body: api_pr.body,
225            state: match api_pr.state.as_str() {
226                "open" => crate::types::PullRequestState::Open,
227                "merged" => crate::types::PullRequestState::Merged,
228                _ => crate::types::PullRequestState::Closed,
229            },
230            draft: api_pr.draft,
231            head_branch: api_pr.head.ref_name,
232            base_branch: api_pr.base.ref_name,
233            html_url: api_pr.html_url,
234        })
235    }
236
237    /// Find a PR for a branch.
238    ///
239    /// # Errors
240    /// Returns error if API call fails.
241    pub async fn find_pr_for_branch(
242        &self,
243        owner: &str,
244        repo: &str,
245        branch: &str,
246    ) -> Result<Option<PullRequest>> {
247        #[derive(serde::Deserialize)]
248        struct ApiPr {
249            number: u64,
250            title: String,
251            body: Option<String>,
252            #[allow(dead_code)]
253            state: String,
254            draft: bool,
255            html_url: String,
256            head: Branch,
257            base: Branch,
258        }
259
260        #[derive(serde::Deserialize)]
261        struct Branch {
262            #[serde(rename = "ref")]
263            ref_name: String,
264        }
265
266        // We only query open PRs, so state is always Open
267        let prs: Vec<ApiPr> = self
268            .get(&format!(
269                "/repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open"
270            ))
271            .await?;
272
273        Ok(prs.into_iter().next().map(|api_pr| PullRequest {
274            number: api_pr.number,
275            title: api_pr.title,
276            body: api_pr.body,
277            state: crate::types::PullRequestState::Open,
278            draft: api_pr.draft,
279            head_branch: api_pr.head.ref_name,
280            base_branch: api_pr.base.ref_name,
281            html_url: api_pr.html_url,
282        }))
283    }
284
285    /// Create a pull request.
286    ///
287    /// # Errors
288    /// Returns error if PR creation fails.
289    pub async fn create_pr(
290        &self,
291        owner: &str,
292        repo: &str,
293        pr: CreatePullRequest,
294    ) -> Result<PullRequest> {
295        #[derive(serde::Deserialize)]
296        struct ApiPr {
297            number: u64,
298            title: String,
299            body: Option<String>,
300            #[allow(dead_code)]
301            state: String,
302            draft: bool,
303            html_url: String,
304            head: Branch,
305            base: Branch,
306        }
307
308        #[derive(serde::Deserialize)]
309        struct Branch {
310            #[serde(rename = "ref")]
311            ref_name: String,
312        }
313
314        // Newly created PRs are always open
315        let api_pr: ApiPr = self
316            .post(&format!("/repos/{owner}/{repo}/pulls"), &pr)
317            .await?;
318
319        Ok(PullRequest {
320            number: api_pr.number,
321            title: api_pr.title,
322            body: api_pr.body,
323            state: crate::types::PullRequestState::Open,
324            draft: api_pr.draft,
325            head_branch: api_pr.head.ref_name,
326            base_branch: api_pr.base.ref_name,
327            html_url: api_pr.html_url,
328        })
329    }
330
331    /// Update a pull request.
332    ///
333    /// # Errors
334    /// Returns error if PR update fails.
335    pub async fn update_pr(
336        &self,
337        owner: &str,
338        repo: &str,
339        number: u64,
340        update: UpdatePullRequest,
341    ) -> Result<PullRequest> {
342        #[derive(serde::Deserialize)]
343        struct ApiPr {
344            number: u64,
345            title: String,
346            body: Option<String>,
347            state: String,
348            draft: bool,
349            html_url: String,
350            head: Branch,
351            base: Branch,
352        }
353
354        #[derive(serde::Deserialize)]
355        struct Branch {
356            #[serde(rename = "ref")]
357            ref_name: String,
358        }
359
360        let api_pr: ApiPr = self
361            .patch(&format!("/repos/{owner}/{repo}/pulls/{number}"), &update)
362            .await?;
363
364        Ok(PullRequest {
365            number: api_pr.number,
366            title: api_pr.title,
367            body: api_pr.body,
368            state: match api_pr.state.as_str() {
369                "open" => crate::types::PullRequestState::Open,
370                "merged" => crate::types::PullRequestState::Merged,
371                _ => crate::types::PullRequestState::Closed,
372            },
373            draft: api_pr.draft,
374            head_branch: api_pr.head.ref_name,
375            base_branch: api_pr.base.ref_name,
376            html_url: api_pr.html_url,
377        })
378    }
379
380    // === Check Runs ===
381
382    /// Get check runs for a commit.
383    ///
384    /// # Errors
385    /// Returns error if API call fails.
386    pub async fn get_check_runs(
387        &self,
388        owner: &str,
389        repo: &str,
390        commit_sha: &str,
391    ) -> Result<Vec<CheckRun>> {
392        #[derive(serde::Deserialize)]
393        struct Response {
394            check_runs: Vec<ApiCheckRun>,
395        }
396
397        #[derive(serde::Deserialize)]
398        struct ApiCheckRun {
399            name: String,
400            status: String,
401            conclusion: Option<String>,
402            details_url: Option<String>,
403        }
404
405        let response: Response = self
406            .get(&format!(
407                "/repos/{owner}/{repo}/commits/{commit_sha}/check-runs"
408            ))
409            .await?;
410
411        Ok(response
412            .check_runs
413            .into_iter()
414            .map(|cr| CheckRun {
415                name: cr.name,
416                status: match (cr.status.as_str(), cr.conclusion.as_deref()) {
417                    ("queued", _) => crate::types::CheckStatus::Queued,
418                    ("in_progress", _) => crate::types::CheckStatus::InProgress,
419                    ("completed", Some("success")) => crate::types::CheckStatus::Success,
420                    ("completed", Some("skipped")) => crate::types::CheckStatus::Skipped,
421                    ("completed", Some("cancelled")) => crate::types::CheckStatus::Cancelled,
422                    // Any other status (failure, timed_out, action_required, etc.) treated as failure
423                    _ => crate::types::CheckStatus::Failure,
424                },
425                details_url: cr.details_url,
426            })
427            .collect())
428    }
429
430    // === Merge Operations ===
431
432    /// Merge a pull request.
433    ///
434    /// # Errors
435    /// Returns error if merge fails.
436    pub async fn merge_pr(
437        &self,
438        owner: &str,
439        repo: &str,
440        number: u64,
441        merge: MergePullRequest,
442    ) -> Result<MergeResult> {
443        self.put(
444            &format!("/repos/{owner}/{repo}/pulls/{number}/merge"),
445            &merge,
446        )
447        .await
448    }
449
450    // === Ref Operations ===
451
452    /// Delete a git reference (branch).
453    ///
454    /// # Errors
455    /// Returns error if deletion fails.
456    pub async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
457        self.delete(&format!("/repos/{owner}/{repo}/git/refs/heads/{ref_name}"))
458            .await
459    }
460
461    // === Comment Operations ===
462
463    /// List comments on a pull request.
464    ///
465    /// # Errors
466    /// Returns error if request fails.
467    pub async fn list_pr_comments(
468        &self,
469        owner: &str,
470        repo: &str,
471        pr_number: u64,
472    ) -> Result<Vec<crate::types::IssueComment>> {
473        self.get(&format!(
474            "/repos/{owner}/{repo}/issues/{pr_number}/comments"
475        ))
476        .await
477    }
478
479    /// Create a comment on a pull request.
480    ///
481    /// # Errors
482    /// Returns error if request fails.
483    pub async fn create_pr_comment(
484        &self,
485        owner: &str,
486        repo: &str,
487        pr_number: u64,
488        comment: crate::types::CreateComment,
489    ) -> Result<crate::types::IssueComment> {
490        self.post(
491            &format!("/repos/{owner}/{repo}/issues/{pr_number}/comments"),
492            &comment,
493        )
494        .await
495    }
496
497    /// Update a comment on a pull request.
498    ///
499    /// # Errors
500    /// Returns error if request fails.
501    pub async fn update_pr_comment(
502        &self,
503        owner: &str,
504        repo: &str,
505        comment_id: u64,
506        comment: crate::types::UpdateComment,
507    ) -> Result<crate::types::IssueComment> {
508        self.patch(
509            &format!("/repos/{owner}/{repo}/issues/comments/{comment_id}"),
510            &comment,
511        )
512        .await
513    }
514}
515
516impl std::fmt::Debug for GitHubClient {
517    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
518        f.debug_struct("GitHubClient")
519            .field("base_url", &self.base_url)
520            .field("token", &"[redacted]")
521            .finish_non_exhaustive()
522    }
523}