use crate::utils::ProgressReader;
use anyhow::{Context, Result};
use bon::Builder;
use clap::Args;
use flate2::bufread::GzDecoder;
use kdam::{Column, RichProgress, Spinner, term, tqdm};
use reqwest::blocking::{Client, RequestBuilder};
use serde::{Deserialize, Serialize};
use std::{
fs::{self},
io::{BufReader, IsTerminal},
path::PathBuf,
};
use tar::Archive;
#[derive(Args, Debug)]
pub struct BenchPullArgs {
#[arg(
short,
long,
help = "Path to where the benchmark is to be saved.",
value_name = "DIR"
)]
pub dest: PathBuf,
#[arg(long, help = "Version of scarfbench to pull.", value_name = "VERSION")]
pub version: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct Release {
assets: Vec<Asset>,
}
#[derive(Debug, Serialize, Deserialize)]
struct Asset {
name: String,
browser_download_url: String,
}
#[derive(Debug, Builder)]
pub struct PullScarfBench {
#[builder(default= "benchmark-v".to_string())]
pub asset_name_prefix: String,
pub version: Option<String>,
pub dest_dir: PathBuf,
}
impl PullScarfBench {
fn github_token() -> Option<String> {
std::env::var("SCARF_BENCH_GITHUB_TOKEN")
.ok()
.filter(|v| !v.trim().is_empty())
.or_else(|| {
std::env::var("GITHUB_TOKEN")
.ok()
.filter(|v| !v.trim().is_empty())
})
}
fn maybe_auth(request: RequestBuilder, token: Option<&str>) -> RequestBuilder {
match token {
Some(token) => request.bearer_auth(token),
None => request,
}
}
pub fn exec(&self) -> anyhow::Result<i32> {
let token = Self::github_token();
let client = Client::builder().user_agent("scarf-cli").build()?;
let api_url = match self.version.as_deref() {
Some(v) => format!(
"https://api.github.com/repos/scarfbench/benchmark/releases/tags/{}",
v
),
None => format!("https://api.github.com/repos/scarfbench/benchmark/releases/latest"),
};
log::info!("Downloading from {api_url}");
let release_response = Self::maybe_auth(client.get(&api_url), token.as_deref())
.header("User-Agent", "scarf")
.send()
.with_context(|| format!("Unable to fetch the release metadata from {api_url}"))?
.error_for_status()
.with_context(|| {
format!(
"GitHub API returned an error status while reading {api_url}. If this repo is private, set SCARF_BENCH_GITHUB_TOKEN (or GITHUB_TOKEN) with contents:read access to scarfbench/benchmark."
)
})?;
let releases: Release = release_response
.json()
.context("Failed to parse release JSON")?;
let asset = releases
.assets
.into_iter()
.find(|predicate| predicate.name.contains(&self.asset_name_prefix))
.with_context(|| {
return format!(
"There are no release assets that start with {} available at {}",
&self.asset_name_prefix, api_url
);
})?;
fs::create_dir_all(&self.dest_dir).with_context(|| {
return format!(
"Failed to create a directory at {}",
&self.dest_dir.to_string_lossy()
);
})?;
let response = Self::maybe_auth(client.get(&asset.browser_download_url), token.as_deref())
.send()
.with_context(|| {
return format!(
"Failed to download the asset from {}",
asset.browser_download_url
);
})?
.error_for_status()
.with_context(|| {
format!(
"Asset download returned an error status from {}. If this release is private, set SCARF_BENCH_GITHUB_TOKEN (or GITHUB_TOKEN) with access to scarfbench/benchmark.",
asset.browser_download_url
)
})?;
let total_size = response.content_length().map(|s| s as usize).unwrap_or(0);
term::init(std::io::stderr().is_terminal());
term::hide_cursor()?;
let pb = RichProgress::new(
tqdm!(
total = total_size as usize,
unit_scale = true,
unit_divisor = 1024,
unit = "B"
),
vec![
Column::Spinner(Spinner::new(
&["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"],
80.0,
1.0,
)),
Column::Text("[bold blue]Downloading scarfbench...".to_owned()),
Column::Animation,
Column::Percentage(1),
Column::Text("•".to_owned()),
Column::CountTotal,
Column::Text("•".to_owned()),
Column::Rate,
Column::Text("•".to_owned()),
Column::RemainingTime,
],
);
let pr = ProgressReader::new(BufReader::new(response), pb, Some(total_size));
let tar = GzDecoder::new(pr);
let mut archive = Archive::new(tar);
archive
.unpack(&self.dest_dir)
.with_context(|| format!("Failed to extract into {}", self.dest_dir.display()))?;
term::show_cursor()?;
Ok(0)
}
}
pub fn run(bench_pull_args: BenchPullArgs) -> Result<i32> {
PullScarfBench::builder()
.maybe_version(bench_pull_args.version)
.dest_dir(bench_pull_args.dest)
.build()
.exec()
}