use super::{NextestConfig, ToolConfigFile, ToolName};
use crate::errors::{ConfigParseError, ConfigParseErrorKind};
use camino::{Utf8Path, Utf8PathBuf};
use semver::Version;
use serde::{
Deserialize, Deserializer,
de::{MapAccess, SeqAccess, Visitor},
};
use std::{borrow::Cow, collections::BTreeSet, fmt, str::FromStr};
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct VersionOnlyConfig {
nextest_version: NextestVersionConfig,
experimental: ExperimentalConfig,
}
impl VersionOnlyConfig {
pub fn from_sources<'a, I>(
workspace_root: &Utf8Path,
config_file: Option<&Utf8Path>,
tool_config_files: impl IntoIterator<IntoIter = I>,
) -> Result<Self, ConfigParseError>
where
I: Iterator<Item = &'a ToolConfigFile> + DoubleEndedIterator,
{
let tool_config_files_rev = tool_config_files.into_iter().rev();
Self::read_from_sources(workspace_root, config_file, tool_config_files_rev)
}
pub fn nextest_version(&self) -> &NextestVersionConfig {
&self.nextest_version
}
pub fn experimental(&self) -> &ExperimentalConfig {
&self.experimental
}
fn read_from_sources<'a>(
workspace_root: &Utf8Path,
config_file: Option<&Utf8Path>,
tool_config_files_rev: impl Iterator<Item = &'a ToolConfigFile>,
) -> Result<Self, ConfigParseError> {
let mut nextest_version = NextestVersionConfig::default();
let mut known = BTreeSet::new();
let mut unknown = BTreeSet::new();
for ToolConfigFile { config_file, tool } in tool_config_files_rev {
if let Some(v) = Self::read_and_deserialize(config_file, Some(tool))?.nextest_version {
nextest_version.accumulate(v, Some(tool.clone()));
}
}
let config_file = match config_file {
Some(file) => Some(Cow::Borrowed(file)),
None => {
let config_file = workspace_root.join(NextestConfig::CONFIG_PATH);
config_file.exists().then_some(Cow::Owned(config_file))
}
};
if let Some(config_file) = config_file {
let d = Self::read_and_deserialize(&config_file, None)?;
if let Some(v) = d.nextest_version {
nextest_version.accumulate(v, None);
}
known.extend(d.experimental.known);
unknown.extend(d.experimental.unknown);
}
Ok(Self {
nextest_version,
experimental: ExperimentalConfig { known, unknown },
})
}
fn read_and_deserialize(
config_file: &Utf8Path,
tool: Option<&ToolName>,
) -> Result<VersionOnlyDeserialize, ConfigParseError> {
let toml_str = std::fs::read_to_string(config_file.as_str()).map_err(|error| {
ConfigParseError::new(
config_file,
tool,
ConfigParseErrorKind::VersionOnlyReadError(error),
)
})?;
let toml_de = toml::de::Deserializer::parse(&toml_str).map_err(|error| {
ConfigParseError::new(
config_file,
tool,
ConfigParseErrorKind::TomlParseError(Box::new(error)),
)
})?;
let v: VersionOnlyDeserialize =
serde_path_to_error::deserialize(toml_de).map_err(|error| {
ConfigParseError::new(
config_file,
tool,
ConfigParseErrorKind::VersionOnlyDeserializeError(Box::new(error)),
)
})?;
if tool.is_some() && !v.experimental.is_empty() {
return Err(ConfigParseError::new(
config_file,
tool,
ConfigParseErrorKind::ExperimentalFeaturesInToolConfig {
features: v.experimental.feature_names(),
},
));
}
Ok(v)
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq, Deserialize)]
#[serde(rename_all = "kebab-case")]
struct VersionOnlyDeserialize {
#[serde(default)]
nextest_version: Option<NextestVersionDeserialize>,
#[serde(default)]
experimental: ExperimentalDeserialize,
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub(crate) struct ExperimentalDeserialize {
known: BTreeSet<ConfigExperimental>,
unknown: BTreeSet<String>,
}
impl ExperimentalDeserialize {
fn is_empty(&self) -> bool {
self.known.is_empty() && self.unknown.is_empty()
}
fn feature_names(&self) -> BTreeSet<String> {
let mut names = self.unknown.clone();
for feature in &self.known {
names.insert(feature.to_string());
}
names
}
}
impl<'de> Deserialize<'de> for ExperimentalDeserialize {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ExperimentalVisitor;
impl<'de> Visitor<'de> for ExperimentalVisitor {
type Value = ExperimentalDeserialize;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str(
"a table ({ setup-scripts = true, benchmarks = true }) \
or an array ([\"setup-scripts\", \"benchmarks\"])",
)
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut known = BTreeSet::new();
let mut unknown = BTreeSet::new();
while let Some(feature_str) = seq.next_element::<String>()? {
if let Ok(feature) = feature_str.parse::<ConfigExperimental>() {
known.insert(feature);
} else {
unknown.insert(feature_str);
}
}
Ok(ExperimentalDeserialize { known, unknown })
}
fn visit_map<A>(self, map: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
#[derive(Deserialize)]
#[serde(rename_all = "kebab-case")]
struct TableConfig {
#[serde(default)]
setup_scripts: bool,
#[serde(default)]
wrapper_scripts: bool,
#[serde(default)]
benchmarks: bool,
}
let mut unknown = BTreeSet::new();
let de = serde::de::value::MapAccessDeserializer::new(map);
let mut cb = |path: serde_ignored::Path| {
unknown.insert(path.to_string());
};
let ignored_de = serde_ignored::Deserializer::new(de, &mut cb);
let TableConfig {
setup_scripts,
wrapper_scripts,
benchmarks,
} = Deserialize::deserialize(ignored_de).map_err(serde::de::Error::custom)?;
let mut known = BTreeSet::new();
if setup_scripts {
known.insert(ConfigExperimental::SetupScripts);
}
if wrapper_scripts {
known.insert(ConfigExperimental::WrapperScripts);
}
if benchmarks {
known.insert(ConfigExperimental::Benchmarks);
}
Ok(ExperimentalDeserialize { known, unknown })
}
}
deserializer.deserialize_any(ExperimentalVisitor)
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct NextestVersionConfig {
pub required: NextestVersionReq,
pub recommended: NextestVersionReq,
}
impl NextestVersionConfig {
pub(crate) fn accumulate(&mut self, v: NextestVersionDeserialize, v_tool: Option<ToolName>) {
if let Some(version) = v.required {
self.required.accumulate(version, v_tool.clone());
}
if let Some(version) = v.recommended {
self.recommended.accumulate(version, v_tool);
}
}
pub fn eval(
&self,
current_version: &Version,
override_version_check: bool,
) -> NextestVersionEval {
match self.required.satisfies(current_version) {
Ok(()) => {}
Err((required, tool)) => {
if override_version_check {
return NextestVersionEval::ErrorOverride {
required: required.clone(),
current: current_version.clone(),
tool: tool.cloned(),
};
} else {
return NextestVersionEval::Error {
required: required.clone(),
current: current_version.clone(),
tool: tool.cloned(),
};
}
}
}
match self.recommended.satisfies(current_version) {
Ok(()) => NextestVersionEval::Satisfied,
Err((recommended, tool)) => {
if override_version_check {
NextestVersionEval::WarnOverride {
recommended: recommended.clone(),
current: current_version.clone(),
tool: tool.cloned(),
}
} else {
NextestVersionEval::Warn {
recommended: recommended.clone(),
current: current_version.clone(),
tool: tool.cloned(),
}
}
}
}
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub struct ExperimentalConfig {
known: BTreeSet<ConfigExperimental>,
unknown: BTreeSet<String>,
}
impl ExperimentalConfig {
pub fn known(&self) -> &BTreeSet<ConfigExperimental> {
&self.known
}
pub fn eval(&self) -> ExperimentalConfigEval {
if self.unknown.is_empty() {
ExperimentalConfigEval::Satisfied
} else {
ExperimentalConfigEval::UnknownFeatures {
unknown: self.unknown.clone(),
known: ConfigExperimental::known_features().collect(),
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ExperimentalConfigEval {
Satisfied,
UnknownFeatures {
unknown: BTreeSet<String>,
known: BTreeSet<ConfigExperimental>,
},
}
impl ExperimentalConfigEval {
pub fn into_error(self, config_file: impl Into<Utf8PathBuf>) -> Option<ConfigParseError> {
match self {
ExperimentalConfigEval::Satisfied => None,
ExperimentalConfigEval::UnknownFeatures { unknown, known } => {
Some(ConfigParseError::new(
config_file,
None,
ConfigParseErrorKind::UnknownExperimentalFeatures { unknown, known },
))
}
}
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
#[non_exhaustive]
pub enum ConfigExperimental {
SetupScripts,
WrapperScripts,
Benchmarks,
}
impl ConfigExperimental {
pub fn known_features() -> impl Iterator<Item = Self> {
vec![Self::SetupScripts, Self::WrapperScripts, Self::Benchmarks].into_iter()
}
pub fn env_var(self) -> Option<&'static str> {
match self {
Self::SetupScripts => None,
Self::WrapperScripts => None,
Self::Benchmarks => Some("NEXTEST_EXPERIMENTAL_BENCHMARKS"),
}
}
pub fn from_env() -> std::collections::BTreeSet<Self> {
let mut set = std::collections::BTreeSet::new();
for feature in Self::known_features() {
if let Some(env_var) = feature.env_var()
&& std::env::var(env_var).as_deref() == Ok("1")
{
set.insert(feature);
}
}
set
}
}
impl FromStr for ConfigExperimental {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"setup-scripts" => Ok(Self::SetupScripts),
"wrapper-scripts" => Ok(Self::WrapperScripts),
"benchmarks" => Ok(Self::Benchmarks),
_ => Err(()),
}
}
}
impl fmt::Display for ConfigExperimental {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::SetupScripts => write!(f, "setup-scripts"),
Self::WrapperScripts => write!(f, "wrapper-scripts"),
Self::Benchmarks => write!(f, "benchmarks"),
}
}
}
#[derive(Debug, Default, Clone, PartialEq, Eq)]
pub enum NextestVersionReq {
Version {
version: Version,
tool: Option<ToolName>,
},
#[default]
None,
}
impl NextestVersionReq {
pub fn version(&self) -> Option<&Version> {
match self {
NextestVersionReq::Version { version, .. } => Some(version),
NextestVersionReq::None => None,
}
}
fn accumulate(&mut self, v: Version, v_tool: Option<ToolName>) {
match self {
NextestVersionReq::Version { version, tool } => {
if &v >= version {
*version = v;
*tool = v_tool;
}
}
NextestVersionReq::None => {
*self = NextestVersionReq::Version {
version: v,
tool: v_tool,
};
}
}
}
fn satisfies(&self, version: &Version) -> Result<(), (&Version, Option<&ToolName>)> {
match self {
NextestVersionReq::Version {
version: required,
tool,
} => {
if version >= required {
Ok(())
} else {
Err((required, tool.as_ref()))
}
}
NextestVersionReq::None => Ok(()),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NextestVersionEval {
Satisfied,
Error {
required: Version,
current: Version,
tool: Option<ToolName>,
},
Warn {
recommended: Version,
current: Version,
tool: Option<ToolName>,
},
ErrorOverride {
required: Version,
current: Version,
tool: Option<ToolName>,
},
WarnOverride {
recommended: Version,
current: Version,
tool: Option<ToolName>,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct NextestVersionDeserialize {
required: Option<Version>,
recommended: Option<Version>,
}
impl<'de> Deserialize<'de> for NextestVersionDeserialize {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct V;
impl<'de2> serde::de::Visitor<'de2> for V {
type Value = NextestVersionDeserialize;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str(
"a table ({{ required = \"0.9.20\", recommended = \"0.9.30\" }}) or a string (\"0.9.50\")",
)
}
fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
where
E: serde::de::Error,
{
let required = parse_version::<E>(s.to_owned())?;
Ok(NextestVersionDeserialize {
required: Some(required),
recommended: None,
})
}
fn visit_map<A>(self, map: A) -> std::result::Result<Self::Value, A::Error>
where
A: serde::de::MapAccess<'de2>,
{
#[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
struct NextestVersionMap {
#[serde(default, deserialize_with = "deserialize_version_opt")]
required: Option<Version>,
#[serde(default, deserialize_with = "deserialize_version_opt")]
recommended: Option<Version>,
}
let NextestVersionMap {
required,
recommended,
} = NextestVersionMap::deserialize(serde::de::value::MapAccessDeserializer::new(
map,
))?;
if let (Some(required), Some(recommended)) = (&required, &recommended)
&& required > recommended
{
return Err(serde::de::Error::custom(format!(
"required version ({required}) must not be greater than recommended version ({recommended})"
)));
}
Ok(NextestVersionDeserialize {
required,
recommended,
})
}
}
deserializer.deserialize_any(V)
}
}
fn deserialize_version_opt<'de, D>(
deserializer: D,
) -> std::result::Result<Option<Version>, D::Error>
where
D: Deserializer<'de>,
{
let s = Option::<String>::deserialize(deserializer)?;
s.map(parse_version::<D::Error>).transpose()
}
fn parse_version<E>(mut s: String) -> std::result::Result<Version, E>
where
E: serde::de::Error,
{
for ch in s.chars() {
if ch == '-' {
return Err(E::custom(
"pre-release identifiers are not supported in nextest-version",
));
} else if ch == '+' {
return Err(E::custom(
"build metadata is not supported in nextest-version",
));
}
}
if s.matches('.').count() == 1 {
s.push_str(".0");
}
Version::parse(&s).map_err(E::custom)
}
#[cfg(test)]
mod tests {
use super::*;
use test_case::test_case;
#[test_case(
r#"
nextest-version = "0.9"
"#,
NextestVersionDeserialize { required: Some("0.9.0".parse().unwrap()), recommended: None } ; "basic"
)]
#[test_case(
r#"
nextest-version = "0.9.30"
"#,
NextestVersionDeserialize { required: Some("0.9.30".parse().unwrap()), recommended: None } ; "basic with patch"
)]
#[test_case(
r#"
nextest-version = { recommended = "0.9.20" }
"#,
NextestVersionDeserialize { required: None, recommended: Some("0.9.20".parse().unwrap()) } ; "with warning"
)]
#[test_case(
r#"
nextest-version = { required = "0.9.20", recommended = "0.9.25" }
"#,
NextestVersionDeserialize {
required: Some("0.9.20".parse().unwrap()),
recommended: Some("0.9.25".parse().unwrap()),
} ; "with error and warning"
)]
fn test_valid_nextest_version(input: &str, expected: NextestVersionDeserialize) {
let actual: VersionOnlyDeserialize = toml::from_str(input).unwrap();
assert_eq!(actual.nextest_version.unwrap(), expected);
}
#[test_case(
r#"
nextest-version = 42
"#,
"a table ({{ required = \"0.9.20\", recommended = \"0.9.30\" }}) or a string (\"0.9.50\")" ; "empty"
)]
#[test_case(
r#"
nextest-version = "0.9.30-rc.1"
"#,
"pre-release identifiers are not supported in nextest-version" ; "pre-release"
)]
#[test_case(
r#"
nextest-version = "0.9.40+mybuild"
"#,
"build metadata is not supported in nextest-version" ; "build metadata"
)]
#[test_case(
r#"
nextest-version = { required = "0.9.20", recommended = "0.9.10" }
"#,
"required version (0.9.20) must not be greater than recommended version (0.9.10)" ; "error greater than warning"
)]
fn test_invalid_nextest_version(input: &str, error_message: &str) {
let err = toml::from_str::<VersionOnlyDeserialize>(input).unwrap_err();
assert!(
err.to_string().contains(error_message),
"error `{err}` contains `{error_message}`"
);
}
fn tool_name(s: &str) -> ToolName {
ToolName::new(s.into()).unwrap()
}
#[test]
fn test_accumulate() {
let mut nextest_version = NextestVersionConfig::default();
nextest_version.accumulate(
NextestVersionDeserialize {
required: Some("0.9.20".parse().unwrap()),
recommended: None,
},
Some(tool_name("tool1")),
);
nextest_version.accumulate(
NextestVersionDeserialize {
required: Some("0.9.30".parse().unwrap()),
recommended: Some("0.9.35".parse().unwrap()),
},
Some(tool_name("tool2")),
);
nextest_version.accumulate(
NextestVersionDeserialize {
required: None,
recommended: Some("0.9.25".parse().unwrap()),
},
Some(tool_name("tool3")),
);
nextest_version.accumulate(
NextestVersionDeserialize {
required: Some("0.9.30".parse().unwrap()),
recommended: None,
},
Some(tool_name("tool4")),
);
assert_eq!(
nextest_version,
NextestVersionConfig {
required: NextestVersionReq::Version {
version: "0.9.30".parse().unwrap(),
tool: Some(tool_name("tool4")),
},
recommended: NextestVersionReq::Version {
version: "0.9.35".parse().unwrap(),
tool: Some(tool_name("tool2")),
},
}
);
}
#[test]
fn test_from_env_benchmarks() {
unsafe { std::env::set_var("NEXTEST_EXPERIMENTAL_BENCHMARKS", "1") };
assert!(ConfigExperimental::from_env().contains(&ConfigExperimental::Benchmarks));
unsafe { std::env::set_var("NEXTEST_EXPERIMENTAL_BENCHMARKS", "0") };
assert!(!ConfigExperimental::from_env().contains(&ConfigExperimental::Benchmarks));
unsafe { std::env::set_var("NEXTEST_EXPERIMENTAL_BENCHMARKS", "true") };
assert!(!ConfigExperimental::from_env().contains(&ConfigExperimental::Benchmarks));
unsafe { std::env::set_var("NEXTEST_EXPERIMENTAL_BENCHMARKS", "1") };
let set = ConfigExperimental::from_env();
assert!(!set.contains(&ConfigExperimental::SetupScripts));
assert!(!set.contains(&ConfigExperimental::WrapperScripts));
}
#[test]
fn test_experimental_formats() {
let input = r#"experimental = ["setup-scripts", "benchmarks"]"#;
let d: VersionOnlyDeserialize = toml::from_str(input).unwrap();
assert_eq!(
d.experimental.known,
BTreeSet::from([
ConfigExperimental::SetupScripts,
ConfigExperimental::Benchmarks
]),
"expected 2 known features"
);
assert!(d.experimental.unknown.is_empty());
let input = r#"experimental = []"#;
let d: VersionOnlyDeserialize = toml::from_str(input).unwrap();
assert!(
d.experimental.is_empty(),
"expected empty, got {:?}",
d.experimental
);
let input = r#"experimental = ["setup-scripts", "unknown-feature"]"#;
let d: VersionOnlyDeserialize = toml::from_str(input).unwrap();
assert_eq!(
d.experimental.known,
BTreeSet::from([ConfigExperimental::SetupScripts])
);
assert_eq!(
d.experimental.unknown,
BTreeSet::from(["unknown-feature".to_owned()])
);
let input = r#"
[experimental]
setup-scripts = true
benchmarks = true
"#;
let d: VersionOnlyDeserialize = toml::from_str(input).unwrap();
assert_eq!(
d.experimental.known,
BTreeSet::from([
ConfigExperimental::SetupScripts,
ConfigExperimental::Benchmarks
])
);
assert!(d.experimental.unknown.is_empty());
let input = r#"[experimental]"#;
let d: VersionOnlyDeserialize = toml::from_str(input).unwrap();
assert!(
d.experimental.is_empty(),
"expected empty, got {:?}",
d.experimental
);
let input = r#"
[experimental]
setup-scripts = false
"#;
let d: VersionOnlyDeserialize = toml::from_str(input).unwrap();
assert!(
d.experimental.is_empty(),
"expected empty, got {:?}",
d.experimental
);
let input = r#"
[experimental]
setup-scripts = true
unknown-feature = true
"#;
let d: VersionOnlyDeserialize = toml::from_str(input).unwrap();
assert_eq!(
d.experimental.known,
BTreeSet::from([ConfigExperimental::SetupScripts])
);
assert!(d.experimental.unknown.contains("unknown-feature"));
let input = r#"experimental = 42"#;
let err = toml::from_str::<VersionOnlyDeserialize>(input).unwrap_err();
let err_str = err.to_string();
assert!(
err_str.contains("expected a table") && err_str.contains("or an array"),
"expected error to mention both formats, got: {}",
err_str
);
let input = r#"experimental = "setup-scripts""#;
let err = toml::from_str::<VersionOnlyDeserialize>(input).unwrap_err();
let err_str = err.to_string();
assert!(
err_str.contains("expected a table") && err_str.contains("or an array"),
"expected error to mention both formats, got: {}",
err_str
);
}
}