1use std::sync::Arc;
2
3use crate::auth::AuthStrategy;
4use crate::error::{GitHubError, Result};
5use crate::types::*;
6use base64::{engine::general_purpose::STANDARD, Engine};
7use reqwest::{header, Client as HttpClient, Response, StatusCode};
8use serde::de::DeserializeOwned;
9use serde_json::json;
10use tracing::debug;
11
12const GITHUB_API_URL: &str = "https://api.github.com";
13const GITHUB_API_VERSION: &str = "2022-11-28";
14
15pub struct ClientBuilder<A> {
16 auth: A,
17 base_url: String,
18 user_agent: String,
19}
20
21impl ClientBuilder<()> {
22 pub fn new() -> Self {
23 Self {
24 auth: (),
25 base_url: GITHUB_API_URL.to_string(),
26 user_agent: "lib-github-client".to_string(),
27 }
28 }
29}
30
31impl Default for ClientBuilder<()> {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl<A> ClientBuilder<A> {
38 pub fn auth<S: AuthStrategy + 'static>(self, auth: S) -> ClientBuilder<S> {
39 ClientBuilder {
40 auth,
41 base_url: self.base_url,
42 user_agent: self.user_agent,
43 }
44 }
45
46 pub fn base_url(mut self, url: impl Into<String>) -> Self {
47 self.base_url = url.into();
48 self
49 }
50
51 pub fn user_agent(mut self, agent: impl Into<String>) -> Self {
52 self.user_agent = agent.into();
53 self
54 }
55}
56
57impl<A: AuthStrategy + 'static> ClientBuilder<A> {
58 pub fn build(self) -> Result<Client> {
59 let mut headers = header::HeaderMap::new();
60 headers.insert(
61 header::ACCEPT,
62 header::HeaderValue::from_static("application/vnd.github+json"),
63 );
64 headers.insert(
65 "X-GitHub-Api-Version",
66 header::HeaderValue::from_static(GITHUB_API_VERSION),
67 );
68
69 let http = HttpClient::builder()
70 .default_headers(headers)
71 .user_agent(&self.user_agent)
72 .build()?;
73
74 Ok(Client {
75 http,
76 auth: Arc::new(self.auth),
77 base_url: self.base_url,
78 })
79 }
80}
81
82pub struct Client {
83 http: HttpClient,
84 auth: Arc<dyn AuthStrategy>,
85 base_url: String,
86}
87
88impl Client {
89 pub fn builder() -> ClientBuilder<()> {
90 ClientBuilder::new()
91 }
92
93 pub fn new(auth: impl AuthStrategy + 'static) -> Result<Self> {
94 Self::builder().auth(auth).build()
95 }
96
97 async fn request<T: DeserializeOwned>(&self, method: reqwest::Method, path: &str) -> Result<T> {
98 let url = self.url(path);
99 debug!("{} {}", method, url);
100
101 let mut headers = header::HeaderMap::new();
102 self.auth.apply(&mut headers).await?;
103
104 let response = self
105 .http
106 .request(method, &url)
107 .headers(headers)
108 .send()
109 .await?;
110
111 self.handle_response(response).await
112 }
113
114 async fn request_with_body<T: DeserializeOwned, B: serde::Serialize>(
115 &self,
116 method: reqwest::Method,
117 path: &str,
118 body: &B,
119 ) -> Result<T> {
120 let url = self.url(path);
121 debug!("{} {}", method, url);
122
123 let mut headers = header::HeaderMap::new();
124 self.auth.apply(&mut headers).await?;
125
126 let response = self
127 .http
128 .request(method, &url)
129 .headers(headers)
130 .json(body)
131 .send()
132 .await?;
133
134 self.handle_response(response).await
135 }
136
137 async fn handle_response<T: DeserializeOwned>(&self, response: Response) -> Result<T> {
138 let status = response.status();
139
140 if status.is_success() {
141 return Ok(response.json().await?);
142 }
143
144 let body = response.text().await.unwrap_or_default();
145 debug!("GitHub API error response: {}", body);
146
147 match status {
148 StatusCode::UNAUTHORIZED => Err(GitHubError::Unauthorized),
149 StatusCode::FORBIDDEN => {
150 if body.contains("rate limit") {
151 Err(GitHubError::RateLimited { retry_after: 60 })
152 } else {
153 Err(GitHubError::Forbidden)
154 }
155 }
156 StatusCode::NOT_FOUND => Err(GitHubError::NotFound(body)),
157 _ => Err(GitHubError::Api {
158 status: status.as_u16(),
159 message: body,
160 }),
161 }
162 }
163
164 fn url(&self, path: &str) -> String {
165 format!("{}{}", self.base_url, path)
166 }
167
168 pub async fn get_repo(&self, owner: &str, repo: &str) -> Result<Repository> {
171 self.request(reqwest::Method::GET, &format!("/repos/{}/{}", owner, repo))
172 .await
173 }
174
175 pub async fn list_branches(&self, owner: &str, repo: &str) -> Result<Vec<Branch>> {
176 self.request(
177 reqwest::Method::GET,
178 &format!("/repos/{}/{}/branches", owner, repo),
179 )
180 .await
181 }
182
183 pub async fn get_branch(&self, owner: &str, repo: &str, branch: &str) -> Result<Branch> {
184 self.request(
185 reqwest::Method::GET,
186 &format!("/repos/{}/{}/branches/{}", owner, repo, branch),
187 )
188 .await
189 }
190
191 pub async fn get_content(
194 &self,
195 owner: &str,
196 repo: &str,
197 path: &str,
198 git_ref: Option<&str>,
199 ) -> Result<FileContent> {
200 let path = match git_ref {
201 Some(r) => format!("/repos/{}/{}/contents/{}?ref={}", owner, repo, path, r),
202 None => format!("/repos/{}/{}/contents/{}", owner, repo, path),
203 };
204 self.request(reqwest::Method::GET, &path).await
205 }
206
207 #[allow(clippy::too_many_arguments)]
208 pub async fn create_or_update_file(
209 &self,
210 owner: &str,
211 repo: &str,
212 path: &str,
213 message: &str,
214 content: &str,
215 sha: Option<&str>,
216 branch: Option<&str>,
217 ) -> Result<serde_json::Value> {
218 let content_base64 = STANDARD.encode(content.as_bytes());
219
220 let mut body = json!({
221 "message": message,
222 "content": content_base64,
223 });
224
225 if let Some(s) = sha {
226 body["sha"] = json!(s);
227 }
228 if let Some(b) = branch {
229 body["branch"] = json!(b);
230 }
231
232 self.request_with_body(
233 reqwest::Method::PUT,
234 &format!("/repos/{}/{}/contents/{}", owner, repo, path),
235 &body,
236 )
237 .await
238 }
239
240 pub async fn get_ref(&self, owner: &str, repo: &str, ref_path: &str) -> Result<Reference> {
243 self.request(
244 reqwest::Method::GET,
245 &format!("/repos/{}/{}/git/ref/{}", owner, repo, ref_path),
246 )
247 .await
248 }
249
250 pub async fn create_ref(
251 &self,
252 owner: &str,
253 repo: &str,
254 ref_name: &str,
255 sha: &str,
256 ) -> Result<Reference> {
257 let body = json!({
258 "ref": ref_name,
259 "sha": sha,
260 });
261
262 self.request_with_body(
263 reqwest::Method::POST,
264 &format!("/repos/{}/{}/git/refs", owner, repo),
265 &body,
266 )
267 .await
268 }
269
270 pub async fn update_ref(
271 &self,
272 owner: &str,
273 repo: &str,
274 ref_path: &str,
275 sha: &str,
276 force: bool,
277 ) -> Result<Reference> {
278 let body = json!({
279 "sha": sha,
280 "force": force,
281 });
282
283 self.request_with_body(
284 reqwest::Method::PATCH,
285 &format!("/repos/{}/{}/git/refs/{}", owner, repo, ref_path),
286 &body,
287 )
288 .await
289 }
290
291 pub async fn get_tree(
292 &self,
293 owner: &str,
294 repo: &str,
295 tree_sha: &str,
296 recursive: bool,
297 ) -> Result<Tree> {
298 let path = if recursive {
299 format!(
300 "/repos/{}/{}/git/trees/{}?recursive=1",
301 owner, repo, tree_sha
302 )
303 } else {
304 format!("/repos/{}/{}/git/trees/{}", owner, repo, tree_sha)
305 };
306 self.request(reqwest::Method::GET, &path).await
307 }
308
309 pub async fn create_tree(
310 &self,
311 owner: &str,
312 repo: &str,
313 base_tree: Option<&str>,
314 entries: Vec<CreateTreeEntry>,
315 ) -> Result<Tree> {
316 let mut body = json!({ "tree": entries });
317 if let Some(base) = base_tree {
318 body["base_tree"] = json!(base);
319 }
320
321 self.request_with_body(
322 reqwest::Method::POST,
323 &format!("/repos/{}/{}/git/trees", owner, repo),
324 &body,
325 )
326 .await
327 }
328
329 pub async fn create_commit(
330 &self,
331 owner: &str,
332 repo: &str,
333 message: &str,
334 tree_sha: &str,
335 parents: Vec<&str>,
336 ) -> Result<serde_json::Value> {
337 let body = json!({
338 "message": message,
339 "tree": tree_sha,
340 "parents": parents,
341 });
342
343 self.request_with_body(
344 reqwest::Method::POST,
345 &format!("/repos/{}/{}/git/commits", owner, repo),
346 &body,
347 )
348 .await
349 }
350
351 pub async fn create_blob(
352 &self,
353 owner: &str,
354 repo: &str,
355 content: &str,
356 encoding: &str,
357 ) -> Result<Blob> {
358 let body = json!({
359 "content": content,
360 "encoding": encoding,
361 });
362
363 self.request_with_body(
364 reqwest::Method::POST,
365 &format!("/repos/{}/{}/git/blobs", owner, repo),
366 &body,
367 )
368 .await
369 }
370
371 pub async fn get_authenticated_user(&self) -> Result<User> {
374 self.request(reqwest::Method::GET, "/user").await
375 }
376
377 pub async fn list_releases(&self, owner: &str, repo: &str) -> Result<Vec<Release>> {
380 self.request(
381 reqwest::Method::GET,
382 &format!("/repos/{}/{}/releases", owner, repo),
383 )
384 .await
385 }
386
387 pub async fn get_latest_release(&self, owner: &str, repo: &str) -> Result<Release> {
388 self.request(
389 reqwest::Method::GET,
390 &format!("/repos/{}/{}/releases/latest", owner, repo),
391 )
392 .await
393 }
394
395 pub async fn get_release_by_tag(&self, owner: &str, repo: &str, tag: &str) -> Result<Release> {
396 self.request(
397 reqwest::Method::GET,
398 &format!("/repos/{}/{}/releases/tags/{}", owner, repo, tag),
399 )
400 .await
401 }
402
403 pub async fn list_release_assets(
404 &self,
405 owner: &str,
406 repo: &str,
407 release_id: u64,
408 ) -> Result<Vec<ReleaseAsset>> {
409 self.request(
410 reqwest::Method::GET,
411 &format!("/repos/{}/{}/releases/{}/assets", owner, repo, release_id),
412 )
413 .await
414 }
415
416 pub async fn download_asset(&self, url: &str) -> Result<bytes::Bytes> {
417 debug!("GET {} (binary)", url);
418
419 let mut headers = header::HeaderMap::new();
420 self.auth.apply(&mut headers).await?;
421 headers.insert(header::ACCEPT, "application/octet-stream".parse().unwrap());
422
423 let response = self.http.get(url).headers(headers).send().await?;
424
425 if !response.status().is_success() {
426 let status = response.status();
427 let body = response.text().await.unwrap_or_default();
428 return Err(GitHubError::Api {
429 status: status.as_u16(),
430 message: body,
431 });
432 }
433
434 Ok(response.bytes().await?)
435 }
436}