use std::path::Path;
use xshell::Shell;
use crate::environment::{get_workspace_root, WorkspaceManifest};
use crate::quiet_cmd;
#[derive(Debug)]
enum ToolchainsLocation {
Workspace,
Package,
}
impl ToolchainsLocation {
fn table_name(&self) -> &'static str {
match self {
Self::Workspace => "[workspace.metadata.rbmt.toolchains]",
Self::Package => "[package.metadata.rbmt.toolchains]",
}
}
}
struct ToolchainsConfigData {
nightly: Option<String>,
stable: Option<String>,
location: ToolchainsLocation,
}
#[derive(serde::Deserialize, Default)]
struct RbmtTable {
#[serde(default)]
toolchains: Option<ToolchainsConfig>,
}
#[derive(serde::Deserialize)]
struct ToolchainsConfig {
nightly: Option<String>,
stable: Option<String>,
}
const RUSTUP_TOOLCHAIN: &str = "RUSTUP_TOOLCHAIN";
#[derive(Debug, Clone, Copy, PartialEq, Eq, clap::ValueEnum)]
pub enum Toolchain {
Nightly,
Stable,
Msrv,
}
impl Toolchain {
pub fn read_version(self, sh: &Shell) -> Result<String, Box<dyn std::error::Error>> {
let config = Self::read_toolchains_config(sh)?;
match self {
Self::Nightly => config.nightly.ok_or_else(|| {
format!("No pinned nightly toolchain found in {}", config.location.table_name())
.into()
}),
Self::Stable => config.stable.ok_or_else(|| {
format!("No pinned stable toolchain found in {}", config.location.table_name())
.into()
}),
Self::Msrv => get_workspace_msrv(sh),
}
}
pub fn write_version(
self,
sh: &Shell,
version: &str,
) -> Result<(), Box<dyn std::error::Error>> {
let root = get_workspace_root(sh)?;
let path = root.join("Cargo.toml");
let contents = std::fs::read_to_string(&path)?;
let mut doc: toml_edit::DocumentMut = contents.parse()?;
let table = match Self::read_toolchains_config(sh)?.location {
ToolchainsLocation::Workspace =>
&mut doc["workspace"]["metadata"]["rbmt"]["toolchains"],
ToolchainsLocation::Package => &mut doc["package"]["metadata"]["rbmt"]["toolchains"],
};
match self {
Self::Nightly => {
table["nightly"] = toml_edit::value(version);
}
Self::Stable => {
table["stable"] = toml_edit::value(version);
}
Self::Msrv =>
return Err(
"Cannot update MSRV via write_version; it's derived from Cargo.toml".into()
),
}
std::fs::write(&path, doc.to_string())?;
Ok(())
}
fn read_toolchains_config(
sh: &Shell,
) -> Result<ToolchainsConfigData, Box<dyn std::error::Error>> {
let root = get_workspace_root(sh)?;
let contents = std::fs::read_to_string(root.join("Cargo.toml"))?;
let cargo_toml = toml::from_str::<WorkspaceManifest<RbmtTable>>(&contents)?;
if let Some(toolchains) = cargo_toml.workspace.metadata.rbmt.toolchains {
return Ok(ToolchainsConfigData {
nightly: toolchains.nightly,
stable: toolchains.stable,
location: ToolchainsLocation::Workspace,
});
}
if let Some(toolchains) = cargo_toml.package.metadata.rbmt.toolchains {
return Ok(ToolchainsConfigData {
nightly: toolchains.nightly,
stable: toolchains.stable,
location: ToolchainsLocation::Package,
});
}
Err("No [workspace.metadata.rbmt.toolchains] or [package.metadata.rbmt.toolchains] exists."
.into())
}
}
pub fn check_toolchain(sh: &Shell, required: Toolchain) -> Result<(), Box<dyn std::error::Error>> {
let current = quiet_cmd!(sh, "rustc --version").read()?;
match required {
Toolchain::Nightly =>
if !current.contains("nightly") {
return Err(format!("Need a nightly compiler; have {}", current).into());
},
Toolchain::Stable =>
if current.contains("nightly") || current.contains("beta") {
return Err(format!("Need a stable compiler; have {}", current).into());
},
Toolchain::Msrv => {
let manifest_path = sh.current_dir().join("Cargo.toml");
if !manifest_path.exists() {
return Err("Not in a crate directory (no Cargo.toml found)".into());
}
let msrv_version = get_msrv_from_manifest(sh, &manifest_path)?;
let current_version =
extract_version(¤t).ok_or("Could not parse rustc version")?;
if current_version != msrv_version {
return Err(format!(
"Need Rust {} for MSRV testing in {}; have {}",
msrv_version,
manifest_path.display(),
current_version
)
.into());
}
}
}
Ok(())
}
pub fn prepare_toolchain(
sh: &Shell,
required: Toolchain,
) -> Result<(), Box<dyn std::error::Error>> {
maybe_set_rustup_toolchain(sh, required);
check_toolchain(sh, required)
}
fn maybe_set_rustup_toolchain(sh: &Shell, required: Toolchain) {
if quiet_cmd!(sh, "rustup --version").ignore_stderr().read().is_err() {
return;
}
if let Ok(toolchain) = required.read_version(sh) {
sh.set_var(RUSTUP_TOOLCHAIN, toolchain);
}
}
pub fn get_workspace_msrv(sh: &Shell) -> Result<String, Box<dyn std::error::Error>> {
let mut msrvs: Vec<String> =
collect_msrvs(sh)?.into_iter().filter_map(|(_, rust_version)| rust_version).collect();
msrvs.sort();
msrvs.dedup();
match msrvs.as_slice() {
[] => Err("No MSRV (rust-version) found in any Cargo.toml in the workspace".into()),
[msrv] => Ok(msrv.clone()),
_ => Err(format!("Workspace packages have conflicting MSRVs: {}", msrvs.join(", ")).into()),
}
}
fn get_msrv_from_manifest(
sh: &Shell,
manifest_path: &Path,
) -> Result<String, Box<dyn std::error::Error>> {
let manifest_path_str = manifest_path.to_str().ok_or_else(|| {
format!("Manifest path contains invalid UTF-8: {}", manifest_path.display())
})?;
collect_msrvs(sh)?
.into_iter()
.find(|(path, _)| path == manifest_path_str)
.and_then(|(_, rust_version)| rust_version)
.ok_or_else(|| {
format!("No MSRV (rust-version) specified in {}", manifest_path.display()).into()
})
}
type ManifestMsrv = (String, Option<String>);
fn collect_msrvs(sh: &Shell) -> Result<Vec<ManifestMsrv>, Box<dyn std::error::Error>> {
let metadata = quiet_cmd!(sh, "cargo metadata --format-version 1 --no-deps").read()?;
let data: serde_json::Value = serde_json::from_str(&metadata)?;
Ok(data["packages"]
.as_array()
.map(|packages| {
packages
.iter()
.filter_map(|pkg| {
let manifest_path = pkg["manifest_path"].as_str()?.to_string();
let rust_version = pkg["rust_version"].as_str().map(str::to_string);
Some((manifest_path, rust_version))
})
.collect()
})
.unwrap_or_default())
}
fn extract_version(rustc_version: &str) -> Option<&str> {
rustc_version.split_whitespace().find_map(|part| {
let version_part = part.split('-').next()?;
let parts: Vec<&str> = version_part.split('.').collect();
if parts.len() == 3 && parts.iter().all(|p| p.chars().all(|c| c.is_ascii_digit())) {
Some(version_part)
} else {
None
}
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_version() {
assert_eq!(extract_version("rustc 1.74.0 (79e9716c9 2023-11-13)"), Some("1.74.0"));
assert_eq!(extract_version("rustc 1.75.0-nightly (12345abcd 2023-11-20)"), Some("1.75.0"));
assert_eq!(extract_version("rustc 1.74.0"), Some("1.74.0"));
assert_eq!(extract_version("rustc unknown version"), None);
assert_eq!(extract_version("no version here"), None);
}
}