rung_github/
client.rs

1//! GitHub API client.
2
3use reqwest::Client;
4use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
5use secrecy::{ExposeSecret, SecretString};
6use serde::de::DeserializeOwned;
7
8use crate::auth::Auth;
9use crate::error::{Error, Result};
10use crate::types::{
11    CheckRun, CreatePullRequest, MergePullRequest, MergeResult, PullRequest, PullRequestState,
12    UpdatePullRequest,
13};
14
15// === Internal API response types (shared across methods) ===
16
17/// Internal representation of a PR from the GitHub API.
18#[derive(serde::Deserialize)]
19struct ApiPullRequest {
20    number: u64,
21    title: String,
22    body: Option<String>,
23    state: String,
24    /// Whether the PR was merged (GitHub returns state="closed" + merged=true for merged PRs).
25    #[serde(default)]
26    merged: bool,
27    draft: bool,
28    html_url: String,
29    head: ApiBranch,
30    base: ApiBranch,
31}
32
33/// Internal representation of a branch ref from the GitHub API.
34#[derive(serde::Deserialize)]
35struct ApiBranch {
36    #[serde(rename = "ref")]
37    ref_name: String,
38}
39
40impl ApiPullRequest {
41    /// Convert API response to domain type, parsing state string.
42    fn into_pull_request(self) -> PullRequest {
43        // GitHub API returns state="closed" + merged=true for merged PRs
44        let state = if self.merged {
45            PullRequestState::Merged
46        } else {
47            match self.state.as_str() {
48                "open" => PullRequestState::Open,
49                _ => PullRequestState::Closed,
50            }
51        };
52
53        PullRequest {
54            number: self.number,
55            title: self.title,
56            body: self.body,
57            state,
58            draft: self.draft,
59            head_branch: self.head.ref_name,
60            base_branch: self.base.ref_name,
61            html_url: self.html_url,
62        }
63    }
64
65    /// Convert API response to domain type with a known state.
66    fn into_pull_request_with_state(self, state: PullRequestState) -> PullRequest {
67        PullRequest {
68            number: self.number,
69            title: self.title,
70            body: self.body,
71            state,
72            draft: self.draft,
73            head_branch: self.head.ref_name,
74            base_branch: self.base.ref_name,
75            html_url: self.html_url,
76        }
77    }
78}
79
80/// GitHub API client.
81pub struct GitHubClient {
82    client: Client,
83    base_url: String,
84    /// Token stored as `SecretString` for automatic zeroization on drop.
85    token: SecretString,
86}
87
88impl GitHubClient {
89    /// Default GitHub API URL.
90    pub const DEFAULT_API_URL: &'static str = "https://api.github.com";
91
92    /// Create a new GitHub client.
93    ///
94    /// # Errors
95    /// Returns error if authentication fails.
96    pub fn new(auth: &Auth) -> Result<Self> {
97        Self::with_base_url(auth, Self::DEFAULT_API_URL)
98    }
99
100    /// Create a new GitHub client with a custom API URL (for GitHub Enterprise).
101    ///
102    /// # Errors
103    /// Returns error if authentication fails.
104    pub fn with_base_url(auth: &Auth, base_url: impl Into<String>) -> Result<Self> {
105        let token = auth.resolve()?;
106
107        let mut headers = HeaderMap::new();
108        headers.insert(
109            ACCEPT,
110            HeaderValue::from_static("application/vnd.github+json"),
111        );
112        headers.insert(USER_AGENT, HeaderValue::from_static("rung-cli"));
113        headers.insert(
114            "X-GitHub-Api-Version",
115            HeaderValue::from_static("2022-11-28"),
116        );
117
118        let client = Client::builder().default_headers(headers).build()?;
119
120        Ok(Self {
121            client,
122            base_url: base_url.into(),
123            token,
124        })
125    }
126
127    /// Make a GET request.
128    async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
129        let url = format!("{}{}", self.base_url, path);
130        let response = self
131            .client
132            .get(&url)
133            .header(
134                AUTHORIZATION,
135                format!("Bearer {}", self.token.expose_secret()),
136            )
137            .send()
138            .await?;
139
140        self.handle_response(response).await
141    }
142
143    /// Make a POST request.
144    async fn post<T: DeserializeOwned, B: serde::Serialize + Sync>(
145        &self,
146        path: &str,
147        body: &B,
148    ) -> Result<T> {
149        let url = format!("{}{}", self.base_url, path);
150        let response = self
151            .client
152            .post(&url)
153            .header(
154                AUTHORIZATION,
155                format!("Bearer {}", self.token.expose_secret()),
156            )
157            .json(body)
158            .send()
159            .await?;
160
161        self.handle_response(response).await
162    }
163
164    /// Make a PATCH request.
165    async fn patch<T: DeserializeOwned, B: serde::Serialize + Sync>(
166        &self,
167        path: &str,
168        body: &B,
169    ) -> Result<T> {
170        let url = format!("{}{}", self.base_url, path);
171        let response = self
172            .client
173            .patch(&url)
174            .header(
175                AUTHORIZATION,
176                format!("Bearer {}", self.token.expose_secret()),
177            )
178            .json(body)
179            .send()
180            .await?;
181
182        self.handle_response(response).await
183    }
184
185    /// Make a PUT request.
186    async fn put<T: DeserializeOwned, B: serde::Serialize + Sync>(
187        &self,
188        path: &str,
189        body: &B,
190    ) -> Result<T> {
191        let url = format!("{}{}", self.base_url, path);
192        let response = self
193            .client
194            .put(&url)
195            .header(
196                AUTHORIZATION,
197                format!("Bearer {}", self.token.expose_secret()),
198            )
199            .json(body)
200            .send()
201            .await?;
202
203        self.handle_response(response).await
204    }
205
206    /// Make a DELETE request.
207    async fn delete(&self, path: &str) -> Result<()> {
208        let url = format!("{}{}", self.base_url, path);
209        let response = self
210            .client
211            .delete(&url)
212            .header(
213                AUTHORIZATION,
214                format!("Bearer {}", self.token.expose_secret()),
215            )
216            .send()
217            .await?;
218
219        let status = response.status();
220        if status.is_success() || status.as_u16() == 204 {
221            return Ok(());
222        }
223
224        let status_code = status.as_u16();
225        match status_code {
226            401 => Err(Error::AuthenticationFailed),
227            403 if response
228                .headers()
229                .get("x-ratelimit-remaining")
230                .is_some_and(|v| v == "0") =>
231            {
232                Err(Error::RateLimited)
233            }
234            _ => {
235                let text = response.text().await.unwrap_or_default();
236                Err(Error::ApiError {
237                    status: status_code,
238                    message: text,
239                })
240            }
241        }
242    }
243
244    /// Handle API response.
245    async fn handle_response<T: DeserializeOwned>(&self, response: reqwest::Response) -> Result<T> {
246        let status = response.status();
247
248        if status.is_success() {
249            let body = response.json().await?;
250            return Ok(body);
251        }
252
253        // Handle error responses
254        let status_code = status.as_u16();
255
256        match status_code {
257            401 => Err(Error::AuthenticationFailed),
258            403 if response
259                .headers()
260                .get("x-ratelimit-remaining")
261                .is_some_and(|v| v == "0") =>
262            {
263                Err(Error::RateLimited)
264            }
265            _ => {
266                let text = response.text().await.unwrap_or_default();
267                Err(Error::ApiError {
268                    status: status_code,
269                    message: text,
270                })
271            }
272        }
273    }
274
275    // === PR Operations ===
276
277    /// Get a pull request by number.
278    ///
279    /// # Errors
280    /// Returns error if PR not found or API call fails.
281    pub async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
282        let api_pr: ApiPullRequest = self
283            .get(&format!("/repos/{owner}/{repo}/pulls/{number}"))
284            .await?;
285
286        Ok(api_pr.into_pull_request())
287    }
288
289    /// Find a PR for a branch.
290    ///
291    /// # Errors
292    /// Returns error if API call fails.
293    pub async fn find_pr_for_branch(
294        &self,
295        owner: &str,
296        repo: &str,
297        branch: &str,
298    ) -> Result<Option<PullRequest>> {
299        // We only query open PRs, so state is always Open
300        let prs: Vec<ApiPullRequest> = self
301            .get(&format!(
302                "/repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open"
303            ))
304            .await?;
305
306        Ok(prs
307            .into_iter()
308            .next()
309            .map(|api_pr| api_pr.into_pull_request_with_state(PullRequestState::Open)))
310    }
311
312    /// Create a pull request.
313    ///
314    /// # Errors
315    /// Returns error if PR creation fails.
316    pub async fn create_pr(
317        &self,
318        owner: &str,
319        repo: &str,
320        pr: CreatePullRequest,
321    ) -> Result<PullRequest> {
322        // Newly created PRs are always open
323        let api_pr: ApiPullRequest = self
324            .post(&format!("/repos/{owner}/{repo}/pulls"), &pr)
325            .await?;
326
327        Ok(api_pr.into_pull_request_with_state(PullRequestState::Open))
328    }
329
330    /// Update a pull request.
331    ///
332    /// # Errors
333    /// Returns error if PR update fails.
334    pub async fn update_pr(
335        &self,
336        owner: &str,
337        repo: &str,
338        number: u64,
339        update: UpdatePullRequest,
340    ) -> Result<PullRequest> {
341        let api_pr: ApiPullRequest = self
342            .patch(&format!("/repos/{owner}/{repo}/pulls/{number}"), &update)
343            .await?;
344
345        Ok(api_pr.into_pull_request())
346    }
347
348    // === Check Runs ===
349
350    /// Get check runs for a commit.
351    ///
352    /// # Errors
353    /// Returns error if API call fails.
354    pub async fn get_check_runs(
355        &self,
356        owner: &str,
357        repo: &str,
358        commit_sha: &str,
359    ) -> Result<Vec<CheckRun>> {
360        #[derive(serde::Deserialize)]
361        struct Response {
362            check_runs: Vec<ApiCheckRun>,
363        }
364
365        #[derive(serde::Deserialize)]
366        struct ApiCheckRun {
367            name: String,
368            status: String,
369            conclusion: Option<String>,
370            details_url: Option<String>,
371        }
372
373        let response: Response = self
374            .get(&format!(
375                "/repos/{owner}/{repo}/commits/{commit_sha}/check-runs"
376            ))
377            .await?;
378
379        Ok(response
380            .check_runs
381            .into_iter()
382            .map(|cr| CheckRun {
383                name: cr.name,
384                status: match (cr.status.as_str(), cr.conclusion.as_deref()) {
385                    ("queued", _) => crate::types::CheckStatus::Queued,
386                    ("in_progress", _) => crate::types::CheckStatus::InProgress,
387                    ("completed", Some("success")) => crate::types::CheckStatus::Success,
388                    ("completed", Some("skipped")) => crate::types::CheckStatus::Skipped,
389                    ("completed", Some("cancelled")) => crate::types::CheckStatus::Cancelled,
390                    // Any other status (failure, timed_out, action_required, etc.) treated as failure
391                    _ => crate::types::CheckStatus::Failure,
392                },
393                details_url: cr.details_url,
394            })
395            .collect())
396    }
397
398    // === Merge Operations ===
399
400    /// Merge a pull request.
401    ///
402    /// # Errors
403    /// Returns error if merge fails.
404    pub async fn merge_pr(
405        &self,
406        owner: &str,
407        repo: &str,
408        number: u64,
409        merge: MergePullRequest,
410    ) -> Result<MergeResult> {
411        self.put(
412            &format!("/repos/{owner}/{repo}/pulls/{number}/merge"),
413            &merge,
414        )
415        .await
416    }
417
418    // === Ref Operations ===
419
420    /// Delete a git reference (branch).
421    ///
422    /// # Errors
423    /// Returns error if deletion fails.
424    pub async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
425        self.delete(&format!("/repos/{owner}/{repo}/git/refs/heads/{ref_name}"))
426            .await
427    }
428
429    // === Comment Operations ===
430
431    /// List comments on a pull request.
432    ///
433    /// # Errors
434    /// Returns error if request fails.
435    pub async fn list_pr_comments(
436        &self,
437        owner: &str,
438        repo: &str,
439        pr_number: u64,
440    ) -> Result<Vec<crate::types::IssueComment>> {
441        self.get(&format!(
442            "/repos/{owner}/{repo}/issues/{pr_number}/comments"
443        ))
444        .await
445    }
446
447    /// Create a comment on a pull request.
448    ///
449    /// # Errors
450    /// Returns error if request fails.
451    pub async fn create_pr_comment(
452        &self,
453        owner: &str,
454        repo: &str,
455        pr_number: u64,
456        comment: crate::types::CreateComment,
457    ) -> Result<crate::types::IssueComment> {
458        self.post(
459            &format!("/repos/{owner}/{repo}/issues/{pr_number}/comments"),
460            &comment,
461        )
462        .await
463    }
464
465    /// Update a comment on a pull request.
466    ///
467    /// # Errors
468    /// Returns error if request fails.
469    pub async fn update_pr_comment(
470        &self,
471        owner: &str,
472        repo: &str,
473        comment_id: u64,
474        comment: crate::types::UpdateComment,
475    ) -> Result<crate::types::IssueComment> {
476        self.patch(
477            &format!("/repos/{owner}/{repo}/issues/comments/{comment_id}"),
478            &comment,
479        )
480        .await
481    }
482}
483
484impl std::fmt::Debug for GitHubClient {
485    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
486        f.debug_struct("GitHubClient")
487            .field("base_url", &self.base_url)
488            .field("token", &"[redacted]")
489            .finish_non_exhaustive()
490    }
491}