1use reqwest::Client;
4use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
5use serde::de::DeserializeOwned;
6
7use crate::auth::Auth;
8use crate::error::{Error, Result};
9use crate::types::{
10 CheckRun, CreatePullRequest, MergePullRequest, MergeResult, PullRequest, UpdatePullRequest,
11};
12
13pub struct GitHubClient {
15 client: Client,
16 base_url: String,
17 token: String,
18}
19
20impl GitHubClient {
21 pub const DEFAULT_API_URL: &'static str = "https://api.github.com";
23
24 pub fn new(auth: &Auth) -> Result<Self> {
29 Self::with_base_url(auth, Self::DEFAULT_API_URL)
30 }
31
32 pub fn with_base_url(auth: &Auth, base_url: impl Into<String>) -> Result<Self> {
37 let token = auth.resolve()?;
38
39 let mut headers = HeaderMap::new();
40 headers.insert(
41 ACCEPT,
42 HeaderValue::from_static("application/vnd.github+json"),
43 );
44 headers.insert(USER_AGENT, HeaderValue::from_static("rung-cli"));
45 headers.insert(
46 "X-GitHub-Api-Version",
47 HeaderValue::from_static("2022-11-28"),
48 );
49
50 let client = Client::builder().default_headers(headers).build()?;
51
52 Ok(Self {
53 client,
54 base_url: base_url.into(),
55 token,
56 })
57 }
58
59 async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
61 let url = format!("{}{}", self.base_url, path);
62 let response = self
63 .client
64 .get(&url)
65 .header(AUTHORIZATION, format!("Bearer {}", self.token))
66 .send()
67 .await?;
68
69 self.handle_response(response).await
70 }
71
72 async fn post<T: DeserializeOwned, B: serde::Serialize + Sync>(
74 &self,
75 path: &str,
76 body: &B,
77 ) -> Result<T> {
78 let url = format!("{}{}", self.base_url, path);
79 let response = self
80 .client
81 .post(&url)
82 .header(AUTHORIZATION, format!("Bearer {}", self.token))
83 .json(body)
84 .send()
85 .await?;
86
87 self.handle_response(response).await
88 }
89
90 async fn patch<T: DeserializeOwned, B: serde::Serialize + Sync>(
92 &self,
93 path: &str,
94 body: &B,
95 ) -> Result<T> {
96 let url = format!("{}{}", self.base_url, path);
97 let response = self
98 .client
99 .patch(&url)
100 .header(AUTHORIZATION, format!("Bearer {}", self.token))
101 .json(body)
102 .send()
103 .await?;
104
105 self.handle_response(response).await
106 }
107
108 async fn put<T: DeserializeOwned, B: serde::Serialize + Sync>(
110 &self,
111 path: &str,
112 body: &B,
113 ) -> Result<T> {
114 let url = format!("{}{}", self.base_url, path);
115 let response = self
116 .client
117 .put(&url)
118 .header(AUTHORIZATION, format!("Bearer {}", self.token))
119 .json(body)
120 .send()
121 .await?;
122
123 self.handle_response(response).await
124 }
125
126 async fn delete(&self, path: &str) -> Result<()> {
128 let url = format!("{}{}", self.base_url, path);
129 let response = self
130 .client
131 .delete(&url)
132 .header(AUTHORIZATION, format!("Bearer {}", self.token))
133 .send()
134 .await?;
135
136 let status = response.status();
137 if status.is_success() || status.as_u16() == 204 {
138 return Ok(());
139 }
140
141 let status_code = status.as_u16();
142 match status_code {
143 401 => Err(Error::AuthenticationFailed),
144 403 if response
145 .headers()
146 .get("x-ratelimit-remaining")
147 .is_some_and(|v| v == "0") =>
148 {
149 Err(Error::RateLimited)
150 }
151 _ => {
152 let text = response.text().await.unwrap_or_default();
153 Err(Error::ApiError {
154 status: status_code,
155 message: text,
156 })
157 }
158 }
159 }
160
161 async fn handle_response<T: DeserializeOwned>(&self, response: reqwest::Response) -> Result<T> {
163 let status = response.status();
164
165 if status.is_success() {
166 let body = response.json().await?;
167 return Ok(body);
168 }
169
170 let status_code = status.as_u16();
172
173 match status_code {
174 401 => Err(Error::AuthenticationFailed),
175 403 if response
176 .headers()
177 .get("x-ratelimit-remaining")
178 .is_some_and(|v| v == "0") =>
179 {
180 Err(Error::RateLimited)
181 }
182 _ => {
183 let text = response.text().await.unwrap_or_default();
184 Err(Error::ApiError {
185 status: status_code,
186 message: text,
187 })
188 }
189 }
190 }
191
192 pub async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
199 #[derive(serde::Deserialize)]
200 struct ApiPr {
201 number: u64,
202 title: String,
203 body: Option<String>,
204 state: String,
205 draft: bool,
206 html_url: String,
207 head: Branch,
208 base: Branch,
209 }
210
211 #[derive(serde::Deserialize)]
212 struct Branch {
213 #[serde(rename = "ref")]
214 ref_name: String,
215 }
216
217 let api_pr: ApiPr = self
218 .get(&format!("/repos/{owner}/{repo}/pulls/{number}"))
219 .await?;
220
221 Ok(PullRequest {
222 number: api_pr.number,
223 title: api_pr.title,
224 body: api_pr.body,
225 state: match api_pr.state.as_str() {
226 "open" => crate::types::PullRequestState::Open,
227 "merged" => crate::types::PullRequestState::Merged,
228 _ => crate::types::PullRequestState::Closed,
229 },
230 draft: api_pr.draft,
231 head_branch: api_pr.head.ref_name,
232 base_branch: api_pr.base.ref_name,
233 html_url: api_pr.html_url,
234 })
235 }
236
237 pub async fn find_pr_for_branch(
242 &self,
243 owner: &str,
244 repo: &str,
245 branch: &str,
246 ) -> Result<Option<PullRequest>> {
247 #[derive(serde::Deserialize)]
248 struct ApiPr {
249 number: u64,
250 title: String,
251 body: Option<String>,
252 #[allow(dead_code)]
253 state: String,
254 draft: bool,
255 html_url: String,
256 head: Branch,
257 base: Branch,
258 }
259
260 #[derive(serde::Deserialize)]
261 struct Branch {
262 #[serde(rename = "ref")]
263 ref_name: String,
264 }
265
266 let prs: Vec<ApiPr> = self
268 .get(&format!(
269 "/repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open"
270 ))
271 .await?;
272
273 Ok(prs.into_iter().next().map(|api_pr| PullRequest {
274 number: api_pr.number,
275 title: api_pr.title,
276 body: api_pr.body,
277 state: crate::types::PullRequestState::Open,
278 draft: api_pr.draft,
279 head_branch: api_pr.head.ref_name,
280 base_branch: api_pr.base.ref_name,
281 html_url: api_pr.html_url,
282 }))
283 }
284
285 pub async fn create_pr(
290 &self,
291 owner: &str,
292 repo: &str,
293 pr: CreatePullRequest,
294 ) -> Result<PullRequest> {
295 #[derive(serde::Deserialize)]
296 struct ApiPr {
297 number: u64,
298 title: String,
299 body: Option<String>,
300 #[allow(dead_code)]
301 state: String,
302 draft: bool,
303 html_url: String,
304 head: Branch,
305 base: Branch,
306 }
307
308 #[derive(serde::Deserialize)]
309 struct Branch {
310 #[serde(rename = "ref")]
311 ref_name: String,
312 }
313
314 let api_pr: ApiPr = self
316 .post(&format!("/repos/{owner}/{repo}/pulls"), &pr)
317 .await?;
318
319 Ok(PullRequest {
320 number: api_pr.number,
321 title: api_pr.title,
322 body: api_pr.body,
323 state: crate::types::PullRequestState::Open,
324 draft: api_pr.draft,
325 head_branch: api_pr.head.ref_name,
326 base_branch: api_pr.base.ref_name,
327 html_url: api_pr.html_url,
328 })
329 }
330
331 pub async fn update_pr(
336 &self,
337 owner: &str,
338 repo: &str,
339 number: u64,
340 update: UpdatePullRequest,
341 ) -> Result<PullRequest> {
342 #[derive(serde::Deserialize)]
343 struct ApiPr {
344 number: u64,
345 title: String,
346 body: Option<String>,
347 state: String,
348 draft: bool,
349 html_url: String,
350 head: Branch,
351 base: Branch,
352 }
353
354 #[derive(serde::Deserialize)]
355 struct Branch {
356 #[serde(rename = "ref")]
357 ref_name: String,
358 }
359
360 let api_pr: ApiPr = self
361 .patch(&format!("/repos/{owner}/{repo}/pulls/{number}"), &update)
362 .await?;
363
364 Ok(PullRequest {
365 number: api_pr.number,
366 title: api_pr.title,
367 body: api_pr.body,
368 state: match api_pr.state.as_str() {
369 "open" => crate::types::PullRequestState::Open,
370 "merged" => crate::types::PullRequestState::Merged,
371 _ => crate::types::PullRequestState::Closed,
372 },
373 draft: api_pr.draft,
374 head_branch: api_pr.head.ref_name,
375 base_branch: api_pr.base.ref_name,
376 html_url: api_pr.html_url,
377 })
378 }
379
380 pub async fn get_check_runs(
387 &self,
388 owner: &str,
389 repo: &str,
390 commit_sha: &str,
391 ) -> Result<Vec<CheckRun>> {
392 #[derive(serde::Deserialize)]
393 struct Response {
394 check_runs: Vec<ApiCheckRun>,
395 }
396
397 #[derive(serde::Deserialize)]
398 struct ApiCheckRun {
399 name: String,
400 status: String,
401 conclusion: Option<String>,
402 details_url: Option<String>,
403 }
404
405 let response: Response = self
406 .get(&format!(
407 "/repos/{owner}/{repo}/commits/{commit_sha}/check-runs"
408 ))
409 .await?;
410
411 Ok(response
412 .check_runs
413 .into_iter()
414 .map(|cr| CheckRun {
415 name: cr.name,
416 status: match (cr.status.as_str(), cr.conclusion.as_deref()) {
417 ("queued", _) => crate::types::CheckStatus::Queued,
418 ("in_progress", _) => crate::types::CheckStatus::InProgress,
419 ("completed", Some("success")) => crate::types::CheckStatus::Success,
420 ("completed", Some("skipped")) => crate::types::CheckStatus::Skipped,
421 ("completed", Some("cancelled")) => crate::types::CheckStatus::Cancelled,
422 _ => crate::types::CheckStatus::Failure,
424 },
425 details_url: cr.details_url,
426 })
427 .collect())
428 }
429
430 pub async fn merge_pr(
437 &self,
438 owner: &str,
439 repo: &str,
440 number: u64,
441 merge: MergePullRequest,
442 ) -> Result<MergeResult> {
443 self.put(
444 &format!("/repos/{owner}/{repo}/pulls/{number}/merge"),
445 &merge,
446 )
447 .await
448 }
449
450 pub async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
457 self.delete(&format!("/repos/{owner}/{repo}/git/refs/heads/{ref_name}"))
458 .await
459 }
460
461 pub async fn list_pr_comments(
468 &self,
469 owner: &str,
470 repo: &str,
471 pr_number: u64,
472 ) -> Result<Vec<crate::types::IssueComment>> {
473 self.get(&format!(
474 "/repos/{owner}/{repo}/issues/{pr_number}/comments"
475 ))
476 .await
477 }
478
479 pub async fn create_pr_comment(
484 &self,
485 owner: &str,
486 repo: &str,
487 pr_number: u64,
488 comment: crate::types::CreateComment,
489 ) -> Result<crate::types::IssueComment> {
490 self.post(
491 &format!("/repos/{owner}/{repo}/issues/{pr_number}/comments"),
492 &comment,
493 )
494 .await
495 }
496
497 pub async fn update_pr_comment(
502 &self,
503 owner: &str,
504 repo: &str,
505 comment_id: u64,
506 comment: crate::types::UpdateComment,
507 ) -> Result<crate::types::IssueComment> {
508 self.patch(
509 &format!("/repos/{owner}/{repo}/issues/comments/{comment_id}"),
510 &comment,
511 )
512 .await
513 }
514}
515
516impl std::fmt::Debug for GitHubClient {
517 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
518 f.debug_struct("GitHubClient")
519 .field("base_url", &self.base_url)
520 .field("token", &"[redacted]")
521 .finish_non_exhaustive()
522 }
523}