automatons_github/client/
mod.rs1use anyhow::{anyhow, Context};
4use reqwest::header::HeaderValue;
5use reqwest::{Client, Method, RequestBuilder};
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8use serde_json::Value;
9
10use automatons::Error;
11
12use crate::resource::{AppId, InstallationId};
13use crate::{name, secret};
14
15use self::token::TokenFactory;
16pub use self::token::{AppScope, InstallationScope, Token};
17
18mod token;
19
20name!(
21 GitHubHost
26);
27
28secret!(
29 PrivateKey
33);
34
35#[derive(Clone, Debug)]
40pub struct GitHubClient {
41 github_host: GitHubHost,
42 token_factory: TokenFactory,
43 installation_id: InstallationId,
44}
45
46#[allow(dead_code)] impl GitHubClient {
48 #[cfg_attr(feature = "tracing", tracing::instrument)]
50 pub fn new(
51 github_host: GitHubHost,
52 app_id: AppId,
53 private_key: PrivateKey,
54 installation_id: InstallationId,
55 ) -> Self {
56 let token_factory = TokenFactory::new(github_host.clone(), app_id, private_key);
57
58 Self {
59 github_host,
60 token_factory,
61 installation_id,
62 }
63 }
64
65 #[cfg_attr(feature = "tracing", tracing::instrument)]
67 pub async fn get<T>(&self, endpoint: &str) -> Result<T, Error>
68 where
69 T: DeserializeOwned,
70 {
71 let body: Option<Value> = None;
73
74 self.send_request(Method::GET, endpoint, body).await
75 }
76
77 #[cfg_attr(feature = "tracing", tracing::instrument(skip(body)))]
79 pub async fn post<T>(&self, endpoint: &str, body: Option<impl Serialize>) -> Result<T, Error>
80 where
81 T: DeserializeOwned,
82 {
83 self.send_request(Method::POST, endpoint, body).await
84 }
85
86 #[cfg_attr(feature = "tracing", tracing::instrument(skip(body)))]
88 pub async fn patch<T>(&self, endpoint: &str, body: Option<impl Serialize>) -> Result<T, Error>
89 where
90 T: DeserializeOwned,
91 {
92 self.send_request(Method::PATCH, endpoint, body).await
93 }
94
95 #[cfg_attr(feature = "tracing", tracing::instrument(skip(body)))]
96 async fn send_request<T>(
97 &self,
98 method: Method,
99 endpoint: &str,
100 body: Option<impl Serialize>,
101 ) -> Result<T, Error>
102 where
103 T: DeserializeOwned,
104 {
105 let url = format!("{}{}", self.github_host.get(), endpoint);
106
107 let mut client = self.client(method.clone(), &url).await?;
108
109 if let Some(body) = body {
110 client = client.json(&body);
111 }
112
113 let response = client.send().await?;
114 let status = &response.status();
115
116 if !status.is_success() {
117 #[cfg(feature = "tracing")]
118 tracing::error!(
119 "failed to send {} request to GitHub: {:?}",
120 &method,
121 response.text().await?
122 );
123
124 return if status == &404 {
125 Err(Error::NotFound(String::from(endpoint)))
126 } else {
127 Err(Error::Unknown(anyhow!(
129 "failed to send {} request to GitHub",
130 &method
131 )))
132 };
133 }
134
135 let data = response.json::<T>().await?;
136
137 Ok(data)
138 }
139
140 #[cfg_attr(feature = "tracing", tracing::instrument)]
142 pub async fn paginate<T>(
143 &self,
144 method: Method,
145 endpoint: &str,
146 key: &str,
147 ) -> Result<Vec<T>, Error>
148 where
149 T: DeserializeOwned,
150 {
151 let url = format!("{}{}", self.github_host.get(), endpoint);
152
153 let mut collection = Vec::new();
154 let mut next_url = Some(url);
155
156 while next_url.is_some() {
157 let response = self
158 .client(method.clone(), &next_url.unwrap())
159 .await?
160 .send()
161 .await?;
162
163 next_url = self.get_next_url(response.headers().get("link"))?;
164 let body = &response.json::<Value>().await?;
165
166 let payload = body
167 .get(key)
168 .context("failed to find pagination key in HTTP response")?;
169
170 let mut entities: Vec<T> = serde_json::from_value(payload.clone())
172 .context("failed to deserialize paginated entities")?;
173
174 collection.append(&mut entities);
175 }
176
177 Ok(collection)
178 }
179
180 #[cfg_attr(feature = "tracing", tracing::instrument)]
181 async fn client(&self, method: Method, url: &str) -> Result<RequestBuilder, Error> {
182 let token = self
183 .token_factory
184 .installation(self.installation_id)
185 .await?;
186
187 let client = Client::new()
188 .request(method, url)
189 .header("Authorization", format!("Bearer {}", token.get()))
190 .header("Accept", "application/vnd.github.v3+json")
191 .header("User-Agent", "devxbots/github-parts");
192
193 Ok(client)
194 }
195
196 #[cfg_attr(feature = "tracing", tracing::instrument)]
197 fn get_next_url(&self, header: Option<&HeaderValue>) -> Result<Option<String>, Error> {
198 let header = match header {
199 Some(header) => header,
200 None => return Ok(None),
201 };
202
203 let relations: Vec<&str> = header
204 .to_str()
205 .context("failed to parse HTTP request header")?
206 .split(',')
207 .collect();
208
209 let next_rel = match relations.iter().find(|link| link.contains(r#"rel="next"#)) {
210 Some(link) => link,
211 None => return Ok(None),
212 };
213
214 let link_start_position = 1 + next_rel
215 .find('<')
216 .context("failed to extract next url from link header")?;
217 let link_end_position = next_rel
218 .find('>')
219 .context("failed to extract next url from link header")?;
220
221 let link = String::from(&next_rel[link_start_position..link_end_position]);
222
223 Ok(Some(link))
224 }
225}
226
227#[cfg(test)]
228mod tests {
229 use mockito::mock;
230 use reqwest::header::HeaderValue;
231 use reqwest::Method;
232
233 use crate::client::PrivateKey;
234 use crate::resource::{AppId, InstallationId, Repository};
235
236 use super::GitHubClient;
237
238 #[tokio::test]
239 async fn get_entity() {
240 let _token_mock = mock("POST", "/app/installations/1/access_tokens")
241 .with_status(200)
242 .with_body(r#"{ "token": "ghs_16C7e42F292c6912E7710c838347Ae178B4a" }"#)
243 .create();
244 let _content_mock = mock("GET", "/repos/devxbots/automatons")
245 .with_status(200)
246 .with_body_from_file("tests/fixtures/resource/repository.json")
247 .create();
248
249 let client = GitHubClient::new(
250 mockito::server_url().into(),
251 AppId::new(1),
252 PrivateKey::new(include_str!("../../tests/fixtures/private-key.pem")),
253 InstallationId::new(1),
254 );
255
256 let repository: Repository = client.get("/repos/devxbots/automatons").await.unwrap();
257
258 assert_eq!(518377950, repository.id().get());
259 }
260
261 #[tokio::test]
262 async fn paginate_returns_all_entities() {
263 let _token_mock = mock("POST", "/app/installations/1/access_tokens")
264 .with_status(200)
265 .with_body(r#"{ "token": "ghs_16C7e42F292c6912E7710c838347Ae178B4a" }"#)
266 .create();
267 let _first_page_mock = mock("GET", "/installation/repositories")
268 .with_status(200)
269 .with_header(
270 "link",
271 &format!(
272 "<{}/installation/repositories?page=2>; rel=\"next\"",
273 mockito::server_url()
274 ),
275 )
276 .with_body(format!(
277 r#"
278 {{
279 "total_count": 2,
280 "repositories": [
281 {}
282 ]
283 }}
284 "#,
285 include_str!("../../tests/fixtures/resource/repository.json")
286 ))
287 .create();
288 let _second_page_mock = mock("GET", "/installation/repositories?page=2")
289 .with_status(200)
290 .with_body(format!(
291 r#"
292 {{
293 "total_count": 2,
294 "repositories": [
295 {}
296 ]
297 }}
298 "#,
299 include_str!("../../tests/fixtures/resource/repository.json")
300 ))
301 .create();
302
303 let client = GitHubClient::new(
304 mockito::server_url().into(),
305 AppId::new(1),
306 PrivateKey::new(include_str!("../../tests/fixtures/private-key.pem")),
307 InstallationId::new(1),
308 );
309
310 let repository: Vec<Repository> = client
311 .paginate(Method::GET, "/installation/repositories", "repositories")
312 .await
313 .unwrap();
314
315 assert_eq!(2, repository.len());
316 }
317
318 #[test]
319 fn get_next_url_returns_url() {
320 let client = GitHubClient::new(
321 mockito::server_url().into(),
322 AppId::new(1),
323 PrivateKey::new(include_str!("../../tests/fixtures/private-key.pem")),
324 InstallationId::new(1),
325 );
326
327 let header = HeaderValue::from_str(r#"<https://api.github.com/search/code?q=addClass+user%3Amozilla&page=13>; rel="prev", <https://api.github.com/search/code?q=addClass+user%3Amozilla&page=15>; rel="next", <https://api.github.com/search/code?q=addClass+user%3Amozilla&page=34>; rel="last", <https://api.github.com/search/code?q=addClass+user%3Amozilla&page=1>; rel="first""#).unwrap();
328
329 let next_url = client.get_next_url(Some(&header)).unwrap().unwrap();
330
331 assert_eq!(
332 "https://api.github.com/search/code?q=addClass+user%3Amozilla&page=15",
333 next_url
334 );
335 }
336
337 #[test]
338 fn get_next_url_returns_none() {
339 let client = GitHubClient::new(
340 mockito::server_url().into(),
341 AppId::new(1),
342 PrivateKey::new(include_str!("../../tests/fixtures/private-key.pem")),
343 InstallationId::new(1),
344 );
345
346 let header = HeaderValue::from_str(
347 r#"<https://api.github.com/search/code?q=addClass+user%3Amozilla&page=13>; rel="prev""#,
348 )
349 .unwrap();
350
351 let next_url = client.get_next_url(Some(&header)).unwrap();
352
353 assert!(next_url.is_none());
354 }
355
356 #[test]
357 fn trait_send() {
358 fn assert_send<T: Send>() {}
359 assert_send::<GitHubClient>();
360 }
361
362 #[test]
363 fn trait_sync() {
364 fn assert_sync<T: Sync>() {}
365 assert_sync::<GitHubClient>();
366 }
367}