1use reqwest::Client;
4use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
5use serde::de::DeserializeOwned;
6
7use crate::auth::Auth;
8use crate::error::{Error, Result};
9use crate::types::{
10 CheckRun, CreatePullRequest, MergePullRequest, MergeResult, PullRequest, PullRequestState,
11 UpdatePullRequest,
12};
13
14#[derive(serde::Deserialize)]
18struct ApiPullRequest {
19 number: u64,
20 title: String,
21 body: Option<String>,
22 state: String,
23 draft: bool,
24 html_url: String,
25 head: ApiBranch,
26 base: ApiBranch,
27}
28
29#[derive(serde::Deserialize)]
31struct ApiBranch {
32 #[serde(rename = "ref")]
33 ref_name: String,
34}
35
36impl ApiPullRequest {
37 fn into_pull_request(self) -> PullRequest {
39 PullRequest {
40 number: self.number,
41 title: self.title,
42 body: self.body,
43 state: match self.state.as_str() {
44 "open" => PullRequestState::Open,
45 "merged" => PullRequestState::Merged,
46 _ => PullRequestState::Closed,
47 },
48 draft: self.draft,
49 head_branch: self.head.ref_name,
50 base_branch: self.base.ref_name,
51 html_url: self.html_url,
52 }
53 }
54
55 fn into_pull_request_with_state(self, state: PullRequestState) -> PullRequest {
57 PullRequest {
58 number: self.number,
59 title: self.title,
60 body: self.body,
61 state,
62 draft: self.draft,
63 head_branch: self.head.ref_name,
64 base_branch: self.base.ref_name,
65 html_url: self.html_url,
66 }
67 }
68}
69
70pub struct GitHubClient {
72 client: Client,
73 base_url: String,
74 token: String,
75}
76
77impl GitHubClient {
78 pub const DEFAULT_API_URL: &'static str = "https://api.github.com";
80
81 pub fn new(auth: &Auth) -> Result<Self> {
86 Self::with_base_url(auth, Self::DEFAULT_API_URL)
87 }
88
89 pub fn with_base_url(auth: &Auth, base_url: impl Into<String>) -> Result<Self> {
94 let token = auth.resolve()?;
95
96 let mut headers = HeaderMap::new();
97 headers.insert(
98 ACCEPT,
99 HeaderValue::from_static("application/vnd.github+json"),
100 );
101 headers.insert(USER_AGENT, HeaderValue::from_static("rung-cli"));
102 headers.insert(
103 "X-GitHub-Api-Version",
104 HeaderValue::from_static("2022-11-28"),
105 );
106
107 let client = Client::builder().default_headers(headers).build()?;
108
109 Ok(Self {
110 client,
111 base_url: base_url.into(),
112 token,
113 })
114 }
115
116 async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
118 let url = format!("{}{}", self.base_url, path);
119 let response = self
120 .client
121 .get(&url)
122 .header(AUTHORIZATION, format!("Bearer {}", self.token))
123 .send()
124 .await?;
125
126 self.handle_response(response).await
127 }
128
129 async fn post<T: DeserializeOwned, B: serde::Serialize + Sync>(
131 &self,
132 path: &str,
133 body: &B,
134 ) -> Result<T> {
135 let url = format!("{}{}", self.base_url, path);
136 let response = self
137 .client
138 .post(&url)
139 .header(AUTHORIZATION, format!("Bearer {}", self.token))
140 .json(body)
141 .send()
142 .await?;
143
144 self.handle_response(response).await
145 }
146
147 async fn patch<T: DeserializeOwned, B: serde::Serialize + Sync>(
149 &self,
150 path: &str,
151 body: &B,
152 ) -> Result<T> {
153 let url = format!("{}{}", self.base_url, path);
154 let response = self
155 .client
156 .patch(&url)
157 .header(AUTHORIZATION, format!("Bearer {}", self.token))
158 .json(body)
159 .send()
160 .await?;
161
162 self.handle_response(response).await
163 }
164
165 async fn put<T: DeserializeOwned, B: serde::Serialize + Sync>(
167 &self,
168 path: &str,
169 body: &B,
170 ) -> Result<T> {
171 let url = format!("{}{}", self.base_url, path);
172 let response = self
173 .client
174 .put(&url)
175 .header(AUTHORIZATION, format!("Bearer {}", self.token))
176 .json(body)
177 .send()
178 .await?;
179
180 self.handle_response(response).await
181 }
182
183 async fn delete(&self, path: &str) -> Result<()> {
185 let url = format!("{}{}", self.base_url, path);
186 let response = self
187 .client
188 .delete(&url)
189 .header(AUTHORIZATION, format!("Bearer {}", self.token))
190 .send()
191 .await?;
192
193 let status = response.status();
194 if status.is_success() || status.as_u16() == 204 {
195 return Ok(());
196 }
197
198 let status_code = status.as_u16();
199 match status_code {
200 401 => Err(Error::AuthenticationFailed),
201 403 if response
202 .headers()
203 .get("x-ratelimit-remaining")
204 .is_some_and(|v| v == "0") =>
205 {
206 Err(Error::RateLimited)
207 }
208 _ => {
209 let text = response.text().await.unwrap_or_default();
210 Err(Error::ApiError {
211 status: status_code,
212 message: text,
213 })
214 }
215 }
216 }
217
218 async fn handle_response<T: DeserializeOwned>(&self, response: reqwest::Response) -> Result<T> {
220 let status = response.status();
221
222 if status.is_success() {
223 let body = response.json().await?;
224 return Ok(body);
225 }
226
227 let status_code = status.as_u16();
229
230 match status_code {
231 401 => Err(Error::AuthenticationFailed),
232 403 if response
233 .headers()
234 .get("x-ratelimit-remaining")
235 .is_some_and(|v| v == "0") =>
236 {
237 Err(Error::RateLimited)
238 }
239 _ => {
240 let text = response.text().await.unwrap_or_default();
241 Err(Error::ApiError {
242 status: status_code,
243 message: text,
244 })
245 }
246 }
247 }
248
249 pub async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
256 let api_pr: ApiPullRequest = self
257 .get(&format!("/repos/{owner}/{repo}/pulls/{number}"))
258 .await?;
259
260 Ok(api_pr.into_pull_request())
261 }
262
263 pub async fn find_pr_for_branch(
268 &self,
269 owner: &str,
270 repo: &str,
271 branch: &str,
272 ) -> Result<Option<PullRequest>> {
273 let prs: Vec<ApiPullRequest> = self
275 .get(&format!(
276 "/repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open"
277 ))
278 .await?;
279
280 Ok(prs
281 .into_iter()
282 .next()
283 .map(|api_pr| api_pr.into_pull_request_with_state(PullRequestState::Open)))
284 }
285
286 pub async fn create_pr(
291 &self,
292 owner: &str,
293 repo: &str,
294 pr: CreatePullRequest,
295 ) -> Result<PullRequest> {
296 let api_pr: ApiPullRequest = self
298 .post(&format!("/repos/{owner}/{repo}/pulls"), &pr)
299 .await?;
300
301 Ok(api_pr.into_pull_request_with_state(PullRequestState::Open))
302 }
303
304 pub async fn update_pr(
309 &self,
310 owner: &str,
311 repo: &str,
312 number: u64,
313 update: UpdatePullRequest,
314 ) -> Result<PullRequest> {
315 let api_pr: ApiPullRequest = self
316 .patch(&format!("/repos/{owner}/{repo}/pulls/{number}"), &update)
317 .await?;
318
319 Ok(api_pr.into_pull_request())
320 }
321
322 pub async fn get_check_runs(
329 &self,
330 owner: &str,
331 repo: &str,
332 commit_sha: &str,
333 ) -> Result<Vec<CheckRun>> {
334 #[derive(serde::Deserialize)]
335 struct Response {
336 check_runs: Vec<ApiCheckRun>,
337 }
338
339 #[derive(serde::Deserialize)]
340 struct ApiCheckRun {
341 name: String,
342 status: String,
343 conclusion: Option<String>,
344 details_url: Option<String>,
345 }
346
347 let response: Response = self
348 .get(&format!(
349 "/repos/{owner}/{repo}/commits/{commit_sha}/check-runs"
350 ))
351 .await?;
352
353 Ok(response
354 .check_runs
355 .into_iter()
356 .map(|cr| CheckRun {
357 name: cr.name,
358 status: match (cr.status.as_str(), cr.conclusion.as_deref()) {
359 ("queued", _) => crate::types::CheckStatus::Queued,
360 ("in_progress", _) => crate::types::CheckStatus::InProgress,
361 ("completed", Some("success")) => crate::types::CheckStatus::Success,
362 ("completed", Some("skipped")) => crate::types::CheckStatus::Skipped,
363 ("completed", Some("cancelled")) => crate::types::CheckStatus::Cancelled,
364 _ => crate::types::CheckStatus::Failure,
366 },
367 details_url: cr.details_url,
368 })
369 .collect())
370 }
371
372 pub async fn merge_pr(
379 &self,
380 owner: &str,
381 repo: &str,
382 number: u64,
383 merge: MergePullRequest,
384 ) -> Result<MergeResult> {
385 self.put(
386 &format!("/repos/{owner}/{repo}/pulls/{number}/merge"),
387 &merge,
388 )
389 .await
390 }
391
392 pub async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
399 self.delete(&format!("/repos/{owner}/{repo}/git/refs/heads/{ref_name}"))
400 .await
401 }
402
403 pub async fn list_pr_comments(
410 &self,
411 owner: &str,
412 repo: &str,
413 pr_number: u64,
414 ) -> Result<Vec<crate::types::IssueComment>> {
415 self.get(&format!(
416 "/repos/{owner}/{repo}/issues/{pr_number}/comments"
417 ))
418 .await
419 }
420
421 pub async fn create_pr_comment(
426 &self,
427 owner: &str,
428 repo: &str,
429 pr_number: u64,
430 comment: crate::types::CreateComment,
431 ) -> Result<crate::types::IssueComment> {
432 self.post(
433 &format!("/repos/{owner}/{repo}/issues/{pr_number}/comments"),
434 &comment,
435 )
436 .await
437 }
438
439 pub async fn update_pr_comment(
444 &self,
445 owner: &str,
446 repo: &str,
447 comment_id: u64,
448 comment: crate::types::UpdateComment,
449 ) -> Result<crate::types::IssueComment> {
450 self.patch(
451 &format!("/repos/{owner}/{repo}/issues/comments/{comment_id}"),
452 &comment,
453 )
454 .await
455 }
456}
457
458impl std::fmt::Debug for GitHubClient {
459 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
460 f.debug_struct("GitHubClient")
461 .field("base_url", &self.base_url)
462 .field("token", &"[redacted]")
463 .finish_non_exhaustive()
464 }
465}