1use reqwest::Client;
4use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderMap, HeaderValue, USER_AGENT};
5use serde::de::DeserializeOwned;
6
7use crate::auth::Auth;
8use crate::error::{Error, Result};
9use crate::types::{
10 CheckRun, CreatePullRequest, MergePullRequest, MergeResult, PullRequest, PullRequestState,
11 UpdatePullRequest,
12};
13
14#[derive(serde::Deserialize)]
18struct ApiPullRequest {
19 number: u64,
20 title: String,
21 body: Option<String>,
22 state: String,
23 #[serde(default)]
25 merged: bool,
26 draft: bool,
27 html_url: String,
28 head: ApiBranch,
29 base: ApiBranch,
30}
31
32#[derive(serde::Deserialize)]
34struct ApiBranch {
35 #[serde(rename = "ref")]
36 ref_name: String,
37}
38
39impl ApiPullRequest {
40 fn into_pull_request(self) -> PullRequest {
42 let state = if self.merged {
44 PullRequestState::Merged
45 } else {
46 match self.state.as_str() {
47 "open" => PullRequestState::Open,
48 _ => PullRequestState::Closed,
49 }
50 };
51
52 PullRequest {
53 number: self.number,
54 title: self.title,
55 body: self.body,
56 state,
57 draft: self.draft,
58 head_branch: self.head.ref_name,
59 base_branch: self.base.ref_name,
60 html_url: self.html_url,
61 }
62 }
63
64 fn into_pull_request_with_state(self, state: PullRequestState) -> PullRequest {
66 PullRequest {
67 number: self.number,
68 title: self.title,
69 body: self.body,
70 state,
71 draft: self.draft,
72 head_branch: self.head.ref_name,
73 base_branch: self.base.ref_name,
74 html_url: self.html_url,
75 }
76 }
77}
78
79pub struct GitHubClient {
81 client: Client,
82 base_url: String,
83 token: String,
84}
85
86impl GitHubClient {
87 pub const DEFAULT_API_URL: &'static str = "https://api.github.com";
89
90 pub fn new(auth: &Auth) -> Result<Self> {
95 Self::with_base_url(auth, Self::DEFAULT_API_URL)
96 }
97
98 pub fn with_base_url(auth: &Auth, base_url: impl Into<String>) -> Result<Self> {
103 let token = auth.resolve()?;
104
105 let mut headers = HeaderMap::new();
106 headers.insert(
107 ACCEPT,
108 HeaderValue::from_static("application/vnd.github+json"),
109 );
110 headers.insert(USER_AGENT, HeaderValue::from_static("rung-cli"));
111 headers.insert(
112 "X-GitHub-Api-Version",
113 HeaderValue::from_static("2022-11-28"),
114 );
115
116 let client = Client::builder().default_headers(headers).build()?;
117
118 Ok(Self {
119 client,
120 base_url: base_url.into(),
121 token,
122 })
123 }
124
125 async fn get<T: DeserializeOwned>(&self, path: &str) -> Result<T> {
127 let url = format!("{}{}", self.base_url, path);
128 let response = self
129 .client
130 .get(&url)
131 .header(AUTHORIZATION, format!("Bearer {}", self.token))
132 .send()
133 .await?;
134
135 self.handle_response(response).await
136 }
137
138 async fn post<T: DeserializeOwned, B: serde::Serialize + Sync>(
140 &self,
141 path: &str,
142 body: &B,
143 ) -> Result<T> {
144 let url = format!("{}{}", self.base_url, path);
145 let response = self
146 .client
147 .post(&url)
148 .header(AUTHORIZATION, format!("Bearer {}", self.token))
149 .json(body)
150 .send()
151 .await?;
152
153 self.handle_response(response).await
154 }
155
156 async fn patch<T: DeserializeOwned, B: serde::Serialize + Sync>(
158 &self,
159 path: &str,
160 body: &B,
161 ) -> Result<T> {
162 let url = format!("{}{}", self.base_url, path);
163 let response = self
164 .client
165 .patch(&url)
166 .header(AUTHORIZATION, format!("Bearer {}", self.token))
167 .json(body)
168 .send()
169 .await?;
170
171 self.handle_response(response).await
172 }
173
174 async fn put<T: DeserializeOwned, B: serde::Serialize + Sync>(
176 &self,
177 path: &str,
178 body: &B,
179 ) -> Result<T> {
180 let url = format!("{}{}", self.base_url, path);
181 let response = self
182 .client
183 .put(&url)
184 .header(AUTHORIZATION, format!("Bearer {}", self.token))
185 .json(body)
186 .send()
187 .await?;
188
189 self.handle_response(response).await
190 }
191
192 async fn delete(&self, path: &str) -> Result<()> {
194 let url = format!("{}{}", self.base_url, path);
195 let response = self
196 .client
197 .delete(&url)
198 .header(AUTHORIZATION, format!("Bearer {}", self.token))
199 .send()
200 .await?;
201
202 let status = response.status();
203 if status.is_success() || status.as_u16() == 204 {
204 return Ok(());
205 }
206
207 let status_code = status.as_u16();
208 match status_code {
209 401 => Err(Error::AuthenticationFailed),
210 403 if response
211 .headers()
212 .get("x-ratelimit-remaining")
213 .is_some_and(|v| v == "0") =>
214 {
215 Err(Error::RateLimited)
216 }
217 _ => {
218 let text = response.text().await.unwrap_or_default();
219 Err(Error::ApiError {
220 status: status_code,
221 message: text,
222 })
223 }
224 }
225 }
226
227 async fn handle_response<T: DeserializeOwned>(&self, response: reqwest::Response) -> Result<T> {
229 let status = response.status();
230
231 if status.is_success() {
232 let body = response.json().await?;
233 return Ok(body);
234 }
235
236 let status_code = status.as_u16();
238
239 match status_code {
240 401 => Err(Error::AuthenticationFailed),
241 403 if response
242 .headers()
243 .get("x-ratelimit-remaining")
244 .is_some_and(|v| v == "0") =>
245 {
246 Err(Error::RateLimited)
247 }
248 _ => {
249 let text = response.text().await.unwrap_or_default();
250 Err(Error::ApiError {
251 status: status_code,
252 message: text,
253 })
254 }
255 }
256 }
257
258 pub async fn get_pr(&self, owner: &str, repo: &str, number: u64) -> Result<PullRequest> {
265 let api_pr: ApiPullRequest = self
266 .get(&format!("/repos/{owner}/{repo}/pulls/{number}"))
267 .await?;
268
269 Ok(api_pr.into_pull_request())
270 }
271
272 pub async fn find_pr_for_branch(
277 &self,
278 owner: &str,
279 repo: &str,
280 branch: &str,
281 ) -> Result<Option<PullRequest>> {
282 let prs: Vec<ApiPullRequest> = self
284 .get(&format!(
285 "/repos/{owner}/{repo}/pulls?head={owner}:{branch}&state=open"
286 ))
287 .await?;
288
289 Ok(prs
290 .into_iter()
291 .next()
292 .map(|api_pr| api_pr.into_pull_request_with_state(PullRequestState::Open)))
293 }
294
295 pub async fn create_pr(
300 &self,
301 owner: &str,
302 repo: &str,
303 pr: CreatePullRequest,
304 ) -> Result<PullRequest> {
305 let api_pr: ApiPullRequest = self
307 .post(&format!("/repos/{owner}/{repo}/pulls"), &pr)
308 .await?;
309
310 Ok(api_pr.into_pull_request_with_state(PullRequestState::Open))
311 }
312
313 pub async fn update_pr(
318 &self,
319 owner: &str,
320 repo: &str,
321 number: u64,
322 update: UpdatePullRequest,
323 ) -> Result<PullRequest> {
324 let api_pr: ApiPullRequest = self
325 .patch(&format!("/repos/{owner}/{repo}/pulls/{number}"), &update)
326 .await?;
327
328 Ok(api_pr.into_pull_request())
329 }
330
331 pub async fn get_check_runs(
338 &self,
339 owner: &str,
340 repo: &str,
341 commit_sha: &str,
342 ) -> Result<Vec<CheckRun>> {
343 #[derive(serde::Deserialize)]
344 struct Response {
345 check_runs: Vec<ApiCheckRun>,
346 }
347
348 #[derive(serde::Deserialize)]
349 struct ApiCheckRun {
350 name: String,
351 status: String,
352 conclusion: Option<String>,
353 details_url: Option<String>,
354 }
355
356 let response: Response = self
357 .get(&format!(
358 "/repos/{owner}/{repo}/commits/{commit_sha}/check-runs"
359 ))
360 .await?;
361
362 Ok(response
363 .check_runs
364 .into_iter()
365 .map(|cr| CheckRun {
366 name: cr.name,
367 status: match (cr.status.as_str(), cr.conclusion.as_deref()) {
368 ("queued", _) => crate::types::CheckStatus::Queued,
369 ("in_progress", _) => crate::types::CheckStatus::InProgress,
370 ("completed", Some("success")) => crate::types::CheckStatus::Success,
371 ("completed", Some("skipped")) => crate::types::CheckStatus::Skipped,
372 ("completed", Some("cancelled")) => crate::types::CheckStatus::Cancelled,
373 _ => crate::types::CheckStatus::Failure,
375 },
376 details_url: cr.details_url,
377 })
378 .collect())
379 }
380
381 pub async fn merge_pr(
388 &self,
389 owner: &str,
390 repo: &str,
391 number: u64,
392 merge: MergePullRequest,
393 ) -> Result<MergeResult> {
394 self.put(
395 &format!("/repos/{owner}/{repo}/pulls/{number}/merge"),
396 &merge,
397 )
398 .await
399 }
400
401 pub async fn delete_ref(&self, owner: &str, repo: &str, ref_name: &str) -> Result<()> {
408 self.delete(&format!("/repos/{owner}/{repo}/git/refs/heads/{ref_name}"))
409 .await
410 }
411
412 pub async fn list_pr_comments(
419 &self,
420 owner: &str,
421 repo: &str,
422 pr_number: u64,
423 ) -> Result<Vec<crate::types::IssueComment>> {
424 self.get(&format!(
425 "/repos/{owner}/{repo}/issues/{pr_number}/comments"
426 ))
427 .await
428 }
429
430 pub async fn create_pr_comment(
435 &self,
436 owner: &str,
437 repo: &str,
438 pr_number: u64,
439 comment: crate::types::CreateComment,
440 ) -> Result<crate::types::IssueComment> {
441 self.post(
442 &format!("/repos/{owner}/{repo}/issues/{pr_number}/comments"),
443 &comment,
444 )
445 .await
446 }
447
448 pub async fn update_pr_comment(
453 &self,
454 owner: &str,
455 repo: &str,
456 comment_id: u64,
457 comment: crate::types::UpdateComment,
458 ) -> Result<crate::types::IssueComment> {
459 self.patch(
460 &format!("/repos/{owner}/{repo}/issues/comments/{comment_id}"),
461 &comment,
462 )
463 .await
464 }
465}
466
467impl std::fmt::Debug for GitHubClient {
468 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
469 f.debug_struct("GitHubClient")
470 .field("base_url", &self.base_url)
471 .field("token", &"[redacted]")
472 .finish_non_exhaustive()
473 }
474}