1use reqwest::Client;
4use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
5use secrecy::{ExposeSecret, SecretString};
6use serde::de::DeserializeOwned;
7
8use crate::auth::Auth;
9use crate::error::{Error, Result};
10use crate::types::{
11 CheckRun, CreatePullRequest, MergePullRequest, MergeResult, PullRequest, PullRequestState,
12 UpdatePullRequest,
13};
14
15#[derive(serde::Deserialize)]
19struct ApiPullRequest {
20 number: u64,
21 title: String,
22 body: Option<String>,
23 state: String,
24 #[serde(default)]
26 merged: bool,
27 draft: bool,
28 html_url: String,
29 head: ApiBranch,
30 base: ApiBranch,
31 mergeable: Option<bool>,
33 mergeable_state: Option<String>,
35}
36
37#[derive(serde::Deserialize)]
39struct ApiBranch {
40 #[serde(rename = "ref")]
41 ref_name: String,
42}
43
44impl ApiPullRequest {
45 fn into_pull_request(self) -> PullRequest {
47 let state = if self.merged {
49 PullRequestState::Merged
50 } else {
51 match self.state.as_str() {
52 "open" => PullRequestState::Open,
53 _ => PullRequestState::Closed,
54 }
55 };
56
57 PullRequest {
58 number: self.number,
59 title: self.title,
60 body: self.body,
61 state,
62 draft: self.draft,
63 head_branch: self.head.ref_name,
64 base_branch: self.base.ref_name,
65 html_url: self.html_url,
66 mergeable: self.mergeable,
67 mergeable_state: self.mergeable_state,
68 }
69 }
70
71 fn into_pull_request_with_state(self, state: PullRequestState) -> PullRequest {
73 PullRequest {
74 number: self.number,
75 title: self.title,
76 body: self.body,
77 state,
78 draft: self.draft,
79 head_branch: self.head.ref_name,
80 base_branch: self.base.ref_name,
81 html_url: self.html_url,
82 mergeable: self.mergeable,
83 mergeable_state: self.mergeable_state,
84 }
85 }
86}
87
88#[derive(serde::Serialize)]
92struct GraphQLRequest {
93 query: String,
94 variables: GraphQLVariables,
95}
96
97#[derive(serde::Serialize)]
99struct GraphQLVariables {
100 owner: String,
101 repo: String,
102}
103
104#[derive(serde::Deserialize)]
106#[serde(rename_all = "camelCase")]
107struct GraphQLPullRequest {
108 number: u64,
109 state: String,
110 merged: bool,
111 is_draft: bool,
112 head_ref_name: String,
113 base_ref_name: String,
114 url: String,
115}
116
117impl GraphQLPullRequest {
118 fn into_pull_request(self) -> PullRequest {
119 let state = if self.merged {
120 PullRequestState::Merged
121 } else if self.state == "OPEN" {
122 PullRequestState::Open
123 } else {
124 PullRequestState::Closed
125 };
126
127 PullRequest {
128 number: self.number,
129 title: String::new(), body: None,
131 state,
132 draft: self.is_draft,
133 head_branch: self.head_ref_name,
134 base_branch: self.base_ref_name,
135 html_url: self.url,
136 mergeable: None, mergeable_state: None,
138 }
139 }
140}
141
142#[derive(serde::Deserialize)]
143struct GraphQLResponse {
144 data: Option<GraphQLData>,
145 errors: Option<Vec<GraphQLError>>,
146}
147
148#[derive(serde::Deserialize)]
149struct GraphQLData {
150 repository: Option<serde_json::Value>,
151}
152
153#[derive(serde::Deserialize)]
154struct GraphQLError {
155 message: String,
156}
157
158pub struct GitHubClient {
160 client: Client,
161 base_url: String,
162 token: SecretString,
164}
165
166impl GitHubClient {
167 pub const DEFAULT_API_URL: &'static str = "https://api.github.com";
169
170 pub fn new(auth: &Auth) -> Result<Self> {
175 Self::with_base_url(auth, Self::DEFAULT_API_URL)
176 }
177
178 pub fn with_base_url(auth: &Auth, base_url: impl Into<String>) -> Result<Self> {
183 let token = auth.resolve()?;
184
185 let mut headers = HeaderMap::new();
186 headers.insert(
187 ACCEPT,
188 HeaderValue::from_static("application/vnd.github+json"),
189 );
190 headers.insert(USER_AGENT, HeaderValue::from_static("rung-cli"));
191 headers.insert(
192 "X-GitHub-Api-Version",
193 HeaderValue::from_static("2022-11-28"),
194 );
195
196 let client = Client::builder().default_headers(headers).build()?;
197
198 Ok(Self {
199 client,
200 base_url: base_url.into(),
201 token,
202 })
203 }
204
205 async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
207 let url = format!("{}{}", self.base_url, path);
208 let response = self
209 .client
210 .get(&url)
211 .header(
212 AUTHORIZATION,
213 format!("Bearer {}", self.token.expose_secret()),
214 )
215 .send()
216 .await?;
217
218 self.handle_response(response).await
219 }
220
221 async fn post<T: DeserializeOwned, B: serde::Serialize + Sync>(
223 &self,
224 path: &str,
225 body: &B,
226 ) -> Result<T> {
227 let url = format!("{}{}", self.base_url, path);
228 let response = self
229 .client
230 .post(&url)
231 .header(
232 AUTHORIZATION,
233 format!("Bearer {}", self.token.expose_secret()),
234 )
235 .json(body)
236 .send()
237 .await?;
238
239 self.handle_response(response).await
240 }
241
242 async fn patch<T: DeserializeOwned, B: serde::Serialize + Sync>(
244 &self,
245 path: &str,
246 body: &B,
247 ) -> Result<T> {
248 let url = format!("{}{}", self.base_url, path);
249 let response = self
250 .client
251 .patch(&url)
252 .header(
253 AUTHORIZATION,
254 format!("Bearer {}", self.token.expose_secret()),
255 )
256 .json(body)
257 .send()
258 .await?;
259
260 self.handle_response(response).await
261 }
262
263 async fn put<T: DeserializeOwned, B: serde::Serialize + Sync>(
265 &self,
266 path: &str,
267 body: &B,
268 ) -> Result<T> {
269 let url = format!("{}{}", self.base_url, path);
270 let response = self
271 .client
272 .put(&url)
273 .header(
274 AUTHORIZATION,
275 format!("Bearer {}", self.token.expose_secret()),
276 )
277 .json(body)
278 .send()
279 .await?;
280
281 self.handle_response(response).await
282 }
283
284 async fn delete(&self, path: &str) -> Result<()> {
286 let url = format!("{}{}", self.base_url, path);
287 let response = self
288 .client
289 .delete(&url)
290 .header(
291 AUTHORIZATION,
292 format!("Bearer {}", self.token.expose_secret()),
293 )
294 .send()
295 .await?;
296
297 let status = response.status();
298 if status.is_success() || status.as_u16() == 204 {
299 return Ok(());
300 }
301
302 let status_code = status.as_u16();
303 match status_code {
304 401 => Err(Error::AuthenticationFailed),
305 403 if response
306 .headers()
307 .get("x-ratelimit-remaining")
308 .is_some_and(|v| v == "0") =>
309 {
310 Err(Error::RateLimited)
311 }
312 _ => {
313 let text = response.text().await.unwrap_or_default();
314 Err(Error::ApiError {
315 status: status_code,
316 message: text,
317 })
318 }
319 }
320 }
321
322 async fn handle_response<T: DeserializeOwned>(&self, response: reqwest::Response) -> Result<T> {
324 let status = response.status();
325
326 if status.is_success() {
327 let body = response.json().await?;
328 return Ok(body);
329 }
330
331 let status_code = status.as_u16();
333
334 match status_code {
335 401 => Err(Error::AuthenticationFailed),
336 403 if response
337 .headers()
338 .get("x-ratelimit-remaining")
339 .is_some_and(|v| v == "0") =>
340 {
341 Err(Error::RateLimited)
342 }
343 _ => {
344 let text = response.text().await.unwrap_or_default();
345 Err(Error::ApiError {
346 status: status_code,
347 message: text,
348 })
349 }
350 }
351 }
352
353 pub async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
360 let api_pr: ApiPullRequest = self
361 .get(&format!("/repos/{owner}/{repo}/pulls/{number}"))
362 .await?;
363
364 Ok(api_pr.into_pull_request())
365 }
366
367 pub async fn get_prs_batch(
378 &self,
379 owner: &str,
380 repo: &str,
381 numbers: &[u64],
382 ) -> Result<std::collections::HashMap<u64, PullRequest>> {
383 if numbers.is_empty() {
384 return Ok(std::collections::HashMap::new());
385 }
386
387 let query = build_graphql_pr_query(numbers);
388 let request = GraphQLRequest {
389 query,
390 variables: GraphQLVariables {
391 owner: owner.to_string(),
392 repo: repo.to_string(),
393 },
394 };
395 let url = format!("{}/graphql", self.base_url);
396
397 let response = self
398 .client
399 .post(&url)
400 .header(
401 AUTHORIZATION,
402 format!("Bearer {}", self.token.expose_secret()),
403 )
404 .json(&request)
405 .send()
406 .await?;
407
408 let status = response.status();
409 if !status.is_success() {
410 let status_code = status.as_u16();
411 return match status_code {
412 401 => Err(Error::AuthenticationFailed),
413 403 if response
414 .headers()
415 .get("x-ratelimit-remaining")
416 .is_some_and(|v| v == "0") =>
417 {
418 Err(Error::RateLimited)
419 }
420 _ => {
421 let text = response.text().await.unwrap_or_default();
422 Err(Error::ApiError {
423 status: status_code,
424 message: text,
425 })
426 }
427 };
428 }
429
430 let graphql_response: GraphQLResponse = response.json().await?;
431
432 if let Some(errors) = graphql_response.errors {
434 if !errors.is_empty() {
435 let messages: Vec<_> = errors.iter().map(|e| e.message.as_str()).collect();
436 return Err(Error::ApiError {
437 status: 200,
438 message: messages.join("; "),
439 });
440 }
441 }
442
443 let mut result = std::collections::HashMap::new();
444
445 if let Some(data) = graphql_response.data {
446 if let Some(repo_data) = data.repository {
447 for (i, &num) in numbers.iter().enumerate() {
449 let key = format!("pr{i}");
450 if let Some(pr_value) = repo_data.get(&key) {
451 if !pr_value.is_null() {
453 if let Ok(pr) =
454 serde_json::from_value::<GraphQLPullRequest>(pr_value.clone())
455 {
456 result.insert(num, pr.into_pull_request());
457 }
458 }
459 }
460 }
461 }
462 }
463
464 Ok(result)
465 }
466
467 pub async fn find_pr_for_branch(
472 &self,
473 owner: &str,
474 repo: &str,
475 branch: &str,
476 ) -> Result<Option<PullRequest>> {
477 let prs: Vec<ApiPullRequest> = self
479 .get(&format!(
480 "/repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open"
481 ))
482 .await?;
483
484 Ok(prs
485 .into_iter()
486 .next()
487 .map(|api_pr| api_pr.into_pull_request_with_state(PullRequestState::Open)))
488 }
489
490 pub async fn create_pr(
495 &self,
496 owner: &str,
497 repo: &str,
498 pr: CreatePullRequest,
499 ) -> Result<PullRequest> {
500 let api_pr: ApiPullRequest = self
502 .post(&format!("/repos/{owner}/{repo}/pulls"), &pr)
503 .await?;
504
505 Ok(api_pr.into_pull_request_with_state(PullRequestState::Open))
506 }
507
508 pub async fn update_pr(
513 &self,
514 owner: &str,
515 repo: &str,
516 number: u64,
517 update: UpdatePullRequest,
518 ) -> Result<PullRequest> {
519 let api_pr: ApiPullRequest = self
520 .patch(&format!("/repos/{owner}/{repo}/pulls/{number}"), &update)
521 .await?;
522
523 Ok(api_pr.into_pull_request())
524 }
525
526 pub async fn get_check_runs(
533 &self,
534 owner: &str,
535 repo: &str,
536 commit_sha: &str,
537 ) -> Result<Vec<CheckRun>> {
538 #[derive(serde::Deserialize)]
539 struct Response {
540 check_runs: Vec<ApiCheckRun>,
541 }
542
543 #[derive(serde::Deserialize)]
544 struct ApiCheckRun {
545 name: String,
546 status: String,
547 conclusion: Option<String>,
548 details_url: Option<String>,
549 }
550
551 let response: Response = self
552 .get(&format!(
553 "/repos/{owner}/{repo}/commits/{commit_sha}/check-runs"
554 ))
555 .await?;
556
557 Ok(response
558 .check_runs
559 .into_iter()
560 .map(|cr| CheckRun {
561 name: cr.name,
562 status: match (cr.status.as_str(), cr.conclusion.as_deref()) {
563 ("queued", _) => crate::types::CheckStatus::Queued,
564 ("in_progress", _) => crate::types::CheckStatus::InProgress,
565 ("completed", Some("success")) => crate::types::CheckStatus::Success,
566 ("completed", Some("skipped")) => crate::types::CheckStatus::Skipped,
567 ("completed", Some("cancelled")) => crate::types::CheckStatus::Cancelled,
568 _ => crate::types::CheckStatus::Failure,
570 },
571 details_url: cr.details_url,
572 })
573 .collect())
574 }
575
576 pub async fn merge_pr(
583 &self,
584 owner: &str,
585 repo: &str,
586 number: u64,
587 merge: MergePullRequest,
588 ) -> Result<MergeResult> {
589 self.put(
590 &format!("/repos/{owner}/{repo}/pulls/{number}/merge"),
591 &merge,
592 )
593 .await
594 }
595
596 pub async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
603 self.delete(&format!("/repos/{owner}/{repo}/git/refs/heads/{ref_name}"))
604 .await
605 }
606
607 pub async fn get_default_branch(&self, owner: &str, repo: &str) -> Result<String> {
614 #[derive(serde::Deserialize)]
615 struct RepoInfo {
616 default_branch: String,
617 }
618
619 let info: RepoInfo = self.get(&format!("/repos/{owner}/{repo}")).await?;
620 Ok(info.default_branch)
621 }
622
623 pub async fn list_pr_comments(
630 &self,
631 owner: &str,
632 repo: &str,
633 pr_number: u64,
634 ) -> Result<Vec<crate::types::IssueComment>> {
635 self.get(&format!(
636 "/repos/{owner}/{repo}/issues/{pr_number}/comments"
637 ))
638 .await
639 }
640
641 pub async fn create_pr_comment(
646 &self,
647 owner: &str,
648 repo: &str,
649 pr_number: u64,
650 comment: crate::types::CreateComment,
651 ) -> Result<crate::types::IssueComment> {
652 self.post(
653 &format!("/repos/{owner}/{repo}/issues/{pr_number}/comments"),
654 &comment,
655 )
656 .await
657 }
658
659 pub async fn update_pr_comment(
664 &self,
665 owner: &str,
666 repo: &str,
667 comment_id: u64,
668 comment: crate::types::UpdateComment,
669 ) -> Result<crate::types::IssueComment> {
670 self.patch(
671 &format!("/repos/{owner}/{repo}/issues/comments/{comment_id}"),
672 &comment,
673 )
674 .await
675 }
676}
677
678impl std::fmt::Debug for GitHubClient {
679 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
680 f.debug_struct("GitHubClient")
681 .field("base_url", &self.base_url)
682 .field("token", &"[redacted]")
683 .finish_non_exhaustive()
684 }
685}
686
687fn build_graphql_pr_query(numbers: &[u64]) -> String {
689 const PR_FIELDS: &str = "number state merged isDraft headRefName baseRefName url";
690
691 let pr_queries: Vec<String> = numbers
692 .iter()
693 .enumerate()
694 .map(|(i, num)| format!("pr{i}: pullRequest(number: {num}) {{ {PR_FIELDS} }}"))
695 .collect();
696
697 format!(
698 r"query($owner: String!, $repo: String!) {{ repository(owner: $owner, name: $repo) {{ {pr_queries} }} }}",
699 pr_queries = pr_queries.join(" ")
700 )
701}