Skip to main content

rung_github/
client.rs

1//! GitHub API client.
2
3use reqwest::Client;
4use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
5use secrecy::{ExposeSecret, SecretString};
6use serde::de::DeserializeOwned;
7
8use crate::auth::Auth;
9use crate::error::{Error, Result};
10use crate::traits::GitHubApi;
11use crate::types::{
12    CheckRun, CreateComment, CreatePullRequest, IssueComment, MergePullRequest, MergeResult,
13    PullRequest, PullRequestState, UpdateComment, UpdatePullRequest,
14};
15
16// === Internal API response types (shared across methods) ===
17
18/// Internal representation of a PR from the GitHub API.
19#[derive(serde::Deserialize)]
20struct ApiPullRequest {
21    number: u64,
22    title: String,
23    body: Option<String>,
24    state: String,
25    /// Whether the PR was merged (GitHub returns state="closed" + merged=true for merged PRs).
26    #[serde(default)]
27    merged: bool,
28    draft: bool,
29    html_url: String,
30    head: ApiBranch,
31    base: ApiBranch,
32    /// Whether the PR is mergeable (None if GitHub is still computing).
33    mergeable: Option<bool>,
34    /// The mergeable state (e.g., "clean", "dirty", "blocked", "behind").
35    mergeable_state: Option<String>,
36}
37
38/// Internal representation of a branch ref from the GitHub API.
39#[derive(serde::Deserialize)]
40struct ApiBranch {
41    #[serde(rename = "ref")]
42    ref_name: String,
43}
44
45impl ApiPullRequest {
46    /// Convert API response to domain type, parsing state string.
47    fn into_pull_request(self) -> PullRequest {
48        // GitHub API returns state="closed" + merged=true for merged PRs
49        let state = if self.merged {
50            PullRequestState::Merged
51        } else {
52            match self.state.as_str() {
53                "open" => PullRequestState::Open,
54                _ => PullRequestState::Closed,
55            }
56        };
57
58        PullRequest {
59            number: self.number,
60            title: self.title,
61            body: self.body,
62            state,
63            draft: self.draft,
64            head_branch: self.head.ref_name,
65            base_branch: self.base.ref_name,
66            html_url: self.html_url,
67            mergeable: self.mergeable,
68            mergeable_state: self.mergeable_state,
69        }
70    }
71
72    /// Convert API response to domain type with a known state.
73    fn into_pull_request_with_state(self, state: PullRequestState) -> PullRequest {
74        PullRequest {
75            number: self.number,
76            title: self.title,
77            body: self.body,
78            state,
79            draft: self.draft,
80            head_branch: self.head.ref_name,
81            base_branch: self.base.ref_name,
82            html_url: self.html_url,
83            mergeable: self.mergeable,
84            mergeable_state: self.mergeable_state,
85        }
86    }
87}
88
89// === GraphQL types for batch PR fetching ===
90
91/// GraphQL request wrapper.
92#[derive(serde::Serialize)]
93struct GraphQLRequest {
94    query: String,
95    variables: GraphQLVariables,
96}
97
98/// GraphQL variables for PR batch query.
99#[derive(serde::Serialize)]
100struct GraphQLVariables {
101    owner: String,
102    repo: String,
103}
104
105/// GraphQL PR response (different field names than REST API).
106#[derive(serde::Deserialize)]
107#[serde(rename_all = "camelCase")]
108struct GraphQLPullRequest {
109    number: u64,
110    state: String,
111    merged: bool,
112    is_draft: bool,
113    head_ref_name: String,
114    base_ref_name: String,
115    url: String,
116}
117
118impl GraphQLPullRequest {
119    fn into_pull_request(self) -> PullRequest {
120        let state = if self.merged {
121            PullRequestState::Merged
122        } else if self.state == "OPEN" {
123            PullRequestState::Open
124        } else {
125            PullRequestState::Closed
126        };
127
128        PullRequest {
129            number: self.number,
130            title: String::new(), // Not fetched in batch query
131            body: None,
132            state,
133            draft: self.is_draft,
134            head_branch: self.head_ref_name,
135            base_branch: self.base_ref_name,
136            html_url: self.url,
137            mergeable: None, // Not fetched in batch query
138            mergeable_state: None,
139        }
140    }
141}
142
143#[derive(serde::Deserialize)]
144struct GraphQLResponse {
145    data: Option<GraphQLData>,
146    errors: Option<Vec<GraphQLError>>,
147}
148
149#[derive(serde::Deserialize)]
150struct GraphQLData {
151    repository: Option<serde_json::Value>,
152}
153
154#[derive(serde::Deserialize)]
155struct GraphQLError {
156    message: String,
157}
158
159/// GitHub API client.
160pub struct GitHubClient {
161    client: Client,
162    base_url: String,
163    /// Token stored as `SecretString` for automatic zeroization on drop.
164    token: SecretString,
165}
166
167impl GitHubClient {
168    /// Default GitHub API URL.
169    pub const DEFAULT_API_URL: &'static str = "https://api.github.com";
170
171    /// Create a new GitHub client.
172    ///
173    /// # Errors
174    /// Returns error if authentication fails.
175    pub fn new(auth: &Auth) -> Result<Self> {
176        Self::with_base_url(auth, Self::DEFAULT_API_URL)
177    }
178
179    /// Create a new GitHub client with a custom API URL (for GitHub Enterprise).
180    ///
181    /// # Errors
182    /// Returns error if authentication fails.
183    pub fn with_base_url(auth: &Auth, base_url: impl Into<String>) -> Result<Self> {
184        let token = auth.resolve()?;
185
186        let mut headers = HeaderMap::new();
187        headers.insert(
188            ACCEPT,
189            HeaderValue::from_static("application/vnd.github+json"),
190        );
191        headers.insert(USER_AGENT, HeaderValue::from_static("rung-cli"));
192        headers.insert(
193            "X-GitHub-Api-Version",
194            HeaderValue::from_static("2022-11-28"),
195        );
196
197        let client = Client::builder().default_headers(headers).build()?;
198
199        Ok(Self {
200            client,
201            base_url: base_url.into(),
202            token,
203        })
204    }
205
206    /// Make a GET request.
207    async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
208        let url = format!("{}{}", self.base_url, path);
209        let response = self
210            .client
211            .get(&url)
212            .header(
213                AUTHORIZATION,
214                format!("Bearer {}", self.token.expose_secret()),
215            )
216            .send()
217            .await?;
218
219        self.handle_response(response).await
220    }
221
222    /// Make a POST request.
223    async fn post<T: DeserializeOwned, B: serde::Serialize + Sync>(
224        &self,
225        path: &str,
226        body: &B,
227    ) -> Result<T> {
228        let url = format!("{}{}", self.base_url, path);
229        let response = self
230            .client
231            .post(&url)
232            .header(
233                AUTHORIZATION,
234                format!("Bearer {}", self.token.expose_secret()),
235            )
236            .json(body)
237            .send()
238            .await?;
239
240        self.handle_response(response).await
241    }
242
243    /// Make a PATCH request.
244    async fn patch<T: DeserializeOwned, B: serde::Serialize + Sync>(
245        &self,
246        path: &str,
247        body: &B,
248    ) -> Result<T> {
249        let url = format!("{}{}", self.base_url, path);
250        let response = self
251            .client
252            .patch(&url)
253            .header(
254                AUTHORIZATION,
255                format!("Bearer {}", self.token.expose_secret()),
256            )
257            .json(body)
258            .send()
259            .await?;
260
261        self.handle_response(response).await
262    }
263
264    /// Make a PUT request.
265    async fn put<T: DeserializeOwned, B: serde::Serialize + Sync>(
266        &self,
267        path: &str,
268        body: &B,
269    ) -> Result<T> {
270        let url = format!("{}{}", self.base_url, path);
271        let response = self
272            .client
273            .put(&url)
274            .header(
275                AUTHORIZATION,
276                format!("Bearer {}", self.token.expose_secret()),
277            )
278            .json(body)
279            .send()
280            .await?;
281
282        self.handle_response(response).await
283    }
284
285    /// Make a DELETE request.
286    async fn delete(&self, path: &str) -> Result<()> {
287        let url = format!("{}{}", self.base_url, path);
288        let response = self
289            .client
290            .delete(&url)
291            .header(
292                AUTHORIZATION,
293                format!("Bearer {}", self.token.expose_secret()),
294            )
295            .send()
296            .await?;
297
298        let status = response.status();
299        if status.is_success() || status.as_u16() == 204 {
300            return Ok(());
301        }
302
303        let status_code = status.as_u16();
304        match status_code {
305            401 => Err(Error::AuthenticationFailed),
306            403 if response
307                .headers()
308                .get("x-ratelimit-remaining")
309                .is_some_and(|v| v == "0") =>
310            {
311                Err(Error::RateLimited)
312            }
313            _ => {
314                let text = response.text().await.unwrap_or_default();
315                Err(Error::ApiError {
316                    status: status_code,
317                    message: text,
318                })
319            }
320        }
321    }
322
323    /// Handle API response.
324    async fn handle_response<T: DeserializeOwned>(&self, response: reqwest::Response) -> Result<T> {
325        let status = response.status();
326
327        if status.is_success() {
328            let body = response.json().await?;
329            return Ok(body);
330        }
331
332        // Handle error responses
333        let status_code = status.as_u16();
334
335        match status_code {
336            401 => Err(Error::AuthenticationFailed),
337            403 if response
338                .headers()
339                .get("x-ratelimit-remaining")
340                .is_some_and(|v| v == "0") =>
341            {
342                Err(Error::RateLimited)
343            }
344            _ => {
345                let text = response.text().await.unwrap_or_default();
346                Err(Error::ApiError {
347                    status: status_code,
348                    message: text,
349                })
350            }
351        }
352    }
353
354    // === PR Operations ===
355
356    /// Get a pull request by number.
357    ///
358    /// # Errors
359    /// Returns error if PR not found or API call fails.
360    pub async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
361        let api_pr: ApiPullRequest = self
362            .get(&format!("/repos/{owner}/{repo}/pulls/{number}"))
363            .await?;
364
365        Ok(api_pr.into_pull_request())
366    }
367
368    /// Get multiple pull requests by number using GraphQL (single API call).
369    ///
370    /// This is more efficient than calling `get_pr` multiple times when fetching
371    /// many PRs, as it uses a single GraphQL query instead of N REST calls.
372    ///
373    /// Returns a map of PR number to PR data. PRs that don't exist or can't be
374    /// fetched are omitted from the result (no error is returned for missing PRs).
375    ///
376    /// # Errors
377    /// Returns error if the GraphQL request fails entirely.
378    pub async fn get_prs_batch(
379        &self,
380        owner: &str,
381        repo: &str,
382        numbers: &[u64],
383    ) -> Result<std::collections::HashMap<u64, PullRequest>> {
384        if numbers.is_empty() {
385            return Ok(std::collections::HashMap::new());
386        }
387
388        let query = build_graphql_pr_query(numbers);
389        let request = GraphQLRequest {
390            query,
391            variables: GraphQLVariables {
392                owner: owner.to_string(),
393                repo: repo.to_string(),
394            },
395        };
396        let url = format!("{}/graphql", self.base_url);
397
398        let response = self
399            .client
400            .post(&url)
401            .header(
402                AUTHORIZATION,
403                format!("Bearer {}", self.token.expose_secret()),
404            )
405            .json(&request)
406            .send()
407            .await?;
408
409        let status = response.status();
410        if !status.is_success() {
411            let status_code = status.as_u16();
412            return match status_code {
413                401 => Err(Error::AuthenticationFailed),
414                403 if response
415                    .headers()
416                    .get("x-ratelimit-remaining")
417                    .is_some_and(|v| v == "0") =>
418                {
419                    Err(Error::RateLimited)
420                }
421                _ => {
422                    let text = response.text().await.unwrap_or_default();
423                    Err(Error::ApiError {
424                        status: status_code,
425                        message: text,
426                    })
427                }
428            };
429        }
430
431        let graphql_response: GraphQLResponse = response.json().await?;
432
433        // Only fail if there's no data at all; allow partial results with errors
434        if graphql_response.data.is_none() {
435            if let Some(errors) = graphql_response.errors
436                && !errors.is_empty()
437            {
438                let messages: Vec<_> = errors.iter().map(|e| e.message.as_str()).collect();
439                return Err(Error::ApiError {
440                    status: 200,
441                    message: messages.join("; "),
442                });
443            }
444            // No data and no errors - return empty result
445            return Ok(std::collections::HashMap::new());
446        }
447
448        let mut result = std::collections::HashMap::new();
449
450        if let Some(data) = graphql_response.data {
451            if let Some(repo_data) = data.repository {
452                // Parse each pr0, pr1, pr2... field (null entries are skipped for partial results)
453                for (i, &num) in numbers.iter().enumerate() {
454                    let key = format!("pr{i}");
455                    if let Some(pr_value) = repo_data.get(&key)
456                        && !pr_value.is_null()
457                        && let Ok(pr) =
458                            serde_json::from_value::<GraphQLPullRequest>(pr_value.clone())
459                    {
460                        result.insert(num, pr.into_pull_request());
461                    }
462                }
463            } else if let Some(errors) = graphql_response.errors
464                && !errors.is_empty()
465            {
466                // data exists but repository is null, and there are errors
467                let messages: Vec<_> = errors.iter().map(|e| e.message.as_str()).collect();
468                return Err(Error::ApiError {
469                    status: 200,
470                    message: messages.join("; "),
471                });
472            }
473        }
474
475        Ok(result)
476    }
477
478    /// Find a PR for a branch.
479    ///
480    /// # Errors
481    /// Returns error if API call fails.
482    pub async fn find_pr_for_branch(
483        &self,
484        owner: &str,
485        repo: &str,
486        branch: &str,
487    ) -> Result<Option<PullRequest>> {
488        // We only query open PRs, so state is always Open
489        let prs: Vec<ApiPullRequest> = self
490            .get(&format!(
491                "/repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open"
492            ))
493            .await?;
494
495        Ok(prs
496            .into_iter()
497            .next()
498            .map(|api_pr| api_pr.into_pull_request_with_state(PullRequestState::Open)))
499    }
500
501    /// Create a pull request.
502    ///
503    /// # Errors
504    /// Returns error if PR creation fails.
505    pub async fn create_pr(
506        &self,
507        owner: &str,
508        repo: &str,
509        pr: CreatePullRequest,
510    ) -> Result<PullRequest> {
511        // Newly created PRs are always open
512        let api_pr: ApiPullRequest = self
513            .post(&format!("/repos/{owner}/{repo}/pulls"), &pr)
514            .await?;
515
516        Ok(api_pr.into_pull_request_with_state(PullRequestState::Open))
517    }
518
519    /// Update a pull request.
520    ///
521    /// # Errors
522    /// Returns error if PR update fails.
523    pub async fn update_pr(
524        &self,
525        owner: &str,
526        repo: &str,
527        number: u64,
528        update: UpdatePullRequest,
529    ) -> Result<PullRequest> {
530        let api_pr: ApiPullRequest = self
531            .patch(&format!("/repos/{owner}/{repo}/pulls/{number}"), &update)
532            .await?;
533
534        Ok(api_pr.into_pull_request())
535    }
536
537    // === Check Runs ===
538
539    /// Get check runs for a commit.
540    ///
541    /// # Errors
542    /// Returns error if API call fails.
543    pub async fn get_check_runs(
544        &self,
545        owner: &str,
546        repo: &str,
547        commit_sha: &str,
548    ) -> Result<Vec<CheckRun>> {
549        #[derive(serde::Deserialize)]
550        struct Response {
551            check_runs: Vec<ApiCheckRun>,
552        }
553
554        #[derive(serde::Deserialize)]
555        struct ApiCheckRun {
556            name: String,
557            status: String,
558            conclusion: Option<String>,
559            details_url: Option<String>,
560        }
561
562        let response: Response = self
563            .get(&format!(
564                "/repos/{owner}/{repo}/commits/{commit_sha}/check-runs"
565            ))
566            .await?;
567
568        Ok(response
569            .check_runs
570            .into_iter()
571            .map(|cr| CheckRun {
572                name: cr.name,
573                status: match (cr.status.as_str(), cr.conclusion.as_deref()) {
574                    ("queued", _) => crate::types::CheckStatus::Queued,
575                    ("in_progress", _) => crate::types::CheckStatus::InProgress,
576                    ("completed", Some("success")) => crate::types::CheckStatus::Success,
577                    ("completed", Some("skipped")) => crate::types::CheckStatus::Skipped,
578                    ("completed", Some("cancelled")) => crate::types::CheckStatus::Cancelled,
579                    // Any other status (failure, timed_out, action_required, etc.) treated as failure
580                    _ => crate::types::CheckStatus::Failure,
581                },
582                details_url: cr.details_url,
583            })
584            .collect())
585    }
586
587    // === Merge Operations ===
588
589    /// Merge a pull request.
590    ///
591    /// # Errors
592    /// Returns error if merge fails.
593    pub async fn merge_pr(
594        &self,
595        owner: &str,
596        repo: &str,
597        number: u64,
598        merge: MergePullRequest,
599    ) -> Result<MergeResult> {
600        self.put(
601            &format!("/repos/{owner}/{repo}/pulls/{number}/merge"),
602            &merge,
603        )
604        .await
605    }
606
607    // === Ref Operations ===
608
609    /// Delete a git reference (branch).
610    ///
611    /// # Errors
612    /// Returns error if deletion fails.
613    pub async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
614        self.delete(&format!("/repos/{owner}/{repo}/git/refs/heads/{ref_name}"))
615            .await
616    }
617
618    // === Repository Operations ===
619
620    /// Get the repository's default branch name.
621    ///
622    /// # Errors
623    /// Returns error if API call fails.
624    pub async fn get_default_branch(&self, owner: &str, repo: &str) -> Result<String> {
625        #[derive(serde::Deserialize)]
626        struct RepoInfo {
627            default_branch: String,
628        }
629
630        let info: RepoInfo = self.get(&format!("/repos/{owner}/{repo}")).await?;
631        Ok(info.default_branch)
632    }
633
634    // === Comment Operations ===
635
636    /// List comments on a pull request.
637    ///
638    /// # Errors
639    /// Returns error if request fails.
640    pub async fn list_pr_comments(
641        &self,
642        owner: &str,
643        repo: &str,
644        pr_number: u64,
645    ) -> Result<Vec<crate::types::IssueComment>> {
646        self.get(&format!(
647            "/repos/{owner}/{repo}/issues/{pr_number}/comments"
648        ))
649        .await
650    }
651
652    /// Create a comment on a pull request.
653    ///
654    /// # Errors
655    /// Returns error if request fails.
656    pub async fn create_pr_comment(
657        &self,
658        owner: &str,
659        repo: &str,
660        pr_number: u64,
661        comment: crate::types::CreateComment,
662    ) -> Result<crate::types::IssueComment> {
663        self.post(
664            &format!("/repos/{owner}/{repo}/issues/{pr_number}/comments"),
665            &comment,
666        )
667        .await
668    }
669
670    /// Update a comment on a pull request.
671    ///
672    /// # Errors
673    /// Returns error if request fails.
674    pub async fn update_pr_comment(
675        &self,
676        owner: &str,
677        repo: &str,
678        comment_id: u64,
679        comment: crate::types::UpdateComment,
680    ) -> Result<crate::types::IssueComment> {
681        self.patch(
682            &format!("/repos/{owner}/{repo}/issues/comments/{comment_id}"),
683            &comment,
684        )
685        .await
686    }
687}
688
689impl std::fmt::Debug for GitHubClient {
690    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
691        f.debug_struct("GitHubClient")
692            .field("base_url", &self.base_url)
693            .field("token", &"[redacted]")
694            .finish_non_exhaustive()
695    }
696}
697
698/// Build a GraphQL query to fetch multiple PRs in a single request.
699fn build_graphql_pr_query(numbers: &[u64]) -> String {
700    const PR_FIELDS: &str = "number state merged isDraft headRefName baseRefName url";
701
702    let pr_queries: Vec<String> = numbers
703        .iter()
704        .enumerate()
705        .map(|(i, num)| format!("pr{i}: pullRequest(number: {num}) {{ {PR_FIELDS} }}"))
706        .collect();
707
708    format!(
709        r"query($owner: String!, $repo: String!) {{ repository(owner: $owner, name: $repo) {{ {pr_queries} }} }}",
710        pr_queries = pr_queries.join(" ")
711    )
712}
713
714// === Trait Implementation ===
715
716impl GitHubApi for GitHubClient {
717    async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
718        self.get_pr(owner, repo, number).await
719    }
720
721    async fn get_prs_batch(
722        &self,
723        owner: &str,
724        repo: &str,
725        numbers: &[u64],
726    ) -> Result<std::collections::HashMap<u64, PullRequest>> {
727        self.get_prs_batch(owner, repo, numbers).await
728    }
729
730    async fn find_pr_for_branch(
731        &self,
732        owner: &str,
733        repo: &str,
734        branch: &str,
735    ) -> Result<Option<PullRequest>> {
736        self.find_pr_for_branch(owner, repo, branch).await
737    }
738
739    async fn create_pr(
740        &self,
741        owner: &str,
742        repo: &str,
743        pr: CreatePullRequest,
744    ) -> Result<PullRequest> {
745        self.create_pr(owner, repo, pr).await
746    }
747
748    async fn update_pr(
749        &self,
750        owner: &str,
751        repo: &str,
752        number: u64,
753        update: UpdatePullRequest,
754    ) -> Result<PullRequest> {
755        self.update_pr(owner, repo, number, update).await
756    }
757
758    async fn get_check_runs(
759        &self,
760        owner: &str,
761        repo: &str,
762        commit_sha: &str,
763    ) -> Result<Vec<CheckRun>> {
764        self.get_check_runs(owner, repo, commit_sha).await
765    }
766
767    async fn merge_pr(
768        &self,
769        owner: &str,
770        repo: &str,
771        number: u64,
772        merge: MergePullRequest,
773    ) -> Result<MergeResult> {
774        self.merge_pr(owner, repo, number, merge).await
775    }
776
777    async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
778        self.delete_ref(owner, repo, ref_name).await
779    }
780
781    async fn get_default_branch(&self, owner: &str, repo: &str) -> Result<String> {
782        self.get_default_branch(owner, repo).await
783    }
784
785    async fn list_pr_comments(
786        &self,
787        owner: &str,
788        repo: &str,
789        pr_number: u64,
790    ) -> Result<Vec<IssueComment>> {
791        self.list_pr_comments(owner, repo, pr_number).await
792    }
793
794    async fn create_pr_comment(
795        &self,
796        owner: &str,
797        repo: &str,
798        pr_number: u64,
799        comment: CreateComment,
800    ) -> Result<IssueComment> {
801        self.create_pr_comment(owner, repo, pr_number, comment)
802            .await
803    }
804
805    async fn update_pr_comment(
806        &self,
807        owner: &str,
808        repo: &str,
809        comment_id: u64,
810        comment: UpdateComment,
811    ) -> Result<IssueComment> {
812        self.update_pr_comment(owner, repo, comment_id, comment)
813            .await
814    }
815}
816
817#[cfg(test)]
818#[allow(clippy::unwrap_used)]
819mod tests {
820    use super::*;
821    use crate::types::{CheckStatus, MergeMethod};
822    use secrecy::SecretString;
823    use wiremock::matchers::{header, method, path, query_param};
824    use wiremock::{Mock, MockServer, ResponseTemplate};
825
826    /// Create a test client pointing to the mock server.
827    fn test_client(base_url: &str) -> GitHubClient {
828        let auth = Auth::Token(SecretString::from("test-token"));
829        GitHubClient::with_base_url(&auth, base_url).unwrap()
830    }
831
832    /// Standard PR response JSON for testing.
833    fn pr_response_json(number: u64, state: &str, merged: bool) -> serde_json::Value {
834        serde_json::json!({
835            "number": number,
836            "title": format!("PR #{number}"),
837            "body": "Test body",
838            "state": state,
839            "merged": merged,
840            "draft": false,
841            "html_url": format!("https://github.com/owner/repo/pull/{number}"),
842            "head": { "ref": "feature-branch" },
843            "base": { "ref": "main" },
844            "mergeable": true,
845            "mergeable_state": "clean"
846        })
847    }
848
849    // === GET PR Tests ===
850
851    #[tokio::test]
852    async fn test_get_pr_success() {
853        let mock_server = MockServer::start().await;
854
855        Mock::given(method("GET"))
856            .and(path("/repos/owner/repo/pulls/123"))
857            .and(header("authorization", "Bearer test-token"))
858            .respond_with(
859                ResponseTemplate::new(200).set_body_json(pr_response_json(123, "open", false)),
860            )
861            .mount(&mock_server)
862            .await;
863
864        let client = test_client(&mock_server.uri());
865        let pr = client.get_pr("owner", "repo", 123).await.unwrap();
866
867        assert_eq!(pr.number, 123);
868        assert_eq!(pr.title, "PR #123");
869        assert_eq!(pr.state, PullRequestState::Open);
870        assert_eq!(pr.head_branch, "feature-branch");
871        assert_eq!(pr.base_branch, "main");
872    }
873
874    #[tokio::test]
875    async fn test_get_pr_merged() {
876        let mock_server = MockServer::start().await;
877
878        Mock::given(method("GET"))
879            .and(path("/repos/owner/repo/pulls/456"))
880            .respond_with(
881                ResponseTemplate::new(200).set_body_json(pr_response_json(456, "closed", true)),
882            )
883            .mount(&mock_server)
884            .await;
885
886        let client = test_client(&mock_server.uri());
887        let pr = client.get_pr("owner", "repo", 456).await.unwrap();
888
889        assert_eq!(pr.state, PullRequestState::Merged);
890    }
891
892    #[tokio::test]
893    async fn test_get_pr_closed() {
894        let mock_server = MockServer::start().await;
895
896        Mock::given(method("GET"))
897            .and(path("/repos/owner/repo/pulls/789"))
898            .respond_with(
899                ResponseTemplate::new(200).set_body_json(pr_response_json(789, "closed", false)),
900            )
901            .mount(&mock_server)
902            .await;
903
904        let client = test_client(&mock_server.uri());
905        let pr = client.get_pr("owner", "repo", 789).await.unwrap();
906
907        assert_eq!(pr.state, PullRequestState::Closed);
908    }
909
910    #[tokio::test]
911    async fn test_get_pr_not_found() {
912        let mock_server = MockServer::start().await;
913
914        Mock::given(method("GET"))
915            .and(path("/repos/owner/repo/pulls/999"))
916            .respond_with(ResponseTemplate::new(404).set_body_json(serde_json::json!({
917                "message": "Not Found"
918            })))
919            .mount(&mock_server)
920            .await;
921
922        let client = test_client(&mock_server.uri());
923        let result = client.get_pr("owner", "repo", 999).await;
924
925        assert!(result.is_err());
926        let err = result.unwrap_err();
927        assert!(matches!(err, Error::ApiError { status: 404, .. }));
928    }
929
930    // === Authentication Error Tests ===
931
932    #[tokio::test]
933    async fn test_unauthorized_error() {
934        let mock_server = MockServer::start().await;
935
936        Mock::given(method("GET"))
937            .and(path("/repos/owner/repo/pulls/123"))
938            .respond_with(ResponseTemplate::new(401).set_body_json(serde_json::json!({
939                "message": "Bad credentials"
940            })))
941            .mount(&mock_server)
942            .await;
943
944        let client = test_client(&mock_server.uri());
945        let result = client.get_pr("owner", "repo", 123).await;
946
947        assert!(matches!(result, Err(Error::AuthenticationFailed)));
948    }
949
950    #[tokio::test]
951    async fn test_rate_limited_error() {
952        let mock_server = MockServer::start().await;
953
954        Mock::given(method("GET"))
955            .and(path("/repos/owner/repo/pulls/123"))
956            .respond_with(
957                ResponseTemplate::new(403)
958                    .insert_header("x-ratelimit-remaining", "0")
959                    .set_body_json(serde_json::json!({
960                        "message": "API rate limit exceeded"
961                    })),
962            )
963            .mount(&mock_server)
964            .await;
965
966        let client = test_client(&mock_server.uri());
967        let result = client.get_pr("owner", "repo", 123).await;
968
969        assert!(matches!(result, Err(Error::RateLimited)));
970    }
971
972    // === Find PR for Branch Tests ===
973
974    #[tokio::test]
975    async fn test_find_pr_for_branch_found() {
976        let mock_server = MockServer::start().await;
977
978        Mock::given(method("GET"))
979            .and(path("/repos/owner/repo/pulls"))
980            .and(query_param("head", "owner:feature"))
981            .and(query_param("state", "open"))
982            .respond_with(
983                ResponseTemplate::new(200)
984                    .set_body_json(serde_json::json!([pr_response_json(42, "open", false)])),
985            )
986            .mount(&mock_server)
987            .await;
988
989        let client = test_client(&mock_server.uri());
990        let pr = client
991            .find_pr_for_branch("owner", "repo", "feature")
992            .await
993            .unwrap();
994
995        assert!(pr.is_some());
996        assert_eq!(pr.unwrap().number, 42);
997    }
998
999    #[tokio::test]
1000    async fn test_find_pr_for_branch_not_found() {
1001        let mock_server = MockServer::start().await;
1002
1003        Mock::given(method("GET"))
1004            .and(path("/repos/owner/repo/pulls"))
1005            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!([])))
1006            .mount(&mock_server)
1007            .await;
1008
1009        let client = test_client(&mock_server.uri());
1010        let pr = client
1011            .find_pr_for_branch("owner", "repo", "nonexistent")
1012            .await
1013            .unwrap();
1014
1015        assert!(pr.is_none());
1016    }
1017
1018    // === Create PR Tests ===
1019
1020    #[tokio::test]
1021    async fn test_create_pr_success() {
1022        let mock_server = MockServer::start().await;
1023
1024        Mock::given(method("POST"))
1025            .and(path("/repos/owner/repo/pulls"))
1026            .and(header("authorization", "Bearer test-token"))
1027            .respond_with(
1028                ResponseTemplate::new(201).set_body_json(pr_response_json(100, "open", false)),
1029            )
1030            .mount(&mock_server)
1031            .await;
1032
1033        let client = test_client(&mock_server.uri());
1034        let create_pr = CreatePullRequest {
1035            title: "New Feature".into(),
1036            body: "Description".into(),
1037            head: "feature".into(),
1038            base: "main".into(),
1039            draft: false,
1040        };
1041
1042        let pr = client.create_pr("owner", "repo", create_pr).await.unwrap();
1043
1044        assert_eq!(pr.number, 100);
1045        assert_eq!(pr.state, PullRequestState::Open);
1046    }
1047
1048    // === Update PR Tests ===
1049
1050    #[tokio::test]
1051    async fn test_update_pr_success() {
1052        let mock_server = MockServer::start().await;
1053
1054        Mock::given(method("PATCH"))
1055            .and(path("/repos/owner/repo/pulls/123"))
1056            .respond_with(
1057                ResponseTemplate::new(200).set_body_json(pr_response_json(123, "open", false)),
1058            )
1059            .mount(&mock_server)
1060            .await;
1061
1062        let client = test_client(&mock_server.uri());
1063        let update = UpdatePullRequest {
1064            title: Some("Updated Title".into()),
1065            body: None,
1066            base: None,
1067        };
1068
1069        let pr = client
1070            .update_pr("owner", "repo", 123, update)
1071            .await
1072            .unwrap();
1073
1074        assert_eq!(pr.number, 123);
1075    }
1076
1077    // === Get Check Runs Tests ===
1078
1079    #[tokio::test]
1080    async fn test_get_check_runs_success() {
1081        let mock_server = MockServer::start().await;
1082
1083        Mock::given(method("GET"))
1084            .and(path("/repos/owner/repo/commits/abc123/check-runs"))
1085            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1086                "total_count": 3,
1087                "check_runs": [
1088                    {
1089                        "name": "CI",
1090                        "status": "completed",
1091                        "conclusion": "success",
1092                        "details_url": "https://example.com/ci"
1093                    },
1094                    {
1095                        "name": "Lint",
1096                        "status": "in_progress",
1097                        "conclusion": null,
1098                        "details_url": null
1099                    },
1100                    {
1101                        "name": "Deploy",
1102                        "status": "queued",
1103                        "conclusion": null,
1104                        "details_url": null
1105                    }
1106                ]
1107            })))
1108            .mount(&mock_server)
1109            .await;
1110
1111        let client = test_client(&mock_server.uri());
1112        let checks = client
1113            .get_check_runs("owner", "repo", "abc123")
1114            .await
1115            .unwrap();
1116
1117        assert_eq!(checks.len(), 3);
1118        assert_eq!(checks[0].name, "CI");
1119        assert_eq!(checks[0].status, CheckStatus::Success);
1120        assert_eq!(checks[1].name, "Lint");
1121        assert_eq!(checks[1].status, CheckStatus::InProgress);
1122        assert_eq!(checks[2].name, "Deploy");
1123        assert_eq!(checks[2].status, CheckStatus::Queued);
1124    }
1125
1126    #[tokio::test]
1127    async fn test_get_check_runs_various_statuses() {
1128        let mock_server = MockServer::start().await;
1129
1130        Mock::given(method("GET"))
1131            .and(path("/repos/owner/repo/commits/def456/check-runs"))
1132            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1133                "total_count": 4,
1134                "check_runs": [
1135                    { "name": "skipped", "status": "completed", "conclusion": "skipped", "details_url": null },
1136                    { "name": "cancelled", "status": "completed", "conclusion": "cancelled", "details_url": null },
1137                    { "name": "failure", "status": "completed", "conclusion": "failure", "details_url": null },
1138                    { "name": "timed_out", "status": "completed", "conclusion": "timed_out", "details_url": null }
1139                ]
1140            })))
1141            .mount(&mock_server)
1142            .await;
1143
1144        let client = test_client(&mock_server.uri());
1145        let checks = client
1146            .get_check_runs("owner", "repo", "def456")
1147            .await
1148            .unwrap();
1149
1150        assert_eq!(checks[0].status, CheckStatus::Skipped);
1151        assert_eq!(checks[1].status, CheckStatus::Cancelled);
1152        assert_eq!(checks[2].status, CheckStatus::Failure);
1153        assert_eq!(checks[3].status, CheckStatus::Failure); // timed_out maps to failure
1154    }
1155
1156    // === Merge PR Tests ===
1157
1158    #[tokio::test]
1159    async fn test_merge_pr_success() {
1160        let mock_server = MockServer::start().await;
1161
1162        Mock::given(method("PUT"))
1163            .and(path("/repos/owner/repo/pulls/123/merge"))
1164            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1165                "sha": "abc123def456",
1166                "merged": true,
1167                "message": "Pull Request successfully merged"
1168            })))
1169            .mount(&mock_server)
1170            .await;
1171
1172        let client = test_client(&mock_server.uri());
1173        let merge = MergePullRequest {
1174            commit_title: Some("Merge PR #123".into()),
1175            commit_message: None,
1176            merge_method: MergeMethod::Squash,
1177        };
1178
1179        let result = client.merge_pr("owner", "repo", 123, merge).await.unwrap();
1180
1181        assert!(result.merged);
1182        assert_eq!(result.sha, "abc123def456");
1183    }
1184
1185    // === Delete Ref Tests ===
1186
1187    #[tokio::test]
1188    async fn test_delete_ref_success() {
1189        let mock_server = MockServer::start().await;
1190
1191        Mock::given(method("DELETE"))
1192            .and(path("/repos/owner/repo/git/refs/heads/feature-branch"))
1193            .respond_with(ResponseTemplate::new(204))
1194            .mount(&mock_server)
1195            .await;
1196
1197        let client = test_client(&mock_server.uri());
1198        let result = client.delete_ref("owner", "repo", "feature-branch").await;
1199
1200        assert!(result.is_ok());
1201    }
1202
1203    #[tokio::test]
1204    async fn test_delete_ref_not_found() {
1205        let mock_server = MockServer::start().await;
1206
1207        Mock::given(method("DELETE"))
1208            .and(path("/repos/owner/repo/git/refs/heads/nonexistent"))
1209            .respond_with(ResponseTemplate::new(422).set_body_json(serde_json::json!({
1210                "message": "Reference does not exist"
1211            })))
1212            .mount(&mock_server)
1213            .await;
1214
1215        let client = test_client(&mock_server.uri());
1216        let result = client.delete_ref("owner", "repo", "nonexistent").await;
1217
1218        assert!(result.is_err());
1219    }
1220
1221    #[tokio::test]
1222    async fn test_delete_ref_rate_limited() {
1223        let mock_server = MockServer::start().await;
1224
1225        Mock::given(method("DELETE"))
1226            .and(path("/repos/owner/repo/git/refs/heads/branch"))
1227            .respond_with(
1228                ResponseTemplate::new(403)
1229                    .insert_header("x-ratelimit-remaining", "0")
1230                    .set_body_json(serde_json::json!({ "message": "Rate limited" })),
1231            )
1232            .mount(&mock_server)
1233            .await;
1234
1235        let client = test_client(&mock_server.uri());
1236        let result = client.delete_ref("owner", "repo", "branch").await;
1237
1238        assert!(matches!(result, Err(Error::RateLimited)));
1239    }
1240
1241    // === Get Default Branch Tests ===
1242
1243    #[tokio::test]
1244    async fn test_get_default_branch_success() {
1245        let mock_server = MockServer::start().await;
1246
1247        Mock::given(method("GET"))
1248            .and(path("/repos/owner/repo"))
1249            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1250                "default_branch": "main"
1251            })))
1252            .mount(&mock_server)
1253            .await;
1254
1255        let client = test_client(&mock_server.uri());
1256        let branch = client.get_default_branch("owner", "repo").await.unwrap();
1257
1258        assert_eq!(branch, "main");
1259    }
1260
1261    // === Comment Tests ===
1262
1263    #[tokio::test]
1264    async fn test_list_pr_comments_success() {
1265        let mock_server = MockServer::start().await;
1266
1267        Mock::given(method("GET"))
1268            .and(path("/repos/owner/repo/issues/123/comments"))
1269            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!([
1270                { "id": 1, "body": "First comment" },
1271                { "id": 2, "body": "Second comment" }
1272            ])))
1273            .mount(&mock_server)
1274            .await;
1275
1276        let client = test_client(&mock_server.uri());
1277        let comments = client.list_pr_comments("owner", "repo", 123).await.unwrap();
1278
1279        assert_eq!(comments.len(), 2);
1280        assert_eq!(comments[0].id, 1);
1281        assert_eq!(comments[0].body, Some("First comment".into()));
1282    }
1283
1284    #[tokio::test]
1285    async fn test_create_pr_comment_success() {
1286        let mock_server = MockServer::start().await;
1287
1288        Mock::given(method("POST"))
1289            .and(path("/repos/owner/repo/issues/123/comments"))
1290            .respond_with(ResponseTemplate::new(201).set_body_json(serde_json::json!({
1291                "id": 42,
1292                "body": "New comment"
1293            })))
1294            .mount(&mock_server)
1295            .await;
1296
1297        let client = test_client(&mock_server.uri());
1298        let comment = CreateComment {
1299            body: "New comment".into(),
1300        };
1301
1302        let result = client
1303            .create_pr_comment("owner", "repo", 123, comment)
1304            .await
1305            .unwrap();
1306
1307        assert_eq!(result.id, 42);
1308        assert_eq!(result.body, Some("New comment".into()));
1309    }
1310
1311    #[tokio::test]
1312    async fn test_update_pr_comment_success() {
1313        let mock_server = MockServer::start().await;
1314
1315        Mock::given(method("PATCH"))
1316            .and(path("/repos/owner/repo/issues/comments/42"))
1317            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1318                "id": 42,
1319                "body": "Updated comment"
1320            })))
1321            .mount(&mock_server)
1322            .await;
1323
1324        let client = test_client(&mock_server.uri());
1325        let update = UpdateComment {
1326            body: "Updated comment".into(),
1327        };
1328
1329        let result = client
1330            .update_pr_comment("owner", "repo", 42, update)
1331            .await
1332            .unwrap();
1333
1334        assert_eq!(result.body, Some("Updated comment".into()));
1335    }
1336
1337    // === GraphQL Batch Tests ===
1338
1339    #[tokio::test]
1340    async fn test_get_prs_batch_empty() {
1341        let mock_server = MockServer::start().await;
1342        let client = test_client(&mock_server.uri());
1343
1344        // Empty input should return empty map without making any requests
1345        let result = client.get_prs_batch("owner", "repo", &[]).await.unwrap();
1346        assert!(result.is_empty());
1347    }
1348
1349    #[tokio::test]
1350    async fn test_get_prs_batch_success() {
1351        let mock_server = MockServer::start().await;
1352
1353        Mock::given(method("POST"))
1354            .and(path("/graphql"))
1355            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1356                "data": {
1357                    "repository": {
1358                        "pr0": {
1359                            "number": 1,
1360                            "state": "OPEN",
1361                            "merged": false,
1362                            "isDraft": false,
1363                            "headRefName": "feature-1",
1364                            "baseRefName": "main",
1365                            "url": "https://github.com/owner/repo/pull/1"
1366                        },
1367                        "pr1": {
1368                            "number": 2,
1369                            "state": "MERGED",
1370                            "merged": true,
1371                            "isDraft": false,
1372                            "headRefName": "feature-2",
1373                            "baseRefName": "main",
1374                            "url": "https://github.com/owner/repo/pull/2"
1375                        },
1376                        "pr2": null
1377                    }
1378                }
1379            })))
1380            .mount(&mock_server)
1381            .await;
1382
1383        let client = test_client(&mock_server.uri());
1384        let result = client
1385            .get_prs_batch("owner", "repo", &[1, 2, 999])
1386            .await
1387            .unwrap();
1388
1389        assert_eq!(result.len(), 2);
1390        assert_eq!(result.get(&1).unwrap().state, PullRequestState::Open);
1391        assert_eq!(result.get(&2).unwrap().state, PullRequestState::Merged);
1392        assert!(!result.contains_key(&999)); // PR 999 was null
1393    }
1394
1395    #[tokio::test]
1396    async fn test_get_prs_batch_graphql_error() {
1397        let mock_server = MockServer::start().await;
1398
1399        Mock::given(method("POST"))
1400            .and(path("/graphql"))
1401            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1402                "data": null,
1403                "errors": [
1404                    { "message": "Something went wrong" }
1405                ]
1406            })))
1407            .mount(&mock_server)
1408            .await;
1409
1410        let client = test_client(&mock_server.uri());
1411        let result = client.get_prs_batch("owner", "repo", &[1]).await;
1412
1413        assert!(result.is_err());
1414        let err = result.unwrap_err();
1415        assert!(matches!(err, Error::ApiError { status: 200, .. }));
1416    }
1417
1418    #[tokio::test]
1419    async fn test_get_prs_batch_auth_error() {
1420        let mock_server = MockServer::start().await;
1421
1422        Mock::given(method("POST"))
1423            .and(path("/graphql"))
1424            .respond_with(ResponseTemplate::new(401).set_body_json(serde_json::json!({
1425                "message": "Bad credentials"
1426            })))
1427            .mount(&mock_server)
1428            .await;
1429
1430        let client = test_client(&mock_server.uri());
1431        let result = client.get_prs_batch("owner", "repo", &[1]).await;
1432
1433        assert!(matches!(result, Err(Error::AuthenticationFailed)));
1434    }
1435
1436    // === Helper Function Tests ===
1437
1438    #[test]
1439    fn test_build_graphql_pr_query() {
1440        let query = build_graphql_pr_query(&[1, 42, 100]);
1441
1442        assert!(query.contains("pr0: pullRequest(number: 1)"));
1443        assert!(query.contains("pr1: pullRequest(number: 42)"));
1444        assert!(query.contains("pr2: pullRequest(number: 100)"));
1445        assert!(query.contains("$owner: String!"));
1446        assert!(query.contains("$repo: String!"));
1447    }
1448
1449    // === Debug Implementation Test ===
1450
1451    #[test]
1452    fn test_github_client_debug_redacts_token() {
1453        let auth = Auth::Token(SecretString::from("super-secret-token"));
1454        let client = GitHubClient::with_base_url(&auth, "https://api.example.com").unwrap();
1455
1456        let debug_output = format!("{client:?}");
1457
1458        assert!(debug_output.contains("[redacted]"));
1459        assert!(!debug_output.contains("super-secret-token"));
1460    }
1461}