automatons_github/client/
mod.rs

1//! Client for GitHub's REST API
2
3use anyhow::{anyhow, Context};
4use reqwest::header::HeaderValue;
5use reqwest::{Client, Method, RequestBuilder};
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8use serde_json::Value;
9
10use automatons::Error;
11
12use crate::resource::{AppId, InstallationId};
13use crate::{name, secret};
14
15use self::token::TokenFactory;
16pub use self::token::{AppScope, InstallationScope, Token};
17
18mod token;
19
20name!(
21    /// API endpoint for the client
22    ///
23    /// The GitHub client can be used with different GitHub instances, for example a self-hosted
24    /// GitHub Enterprise Server. The `GitHubHost` sets the base URL that the client will use.
25    GitHubHost
26);
27
28secret!(
29    /// Private key of the GitHub App
30    ///
31    /// GitHub Apps have a private key that they use to sign authentication tokens.
32    PrivateKey
33);
34
35/// Client for GitHub's REST API
36///
37/// The GitHub client can be used to send HTTP requests to GitHub's REST API. The client handles
38/// authentication, serialization, and pagination.
39#[derive(Clone, Debug)]
40pub struct GitHubClient {
41    github_host: GitHubHost,
42    token_factory: TokenFactory,
43    installation_id: InstallationId,
44}
45
46#[allow(dead_code)] // TODO: Remove when remaining tasks have been migrated from `github-parts`
47impl GitHubClient {
48    /// Initializes a new instance of the GitHub client
49    #[cfg_attr(feature = "tracing", tracing::instrument)]
50    pub fn new(
51        github_host: GitHubHost,
52        app_id: AppId,
53        private_key: PrivateKey,
54        installation_id: InstallationId,
55    ) -> Self {
56        let token_factory = TokenFactory::new(github_host.clone(), app_id, private_key);
57
58        Self {
59            github_host,
60            token_factory,
61            installation_id,
62        }
63    }
64
65    /// Send a GET request to GitHub
66    #[cfg_attr(feature = "tracing", tracing::instrument)]
67    pub async fn get<T>(&self, endpoint: &str) -> Result<T, Error>
68    where
69        T: DeserializeOwned,
70    {
71        // We need to explicitly declare the type of the body somewhere to silence a compiler error.
72        let body: Option<Value> = None;
73
74        self.send_request(Method::GET, endpoint, body).await
75    }
76
77    /// Send a POST request to GitHub
78    #[cfg_attr(feature = "tracing", tracing::instrument(skip(body)))]
79    pub async fn post<T>(&self, endpoint: &str, body: Option<impl Serialize>) -> Result<T, Error>
80    where
81        T: DeserializeOwned,
82    {
83        self.send_request(Method::POST, endpoint, body).await
84    }
85
86    /// Send a PATCH request to GitHub
87    #[cfg_attr(feature = "tracing", tracing::instrument(skip(body)))]
88    pub async fn patch<T>(&self, endpoint: &str, body: Option<impl Serialize>) -> Result<T, Error>
89    where
90        T: DeserializeOwned,
91    {
92        self.send_request(Method::PATCH, endpoint, body).await
93    }
94
95    #[cfg_attr(feature = "tracing", tracing::instrument(skip(body)))]
96    async fn send_request<T>(
97        &self,
98        method: Method,
99        endpoint: &str,
100        body: Option<impl Serialize>,
101    ) -> Result<T, Error>
102    where
103        T: DeserializeOwned,
104    {
105        let url = format!("{}{}", self.github_host.get(), endpoint);
106
107        let mut client = self.client(method.clone(), &url).await?;
108
109        if let Some(body) = body {
110            client = client.json(&body);
111        }
112
113        let response = client.send().await?;
114        let status = &response.status();
115
116        if !status.is_success() {
117            #[cfg(feature = "tracing")]
118            tracing::error!(
119                "failed to send {} request to GitHub: {:?}",
120                &method,
121                response.text().await?
122            );
123
124            return if status == &404 {
125                Err(Error::NotFound(String::from(endpoint)))
126            } else {
127                // TODO: Gracefully return status instead of error
128                Err(Error::Unknown(anyhow!(
129                    "failed to send {} request to GitHub",
130                    &method
131                )))
132            };
133        }
134
135        let data = response.json::<T>().await?;
136
137        Ok(data)
138    }
139
140    /// Send a paginated request to GitHub
141    #[cfg_attr(feature = "tracing", tracing::instrument)]
142    pub async fn paginate<T>(
143        &self,
144        method: Method,
145        endpoint: &str,
146        key: &str,
147    ) -> Result<Vec<T>, Error>
148    where
149        T: DeserializeOwned,
150    {
151        let url = format!("{}{}", self.github_host.get(), endpoint);
152
153        let mut collection = Vec::new();
154        let mut next_url = Some(url);
155
156        while next_url.is_some() {
157            let response = self
158                .client(method.clone(), &next_url.unwrap())
159                .await?
160                .send()
161                .await?;
162
163            next_url = self.get_next_url(response.headers().get("link"))?;
164            let body = &response.json::<Value>().await?;
165
166            let payload = body
167                .get(key)
168                .context("failed to find pagination key in HTTP response")?;
169
170            // TODO: Avoid cloning the payload
171            let mut entities: Vec<T> = serde_json::from_value(payload.clone())
172                .context("failed to deserialize paginated entities")?;
173
174            collection.append(&mut entities);
175        }
176
177        Ok(collection)
178    }
179
180    #[cfg_attr(feature = "tracing", tracing::instrument)]
181    async fn client(&self, method: Method, url: &str) -> Result<RequestBuilder, Error> {
182        let token = self
183            .token_factory
184            .installation(self.installation_id)
185            .await?;
186
187        let client = Client::new()
188            .request(method, url)
189            .header("Authorization", format!("Bearer {}", token.get()))
190            .header("Accept", "application/vnd.github.v3+json")
191            .header("User-Agent", "devxbots/github-parts");
192
193        Ok(client)
194    }
195
196    #[cfg_attr(feature = "tracing", tracing::instrument)]
197    fn get_next_url(&self, header: Option<&HeaderValue>) -> Result<Option<String>, Error> {
198        let header = match header {
199            Some(header) => header,
200            None => return Ok(None),
201        };
202
203        let relations: Vec<&str> = header
204            .to_str()
205            .context("failed to parse HTTP request header")?
206            .split(',')
207            .collect();
208
209        let next_rel = match relations.iter().find(|link| link.contains(r#"rel="next"#)) {
210            Some(link) => link,
211            None => return Ok(None),
212        };
213
214        let link_start_position = 1 + next_rel
215            .find('<')
216            .context("failed to extract next url from link header")?;
217        let link_end_position = next_rel
218            .find('>')
219            .context("failed to extract next url from link header")?;
220
221        let link = String::from(&next_rel[link_start_position..link_end_position]);
222
223        Ok(Some(link))
224    }
225}
226
227#[cfg(test)]
228mod tests {
229    use mockito::mock;
230    use reqwest::header::HeaderValue;
231    use reqwest::Method;
232
233    use crate::client::PrivateKey;
234    use crate::resource::{AppId, InstallationId, Repository};
235
236    use super::GitHubClient;
237
238    #[tokio::test]
239    async fn get_entity() {
240        let _token_mock = mock("POST", "/app/installations/1/access_tokens")
241            .with_status(200)
242            .with_body(r#"{ "token": "ghs_16C7e42F292c6912E7710c838347Ae178B4a" }"#)
243            .create();
244        let _content_mock = mock("GET", "/repos/devxbots/automatons")
245            .with_status(200)
246            .with_body_from_file("tests/fixtures/resource/repository.json")
247            .create();
248
249        let client = GitHubClient::new(
250            mockito::server_url().into(),
251            AppId::new(1),
252            PrivateKey::new(include_str!("../../tests/fixtures/private-key.pem")),
253            InstallationId::new(1),
254        );
255
256        let repository: Repository = client.get("/repos/devxbots/automatons").await.unwrap();
257
258        assert_eq!(518377950, repository.id().get());
259    }
260
261    #[tokio::test]
262    async fn paginate_returns_all_entities() {
263        let _token_mock = mock("POST", "/app/installations/1/access_tokens")
264            .with_status(200)
265            .with_body(r#"{ "token": "ghs_16C7e42F292c6912E7710c838347Ae178B4a" }"#)
266            .create();
267        let _first_page_mock = mock("GET", "/installation/repositories")
268            .with_status(200)
269            .with_header(
270                "link",
271                &format!(
272                    "<{}/installation/repositories?page=2>; rel=\"next\"",
273                    mockito::server_url()
274                ),
275            )
276            .with_body(format!(
277                r#"
278                {{
279                    "total_count": 2,
280                    "repositories": [
281                        {}
282                    ]
283                }}
284                "#,
285                include_str!("../../tests/fixtures/resource/repository.json")
286            ))
287            .create();
288        let _second_page_mock = mock("GET", "/installation/repositories?page=2")
289            .with_status(200)
290            .with_body(format!(
291                r#"
292                {{
293                    "total_count": 2,
294                    "repositories": [
295                        {}
296                    ]
297                }}
298                "#,
299                include_str!("../../tests/fixtures/resource/repository.json")
300            ))
301            .create();
302
303        let client = GitHubClient::new(
304            mockito::server_url().into(),
305            AppId::new(1),
306            PrivateKey::new(include_str!("../../tests/fixtures/private-key.pem")),
307            InstallationId::new(1),
308        );
309
310        let repository: Vec<Repository> = client
311            .paginate(Method::GET, "/installation/repositories", "repositories")
312            .await
313            .unwrap();
314
315        assert_eq!(2, repository.len());
316    }
317
318    #[test]
319    fn get_next_url_returns_url() {
320        let client = GitHubClient::new(
321            mockito::server_url().into(),
322            AppId::new(1),
323            PrivateKey::new(include_str!("../../tests/fixtures/private-key.pem")),
324            InstallationId::new(1),
325        );
326
327        let header = HeaderValue::from_str(r#"<https://api.github.com/search/code?q=addClass+user%3Amozilla&page=13>; rel="prev", <https://api.github.com/search/code?q=addClass+user%3Amozilla&page=15>; rel="next", <https://api.github.com/search/code?q=addClass+user%3Amozilla&page=34>; rel="last", <https://api.github.com/search/code?q=addClass+user%3Amozilla&page=1>; rel="first""#).unwrap();
328
329        let next_url = client.get_next_url(Some(&header)).unwrap().unwrap();
330
331        assert_eq!(
332            "https://api.github.com/search/code?q=addClass+user%3Amozilla&page=15",
333            next_url
334        );
335    }
336
337    #[test]
338    fn get_next_url_returns_none() {
339        let client = GitHubClient::new(
340            mockito::server_url().into(),
341            AppId::new(1),
342            PrivateKey::new(include_str!("../../tests/fixtures/private-key.pem")),
343            InstallationId::new(1),
344        );
345
346        let header = HeaderValue::from_str(
347            r#"<https://api.github.com/search/code?q=addClass+user%3Amozilla&page=13>; rel="prev""#,
348        )
349        .unwrap();
350
351        let next_url = client.get_next_url(Some(&header)).unwrap();
352
353        assert!(next_url.is_none());
354    }
355
356    #[test]
357    fn trait_send() {
358        fn assert_send<T: Send>() {}
359        assert_send::<GitHubClient>();
360    }
361
362    #[test]
363    fn trait_sync() {
364        fn assert_sync<T: Sync>() {}
365        assert_sync::<GitHubClient>();
366    }
367}