Skip to main content

rung_github/
client.rs

1//! GitHub API client.
2
3use reqwest::Client;
4use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
5use secrecy::{ExposeSecret, SecretString};
6use serde::de::DeserializeOwned;
7
8use crate::auth::Auth;
9use crate::error::{Error, Result};
10use crate::types::{
11    CheckRun, CreatePullRequest, MergePullRequest, MergeResult, PullRequest, PullRequestState,
12    UpdatePullRequest,
13};
14
15// === Internal API response types (shared across methods) ===
16
17/// Internal representation of a PR from the GitHub API.
18#[derive(serde::Deserialize)]
19struct ApiPullRequest {
20    number: u64,
21    title: String,
22    body: Option<String>,
23    state: String,
24    /// Whether the PR was merged (GitHub returns state="closed" + merged=true for merged PRs).
25    #[serde(default)]
26    merged: bool,
27    draft: bool,
28    html_url: String,
29    head: ApiBranch,
30    base: ApiBranch,
31    /// Whether the PR is mergeable (None if GitHub is still computing).
32    mergeable: Option<bool>,
33    /// The mergeable state (e.g., "clean", "dirty", "blocked", "behind").
34    mergeable_state: Option<String>,
35}
36
37/// Internal representation of a branch ref from the GitHub API.
38#[derive(serde::Deserialize)]
39struct ApiBranch {
40    #[serde(rename = "ref")]
41    ref_name: String,
42}
43
44impl ApiPullRequest {
45    /// Convert API response to domain type, parsing state string.
46    fn into_pull_request(self) -> PullRequest {
47        // GitHub API returns state="closed" + merged=true for merged PRs
48        let state = if self.merged {
49            PullRequestState::Merged
50        } else {
51            match self.state.as_str() {
52                "open" => PullRequestState::Open,
53                _ => PullRequestState::Closed,
54            }
55        };
56
57        PullRequest {
58            number: self.number,
59            title: self.title,
60            body: self.body,
61            state,
62            draft: self.draft,
63            head_branch: self.head.ref_name,
64            base_branch: self.base.ref_name,
65            html_url: self.html_url,
66            mergeable: self.mergeable,
67            mergeable_state: self.mergeable_state,
68        }
69    }
70
71    /// Convert API response to domain type with a known state.
72    fn into_pull_request_with_state(self, state: PullRequestState) -> PullRequest {
73        PullRequest {
74            number: self.number,
75            title: self.title,
76            body: self.body,
77            state,
78            draft: self.draft,
79            head_branch: self.head.ref_name,
80            base_branch: self.base.ref_name,
81            html_url: self.html_url,
82            mergeable: self.mergeable,
83            mergeable_state: self.mergeable_state,
84        }
85    }
86}
87
88// === GraphQL types for batch PR fetching ===
89
90/// GraphQL request wrapper.
91#[derive(serde::Serialize)]
92struct GraphQLRequest {
93    query: String,
94    variables: GraphQLVariables,
95}
96
97/// GraphQL variables for PR batch query.
98#[derive(serde::Serialize)]
99struct GraphQLVariables {
100    owner: String,
101    repo: String,
102}
103
104/// GraphQL PR response (different field names than REST API).
105#[derive(serde::Deserialize)]
106#[serde(rename_all = "camelCase")]
107struct GraphQLPullRequest {
108    number: u64,
109    state: String,
110    merged: bool,
111    is_draft: bool,
112    head_ref_name: String,
113    base_ref_name: String,
114    url: String,
115}
116
117impl GraphQLPullRequest {
118    fn into_pull_request(self) -> PullRequest {
119        let state = if self.merged {
120            PullRequestState::Merged
121        } else if self.state == "OPEN" {
122            PullRequestState::Open
123        } else {
124            PullRequestState::Closed
125        };
126
127        PullRequest {
128            number: self.number,
129            title: String::new(), // Not fetched in batch query
130            body: None,
131            state,
132            draft: self.is_draft,
133            head_branch: self.head_ref_name,
134            base_branch: self.base_ref_name,
135            html_url: self.url,
136            mergeable: None, // Not fetched in batch query
137            mergeable_state: None,
138        }
139    }
140}
141
142#[derive(serde::Deserialize)]
143struct GraphQLResponse {
144    data: Option<GraphQLData>,
145    errors: Option<Vec<GraphQLError>>,
146}
147
148#[derive(serde::Deserialize)]
149struct GraphQLData {
150    repository: Option<serde_json::Value>,
151}
152
153#[derive(serde::Deserialize)]
154struct GraphQLError {
155    message: String,
156}
157
158/// GitHub API client.
159pub struct GitHubClient {
160    client: Client,
161    base_url: String,
162    /// Token stored as `SecretString` for automatic zeroization on drop.
163    token: SecretString,
164}
165
166impl GitHubClient {
167    /// Default GitHub API URL.
168    pub const DEFAULT_API_URL: &'static str = "https://api.github.com";
169
170    /// Create a new GitHub client.
171    ///
172    /// # Errors
173    /// Returns error if authentication fails.
174    pub fn new(auth: &Auth) -> Result<Self> {
175        Self::with_base_url(auth, Self::DEFAULT_API_URL)
176    }
177
178    /// Create a new GitHub client with a custom API URL (for GitHub Enterprise).
179    ///
180    /// # Errors
181    /// Returns error if authentication fails.
182    pub fn with_base_url(auth: &Auth, base_url: impl Into<String>) -> Result<Self> {
183        let token = auth.resolve()?;
184
185        let mut headers = HeaderMap::new();
186        headers.insert(
187            ACCEPT,
188            HeaderValue::from_static("application/vnd.github+json"),
189        );
190        headers.insert(USER_AGENT, HeaderValue::from_static("rung-cli"));
191        headers.insert(
192            "X-GitHub-Api-Version",
193            HeaderValue::from_static("2022-11-28"),
194        );
195
196        let client = Client::builder().default_headers(headers).build()?;
197
198        Ok(Self {
199            client,
200            base_url: base_url.into(),
201            token,
202        })
203    }
204
205    /// Make a GET request.
206    async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
207        let url = format!("{}{}", self.base_url, path);
208        let response = self
209            .client
210            .get(&url)
211            .header(
212                AUTHORIZATION,
213                format!("Bearer {}", self.token.expose_secret()),
214            )
215            .send()
216            .await?;
217
218        self.handle_response(response).await
219    }
220
221    /// Make a POST request.
222    async fn post<T: DeserializeOwned, B: serde::Serialize + Sync>(
223        &self,
224        path: &str,
225        body: &B,
226    ) -> Result<T> {
227        let url = format!("{}{}", self.base_url, path);
228        let response = self
229            .client
230            .post(&url)
231            .header(
232                AUTHORIZATION,
233                format!("Bearer {}", self.token.expose_secret()),
234            )
235            .json(body)
236            .send()
237            .await?;
238
239        self.handle_response(response).await
240    }
241
242    /// Make a PATCH request.
243    async fn patch<T: DeserializeOwned, B: serde::Serialize + Sync>(
244        &self,
245        path: &str,
246        body: &B,
247    ) -> Result<T> {
248        let url = format!("{}{}", self.base_url, path);
249        let response = self
250            .client
251            .patch(&url)
252            .header(
253                AUTHORIZATION,
254                format!("Bearer {}", self.token.expose_secret()),
255            )
256            .json(body)
257            .send()
258            .await?;
259
260        self.handle_response(response).await
261    }
262
263    /// Make a PUT request.
264    async fn put<T: DeserializeOwned, B: serde::Serialize + Sync>(
265        &self,
266        path: &str,
267        body: &B,
268    ) -> Result<T> {
269        let url = format!("{}{}", self.base_url, path);
270        let response = self
271            .client
272            .put(&url)
273            .header(
274                AUTHORIZATION,
275                format!("Bearer {}", self.token.expose_secret()),
276            )
277            .json(body)
278            .send()
279            .await?;
280
281        self.handle_response(response).await
282    }
283
284    /// Make a DELETE request.
285    async fn delete(&self, path: &str) -> Result<()> {
286        let url = format!("{}{}", self.base_url, path);
287        let response = self
288            .client
289            .delete(&url)
290            .header(
291                AUTHORIZATION,
292                format!("Bearer {}", self.token.expose_secret()),
293            )
294            .send()
295            .await?;
296
297        let status = response.status();
298        if status.is_success() || status.as_u16() == 204 {
299            return Ok(());
300        }
301
302        let status_code = status.as_u16();
303        match status_code {
304            401 => Err(Error::AuthenticationFailed),
305            403 if response
306                .headers()
307                .get("x-ratelimit-remaining")
308                .is_some_and(|v| v == "0") =>
309            {
310                Err(Error::RateLimited)
311            }
312            _ => {
313                let text = response.text().await.unwrap_or_default();
314                Err(Error::ApiError {
315                    status: status_code,
316                    message: text,
317                })
318            }
319        }
320    }
321
322    /// Handle API response.
323    async fn handle_response<T: DeserializeOwned>(&self, response: reqwest::Response) -> Result<T> {
324        let status = response.status();
325
326        if status.is_success() {
327            let body = response.json().await?;
328            return Ok(body);
329        }
330
331        // Handle error responses
332        let status_code = status.as_u16();
333
334        match status_code {
335            401 => Err(Error::AuthenticationFailed),
336            403 if response
337                .headers()
338                .get("x-ratelimit-remaining")
339                .is_some_and(|v| v == "0") =>
340            {
341                Err(Error::RateLimited)
342            }
343            _ => {
344                let text = response.text().await.unwrap_or_default();
345                Err(Error::ApiError {
346                    status: status_code,
347                    message: text,
348                })
349            }
350        }
351    }
352
353    // === PR Operations ===
354
355    /// Get a pull request by number.
356    ///
357    /// # Errors
358    /// Returns error if PR not found or API call fails.
359    pub async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
360        let api_pr: ApiPullRequest = self
361            .get(&format!("/repos/{owner}/{repo}/pulls/{number}"))
362            .await?;
363
364        Ok(api_pr.into_pull_request())
365    }
366
367    /// Get multiple pull requests by number using GraphQL (single API call).
368    ///
369    /// This is more efficient than calling `get_pr` multiple times when fetching
370    /// many PRs, as it uses a single GraphQL query instead of N REST calls.
371    ///
372    /// Returns a map of PR number to PR data. PRs that don't exist or can't be
373    /// fetched are omitted from the result (no error is returned for missing PRs).
374    ///
375    /// # Errors
376    /// Returns error if the GraphQL request fails entirely.
377    pub async fn get_prs_batch(
378        &self,
379        owner: &str,
380        repo: &str,
381        numbers: &[u64],
382    ) -> Result<std::collections::HashMap<u64, PullRequest>> {
383        if numbers.is_empty() {
384            return Ok(std::collections::HashMap::new());
385        }
386
387        let query = build_graphql_pr_query(numbers);
388        let request = GraphQLRequest {
389            query,
390            variables: GraphQLVariables {
391                owner: owner.to_string(),
392                repo: repo.to_string(),
393            },
394        };
395        let url = format!("{}/graphql", self.base_url);
396
397        let response = self
398            .client
399            .post(&url)
400            .header(
401                AUTHORIZATION,
402                format!("Bearer {}", self.token.expose_secret()),
403            )
404            .json(&request)
405            .send()
406            .await?;
407
408        let status = response.status();
409        if !status.is_success() {
410            let status_code = status.as_u16();
411            return match status_code {
412                401 => Err(Error::AuthenticationFailed),
413                403 if response
414                    .headers()
415                    .get("x-ratelimit-remaining")
416                    .is_some_and(|v| v == "0") =>
417                {
418                    Err(Error::RateLimited)
419                }
420                _ => {
421                    let text = response.text().await.unwrap_or_default();
422                    Err(Error::ApiError {
423                        status: status_code,
424                        message: text,
425                    })
426                }
427            };
428        }
429
430        let graphql_response: GraphQLResponse = response.json().await?;
431
432        // Check for GraphQL-level errors
433        if let Some(errors) = graphql_response.errors {
434            if !errors.is_empty() {
435                let messages: Vec<_> = errors.iter().map(|e| e.message.as_str()).collect();
436                return Err(Error::ApiError {
437                    status: 200,
438                    message: messages.join("; "),
439                });
440            }
441        }
442
443        let mut result = std::collections::HashMap::new();
444
445        if let Some(data) = graphql_response.data {
446            if let Some(repo_data) = data.repository {
447                // Parse each pr0, pr1, pr2... field
448                for (i, &num) in numbers.iter().enumerate() {
449                    let key = format!("pr{i}");
450                    if let Some(pr_value) = repo_data.get(&key) {
451                        // Skip null values (PR doesn't exist)
452                        if !pr_value.is_null() {
453                            if let Ok(pr) =
454                                serde_json::from_value::<GraphQLPullRequest>(pr_value.clone())
455                            {
456                                result.insert(num, pr.into_pull_request());
457                            }
458                        }
459                    }
460                }
461            }
462        }
463
464        Ok(result)
465    }
466
467    /// Find a PR for a branch.
468    ///
469    /// # Errors
470    /// Returns error if API call fails.
471    pub async fn find_pr_for_branch(
472        &self,
473        owner: &str,
474        repo: &str,
475        branch: &str,
476    ) -> Result<Option<PullRequest>> {
477        // We only query open PRs, so state is always Open
478        let prs: Vec<ApiPullRequest> = self
479            .get(&format!(
480                "/repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open"
481            ))
482            .await?;
483
484        Ok(prs
485            .into_iter()
486            .next()
487            .map(|api_pr| api_pr.into_pull_request_with_state(PullRequestState::Open)))
488    }
489
490    /// Create a pull request.
491    ///
492    /// # Errors
493    /// Returns error if PR creation fails.
494    pub async fn create_pr(
495        &self,
496        owner: &str,
497        repo: &str,
498        pr: CreatePullRequest,
499    ) -> Result<PullRequest> {
500        // Newly created PRs are always open
501        let api_pr: ApiPullRequest = self
502            .post(&format!("/repos/{owner}/{repo}/pulls"), &pr)
503            .await?;
504
505        Ok(api_pr.into_pull_request_with_state(PullRequestState::Open))
506    }
507
508    /// Update a pull request.
509    ///
510    /// # Errors
511    /// Returns error if PR update fails.
512    pub async fn update_pr(
513        &self,
514        owner: &str,
515        repo: &str,
516        number: u64,
517        update: UpdatePullRequest,
518    ) -> Result<PullRequest> {
519        let api_pr: ApiPullRequest = self
520            .patch(&format!("/repos/{owner}/{repo}/pulls/{number}"), &update)
521            .await?;
522
523        Ok(api_pr.into_pull_request())
524    }
525
526    // === Check Runs ===
527
528    /// Get check runs for a commit.
529    ///
530    /// # Errors
531    /// Returns error if API call fails.
532    pub async fn get_check_runs(
533        &self,
534        owner: &str,
535        repo: &str,
536        commit_sha: &str,
537    ) -> Result<Vec<CheckRun>> {
538        #[derive(serde::Deserialize)]
539        struct Response {
540            check_runs: Vec<ApiCheckRun>,
541        }
542
543        #[derive(serde::Deserialize)]
544        struct ApiCheckRun {
545            name: String,
546            status: String,
547            conclusion: Option<String>,
548            details_url: Option<String>,
549        }
550
551        let response: Response = self
552            .get(&format!(
553                "/repos/{owner}/{repo}/commits/{commit_sha}/check-runs"
554            ))
555            .await?;
556
557        Ok(response
558            .check_runs
559            .into_iter()
560            .map(|cr| CheckRun {
561                name: cr.name,
562                status: match (cr.status.as_str(), cr.conclusion.as_deref()) {
563                    ("queued", _) => crate::types::CheckStatus::Queued,
564                    ("in_progress", _) => crate::types::CheckStatus::InProgress,
565                    ("completed", Some("success")) => crate::types::CheckStatus::Success,
566                    ("completed", Some("skipped")) => crate::types::CheckStatus::Skipped,
567                    ("completed", Some("cancelled")) => crate::types::CheckStatus::Cancelled,
568                    // Any other status (failure, timed_out, action_required, etc.) treated as failure
569                    _ => crate::types::CheckStatus::Failure,
570                },
571                details_url: cr.details_url,
572            })
573            .collect())
574    }
575
576    // === Merge Operations ===
577
578    /// Merge a pull request.
579    ///
580    /// # Errors
581    /// Returns error if merge fails.
582    pub async fn merge_pr(
583        &self,
584        owner: &str,
585        repo: &str,
586        number: u64,
587        merge: MergePullRequest,
588    ) -> Result<MergeResult> {
589        self.put(
590            &format!("/repos/{owner}/{repo}/pulls/{number}/merge"),
591            &merge,
592        )
593        .await
594    }
595
596    // === Ref Operations ===
597
598    /// Delete a git reference (branch).
599    ///
600    /// # Errors
601    /// Returns error if deletion fails.
602    pub async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
603        self.delete(&format!("/repos/{owner}/{repo}/git/refs/heads/{ref_name}"))
604            .await
605    }
606
607    // === Repository Operations ===
608
609    /// Get the repository's default branch name.
610    ///
611    /// # Errors
612    /// Returns error if API call fails.
613    pub async fn get_default_branch(&self, owner: &str, repo: &str) -> Result<String> {
614        #[derive(serde::Deserialize)]
615        struct RepoInfo {
616            default_branch: String,
617        }
618
619        let info: RepoInfo = self.get(&format!("/repos/{owner}/{repo}")).await?;
620        Ok(info.default_branch)
621    }
622
623    // === Comment Operations ===
624
625    /// List comments on a pull request.
626    ///
627    /// # Errors
628    /// Returns error if request fails.
629    pub async fn list_pr_comments(
630        &self,
631        owner: &str,
632        repo: &str,
633        pr_number: u64,
634    ) -> Result<Vec<crate::types::IssueComment>> {
635        self.get(&format!(
636            "/repos/{owner}/{repo}/issues/{pr_number}/comments"
637        ))
638        .await
639    }
640
641    /// Create a comment on a pull request.
642    ///
643    /// # Errors
644    /// Returns error if request fails.
645    pub async fn create_pr_comment(
646        &self,
647        owner: &str,
648        repo: &str,
649        pr_number: u64,
650        comment: crate::types::CreateComment,
651    ) -> Result<crate::types::IssueComment> {
652        self.post(
653            &format!("/repos/{owner}/{repo}/issues/{pr_number}/comments"),
654            &comment,
655        )
656        .await
657    }
658
659    /// Update a comment on a pull request.
660    ///
661    /// # Errors
662    /// Returns error if request fails.
663    pub async fn update_pr_comment(
664        &self,
665        owner: &str,
666        repo: &str,
667        comment_id: u64,
668        comment: crate::types::UpdateComment,
669    ) -> Result<crate::types::IssueComment> {
670        self.patch(
671            &format!("/repos/{owner}/{repo}/issues/comments/{comment_id}"),
672            &comment,
673        )
674        .await
675    }
676}
677
678impl std::fmt::Debug for GitHubClient {
679    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
680        f.debug_struct("GitHubClient")
681            .field("base_url", &self.base_url)
682            .field("token", &"[redacted]")
683            .finish_non_exhaustive()
684    }
685}
686
687/// Build a GraphQL query to fetch multiple PRs in a single request.
688fn build_graphql_pr_query(numbers: &[u64]) -> String {
689    const PR_FIELDS: &str = "number state merged isDraft headRefName baseRefName url";
690
691    let pr_queries: Vec<String> = numbers
692        .iter()
693        .enumerate()
694        .map(|(i, num)| format!("pr{i}: pullRequest(number: {num}) {{ {PR_FIELDS} }}"))
695        .collect();
696
697    format!(
698        r"query($owner: String!, $repo: String!) {{ repository(owner: $owner, name: $repo) {{ {pr_queries} }} }}",
699        pr_queries = pr_queries.join(" ")
700    )
701}