1use super::{NextestConfig, ToolConfigFile, ToolName};
7use crate::errors::{ConfigParseError, ConfigParseErrorKind};
8use camino::Utf8Path;
9use semver::Version;
10use serde::{Deserialize, Deserializer};
11use std::{borrow::Cow, collections::BTreeSet, fmt, str::FromStr};
12
13#[derive(Debug, Default, Clone, PartialEq, Eq)]
18pub struct VersionOnlyConfig {
19 nextest_version: NextestVersionConfig,
21
22 experimental: BTreeSet<ConfigExperimental>,
24}
25
26impl VersionOnlyConfig {
27 pub fn from_sources<'a, I>(
31 workspace_root: &Utf8Path,
32 config_file: Option<&Utf8Path>,
33 tool_config_files: impl IntoIterator<IntoIter = I>,
34 ) -> Result<Self, ConfigParseError>
35 where
36 I: Iterator<Item = &'a ToolConfigFile> + DoubleEndedIterator,
37 {
38 let tool_config_files_rev = tool_config_files.into_iter().rev();
39
40 Self::read_from_sources(workspace_root, config_file, tool_config_files_rev)
41 }
42
43 pub fn nextest_version(&self) -> &NextestVersionConfig {
45 &self.nextest_version
46 }
47
48 pub fn experimental(&self) -> &BTreeSet<ConfigExperimental> {
50 &self.experimental
51 }
52
53 fn read_from_sources<'a>(
54 workspace_root: &Utf8Path,
55 config_file: Option<&Utf8Path>,
56 tool_config_files_rev: impl Iterator<Item = &'a ToolConfigFile>,
57 ) -> Result<Self, ConfigParseError> {
58 let mut nextest_version = NextestVersionConfig::default();
59 let mut experimental = BTreeSet::new();
60
61 for ToolConfigFile { config_file, tool } in tool_config_files_rev {
63 if let Some(v) = Self::read_and_deserialize(config_file, Some(tool))?.nextest_version {
64 nextest_version.accumulate(v, Some(tool.clone()));
65 }
66 }
67
68 let config_file = match config_file {
70 Some(file) => Some(Cow::Borrowed(file)),
71 None => {
72 let config_file = workspace_root.join(NextestConfig::CONFIG_PATH);
73 config_file.exists().then_some(Cow::Owned(config_file))
74 }
75 };
76 if let Some(config_file) = config_file {
77 let d = Self::read_and_deserialize(&config_file, None)?;
78 if let Some(v) = d.nextest_version {
79 nextest_version.accumulate(v, None);
80 }
81
82 let unknown: BTreeSet<_> = d
84 .experimental
85 .into_iter()
86 .filter(|feature| {
87 if let Ok(feature) = feature.parse::<ConfigExperimental>() {
88 experimental.insert(feature);
89 false
90 } else {
91 true
92 }
93 })
94 .collect();
95 if !unknown.is_empty() {
96 let known = ConfigExperimental::known().collect();
97 return Err(ConfigParseError::new(
98 config_file.into_owned(),
99 None,
100 ConfigParseErrorKind::UnknownExperimentalFeatures { unknown, known },
101 ));
102 }
103 }
104
105 Ok(Self {
106 nextest_version,
107 experimental,
108 })
109 }
110
111 fn read_and_deserialize(
112 config_file: &Utf8Path,
113 tool: Option<&ToolName>,
114 ) -> Result<VersionOnlyDeserialize, ConfigParseError> {
115 let toml_str = std::fs::read_to_string(config_file.as_str()).map_err(|error| {
116 ConfigParseError::new(
117 config_file,
118 tool,
119 ConfigParseErrorKind::VersionOnlyReadError(error),
120 )
121 })?;
122 let toml_de = toml::de::Deserializer::parse(&toml_str).map_err(|error| {
123 ConfigParseError::new(
124 config_file,
125 tool,
126 ConfigParseErrorKind::TomlParseError(Box::new(error)),
127 )
128 })?;
129 let v: VersionOnlyDeserialize =
130 serde_path_to_error::deserialize(toml_de).map_err(|error| {
131 ConfigParseError::new(
132 config_file,
133 tool,
134 ConfigParseErrorKind::VersionOnlyDeserializeError(Box::new(error)),
135 )
136 })?;
137 if tool.is_some() && !v.experimental.is_empty() {
138 return Err(ConfigParseError::new(
139 config_file,
140 tool,
141 ConfigParseErrorKind::ExperimentalFeaturesInToolConfig {
142 features: v.experimental,
143 },
144 ));
145 }
146
147 Ok(v)
148 }
149}
150
151#[derive(Debug, Default, Clone, PartialEq, Eq, Deserialize)]
153#[serde(rename_all = "kebab-case")]
154struct VersionOnlyDeserialize {
155 #[serde(default)]
156 nextest_version: Option<NextestVersionDeserialize>,
157 #[serde(default)]
158 experimental: BTreeSet<String>,
159}
160
161#[derive(Debug, Default, Clone, PartialEq, Eq)]
167pub struct NextestVersionConfig {
168 pub required: NextestVersionReq,
170
171 pub recommended: NextestVersionReq,
176}
177
178impl NextestVersionConfig {
179 pub(crate) fn accumulate(&mut self, v: NextestVersionDeserialize, v_tool: Option<ToolName>) {
181 if let Some(version) = v.required {
182 self.required.accumulate(version, v_tool.clone());
183 }
184 if let Some(version) = v.recommended {
185 self.recommended.accumulate(version, v_tool);
186 }
187 }
188
189 pub fn eval(
191 &self,
192 current_version: &Version,
193 override_version_check: bool,
194 ) -> NextestVersionEval {
195 match self.required.satisfies(current_version) {
196 Ok(()) => {}
197 Err((required, tool)) => {
198 if override_version_check {
199 return NextestVersionEval::ErrorOverride {
200 required: required.clone(),
201 current: current_version.clone(),
202 tool: tool.cloned(),
203 };
204 } else {
205 return NextestVersionEval::Error {
206 required: required.clone(),
207 current: current_version.clone(),
208 tool: tool.cloned(),
209 };
210 }
211 }
212 }
213
214 match self.recommended.satisfies(current_version) {
215 Ok(()) => NextestVersionEval::Satisfied,
216 Err((recommended, tool)) => {
217 if override_version_check {
218 NextestVersionEval::WarnOverride {
219 recommended: recommended.clone(),
220 current: current_version.clone(),
221 tool: tool.cloned(),
222 }
223 } else {
224 NextestVersionEval::Warn {
225 recommended: recommended.clone(),
226 current: current_version.clone(),
227 tool: tool.cloned(),
228 }
229 }
230 }
231 }
232 }
233}
234
235#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
237#[non_exhaustive]
238pub enum ConfigExperimental {
239 SetupScripts,
241 WrapperScripts,
243 Benchmarks,
245}
246
247impl ConfigExperimental {
248 fn known() -> impl Iterator<Item = Self> {
249 vec![Self::SetupScripts, Self::WrapperScripts, Self::Benchmarks].into_iter()
250 }
251
252 pub fn env_var(self) -> Option<&'static str> {
254 match self {
255 Self::SetupScripts => None,
256 Self::WrapperScripts => None,
257 Self::Benchmarks => Some("NEXTEST_EXPERIMENTAL_BENCHMARKS"),
258 }
259 }
260
261 pub fn from_env() -> std::collections::BTreeSet<Self> {
263 let mut set = std::collections::BTreeSet::new();
264 for feature in Self::known() {
265 if let Some(env_var) = feature.env_var()
266 && std::env::var(env_var).as_deref() == Ok("1")
267 {
268 set.insert(feature);
269 }
270 }
271 set
272 }
273}
274
275impl FromStr for ConfigExperimental {
276 type Err = ();
277
278 fn from_str(s: &str) -> Result<Self, Self::Err> {
279 match s {
280 "setup-scripts" => Ok(Self::SetupScripts),
281 "wrapper-scripts" => Ok(Self::WrapperScripts),
282 "benchmarks" => Ok(Self::Benchmarks),
283 _ => Err(()),
284 }
285 }
286}
287
288impl fmt::Display for ConfigExperimental {
289 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
290 match self {
291 Self::SetupScripts => write!(f, "setup-scripts"),
292 Self::WrapperScripts => write!(f, "wrapper-scripts"),
293 Self::Benchmarks => write!(f, "benchmarks"),
294 }
295 }
296}
297
298#[derive(Debug, Default, Clone, PartialEq, Eq)]
300pub enum NextestVersionReq {
301 Version {
303 version: Version,
305
306 tool: Option<ToolName>,
308 },
309
310 #[default]
312 None,
313}
314
315impl NextestVersionReq {
316 fn accumulate(&mut self, v: Version, v_tool: Option<ToolName>) {
317 match self {
318 NextestVersionReq::Version { version, tool } => {
319 if &v >= version {
322 *version = v;
323 *tool = v_tool;
324 }
325 }
326 NextestVersionReq::None => {
327 *self = NextestVersionReq::Version {
328 version: v,
329 tool: v_tool,
330 };
331 }
332 }
333 }
334
335 fn satisfies(&self, version: &Version) -> Result<(), (&Version, Option<&ToolName>)> {
336 match self {
337 NextestVersionReq::Version {
338 version: required,
339 tool,
340 } => {
341 if version >= required {
342 Ok(())
343 } else {
344 Err((required, tool.as_ref()))
345 }
346 }
347 NextestVersionReq::None => Ok(()),
348 }
349 }
350}
351
352#[derive(Debug, Clone, PartialEq, Eq)]
356pub enum NextestVersionEval {
357 Satisfied,
359
360 Error {
362 required: Version,
364 current: Version,
366 tool: Option<ToolName>,
368 },
369
370 Warn {
372 recommended: Version,
374 current: Version,
376 tool: Option<ToolName>,
378 },
379
380 ErrorOverride {
382 required: Version,
384 current: Version,
386 tool: Option<ToolName>,
388 },
389
390 WarnOverride {
392 recommended: Version,
394 current: Version,
396 tool: Option<ToolName>,
398 },
399}
400
401#[derive(Debug, Clone, PartialEq, Eq)]
407pub(crate) struct NextestVersionDeserialize {
408 required: Option<Version>,
410
411 recommended: Option<Version>,
413}
414
415impl<'de> Deserialize<'de> for NextestVersionDeserialize {
416 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
417 where
418 D: Deserializer<'de>,
419 {
420 struct V;
421
422 impl<'de2> serde::de::Visitor<'de2> for V {
423 type Value = NextestVersionDeserialize;
424
425 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
426 formatter.write_str(
427 "a table ({{ required = \"0.9.20\", recommended = \"0.9.30\" }}) or a string (\"0.9.50\")",
428 )
429 }
430
431 fn visit_str<E>(self, s: &str) -> std::result::Result<Self::Value, E>
432 where
433 E: serde::de::Error,
434 {
435 let required = parse_version::<E>(s.to_owned())?;
436 Ok(NextestVersionDeserialize {
437 required: Some(required),
438 recommended: None,
439 })
440 }
441
442 fn visit_map<A>(self, map: A) -> std::result::Result<Self::Value, A::Error>
443 where
444 A: serde::de::MapAccess<'de2>,
445 {
446 #[derive(Deserialize, Debug, Clone, PartialEq, Eq)]
447 struct NextestVersionMap {
448 #[serde(default, deserialize_with = "deserialize_version_opt")]
449 required: Option<Version>,
450 #[serde(default, deserialize_with = "deserialize_version_opt")]
451 recommended: Option<Version>,
452 }
453
454 let NextestVersionMap {
455 required,
456 recommended,
457 } = NextestVersionMap::deserialize(serde::de::value::MapAccessDeserializer::new(
458 map,
459 ))?;
460
461 if let (Some(required), Some(recommended)) = (&required, &recommended)
462 && required > recommended
463 {
464 return Err(serde::de::Error::custom(format!(
465 "required version ({required}) must not be greater than recommended version ({recommended})"
466 )));
467 }
468
469 Ok(NextestVersionDeserialize {
470 required,
471 recommended,
472 })
473 }
474 }
475
476 deserializer.deserialize_any(V)
477 }
478}
479
480fn deserialize_version_opt<'de, D>(
485 deserializer: D,
486) -> std::result::Result<Option<Version>, D::Error>
487where
488 D: Deserializer<'de>,
489{
490 let s = Option::<String>::deserialize(deserializer)?;
491 s.map(parse_version::<D::Error>).transpose()
492}
493
494fn parse_version<E>(mut s: String) -> std::result::Result<Version, E>
495where
496 E: serde::de::Error,
497{
498 for ch in s.chars() {
499 if ch == '-' {
500 return Err(E::custom(
501 "pre-release identifiers are not supported in nextest-version",
502 ));
503 } else if ch == '+' {
504 return Err(E::custom(
505 "build metadata is not supported in nextest-version",
506 ));
507 }
508 }
509
510 if s.matches('.').count() == 1 {
513 s.push_str(".0");
515 }
516
517 Version::parse(&s).map_err(E::custom)
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use test_case::test_case;
524
525 #[test_case(
526 r#"
527 nextest-version = "0.9"
528 "#,
529 NextestVersionDeserialize { required: Some("0.9.0".parse().unwrap()), recommended: None } ; "basic"
530 )]
531 #[test_case(
532 r#"
533 nextest-version = "0.9.30"
534 "#,
535 NextestVersionDeserialize { required: Some("0.9.30".parse().unwrap()), recommended: None } ; "basic with patch"
536 )]
537 #[test_case(
538 r#"
539 nextest-version = { recommended = "0.9.20" }
540 "#,
541 NextestVersionDeserialize { required: None, recommended: Some("0.9.20".parse().unwrap()) } ; "with warning"
542 )]
543 #[test_case(
544 r#"
545 nextest-version = { required = "0.9.20", recommended = "0.9.25" }
546 "#,
547 NextestVersionDeserialize {
548 required: Some("0.9.20".parse().unwrap()),
549 recommended: Some("0.9.25".parse().unwrap()),
550 } ; "with error and warning"
551 )]
552 fn test_valid_nextest_version(input: &str, expected: NextestVersionDeserialize) {
553 let actual: VersionOnlyDeserialize = toml::from_str(input).unwrap();
554 assert_eq!(actual.nextest_version.unwrap(), expected);
555 }
556
557 #[test_case(
558 r#"
559 nextest-version = 42
560 "#,
561 "a table ({{ required = \"0.9.20\", recommended = \"0.9.30\" }}) or a string (\"0.9.50\")" ; "empty"
562 )]
563 #[test_case(
564 r#"
565 nextest-version = "0.9.30-rc.1"
566 "#,
567 "pre-release identifiers are not supported in nextest-version" ; "pre-release"
568 )]
569 #[test_case(
570 r#"
571 nextest-version = "0.9.40+mybuild"
572 "#,
573 "build metadata is not supported in nextest-version" ; "build metadata"
574 )]
575 #[test_case(
576 r#"
577 nextest-version = { required = "0.9.20", recommended = "0.9.10" }
578 "#,
579 "required version (0.9.20) must not be greater than recommended version (0.9.10)" ; "error greater than warning"
580 )]
581 fn test_invalid_nextest_version(input: &str, error_message: &str) {
582 let err = toml::from_str::<VersionOnlyDeserialize>(input).unwrap_err();
583 assert!(
584 err.to_string().contains(error_message),
585 "error `{err}` contains `{error_message}`"
586 );
587 }
588
589 fn tool_name(s: &str) -> ToolName {
590 ToolName::new(s.into()).unwrap()
591 }
592
593 #[test]
594 fn test_accumulate() {
595 let mut nextest_version = NextestVersionConfig::default();
596 nextest_version.accumulate(
597 NextestVersionDeserialize {
598 required: Some("0.9.20".parse().unwrap()),
599 recommended: None,
600 },
601 Some(tool_name("tool1")),
602 );
603 nextest_version.accumulate(
604 NextestVersionDeserialize {
605 required: Some("0.9.30".parse().unwrap()),
606 recommended: Some("0.9.35".parse().unwrap()),
607 },
608 Some(tool_name("tool2")),
609 );
610 nextest_version.accumulate(
611 NextestVersionDeserialize {
612 required: None,
613 recommended: Some("0.9.25".parse().unwrap()),
616 },
617 Some(tool_name("tool3")),
618 );
619 nextest_version.accumulate(
620 NextestVersionDeserialize {
621 required: Some("0.9.30".parse().unwrap()),
624 recommended: None,
625 },
626 Some(tool_name("tool4")),
627 );
628
629 assert_eq!(
630 nextest_version,
631 NextestVersionConfig {
632 required: NextestVersionReq::Version {
633 version: "0.9.30".parse().unwrap(),
634 tool: Some(tool_name("tool4")),
635 },
636 recommended: NextestVersionReq::Version {
637 version: "0.9.35".parse().unwrap(),
638 tool: Some(tool_name("tool2")),
639 },
640 }
641 );
642 }
643
644 #[test]
645 fn test_from_env_benchmarks() {
646 unsafe { std::env::set_var("NEXTEST_EXPERIMENTAL_BENCHMARKS", "1") };
649 assert!(ConfigExperimental::from_env().contains(&ConfigExperimental::Benchmarks));
650
651 unsafe { std::env::set_var("NEXTEST_EXPERIMENTAL_BENCHMARKS", "0") };
655 assert!(!ConfigExperimental::from_env().contains(&ConfigExperimental::Benchmarks));
656
657 unsafe { std::env::set_var("NEXTEST_EXPERIMENTAL_BENCHMARKS", "true") };
660 assert!(!ConfigExperimental::from_env().contains(&ConfigExperimental::Benchmarks));
661
662 unsafe { std::env::set_var("NEXTEST_EXPERIMENTAL_BENCHMARKS", "1") };
667 let set = ConfigExperimental::from_env();
668 assert!(!set.contains(&ConfigExperimental::SetupScripts));
669 assert!(!set.contains(&ConfigExperimental::WrapperScripts));
670 }
671}