ghostflow_github/
client.rs

1// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4// option. This file may not be copied, modified, or distributed
5// except according to those terms.
6
7use std::env;
8use std::fmt::Debug;
9use std::iter;
10use std::thread;
11use std::time::Duration;
12
13use graphql_client::{GraphQLQuery, QueryBody, Response};
14use itertools::Itertools;
15use log::{info, warn};
16use reqwest::blocking::Client;
17use reqwest::header::{self, HeaderMap, HeaderValue};
18use reqwest::Url;
19use serde::{Deserialize, Serialize};
20use serde_json::Value;
21use thiserror::Error;
22
23use crate::authorization::{CurrentUser, GithubAuthError, GithubAuthorization};
24
25// The maximum number of times we will retry server errors.
26const BACKOFF_LIMIT: usize = if cfg!(test) { 2 } else { 5 };
27// The number of seconds to start retries at.
28const BACKOFF_START: Duration = Duration::from_secs(1);
29// How much to scale retry timeouts for a single query.
30const BACKOFF_SCALE: u32 = 2;
31
32#[derive(Debug, Error)]
33#[non_exhaustive]
34pub enum GithubError {
35    #[error("url parse error: {}", source)]
36    UrlParse {
37        #[from]
38        source: url::ParseError,
39    },
40    #[error("invalid `GITHUB_TOKEN`: {}", source)]
41    InvalidToken {
42        #[source]
43        source: env::VarError,
44    },
45    #[error("invalid `GITHUB_ACTOR`: {}", source)]
46    InvalidActor {
47        #[source]
48        source: env::VarError,
49    },
50    #[error("failed to send request to {}: {}", endpoint, source)]
51    SendRequest {
52        endpoint: Url,
53        #[source]
54        source: reqwest::Error,
55    },
56    #[error("github error: {}", response)]
57    Github { response: String },
58    #[error("deserialize error: {}", source)]
59    Deserialize {
60        #[from]
61        source: serde_json::Error,
62    },
63    #[error("github service error: {}", status)]
64    GithubService { status: reqwest::StatusCode },
65    #[error("json response deserialize: {}", source)]
66    JsonResponse {
67        #[source]
68        source: reqwest::Error,
69    },
70    #[allow(clippy::upper_case_acronyms)]
71    #[error("graphql error: [\"{}\"]", message.iter().format("\", \""))]
72    GraphQL { message: Vec<graphql_client::Error> },
73    #[error("no response from github")]
74    NoResponse {},
75    #[error("failure even after exponential backoff")]
76    GithubBackoff {},
77    #[error("authorization error: {}", source)]
78    Authorization {
79        #[from]
80        source: GithubAuthError,
81    },
82}
83
84impl GithubError {
85    fn should_backoff(&self) -> bool {
86        matches!(self, GithubError::GithubService { .. })
87    }
88
89    pub(crate) fn send_request(endpoint: Url, source: reqwest::Error) -> Self {
90        GithubError::SendRequest {
91            endpoint,
92            source,
93        }
94    }
95
96    pub(crate) fn github(response: String) -> Self {
97        GithubError::Github {
98            response,
99        }
100    }
101
102    fn github_service(status: reqwest::StatusCode) -> Self {
103        GithubError::GithubService {
104            status,
105        }
106    }
107
108    pub(crate) fn json_response(source: reqwest::Error) -> Self {
109        GithubError::JsonResponse {
110            source,
111        }
112    }
113
114    pub(crate) fn invalid_token(source: env::VarError) -> Self {
115        GithubError::InvalidToken {
116            source,
117        }
118    }
119
120    pub(crate) fn invalid_actor(source: env::VarError) -> Self {
121        GithubError::InvalidActor {
122            source,
123        }
124    }
125
126    fn graphql(message: Vec<graphql_client::Error>) -> Self {
127        GithubError::GraphQL {
128            message,
129        }
130    }
131
132    fn no_response() -> Self {
133        GithubError::NoResponse {}
134    }
135
136    fn github_backoff() -> Self {
137        GithubError::GithubBackoff {}
138    }
139}
140
141pub(crate) type GithubResult<T> = Result<T, GithubError>;
142
143// The user agent for all queries.
144pub(crate) const USER_AGENT: &str =
145    concat!(env!("CARGO_PKG_NAME"), " v", env!("CARGO_PKG_VERSION"));
146
147/// A client for communicating with a Github instance.
148#[derive(Clone)]
149pub struct Github {
150    /// The client used to communicate with Github.
151    client: Client,
152    /// The endpoint for REST queries.
153    rest_endpoint: Url,
154    /// The endpoint for GraphQL queries.
155    gql_endpoint: Url,
156
157    /// The authorization process for the client.
158    authorization: GithubAuthorization,
159}
160
161impl Github {
162    fn new_impl(host: &str, authorization: GithubAuthorization) -> GithubResult<Self> {
163        let rest_endpoint = Url::parse(&format!("https://{host}/"))?;
164        let gql_endpoint = Url::parse(&format!("https://{host}/graphql"))?;
165
166        Ok(Github {
167            client: Client::new(),
168            rest_endpoint,
169            gql_endpoint,
170            authorization,
171        })
172    }
173
174    /// Create a new Github client as a GitHub App.
175    ///
176    /// The `host` parameter is the API endpoint. For example `github.com` uses `api.github.com`.
177    ///
178    /// The `app_id` and `private_key` are provided when [registering the application][new-app].
179    /// The `installation_id` is an ID associated with a given installation of the application. Its
180    /// value is present in webhooks, but does not seem to be available generically.
181    ///
182    /// [new-app]: https://developer.github.com/apps/building-your-first-github-app/#register-a-new-app-with-github
183    pub fn new_app<H, P, I, S>(
184        host: H,
185        app_id: i64,
186        private_key: P,
187        installation_ids: I,
188    ) -> GithubResult<Self>
189    where
190        H: AsRef<str>,
191        P: AsRef<[u8]>,
192        I: IntoIterator<Item = (S, i64)>,
193        S: Into<String>,
194    {
195        let ids = installation_ids
196            .into_iter()
197            .map(|(s, i)| (s.into(), i))
198            .collect();
199        let authorization =
200            GithubAuthorization::new_app(host.as_ref(), app_id, private_key.as_ref(), ids)?;
201
202        Self::new_impl(host.as_ref(), authorization)
203    }
204
205    /// Create a new Github client as a GitHub Action.
206    ///
207    /// The `host` parameter is the API endpoint. For example `github.com` uses `api.github.com`.
208    ///
209    /// The `app_id` and `private_key` are provided when [registering the application][new-app].
210    /// The `installation_id` is an ID associated with a given installation of the application. Its
211    /// value is present in webhooks, but does not seem to be available generically.
212    ///
213    /// [new-app]: https://developer.github.com/apps/building-your-first-github-app/#register-a-new-app-with-github
214    pub fn new_action<H>(host: H) -> GithubResult<Self>
215    where
216        H: AsRef<str>,
217    {
218        let authorization = GithubAuthorization::new_action()?;
219
220        Self::new_impl(host.as_ref(), authorization)
221    }
222
223    pub(crate) fn app_id(&self) -> Option<i64> {
224        self.authorization.app_id()
225    }
226
227    pub(crate) fn current_user(&self) -> GithubResult<CurrentUser> {
228        self.authorization.current_user(&self.client)
229    }
230
231    /// The authorization header for GraphQL.
232    fn installation_auth_header(&self, owner: &str) -> GithubResult<HeaderMap> {
233        let token = self.authorization.token(&self.client, owner)?;
234        let mut header_value: HeaderValue = format!("token {token}").parse().unwrap();
235        header_value.set_sensitive(true);
236        Ok([(header::AUTHORIZATION, header_value)]
237            .iter()
238            .cloned()
239            .collect())
240    }
241
242    /// Accept headers for REST.
243    fn rest_accept_headers() -> HeaderMap {
244        [
245            // GitHub v3 API
246            (
247                header::ACCEPT,
248                "application/vnd.github.v3+json".parse().unwrap(),
249            ),
250        ]
251        .iter()
252        .cloned()
253        .collect()
254    }
255
256    /// Accept headers for GraphQL.
257    ///
258    /// We're using preview APIs and we need these to get access to them.
259    fn gql_accept_headers() -> HeaderMap {
260        HeaderMap::new()
261    }
262
263    pub(crate) fn post<D>(&self, owner: &str, endpoint: &str, data: &D) -> GithubResult<Value>
264    where
265        D: Serialize,
266    {
267        let endpoint = Url::parse(&format!("{}{}", self.rest_endpoint, endpoint))?;
268        let rsp = self
269            .client
270            .post(endpoint.clone())
271            .headers(self.installation_auth_header(owner)?)
272            .headers(Self::rest_accept_headers())
273            .header(header::USER_AGENT, USER_AGENT)
274            .json(data)
275            .send()
276            .map_err(|err| GithubError::send_request(endpoint, err))?;
277        if !rsp.status().is_success() {
278            let err = rsp
279                .text()
280                .unwrap_or_else(|text_err| format!("failed to extract error body: {text_err:?}"));
281            return Err(GithubError::github(err));
282        }
283
284        rsp.json().map_err(GithubError::json_response)
285    }
286
287    /// Send a GraphQL query.
288    fn send_impl<Q>(
289        &self,
290        owner: &str,
291        query: &QueryBody<Q::Variables>,
292    ) -> GithubResult<Q::ResponseData>
293    where
294        Q: GraphQLQuery,
295        Q::Variables: Debug,
296        for<'d> Q::ResponseData: Deserialize<'d>,
297    {
298        info!(
299            target: "github",
300            "sending GraphQL query '{}' {:?}",
301            query.operation_name,
302            query.variables,
303        );
304        let rsp = self
305            .client
306            .post(self.gql_endpoint.clone())
307            .headers(self.installation_auth_header(owner)?)
308            .headers(Self::gql_accept_headers())
309            .header(header::USER_AGENT, USER_AGENT)
310            .json(query)
311            .send()
312            .map_err(|err| GithubError::send_request(self.gql_endpoint.clone(), err))?;
313        if rsp.status().is_server_error() {
314            warn!(
315                target: "github",
316                "service error {} for query; retrying with backoff",
317                rsp.status().as_u16(),
318            );
319            return Err(GithubError::github_service(rsp.status()));
320        }
321        if !rsp.status().is_success() {
322            let err = rsp
323                .text()
324                .unwrap_or_else(|text_err| format!("failed to extract error body: {text_err:?}"));
325            return Err(GithubError::github(err));
326        }
327
328        let rsp: Response<Q::ResponseData> = rsp.json().map_err(GithubError::json_response)?;
329        if let Some(errs) = rsp.errors {
330            return Err(GithubError::graphql(errs));
331        }
332        rsp.data.ok_or_else(GithubError::no_response)
333    }
334
335    /// Send a GraphQL query.
336    pub fn send<Q>(
337        &self,
338        owner: &str,
339        query: &QueryBody<Q::Variables>,
340    ) -> GithubResult<Q::ResponseData>
341    where
342        Q: GraphQLQuery,
343        Q::Variables: Debug,
344        for<'d> Q::ResponseData: Deserialize<'d>,
345    {
346        retry_with_backoff(|| self.send_impl::<Q>(owner, query))
347    }
348}
349
350fn retry_with_backoff<F, K>(mut tryf: F) -> GithubResult<K>
351where
352    F: FnMut() -> GithubResult<K>,
353{
354    iter::repeat_n((), BACKOFF_LIMIT)
355        .scan(BACKOFF_START, |timeout, _| {
356            match tryf() {
357                Ok(r) => Some(Some(Ok(r))),
358                Err(err) => {
359                    if err.should_backoff() {
360                        thread::sleep(*timeout);
361                        *timeout *= BACKOFF_SCALE;
362                        Some(None)
363                    } else {
364                        Some(Some(Err(err)))
365                    }
366                },
367            }
368        })
369        .flatten()
370        .next()
371        .unwrap_or_else(|| Err(GithubError::github_backoff()))
372}
373
374#[cfg(test)]
375mod tests {
376    use reqwest::{header, Client, StatusCode};
377
378    use crate::client::{retry_with_backoff, Github, GithubError, BACKOFF_LIMIT};
379
380    #[test]
381    fn test_rest_accept_headers() {
382        let rest_headers = Github::rest_accept_headers();
383        assert_eq!(rest_headers.len(), 1);
384        assert_eq!(
385            rest_headers.get(header::ACCEPT).unwrap(),
386            "application/vnd.github.v3+json",
387        );
388    }
389
390    #[test]
391    fn test_gql_accept_headers() {
392        let gql_headers = Github::gql_accept_headers();
393        assert!(gql_headers.is_empty());
394    }
395
396    #[test]
397    fn test_retry_with_backoff_first_success() {
398        let mut call_count = 0;
399        retry_with_backoff(|| {
400            call_count += 1;
401            Ok(())
402        })
403        .unwrap();
404        assert_eq!(call_count, 1);
405    }
406
407    #[test]
408    fn test_retry_with_backoff_second_success() {
409        let mut call_count = 0;
410        let mut did_err = false;
411        retry_with_backoff(|| {
412            call_count += 1;
413            if did_err {
414                Ok(())
415            } else {
416                did_err = true;
417                Err(GithubError::github_service(
418                    StatusCode::INTERNAL_SERVER_ERROR,
419                ))
420            }
421        })
422        .unwrap();
423        assert_eq!(call_count, 2);
424    }
425
426    #[test]
427    fn test_retry_with_backoff_no_success() {
428        let mut call_count = 0;
429        let err = retry_with_backoff::<_, ()>(|| {
430            call_count += 1;
431            Err(GithubError::github_service(
432                StatusCode::INTERNAL_SERVER_ERROR,
433            ))
434        })
435        .unwrap_err();
436        assert_eq!(call_count, BACKOFF_LIMIT);
437        if let GithubError::GithubBackoff {} = err {
438        } else {
439            panic!("unexpected error: {}", err);
440        }
441    }
442
443    #[test]
444    fn test_rest_headers_work() {
445        let req = Client::new()
446            .post("https://nowhere")
447            .headers(Github::rest_accept_headers())
448            .build()
449            .unwrap();
450
451        let headers = req.headers();
452
453        for (key, value) in Github::rest_accept_headers().iter() {
454            if !headers.get_all(key).iter().any(|av| av == value) {
455                panic!("REST request is missing HTTP header `{}: {:?}`", key, value);
456            }
457        }
458    }
459
460    #[test]
461    fn test_graphql_headers_work() {
462        let req = Client::new()
463            .post("https://nowhere")
464            .headers(Github::gql_accept_headers())
465            .build()
466            .unwrap();
467
468        let headers = req.headers();
469
470        for (key, value) in Github::gql_accept_headers().iter() {
471            if !headers.get_all(key).iter().any(|av| av == value) {
472                panic!(
473                    "GraphQL request is missing HTTP header `{}: {:?}`",
474                    key, value,
475                );
476            }
477        }
478    }
479}