use crate::config::core::get_num_cpus;
use serde::Deserialize;
use std::{cmp::Ordering, fmt};
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ThreadsRequired {
Count(usize),
NumCpus,
NumTestThreads,
}
impl ThreadsRequired {
pub fn compute(self, test_threads: usize) -> usize {
match self {
Self::Count(threads) => threads,
Self::NumCpus => get_num_cpus(),
Self::NumTestThreads => test_threads,
}
}
}
impl<'de> Deserialize<'de> for ThreadsRequired {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct V;
impl serde::de::Visitor<'_> for V {
type Value = ThreadsRequired;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(
formatter,
"an integer, the string \"num-cpus\" or the string \"num-test-threads\""
)
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
if v == "num-cpus" {
Ok(ThreadsRequired::NumCpus)
} else if v == "num-test-threads" {
Ok(ThreadsRequired::NumTestThreads)
} else {
Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Str(v),
&self,
))
}
}
fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
match v.cmp(&0) {
Ordering::Greater => Ok(ThreadsRequired::Count(v as usize)),
Ordering::Equal | Ordering::Less => Err(serde::de::Error::invalid_value(
serde::de::Unexpected::Signed(v),
&self,
)),
}
}
}
deserializer.deserialize_any(V)
}
}
#[cfg(feature = "config-schema")]
impl schemars::JsonSchema for ThreadsRequired {
fn schema_name() -> std::borrow::Cow<'static, str> {
"ThreadsRequired".into()
}
fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
schemars::json_schema!({
"oneOf": [
{ "type": "integer", "minimum": 1 },
{ "type": "string", "enum": ["num-cpus", "num-test-threads"] }
]
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{core::NextestConfig, utils::test_helpers::*};
use camino_tempfile::tempdir;
use indoc::indoc;
use nextest_filtering::ParseContext;
use test_case::test_case;
#[test_case(
indoc! {r#"
[profile.custom]
threads-required = 2
"#},
Some(2)
; "positive"
)]
#[test_case(
indoc! {r#"
[profile.custom]
threads-required = 0
"#},
None
; "zero"
)]
#[test_case(
indoc! {r#"
[profile.custom]
threads-required = -1
"#},
None
; "negative"
)]
#[test_case(
indoc! {r#"
[profile.custom]
threads-required = "num-cpus"
"#},
Some(get_num_cpus())
; "num-cpus"
)]
#[test_case(
indoc! {r#"
[profile.custom]
test-threads = 1
threads-required = "num-cpus"
"#},
Some(get_num_cpus())
; "num-cpus-with-custom-test-threads"
)]
#[test_case(
indoc! {r#"
[profile.custom]
threads-required = "num-test-threads"
"#},
Some(get_num_cpus())
; "num-test-threads"
)]
#[test_case(
indoc! {r#"
[profile.custom]
test-threads = 1
threads-required = "num-test-threads"
"#},
Some(1)
; "num-test-threads-with-custom-test-threads"
)]
fn parse_threads_required(config_contents: &str, threads_required: Option<usize>) {
let workspace_dir = tempdir().unwrap();
let graph = temp_workspace(&workspace_dir, config_contents);
let pcx = ParseContext::new(&graph);
let config = NextestConfig::from_sources(
graph.workspace().root(),
&pcx,
None,
[],
&Default::default(),
);
match threads_required {
None => assert!(config.is_err()),
Some(t) => {
let config = config.unwrap();
let profile = config
.profile("custom")
.unwrap()
.apply_build_platforms(&build_platforms());
let test_threads = profile.test_threads().compute();
let threads_required = profile.threads_required().compute(test_threads);
assert_eq!(threads_required, t)
}
}
}
}