1use std::env;
8use std::fmt::Debug;
9use std::iter;
10use std::thread;
11use std::time::Duration;
12
13use graphql_client::{GraphQLQuery, QueryBody, Response};
14use itertools::Itertools;
15use log::{info, warn};
16use reqwest::blocking::Client;
17use reqwest::header::{self, HeaderMap, HeaderValue};
18use reqwest::Url;
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21use thiserror::Error;
22
23use crate::authorization::{CurrentUser, GithubAuthError, GithubAuthorization};
24
25const BACKOFF_LIMIT: usize = if cfg!(test) { 2 } else { 5 };
27const BACKOFF_START: Duration = Duration::from_secs(1);
29const BACKOFF_SCALE: u32 = 2;
31
32#[derive(Debug, Error)]
33#[non_exhaustive]
34pub enum GithubError {
35 #[error("url parse error: {}", source)]
36 UrlParse {
37 #[from]
38 source: url::ParseError,
39 },
40 #[error("invalid `GITHUB_TOKEN`: {}", source)]
41 InvalidToken {
42 #[source]
43 source: env::VarError,
44 },
45 #[error("invalid `GITHUB_ACTOR`: {}", source)]
46 InvalidActor {
47 #[source]
48 source: env::VarError,
49 },
50 #[error("failed to send request to {}: {}", endpoint, source)]
51 SendRequest {
52 endpoint: Url,
53 #[source]
54 source: reqwest::Error,
55 },
56 #[error("github error: {}", response)]
57 Github { response: String },
58 #[error("deserialize error: {}", source)]
59 Deserialize {
60 #[from]
61 source: serde_json::Error,
62 },
63 #[error("github service error: {}", status)]
64 GithubService { status: reqwest::StatusCode },
65 #[error("json response deserialize: {}", source)]
66 JsonResponse {
67 #[source]
68 source: reqwest::Error,
69 },
70 #[allow(clippy::upper_case_acronyms)]
71 #[error("graphql error: [\"{}\"]", message.iter().format("\", \""))]
72 GraphQL { message: Vec<graphql_client::Error> },
73 #[error("no response from github")]
74 NoResponse {},
75 #[error("failure even after exponential backoff")]
76 GithubBackoff {},
77 #[error("authorization error: {}", source)]
78 Authorization {
79 #[from]
80 source: GithubAuthError,
81 },
82}
83
84impl GithubError {
85 fn should_backoff(&self) -> bool {
86 matches!(self, GithubError::GithubService { .. })
87 }
88
89 pub(crate) fn send_request(endpoint: Url, source: reqwest::Error) -> Self {
90 GithubError::SendRequest {
91 endpoint,
92 source,
93 }
94 }
95
96 pub(crate) fn github(response: String) -> Self {
97 GithubError::Github {
98 response,
99 }
100 }
101
102 fn github_service(status: reqwest::StatusCode) -> Self {
103 GithubError::GithubService {
104 status,
105 }
106 }
107
108 pub(crate) fn json_response(source: reqwest::Error) -> Self {
109 GithubError::JsonResponse {
110 source,
111 }
112 }
113
114 pub(crate) fn invalid_token(source: env::VarError) -> Self {
115 GithubError::InvalidToken {
116 source,
117 }
118 }
119
120 pub(crate) fn invalid_actor(source: env::VarError) -> Self {
121 GithubError::InvalidActor {
122 source,
123 }
124 }
125
126 fn graphql(message: Vec<graphql_client::Error>) -> Self {
127 GithubError::GraphQL {
128 message,
129 }
130 }
131
132 fn no_response() -> Self {
133 GithubError::NoResponse {}
134 }
135
136 fn github_backoff() -> Self {
137 GithubError::GithubBackoff {}
138 }
139}
140
141pub(crate) type GithubResult<T> = Result<T, GithubError>;
142
143pub(crate) const USER_AGENT: &str =
145 concat!(env!("CARGO_PKG_NAME"), " v", env!("CARGO_PKG_VERSION"));
146
147#[derive(Clone)]
149pub struct Github {
150 client: Client,
152 rest_endpoint: Url,
154 gql_endpoint: Url,
156
157 authorization: GithubAuthorization,
159}
160
161impl Github {
162 fn new_impl(host: &str, authorization: GithubAuthorization) -> GithubResult<Self> {
163 let rest_endpoint = Url::parse(&format!("https://{host}/"))?;
164 let gql_endpoint = Url::parse(&format!("https://{host}/graphql"))?;
165
166 Ok(Github {
167 client: Client::new(),
168 rest_endpoint,
169 gql_endpoint,
170 authorization,
171 })
172 }
173
174 pub fn new_app<H, P, I, S>(
184 host: H,
185 app_id: i64,
186 private_key: P,
187 installation_ids: I,
188 ) -> GithubResult<Self>
189 where
190 H: AsRef<str>,
191 P: AsRef<[u8]>,
192 I: IntoIterator<Item = (S, i64)>,
193 S: Into<String>,
194 {
195 let ids = installation_ids
196 .into_iter()
197 .map(|(s, i)| (s.into(), i))
198 .collect();
199 let authorization =
200 GithubAuthorization::new_app(host.as_ref(), app_id, private_key.as_ref(), ids)?;
201
202 Self::new_impl(host.as_ref(), authorization)
203 }
204
205 pub fn new_action<H>(host: H) -> GithubResult<Self>
215 where
216 H: AsRef<str>,
217 {
218 let authorization = GithubAuthorization::new_action()?;
219
220 Self::new_impl(host.as_ref(), authorization)
221 }
222
223 pub(crate) fn app_id(&self) -> Option<i64> {
224 self.authorization.app_id()
225 }
226
227 pub(crate) fn current_user(&self) -> GithubResult<CurrentUser> {
228 self.authorization.current_user(&self.client)
229 }
230
231 fn installation_auth_header(&self, owner: &str) -> GithubResult<HeaderMap> {
233 let token = self.authorization.token(&self.client, owner)?;
234 let mut header_value: HeaderValue = format!("token {token}").parse().unwrap();
235 header_value.set_sensitive(true);
236 Ok([(header::AUTHORIZATION, header_value)]
237 .iter()
238 .cloned()
239 .collect())
240 }
241
242 fn rest_accept_headers() -> HeaderMap {
244 [
245 (
247 header::ACCEPT,
248 "application/vnd.github.v3+json".parse().unwrap(),
249 ),
250 ]
251 .iter()
252 .cloned()
253 .collect()
254 }
255
256 fn gql_accept_headers() -> HeaderMap {
260 HeaderMap::new()
261 }
262
263 pub(crate) fn post<D>(&self, owner: &str, endpoint: &str, data: &D) -> GithubResult<Value>
264 where
265 D: Serialize,
266 {
267 let endpoint = Url::parse(&format!("{}{}", self.rest_endpoint, endpoint))?;
268 let rsp = self
269 .client
270 .post(endpoint.clone())
271 .headers(self.installation_auth_header(owner)?)
272 .headers(Self::rest_accept_headers())
273 .header(header::USER_AGENT, USER_AGENT)
274 .json(data)
275 .send()
276 .map_err(|err| GithubError::send_request(endpoint, err))?;
277 if !rsp.status().is_success() {
278 let err = rsp
279 .text()
280 .unwrap_or_else(|text_err| format!("failed to extract error body: {text_err:?}"));
281 return Err(GithubError::github(err));
282 }
283
284 rsp.json().map_err(GithubError::json_response)
285 }
286
287 fn send_impl<Q>(
289 &self,
290 owner: &str,
291 query: &QueryBody<Q::Variables>,
292 ) -> GithubResult<Q::ResponseData>
293 where
294 Q: GraphQLQuery,
295 Q::Variables: Debug,
296 for<'d> Q::ResponseData: Deserialize<'d>,
297 {
298 info!(
299 target: "github",
300 "sending GraphQL query '{}' {:?}",
301 query.operation_name,
302 query.variables,
303 );
304 let rsp = self
305 .client
306 .post(self.gql_endpoint.clone())
307 .headers(self.installation_auth_header(owner)?)
308 .headers(Self::gql_accept_headers())
309 .header(header::USER_AGENT, USER_AGENT)
310 .json(query)
311 .send()
312 .map_err(|err| GithubError::send_request(self.gql_endpoint.clone(), err))?;
313 if rsp.status().is_server_error() {
314 warn!(
315 target: "github",
316 "service error {} for query; retrying with backoff",
317 rsp.status().as_u16(),
318 );
319 return Err(GithubError::github_service(rsp.status()));
320 }
321 if !rsp.status().is_success() {
322 let err = rsp
323 .text()
324 .unwrap_or_else(|text_err| format!("failed to extract error body: {text_err:?}"));
325 return Err(GithubError::github(err));
326 }
327
328 let rsp: Response<Q::ResponseData> = rsp.json().map_err(GithubError::json_response)?;
329 if let Some(errs) = rsp.errors {
330 return Err(GithubError::graphql(errs));
331 }
332 rsp.data.ok_or_else(GithubError::no_response)
333 }
334
335 pub fn send<Q>(
337 &self,
338 owner: &str,
339 query: &QueryBody<Q::Variables>,
340 ) -> GithubResult<Q::ResponseData>
341 where
342 Q: GraphQLQuery,
343 Q::Variables: Debug,
344 for<'d> Q::ResponseData: Deserialize<'d>,
345 {
346 retry_with_backoff(|| self.send_impl::<Q>(owner, query))
347 }
348}
349
350fn retry_with_backoff<F, K>(mut tryf: F) -> GithubResult<K>
351where
352 F: FnMut() -> GithubResult<K>,
353{
354 iter::repeat_n((), BACKOFF_LIMIT)
355 .scan(BACKOFF_START, |timeout, _| {
356 match tryf() {
357 Ok(r) => Some(Some(Ok(r))),
358 Err(err) => {
359 if err.should_backoff() {
360 thread::sleep(*timeout);
361 *timeout *= BACKOFF_SCALE;
362 Some(None)
363 } else {
364 Some(Some(Err(err)))
365 }
366 },
367 }
368 })
369 .flatten()
370 .next()
371 .unwrap_or_else(|| Err(GithubError::github_backoff()))
372}
373
374#[cfg(test)]
375mod tests {
376 use reqwest::{header, Client, StatusCode};
377
378 use crate::client::{retry_with_backoff, Github, GithubError, BACKOFF_LIMIT};
379
380 #[test]
381 fn test_rest_accept_headers() {
382 let rest_headers = Github::rest_accept_headers();
383 assert_eq!(rest_headers.len(), 1);
384 assert_eq!(
385 rest_headers.get(header::ACCEPT).unwrap(),
386 "application/vnd.github.v3+json",
387 );
388 }
389
390 #[test]
391 fn test_gql_accept_headers() {
392 let gql_headers = Github::gql_accept_headers();
393 assert!(gql_headers.is_empty());
394 }
395
396 #[test]
397 fn test_retry_with_backoff_first_success() {
398 let mut call_count = 0;
399 retry_with_backoff(|| {
400 call_count += 1;
401 Ok(())
402 })
403 .unwrap();
404 assert_eq!(call_count, 1);
405 }
406
407 #[test]
408 fn test_retry_with_backoff_second_success() {
409 let mut call_count = 0;
410 let mut did_err = false;
411 retry_with_backoff(|| {
412 call_count += 1;
413 if did_err {
414 Ok(())
415 } else {
416 did_err = true;
417 Err(GithubError::github_service(
418 StatusCode::INTERNAL_SERVER_ERROR,
419 ))
420 }
421 })
422 .unwrap();
423 assert_eq!(call_count, 2);
424 }
425
426 #[test]
427 fn test_retry_with_backoff_no_success() {
428 let mut call_count = 0;
429 let err = retry_with_backoff::<_, ()>(|| {
430 call_count += 1;
431 Err(GithubError::github_service(
432 StatusCode::INTERNAL_SERVER_ERROR,
433 ))
434 })
435 .unwrap_err();
436 assert_eq!(call_count, BACKOFF_LIMIT);
437 if let GithubError::GithubBackoff {} = err {
438 } else {
439 panic!("unexpected error: {}", err);
440 }
441 }
442
443 #[test]
444 fn test_rest_headers_work() {
445 let req = Client::new()
446 .post("https://nowhere")
447 .headers(Github::rest_accept_headers())
448 .build()
449 .unwrap();
450
451 let headers = req.headers();
452
453 for (key, value) in Github::rest_accept_headers().iter() {
454 if !headers.get_all(key).iter().any(|av| av == value) {
455 panic!("REST request is missing HTTP header `{}: {:?}`", key, value);
456 }
457 }
458 }
459
460 #[test]
461 fn test_graphql_headers_work() {
462 let req = Client::new()
463 .post("https://nowhere")
464 .headers(Github::gql_accept_headers())
465 .build()
466 .unwrap();
467
468 let headers = req.headers();
469
470 for (key, value) in Github::gql_accept_headers().iter() {
471 if !headers.get_all(key).iter().any(|av| av == value) {
472 panic!(
473 "GraphQL request is missing HTTP header `{}: {:?}`",
474 key, value,
475 );
476 }
477 }
478 }
479}