1use std::{
2 collections::HashMap,
3 future::Future,
4 ops::Deref,
5 sync::{
6 atomic::{AtomicBool, Ordering::Relaxed},
7 Arc, Mutex, RwLock,
8 },
9 time::{Duration, Instant},
10};
11
12use binstalk_downloader::{download::Download, remote};
13use compact_str::{format_compact, CompactString, ToCompactString};
14use tokio::sync::OnceCell;
15use tracing::{instrument, Level};
16use url::Url;
17use zeroize::Zeroizing;
18
19mod common;
20mod error;
21mod release_artifacts;
22mod repo_info;
23
24use common::{check_http_status_and_header, percent_decode_http_url_path};
25pub use error::{GhApiContextError, GhApiError, GhGraphQLErrors};
26pub use repo_info::RepoInfo;
27
28const DEFAULT_RETRY_DURATION: Duration = Duration::from_secs(10 * 60);
30
31#[derive(Clone, Eq, PartialEq, Hash, Debug)]
32pub struct GhRepo {
33 pub owner: CompactString,
34 pub repo: CompactString,
35}
36impl GhRepo {
37 pub fn repo_url(&self) -> Result<Url, url::ParseError> {
38 Url::parse(&format_compact!(
39 "https://github.com/{}/{}",
40 self.owner,
41 self.repo
42 ))
43 }
44
45 pub fn try_extract_from_url(url: &Url) -> Option<Self> {
46 if url.domain() != Some("github.com") {
47 return None;
48 }
49
50 let mut path_segments = url.path_segments()?;
51
52 Some(Self {
53 owner: path_segments.next()?.to_compact_string(),
54 repo: path_segments.next()?.to_compact_string(),
55 })
56 }
57}
58
59#[derive(Clone, Eq, PartialEq, Hash, Debug)]
61pub struct GhRelease {
62 pub repo: GhRepo,
63 pub tag: CompactString,
64}
65
66#[derive(Clone, Eq, PartialEq, Hash, Debug)]
68pub struct GhReleaseArtifact {
69 pub release: GhRelease,
70 pub artifact_name: CompactString,
71}
72
73impl GhReleaseArtifact {
74 pub fn try_extract_from_url(url: &remote::Url) -> Option<Self> {
76 if url.domain() != Some("github.com") {
77 return None;
78 }
79
80 let mut path_segments = url.path_segments()?;
81
82 let owner = path_segments.next()?;
83 let repo = path_segments.next()?;
84
85 if (path_segments.next()?, path_segments.next()?) != ("releases", "download") {
86 return None;
87 }
88
89 let tag = path_segments.next()?;
90 let artifact_name = path_segments.next()?;
91
92 (path_segments.next().is_none() && url.fragment().is_none() && url.query().is_none()).then(
93 || Self {
94 release: GhRelease {
95 repo: GhRepo {
96 owner: percent_decode_http_url_path(owner),
97 repo: percent_decode_http_url_path(repo),
98 },
99 tag: percent_decode_http_url_path(tag),
100 },
101 artifact_name: percent_decode_http_url_path(artifact_name),
102 },
103 )
104 }
105}
106
107#[derive(Debug)]
108struct Map<K, V>(RwLock<HashMap<K, Arc<V>>>);
109
110impl<K, V> Default for Map<K, V> {
111 fn default() -> Self {
112 Self(Default::default())
113 }
114}
115
116impl<K, V> Map<K, V>
117where
118 K: Eq + std::hash::Hash,
119 V: Default,
120{
121 fn get(&self, k: K) -> Arc<V> {
122 let optional_value = self.0.read().unwrap().deref().get(&k).cloned();
123 optional_value.unwrap_or_else(|| Arc::clone(self.0.write().unwrap().entry(k).or_default()))
124 }
125}
126
127#[derive(Debug)]
128struct Inner {
129 client: remote::Client,
130 release_artifacts: Map<GhRelease, OnceCell<Option<release_artifacts::Artifacts>>>,
131 retry_after: Mutex<Option<Instant>>,
132
133 auth_token: Option<Zeroizing<Box<str>>>,
134 is_auth_token_valid: AtomicBool,
135
136 only_use_restful_api: AtomicBool,
137}
138
139#[derive(Clone, Debug)]
142pub struct GhApiClient(Arc<Inner>);
143
144impl GhApiClient {
145 pub fn new(client: remote::Client, auth_token: Option<Zeroizing<Box<str>>>) -> Self {
146 Self(Arc::new(Inner {
147 client,
148 release_artifacts: Default::default(),
149 retry_after: Default::default(),
150
151 auth_token,
152 is_auth_token_valid: AtomicBool::new(true),
153
154 only_use_restful_api: AtomicBool::new(false),
155 }))
156 }
157
158 pub fn set_only_use_restful_api(&self) {
160 self.0.only_use_restful_api.store(true, Relaxed);
161 }
162
163 pub fn remote_client(&self) -> &remote::Client {
164 &self.0.client
165 }
166}
167
168impl GhApiClient {
169 fn check_retry_after(&self) -> Result<(), GhApiError> {
170 let mut guard = self.0.retry_after.lock().unwrap();
171
172 if let Some(retry_after) = *guard {
173 if retry_after.elapsed().is_zero() {
174 return Err(GhApiError::RateLimit {
175 retry_after: Some(retry_after - Instant::now()),
176 });
177 } else {
178 *guard = None;
180 }
181 }
182
183 Ok(())
184 }
185
186 fn get_auth_token(&self) -> Option<&str> {
187 if self.0.is_auth_token_valid.load(Relaxed) {
188 self.0.auth_token.as_deref().map(|s| &**s)
189 } else {
190 None
191 }
192 }
193
194 pub fn has_gh_token(&self) -> bool {
195 self.get_auth_token().is_some()
196 }
197
198 async fn do_fetch<T, U, GraphQLFn, RestfulFn, GraphQLFut, RestfulFut>(
199 &self,
200 graphql_func: GraphQLFn,
201 restful_func: RestfulFn,
202 data: &T,
203 ) -> Result<U, GhApiError>
204 where
205 GraphQLFn: Fn(&remote::Client, &T, &str) -> GraphQLFut,
206 RestfulFn: Fn(&remote::Client, &T, Option<&str>) -> RestfulFut,
207 GraphQLFut: Future<Output = Result<U, GhApiError>> + Send + 'static,
208 RestfulFut: Future<Output = Result<U, GhApiError>> + Send + 'static,
209 {
210 self.check_retry_after()?;
211
212 if !self.0.only_use_restful_api.load(Relaxed) {
213 if let Some(auth_token) = self.get_auth_token() {
214 match graphql_func(&self.0.client, data, auth_token).await {
215 Err(GhApiError::Unauthorized) => {
216 self.0.is_auth_token_valid.store(false, Relaxed);
217 }
218 res => return res.map_err(|err| err.context("GraphQL API")),
219 }
220 }
221 }
222
223 restful_func(&self.0.client, data, self.get_auth_token())
224 .await
225 .map_err(|err| err.context("Restful API"))
226 }
227
228 #[instrument(skip(self), ret(level = Level::DEBUG))]
229 pub async fn get_repo_info(&self, repo: &GhRepo) -> Result<Option<RepoInfo>, GhApiError> {
230 match self
231 .do_fetch(
232 repo_info::fetch_repo_info_graphql_api,
233 repo_info::fetch_repo_info_restful_api,
234 repo,
235 )
236 .await
237 {
238 Ok(repo_info) => Ok(repo_info),
239 Err(GhApiError::NotFound) => Ok(None),
240 Err(err) => Err(err),
241 }
242 }
243}
244
245#[derive(Clone, Debug, Eq, PartialEq, Hash)]
246pub struct GhReleaseArtifactUrl(Url);
247
248impl GhApiClient {
249 #[instrument(skip(self), ret(level = Level::DEBUG))]
255 pub async fn has_release_artifact(
256 &self,
257 GhReleaseArtifact {
258 release,
259 artifact_name,
260 }: GhReleaseArtifact,
261 ) -> Result<Option<GhReleaseArtifactUrl>, GhApiError> {
262 let once_cell = self.0.release_artifacts.get(release.clone());
263 let res = once_cell
264 .get_or_try_init(|| {
265 Box::pin(async {
266 match self
267 .do_fetch(
268 release_artifacts::fetch_release_artifacts_graphql_api,
269 release_artifacts::fetch_release_artifacts_restful_api,
270 &release,
271 )
272 .await
273 {
274 Ok(artifacts) => Ok(Some(artifacts)),
275 Err(GhApiError::NotFound) => Ok(None),
276 Err(err) => Err(err),
277 }
278 })
279 })
280 .await;
281
282 match res {
283 Ok(Some(artifacts)) => Ok(artifacts
284 .get_artifact_url(&artifact_name)
285 .map(GhReleaseArtifactUrl)),
286 Ok(None) => Ok(None),
287 Err(GhApiError::RateLimit { retry_after }) => {
288 *self.0.retry_after.lock().unwrap() =
289 Some(Instant::now() + retry_after.unwrap_or(DEFAULT_RETRY_DURATION));
290
291 Err(GhApiError::RateLimit { retry_after })
292 }
293 Err(err) => Err(err),
294 }
295 }
296
297 pub async fn download_artifact(
298 &self,
299 artifact_url: GhReleaseArtifactUrl,
300 ) -> Result<Download<'static>, GhApiError> {
301 self.check_retry_after()?;
302
303 let Some(auth_token) = self.get_auth_token() else {
304 return Err(GhApiError::Unauthorized);
305 };
306
307 let response = self
308 .0
309 .client
310 .get(artifact_url.0)
311 .header("Accept", "application/octet-stream")
312 .bearer_auth(&auth_token)
313 .send(false)
314 .await?;
315
316 match check_http_status_and_header(response) {
317 Err(GhApiError::Unauthorized) => {
318 self.0.is_auth_token_valid.store(false, Relaxed);
319 Err(GhApiError::Unauthorized)
320 }
321 res => res.map(Download::from_response),
322 }
323 }
324}
325
326#[cfg(test)]
327mod test {
328 use super::*;
329 use compact_str::{CompactString, ToCompactString};
330 use std::{env, num::NonZeroU16, time::Duration};
331 use tokio::time::sleep;
332 use tracing::subscriber::set_global_default;
333 use tracing_subscriber::{filter::LevelFilter, fmt::fmt};
334
335 static DEFAULT_RETRY_AFTER: Duration = Duration::from_secs(1);
336
337 mod cargo_binstall_v0_20_1 {
338 use super::{CompactString, GhRelease, GhRepo};
339
340 pub(super) const RELEASE: GhRelease = GhRelease {
341 repo: GhRepo {
342 owner: CompactString::const_new("cargo-bins"),
343 repo: CompactString::const_new("cargo-binstall"),
344 },
345 tag: CompactString::const_new("v0.20.1"),
346 };
347
348 pub(super) const ARTIFACTS: &[&str] = &[
349 "cargo-binstall-aarch64-apple-darwin.full.zip",
350 "cargo-binstall-aarch64-apple-darwin.zip",
351 "cargo-binstall-aarch64-pc-windows-msvc.full.zip",
352 "cargo-binstall-aarch64-pc-windows-msvc.zip",
353 "cargo-binstall-aarch64-unknown-linux-gnu.full.tgz",
354 "cargo-binstall-aarch64-unknown-linux-gnu.tgz",
355 "cargo-binstall-aarch64-unknown-linux-musl.full.tgz",
356 "cargo-binstall-aarch64-unknown-linux-musl.tgz",
357 "cargo-binstall-armv7-unknown-linux-gnueabihf.full.tgz",
358 "cargo-binstall-armv7-unknown-linux-gnueabihf.tgz",
359 "cargo-binstall-armv7-unknown-linux-musleabihf.full.tgz",
360 "cargo-binstall-armv7-unknown-linux-musleabihf.tgz",
361 "cargo-binstall-universal-apple-darwin.full.zip",
362 "cargo-binstall-universal-apple-darwin.zip",
363 "cargo-binstall-x86_64-apple-darwin.full.zip",
364 "cargo-binstall-x86_64-apple-darwin.zip",
365 "cargo-binstall-x86_64-pc-windows-msvc.full.zip",
366 "cargo-binstall-x86_64-pc-windows-msvc.zip",
367 "cargo-binstall-x86_64-unknown-linux-gnu.full.tgz",
368 "cargo-binstall-x86_64-unknown-linux-gnu.tgz",
369 "cargo-binstall-x86_64-unknown-linux-musl.full.tgz",
370 "cargo-binstall-x86_64-unknown-linux-musl.tgz",
371 ];
372 }
373
374 mod cargo_audit_v_0_17_6 {
375 use super::*;
376
377 pub(super) const RELEASE: GhRelease = GhRelease {
378 repo: GhRepo {
379 owner: CompactString::const_new("rustsec"),
380 repo: CompactString::const_new("rustsec"),
381 },
382 tag: CompactString::const_new("cargo-audit/v0.17.6"),
383 };
384
385 #[allow(unused)]
386 pub(super) const ARTIFACTS: &[&str] = &[
387 "cargo-audit-aarch64-unknown-linux-gnu-v0.17.6.tgz",
388 "cargo-audit-armv7-unknown-linux-gnueabihf-v0.17.6.tgz",
389 "cargo-audit-x86_64-apple-darwin-v0.17.6.tgz",
390 "cargo-audit-x86_64-pc-windows-msvc-v0.17.6.zip",
391 "cargo-audit-x86_64-unknown-linux-gnu-v0.17.6.tgz",
392 "cargo-audit-x86_64-unknown-linux-musl-v0.17.6.tgz",
393 ];
394
395 #[test]
396 fn extract_with_escaped_characters() {
397 let release_artifact = try_extract_artifact_from_str(
398"https://github.com/rustsec/rustsec/releases/download/cargo-audit%2Fv0.17.6/cargo-audit-aarch64-unknown-linux-gnu-v0.17.6.tgz"
399 ).unwrap();
400
401 assert_eq!(
402 release_artifact,
403 GhReleaseArtifact {
404 release: RELEASE,
405 artifact_name: CompactString::from(
406 "cargo-audit-aarch64-unknown-linux-gnu-v0.17.6.tgz",
407 )
408 }
409 );
410 }
411 }
412
413 #[test]
414 fn gh_repo_extract_from_and_to_url() {
415 [
416 "https://github.com/cargo-bins/cargo-binstall",
417 "https://github.com/rustsec/rustsec",
418 ]
419 .into_iter()
420 .for_each(|url| {
421 let url = Url::parse(url).unwrap();
422 assert_eq!(
423 GhRepo::try_extract_from_url(&url)
424 .unwrap()
425 .repo_url()
426 .unwrap(),
427 url
428 );
429 })
430 }
431
432 fn try_extract_artifact_from_str(s: &str) -> Option<GhReleaseArtifact> {
433 GhReleaseArtifact::try_extract_from_url(&url::Url::parse(s).unwrap())
434 }
435
436 fn assert_extract_gh_release_artifacts_failures(urls: &[&str]) {
437 for url in urls {
438 assert_eq!(try_extract_artifact_from_str(url), None);
439 }
440 }
441
442 #[test]
443 fn extract_gh_release_artifacts_failure() {
444 use cargo_binstall_v0_20_1::*;
445
446 let GhRelease {
447 repo: GhRepo { owner, repo },
448 tag,
449 } = RELEASE;
450
451 assert_extract_gh_release_artifacts_failures(&[
452 "https://examle.com",
453 "https://github.com",
454 &format!("https://github.com/{owner}"),
455 &format!("https://github.com/{owner}/{repo}"),
456 &format!("https://github.com/{owner}/{repo}/123e"),
457 &format!("https://github.com/{owner}/{repo}/releases/21343"),
458 &format!("https://github.com/{owner}/{repo}/releases/download"),
459 &format!("https://github.com/{owner}/{repo}/releases/download/{tag}"),
460 &format!("https://github.com/{owner}/{repo}/releases/download/{tag}/a/23"),
461 &format!("https://github.com/{owner}/{repo}/releases/download/{tag}/a#a=12"),
462 &format!("https://github.com/{owner}/{repo}/releases/download/{tag}/a?page=3"),
463 ]);
464 }
465
466 #[test]
467 fn extract_gh_release_artifacts_success() {
468 use cargo_binstall_v0_20_1::*;
469
470 let GhRelease {
471 repo: GhRepo { owner, repo },
472 tag,
473 } = RELEASE;
474
475 for artifact in ARTIFACTS {
476 let GhReleaseArtifact {
477 release,
478 artifact_name,
479 } = try_extract_artifact_from_str(&format!(
480 "https://github.com/{owner}/{repo}/releases/download/{tag}/{artifact}"
481 ))
482 .unwrap();
483
484 assert_eq!(release, RELEASE);
485 assert_eq!(artifact_name, artifact);
486 }
487 }
488
489 fn init_logger() {
490 let subscriber = fmt()
493 .without_time()
494 .with_target(false)
495 .with_file(false)
496 .with_line_number(false)
497 .with_thread_names(false)
498 .with_thread_ids(false)
499 .with_test_writer()
500 .with_max_level(LevelFilter::DEBUG)
501 .finish();
502
503 let _ = set_global_default(subscriber);
505 }
506
507 fn create_remote_client() -> remote::Client {
508 remote::Client::new(
509 concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")),
510 None,
511 NonZeroU16::new(300).unwrap(),
512 1.try_into().unwrap(),
513 [],
514 )
515 .unwrap()
516 }
517
518 fn create_client() -> Vec<GhApiClient> {
521 let client = create_remote_client();
522
523 let auth_token = match env::var("CI_UNIT_TEST_GITHUB_TOKEN") {
524 Ok(auth_token) if !auth_token.is_empty() => {
525 Some(zeroize::Zeroizing::new(auth_token.into_boxed_str()))
526 }
527 _ => None,
528 };
529
530 let gh_client = GhApiClient::new(client.clone(), auth_token.clone());
531 gh_client.set_only_use_restful_api();
532
533 let mut gh_clients = vec![gh_client];
534
535 if auth_token.is_some() {
536 gh_clients.push(GhApiClient::new(client, auth_token));
537 }
538
539 gh_clients
540 }
541
542 #[tokio::test]
543 async fn rate_limited_test_get_repo_info() {
544 const PUBLIC_REPOS: [GhRepo; 1] = [GhRepo {
545 owner: CompactString::const_new("cargo-bins"),
546 repo: CompactString::const_new("cargo-binstall"),
547 }];
548 const PRIVATE_REPOS: [GhRepo; 1] = [GhRepo {
549 owner: CompactString::const_new("cargo-bins"),
550 repo: CompactString::const_new("private-repo-for-testing"),
551 }];
552 const NON_EXISTENT_REPOS: [GhRepo; 1] = [GhRepo {
553 owner: CompactString::const_new("cargo-bins"),
554 repo: CompactString::const_new("ttt"),
555 }];
556
557 init_logger();
558
559 let mut tests: Vec<(_, _)> = Vec::new();
560
561 for client in create_client() {
562 let spawn_get_repo_info_task = |repo| {
563 let client = client.clone();
564 tokio::spawn(async move {
565 loop {
566 match client.get_repo_info(&repo).await {
567 Err(GhApiError::RateLimit { retry_after }) => {
568 sleep(retry_after.unwrap_or(DEFAULT_RETRY_AFTER)).await
569 }
570 res => break res,
571 }
572 }
573 })
574 };
575
576 for repo in PUBLIC_REPOS {
577 tests.push((
578 Some(RepoInfo::new(repo.clone(), false)),
579 spawn_get_repo_info_task(repo),
580 ));
581 }
582
583 for repo in NON_EXISTENT_REPOS {
584 tests.push((None, spawn_get_repo_info_task(repo)));
585 }
586
587 if client.has_gh_token() {
588 for repo in PRIVATE_REPOS {
589 tests.push((
590 Some(RepoInfo::new(repo.clone(), true)),
591 spawn_get_repo_info_task(repo),
592 ));
593 }
594 }
595 }
596
597 for (expected, task) in tests {
598 assert_eq!(task.await.unwrap().unwrap(), expected);
599 }
600 }
601
602 #[tokio::test]
603 async fn rate_limited_test_has_release_artifact_and_download_artifacts() {
604 const RELEASES: [(GhRelease, &[&str]); 1] = [(
605 cargo_binstall_v0_20_1::RELEASE,
606 cargo_binstall_v0_20_1::ARTIFACTS,
607 )];
608 const NON_EXISTENT_RELEASES: [GhRelease; 1] = [GhRelease {
609 repo: GhRepo {
610 owner: CompactString::const_new("cargo-bins"),
611 repo: CompactString::const_new("cargo-binstall"),
612 },
613 tag: CompactString::const_new("v0.18.2"),
616 }];
617
618 init_logger();
619
620 let mut tasks = Vec::new();
621
622 for client in create_client() {
623 async fn has_release_artifact(
624 client: &GhApiClient,
625 artifact: &GhReleaseArtifact,
626 ) -> Result<Option<GhReleaseArtifactUrl>, GhApiError> {
627 loop {
628 match client.has_release_artifact(artifact.clone()).await {
629 Err(GhApiError::RateLimit { retry_after }) => {
630 sleep(retry_after.unwrap_or(DEFAULT_RETRY_AFTER)).await
631 }
632 res => break res,
633 }
634 }
635 }
636
637 for (release, artifacts) in RELEASES {
638 for artifact_name in artifacts {
639 let client = client.clone();
640 let release = release.clone();
641 tasks.push(tokio::spawn(async move {
642 let artifact = GhReleaseArtifact {
643 release,
644 artifact_name: artifact_name.to_compact_string(),
645 };
646
647 let browser_download_task = client.get_auth_token().map(|_| {
648 tokio::spawn(
649 Download::new(
650 client.remote_client().clone(),
651 Url::parse(&format!(
652 "https://github.com/{}/{}/releases/download/{}/{}",
653 artifact.release.repo.owner,
654 artifact.release.repo.repo,
655 artifact.release.tag,
656 artifact.artifact_name,
657 ))
658 .unwrap(),
659 )
660 .into_bytes(),
661 )
662 });
663 let artifact_url = has_release_artifact(&client, &artifact)
664 .await
665 .unwrap()
666 .unwrap();
667
668 if let Some(browser_download_task) = browser_download_task {
669 let artifact_download_data = loop {
670 match client.download_artifact(artifact_url.clone()).await {
671 Err(GhApiError::RateLimit { retry_after }) => {
672 sleep(retry_after.unwrap_or(DEFAULT_RETRY_AFTER)).await
673 }
674 res => break res.unwrap(),
675 }
676 }
677 .into_bytes()
678 .await
679 .unwrap();
680
681 let browser_download_data =
682 browser_download_task.await.unwrap().unwrap();
683
684 assert_eq!(artifact_download_data, browser_download_data);
685 }
686 }));
687 }
688
689 let client = client.clone();
690 tasks.push(tokio::spawn(async move {
691 assert_eq!(
692 has_release_artifact(
693 &client,
694 &GhReleaseArtifact {
695 release,
696 artifact_name: "123z".to_compact_string(),
697 }
698 )
699 .await
700 .unwrap(),
701 None
702 );
703 }));
704 }
705
706 for release in NON_EXISTENT_RELEASES {
707 let client = client.clone();
708
709 tasks.push(tokio::spawn(async move {
710 assert_eq!(
711 has_release_artifact(
712 &client,
713 &GhReleaseArtifact {
714 release,
715 artifact_name: "1234".to_compact_string(),
716 }
717 )
718 .await
719 .unwrap(),
720 None
721 );
722 }));
723 }
724 }
725
726 for task in tasks {
727 task.await.unwrap();
728 }
729 }
730}