1use reqwest::Client;
4use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
5use secrecy::{ExposeSecret, SecretString};
6use serde::de::DeserializeOwned;
7
8use crate::auth::Auth;
9use crate::error::{Error, Result};
10use crate::types::{
11 CheckRun, CreatePullRequest, MergePullRequest, MergeResult, PullRequest, PullRequestState,
12 UpdatePullRequest,
13};
14
15#[derive(serde::Deserialize)]
19struct ApiPullRequest {
20 number: u64,
21 title: String,
22 body: Option<String>,
23 state: String,
24 #[serde(default)]
26 merged: bool,
27 draft: bool,
28 html_url: String,
29 head: ApiBranch,
30 base: ApiBranch,
31}
32
33#[derive(serde::Deserialize)]
35struct ApiBranch {
36 #[serde(rename = "ref")]
37 ref_name: String,
38}
39
40impl ApiPullRequest {
41 fn into_pull_request(self) -> PullRequest {
43 let state = if self.merged {
45 PullRequestState::Merged
46 } else {
47 match self.state.as_str() {
48 "open" => PullRequestState::Open,
49 _ => PullRequestState::Closed,
50 }
51 };
52
53 PullRequest {
54 number: self.number,
55 title: self.title,
56 body: self.body,
57 state,
58 draft: self.draft,
59 head_branch: self.head.ref_name,
60 base_branch: self.base.ref_name,
61 html_url: self.html_url,
62 }
63 }
64
65 fn into_pull_request_with_state(self, state: PullRequestState) -> PullRequest {
67 PullRequest {
68 number: self.number,
69 title: self.title,
70 body: self.body,
71 state,
72 draft: self.draft,
73 head_branch: self.head.ref_name,
74 base_branch: self.base.ref_name,
75 html_url: self.html_url,
76 }
77 }
78}
79
80pub struct GitHubClient {
82 client: Client,
83 base_url: String,
84 token: SecretString,
86}
87
88impl GitHubClient {
89 pub const DEFAULT_API_URL: &'static str = "https://api.github.com";
91
92 pub fn new(auth: &Auth) -> Result<Self> {
97 Self::with_base_url(auth, Self::DEFAULT_API_URL)
98 }
99
100 pub fn with_base_url(auth: &Auth, base_url: impl Into<String>) -> Result<Self> {
105 let token = auth.resolve()?;
106
107 let mut headers = HeaderMap::new();
108 headers.insert(
109 ACCEPT,
110 HeaderValue::from_static("application/vnd.github+json"),
111 );
112 headers.insert(USER_AGENT, HeaderValue::from_static("rung-cli"));
113 headers.insert(
114 "X-GitHub-Api-Version",
115 HeaderValue::from_static("2022-11-28"),
116 );
117
118 let client = Client::builder().default_headers(headers).build()?;
119
120 Ok(Self {
121 client,
122 base_url: base_url.into(),
123 token,
124 })
125 }
126
127 async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
129 let url = format!("{}{}", self.base_url, path);
130 let response = self
131 .client
132 .get(&url)
133 .header(
134 AUTHORIZATION,
135 format!("Bearer {}", self.token.expose_secret()),
136 )
137 .send()
138 .await?;
139
140 self.handle_response(response).await
141 }
142
143 async fn post<T: DeserializeOwned, B: serde::Serialize + Sync>(
145 &self,
146 path: &str,
147 body: &B,
148 ) -> Result<T> {
149 let url = format!("{}{}", self.base_url, path);
150 let response = self
151 .client
152 .post(&url)
153 .header(
154 AUTHORIZATION,
155 format!("Bearer {}", self.token.expose_secret()),
156 )
157 .json(body)
158 .send()
159 .await?;
160
161 self.handle_response(response).await
162 }
163
164 async fn patch<T: DeserializeOwned, B: serde::Serialize + Sync>(
166 &self,
167 path: &str,
168 body: &B,
169 ) -> Result<T> {
170 let url = format!("{}{}", self.base_url, path);
171 let response = self
172 .client
173 .patch(&url)
174 .header(
175 AUTHORIZATION,
176 format!("Bearer {}", self.token.expose_secret()),
177 )
178 .json(body)
179 .send()
180 .await?;
181
182 self.handle_response(response).await
183 }
184
185 async fn put<T: DeserializeOwned, B: serde::Serialize + Sync>(
187 &self,
188 path: &str,
189 body: &B,
190 ) -> Result<T> {
191 let url = format!("{}{}", self.base_url, path);
192 let response = self
193 .client
194 .put(&url)
195 .header(
196 AUTHORIZATION,
197 format!("Bearer {}", self.token.expose_secret()),
198 )
199 .json(body)
200 .send()
201 .await?;
202
203 self.handle_response(response).await
204 }
205
206 async fn delete(&self, path: &str) -> Result<()> {
208 let url = format!("{}{}", self.base_url, path);
209 let response = self
210 .client
211 .delete(&url)
212 .header(
213 AUTHORIZATION,
214 format!("Bearer {}", self.token.expose_secret()),
215 )
216 .send()
217 .await?;
218
219 let status = response.status();
220 if status.is_success() || status.as_u16() == 204 {
221 return Ok(());
222 }
223
224 let status_code = status.as_u16();
225 match status_code {
226 401 => Err(Error::AuthenticationFailed),
227 403 if response
228 .headers()
229 .get("x-ratelimit-remaining")
230 .is_some_and(|v| v == "0") =>
231 {
232 Err(Error::RateLimited)
233 }
234 _ => {
235 let text = response.text().await.unwrap_or_default();
236 Err(Error::ApiError {
237 status: status_code,
238 message: text,
239 })
240 }
241 }
242 }
243
244 async fn handle_response<T: DeserializeOwned>(&self, response: reqwest::Response) -> Result<T> {
246 let status = response.status();
247
248 if status.is_success() {
249 let body = response.json().await?;
250 return Ok(body);
251 }
252
253 let status_code = status.as_u16();
255
256 match status_code {
257 401 => Err(Error::AuthenticationFailed),
258 403 if response
259 .headers()
260 .get("x-ratelimit-remaining")
261 .is_some_and(|v| v == "0") =>
262 {
263 Err(Error::RateLimited)
264 }
265 _ => {
266 let text = response.text().await.unwrap_or_default();
267 Err(Error::ApiError {
268 status: status_code,
269 message: text,
270 })
271 }
272 }
273 }
274
275 pub async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
282 let api_pr: ApiPullRequest = self
283 .get(&format!("/repos/{owner}/{repo}/pulls/{number}"))
284 .await?;
285
286 Ok(api_pr.into_pull_request())
287 }
288
289 pub async fn find_pr_for_branch(
294 &self,
295 owner: &str,
296 repo: &str,
297 branch: &str,
298 ) -> Result<Option<PullRequest>> {
299 let prs: Vec<ApiPullRequest> = self
301 .get(&format!(
302 "/repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open"
303 ))
304 .await?;
305
306 Ok(prs
307 .into_iter()
308 .next()
309 .map(|api_pr| api_pr.into_pull_request_with_state(PullRequestState::Open)))
310 }
311
312 pub async fn create_pr(
317 &self,
318 owner: &str,
319 repo: &str,
320 pr: CreatePullRequest,
321 ) -> Result<PullRequest> {
322 let api_pr: ApiPullRequest = self
324 .post(&format!("/repos/{owner}/{repo}/pulls"), &pr)
325 .await?;
326
327 Ok(api_pr.into_pull_request_with_state(PullRequestState::Open))
328 }
329
330 pub async fn update_pr(
335 &self,
336 owner: &str,
337 repo: &str,
338 number: u64,
339 update: UpdatePullRequest,
340 ) -> Result<PullRequest> {
341 let api_pr: ApiPullRequest = self
342 .patch(&format!("/repos/{owner}/{repo}/pulls/{number}"), &update)
343 .await?;
344
345 Ok(api_pr.into_pull_request())
346 }
347
348 pub async fn get_check_runs(
355 &self,
356 owner: &str,
357 repo: &str,
358 commit_sha: &str,
359 ) -> Result<Vec<CheckRun>> {
360 #[derive(serde::Deserialize)]
361 struct Response {
362 check_runs: Vec<ApiCheckRun>,
363 }
364
365 #[derive(serde::Deserialize)]
366 struct ApiCheckRun {
367 name: String,
368 status: String,
369 conclusion: Option<String>,
370 details_url: Option<String>,
371 }
372
373 let response: Response = self
374 .get(&format!(
375 "/repos/{owner}/{repo}/commits/{commit_sha}/check-runs"
376 ))
377 .await?;
378
379 Ok(response
380 .check_runs
381 .into_iter()
382 .map(|cr| CheckRun {
383 name: cr.name,
384 status: match (cr.status.as_str(), cr.conclusion.as_deref()) {
385 ("queued", _) => crate::types::CheckStatus::Queued,
386 ("in_progress", _) => crate::types::CheckStatus::InProgress,
387 ("completed", Some("success")) => crate::types::CheckStatus::Success,
388 ("completed", Some("skipped")) => crate::types::CheckStatus::Skipped,
389 ("completed", Some("cancelled")) => crate::types::CheckStatus::Cancelled,
390 _ => crate::types::CheckStatus::Failure,
392 },
393 details_url: cr.details_url,
394 })
395 .collect())
396 }
397
398 pub async fn merge_pr(
405 &self,
406 owner: &str,
407 repo: &str,
408 number: u64,
409 merge: MergePullRequest,
410 ) -> Result<MergeResult> {
411 self.put(
412 &format!("/repos/{owner}/{repo}/pulls/{number}/merge"),
413 &merge,
414 )
415 .await
416 }
417
418 pub async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
425 self.delete(&format!("/repos/{owner}/{repo}/git/refs/heads/{ref_name}"))
426 .await
427 }
428
429 pub async fn list_pr_comments(
436 &self,
437 owner: &str,
438 repo: &str,
439 pr_number: u64,
440 ) -> Result<Vec<crate::types::IssueComment>> {
441 self.get(&format!(
442 "/repos/{owner}/{repo}/issues/{pr_number}/comments"
443 ))
444 .await
445 }
446
447 pub async fn create_pr_comment(
452 &self,
453 owner: &str,
454 repo: &str,
455 pr_number: u64,
456 comment: crate::types::CreateComment,
457 ) -> Result<crate::types::IssueComment> {
458 self.post(
459 &format!("/repos/{owner}/{repo}/issues/{pr_number}/comments"),
460 &comment,
461 )
462 .await
463 }
464
465 pub async fn update_pr_comment(
470 &self,
471 owner: &str,
472 repo: &str,
473 comment_id: u64,
474 comment: crate::types::UpdateComment,
475 ) -> Result<crate::types::IssueComment> {
476 self.patch(
477 &format!("/repos/{owner}/{repo}/issues/comments/{comment_id}"),
478 &comment,
479 )
480 .await
481 }
482}
483
484impl std::fmt::Debug for GitHubClient {
485 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
486 f.debug_struct("GitHubClient")
487 .field("base_url", &self.base_url)
488 .field("token", &"[redacted]")
489 .finish_non_exhaustive()
490 }
491}