#![warn(rust_2018_idioms)]
use std::borrow::Cow;
use std::env::set_current_dir;
use std::fs::{create_dir_all, rename};
use std::io::{Write, stderr, stdout};
use std::iter::once;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::process::Command;
use std::process::exit;
use std::thread::sleep;
use std::time::Duration;
use anyhow::{Context, Error, anyhow, bail, ensure};
use clap::{Parser, crate_version};
use colored::Colorize;
use pbr::{ProgressBar, Units};
use remove_dir_all::remove_dir_all;
use reqwest::blocking::{Client, ClientBuilder};
use reqwest::header::{ACCEPT, AUTHORIZATION, CONTENT_LENGTH, HeaderMap, HeaderValue, USER_AGENT};
use reqwest::{Proxy, StatusCode};
use tar::Archive;
use tee::TeeReader;
use tempfile::{tempdir, tempdir_in};
use xz2::read::XzDecoder;
static SUPPORTED_CHANNELS: &[&str] = &["nightly", "beta", "stable"];
const MIN_RETRY_SLEEP: Duration = Duration::from_millis(250);
const MAX_RETRY_SLEEP: Duration = Duration::from_secs(60);
#[allow(clippy::struct_excessive_bools)]
#[derive(Parser, Debug)]
#[command(term_width(0), version(crate_version!()))]
struct Args {
#[arg(
help = "full commit hashes of the rustc builds, all 40 digits are needed; \
if omitted, the latest HEAD commit will be installed"
)]
commits: Vec<String>,
#[arg(short = 'n', long = "name", help = "the name to call the toolchain")]
name: Option<String>,
#[arg(
short = 'a',
long = "alt",
help = "download the alt build instead of normal build"
)]
alt: bool,
#[arg(
short = 's',
long = "server",
help = "the server path which stores the compilers",
default_value = "https://ci-artifacts.rust-lang.org"
)]
server: String,
#[arg(
short = 'i',
long = "host",
help = "the triple of the host platform",
default_value = env!("HOST")
)]
host: String,
#[arg(
short = 't',
long = "targets",
help = "additional target platforms to install rust-std for, besides the host platform",
num_args = 1..,
)]
targets: Vec<String>,
#[arg(
short = 'c',
long = "component",
help = "additional components to install, besides rustc and rust-std",
num_args = 1..,
)]
components: Vec<String>,
#[arg(
long = "channel",
help = "specify the channel of the commits instead of detecting it automatically"
)]
channel: Option<String>,
#[arg(
short = 'p',
long = "proxy",
help = "the HTTP proxy for all download requests"
)]
proxy: Option<String>,
#[arg(
long = "github-token",
help = "An authorization token to access GitHub APIs"
)]
github_token: Option<String>,
#[arg(
long = "dry-run",
help = "Only log the URLs, without downloading the artifacts"
)]
dry_run: bool,
#[arg(
long = "force",
short = 'f',
help = "Replace an existing toolchain of the same name"
)]
force: bool,
#[arg(
long = "keep-going",
short = 'k',
help = "Continue downloading toolchains even if some of them failed"
)]
keep_going: bool,
#[arg(
long = "retry",
short = 'r',
help = "Maximum number of retries for xz downloads",
default_value = "0"
)]
retry: u32,
}
struct RetryableError {
error: Error,
should_retry: bool,
}
impl RetryableError {
fn retryable<E: Into<Error>>(err: E) -> Self {
Self {
error: err.into(),
should_retry: true,
}
}
}
impl<E: Into<Error>> From<E> for RetryableError {
fn from(err: E) -> Self {
Self {
error: err.into(),
should_retry: false,
}
}
}
#[derive(Debug)]
struct TarXzDownloader<'a> {
client: Option<&'a Client>,
toolchain: &'a Toolchain<'a>,
channel: &'a str,
retry: u32,
}
impl<'a> TarXzDownloader<'a> {
fn download(&self, url: &str, component: &str, target: &str) -> Result<(), Error> {
let mut attempt = 0;
let mut sleep_duration = MIN_RETRY_SLEEP;
loop {
match self.download_once(url, component, target) {
Ok(()) => return Ok(()),
Err(err) if err.should_retry && attempt < self.retry => {
report_warn(&err.error);
eprintln!(
"download failed, going to retry after {sleep_duration:?} ({attempt}/{})",
self.retry
);
sleep(sleep_duration);
attempt += 1;
sleep_duration = MAX_RETRY_SLEEP.min(sleep_duration * 2);
}
Err(err) => return Err(err.error),
}
}
}
fn download_once(
&self,
url: &str,
component: &str,
target: &str,
) -> Result<(), RetryableError> {
eprintln!("downloading <{url}>...");
if let Some(client) = self.client {
let response = client.get(url).send().map_err(RetryableError::retryable)?;
let status = response.status();
if status != StatusCode::OK {
return Err(RetryableError {
error: if status == StatusCode::NOT_FOUND {
anyhow!(
"missing component `{component}` on toolchain `{}` on channel `{}` for target `{target}`",
self.toolchain.commit,
self.channel,
)
} else {
anyhow!("received status {status} for GET {url}")
},
should_retry: status.is_server_error(),
});
}
let length = response
.headers()
.get(CONTENT_LENGTH)
.and_then(|h| h.to_str().ok())
.and_then(|h| h.parse().ok())
.unwrap_or(0);
let err = stderr();
let lock = err.lock();
let mut progress_bar = ProgressBar::on(lock, length);
progress_bar.set_units(Units::Bytes);
progress_bar.set_max_refresh_rate(Some(Duration::from_secs(1)));
let response = TeeReader::new(response, &mut progress_bar);
let response = XzDecoder::new(response);
for entry in Archive::new(response)
.entries()
.map_err(RetryableError::retryable)?
{
let mut entry = entry.map_err(RetryableError::retryable)?;
let relpath = entry.path()?;
let mut components = relpath.components();
for part in components.clone() {
match part {
std::path::Component::Normal(_) => {}
_ => return Err(anyhow!("bad path in tar: {}", relpath.display()).into()),
}
}
components.next();
components.next();
let full_path = self.toolchain.dest.join(components.as_path());
if full_path == self.toolchain.dest {
continue;
}
let kind = entry.header().entry_type();
match kind {
tar::EntryType::Directory => {
create_dir_all(full_path)?;
}
tar::EntryType::Regular => {
entry.unpack(full_path).map_err(RetryableError::retryable)?;
}
_ => return Err(anyhow!("unsupported tar entry: {kind:?}").into()),
}
}
progress_bar.finish();
eprintln!();
}
Ok(())
}
}
#[derive(Debug)]
struct Toolchain<'a> {
commit: &'a str,
host_target: &'a str,
rust_std_targets: &'a [&'a str],
components: &'a [&'a str],
dest: PathBuf,
}
struct Installer<'a> {
client: &'a Client,
actually_install: bool,
override_channel: Option<&'a str>,
prefix: &'a str,
toolchains_path: &'a Path,
force: bool,
retry: u32,
}
impl<'a> Installer<'a> {
fn install_single_toolchain(&self, toolchain: &Toolchain<'_>) -> Result<(), Error> {
let toolchain_path = self.toolchains_path.join(&toolchain.dest);
if toolchain_path.is_dir() {
if self.force {
if self.actually_install {
remove_dir_all(&toolchain_path)?;
}
} else {
eprintln!(
"toolchain `{}` is already installed",
toolchain.dest.display()
);
return Ok(());
}
}
let channel = if let Some(channel) = self.override_channel {
Cow::Borrowed(channel)
} else {
Cow::Owned(get_channel(self.client, self.prefix, toolchain.commit)?)
};
let downloader = TarXzDownloader {
client: self.actually_install.then_some(self.client),
toolchain,
channel: &channel,
retry: self.retry,
};
for component in once(&"rustc").chain(toolchain.components) {
let component_filename = if *component == "rust-src" {
format_args!("{component}-{channel}")
} else {
format_args!("{component}-{channel}-{}", toolchain.host_target)
};
downloader.download(
&format!(
"{}/{}/{component_filename}.tar.xz",
self.prefix, toolchain.commit
),
component,
toolchain.host_target,
)?;
}
for target in toolchain.rust_std_targets {
downloader.download(
&format!(
"{}/{}/rust-std-{channel}-{target}.tar.xz",
self.prefix, toolchain.commit,
),
"rust-std",
target,
)?;
}
if self.actually_install {
rename(&toolchain.dest, toolchain_path)?;
eprintln!(
"toolchain `{}` is successfully installed!",
toolchain.dest.display()
);
} else {
eprintln!(
"toolchain `{}` will be installed to `{}` on real run",
toolchain.dest.display(),
toolchain_path.display()
);
}
Ok(())
}
}
fn fetch_master_commit(client: &Client, github_token: Option<&str>) -> Result<String, Error> {
eprintln!("fetching HEAD commit hash... ");
fetch_master_commit_via_git()
.context("unable to fetch HEAD commit via git, falling back to HTTP")
.or_else(|err| {
report_warn(&err);
fetch_master_commit_via_http(client, github_token)
})
}
fn fetch_master_commit_via_git() -> Result<String, Error> {
let mut output = Command::new("git")
.args(["ls-remote", "https://github.com/rust-lang/rust.git", "HEAD"])
.output()?;
ensure!(output.status.success(), "git ls-remote exited with error");
ensure!(
output
.stdout
.get(..40)
.is_some_and(|h| h.iter().all(u8::is_ascii_hexdigit)),
"git ls-remote does not return a commit"
);
output.stdout.truncate(40);
Ok(unsafe { String::from_utf8_unchecked(output.stdout) })
}
fn fetch_master_commit_via_http(
client: &Client,
github_token: Option<&str>,
) -> Result<String, Error> {
static URL: &str = "https://api.github.com/repos/rust-lang/rust/commits/HEAD";
static MEDIA_TYPE: &str = "application/vnd.github.VERSION.sha";
let mut req = client.get(URL).header(ACCEPT, MEDIA_TYPE);
if let Some(token) = github_token {
req = req.header(AUTHORIZATION, format!("token {token}"));
}
let response = req.send()?;
match response.status() {
StatusCode::OK => {}
status @ StatusCode::FORBIDDEN => {
let rate_limit = response
.headers()
.get("X-RateLimit-Remaining")
.and_then(|r| r.to_str().ok())
.and_then(|r| r.parse::<u32>().ok())
.unwrap_or(0);
if rate_limit == 0 {
bail!("GitHub API rate limit exceeded");
} else {
bail!("status: {} with rate limit: {}", status, rate_limit);
}
}
status => bail!("received status {} for URL {}", status, URL),
}
let master_commit = response.text()?;
if master_commit.len() == 40
&& master_commit
.chars()
.all(|c| matches!(c, '0'..='9' | 'a'..='f'))
{
let out = stdout();
let mut lock = out.lock();
lock.write_all(master_commit.as_bytes())?;
lock.flush()?;
eprintln!();
Ok(master_commit)
} else {
bail!("unable to parse `{}` as a commit", master_commit)
}
}
fn get_channel(client: &Client, prefix: &str, commit: &str) -> Result<String, Error> {
eprintln!("detecting the channel of the `{commit}` toolchain...");
let url = format!("{prefix}/{commit}/package-version");
let resp = client.get(&url).send()?;
match resp.status() {
StatusCode::OK => return Ok(resp.text()?.trim().to_owned()),
StatusCode::NOT_FOUND | StatusCode::FORBIDDEN => {}
status => bail!("unexpected status code {} for GET {}", status, url),
}
for channel in SUPPORTED_CHANNELS {
let url = format!("{prefix}/{commit}/rust-src-{channel}.tar.xz");
let resp = client.head(&url).send()?;
match resp.status() {
StatusCode::OK => return Ok(String::from(*channel)),
StatusCode::NOT_FOUND | StatusCode::FORBIDDEN => {}
status => bail!("unexpected status code {} for HEAD {}", status, url),
}
}
bail!("toolchain `{}` doesn't exist in any channel", commit);
}
fn run() -> Result<(), Error> {
let mut args = Args::parse();
let mut headers = HeaderMap::new();
headers.insert(
USER_AGENT,
HeaderValue::from_static("rustup-toolchain-install-master"),
);
let mut client_builder = ClientBuilder::new().default_headers(headers);
if let Some(proxy) = args.proxy {
client_builder = client_builder.proxy(Proxy::all(&proxy)?);
}
let client = client_builder.build()?;
let rustup_home = home::rustup_home().expect("$RUSTUP_HOME is undefined?");
let toolchains_path = rustup_home.join("toolchains");
if !toolchains_path.is_dir() {
bail!(
"`{}` is not a directory. please reinstall rustup.",
toolchains_path.display()
);
}
if args.commits.len() > 1 && args.name.is_some() {
return Err(Error::msg(
"name argument can only be provided with a single commit",
));
}
let components = args.components.iter().map(Deref::deref).collect::<Vec<_>>();
let rust_std_targets = args
.targets
.iter()
.map(Deref::deref)
.chain(once(&*args.host))
.collect::<Vec<_>>();
let toolchains_dir = {
let path = rustup_home.join("tmp");
if !path.exists() {
create_dir_all(&path)?;
}
if path.is_dir() {
tempdir_in(&path)
} else {
tempdir()
}
}?;
set_current_dir(toolchains_dir.path())?;
let prefix = format!(
"{}/rustc-builds{}",
args.server,
if args.alt { "-alt" } else { "" }
);
if args.commits.is_empty() {
args.commits
.push(fetch_master_commit(&client, args.github_token.as_deref())?);
}
let mut failed = false;
let installer = Installer {
client: &client,
actually_install: !args.dry_run,
override_channel: args.channel.as_deref(),
prefix: &prefix,
toolchains_path: &toolchains_path,
force: args.force,
retry: args.retry,
};
for commit in args.commits {
let dest = if let Some(name) = args.name.as_deref() {
PathBuf::from(name)
} else if args.alt {
PathBuf::from(format!("{commit}-alt"))
} else {
PathBuf::from(&commit)
};
let result = installer.install_single_toolchain(&Toolchain {
commit: &commit,
host_target: &args.host,
rust_std_targets: &rust_std_targets,
components: &components,
dest,
});
if args.keep_going {
if let Err(err) = result {
report_warn(
&err.context(format!("skipping toolchain `{commit}` due to a failure")),
);
failed = true;
}
} else {
result?;
}
}
if failed {
Err(Error::msg("failed to download some toolchains"))
} else {
Ok(())
}
}
fn report_error(err: &Error) {
eprintln!("{} {}", "error:".red().bold(), err);
for cause in err.chain().skip(1) {
eprintln!("{} {}", "caused by:".red().bold(), cause);
}
exit(1);
}
fn report_warn(warn: &Error) {
eprintln!("{} {}", "warn:".yellow().bold(), warn);
for cause in warn.chain().skip(1) {
eprintln!("{} {}", "caused by:".yellow().bold(), cause);
}
eprintln!();
}
fn main() {
if let Err(err) = run() {
report_error(&err);
}
}