Skip to main content

nextest_runner/config/elements/
test_threads.rs

1// Copyright (c) The nextest Contributors
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::{config::core::get_num_cpus, errors::TestThreadsParseError};
5use serde::Deserialize;
6use std::{cmp::Ordering, fmt, str::FromStr};
7
8/// Type for the test-threads config key.
9#[derive(Clone, Copy, Debug, Eq, PartialEq)]
10pub enum TestThreads {
11    /// Run tests with a specified number of threads.
12    Count(usize),
13
14    /// Run tests with a number of threads equal to the logical CPU count.
15    NumCpus,
16}
17
18impl TestThreads {
19    /// Gets the actual number of test threads computed at runtime.
20    pub fn compute(self) -> usize {
21        match self {
22            Self::Count(threads) => threads,
23            Self::NumCpus => get_num_cpus(),
24        }
25    }
26}
27
28impl FromStr for TestThreads {
29    type Err = TestThreadsParseError;
30
31    fn from_str(s: &str) -> Result<Self, Self::Err> {
32        if s == "num-cpus" {
33            return Ok(Self::NumCpus);
34        }
35
36        match s.parse::<isize>() {
37            Err(e) => Err(TestThreadsParseError::new(format!(
38                "Error: {e} parsing {s}"
39            ))),
40            Ok(0) => Err(TestThreadsParseError::new("jobs may not be 0")),
41            Ok(j) if j < 0 => Ok(TestThreads::Count(
42                (get_num_cpus() as isize + j).max(1) as usize
43            )),
44            Ok(j) => Ok(TestThreads::Count(j as usize)),
45        }
46    }
47}
48
49impl fmt::Display for TestThreads {
50    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51        match self {
52            Self::Count(threads) => write!(f, "{threads}"),
53            Self::NumCpus => write!(f, "num-cpus"),
54        }
55    }
56}
57
58impl<'de> Deserialize<'de> for TestThreads {
59    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
60    where
61        D: serde::Deserializer<'de>,
62    {
63        struct V;
64
65        impl serde::de::Visitor<'_> for V {
66            type Value = TestThreads;
67
68            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
69                write!(formatter, "an integer or the string \"num-cpus\"")
70            }
71
72            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
73            where
74                E: serde::de::Error,
75            {
76                if v == "num-cpus" {
77                    Ok(TestThreads::NumCpus)
78                } else {
79                    Err(serde::de::Error::invalid_value(
80                        serde::de::Unexpected::Str(v),
81                        &self,
82                    ))
83                }
84            }
85
86            // Note that TOML uses i64, not u64.
87            fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
88            where
89                E: serde::de::Error,
90            {
91                match v.cmp(&0) {
92                    Ordering::Greater => Ok(TestThreads::Count(v as usize)),
93                    Ordering::Less => Ok(TestThreads::Count(
94                        (get_num_cpus() as i64 + v).max(1) as usize
95                    )),
96                    Ordering::Equal => Err(serde::de::Error::invalid_value(
97                        serde::de::Unexpected::Signed(v),
98                        &self,
99                    )),
100                }
101            }
102        }
103
104        deserializer.deserialize_any(V)
105    }
106}
107
108#[cfg(feature = "config-schema")]
109impl schemars::JsonSchema for TestThreads {
110    fn schema_name() -> std::borrow::Cow<'static, str> {
111        "TestThreads".into()
112    }
113
114    fn json_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
115        schemars::json_schema!({
116            "oneOf": [
117                { "type": "integer", "not": { "const": 0 } },
118                { "type": "string", "enum": ["num-cpus"] }
119            ]
120        })
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::config::{core::NextestConfig, utils::test_helpers::*};
128    use camino_tempfile::tempdir;
129    use indoc::indoc;
130    use nextest_filtering::ParseContext;
131    use test_case::test_case;
132
133    #[test_case(
134        indoc! {r#"
135            [profile.custom]
136            test-threads = -1
137        "#},
138        Some(get_num_cpus() - 1)
139
140        ; "negative"
141    )]
142    #[test_case(
143        indoc! {r#"
144            [profile.custom]
145            test-threads = 2
146        "#},
147        Some(2)
148
149        ; "positive"
150    )]
151    #[test_case(
152        indoc! {r#"
153            [profile.custom]
154            test-threads = 0
155        "#},
156        None
157
158        ; "zero"
159    )]
160    #[test_case(
161        indoc! {r#"
162            [profile.custom]
163            test-threads = "num-cpus"
164        "#},
165        Some(get_num_cpus())
166
167        ; "num-cpus"
168    )]
169    fn parse_test_threads(config_contents: &str, n_threads: Option<usize>) {
170        let workspace_dir = tempdir().unwrap();
171
172        let graph = temp_workspace(&workspace_dir, config_contents);
173
174        let pcx = ParseContext::new(&graph);
175        let config = NextestConfig::from_sources(
176            graph.workspace().root(),
177            &pcx,
178            None,
179            [],
180            &Default::default(),
181        );
182        match n_threads {
183            None => assert!(config.is_err()),
184            Some(n) => assert_eq!(
185                config
186                    .unwrap()
187                    .profile("custom")
188                    .unwrap()
189                    .apply_build_platforms(&build_platforms())
190                    .custom_profile()
191                    .unwrap()
192                    .test_threads()
193                    .unwrap()
194                    .compute(),
195                n,
196            ),
197        }
198    }
199}