1use reqwest::Client;
4use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
5use secrecy::{ExposeSecret, SecretString};
6use serde::de::DeserializeOwned;
7
8use crate::auth::Auth;
9use crate::error::{Error, Result};
10use crate::traits::GitHubApi;
11use crate::types::{
12 CheckRun, CreateComment, CreatePullRequest, IssueComment, MergePullRequest, MergeResult,
13 PullRequest, PullRequestState, UpdateComment, UpdatePullRequest,
14};
15
16#[derive(serde::Deserialize)]
20struct ApiPullRequest {
21 number: u64,
22 title: String,
23 body: Option<String>,
24 state: String,
25 #[serde(default)]
27 merged: bool,
28 draft: bool,
29 html_url: String,
30 head: ApiBranch,
31 base: ApiBranch,
32 mergeable: Option<bool>,
34 mergeable_state: Option<String>,
36}
37
38#[derive(serde::Deserialize)]
40struct ApiBranch {
41 #[serde(rename = "ref")]
42 ref_name: String,
43}
44
45impl ApiPullRequest {
46 fn into_pull_request(self) -> PullRequest {
48 let state = if self.merged {
50 PullRequestState::Merged
51 } else {
52 match self.state.as_str() {
53 "open" => PullRequestState::Open,
54 _ => PullRequestState::Closed,
55 }
56 };
57
58 PullRequest {
59 number: self.number,
60 title: self.title,
61 body: self.body,
62 state,
63 draft: self.draft,
64 head_branch: self.head.ref_name,
65 base_branch: self.base.ref_name,
66 html_url: self.html_url,
67 mergeable: self.mergeable,
68 mergeable_state: self.mergeable_state,
69 }
70 }
71
72 fn into_pull_request_with_state(self, state: PullRequestState) -> PullRequest {
74 PullRequest {
75 number: self.number,
76 title: self.title,
77 body: self.body,
78 state,
79 draft: self.draft,
80 head_branch: self.head.ref_name,
81 base_branch: self.base.ref_name,
82 html_url: self.html_url,
83 mergeable: self.mergeable,
84 mergeable_state: self.mergeable_state,
85 }
86 }
87}
88
89#[derive(serde::Serialize)]
93struct GraphQLRequest {
94 query: String,
95 variables: GraphQLVariables,
96}
97
98#[derive(serde::Serialize)]
100struct GraphQLVariables {
101 owner: String,
102 repo: String,
103}
104
105#[derive(serde::Deserialize)]
107#[serde(rename_all = "camelCase")]
108struct GraphQLPullRequest {
109 number: u64,
110 state: String,
111 merged: bool,
112 is_draft: bool,
113 head_ref_name: String,
114 base_ref_name: String,
115 url: String,
116}
117
118impl GraphQLPullRequest {
119 fn into_pull_request(self) -> PullRequest {
120 let state = if self.merged {
121 PullRequestState::Merged
122 } else if self.state == "OPEN" {
123 PullRequestState::Open
124 } else {
125 PullRequestState::Closed
126 };
127
128 PullRequest {
129 number: self.number,
130 title: String::new(), body: None,
132 state,
133 draft: self.is_draft,
134 head_branch: self.head_ref_name,
135 base_branch: self.base_ref_name,
136 html_url: self.url,
137 mergeable: None, mergeable_state: None,
139 }
140 }
141}
142
143#[derive(serde::Deserialize)]
144struct GraphQLResponse {
145 data: Option<GraphQLData>,
146 errors: Option<Vec<GraphQLError>>,
147}
148
149#[derive(serde::Deserialize)]
150struct GraphQLData {
151 repository: Option<serde_json::Value>,
152}
153
154#[derive(serde::Deserialize)]
155struct GraphQLError {
156 message: String,
157}
158
159pub struct GitHubClient {
161 client: Client,
162 base_url: String,
163 token: SecretString,
165}
166
167impl GitHubClient {
168 pub const DEFAULT_API_URL: &'static str = "https://api.github.com";
170
171 pub fn new(auth: &Auth) -> Result<Self> {
176 Self::with_base_url(auth, Self::DEFAULT_API_URL)
177 }
178
179 pub fn with_base_url(auth: &Auth, base_url: impl Into<String>) -> Result<Self> {
184 let token = auth.resolve()?;
185
186 let mut headers = HeaderMap::new();
187 headers.insert(
188 ACCEPT,
189 HeaderValue::from_static("application/vnd.github+json"),
190 );
191 headers.insert(USER_AGENT, HeaderValue::from_static("rung-cli"));
192 headers.insert(
193 "X-GitHub-Api-Version",
194 HeaderValue::from_static("2022-11-28"),
195 );
196
197 let client = Client::builder().default_headers(headers).build()?;
198
199 Ok(Self {
200 client,
201 base_url: base_url.into(),
202 token,
203 })
204 }
205
206 async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
208 let url = format!("{}{}", self.base_url, path);
209 let response = self
210 .client
211 .get(&url)
212 .header(
213 AUTHORIZATION,
214 format!("Bearer {}", self.token.expose_secret()),
215 )
216 .send()
217 .await?;
218
219 self.handle_response(response).await
220 }
221
222 async fn post<T: DeserializeOwned, B: serde::Serialize + Sync>(
224 &self,
225 path: &str,
226 body: &B,
227 ) -> Result<T> {
228 let url = format!("{}{}", self.base_url, path);
229 let response = self
230 .client
231 .post(&url)
232 .header(
233 AUTHORIZATION,
234 format!("Bearer {}", self.token.expose_secret()),
235 )
236 .json(body)
237 .send()
238 .await?;
239
240 self.handle_response(response).await
241 }
242
243 async fn patch<T: DeserializeOwned, B: serde::Serialize + Sync>(
245 &self,
246 path: &str,
247 body: &B,
248 ) -> Result<T> {
249 let url = format!("{}{}", self.base_url, path);
250 let response = self
251 .client
252 .patch(&url)
253 .header(
254 AUTHORIZATION,
255 format!("Bearer {}", self.token.expose_secret()),
256 )
257 .json(body)
258 .send()
259 .await?;
260
261 self.handle_response(response).await
262 }
263
264 async fn put<T: DeserializeOwned, B: serde::Serialize + Sync>(
266 &self,
267 path: &str,
268 body: &B,
269 ) -> Result<T> {
270 let url = format!("{}{}", self.base_url, path);
271 let response = self
272 .client
273 .put(&url)
274 .header(
275 AUTHORIZATION,
276 format!("Bearer {}", self.token.expose_secret()),
277 )
278 .json(body)
279 .send()
280 .await?;
281
282 self.handle_response(response).await
283 }
284
285 async fn delete(&self, path: &str) -> Result<()> {
287 let url = format!("{}{}", self.base_url, path);
288 let response = self
289 .client
290 .delete(&url)
291 .header(
292 AUTHORIZATION,
293 format!("Bearer {}", self.token.expose_secret()),
294 )
295 .send()
296 .await?;
297
298 let status = response.status();
299 if status.is_success() || status.as_u16() == 204 {
300 return Ok(());
301 }
302
303 let status_code = status.as_u16();
304 match status_code {
305 401 => Err(Error::AuthenticationFailed),
306 403 if response
307 .headers()
308 .get("x-ratelimit-remaining")
309 .is_some_and(|v| v == "0") =>
310 {
311 Err(Error::RateLimited)
312 }
313 _ => {
314 let text = response.text().await.unwrap_or_default();
315 Err(Error::ApiError {
316 status: status_code,
317 message: text,
318 })
319 }
320 }
321 }
322
323 async fn handle_response<T: DeserializeOwned>(&self, response: reqwest::Response) -> Result<T> {
325 let status = response.status();
326
327 if status.is_success() {
328 let body = response.json().await?;
329 return Ok(body);
330 }
331
332 let status_code = status.as_u16();
334
335 match status_code {
336 401 => Err(Error::AuthenticationFailed),
337 403 if response
338 .headers()
339 .get("x-ratelimit-remaining")
340 .is_some_and(|v| v == "0") =>
341 {
342 Err(Error::RateLimited)
343 }
344 _ => {
345 let text = response.text().await.unwrap_or_default();
346 Err(Error::ApiError {
347 status: status_code,
348 message: text,
349 })
350 }
351 }
352 }
353
354 pub async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
361 let api_pr: ApiPullRequest = self
362 .get(&format!("/repos/{owner}/{repo}/pulls/{number}"))
363 .await?;
364
365 Ok(api_pr.into_pull_request())
366 }
367
368 pub async fn get_prs_batch(
379 &self,
380 owner: &str,
381 repo: &str,
382 numbers: &[u64],
383 ) -> Result<std::collections::HashMap<u64, PullRequest>> {
384 if numbers.is_empty() {
385 return Ok(std::collections::HashMap::new());
386 }
387
388 let query = build_graphql_pr_query(numbers);
389 let request = GraphQLRequest {
390 query,
391 variables: GraphQLVariables {
392 owner: owner.to_string(),
393 repo: repo.to_string(),
394 },
395 };
396 let url = format!("{}/graphql", self.base_url);
397
398 let response = self
399 .client
400 .post(&url)
401 .header(
402 AUTHORIZATION,
403 format!("Bearer {}", self.token.expose_secret()),
404 )
405 .json(&request)
406 .send()
407 .await?;
408
409 let status = response.status();
410 if !status.is_success() {
411 let status_code = status.as_u16();
412 return match status_code {
413 401 => Err(Error::AuthenticationFailed),
414 403 if response
415 .headers()
416 .get("x-ratelimit-remaining")
417 .is_some_and(|v| v == "0") =>
418 {
419 Err(Error::RateLimited)
420 }
421 _ => {
422 let text = response.text().await.unwrap_or_default();
423 Err(Error::ApiError {
424 status: status_code,
425 message: text,
426 })
427 }
428 };
429 }
430
431 let graphql_response: GraphQLResponse = response.json().await?;
432
433 if graphql_response.data.is_none() {
435 if let Some(errors) = graphql_response.errors
436 && !errors.is_empty()
437 {
438 let messages: Vec<_> = errors.iter().map(|e| e.message.as_str()).collect();
439 return Err(Error::ApiError {
440 status: 200,
441 message: messages.join("; "),
442 });
443 }
444 return Ok(std::collections::HashMap::new());
446 }
447
448 let mut result = std::collections::HashMap::new();
449
450 if let Some(data) = graphql_response.data {
451 if let Some(repo_data) = data.repository {
452 for (i, &num) in numbers.iter().enumerate() {
454 let key = format!("pr{i}");
455 if let Some(pr_value) = repo_data.get(&key)
456 && !pr_value.is_null()
457 && let Ok(pr) =
458 serde_json::from_value::<GraphQLPullRequest>(pr_value.clone())
459 {
460 result.insert(num, pr.into_pull_request());
461 }
462 }
463 } else if let Some(errors) = graphql_response.errors
464 && !errors.is_empty()
465 {
466 let messages: Vec<_> = errors.iter().map(|e| e.message.as_str()).collect();
468 return Err(Error::ApiError {
469 status: 200,
470 message: messages.join("; "),
471 });
472 }
473 }
474
475 Ok(result)
476 }
477
478 pub async fn find_pr_for_branch(
483 &self,
484 owner: &str,
485 repo: &str,
486 branch: &str,
487 ) -> Result<Option<PullRequest>> {
488 let prs: Vec<ApiPullRequest> = self
490 .get(&format!(
491 "/repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open"
492 ))
493 .await?;
494
495 Ok(prs
496 .into_iter()
497 .next()
498 .map(|api_pr| api_pr.into_pull_request_with_state(PullRequestState::Open)))
499 }
500
501 pub async fn create_pr(
506 &self,
507 owner: &str,
508 repo: &str,
509 pr: CreatePullRequest,
510 ) -> Result<PullRequest> {
511 let api_pr: ApiPullRequest = self
513 .post(&format!("/repos/{owner}/{repo}/pulls"), &pr)
514 .await?;
515
516 Ok(api_pr.into_pull_request_with_state(PullRequestState::Open))
517 }
518
519 pub async fn update_pr(
524 &self,
525 owner: &str,
526 repo: &str,
527 number: u64,
528 update: UpdatePullRequest,
529 ) -> Result<PullRequest> {
530 let api_pr: ApiPullRequest = self
531 .patch(&format!("/repos/{owner}/{repo}/pulls/{number}"), &update)
532 .await?;
533
534 Ok(api_pr.into_pull_request())
535 }
536
537 pub async fn get_check_runs(
544 &self,
545 owner: &str,
546 repo: &str,
547 commit_sha: &str,
548 ) -> Result<Vec<CheckRun>> {
549 #[derive(serde::Deserialize)]
550 struct Response {
551 check_runs: Vec<ApiCheckRun>,
552 }
553
554 #[derive(serde::Deserialize)]
555 struct ApiCheckRun {
556 name: String,
557 status: String,
558 conclusion: Option<String>,
559 details_url: Option<String>,
560 }
561
562 let response: Response = self
563 .get(&format!(
564 "/repos/{owner}/{repo}/commits/{commit_sha}/check-runs"
565 ))
566 .await?;
567
568 Ok(response
569 .check_runs
570 .into_iter()
571 .map(|cr| CheckRun {
572 name: cr.name,
573 status: match (cr.status.as_str(), cr.conclusion.as_deref()) {
574 ("queued", _) => crate::types::CheckStatus::Queued,
575 ("in_progress", _) => crate::types::CheckStatus::InProgress,
576 ("completed", Some("success")) => crate::types::CheckStatus::Success,
577 ("completed", Some("skipped")) => crate::types::CheckStatus::Skipped,
578 ("completed", Some("cancelled")) => crate::types::CheckStatus::Cancelled,
579 _ => crate::types::CheckStatus::Failure,
581 },
582 details_url: cr.details_url,
583 })
584 .collect())
585 }
586
587 pub async fn merge_pr(
594 &self,
595 owner: &str,
596 repo: &str,
597 number: u64,
598 merge: MergePullRequest,
599 ) -> Result<MergeResult> {
600 self.put(
601 &format!("/repos/{owner}/{repo}/pulls/{number}/merge"),
602 &merge,
603 )
604 .await
605 }
606
607 pub async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
614 self.delete(&format!("/repos/{owner}/{repo}/git/refs/heads/{ref_name}"))
615 .await
616 }
617
618 pub async fn get_default_branch(&self, owner: &str, repo: &str) -> Result<String> {
625 #[derive(serde::Deserialize)]
626 struct RepoInfo {
627 default_branch: String,
628 }
629
630 let info: RepoInfo = self.get(&format!("/repos/{owner}/{repo}")).await?;
631 Ok(info.default_branch)
632 }
633
634 pub async fn list_pr_comments(
641 &self,
642 owner: &str,
643 repo: &str,
644 pr_number: u64,
645 ) -> Result<Vec<crate::types::IssueComment>> {
646 self.get(&format!(
647 "/repos/{owner}/{repo}/issues/{pr_number}/comments"
648 ))
649 .await
650 }
651
652 pub async fn create_pr_comment(
657 &self,
658 owner: &str,
659 repo: &str,
660 pr_number: u64,
661 comment: crate::types::CreateComment,
662 ) -> Result<crate::types::IssueComment> {
663 self.post(
664 &format!("/repos/{owner}/{repo}/issues/{pr_number}/comments"),
665 &comment,
666 )
667 .await
668 }
669
670 pub async fn update_pr_comment(
675 &self,
676 owner: &str,
677 repo: &str,
678 comment_id: u64,
679 comment: crate::types::UpdateComment,
680 ) -> Result<crate::types::IssueComment> {
681 self.patch(
682 &format!("/repos/{owner}/{repo}/issues/comments/{comment_id}"),
683 &comment,
684 )
685 .await
686 }
687}
688
689impl std::fmt::Debug for GitHubClient {
690 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
691 f.debug_struct("GitHubClient")
692 .field("base_url", &self.base_url)
693 .field("token", &"[redacted]")
694 .finish_non_exhaustive()
695 }
696}
697
698fn build_graphql_pr_query(numbers: &[u64]) -> String {
700 const PR_FIELDS: &str = "number state merged isDraft headRefName baseRefName url";
701
702 let pr_queries: Vec<String> = numbers
703 .iter()
704 .enumerate()
705 .map(|(i, num)| format!("pr{i}: pullRequest(number: {num}) {{ {PR_FIELDS} }}"))
706 .collect();
707
708 format!(
709 r"query($owner: String!, $repo: String!) {{ repository(owner: $owner, name: $repo) {{ {pr_queries} }} }}",
710 pr_queries = pr_queries.join(" ")
711 )
712}
713
714impl GitHubApi for GitHubClient {
717 async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
718 self.get_pr(owner, repo, number).await
719 }
720
721 async fn get_prs_batch(
722 &self,
723 owner: &str,
724 repo: &str,
725 numbers: &[u64],
726 ) -> Result<std::collections::HashMap<u64, PullRequest>> {
727 self.get_prs_batch(owner, repo, numbers).await
728 }
729
730 async fn find_pr_for_branch(
731 &self,
732 owner: &str,
733 repo: &str,
734 branch: &str,
735 ) -> Result<Option<PullRequest>> {
736 self.find_pr_for_branch(owner, repo, branch).await
737 }
738
739 async fn create_pr(
740 &self,
741 owner: &str,
742 repo: &str,
743 pr: CreatePullRequest,
744 ) -> Result<PullRequest> {
745 self.create_pr(owner, repo, pr).await
746 }
747
748 async fn update_pr(
749 &self,
750 owner: &str,
751 repo: &str,
752 number: u64,
753 update: UpdatePullRequest,
754 ) -> Result<PullRequest> {
755 self.update_pr(owner, repo, number, update).await
756 }
757
758 async fn get_check_runs(
759 &self,
760 owner: &str,
761 repo: &str,
762 commit_sha: &str,
763 ) -> Result<Vec<CheckRun>> {
764 self.get_check_runs(owner, repo, commit_sha).await
765 }
766
767 async fn merge_pr(
768 &self,
769 owner: &str,
770 repo: &str,
771 number: u64,
772 merge: MergePullRequest,
773 ) -> Result<MergeResult> {
774 self.merge_pr(owner, repo, number, merge).await
775 }
776
777 async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
778 self.delete_ref(owner, repo, ref_name).await
779 }
780
781 async fn get_default_branch(&self, owner: &str, repo: &str) -> Result<String> {
782 self.get_default_branch(owner, repo).await
783 }
784
785 async fn list_pr_comments(
786 &self,
787 owner: &str,
788 repo: &str,
789 pr_number: u64,
790 ) -> Result<Vec<IssueComment>> {
791 self.list_pr_comments(owner, repo, pr_number).await
792 }
793
794 async fn create_pr_comment(
795 &self,
796 owner: &str,
797 repo: &str,
798 pr_number: u64,
799 comment: CreateComment,
800 ) -> Result<IssueComment> {
801 self.create_pr_comment(owner, repo, pr_number, comment)
802 .await
803 }
804
805 async fn update_pr_comment(
806 &self,
807 owner: &str,
808 repo: &str,
809 comment_id: u64,
810 comment: UpdateComment,
811 ) -> Result<IssueComment> {
812 self.update_pr_comment(owner, repo, comment_id, comment)
813 .await
814 }
815}
816
817#[cfg(test)]
818#[allow(clippy::unwrap_used)]
819mod tests {
820 use super::*;
821 use crate::types::{CheckStatus, MergeMethod};
822 use secrecy::SecretString;
823 use wiremock::matchers::{header, method, path, query_param};
824 use wiremock::{Mock, MockServer, ResponseTemplate};
825
826 fn test_client(base_url: &str) -> GitHubClient {
828 let auth = Auth::Token(SecretString::from("test-token"));
829 GitHubClient::with_base_url(&auth, base_url).unwrap()
830 }
831
832 fn pr_response_json(number: u64, state: &str, merged: bool) -> serde_json::Value {
834 serde_json::json!({
835 "number": number,
836 "title": format!("PR #{number}"),
837 "body": "Test body",
838 "state": state,
839 "merged": merged,
840 "draft": false,
841 "html_url": format!("https://github.com/owner/repo/pull/{number}"),
842 "head": { "ref": "feature-branch" },
843 "base": { "ref": "main" },
844 "mergeable": true,
845 "mergeable_state": "clean"
846 })
847 }
848
849 #[tokio::test]
852 async fn test_get_pr_success() {
853 let mock_server = MockServer::start().await;
854
855 Mock::given(method("GET"))
856 .and(path("/repos/owner/repo/pulls/123"))
857 .and(header("authorization", "Bearer test-token"))
858 .respond_with(
859 ResponseTemplate::new(200).set_body_json(pr_response_json(123, "open", false)),
860 )
861 .mount(&mock_server)
862 .await;
863
864 let client = test_client(&mock_server.uri());
865 let pr = client.get_pr("owner", "repo", 123).await.unwrap();
866
867 assert_eq!(pr.number, 123);
868 assert_eq!(pr.title, "PR #123");
869 assert_eq!(pr.state, PullRequestState::Open);
870 assert_eq!(pr.head_branch, "feature-branch");
871 assert_eq!(pr.base_branch, "main");
872 }
873
874 #[tokio::test]
875 async fn test_get_pr_merged() {
876 let mock_server = MockServer::start().await;
877
878 Mock::given(method("GET"))
879 .and(path("/repos/owner/repo/pulls/456"))
880 .respond_with(
881 ResponseTemplate::new(200).set_body_json(pr_response_json(456, "closed", true)),
882 )
883 .mount(&mock_server)
884 .await;
885
886 let client = test_client(&mock_server.uri());
887 let pr = client.get_pr("owner", "repo", 456).await.unwrap();
888
889 assert_eq!(pr.state, PullRequestState::Merged);
890 }
891
892 #[tokio::test]
893 async fn test_get_pr_closed() {
894 let mock_server = MockServer::start().await;
895
896 Mock::given(method("GET"))
897 .and(path("/repos/owner/repo/pulls/789"))
898 .respond_with(
899 ResponseTemplate::new(200).set_body_json(pr_response_json(789, "closed", false)),
900 )
901 .mount(&mock_server)
902 .await;
903
904 let client = test_client(&mock_server.uri());
905 let pr = client.get_pr("owner", "repo", 789).await.unwrap();
906
907 assert_eq!(pr.state, PullRequestState::Closed);
908 }
909
910 #[tokio::test]
911 async fn test_get_pr_not_found() {
912 let mock_server = MockServer::start().await;
913
914 Mock::given(method("GET"))
915 .and(path("/repos/owner/repo/pulls/999"))
916 .respond_with(ResponseTemplate::new(404).set_body_json(serde_json::json!({
917 "message": "Not Found"
918 })))
919 .mount(&mock_server)
920 .await;
921
922 let client = test_client(&mock_server.uri());
923 let result = client.get_pr("owner", "repo", 999).await;
924
925 assert!(result.is_err());
926 let err = result.unwrap_err();
927 assert!(matches!(err, Error::ApiError { status: 404, .. }));
928 }
929
930 #[tokio::test]
933 async fn test_unauthorized_error() {
934 let mock_server = MockServer::start().await;
935
936 Mock::given(method("GET"))
937 .and(path("/repos/owner/repo/pulls/123"))
938 .respond_with(ResponseTemplate::new(401).set_body_json(serde_json::json!({
939 "message": "Bad credentials"
940 })))
941 .mount(&mock_server)
942 .await;
943
944 let client = test_client(&mock_server.uri());
945 let result = client.get_pr("owner", "repo", 123).await;
946
947 assert!(matches!(result, Err(Error::AuthenticationFailed)));
948 }
949
950 #[tokio::test]
951 async fn test_rate_limited_error() {
952 let mock_server = MockServer::start().await;
953
954 Mock::given(method("GET"))
955 .and(path("/repos/owner/repo/pulls/123"))
956 .respond_with(
957 ResponseTemplate::new(403)
958 .insert_header("x-ratelimit-remaining", "0")
959 .set_body_json(serde_json::json!({
960 "message": "API rate limit exceeded"
961 })),
962 )
963 .mount(&mock_server)
964 .await;
965
966 let client = test_client(&mock_server.uri());
967 let result = client.get_pr("owner", "repo", 123).await;
968
969 assert!(matches!(result, Err(Error::RateLimited)));
970 }
971
972 #[tokio::test]
975 async fn test_find_pr_for_branch_found() {
976 let mock_server = MockServer::start().await;
977
978 Mock::given(method("GET"))
979 .and(path("/repos/owner/repo/pulls"))
980 .and(query_param("head", "owner:feature"))
981 .and(query_param("state", "open"))
982 .respond_with(
983 ResponseTemplate::new(200)
984 .set_body_json(serde_json::json!([pr_response_json(42, "open", false)])),
985 )
986 .mount(&mock_server)
987 .await;
988
989 let client = test_client(&mock_server.uri());
990 let pr = client
991 .find_pr_for_branch("owner", "repo", "feature")
992 .await
993 .unwrap();
994
995 assert!(pr.is_some());
996 assert_eq!(pr.unwrap().number, 42);
997 }
998
999 #[tokio::test]
1000 async fn test_find_pr_for_branch_not_found() {
1001 let mock_server = MockServer::start().await;
1002
1003 Mock::given(method("GET"))
1004 .and(path("/repos/owner/repo/pulls"))
1005 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!([])))
1006 .mount(&mock_server)
1007 .await;
1008
1009 let client = test_client(&mock_server.uri());
1010 let pr = client
1011 .find_pr_for_branch("owner", "repo", "nonexistent")
1012 .await
1013 .unwrap();
1014
1015 assert!(pr.is_none());
1016 }
1017
1018 #[tokio::test]
1021 async fn test_create_pr_success() {
1022 let mock_server = MockServer::start().await;
1023
1024 Mock::given(method("POST"))
1025 .and(path("/repos/owner/repo/pulls"))
1026 .and(header("authorization", "Bearer test-token"))
1027 .respond_with(
1028 ResponseTemplate::new(201).set_body_json(pr_response_json(100, "open", false)),
1029 )
1030 .mount(&mock_server)
1031 .await;
1032
1033 let client = test_client(&mock_server.uri());
1034 let create_pr = CreatePullRequest {
1035 title: "New Feature".into(),
1036 body: "Description".into(),
1037 head: "feature".into(),
1038 base: "main".into(),
1039 draft: false,
1040 };
1041
1042 let pr = client.create_pr("owner", "repo", create_pr).await.unwrap();
1043
1044 assert_eq!(pr.number, 100);
1045 assert_eq!(pr.state, PullRequestState::Open);
1046 }
1047
1048 #[tokio::test]
1051 async fn test_update_pr_success() {
1052 let mock_server = MockServer::start().await;
1053
1054 Mock::given(method("PATCH"))
1055 .and(path("/repos/owner/repo/pulls/123"))
1056 .respond_with(
1057 ResponseTemplate::new(200).set_body_json(pr_response_json(123, "open", false)),
1058 )
1059 .mount(&mock_server)
1060 .await;
1061
1062 let client = test_client(&mock_server.uri());
1063 let update = UpdatePullRequest {
1064 title: Some("Updated Title".into()),
1065 body: None,
1066 base: None,
1067 };
1068
1069 let pr = client
1070 .update_pr("owner", "repo", 123, update)
1071 .await
1072 .unwrap();
1073
1074 assert_eq!(pr.number, 123);
1075 }
1076
1077 #[tokio::test]
1080 async fn test_get_check_runs_success() {
1081 let mock_server = MockServer::start().await;
1082
1083 Mock::given(method("GET"))
1084 .and(path("/repos/owner/repo/commits/abc123/check-runs"))
1085 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1086 "total_count": 3,
1087 "check_runs": [
1088 {
1089 "name": "CI",
1090 "status": "completed",
1091 "conclusion": "success",
1092 "details_url": "https://example.com/ci"
1093 },
1094 {
1095 "name": "Lint",
1096 "status": "in_progress",
1097 "conclusion": null,
1098 "details_url": null
1099 },
1100 {
1101 "name": "Deploy",
1102 "status": "queued",
1103 "conclusion": null,
1104 "details_url": null
1105 }
1106 ]
1107 })))
1108 .mount(&mock_server)
1109 .await;
1110
1111 let client = test_client(&mock_server.uri());
1112 let checks = client
1113 .get_check_runs("owner", "repo", "abc123")
1114 .await
1115 .unwrap();
1116
1117 assert_eq!(checks.len(), 3);
1118 assert_eq!(checks[0].name, "CI");
1119 assert_eq!(checks[0].status, CheckStatus::Success);
1120 assert_eq!(checks[1].name, "Lint");
1121 assert_eq!(checks[1].status, CheckStatus::InProgress);
1122 assert_eq!(checks[2].name, "Deploy");
1123 assert_eq!(checks[2].status, CheckStatus::Queued);
1124 }
1125
1126 #[tokio::test]
1127 async fn test_get_check_runs_various_statuses() {
1128 let mock_server = MockServer::start().await;
1129
1130 Mock::given(method("GET"))
1131 .and(path("/repos/owner/repo/commits/def456/check-runs"))
1132 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1133 "total_count": 4,
1134 "check_runs": [
1135 { "name": "skipped", "status": "completed", "conclusion": "skipped", "details_url": null },
1136 { "name": "cancelled", "status": "completed", "conclusion": "cancelled", "details_url": null },
1137 { "name": "failure", "status": "completed", "conclusion": "failure", "details_url": null },
1138 { "name": "timed_out", "status": "completed", "conclusion": "timed_out", "details_url": null }
1139 ]
1140 })))
1141 .mount(&mock_server)
1142 .await;
1143
1144 let client = test_client(&mock_server.uri());
1145 let checks = client
1146 .get_check_runs("owner", "repo", "def456")
1147 .await
1148 .unwrap();
1149
1150 assert_eq!(checks[0].status, CheckStatus::Skipped);
1151 assert_eq!(checks[1].status, CheckStatus::Cancelled);
1152 assert_eq!(checks[2].status, CheckStatus::Failure);
1153 assert_eq!(checks[3].status, CheckStatus::Failure); }
1155
1156 #[tokio::test]
1159 async fn test_merge_pr_success() {
1160 let mock_server = MockServer::start().await;
1161
1162 Mock::given(method("PUT"))
1163 .and(path("/repos/owner/repo/pulls/123/merge"))
1164 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1165 "sha": "abc123def456",
1166 "merged": true,
1167 "message": "Pull Request successfully merged"
1168 })))
1169 .mount(&mock_server)
1170 .await;
1171
1172 let client = test_client(&mock_server.uri());
1173 let merge = MergePullRequest {
1174 commit_title: Some("Merge PR #123".into()),
1175 commit_message: None,
1176 merge_method: MergeMethod::Squash,
1177 };
1178
1179 let result = client.merge_pr("owner", "repo", 123, merge).await.unwrap();
1180
1181 assert!(result.merged);
1182 assert_eq!(result.sha, "abc123def456");
1183 }
1184
1185 #[tokio::test]
1188 async fn test_delete_ref_success() {
1189 let mock_server = MockServer::start().await;
1190
1191 Mock::given(method("DELETE"))
1192 .and(path("/repos/owner/repo/git/refs/heads/feature-branch"))
1193 .respond_with(ResponseTemplate::new(204))
1194 .mount(&mock_server)
1195 .await;
1196
1197 let client = test_client(&mock_server.uri());
1198 let result = client.delete_ref("owner", "repo", "feature-branch").await;
1199
1200 assert!(result.is_ok());
1201 }
1202
1203 #[tokio::test]
1204 async fn test_delete_ref_not_found() {
1205 let mock_server = MockServer::start().await;
1206
1207 Mock::given(method("DELETE"))
1208 .and(path("/repos/owner/repo/git/refs/heads/nonexistent"))
1209 .respond_with(ResponseTemplate::new(422).set_body_json(serde_json::json!({
1210 "message": "Reference does not exist"
1211 })))
1212 .mount(&mock_server)
1213 .await;
1214
1215 let client = test_client(&mock_server.uri());
1216 let result = client.delete_ref("owner", "repo", "nonexistent").await;
1217
1218 assert!(result.is_err());
1219 }
1220
1221 #[tokio::test]
1222 async fn test_delete_ref_rate_limited() {
1223 let mock_server = MockServer::start().await;
1224
1225 Mock::given(method("DELETE"))
1226 .and(path("/repos/owner/repo/git/refs/heads/branch"))
1227 .respond_with(
1228 ResponseTemplate::new(403)
1229 .insert_header("x-ratelimit-remaining", "0")
1230 .set_body_json(serde_json::json!({ "message": "Rate limited" })),
1231 )
1232 .mount(&mock_server)
1233 .await;
1234
1235 let client = test_client(&mock_server.uri());
1236 let result = client.delete_ref("owner", "repo", "branch").await;
1237
1238 assert!(matches!(result, Err(Error::RateLimited)));
1239 }
1240
1241 #[tokio::test]
1244 async fn test_get_default_branch_success() {
1245 let mock_server = MockServer::start().await;
1246
1247 Mock::given(method("GET"))
1248 .and(path("/repos/owner/repo"))
1249 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1250 "default_branch": "main"
1251 })))
1252 .mount(&mock_server)
1253 .await;
1254
1255 let client = test_client(&mock_server.uri());
1256 let branch = client.get_default_branch("owner", "repo").await.unwrap();
1257
1258 assert_eq!(branch, "main");
1259 }
1260
1261 #[tokio::test]
1264 async fn test_list_pr_comments_success() {
1265 let mock_server = MockServer::start().await;
1266
1267 Mock::given(method("GET"))
1268 .and(path("/repos/owner/repo/issues/123/comments"))
1269 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!([
1270 { "id": 1, "body": "First comment" },
1271 { "id": 2, "body": "Second comment" }
1272 ])))
1273 .mount(&mock_server)
1274 .await;
1275
1276 let client = test_client(&mock_server.uri());
1277 let comments = client.list_pr_comments("owner", "repo", 123).await.unwrap();
1278
1279 assert_eq!(comments.len(), 2);
1280 assert_eq!(comments[0].id, 1);
1281 assert_eq!(comments[0].body, Some("First comment".into()));
1282 }
1283
1284 #[tokio::test]
1285 async fn test_create_pr_comment_success() {
1286 let mock_server = MockServer::start().await;
1287
1288 Mock::given(method("POST"))
1289 .and(path("/repos/owner/repo/issues/123/comments"))
1290 .respond_with(ResponseTemplate::new(201).set_body_json(serde_json::json!({
1291 "id": 42,
1292 "body": "New comment"
1293 })))
1294 .mount(&mock_server)
1295 .await;
1296
1297 let client = test_client(&mock_server.uri());
1298 let comment = CreateComment {
1299 body: "New comment".into(),
1300 };
1301
1302 let result = client
1303 .create_pr_comment("owner", "repo", 123, comment)
1304 .await
1305 .unwrap();
1306
1307 assert_eq!(result.id, 42);
1308 assert_eq!(result.body, Some("New comment".into()));
1309 }
1310
1311 #[tokio::test]
1312 async fn test_update_pr_comment_success() {
1313 let mock_server = MockServer::start().await;
1314
1315 Mock::given(method("PATCH"))
1316 .and(path("/repos/owner/repo/issues/comments/42"))
1317 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1318 "id": 42,
1319 "body": "Updated comment"
1320 })))
1321 .mount(&mock_server)
1322 .await;
1323
1324 let client = test_client(&mock_server.uri());
1325 let update = UpdateComment {
1326 body: "Updated comment".into(),
1327 };
1328
1329 let result = client
1330 .update_pr_comment("owner", "repo", 42, update)
1331 .await
1332 .unwrap();
1333
1334 assert_eq!(result.body, Some("Updated comment".into()));
1335 }
1336
1337 #[tokio::test]
1340 async fn test_get_prs_batch_empty() {
1341 let mock_server = MockServer::start().await;
1342 let client = test_client(&mock_server.uri());
1343
1344 let result = client.get_prs_batch("owner", "repo", &[]).await.unwrap();
1346 assert!(result.is_empty());
1347 }
1348
1349 #[tokio::test]
1350 async fn test_get_prs_batch_success() {
1351 let mock_server = MockServer::start().await;
1352
1353 Mock::given(method("POST"))
1354 .and(path("/graphql"))
1355 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1356 "data": {
1357 "repository": {
1358 "pr0": {
1359 "number": 1,
1360 "state": "OPEN",
1361 "merged": false,
1362 "isDraft": false,
1363 "headRefName": "feature-1",
1364 "baseRefName": "main",
1365 "url": "https://github.com/owner/repo/pull/1"
1366 },
1367 "pr1": {
1368 "number": 2,
1369 "state": "MERGED",
1370 "merged": true,
1371 "isDraft": false,
1372 "headRefName": "feature-2",
1373 "baseRefName": "main",
1374 "url": "https://github.com/owner/repo/pull/2"
1375 },
1376 "pr2": null
1377 }
1378 }
1379 })))
1380 .mount(&mock_server)
1381 .await;
1382
1383 let client = test_client(&mock_server.uri());
1384 let result = client
1385 .get_prs_batch("owner", "repo", &[1, 2, 999])
1386 .await
1387 .unwrap();
1388
1389 assert_eq!(result.len(), 2);
1390 assert_eq!(result.get(&1).unwrap().state, PullRequestState::Open);
1391 assert_eq!(result.get(&2).unwrap().state, PullRequestState::Merged);
1392 assert!(!result.contains_key(&999)); }
1394
1395 #[tokio::test]
1396 async fn test_get_prs_batch_graphql_error() {
1397 let mock_server = MockServer::start().await;
1398
1399 Mock::given(method("POST"))
1400 .and(path("/graphql"))
1401 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
1402 "data": null,
1403 "errors": [
1404 { "message": "Something went wrong" }
1405 ]
1406 })))
1407 .mount(&mock_server)
1408 .await;
1409
1410 let client = test_client(&mock_server.uri());
1411 let result = client.get_prs_batch("owner", "repo", &[1]).await;
1412
1413 assert!(result.is_err());
1414 let err = result.unwrap_err();
1415 assert!(matches!(err, Error::ApiError { status: 200, .. }));
1416 }
1417
1418 #[tokio::test]
1419 async fn test_get_prs_batch_auth_error() {
1420 let mock_server = MockServer::start().await;
1421
1422 Mock::given(method("POST"))
1423 .and(path("/graphql"))
1424 .respond_with(ResponseTemplate::new(401).set_body_json(serde_json::json!({
1425 "message": "Bad credentials"
1426 })))
1427 .mount(&mock_server)
1428 .await;
1429
1430 let client = test_client(&mock_server.uri());
1431 let result = client.get_prs_batch("owner", "repo", &[1]).await;
1432
1433 assert!(matches!(result, Err(Error::AuthenticationFailed)));
1434 }
1435
1436 #[test]
1439 fn test_build_graphql_pr_query() {
1440 let query = build_graphql_pr_query(&[1, 42, 100]);
1441
1442 assert!(query.contains("pr0: pullRequest(number: 1)"));
1443 assert!(query.contains("pr1: pullRequest(number: 42)"));
1444 assert!(query.contains("pr2: pullRequest(number: 100)"));
1445 assert!(query.contains("$owner: String!"));
1446 assert!(query.contains("$repo: String!"));
1447 }
1448
1449 #[test]
1452 fn test_github_client_debug_redacts_token() {
1453 let auth = Auth::Token(SecretString::from("super-secret-token"));
1454 let client = GitHubClient::with_base_url(&auth, "https://api.example.com").unwrap();
1455
1456 let debug_output = format!("{client:?}");
1457
1458 assert!(debug_output.contains("[redacted]"));
1459 assert!(!debug_output.contains("super-secret-token"));
1460 }
1461}